integrate Flash Attention 3. +9% tok_per_sec for d12 with ctx even as low as 2048 out of the box nice. also, ready to tune windows huge

This commit is contained in:
Andrej Karpathy
2026-01-11 20:33:19 +00:00
parent 201d705957
commit 2ff7d51252
6 changed files with 177 additions and 143 deletions
+57 -47
View File
@@ -39,13 +39,9 @@ class MockModel:
def forward(self, ids, kv_cache=None):
"""Return uniform logits so sampling is spread across vocab."""
B, T = ids.shape
# Simulate what a real transformer does: insert k,v into the cache for each layer
# With FA3, flash_attn_with_kvcache updates cache in-place and we advance position
if kv_cache is not None:
head_dim = self.config.n_embd // self.config.n_head
for layer_idx in range(self.config.n_layer):
k = torch.zeros(B, self.config.n_kv_head, T, head_dim)
v = torch.zeros(B, self.config.n_kv_head, T, head_dim)
kv_cache.insert_kv(layer_idx, k, v)
kv_cache.advance(T)
# Uniform logits -> equal probability for all tokens
logits = torch.zeros(B, T, self.vocab_size)
return logits
@@ -85,16 +81,11 @@ class ByteTokenizer:
byte_tokens = [t for t in tokens if t < 256]
return bytes(byte_tokens).decode("utf-8", errors="replace")
def test_kv_cache_resize():
"""
The KV cache was not resized correctly, more information here:
https://github.com/karpathy/nanochat/pull/186
This test reproduces the issue and will be merged alongside the fix.
"""
def test_kv_cache_basic():
"""Test basic KVCache functionality for FA3."""
batch_size = 2
num_heads = 3
seq_len = 4
seq_len = 64
head_dim = 5
num_layers = 6
@@ -103,45 +94,64 @@ def test_kv_cache_resize():
num_heads=num_heads,
seq_len=seq_len,
head_dim=head_dim,
num_layers=num_layers
num_layers=num_layers,
device="cpu",
)
# Insert a single token with a distinct fill value to all layers
def insert_token(token_idx):
for layer_idx in range(num_layers):
k = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx), dtype=torch.float32)
v = torch.full((batch_size, num_heads, 1, head_dim), fill_value=float(token_idx * 100), dtype=torch.float32)
kv_cache.insert_kv(layer_idx, k, v)
# Check initial state
assert kv_cache.get_pos() == 0
assert kv_cache.k_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
assert kv_cache.v_cache.shape == (num_layers, batch_size, seq_len, num_heads, head_dim)
# Insert 4 tokens (fills the initial seq_len=4)
for i in range(4):
insert_token(i)
# Test advance
kv_cache.advance(10)
assert kv_cache.get_pos() == 10
# Record the original state of the cache
original_cache = kv_cache.kv_cache.clone()
original_seq_len = original_cache.shape[4]
kv_cache.advance(5)
assert kv_cache.get_pos() == 15
# Insert the 5th token, which will trigger a resize
insert_token(4)
# Verify that the cache actually resized
new_seq_len = kv_cache.kv_cache.shape[4]
assert new_seq_len > original_seq_len, f"Cache did not resize: original seq_len={original_seq_len}, new seq_len={new_seq_len}"
# Test reset
kv_cache.reset()
assert kv_cache.get_pos() == 0
# Verify that the original 4 tokens are still intact after resize
for layer_idx in range(num_layers):
for token_idx in range(4):
# Check that resized cache matches expected values
expected_k = float(token_idx)
expected_v = float(token_idx * 100)
actual_k = kv_cache.kv_cache[layer_idx, 0, :, :, token_idx, :]
actual_v = kv_cache.kv_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == expected_k).all(), f"Layer {layer_idx}, token {token_idx}: key corrupted, expected {expected_k}"
assert (actual_v == expected_v).all(), f"Layer {layer_idx}, token {token_idx}: value corrupted, expected {expected_v}"
# And that the original cache matches resized cache
original_k = original_cache[layer_idx, 0, :, :, token_idx, :]
original_v = original_cache[layer_idx, 1, :, :, token_idx, :]
assert (actual_k == original_k).all(), f"Layer {layer_idx}, token {token_idx}: key doesn't match original"
assert (actual_v == original_v).all(), f"Layer {layer_idx}, token {token_idx}: value doesn't match original"
# Test get_layer_cache returns correct views
k_layer0, v_layer0 = kv_cache.get_layer_cache(0)
assert k_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
assert v_layer0.shape == (batch_size, seq_len, num_heads, head_dim)
def test_kv_cache_prefill():
"""Test KVCache.prefill() copies data correctly."""
batch_size = 1
num_heads = 4
head_dim = 8
num_layers = 2
# Create source cache and advance it
src_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=32,
head_dim=head_dim, num_layers=num_layers, device="cpu",
)
# Write some data to source cache
src_cache.k_cache[0, 0, :16, :, :] = 1.0
src_cache.v_cache[0, 0, :16, :, :] = 2.0
src_cache.advance(16)
# Create destination cache with larger seq_len
dst_cache = KVCache(
batch_size=batch_size, num_heads=num_heads, seq_len=64,
head_dim=head_dim, num_layers=num_layers, device="cpu",
)
# Prefill
dst_cache.prefill(src_cache)
# Check position was copied
assert dst_cache.get_pos() == 16
# Check data was copied
assert (dst_cache.k_cache[0, 0, :16, :, :] == 1.0).all()
assert (dst_cache.v_cache[0, 0, :16, :, :] == 2.0).all()
def test_multi_sample_first_token_diversity():