delete autocast, an unnecessary thorn in my side, manage dtypes directly
This commit is contained in:
@@ -10,6 +10,26 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from filelock import FileLock
|
||||
|
||||
# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
|
||||
# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
|
||||
# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
|
||||
_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
|
||||
def _detect_compute_dtype():
|
||||
env = os.environ.get("NANOCHAT_DTYPE")
|
||||
if env is not None:
|
||||
return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
|
||||
if torch.cuda.is_available():
|
||||
# bf16 requires SM 80+ (Ampere: A100, A10, etc.)
|
||||
# Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability >= (8, 0):
|
||||
return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
|
||||
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
|
||||
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
|
||||
return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
|
||||
return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
|
||||
COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter that adds colors to log messages."""
|
||||
# ANSI color codes
|
||||
|
||||
+9
-14
@@ -19,7 +19,6 @@ from contextlib import contextmanager
|
||||
from collections import deque
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from contextlib import nullcontext
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
@@ -308,8 +307,6 @@ if __name__ == "__main__":
|
||||
# init compute
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
bos_token_id = tokenizer.get_bos_token_id()
|
||||
@@ -322,11 +319,10 @@ if __name__ == "__main__":
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
stream = model.generate(prompt_tokens, **kwargs)
|
||||
with autocast_ctx:
|
||||
for token in stream:
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
for token in stream:
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
print()
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
@@ -338,12 +334,11 @@ if __name__ == "__main__":
|
||||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
generated_tokens.append(token)
|
||||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
print()
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
||||
@@ -45,14 +45,22 @@ HAS_FA3 = _fa3 is not None
|
||||
_override_impl = None
|
||||
|
||||
|
||||
def _use_fa3():
|
||||
"""Determine whether to use FA3 based on availability and override."""
|
||||
def _resolve_use_fa3():
|
||||
"""Decide once whether to use FA3, based on availability, override, and dtype."""
|
||||
if _override_impl == 'fa3':
|
||||
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
||||
return True
|
||||
if _override_impl == 'sdpa':
|
||||
return False
|
||||
return HAS_FA3 # auto
|
||||
if HAS_FA3:
|
||||
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
|
||||
from nanochat.common import COMPUTE_DTYPE
|
||||
if COMPUTE_DTYPE == torch.bfloat16:
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
||||
USE_FA3 = _resolve_use_fa3()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -90,7 +98,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
||||
# sliding window (left)
|
||||
if window >= 0 and window < Tk:
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
|
||||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
# =============================================================================
|
||||
@@ -108,7 +116,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
||||
Returns:
|
||||
Output tensor of shape (B, T, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
|
||||
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
||||
@@ -138,7 +146,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
||||
Returns:
|
||||
Output tensor of shape (B, T_new, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
causal=causal, window_size=window_size
|
||||
|
||||
+6
-6
@@ -72,6 +72,8 @@ generates a different graph. Numerics are bitwise identical in eager mode.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nanochat.common import COMPUTE_DTYPE
|
||||
|
||||
# Avoid division by zero when computing scale from an all-zeros tensor
|
||||
EPS = 1e-12
|
||||
|
||||
@@ -123,7 +125,7 @@ def _to_col_major(x):
|
||||
class _Float8Matmul(torch.autograd.Function):
|
||||
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
|
||||
|
||||
The forward quantizes input and weight to FP8 and saves
|
||||
The forward quantizes input and weight to FP8 and saves
|
||||
the quantized tensors + scales for backward.
|
||||
"""
|
||||
|
||||
@@ -198,11 +200,9 @@ class Float8Linear(nn.Linear):
|
||||
"""
|
||||
|
||||
def forward(self, input):
|
||||
# Replicate the autocast behavior of F.linear — when autocast is active,
|
||||
# we need to manually cast input to the autocast dtype (e.g. bf16),
|
||||
# since we bypass F.linear's built-in autocast handling.
|
||||
if torch.is_autocast_enabled():
|
||||
input = input.to(torch.get_autocast_gpu_dtype())
|
||||
# Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects
|
||||
# reduced precision input, and we no longer rely on autocast to do this.
|
||||
input = input.to(COMPUTE_DTYPE)
|
||||
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
|
||||
orig_shape = input.shape
|
||||
input_2d = input.reshape(-1, orig_shape[-1])
|
||||
|
||||
+27
-18
@@ -19,7 +19,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
|
||||
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
||||
|
||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||
@@ -40,8 +40,14 @@ class GPTConfig:
|
||||
|
||||
|
||||
def norm(x):
|
||||
# Purely functional rmsnorm with no learnable params
|
||||
return F.rms_norm(x, (x.size(-1),))
|
||||
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""nn.Linear that casts weights to match input dtype in forward.
|
||||
Replaces autocast: master weights stay fp32 for optimizer precision,
|
||||
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.weight.to(dtype=x.dtype))
|
||||
|
||||
|
||||
def has_ve(layer_idx, n_layer):
|
||||
@@ -66,12 +72,12 @@ class CausalSelfAttention(nn.Module):
|
||||
self.head_dim = self.n_embd // self.n_head
|
||||
assert self.n_embd % self.n_head == 0
|
||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
|
||||
self.ve_gate_channels = 32
|
||||
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
B, T, C = x.size()
|
||||
@@ -121,8 +127,8 @@ class CausalSelfAttention(nn.Module):
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||
self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
@@ -164,7 +170,7 @@ class GPT(nn.Module):
|
||||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
@@ -234,11 +240,13 @@ class GPT(nn.Module):
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
|
||||
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
# Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
|
||||
# embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
|
||||
# because GradScaler cannot unscale fp16 gradients.
|
||||
if COMPUTE_DTYPE != torch.float16:
|
||||
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
||||
for ve in self.value_embeds.values():
|
||||
ve.to(dtype=torch.bfloat16)
|
||||
ve.to(dtype=COMPUTE_DTYPE)
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
@@ -253,7 +261,7 @@ class GPT(nn.Module):
|
||||
# calculate the rotation frequencies at each (time, channel) pair
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
cos, sin = freqs.cos(), freqs.sin()
|
||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
||||
cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
|
||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||
return cos, sin
|
||||
|
||||
@@ -391,18 +399,19 @@ class GPT(nn.Module):
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx) # embed current token
|
||||
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
||||
x = norm(x)
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user