add feb2 new leaderboard record from upgrading to fp8 training, +4.3% speedup to time to GPT-2

This commit is contained in:
Andrej Karpathy
2026-02-03 20:54:30 +00:00
parent 6079f78fc3
commit a67eba35dc
2 changed files with 75 additions and 29 deletions
+68
View File
@@ -4,6 +4,74 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
---
## 2026-02-02: FP8 Training with torchao
Integrated FP8 training using `torchao.float8` to accelerate Linear layer matmuls on H100 GPUs.
### Background
FP8 (8-bit floating point) uses H100's FP8 tensor cores for ~2x theoretical matmul throughput. The tradeoff is quantization overhead: computing scales and casting tensors to/from FP8. Still, as an example torchtitan (Meta's distributed training framework) reports 25-28% speedups with FP8 for some of their experiments.
**Previous attempt (Jan 2026):** FP8 on just `lm_head` following modded-nanogpt with custom ops → 1% speedup, +2GB memory. Failed due to fragile torch.compile interaction. But this experiment was also done on ~d12 scale back then instead of the bigger model that gets GPT-2 capability of approx d24.
**This attempt:** Use torchao's `convert_to_float8_training()` on ALL Linear layers, increase model size to d24. The core snippet is:
```python
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
config = Float8LinearConfig.from_recipe_name("tensorwise")
convert_to_float8_training(model, config=config)
```
But in practice it's more involved (see base_train.py).
### Results
**Microbenchmark (d26 MLP, 65536x1664 @ 1664x6656):**
| Method | Forward | Fwd+Bwd | Speedup |
|--------|---------|---------|---------|
| BF16 + compile | 2.00ms | 4.79ms | 1.00x |
| FP8 rowwise + compile | 1.84ms | 4.55ms | 1.08x |
| FP8 tensorwise + compile | 1.45ms | 4.06ms | **1.38x** |
| FP8 rowwise (no compile) | 2.89ms | 21.86ms | 0.23x ❌ |
torch.compile is MANDATORY. Without it, FP8 is 4x slower due to unfused scaling ops.
**Full training (d26):**
| Config | tok/sec | vs baseline |
|--------|---------|-------------|
| BF16 baseline | 630K | 1.00x |
| FP8 rowwise | 564K | 0.90x ❌ |
| FP8 tensorwise | 740K | **1.17x** ✓ |
Memory usage also decreases quite a bit, by ~9GB (activations stored as FP8 instead of BF16).
Seeing 17% speedup is encouraging but we're still not done yet because each step is now in lower precision and less powerful individually, so to make up for the precision drop we have to train longer. Empirically, running some sweeps overnight on d24 scale, I saw that the actual speedup (when you match performance) is closer to 5%. It's possible that our LLMs at ~d24 scale are still too small to confidently enjoy the speedups that come from fp8 for bigger models.
### Key Learnings
For nanochat at approximate scale of interest (~GPT-2 capability, ~d24):
1. **Tensorwise >> Rowwise** - Rowwise computes per-row scales, overhead exceeds benefit. Tensorwise uses one scale per tensor.
2. **Filter small layers** - Layers with dims not divisible by 16 must be skipped (FP8 hardware requirement)
3. **Larger models benefit more** - d12 was still slower with FP8; d26+ shows gains. Therefore, in some depths there is a benefit to fp8 and in some there isn't. Keeping it configurable for now, passed in via kwargs and default off.
4. **The effective, capability-matched speedup is lower still** - because each step is of slightly lower precision/quality.
### Integration
Added `--fp8` flag to `base_train.py`, default recipe is "tensorwise", example of turning on:
```bash
torchrun --nproc_per_node=8 -m scripts.base_train --depth=24 --fp8
```
Uses tensorwise by default. Requires `torchao==0.15.0` (compatible with torch 2.9.1), which was added to dependencies.
**TLDR**: turning on fp8 for GPT-2 capability nanochat model gives approx +5% capability-matched speedup.
---
## 2026-01-29: Hyperball/MuonH Experiments (Negative Result)
Explored Hyperball optimization from [this post](https://psychedelic-sunstone-851.notion.site/Fantastic-Pretraining-Optimizers-and-Where-to-Find-Them-2-1-Hyperball-Optimization-2e924306e6f280e7a5ffee00eb40a0dd) (saved to `knowledge/muonh.md`). Constrains weights to sphere of radius R (initial norm): `W_{t+1} = R · Normalize(W_t - η·R · Normalize(u_t))`. Had to change a number of details in a branch, e.g. not use zero init for our projections (or the initial norm would be zero), keep track of the initial norm, adjust Muon -> MuonH for the update.