nudge hyperparameters of the base script with the results of the sweeps and miniseries. vocab size down to 32K. D:N ratio from 20 to 8. add miniseries script
This commit is contained in:
+24
-3
@@ -216,14 +216,35 @@ class GPT(nn.Module):
|
||||
return self.transformer.wte.weight.device
|
||||
|
||||
def estimate_flops(self):
|
||||
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
|
||||
"""
|
||||
Return the estimated FLOPs per token for the model (forward + backward).
|
||||
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
||||
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
||||
On top of that, the term 12 * l * h * q * t accounts for key @ query matmul flops inside attention.
|
||||
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
||||
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
||||
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
||||
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
nparams_embedding = self.transformer.wte.weight.numel()
|
||||
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||
return num_flops_per_token
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
|
||||
def num_scaling_params(self):
|
||||
"""
|
||||
Return all of the parameters, same as Chinchilla paper.
|
||||
Kaplan et al. did not include embedding parameters and said that this led to cleaner scaling laws.
|
||||
But Kaplan et al. also had a bug in their results (as pointed out by Chinchilla).
|
||||
My own experiments in nanochat confirm the Chinchilla approach gives the much cleaner scaling law.
|
||||
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper <- good).
|
||||
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper <- bad)
|
||||
"""
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
return nparams
|
||||
|
||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95)):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
||||
@@ -239,7 +260,7 @@ class GPT(nn.Module):
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
]
|
||||
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
|
||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=weight_decay)
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
# Create the Muon optimizer for the linear layers
|
||||
|
||||
Reference in New Issue
Block a user