Add learnable lambdas that gate the residual connection and a skip connection to the input embeddings, solid bump to val_bpb

This commit is contained in:
Andrej Karpathy
2026-01-11 18:47:35 +00:00
parent 2c4473dd1b
commit aa530cdad5
4 changed files with 121 additions and 23 deletions
+36 -18
View File
@@ -16,23 +16,31 @@ class DistAdamW(torch.optim.Optimizer):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(param_groups, defaults)
@torch.compile
@torch.no_grad()
def step(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
reduce_scatter_futures: list[torch.Future] = []
all_reduce_futures: list[torch.Future] = []
reduce_futures: list[torch.Future] = []
gather_futures: list[torch.Future] = []
grad_slices = []
is_small = [] # track which params are small (use all_reduce) vs large (use reduce_scatter)
for group in self.param_groups:
params: list[Tensor] = group["params"]
for base_i in range(len(params)):
assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
grad = params[base_i].grad
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
for p in params:
grad = p.grad
# Small params: use all_reduce (no scatter/gather needed)
if p.numel() < 1024:
is_small.append(True)
reduce_futures.append(dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad)
else:
is_small.append(False)
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
grad_slices.append(grad_slice)
idx = 0
for group in self.param_groups:
@@ -40,14 +48,19 @@ class DistAdamW(torch.optim.Optimizer):
eps = group['eps']
wd = group['weight_decay']
params = group['params']
for base in range(len(params)):
reduce_scatter_futures[idx].wait()
p = params[base]
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
for p in params:
reduce_futures[idx].wait()
g_slice = grad_slices[idx]
lr = group['lr'] * getattr(p, "lr_mul", 1.0)
state = self.state[p]
g_slice = grad_slices[idx]
# For small params, operate on full param; for large, operate on slice
if is_small[idx]:
p_slice = p
else:
rank_size = p.shape[0] // world_size
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
# State init
if not state:
state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
@@ -72,6 +85,11 @@ class DistAdamW(torch.optim.Optimizer):
step_size = lr / bias1
update = exp_avg.div(denom).mul_(step_size)
p_slice.add_(other=update, alpha=-1.0)
# Only large params need all_gather
if not is_small[idx]:
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
idx += 1
all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
torch.futures.collect_all(all_reduce_futures).wait()
if gather_futures:
torch.futures.collect_all(gather_futures).wait()