fix adamw slight bug. this chunk was copy pasted originally from modded-nanogpt, which still seems to have the bug
This commit is contained in:
+2
-2
@@ -68,8 +68,8 @@ class DistAdamW(torch.optim.Optimizer):
|
|||||||
bias1 = 1 - beta1 ** t
|
bias1 = 1 - beta1 ** t
|
||||||
bias2 = 1 - beta2 ** t
|
bias2 = 1 - beta2 ** t
|
||||||
# compute step
|
# compute step
|
||||||
denom = exp_avg_sq.sqrt().add_(eps)
|
denom = (exp_avg_sq / bias2).sqrt().add_(eps)
|
||||||
step_size = lr * (torch.sqrt(bias2) / bias1)
|
step_size = lr / bias1
|
||||||
update = exp_avg.div(denom).mul_(step_size)
|
update = exp_avg.div(denom).mul_(step_size)
|
||||||
p_slice.add_(other=update, alpha=-1.0)
|
p_slice.add_(other=update, alpha=-1.0)
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user