simplify, clarify and slightly tune model initialization. should be very slightly better possibly, but certainly a lot clearer
This commit is contained in:
+33
-22
@@ -146,9 +146,9 @@ class GPT(nn.Module):
|
|||||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||||
})
|
})
|
||||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||||
# To support meta device initialization, we init the rotary embeddings here, but it's fake
|
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||||
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||||
# In the future we can dynamically grow the cache, for now it's fine.
|
# In the future we can dynamically grow the cache, for now it's fine.
|
||||||
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
||||||
head_dim = config.n_embd // config.n_head
|
head_dim = config.n_embd // config.n_head
|
||||||
@@ -157,35 +157,46 @@ class GPT(nn.Module):
|
|||||||
self.register_buffer("sin", sin, persistent=False)
|
self.register_buffer("sin", sin, persistent=False)
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
self.apply(self._init_weights)
|
"""
|
||||||
# zero out classifier weights
|
Initialize the full model in this one function for maximum clarity.
|
||||||
torch.nn.init.zeros_(self.lm_head.weight)
|
|
||||||
# zero out c_proj weights in all blocks
|
wte (embedding): normal, std=1.0
|
||||||
|
lm_head: normal, std=0.001
|
||||||
|
for each block:
|
||||||
|
attn.c_q: uniform, std=1/sqrt(n_embd)
|
||||||
|
attn.c_k: uniform, std=1/sqrt(n_embd)
|
||||||
|
attn.c_v: uniform, std=1/sqrt(n_embd)
|
||||||
|
attn.c_proj: zeros
|
||||||
|
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
||||||
|
mlp.c_proj: zeros
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Embedding and unembedding
|
||||||
|
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
||||||
|
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
||||||
|
|
||||||
|
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
||||||
|
n_embd = self.config.n_embd
|
||||||
|
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
||||||
for block in self.transformer.h:
|
for block in self.transformer.h:
|
||||||
|
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
||||||
|
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
||||||
|
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
||||||
|
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
||||||
|
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
||||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||||
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
|
||||||
# init the rotary embeddings
|
# Rotary embeddings
|
||||||
head_dim = self.config.n_embd // self.config.n_head
|
head_dim = self.config.n_embd // self.config.n_head
|
||||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||||
self.cos, self.sin = cos, sin
|
self.cos, self.sin = cos, sin
|
||||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
|
||||||
|
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||||
if self.transformer.wte.weight.device.type == "cuda":
|
if self.transformer.wte.weight.device.type == "cuda":
|
||||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
# https://arxiv.org/pdf/2310.17813
|
|
||||||
fan_out = module.weight.size(0)
|
|
||||||
fan_in = module.weight.size(1)
|
|
||||||
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
||||||
if module.bias is not None:
|
|
||||||
torch.nn.init.zeros_(module.bias)
|
|
||||||
elif isinstance(module, nn.Embedding):
|
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
|
||||||
|
|
||||||
# TODO: bump base theta more, e.g. 100K is more common more recently
|
|
||||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||||
|
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||||
# autodetect the device from model embeddings
|
# autodetect the device from model embeddings
|
||||||
if device is None:
|
if device is None:
|
||||||
device = self.transformer.wte.weight.device
|
device = self.transformer.wte.weight.device
|
||||||
|
|||||||
@@ -112,10 +112,11 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
|
|||||||
# Create a new model with random weights
|
# Create a new model with random weights
|
||||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
|
# All tensors are created as meta tensors (they have shape/dtype but no data)
|
||||||
model_config = GPTConfig(**model_config_kwargs)
|
model_config = GPTConfig(**model_config_kwargs)
|
||||||
model = GPT(model_config)
|
model = GPT(model_config)
|
||||||
model.to_empty(device=device)
|
model.to_empty(device=device) # All tensors get storage on target device but with uninitialized (garbage) data
|
||||||
model.init_weights()
|
model.init_weights() # All tensors get initialized
|
||||||
|
|
||||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||||
base_dir = get_base_dir()
|
base_dir = get_base_dir()
|
||||||
|
|||||||
Reference in New Issue
Block a user