remove unnecessary check
This commit is contained in:
+1
-2
@@ -98,8 +98,7 @@ class CausalSelfAttention(nn.Module):
|
|||||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||||
prefix_len = Tk - Tq
|
prefix_len = Tk - Tq
|
||||||
if prefix_len > 0: # can't be negative but could be zero
|
attn_mask[:, :prefix_len] = True
|
||||||
attn_mask[:, :prefix_len] = True
|
|
||||||
# Then, causal attention within this chunk
|
# Then, causal attention within this chunk
|
||||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
||||||
|
|||||||
Reference in New Issue
Block a user