remove unnecessary check

This commit is contained in:
Eric Silberstein
2025-11-19 16:31:41 -05:00
parent 4a87a0d19f
commit 5c93a56be5
-1
View File
@@ -98,7 +98,6 @@ class CausalSelfAttention(nn.Module):
# First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
if prefix_len > 0: # can't be negative but could be zero
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))