integrate Flash Attention 3. +9% tok_per_sec for d12 with ctx even as low as 2048 out of the box nice. also, ready to tune windows huge

This commit is contained in:
Andrej Karpathy
2026-01-11 20:33:19 +00:00
parent 201d705957
commit 2ff7d51252
6 changed files with 177 additions and 143 deletions
+40 -67
View File
@@ -82,83 +82,54 @@ def use_calculator(expr):
# -----------------------------------------------------------------------------
class KVCache:
"""
Works hand-in-hand with the GPT model to maintain the KV cache.
Note that the .pos advances automatically after the last layer of the Transformer inserts.
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
Key differences from FA2-style cache:
- Tensors are (B, T, H, D) not (B, H, T, D)
- FA3 updates the cache in-place during flash_attn_with_kvcache
- Position tracked per batch element via cache_seqlens tensor
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None
self.pos = 0 # current position in time in the cache
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype=torch.bfloat16):
self.batch_size = batch_size
self.max_seq_len = seq_len
self.n_layers = num_layers
self.n_heads = num_heads
self.head_dim = head_dim
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
# Current sequence length per batch element (FA3 needs int32)
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
def reset(self):
self.pos = 0
"""Reset cache to empty state."""
self.cache_seqlens.zero_()
def get_pos(self):
return self.pos
"""Get current position (assumes all batch elements at same position)."""
return self.cache_seqlens[0].item()
def get_layer_cache(self, layer_idx):
"""Return (k_cache, v_cache) views for a specific layer."""
return self.k_cache[layer_idx], self.v_cache[layer_idx]
def advance(self, num_tokens):
"""Advance the cache position by num_tokens."""
self.cache_seqlens += num_tokens
def prefill(self, other):
"""
Prefill given another KV cache. Optionally expand along batch dim.
This is used when we do batch 1 prefill and then want to generate
multiple samples in parallel from there.
Copy cached KV from another cache into this one.
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
"""
# 1) validate the shapes
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
# Extract dimensions explicitly
self_layers, self_kv, self_batch, self_heads, self_seq, self_head_dim = self.kv_shape
other_layers, other_kv, other_batch, other_heads, other_seq, other_head_dim = other.kv_shape
# Validate dimensions
assert self_layers == other_layers, f"Layer count mismatch: {self_layers} != {other_layers}"
assert self_kv == other_kv, f"K/V dimension mismatch: {self_kv} != {other_kv}"
assert self_heads == other_heads, f"Head count mismatch: {self_heads} != {other_heads}"
assert self_head_dim == other_head_dim, f"Head dim mismatch: {self_head_dim} != {other_head_dim}"
# Batch size can be expanded (other can be 1, self can be larger)
assert self_batch == other_batch or other_batch == 1, f"Batch size mismatch: {self_batch} vs {other_batch} (other must be 1 or equal)"
# Sequence length: self must be longer than other
assert self_seq >= other_seq, f"Sequence length mismatch: {self_seq} < {other_seq}"
# 2) initialize the cache
dtype, device = other.kv_cache.dtype, other.kv_cache.device
self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
# 3) copy the data over
self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
# 4) update the pos
self.pos = other.pos
def insert_kv(self, layer_idx, k, v):
# Lazy initialize the cache here because we need to know the dtype/device
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
# Insert new keys/values to the cache and return the full cache so far
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
# Dynamically grow the cache if needed
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
additional_shape = list(self.kv_cache.shape)
additional_shape[4] = t_needed - self.kv_cache.size(4)
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4).contiguous()
self.kv_shape = self.kv_cache.shape
# Insert k, v into the cache
self.kv_cache[layer_idx, 0, :, :, t0:t1, :] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1, :] = v
# Return the full cached keys/values up to current position (as a view)
key_view = self.kv_cache[layer_idx, 0, :, :, :t1, :]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1, :]
# Increment pos after the last layer of the Transformer processes
if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
assert self.max_seq_len >= other.max_seq_len
other_pos = other.get_pos()
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
self.cache_seqlens.fill_(other_pos)
# -----------------------------------------------------------------------------
@torch.inference_mode()
@@ -219,6 +190,7 @@ class Engine:
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
device=device,
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
@@ -230,6 +202,7 @@ class Engine:
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint,
device=device,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
+29 -29
View File
@@ -9,9 +9,9 @@ Notable features:
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
- Flash Attention 3 integration
"""
import math
from functools import partial
from dataclasses import dataclass
@@ -23,6 +23,14 @@ from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
@dataclass
class GPTConfig:
sequence_len: int = 1024
@@ -65,44 +73,36 @@ class CausalSelfAttention(nn.Module):
B, T, C = x.size()
# Project the input to get queries, keys, and values
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
# During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Attention with Flash Attention 3
# FA3 handles GQA automatically when n_kv_heads < n_heads
if kv_cache is None:
# Training: simple causal attention
y = flash_attn.flash_attn_func(q, k, v, causal=True)
else:
# During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k=k, v=v,
cache_seqlens=kv_cache.cache_seqlens,
causal=True,
)
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y