auto-calculate optimal batch size. the original setting of 0.5M was only optimal for d12, but d26 prefers 1M and so on

This commit is contained in:
Andrej Karpathy
2026-02-05 19:40:37 +00:00
parent 98eed6df18
commit f41dd3cbd7
2 changed files with 156 additions and 91 deletions
+46
View File
@@ -4,6 +4,52 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
---
## 2026-02-05: Auto Batch Size Scaling
### Background
So far, the `--total-batch-size` was hardcoded to be `2**19 = 524,288` ~= 0.5M tokens. This was the optimal setting for d12, but when I tried to re-tune it for d26 (GPT-2), I noticed that the optimal was closer to `2**20 = 1,048,576` ~= 1M tokens. This is to be expected - larger models prefer a higher optimal total batch size. However, we have to make sure that all settings of `--depth` get their own optimal batch size calculated in some principled way. Here, I referenced the "Power Lines" paper from Cerebras ([arXiv:2505.13738](https://arxiv.org/abs/2505.13738)) for a lot of related experimentation. In particular, they found that **Bopt ∝ D^0.383** (where D is the number of training tokens, not the number of parameters!). So the idea is to tune the optimal batch size on d12, and then extrapolate it with this power law to bigger models. The 0.383 exponent means batch size grows slowly: 10× more tokens only justifies ~2.4× bigger batch. For nanochat's compute-optimal training (D ∝ N via `--target-param-data-ratio`), this means deeper models naturally want larger batches.
### Implementation
Added `--total-batch-size=-1` (now the default) to auto-compute optimal batch:
```python
get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head']
if args.total_batch_size == -1:
D_REF = args.target_param_data_ratio * get_scaling_params(build_model_meta(12))
B_REF = 2**19
args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383))
```
Reference point: d=12 model with B=2^19 (empirically validated). The reference is computed dynamically so that if the architecture changes (e.g., different `--aspect-ratio`), the math automatically adjusts. However, if the model actually does change too much, one would also want to re-tune the optimal batch size for d=12.
### Results
With this formula, we currently get:
| Depth | Scaling Params | Target Tokens | Auto Batch |
|-------|---------------|---------------|------------|
| d=8 | 42M | 0.44B | 2^18 = 262K |
| d=10-16 | 70M-235M | 0.7B-2.5B | 2^19 = 524K |
| d=18-26 | 324M-918M | 3.4B-9.6B | 2^20 = 1.05M |
| d=32-50 | 1.7B-6.2B | 17.6B-65.6B | 2^21 = 2.1M |
In particular, this matches empirical observations that d26 prefers ~2^20 while d12 prefers ~2^19.
### Code Cleanup
Also refactored model initialization to use `build_model_meta(depth)` helper and `dataclasses.asdict()` for cleaner config handling.
### Useful references
- [Bergsma et al., Power Laws for Batch Size, Model Size, and Training Horizon](https://arxiv.org/abs/2505.13738)
- [McCandlish et al., An Empirical Model of Large-Batch Training](https://arxiv.org/abs/1812.06162)
- [Brown et al., Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)
- [Merrill et al., The Batch SizeCritical Batch Size Myth](https://arxiv.org/abs/2505.23971)
---
## 2026-02-05: SwiGLU Activation (Negative Result)
Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×: