delete autocast, an unnecessary thorn in my side, manage dtypes directly

This commit is contained in:
Andrej Karpathy
2026-03-04 23:55:24 +00:00
parent 752abc836e
commit 1076f97059
15 changed files with 258 additions and 167 deletions
+35
View File
@@ -4,6 +4,41 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
---
## 2026-03-04: Remove autocast, explicit dtype management, fp16 GradScaler
Replaced `torch.amp.autocast` throughout the codebase with explicit dtype management via a single `COMPUTE_DTYPE` global. Also added fp16 training support with GradScaler.
### Motivation
autocast is "magic we don't control" — it silently decides which ops run in which precision via internal allowlists. For this codebase, autocast was doing very little: the only thing it actually cast was `nn.Linear` weights from fp32 to bf16 for matmuls. `F.rms_norm`, `F.cross_entropy`, and Flash Attention all handle their own dtypes already. By making precision explicit, we gain fine-grained control (e.g. can experiment with fp32 norms) and eliminate an unnecessary layer of abstraction.
### What changed
**Core mechanism** (`nanochat/common.py`, `nanochat/gpt.py`):
- `COMPUTE_DTYPE` auto-detected from hardware: SM 80+ → bf16, pre-Ampere → fp32, CPU/MPS → fp32. Override via `NANOCHAT_DTYPE` env var.
- Custom `Linear(nn.Linear)` class that casts weights to match input dtype in forward: `F.linear(x, self.weight.to(dtype=x.dtype))`. This is the single mechanism that replaces autocast.
- Embeddings cast to `COMPUTE_DTYPE` at init (saves memory). Exception: fp16 keeps embeddings fp32 because GradScaler cannot unscale fp16 gradients.
- Embedding output explicitly cast to `COMPUTE_DTYPE` in `GPT.forward()` (no-op for bf16, active for fp16 path).
- RoPE cos/sin cache uses `COMPUTE_DTYPE` instead of hardcoded bf16.
**Autocast removal** (11 files):
- Deleted `--dtype` CLI flag, `ptdtype` variables, `autocast_ctx` definitions, and all `with autocast_ctx:` blocks from: `base_train.py`, `chat_sft.py`, `chat_rl.py`, `chat_cli.py`, `chat_eval.py`, `chat_web.py`, `base_eval.py`, `engine.py`, `bench_train_toks.py`, `test_e2e_pipeline.py`.
**fp16 + GradScaler** (`base_train.py`, `chat_sft.py`):
- `scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None`
- Backward: `scaler.scale(loss).backward()` vs plain `loss.backward()`
- After accumulation: `scaler.unscale_(optimizer)` → distributed inf-sync via `scaler._found_inf_per_device(optimizer)` all-reduced with `ReduceOp.MAX``scaler.step(optimizer)``scaler.update()`
- Zero overhead for bf16/fp32 paths (scaler is None, no branching inside kernels).
**FP8 fix** (`nanochat/fp8.py`, `base_train.py`):
- `Float8Linear.forward` explicitly casts input to `COMPUTE_DTYPE` (previously relied on autocast).
- `disable_fp8` context manager now creates our custom `Linear` (not vanilla `nn.Linear`) when swapping out Float8Linear during eval.
**Flash Attention** (`flash_attention.py`):
- FA3 Hopper kernels don't support fp16 or fp32, so `USE_FA3` (module-level constant, resolved once at import) returns False, falling back to SDPA.
---
## 2026-03-04: Dataset upgrade: FineWeb-EDU 100B → ClimbMix 400B
Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction.