delete autocast, an unnecessary thorn in my side, manage dtypes directly

This commit is contained in:
Andrej Karpathy
2026-03-04 23:55:24 +00:00
parent 752abc836e
commit 1076f97059
15 changed files with 258 additions and 167 deletions
+27 -18
View File
@@ -19,7 +19,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
@@ -40,8 +40,14 @@ class GPTConfig:
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
class Linear(nn.Linear):
"""nn.Linear that casts weights to match input dtype in forward.
Replaces autocast: master weights stay fp32 for optimizer precision,
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
def forward(self, x):
return F.linear(x, self.weight.to(dtype=x.dtype))
def has_ve(layer_idx, n_layer):
@@ -66,12 +72,12 @@ class CausalSelfAttention(nn.Module):
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size()
@@ -121,8 +127,8 @@ class CausalSelfAttention(nn.Module):
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
@@ -164,7 +170,7 @@ class GPT(nn.Module):
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
# Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
@@ -234,11 +240,13 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
# Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
# embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
# because GradScaler cannot unscale fp16 gradients.
if COMPUTE_DTYPE != torch.float16:
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16)
ve.to(dtype=COMPUTE_DTYPE)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
@@ -253,7 +261,7 @@ class GPT(nn.Module):
# calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
@@ -391,18 +399,19 @@ class GPT(nn.Module):
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Forward the trunk of the Transformer
x = self.transformer.wte(idx) # embed current token
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
x = norm(x)
x0 = x # save initial normalized embedding for x0 residual
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
x = norm(x)