296 lines
14 KiB
Python
296 lines
14 KiB
Python
"""
|
|
Muon optimizer adapted (simplified) from modded-nanogpt.
|
|
https://github.com/KellerJordan/modded-nanogpt
|
|
"""
|
|
import torch
|
|
from torch import Tensor
|
|
import torch.distributed as dist
|
|
|
|
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
|
# From https://arxiv.org/pdf/2505.16932
|
|
polar_express_coeffs = [
|
|
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
|
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
|
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
|
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
|
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
|
]
|
|
|
|
|
|
@torch.compile
|
|
def zeropower_via_polar_express(G: Tensor, steps: int = 5) -> Tensor:
|
|
"""
|
|
Polar Express Sign Method for orthogonalization.
|
|
https://arxiv.org/pdf/2505.16932
|
|
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
|
|
|
Alternative to Newton-Schulz iteration with potentially better convergence properties.
|
|
"""
|
|
assert G.ndim >= 2
|
|
X = G.bfloat16()
|
|
if G.size(-2) > G.size(-1):
|
|
X = X.mT
|
|
|
|
# Ensure spectral norm is at most 1 (with 2% safety factor)
|
|
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
|
|
|
# Perform the iterations (cap at available coefficients)
|
|
for a, b, c in polar_express_coeffs[:min(steps, len(polar_express_coeffs))]:
|
|
A = X @ X.mT
|
|
B = b * A + c * (A @ A)
|
|
X = a * X + B @ X
|
|
|
|
if G.size(-2) > G.size(-1):
|
|
X = X.mT
|
|
return X
|
|
|
|
|
|
@torch.compile
|
|
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
|
"""
|
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
"""
|
|
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
X = G.bfloat16()
|
|
if G.size(-2) > G.size(-1):
|
|
X = X.mT
|
|
|
|
# Ensure spectral norm is at most 1
|
|
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
|
# Perform the NS iterations
|
|
for _ in range(steps):
|
|
A = X @ X.mT
|
|
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
X = a * X + B @ X
|
|
|
|
if G.size(-2) > G.size(-1):
|
|
X = X.mT
|
|
return X
|
|
|
|
|
|
@torch.compile
|
|
def apply_variance_reduction(v: Tensor, second_momentum_buffer: Tensor, beta2: float) -> Tensor:
|
|
"""
|
|
NorMuon-style variance reduction, similar to Adafactor's low-rank variance estimator.
|
|
https://arxiv.org/pdf/2510.05491
|
|
|
|
Normalizes updates based on a running estimate of per-row (or per-column) variance.
|
|
The reduction dimension is determined by the shape of second_momentum_buffer.
|
|
"""
|
|
# Determine reduction dimension from buffer shape
|
|
red_dim = -1 if second_momentum_buffer.size(-1) == 1 else -2
|
|
|
|
# Compute per-row/col mean of squared values
|
|
v_mean = v.float().square().mean(dim=red_dim, keepdim=True)
|
|
red_dim_size = v.size(red_dim)
|
|
|
|
# Compute current norm
|
|
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
|
v_norm = v_norm_sq.sqrt()
|
|
|
|
# Update second momentum buffer (EMA of variance)
|
|
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
|
|
|
# Compute scaling factor from second momentum
|
|
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
|
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
|
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
|
|
|
# Final scale preserves overall norm while adjusting per-row/col
|
|
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
|
return v.mul(final_scale.to(v.dtype))
|
|
|
|
|
|
class Muon(torch.optim.Optimizer):
|
|
"""
|
|
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
https://kellerjordan.github.io/posts/muon/
|
|
|
|
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
|
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
|
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
|
the advantage that it can be stably run in bfloat16 on the GPU.
|
|
|
|
Some warnings:
|
|
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
|
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
|
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
|
|
|
Arguments:
|
|
lr: The learning rate used by the internal SGD.
|
|
momentum: The momentum used by the internal SGD.
|
|
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
|
ns_steps: The number of Newton-Schulz iteration steps to use.
|
|
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
|
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
|
"""
|
|
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, beta2=0.95, weight_decay=0.0):
|
|
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
|
params: list[Tensor] = [*params]
|
|
param_groups = []
|
|
for size in {p.numel() for p in params}:
|
|
group = dict(params=[p for p in params if p.numel() == size])
|
|
param_groups.append(group)
|
|
super().__init__(param_groups, defaults)
|
|
|
|
@torch.no_grad()
|
|
def step(self):
|
|
for group in self.param_groups:
|
|
params: list[Tensor] = group["params"]
|
|
for p in params:
|
|
g = p.grad
|
|
assert g is not None
|
|
state = self.state[p]
|
|
if "momentum_buffer" not in state:
|
|
state["momentum_buffer"] = torch.zeros_like(g)
|
|
buf: Tensor = state["momentum_buffer"]
|
|
buf.lerp_(g, 1 - group["momentum"])
|
|
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
|
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
|
|
# Variance reduction (NorMuon-style)
|
|
if group["beta2"] is not None:
|
|
if "second_momentum_buffer" not in state:
|
|
# Buffer shape determines reduction dim: reduce along larger dimension
|
|
if p.size(-2) >= p.size(-1):
|
|
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
|
|
else:
|
|
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
|
|
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
|
|
# Parameter update with cautious weight decay
|
|
effective_lr = group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5
|
|
wd = group["weight_decay"]
|
|
if wd != 0:
|
|
mask = (g * p) >= 0
|
|
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
|
|
else:
|
|
p.sub_(effective_lr * g)
|
|
|
|
|
|
class DistMuon(torch.optim.Optimizer):
|
|
"""
|
|
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Polar Express,
|
|
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
|
|
- reduce_scatter(AVG) for gradient averaging
|
|
- all_gather to replicate updated weights
|
|
|
|
Notes:
|
|
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
|
|
params like embeddings or scalars.
|
|
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
|
|
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
|
|
consolidate states beforehand.
|
|
|
|
Args:
|
|
params: iterable of Tensors
|
|
lr: learning rate
|
|
momentum: momentum coefficient in [0,1)
|
|
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
|
ns_steps: number of Newton-Schulz iterations for the orthogonalization
|
|
beta2: decay rate for second moment (variance) estimate. Set to None to disable.
|
|
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
|
"""
|
|
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
|
nesterov: bool = True, ns_steps: int = 5, beta2: float = 0.95, weight_decay: float = 0.0):
|
|
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, beta2=beta2, weight_decay=weight_decay)
|
|
params = list(params)
|
|
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
|
rank = dist.get_rank()
|
|
# Group all parameters by their shape
|
|
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
|
param_groups = []
|
|
for shape in shapes:
|
|
group_params = [p for p in params if p.shape == shape]
|
|
device, dtype = group_params[0].device, group_params[0].dtype
|
|
assert all(p.device == device for p in group_params)
|
|
assert all(p.dtype == dtype for p in group_params)
|
|
if rank == 0:
|
|
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
|
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
|
super().__init__(param_groups, defaults)
|
|
|
|
@torch.no_grad()
|
|
def step(self):
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
|
|
# Ensure all grads exist
|
|
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
|
|
|
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
|
all_reduce_futures = []
|
|
for group in self.param_groups:
|
|
params = group["params"]
|
|
zero_buffer = group["zero_buffer"]
|
|
# Go through params in groups of world_size.
|
|
for base_i in range(0, len(params), world_size):
|
|
# The compute owner of each param is rank i % world_size
|
|
owner_idx = base_i + rank
|
|
# each rank stacks up its chunk of world_size params into a list
|
|
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
|
|
# pad rs_input with the zero buffer to complete the group
|
|
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
|
# the output buffer gets strided across the group based on the rank
|
|
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
|
|
# reduce scatter the gradients within this group of world_size params
|
|
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
|
all_reduce_futures.append(work)
|
|
|
|
# Now each rank computes the update and gathers
|
|
future_idx = 0
|
|
all_gather_futures = []
|
|
for group in self.param_groups:
|
|
params = group["params"]
|
|
zero_buffer = group["zero_buffer"]
|
|
# Go through params in groups of world_size.
|
|
for base_i in range(0, len(params), world_size):
|
|
# The compute owner of each param is rank i % world_size
|
|
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
|
# Wait for the reduce scatter to complete
|
|
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
|
future_idx += 1
|
|
# Owner computes the Muon update, result is in its param
|
|
if owner_idx < len(params):
|
|
p = params[owner_idx]
|
|
g = p.grad # now averaged across ranks
|
|
state = self.state[p]
|
|
if "momentum_buffer" not in state:
|
|
state["momentum_buffer"] = torch.zeros_like(g)
|
|
buf: Tensor = state["momentum_buffer"]
|
|
buf.lerp_(g, 1.0 - group["momentum"])
|
|
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
|
g = zeropower_via_polar_express(g, steps=group["ns_steps"])
|
|
# Variance reduction (NorMuon-style)
|
|
if group["beta2"] is not None:
|
|
if "second_momentum_buffer" not in state:
|
|
# Buffer shape determines reduction dim: reduce along larger dimension
|
|
if p.size(-2) >= p.size(-1):
|
|
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1])
|
|
else:
|
|
state["second_momentum_buffer"] = torch.zeros_like(g[..., :1, :])
|
|
g = apply_variance_reduction(g, state["second_momentum_buffer"], group["beta2"])
|
|
# Parameter update with cautious weight decay
|
|
effective_lr = group["lr"] * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
|
wd = group["weight_decay"]
|
|
if wd != 0:
|
|
mask = (g * p) >= 0
|
|
p.sub_(effective_lr * g + effective_lr * wd * p * mask)
|
|
else:
|
|
p.sub_(effective_lr * g)
|
|
# Replicate updated parameters to all ranks
|
|
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
|
ag_output = params[base_i:base_i + world_size]
|
|
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
|
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
|
all_gather_futures.append(work)
|
|
|
|
# Wait for all work to finish
|
|
torch.futures.collect_all(all_gather_futures).wait()
|