Fix SDPA KV-cache decode to respect sliding window (#456)

SDPA fallback now respects sliding window during single-token KV-cache
decode by slicing K/V to the last (window + 1) tokens.

Also simplifies the mask building for chunk inference to properly apply
sliding window in that path as well.

Fixes #452

Co-Authored-By: Kartik Vashishta <kartikv776@gmail.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Andrej Karpathy
2026-01-30 17:32:12 +00:00
parent ace6740bdd
commit 3ba42e8135
2 changed files with 47 additions and 15 deletions
+14 -15
View File
@@ -71,27 +71,26 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
# Single token generation
if Tq == 1:
if window >= 0 and window < Tk:
# window is "left" tokens we need to include (window + 1) keys total
start = max(0, Tk - (window + 1))
k = k[:, :, start:, :]
v = v[:, :, start:, :]
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Need explicit mask
# Need explicit mask for sliding window/chunk inference
device = q.device
if Tq == Tk:
# Causal + sliding window
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
if window > 0 and window < Tq:
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = mask & ((row_idx - col_idx) <= window)
else:
# Chunk inference: attend to prefix + causal within chunk
prefix_len = Tk - Tq
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
mask[:, :prefix_len] = True
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = col_idx <= row_idx
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
# =============================================================================
# Public API: Same interface as FA3
# =============================================================================