320 experiments just to tune the adam beta1 of x0 a little bit up from 0.8 to 0.96

This commit is contained in:
Andrej Karpathy
2026-01-25 00:03:55 +00:00
parent 6a477eedbd
commit 85b3e95e09
2 changed files with 67 additions and 1 deletions
+1 -1
View File
@@ -349,7 +349,7 @@ class GPT(nn.Module):
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
dict(params=x0_params, lr=scalar_lr),
dict(params=x0_params, lr=scalar_lr, betas=(0.96, 0.95)), # higher beta1 for x0 scalars
]
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)