Big Muon optimizer changes inspired by latest of modded-nanogpt. Added Polar Express, Adafactor-style variance reduction, cautious weight decay, schedule weight decay linearly to ramp down to zero. Tuned optimum weight decay for multiple model sizes d8, d12, d16, d20 and found a scaling law with optimum wd \propto 1/channels^2, including it as default into code. --weight_decay of base_train is now default on and configured optimally according to all of these experiments. Solid bump to val_bpb observed as a result of these changes.

This commit is contained in:
Andrej Karpathy
2026-01-11 16:56:59 +00:00
parent f5a0ea4d3f
commit 2c4473dd1b
4 changed files with 198 additions and 22 deletions
+121 -13
View File
@@ -1,11 +1,50 @@
"""
Muon optimizer from Keller et al.
Also a lot of borrowing of ideas from modded-nanogpt.
Muon optimizer adapted (simplified) from modded-nanogpt.
https://github.com/KellerJordan/modded-nanogpt
"""
import torch
from torch import Tensor
import torch.distributed as dist
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]
@torch.compile
def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor:
"""
Polar Express Sign Method for orthogonalization.
https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
Alternative to Newton-Schulz iteration with potentially better convergence properties.
"""
assert G.ndim >= 2
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1 (with 2% safety factor)
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
# Perform the iterations (cap at available coefficients)
for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
@@ -35,6 +74,40 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
X = X.mT
return X
@torch.compile
def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor:
"""
NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator.
https://arxiv.org/pdf/2510.05491
Normalizes updates based on a running estimate of per-row (or per-column) variance.
The reduction dimension is determined by the shape of second_momentum_buffer.
"""
# Determine reduction dimension from buffer shape
red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2
# Compute per-row/col mean of squared values
v_mean = v.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = v.size(red_dim)
# Compute current norm
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
v_norm = v_norm_sq.sqrt()
# Update second momentum buffer (EMA of variance)
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
# Compute scaling factor from second momentum
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
# Final scale preserves overall norm while adjusting per-row/col
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
return v.mul(final_scale.to(v.dtype))
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
@@ -56,9 +129,11 @@ class Muon(torch.optim.Optimizer):
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
params: list[Tensor] = [*params]
param_groups = []
for size in {p.numel() for p in params}:
@@ -79,13 +154,29 @@ class Muon(torch.optim.Optimizer):
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
# Variance reduction (NorMuon-style)
if group["beta2"] is not None:
if "second_momentum_buffer" not in state:
# Buffer shape determines reduction dim: reduce along larger dimension
if p.size(-2) >= p.size(-1):
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
else:
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
# Parameter update with cautious weight decay
effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5
wd = group["weight_decay"]
if wd != 0:
mask = (g * p) >= 0
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
else:
p.sub_(effective_lr * g)
class DistMuon(torch.optim.Optimizer):
"""
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via NewtonSchulz,
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express,
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
- reduce_scatter(AVG) for gradient averaging
- all_gather to replicate updated weights
@@ -102,11 +193,13 @@ class DistMuon(torch.optim.Optimizer):
lr: learning rate
momentum: momentum coefficient in [0,1)
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
ns_steps: number of NewtonSchulz iterations for the orthogonalization
ns_steps: number of Newton-Schulz iterations for the orthogonalization
beta2: decay rate for second moment (variance) estimate. Set to None to disable.
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
"""
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
nesterov: bool = True, ns_steps: int = 5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
params = list(params)
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
rank = dist.get_rank()
@@ -173,9 +266,24 @@ class DistMuon(torch.optim.Optimizer):
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1.0 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
p.add_(g, alpha=-group["lr"] * scale)
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
# Variance reduction (NorMuon-style)
if group["beta2"] is not None:
if "second_momentum_buffer" not in state:
# Buffer shape determines reduction dim: reduce along larger dimension
if p.size(-2) >= p.size(-1):
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
else:
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
# Parameter update with cautious weight decay
effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
wd = group["weight_decay"]
if wd != 0:
mask = (g * p) >= 0
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
else:
p.sub_(effective_lr * g)
# Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
ag_output = params[base_i:base_i + world_size]