alternating design
This commit is contained in:
+31
-25
@@ -45,6 +45,10 @@ def norm(x):
|
|||||||
return F.rms_norm(x, (x.size(-1),))
|
return F.rms_norm(x, (x.size(-1),))
|
||||||
|
|
||||||
|
|
||||||
|
def has_ve(layer_idx, n_layer):
|
||||||
|
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
||||||
|
return layer_idx % 2 == (n_layer - 1) % 2
|
||||||
|
|
||||||
def apply_rotary_emb(x, cos, sin):
|
def apply_rotary_emb(x, cos, sin):
|
||||||
assert x.ndim == 4 # multihead attention
|
assert x.ndim == 4 # multihead attention
|
||||||
d = x.shape[3] // 2
|
d = x.shape[3] // 2
|
||||||
@@ -67,8 +71,10 @@ class CausalSelfAttention(nn.Module):
|
|||||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||||
|
self.ve_gate_channels = 32
|
||||||
|
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||||
|
|
||||||
def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda):
|
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||||
B, T, C = x.size()
|
B, T, C = x.size()
|
||||||
|
|
||||||
# Project the input to get queries, keys, and values
|
# Project the input to get queries, keys, and values
|
||||||
@@ -77,10 +83,11 @@ class CausalSelfAttention(nn.Module):
|
|||||||
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||||
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
||||||
|
|
||||||
# Value residual (ResFormer): mix in projected initial embedding for later layers
|
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
||||||
if v0 is not None:
|
if ve is not None:
|
||||||
v0_reshaped = v0.view(B, T, self.n_kv_head, self.head_dim)
|
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
||||||
v = v + v0_lambda * v0_reshaped
|
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2)
|
||||||
|
v = v + gate.unsqueeze(-1) * ve
|
||||||
|
|
||||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||||
cos, sin = cos_sin
|
cos, sin = cos_sin
|
||||||
@@ -131,8 +138,8 @@ class Block(nn.Module):
|
|||||||
self.attn = CausalSelfAttention(config, layer_idx)
|
self.attn = CausalSelfAttention(config, layer_idx)
|
||||||
self.mlp = MLP(config)
|
self.mlp = MLP(config)
|
||||||
|
|
||||||
def forward(self, x, cos_sin, window_size, kv_cache, v0, v0_lambda):
|
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||||
x = x + self.attn(norm(x), cos_sin, window_size, kv_cache, v0, v0_lambda)
|
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
||||||
x = x + self.mlp(norm(x))
|
x = x + self.mlp(norm(x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -165,12 +172,10 @@ class GPT(nn.Module):
|
|||||||
# Separate parameters so they can have different optimizer treatment
|
# Separate parameters so they can have different optimizer treatment
|
||||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||||
# Value residual (ResFormer-style): every layer gets its own value embedding
|
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
||||||
# Paper: "Value Residual Learning" (arXiv:2410.17897) shows this improves information flow
|
|
||||||
head_dim = config.n_embd // config.n_head
|
head_dim = config.n_embd // config.n_head
|
||||||
kv_dim = config.n_kv_head * head_dim
|
kv_dim = config.n_kv_head * head_dim
|
||||||
self.value_embeds = nn.ModuleList([nn.Embedding(padded_vocab_size, kv_dim) for _ in range(config.n_layer)])
|
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||||
self.v0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
|
||||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||||
@@ -181,6 +186,7 @@ class GPT(nn.Module):
|
|||||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||||
self.register_buffer("sin", sin, persistent=False)
|
self.register_buffer("sin", sin, persistent=False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
"""
|
"""
|
||||||
Initialize the full model in this one function for maximum clarity.
|
Initialize the full model in this one function for maximum clarity.
|
||||||
@@ -212,15 +218,18 @@ class GPT(nn.Module):
|
|||||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||||
|
|
||||||
# Per-layer scalars
|
# Per-layer scalars
|
||||||
with torch.no_grad():
|
|
||||||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
||||||
self.v0_lambdas.fill_(0.0) # 0.0 => value residual is disabled at init
|
|
||||||
|
|
||||||
# Value embeddings (init like c_v: uniform with same std)
|
# Value embeddings (init like c_v: uniform with same std)
|
||||||
for ve in self.value_embeds:
|
for ve in self.value_embeds.values():
|
||||||
torch.nn.init.uniform_(ve.weight, -s, s)
|
torch.nn.init.uniform_(ve.weight, -s, s)
|
||||||
|
|
||||||
|
# Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral)
|
||||||
|
for block in self.transformer.h:
|
||||||
|
if block.attn.ve_gate is not None:
|
||||||
|
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
||||||
|
|
||||||
# Rotary embeddings
|
# Rotary embeddings
|
||||||
head_dim = self.config.n_embd // self.config.n_head
|
head_dim = self.config.n_embd // self.config.n_head
|
||||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||||
@@ -229,7 +238,7 @@ class GPT(nn.Module):
|
|||||||
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||||
if self.transformer.wte.weight.device.type == "cuda":
|
if self.transformer.wte.weight.device.type == "cuda":
|
||||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||||
for ve in self.value_embeds:
|
for ve in self.value_embeds.values():
|
||||||
ve.to(dtype=torch.bfloat16)
|
ve.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||||
@@ -295,9 +304,9 @@ class GPT(nn.Module):
|
|||||||
"""
|
"""
|
||||||
nparams = sum(p.numel() for p in self.parameters())
|
nparams = sum(p.numel() for p in self.parameters())
|
||||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds)
|
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.v0_lambdas.numel())
|
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
||||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||||
# Sum attention FLOPs per layer, accounting for sliding window
|
# Sum attention FLOPs per layer, accounting for sliding window
|
||||||
attn_flops = 0
|
attn_flops = 0
|
||||||
@@ -323,15 +332,14 @@ class GPT(nn.Module):
|
|||||||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
||||||
model_dim = self.config.n_embd
|
model_dim = self.config.n_embd
|
||||||
ddp, rank, local_rank, world_size = get_dist_info()
|
ddp, rank, local_rank, world_size = get_dist_info()
|
||||||
# Separate out all parameters into groups (matrix, embedding, lm_head, value_embeds, resid_lambdas, x0_lambdas, v0_lambdas)
|
# Separate out all parameters into groups
|
||||||
matrix_params = list(self.transformer.h.parameters())
|
matrix_params = list(self.transformer.h.parameters())
|
||||||
|
value_embeds_params = list(self.value_embeds.parameters())
|
||||||
embedding_params = list(self.transformer.wte.parameters())
|
embedding_params = list(self.transformer.wte.parameters())
|
||||||
lm_head_params = list(self.lm_head.parameters())
|
lm_head_params = list(self.lm_head.parameters())
|
||||||
value_embeds_params = list(self.value_embeds.parameters())
|
|
||||||
resid_params = [self.resid_lambdas]
|
resid_params = [self.resid_lambdas]
|
||||||
x0_params = [self.x0_lambdas]
|
x0_params = [self.x0_lambdas]
|
||||||
v0_params = [self.v0_lambdas]
|
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
||||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(v0_params)
|
|
||||||
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
# Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars
|
||||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||||
@@ -342,7 +350,6 @@ class GPT(nn.Module):
|
|||||||
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding
|
||||||
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream
|
||||||
dict(params=x0_params, lr=scalar_lr),
|
dict(params=x0_params, lr=scalar_lr),
|
||||||
dict(params=v0_params, lr=scalar_lr),
|
|
||||||
]
|
]
|
||||||
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
adamw_kwargs = dict(betas=adam_betas, eps=1e-10, weight_decay=0.0) # NOTE: weight decay is hardcoded to 0.0 for AdamW, only used in Muon
|
||||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||||
@@ -373,11 +380,10 @@ class GPT(nn.Module):
|
|||||||
x = self.transformer.wte(idx)
|
x = self.transformer.wte(idx)
|
||||||
x = norm(x)
|
x = norm(x)
|
||||||
x0 = x # save initial normalized embedding for x0 residual
|
x0 = x # save initial normalized embedding for x0 residual
|
||||||
# Value residual (ResFormer): every layer gets its own value embedding
|
|
||||||
v0s = [ve(idx) for ve in self.value_embeds] # n_layer x (B, T, kv_dim)
|
|
||||||
for i, block in enumerate(self.transformer.h):
|
for i, block in enumerate(self.transformer.h):
|
||||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||||
x = block(x, cos_sin, self.window_sizes[i], kv_cache, v0s[i], self.v0_lambdas[i])
|
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||||
|
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||||
x = norm(x)
|
x = norm(x)
|
||||||
|
|
||||||
# Forward the lm_head (compute logits)
|
# Forward the lm_head (compute logits)
|
||||||
|
|||||||
Reference in New Issue
Block a user