feat: pad vocab size to 64 for DDP optimizers and efficiency

This commit is contained in:
Matěj Kripner
2025-12-09 12:38:18 +01:00
parent d5759400f9
commit f1bf69d562
3 changed files with 14 additions and 6 deletions
+1
View File
@@ -27,6 +27,7 @@ class DistAdamW(torch.optim.Optimizer):
for group in self.param_groups: for group in self.param_groups:
params: list[Tensor] = group["params"] params: list[Tensor] = group["params"]
for base_i in range(len(params)): for base_i in range(len(params)):
assert params[base_i].shape[0] % world_size == 0, f"Parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
grad = params[base_i].grad grad = params[base_i].grad
rank_size = grad.shape[0] // world_size rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size]) grad_slice = torch.empty_like(grad[:rank_size])
+1
View File
@@ -26,6 +26,7 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches(): def document_batches():
parquet_paths = list_parquet_files() parquet_paths = list_parquet_files()
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:] parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0 resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
+12 -6
View File
@@ -135,14 +135,19 @@ class Block(nn.Module):
class GPT(nn.Module): class GPT(nn.Module):
def __init__(self, config): def __init__(self, config, pad_vocab_size_to=64):
super().__init__() super().__init__()
self.config = config self.config = config
# For DDP, we want vocab_size divisible by world_size. Also, there are potential performance benefits, see:
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
if padded_vocab_size != config.vocab_size:
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} to be divisible by {pad_vocab_size_to}")
self.transformer = nn.ModuleDict({ self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd), "wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
}) })
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
# To support meta device initialization, we init the rotary embeddings here, but it's fake # To support meta device initialization, we init the rotary embeddings here, but it's fake
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount. # so let's just over-compute them, but assert fail if we ever reach that amount.
@@ -220,8 +225,7 @@ class GPT(nn.Module):
# Create the AdamW optimizer for the embedding and lm_head # Create the AdamW optimizer for the embedding and lm_head
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5 dmodel_lr_scale = (model_dim / 768) ** -0.5
if rank == 0: print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [ adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
@@ -260,7 +264,9 @@ class GPT(nn.Module):
# Forward the lm_head (compute logits) # Forward the lm_head (compute logits)
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap] softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
logits = self.lm_head(x) # (B, T, vocab_size) <- very big tensor, large amount of memory logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
# slice to remove padding
logits = logits[..., :self.config.vocab_size]
logits = logits.float() # switch to fp32 for logit softcap and loss computation logits = logits.float() # switch to fp32 for logit softcap and loss computation
logits = softcap * torch.tanh(logits / softcap) # squash the logits logits = softcap * torch.tanh(logits / softcap) # squash the logits