diff --git a/dev/LOG.md b/dev/LOG.md index c7d8b80..7944526 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,67 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-01-13: FP8 Training for lm_head + +Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16. + +### Implementation Approaches Tried + +**1. Dynamic Scaling (failed)** +- Compute `x.abs().max()` and `w.abs().max()` each forward to determine scales +- Problem: `.item()` calls cause graph breaks with torch.compile +- Tried `@torch._dynamo.allow_in_graph` pattern (like torchao.float8) - worked but no speedup +- Tried `torch.library.custom_op` with float scales - caused NaN gradients after first optimizer step +- Root cause: interaction between custom ops, dynamic scale computation, and torch.compile is fragile + +**2. Static Scaling (partial success)** +- Pre-set scales at init time like modded-nanogpt: `x_scale=10/448, w_scale=0.1/448` +- `grad_scale` computed dynamically from batch size (safe since it's just `1/(B*T)/57344` due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set `grad_scale = 0.75/448`, but grads are in E5M2 so this should probably be `1/57344`, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction. +- Uses `torch.library.custom_op` with `@torch.compile` on inner kernels +- This works correctly - no NaNs, proper gradients + +### Results (d12) + +| Metric | BF16 Baseline | FP8 lm_head | +|--------|---------------|-------------| +| GPU Memory | 34 GB | 36 GB | +| tok/sec | baseline | ~1% faster | + +### The Memory Mystery + +FP8 *should* save memory since we store `x_f8` (1 byte) instead of `x` (2 bytes) for backward. But we see 2GB *increase*. Suspected causes: +- `torch.compile` on inner kernels creating extra buffers/specializations +- `torch._scaled_mm` internal workspace allocations +- Custom op registration machinery overhead + +Tried saving original weight `w` (just a reference to parameter) instead of `w_f8` in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump. + +### Microbenchmark vs Reality + +Raw microbenchmark showed promise: +- BF16 matmul: 16.95 ms +- FP8 matmul (static scales): 10.31 ms (1.64x faster) +- FP8 with dynamic scaling: 12.25 ms (1.38x faster) + +But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w. + +### Code Artifacts + +See the branch `fp8_attempt_fail` for: + +- `nanochat/fp8_static.py` - Static scaling implementation (working) +- `nanochat/fp8_dynamic.py` - Dynamic scaling implementation (torchao-style, working but slow) +- `gpt.py` imports `fp8_static.LinearFP8` and simply swaps it for `lm_head` in `gpt.py`. + +### Open Questions + +- Why does the custom op approach use more memory than vanilla BF16? +- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Ahmdal's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized. + +**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao. + +--- + ## 2026-01-12: Multi-Token Prediction (MTP) Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss.