Fix SDPA KV-cache decode to respect sliding window (#456)
SDPA fallback now respects sliding window during single-token KV-cache decode by slicing K/V to the last (window + 1) tokens. Also simplifies the mask building for chunk inference to properly apply sliding window in that path as well. Fixes #452 Co-Authored-By: Kartik Vashishta <kartikv776@gmail.com> Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
+14
-15
@@ -71,27 +71,26 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
|||||||
|
|
||||||
# Single token generation
|
# Single token generation
|
||||||
if Tq == 1:
|
if Tq == 1:
|
||||||
|
if window >= 0 and window < Tk:
|
||||||
|
# window is "left" tokens we need to include (window + 1) keys total
|
||||||
|
start = max(0, Tk - (window + 1))
|
||||||
|
k = k[:, :, start:, :]
|
||||||
|
v = v[:, :, start:, :]
|
||||||
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||||
|
|
||||||
# Need explicit mask
|
# Need explicit mask for sliding window/chunk inference
|
||||||
device = q.device
|
device = q.device
|
||||||
if Tq == Tk:
|
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
|
||||||
# Causal + sliding window
|
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
||||||
mask = torch.tril(torch.ones(Tq, Tk, device=device, dtype=torch.bool))
|
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
||||||
if window > 0 and window < Tq:
|
mask = col_idx <= row_idx
|
||||||
row_idx = torch.arange(Tq, device=device).unsqueeze(1)
|
|
||||||
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
# sliding window (left)
|
||||||
mask = mask & ((row_idx - col_idx) <= window)
|
if window >= 0 and window < Tk:
|
||||||
else:
|
mask = mask & ((row_idx - col_idx) <= window)
|
||||||
# Chunk inference: attend to prefix + causal within chunk
|
|
||||||
prefix_len = Tk - Tq
|
|
||||||
mask = torch.zeros(Tq, Tk, device=device, dtype=torch.bool)
|
|
||||||
mask[:, :prefix_len] = True
|
|
||||||
mask[:, prefix_len:] = torch.tril(torch.ones(Tq, Tq, device=device, dtype=torch.bool))
|
|
||||||
|
|
||||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Public API: Same interface as FA3
|
# Public API: Same interface as FA3
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -178,6 +178,39 @@ class TestFA3VsSDPA:
|
|||||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
|
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
|
||||||
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||||
|
|
||||||
|
def test_kvcache_single_token_sliding_window(self):
|
||||||
|
"""Test single token decode with sliding window smaller than cache size.
|
||||||
|
|
||||||
|
This catches the bug where SDPA ignores window_size during Tq=1 decode.
|
||||||
|
When window < Tk, FA3 only attends to the last (window+1) tokens,
|
||||||
|
but SDPA was attending to all cached tokens.
|
||||||
|
"""
|
||||||
|
B, T_max, H, D = 2, 64, 4, 32
|
||||||
|
T_prefill = 32 # Enough tokens to exceed window
|
||||||
|
window = 8 # Window SMALLER than cache size
|
||||||
|
|
||||||
|
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||||
|
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||||
|
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||||
|
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||||
|
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||||
|
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||||
|
k_cache[:, :T_prefill, :, :] = k_init
|
||||||
|
v_cache[:, :T_prefill, :, :] = v_init
|
||||||
|
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
|
||||||
|
return flash_attn.flash_attn_with_kvcache(
|
||||||
|
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
causal=True, window_size=(window, 0) # window=8 < Tk=33
|
||||||
|
)
|
||||||
|
|
||||||
|
y_fa3, y_sdpa = run_both_impls(run)
|
||||||
|
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token_sliding_window")
|
||||||
|
print(f"single_token_sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||||
|
|
||||||
def test_backward_gradients_match(self):
|
def test_backward_gradients_match(self):
|
||||||
"""Verify gradients are similar between FA3 and SDPA."""
|
"""Verify gradients are similar between FA3 and SDPA."""
|
||||||
B, T, H, D = 2, 32, 4, 16
|
B, T, H, D = 2, 32, 4, 16
|
||||||
|
|||||||
Reference in New Issue
Block a user