implement flash attention 3 fallback to pytorch sdpa by touching as few lines of code as possible in main files and keeping all implementation to a single file. add tests. add helpful warning messages for the user.

This commit is contained in:
Andrej Karpathy
2026-01-16 17:37:51 +00:00
parent 50413d2d67
commit 8203efa919
3 changed files with 354 additions and 9 deletions
+3 -9
View File
@@ -23,13 +23,8 @@ from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
from kernels import get_kernel
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
@dataclass
class GPTConfig:
@@ -87,8 +82,7 @@ class CausalSelfAttention(nn.Module):
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
# Attention with Flash Attention 3
# FA3 handles GQA automatically when n_kv_heads < n_heads
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window