Fix SDPA KV-cache decode to respect sliding window (#456)
SDPA fallback now respects sliding window during single-token KV-cache decode by slicing K/V to the last (window + 1) tokens. Also simplifies the mask building for chunk inference to properly apply sliding window in that path as well. Fixes #452 Co-Authored-By: Kartik Vashishta <kartikv776@gmail.com> Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -178,6 +178,39 @@ class TestFA3VsSDPA:
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token")
|
||||
print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_kvcache_single_token_sliding_window(self):
|
||||
"""Test single token decode with sliding window smaller than cache size.
|
||||
|
||||
This catches the bug where SDPA ignores window_size during Tq=1 decode.
|
||||
When window < Tk, FA3 only attends to the last (window+1) tokens,
|
||||
but SDPA was attending to all cached tokens.
|
||||
"""
|
||||
B, T_max, H, D = 2, 64, 4, 32
|
||||
T_prefill = 32 # Enough tokens to exceed window
|
||||
window = 8 # Window SMALLER than cache size
|
||||
|
||||
k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_cache[:, :T_prefill, :, :] = k_init
|
||||
v_cache[:, :T_prefill, :, :] = v_init
|
||||
cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE)
|
||||
return flash_attn.flash_attn_with_kvcache(
|
||||
q_single, k_cache, v_cache, k=k_single, v=v_single,
|
||||
cache_seqlens=cache_seqlens,
|
||||
causal=True, window_size=(window, 0) # window=8 < Tk=33
|
||||
)
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token_sliding_window")
|
||||
print(f"single_token_sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
||||
|
||||
def test_backward_gradients_match(self):
|
||||
"""Verify gradients are similar between FA3 and SDPA."""
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
|
||||
Reference in New Issue
Block a user