remove unnecessary check
This commit is contained in:
@@ -98,7 +98,6 @@ class CausalSelfAttention(nn.Module):
|
|||||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||||
prefix_len = Tk - Tq
|
prefix_len = Tk - Tq
|
||||||
if prefix_len > 0: # can't be negative but could be zero
|
|
||||||
attn_mask[:, :prefix_len] = True
|
attn_mask[:, :prefix_len] = True
|
||||||
# Then, causal attention within this chunk
|
# Then, causal attention within this chunk
|
||||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||||
|
|||||||
Reference in New Issue
Block a user