Big Muon optimizer changes inspired by latest of modded-nanogpt. Added Polar Express, Adafactor-style variance reduction, cautious weight decay, schedule weight decay linearly to ramp down to zero. Tuned optimum weight decay for multiple model sizes d8, d12, d16, d20 and found a scaling law with optimum wd \propto 1/channels^2, including it as default into code. --weight_decay of base_train is now default on and configured optimally according to all of these experiments. Solid bump to val_bpb observed as a result of these changes.
This commit is contained in:
+2
-2
@@ -260,11 +260,11 @@ class GPT(nn.Module):
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
]
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=weight_decay)
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
# Create the Muon optimizer for the linear layers
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95, weight_decay=weight_decay)
|
||||
MuonFactory = DistMuon if ddp else Muon
|
||||
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
||||
# Combine them the two optimizers into one list
|
||||
|
||||
+121
-13
@@ -1,11 +1,50 @@
|
||||
"""
|
||||
Muon optimizer from Keller et al.
|
||||
Also a lot of borrowing of ideas from modded-nanogpt.
|
||||
Muon optimizer adapted (simplified) from modded-nanogpt.
|
||||
https://github.com/KellerJordan/modded-nanogpt
|
||||
"""
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
||||
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
||||
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
||||
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
||||
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
||||
]
|
||||
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor:
|
||||
"""
|
||||
Polar Express Sign Method for orthogonalization.
|
||||
https://arxiv.org/pdf/2505.16932
|
||||
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
||||
|
||||
Alternative to Newton-Schulz iteration with potentially better convergence properties.
|
||||
"""
|
||||
assert G.ndim >= 2
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
||||
# Ensure spectral norm is at most 1 (with 2% safety factor)
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
||||
|
||||
# Perform the iterations (cap at available coefficients)
|
||||
for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]:
|
||||
A = X @ X.mT
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
|
||||
|
||||
@torch.compile
|
||||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||
"""
|
||||
@@ -35,6 +74,40 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||
X = X.mT
|
||||
return X
|
||||
|
||||
|
||||
@torch.compile
|
||||
def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor:
|
||||
"""
|
||||
NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator.
|
||||
https://arxiv.org/pdf/2510.05491
|
||||
|
||||
Normalizes updates based on a running estimate of per-row (or per-column) variance.
|
||||
The reduction dimension is determined by the shape of second_momentum_buffer.
|
||||
"""
|
||||
# Determine reduction dimension from buffer shape
|
||||
red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2
|
||||
|
||||
# Compute per-row/col mean of squared values
|
||||
v_mean = v.float().square().mean(dim=red_dim, keepdim=True)
|
||||
red_dim_size = v.size(red_dim)
|
||||
|
||||
# Compute current norm
|
||||
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
||||
v_norm = v_norm_sq.sqrt()
|
||||
|
||||
# Update second momentum buffer (EMA of variance)
|
||||
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
||||
|
||||
# Compute scaling factor from second momentum
|
||||
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
||||
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
||||
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
||||
|
||||
# Final scale preserves overall norm while adjusting per-row/col
|
||||
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
||||
return v.mul(final_scale.to(v.dtype))
|
||||
|
||||
|
||||
class Muon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
@@ -56,9 +129,11 @@ class Muon(torch.optim.Optimizer):
|
||||
momentum: The momentum used by the internal SGD.
|
||||
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
params: list[Tensor] = [*params]
|
||||
param_groups = []
|
||||
for size in {p.numel() for p in params}:
|
||||
@@ -79,13 +154,29 @@ class Muon(torch.optim.Optimizer):
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
|
||||
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
|
||||
# Variance reduction (NorMuon-style)
|
||||
if group["beta2"] is not None:
|
||||
if "second_momentum_buffer" not in state:
|
||||
# Buffer shape determines reduction dim: reduce along larger dimension
|
||||
if p.size(-2) >= p.size(-1):
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
|
||||
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
|
||||
# Parameter update with cautious weight decay
|
||||
effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5
|
||||
wd = group["weight_decay"]
|
||||
if wd != 0:
|
||||
mask = (g * p) >= 0
|
||||
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
|
||||
else:
|
||||
p.sub_(effective_lr * g)
|
||||
|
||||
|
||||
class DistMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
|
||||
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express,
|
||||
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
|
||||
- reduce_scatter(AVG) for gradient averaging
|
||||
- all_gather to replicate updated weights
|
||||
@@ -102,11 +193,13 @@ class DistMuon(torch.optim.Optimizer):
|
||||
lr: learning rate
|
||||
momentum: momentum coefficient in [0,1)
|
||||
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
||||
ns_steps: number of Newton–Schulz iterations for the orthogonalization
|
||||
ns_steps: number of Newton-Schulz iterations for the orthogonalization
|
||||
beta2: decay rate for second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||||
nesterov: bool = True, ns_steps: int = 5):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||||
nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
||||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
||||
params = list(params)
|
||||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||||
rank = dist.get_rank()
|
||||
@@ -173,9 +266,24 @@ class DistMuon(torch.optim.Optimizer):
|
||||
buf: Tensor = state["momentum_buffer"]
|
||||
buf.lerp_(g, 1.0 - group["momentum"])
|
||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||||
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||||
p.add_(g, alpha=-group["lr"] * scale)
|
||||
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
|
||||
# Variance reduction (NorMuon-style)
|
||||
if group["beta2"] is not None:
|
||||
if "second_momentum_buffer" not in state:
|
||||
# Buffer shape determines reduction dim: reduce along larger dimension
|
||||
if p.size(-2) >= p.size(-1):
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
|
||||
else:
|
||||
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
|
||||
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
|
||||
# Parameter update with cautious weight decay
|
||||
effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||||
wd = group["weight_decay"]
|
||||
if wd != 0:
|
||||
mask = (g * p) >= 0
|
||||
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
|
||||
else:
|
||||
p.sub_(effective_lr * g)
|
||||
# Replicate updated parameters to all ranks
|
||||
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
||||
ag_output = params[base_i:base_i + world_size]
|
||||
|
||||
Reference in New Issue
Block a user