All of these improvements were developed by Claude running autonomously over ~2 days using autoresearch. I didn't touch anything - incredible. All tuning was done on d12 but generalized easily to larger models (e.g. d24 in particular). This means we will also get a new "Time to GPT-2" Leaderboard entry, which I will push separately.
Optimizer & schedule changes: - Increase unembedding LR 0.004 -> 0.008, weight decay 0.2 -> 0.28 - Per-group Adam betas and weight decay (instead of shared global betas) - Muon beta2 0.95 -> 0.9, momentum warmup target 0.95 -> 0.97 over 400 steps - Warmup: ratio-based -> absolute steps (default 40) - Warmdown ratio 0.5 -> 0.65, final LR fraction 0.0 -> 0.05 - Weight decay schedule: linear -> cosine decay - Polar express norm factor 1.02 -> 1.01 Architecture & init changes: - VE gate: channels 32 -> 12, scale range 2x -> 3x, init small positive - Add post-QK-norm scaling (q,k *= 1.15) for sharper attention - Embedding init std 1.0 -> 0.8, MLP c_fc init 0.5x smaller - RoPE base theta 10K -> 100K - Short attention window: seq_len/2 -> ~seq_len/3 (ceil to 128 tile) - Logit softcap 20 -> 15
This commit is contained in:
+12
-15
@@ -60,15 +60,13 @@ parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help=
|
||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.")
|
||||
parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)")
|
||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.2, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)")
|
||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)")
|
||||
parser.add_argument("--adam-beta1", type=float, default=0.8, help="Adam beta1 for embedding/unembedding")
|
||||
parser.add_argument("--adam-beta2", type=float, default=0.95, help="Adam beta2 for embedding/unembedding")
|
||||
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup")
|
||||
parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown")
|
||||
parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR")
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
@@ -311,7 +309,6 @@ optimizer = model.setup_optimizer(
|
||||
unembedding_lr=args.unembedding_lr * batch_lr_scale,
|
||||
embedding_lr=args.embedding_lr * batch_lr_scale,
|
||||
scalar_lr=args.scalar_lr * batch_lr_scale,
|
||||
adam_betas=(args.adam_beta1, args.adam_beta2),
|
||||
# Muon hyperparameters
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
@@ -360,7 +357,7 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(args.warmup_ratio * num_iterations)
|
||||
warmup_iters = args.warmup_steps
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
if it < warmup_iters:
|
||||
return (it + 1) / warmup_iters
|
||||
@@ -370,15 +367,15 @@ def get_lr_multiplier(it):
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer (warms up to 0.95 over the first 300 steps)
|
||||
# Momentum scheduler for Muon optimizer (warms up to 0.97 over the first 400 steps)
|
||||
def get_muon_momentum(it):
|
||||
frac = min(it / 300, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.95
|
||||
frac = min(it / 400, 1)
|
||||
momentum = (1 - frac) * 0.85 + frac * 0.97
|
||||
return momentum
|
||||
|
||||
# Weight decay scheduler for Muon optimizer (linearly decays to zero over the course of training)
|
||||
# Weight decay scheduler for Muon optimizer (cosine decay to zero over the course of training)
|
||||
def get_weight_decay(it):
|
||||
return weight_decay_scaled * (1 - it / num_iterations)
|
||||
return weight_decay_scaled * 0.5 * (1 + math.cos(math.pi * it / num_iterations))
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
@@ -605,7 +602,7 @@ get_report().log(section="Base model training", data=[
|
||||
"Number of training tokens": total_tokens,
|
||||
"Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params,
|
||||
"DDP world size": ddp_world_size,
|
||||
"warmup_ratio": args.warmup_ratio,
|
||||
"warmup_steps": args.warmup_steps,
|
||||
"warmdown_ratio": args.warmdown_ratio,
|
||||
"final_lr_frac": args.final_lr_frac,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user