integrate Flash Attention 3. +9% tok_per_sec for d12 with ctx even as low as 2048 out of the box nice. also, ready to tune windows huge
This commit is contained in:
+29
-29
@@ -9,9 +9,9 @@ Notable features:
|
||||
- no learnable params in rmsnorm
|
||||
- no bias in linear layers
|
||||
- Group-Query Attention (GQA) support for more efficient inference
|
||||
- Flash Attention 3 integration
|
||||
"""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -23,6 +23,14 @@ from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
from nanochat.adamw import DistAdamW
|
||||
|
||||
# Load Flash Attention 3 from HuggingFace Hub (and silence the progress bar)
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
|
||||
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
|
||||
from kernels import get_kernel
|
||||
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 1024
|
||||
@@ -65,44 +73,36 @@ class CausalSelfAttention(nn.Module):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
||||
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
||||
|
||||
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
||||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
||||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
elif Tq == 1:
|
||||
# During inference but with a single query in this forward pass:
|
||||
# The query has to attend to all the keys/values in the cache
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
# Attention with Flash Attention 3
|
||||
# FA3 handles GQA automatically when n_kv_heads < n_heads
|
||||
if kv_cache is None:
|
||||
# Training: simple causal attention
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True)
|
||||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
|
||||
prefix_len = Tk - Tq
|
||||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
||||
# Inference: use flash_attn_with_kvcache which handles cache management
|
||||
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
||||
y = flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache,
|
||||
k=k, v=v,
|
||||
cache_seqlens=kv_cache.cache_seqlens,
|
||||
causal=True,
|
||||
)
|
||||
# Advance position after last layer processes
|
||||
if self.layer_idx == kv_cache.n_layers - 1:
|
||||
kv_cache.advance(T)
|
||||
|
||||
# Re-assemble the heads side by side and project back to residual stream
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
# Re-assemble the heads and project back to residual stream
|
||||
y = y.contiguous().view(B, T, -1)
|
||||
y = self.c_proj(y)
|
||||
return y
|
||||
|
||||
|
||||
Reference in New Issue
Block a user