delete autocast, an unnecessary thorn in my side, manage dtypes directly
This commit is contained in:
+27
-18
@@ -19,7 +19,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
|
||||
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
||||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
@@ -40,8 +40,14 @@ class GPTConfig:
|
||||
|
||||
|
||||
def norm(x):
|
||||
# Purely functional rmsnorm with no learnable params
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""nn.Linear that casts weights to match input dtype in forward.
|
||||
Replaces autocast: master weights stay fp32 for optimizer precision,
|
||||
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.weight.to(dtype=x.dtype))
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
@@ -66,12 +72,12 @@ class CausalSelfAttention(nn.Module):
|
||||
self.head_dim = self.n_embd // self.n_head
|
||||
assert self.n_embd % self.n_head == 0
|
||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.ve_gate_channels = 32
|
||||
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
B, T, C = x.size()
|
||||
@@ -121,8 +127,8 @@ class CausalSelfAttention(nn.Module):
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
@@ -164,7 +170,7 @@ class GPT(nn.Module):
|
||||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
@@ -234,11 +240,13 @@ class GPT(nn.Module):
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
|
||||
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
# Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
|
||||
# embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
|
||||
# because GradScaler cannot unscale fp16 gradients.
|
||||
if COMPUTE_DTYPE != torch.float16:
|
||||
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
ve.to(dtype=COMPUTE_DTYPE)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
@@ -253,7 +261,7 @@ class GPT(nn.Module):
|
||||
# calculate the rotation frequencies at each (time, channel) pair
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos, sin = freqs.cos(), freqs.sin()
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
||||
cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
@@ -391,18 +399,19 @@ class GPT(nn.Module):
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
||||
x = norm(x)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user