188 lines
9.4 KiB
Python
188 lines
9.4 KiB
Python
"""
|
||
Muon optimizer from Keller et al.
|
||
Also a lot of borrowing of ideas from modded-nanogpt.
|
||
"""
|
||
import torch
|
||
from torch import Tensor
|
||
import torch.distributed as dist
|
||
|
||
@torch.compile
|
||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||
"""
|
||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||
"""
|
||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||
X = G.bfloat16()
|
||
if G.size(-2) > G.size(-1):
|
||
X = X.mT
|
||
|
||
# Ensure spectral norm is at most 1
|
||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||
# Perform the NS iterations
|
||
for _ in range(steps):
|
||
A = X @ X.mT
|
||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||
X = a * X + B @ X
|
||
|
||
if G.size(-2) > G.size(-1):
|
||
X = X.mT
|
||
return X
|
||
|
||
class Muon(torch.optim.Optimizer):
|
||
"""
|
||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||
|
||
https://kellerjordan.github.io/posts/muon/
|
||
|
||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||
|
||
Some warnings:
|
||
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||
|
||
Arguments:
|
||
lr: The learning rate used by the internal SGD.
|
||
momentum: The momentum used by the internal SGD.
|
||
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
|
||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||
"""
|
||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
|
||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||
params: list[Tensor] = [*params]
|
||
param_groups = []
|
||
for size in {p.numel() for p in params}:
|
||
group = dict(params=[p for p in params if p.numel() == size])
|
||
param_groups.append(group)
|
||
super().__init__(param_groups, defaults)
|
||
|
||
@torch.no_grad()
|
||
def step(self):
|
||
for group in self.param_groups:
|
||
params: list[Tensor] = group["params"]
|
||
for p in params:
|
||
g = p.grad
|
||
assert g is not None
|
||
state = self.state[p]
|
||
if "momentum_buffer" not in state:
|
||
state["momentum_buffer"] = torch.zeros_like(g)
|
||
buf: Tensor = state["momentum_buffer"]
|
||
buf.lerp_(g, 1 - group["momentum"])
|
||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
|
||
|
||
|
||
class DistMuon(torch.optim.Optimizer):
|
||
"""
|
||
Muon: SGD-momentum + (optional) Nesterov, then orthogonalize the 2D update via Newton–Schulz,
|
||
finally apply aspect-ratio scaled step. Performs its own distributed synchronization:
|
||
- reduce_scatter(AVG) for gradient averaging
|
||
- all_gather to replicate updated weights
|
||
|
||
Notes:
|
||
* Designed for 2D parameters (e.g., linear/conv kernels reshaped to 2D). Do not use for 0D/1D
|
||
params like embeddings or scalars.
|
||
* Momentum buffers are maintained only on the 'owner' rank for each parameter (rank chosen
|
||
by block-cyclic assignment below). If you checkpoint optimizer state on a single rank,
|
||
consolidate states beforehand.
|
||
|
||
Args:
|
||
params: iterable of Tensors
|
||
lr: learning rate
|
||
momentum: momentum coefficient in [0,1)
|
||
nesterov: if True, Nesterov-style update (g <- lerp(g, buf, momentum)); else use buf
|
||
ns_steps: number of Newton–Schulz iterations for the orthogonalization
|
||
"""
|
||
def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
|
||
nesterov: bool = True, ns_steps: int = 5):
|
||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
|
||
params = list(params)
|
||
assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
|
||
rank = dist.get_rank()
|
||
# Group all parameters by their shape
|
||
shapes = sorted({p.shape for p in params}) # sort to ensure consistent / deterministic ordering
|
||
param_groups = []
|
||
for shape in shapes:
|
||
group_params = [p for p in params if p.shape == shape]
|
||
device, dtype = group_params[0].device, group_params[0].dtype
|
||
assert all(p.device == device for p in group_params)
|
||
assert all(p.dtype == dtype for p in group_params)
|
||
if rank == 0:
|
||
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
||
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
||
super().__init__(param_groups, defaults)
|
||
|
||
@torch.no_grad()
|
||
def step(self):
|
||
rank = dist.get_rank()
|
||
world_size = dist.get_world_size()
|
||
|
||
# Ensure all grads exist
|
||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||
|
||
# Kick off all the reduce scatter operations to average up the gradients across all ranks
|
||
all_reduce_futures = []
|
||
for group in self.param_groups:
|
||
params = group["params"]
|
||
zero_buffer = group["zero_buffer"]
|
||
# Go through params in groups of world_size.
|
||
for base_i in range(0, len(params), world_size):
|
||
# The compute owner of each param is rank i % world_size
|
||
owner_idx = base_i + rank
|
||
# each rank stacks up its chunk of world_size params into a list
|
||
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
|
||
# pad rs_input with the zero buffer to complete the group
|
||
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
|
||
# the output buffer gets strided across the group based on the rank
|
||
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
|
||
# reduce scatter the gradients within this group of world_size params
|
||
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||
all_reduce_futures.append(work)
|
||
|
||
# Now each rank computes the update and gathers
|
||
future_idx = 0
|
||
all_gather_futures = []
|
||
for group in self.param_groups:
|
||
params = group["params"]
|
||
zero_buffer = group["zero_buffer"]
|
||
# Go through params in groups of world_size.
|
||
for base_i in range(0, len(params), world_size):
|
||
# The compute owner of each param is rank i % world_size
|
||
owner_idx = base_i + rank # calculate the index of the param that this rank owns
|
||
# Wait for the reduce scatter to complete
|
||
all_reduce_futures[future_idx].wait() # possibly later we could use wait_any polling instead
|
||
future_idx += 1
|
||
# Owner computes the Muon update, result is in its param
|
||
if owner_idx < len(params):
|
||
p = params[owner_idx]
|
||
g = p.grad # now averaged across ranks
|
||
state = self.state[p]
|
||
if "momentum_buffer" not in state:
|
||
state["momentum_buffer"] = torch.zeros_like(g)
|
||
buf: Tensor = state["momentum_buffer"]
|
||
buf.lerp_(g, 1.0 - group["momentum"])
|
||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
||
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||
p.add_(g, alpha=-group["lr"] * scale)
|
||
# Replicate updated parameters to all ranks
|
||
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
|
||
ag_output = params[base_i:base_i + world_size]
|
||
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # pad
|
||
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
|
||
all_gather_futures.append(work)
|
||
|
||
# Wait for all work to finish
|
||
torch.futures.collect_all(all_gather_futures).wait()
|