remove unnecessary check to make the logic in CausalSelfAttention.forward() clearer

This commit is contained in:
Andrej
2025-12-08 18:30:37 -08:00
committed by GitHub
-1
View File
@@ -98,7 +98,6 @@ class CausalSelfAttention(nn.Module):
# First, each query attends to all the cached keys/values (i.e. full prefix) # First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq prefix_len = Tk - Tq
if prefix_len > 0: # can't be negative but could be zero
attn_mask[:, :prefix_len] = True attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk # Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))