feat: pad vocab size to 64 for DDP optimizers and efficiency

This commit is contained in:
Matěj Kripner
2025-12-09 12:38:18 +01:00
parent d5759400f9
commit f1bf69d562
3 changed files with 14 additions and 6 deletions
+1
View File
@@ -27,6 +27,7 @@ class DistAdamW(torch.optim.Optimizer):
for group in self.param_groups:
params: list[Tensor] = group["params"]
for base_i in range(len(params)):
assert params[base_i].shape[0] % world_size == 0, f"Parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
grad = params[base_i].grad
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])