All of these improvements were developed by Claude running autonomously over ~2 days using autoresearch. I didn't touch anything - incredible. All tuning was done on d12 but generalized easily to larger models (e.g. d24 in particular). This means we will also get a new "Time to GPT-2" Leaderboard entry, which I will push separately.
Optimizer & schedule changes: - Increase unembedding LR 0.004 -> 0.008, weight decay 0.2 -> 0.28 - Per-group Adam betas and weight decay (instead of shared global betas) - Muon beta2 0.95 -> 0.9, momentum warmup target 0.95 -> 0.97 over 400 steps - Warmup: ratio-based -> absolute steps (default 40) - Warmdown ratio 0.5 -> 0.65, final LR fraction 0.0 -> 0.05 - Weight decay schedule: linear -> cosine decay - Polar express norm factor 1.02 -> 1.01 Architecture & init changes: - VE gate: channels 32 -> 12, scale range 2x -> 3x, init small positive - Add post-QK-norm scaling (q,k *= 1.15) for sharper attention - Embedding init std 1.0 -> 0.8, MLP c_fc init 0.5x smaller - RoPE base theta 10K -> 100K - Short attention window: seq_len/2 -> ~seq_len/3 (ceil to 128 tile) - Logit softcap 20 -> 15
This commit is contained in:
+17
-15
@@ -76,7 +76,7 @@ class CausalSelfAttention(nn.Module):
|
||||
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.ve_gate_channels = 32
|
||||
self.ve_gate_channels = 12
|
||||
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
@@ -91,13 +91,15 @@ class CausalSelfAttention(nn.Module):
|
||||
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
||||
if ve is not None:
|
||||
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
||||
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
|
||||
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3)
|
||||
v = v + gate.unsqueeze(-1) * ve
|
||||
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = cos_sin
|
||||
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
||||
q, k = norm(q), norm(k) # QK norm
|
||||
q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better
|
||||
k = k * 1.15
|
||||
|
||||
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
||||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
@@ -208,7 +210,7 @@ class GPT(nn.Module):
|
||||
"""
|
||||
|
||||
# Embedding and unembedding
|
||||
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
||||
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
|
||||
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||
|
||||
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
||||
@@ -219,7 +221,7 @@ class GPT(nn.Module):
|
||||
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
||||
# Per-layer scalars
|
||||
@@ -230,10 +232,10 @@ class GPT(nn.Module):
|
||||
for ve in self.value_embeds.values():
|
||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||
|
||||
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
|
||||
# Gate weights init with small positive values so gates start slightly above neutral
|
||||
for block in self.transformer.h:
|
||||
if block.attn.ve_gate is not None:
|
||||
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
||||
torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
@@ -248,7 +250,7 @@ class GPT(nn.Module):
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=COMPUTE_DTYPE)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
# autodetect the device from model embeddings
|
||||
if device is None:
|
||||
@@ -280,7 +282,7 @@ class GPT(nn.Module):
|
||||
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
||||
# Map characters to window sizes
|
||||
long_window = config.sequence_len
|
||||
short_window = long_window // 2
|
||||
short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
|
||||
char_to_window = {
|
||||
"L": (long_window, 0),
|
||||
"S": (short_window, 0),
|
||||
@@ -353,7 +355,7 @@ class GPT(nn.Module):
|
||||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
|
||||
@@ -373,10 +375,10 @@ class GPT(nn.Module):
|
||||
# Build param_groups with all required fields explicit
|
||||
param_groups = [
|
||||
# AdamW groups (embeddings, lm_head, scalars)
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
||||
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
|
||||
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
||||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
]
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
@@ -384,7 +386,7 @@ class GPT(nn.Module):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
param_groups.append(dict(
|
||||
kind='muon', params=group_params, lr=matrix_lr,
|
||||
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
||||
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
||||
))
|
||||
|
||||
Factory = DistMuonAdamW if ddp else MuonAdamW
|
||||
@@ -416,7 +418,7 @@ class GPT(nn.Module):
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
softcap = 20 # smoothly cap the logits to the range [-softcap, softcap]
|
||||
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
||||
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
||||
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
||||
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
||||
|
||||
Reference in New Issue
Block a user