auto-calculate optimal batch size. the original setting of 0.5M was only optimal for d12, but d26 prefers 1M and so on
This commit is contained in:
+46
@@ -4,6 +4,52 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-05: Auto Batch Size Scaling
|
||||
|
||||
### Background
|
||||
|
||||
So far, the `--total-batch-size` was hardcoded to be `2**19 = 524,288` ~= 0.5M tokens. This was the optimal setting for d12, but when I tried to re-tune it for d26 (GPT-2), I noticed that the optimal was closer to `2**20 = 1,048,576` ~= 1M tokens. This is to be expected - larger models prefer a higher optimal total batch size. However, we have to make sure that all settings of `--depth` get their own optimal batch size calculated in some principled way. Here, I referenced the "Power Lines" paper from Cerebras ([arXiv:2505.13738](https://arxiv.org/abs/2505.13738)) for a lot of related experimentation. In particular, they found that **Bopt ∝ D^0.383** (where D is the number of training tokens, not the number of parameters!). So the idea is to tune the optimal batch size on d12, and then extrapolate it with this power law to bigger models. The 0.383 exponent means batch size grows slowly: 10× more tokens only justifies ~2.4× bigger batch. For nanochat's compute-optimal training (D ∝ N via `--target-param-data-ratio`), this means deeper models naturally want larger batches.
|
||||
|
||||
### Implementation
|
||||
|
||||
Added `--total-batch-size=-1` (now the default) to auto-compute optimal batch:
|
||||
|
||||
```python
|
||||
get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head']
|
||||
if args.total_batch_size == -1:
|
||||
D_REF = args.target_param_data_ratio * get_scaling_params(build_model_meta(12))
|
||||
B_REF = 2**19
|
||||
args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383))
|
||||
```
|
||||
|
||||
Reference point: d=12 model with B=2^19 (empirically validated). The reference is computed dynamically so that if the architecture changes (e.g., different `--aspect-ratio`), the math automatically adjusts. However, if the model actually does change too much, one would also want to re-tune the optimal batch size for d=12.
|
||||
|
||||
### Results
|
||||
|
||||
With this formula, we currently get:
|
||||
|
||||
| Depth | Scaling Params | Target Tokens | Auto Batch |
|
||||
|-------|---------------|---------------|------------|
|
||||
| d=8 | 42M | 0.44B | 2^18 = 262K |
|
||||
| d=10-16 | 70M-235M | 0.7B-2.5B | 2^19 = 524K |
|
||||
| d=18-26 | 324M-918M | 3.4B-9.6B | 2^20 = 1.05M |
|
||||
| d=32-50 | 1.7B-6.2B | 17.6B-65.6B | 2^21 = 2.1M |
|
||||
|
||||
In particular, this matches empirical observations that d26 prefers ~2^20 while d12 prefers ~2^19.
|
||||
|
||||
### Code Cleanup
|
||||
|
||||
Also refactored model initialization to use `build_model_meta(depth)` helper and `dataclasses.asdict()` for cleaner config handling.
|
||||
|
||||
### Useful references
|
||||
|
||||
- [Bergsma et al., Power Laws for Batch Size, Model Size, and Training Horizon](https://arxiv.org/abs/2505.13738)
|
||||
- [McCandlish et al., An Empirical Model of Large-Batch Training](https://arxiv.org/abs/1812.06162)
|
||||
- [Brown et al., Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)
|
||||
- [Merrill et al., The Batch Size–Critical Batch Size Myth](https://arxiv.org/abs/2505.23971)
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-05: SwiGLU Activation (Negative Result)
|
||||
|
||||
Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×:
|
||||
|
||||
Reference in New Issue
Block a user