bring back an assert guarding against bad param sizing

This commit is contained in:
Andrej Karpathy
2026-02-05 18:14:09 +00:00
parent 012da1a78b
commit 98eed6df18
+1
View File
@@ -377,6 +377,7 @@ class DistMuonAdamW(torch.optim.Optimizer):
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True) param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
else: else:
# Large params: reduce_scatter # Large params: reduce_scatter
assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})"
rank_size = grad.shape[0] // world_size rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size]) grad_slice = torch.empty_like(grad[:rank_size])
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future() future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()