auto-calculate optimal batch size. the original setting of 0.5M was only optimal for d12, but d26 prefers 1M and so on
This commit is contained in:
+110
-91
@@ -11,11 +11,14 @@ If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Ex
|
||||
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
import argparse
|
||||
from dataclasses import asdict
|
||||
from contextlib import nullcontext, contextmanager
|
||||
|
||||
import wandb
|
||||
@@ -53,8 +56,8 @@ parser.add_argument("--num-iterations", type=int, default=-1, help="explicit num
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)")
|
||||
# Optimization
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.")
|
||||
parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
@@ -78,8 +81,8 @@ parser.add_argument("--model-tag", type=str, default=None, help="override model
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
# Compute init and wandb logging
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
@@ -109,65 +112,39 @@ else:
|
||||
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
||||
print0("!" * 80)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer will be useful for evaluation and also we need the vocab size to init the model
|
||||
tokenizer = get_tokenizer()
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
|
||||
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
||||
# (For very small depths, this gives a slight "unfair" advantage to models with odd depths)
|
||||
num_layers = args.depth
|
||||
base_dim = args.depth * args.aspect_ratio
|
||||
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
||||
num_heads = model_dim // args.head_dim
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
head_dim = model_dim // num_heads
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"head_dim: {head_dim}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
|
||||
batch_lr_scale = 1.0
|
||||
reference_batch_size = 2**19
|
||||
batch_ratio = args.total_batch_size / reference_batch_size
|
||||
if batch_ratio != 1.0:
|
||||
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
||||
# AdamW: sqrt scaling is standard
|
||||
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
|
||||
batch_lr_scale = batch_ratio ** 0.5
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
|
||||
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
|
||||
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
|
||||
if args.depth != 12:
|
||||
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, window_pattern=args.window_pattern)
|
||||
with torch.device("meta"):
|
||||
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
|
||||
model.init_weights() # All tensors get initialized
|
||||
def build_model_meta(depth):
|
||||
"""Build a model on meta device for a given depth (shapes/dtypes only, no data)."""
|
||||
# Model dim is nudged up to nearest multiple of head_dim for clean division
|
||||
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
||||
base_dim = depth * args.aspect_ratio
|
||||
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
||||
num_heads = model_dim // args.head_dim
|
||||
config = GPTConfig(
|
||||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
return model_meta
|
||||
|
||||
# Build the model, move to device, init the weights
|
||||
model = build_model_meta(args.depth) # 1) Build on meta device (only shapes/dtypes, no data)
|
||||
model_config = model.config
|
||||
model_config_kwargs = asdict(model_config)
|
||||
print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}")
|
||||
model.to_empty(device=device) # 2) All tensors get storage on target device but with uninitialized (garbage) data
|
||||
model.init_weights() # 3) All tensors get initialized
|
||||
|
||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||
base_dir = get_base_dir()
|
||||
@@ -181,41 +158,7 @@ if resuming:
|
||||
del model_data # free up this memory after the copy
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Determine the length of the training run based on model size
|
||||
|
||||
# Detailed parameter counts
|
||||
param_counts = model.num_scaling_params()
|
||||
print0(f"Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
print0(f"{key:24s}: {value:,}")
|
||||
num_params = param_counts['total']
|
||||
num_scaling_params = param_counts['transformer_matrices'] + param_counts['lm_head'] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
||||
if args.num_iterations > 0:
|
||||
num_iterations = args.num_iterations
|
||||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif args.target_flops > 0:
|
||||
# calculate the number of iterations from the target flops
|
||||
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# FP8 training initialization and management (has to be done before torch.compile)
|
||||
# FP8 training initialization and management (this has to be done before torch.compile)
|
||||
|
||||
# Convert Linear layers to Float8Linear if --fp8 is set
|
||||
if args.fp8:
|
||||
@@ -293,6 +236,82 @@ def disable_fp8(model):
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Determine the optimization horizon based on the model size
|
||||
# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis).
|
||||
# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params
|
||||
|
||||
# Get the parameter counts of the model
|
||||
param_counts = model.num_scaling_params()
|
||||
print0(f"Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
print0(f"{key:24s}: {value:,}")
|
||||
num_params = param_counts['total']
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Scaling params: transformer matrices + lm_head (gives cleanest scaling laws, see dev/LOG.md Jan 27, 2026)
|
||||
get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head']
|
||||
num_scaling_params = get_scaling_params(model)
|
||||
target_tokens = int(args.target_param_data_ratio * num_scaling_params)
|
||||
|
||||
# Auto-compute optimal batch size based on Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738
|
||||
if args.total_batch_size == -1:
|
||||
d12_ref = build_model_meta(12) # d12 is where the optimal batch size was measured to be 2**19 tokens
|
||||
d12_num_scaling_params = get_scaling_params(d12_ref)
|
||||
D_REF = args.target_param_data_ratio * d12_num_scaling_params
|
||||
B_REF = 2**19
|
||||
args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383)) # also clamp to power of 2
|
||||
print0(f"Auto-computed optimal batch size: {args.total_batch_size:,} tokens")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0
|
||||
if args.num_iterations > 0:
|
||||
# Override num_iterations to a specific value if given
|
||||
num_iterations = args.num_iterations
|
||||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif args.target_flops > 0:
|
||||
# Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh)
|
||||
num_iterations = round(args.target_flops / (num_flops_per_token * args.total_batch_size))
|
||||
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
|
||||
elif args.target_param_data_ratio > 0:
|
||||
# Calculate the number of iterations from the target param data ratio (the most common use case)
|
||||
num_iterations = target_tokens // args.total_batch_size
|
||||
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
|
||||
else:
|
||||
raise ValueError("No training horizon specified")
|
||||
total_tokens = args.total_batch_size * num_iterations
|
||||
print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert args.total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
grad_accum_steps = args.total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
|
||||
batch_lr_scale = 1.0
|
||||
reference_batch_size = 2**19
|
||||
batch_ratio = args.total_batch_size / reference_batch_size
|
||||
if batch_ratio != 1.0:
|
||||
# SGD: linear scaling with batch size is standard (not used in nanochat)
|
||||
# AdamW: sqrt scaling is standard
|
||||
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
|
||||
batch_lr_scale = batch_ratio ** 0.5
|
||||
print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {args.total_batch_size:,} (reference: {reference_batch_size:,})")
|
||||
|
||||
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
|
||||
weight_decay_scaled = args.weight_decay * (12 / args.depth)**2
|
||||
if args.depth != 12:
|
||||
print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
|
||||
Reference in New Issue
Block a user