delete autocast, an unnecessary thorn in my side, manage dtypes directly

This commit is contained in:
Andrej Karpathy
2026-03-04 23:55:24 +00:00
parent 752abc836e
commit 1076f97059
15 changed files with 258 additions and 167 deletions
+21
View File
@@ -82,6 +82,27 @@ The important thing to note is that nanochat is written and configured around on
The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way. The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way.
## Precision / dtype
nanochat does not use `torch.amp.autocast`. Instead, precision is managed explicitly through a single global `COMPUTE_DTYPE` (defined in `nanochat/common.py`). By default this is auto-detected based on your hardware:
| Hardware | Default dtype | Why |
|----------|--------------|-----|
| CUDA SM 80+ (A100, H100, ...) | `bfloat16` | Native bf16 tensor cores |
| CUDA SM < 80 (V100, T4, ...) | `float32` | No bf16; fp16 available via `NANOCHAT_DTYPE=float16` (uses GradScaler) |
| CPU / MPS | `float32` | No reduced-precision tensor cores |
You can override the default with the `NANOCHAT_DTYPE` environment variable:
```bash
NANOCHAT_DTYPE=float32 python -m scripts.chat_cli -p "hello" # force fp32
NANOCHAT_DTYPE=bfloat16 torchrun --nproc_per_node=8 -m scripts.base_train # force bf16
```
How it works: model weights are stored in fp32 (for optimizer precision), but our custom `Linear` layer casts them to `COMPUTE_DTYPE` during the forward pass. Embeddings are stored directly in `COMPUTE_DTYPE` to save memory. This gives us the same mixed-precision benefit as autocast but with full explicit control over what runs in which precision.
Note: `float16` training automatically enables a `GradScaler` in `base_train.py` to prevent gradient underflow. SFT suppors this too but RL currently does not. Inference in fp16 works fine everywhere.
## Guides ## Guides
I've published a number of guides that might contain helpful information, most recent to least recent: I've published a number of guides that might contain helpful information, most recent to least recent:
+35
View File
@@ -4,6 +4,41 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
--- ---
## 2026-03-04: Remove autocast, explicit dtype management, fp16 GradScaler
Replaced `torch.amp.autocast` throughout the codebase with explicit dtype management via a single `COMPUTE_DTYPE` global. Also added fp16 training support with GradScaler.
### Motivation
autocast is "magic we don't control" — it silently decides which ops run in which precision via internal allowlists. For this codebase, autocast was doing very little: the only thing it actually cast was `nn.Linear` weights from fp32 to bf16 for matmuls. `F.rms_norm`, `F.cross_entropy`, and Flash Attention all handle their own dtypes already. By making precision explicit, we gain fine-grained control (e.g. can experiment with fp32 norms) and eliminate an unnecessary layer of abstraction.
### What changed
**Core mechanism** (`nanochat/common.py`, `nanochat/gpt.py`):
- `COMPUTE_DTYPE` auto-detected from hardware: SM 80+ → bf16, pre-Ampere → fp32, CPU/MPS → fp32. Override via `NANOCHAT_DTYPE` env var.
- Custom `Linear(nn.Linear)` class that casts weights to match input dtype in forward: `F.linear(x, self.weight.to(dtype=x.dtype))`. This is the single mechanism that replaces autocast.
- Embeddings cast to `COMPUTE_DTYPE` at init (saves memory). Exception: fp16 keeps embeddings fp32 because GradScaler cannot unscale fp16 gradients.
- Embedding output explicitly cast to `COMPUTE_DTYPE` in `GPT.forward()` (no-op for bf16, active for fp16 path).
- RoPE cos/sin cache uses `COMPUTE_DTYPE` instead of hardcoded bf16.
**Autocast removal** (11 files):
- Deleted `--dtype` CLI flag, `ptdtype` variables, `autocast_ctx` definitions, and all `with autocast_ctx:` blocks from: `base_train.py`, `chat_sft.py`, `chat_rl.py`, `chat_cli.py`, `chat_eval.py`, `chat_web.py`, `base_eval.py`, `engine.py`, `bench_train_toks.py`, `test_e2e_pipeline.py`.
**fp16 + GradScaler** (`base_train.py`, `chat_sft.py`):
- `scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None`
- Backward: `scaler.scale(loss).backward()` vs plain `loss.backward()`
- After accumulation: `scaler.unscale_(optimizer)` → distributed inf-sync via `scaler._found_inf_per_device(optimizer)` all-reduced with `ReduceOp.MAX``scaler.step(optimizer)``scaler.update()`
- Zero overhead for bf16/fp32 paths (scaler is None, no branching inside kernels).
**FP8 fix** (`nanochat/fp8.py`, `base_train.py`):
- `Float8Linear.forward` explicitly casts input to `COMPUTE_DTYPE` (previously relied on autocast).
- `disable_fp8` context manager now creates our custom `Linear` (not vanilla `nn.Linear`) when swapping out Float8Linear during eval.
**Flash Attention** (`flash_attention.py`):
- FA3 Hopper kernels don't support fp16 or fp32, so `USE_FA3` (module-level constant, resolved once at import) returns False, falling back to SDPA.
---
## 2026-03-04: Dataset upgrade: FineWeb-EDU 100B → ClimbMix 400B ## 2026-03-04: Dataset upgrade: FineWeb-EDU 100B → ClimbMix 400B
Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction. Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction.
+20
View File
@@ -10,6 +10,26 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from filelock import FileLock from filelock import FileLock
# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
def _detect_compute_dtype():
env = os.environ.get("NANOCHAT_DTYPE")
if env is not None:
return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
if torch.cuda.is_available():
# bf16 requires SM 80+ (Ampere: A100, A10, etc.)
# Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
capability = torch.cuda.get_device_capability()
if capability >= (8, 0):
return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
class ColoredFormatter(logging.Formatter): class ColoredFormatter(logging.Formatter):
"""Custom formatter that adds colors to log messages.""" """Custom formatter that adds colors to log messages."""
# ANSI color codes # ANSI color codes
+9 -14
View File
@@ -19,7 +19,6 @@ from contextlib import contextmanager
from collections import deque from collections import deque
from nanochat.common import compute_init, autodetect_device_type from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Calculator tool helpers # Calculator tool helpers
@@ -308,8 +307,6 @@ if __name__ == "__main__":
# init compute # init compute
device_type = autodetect_device_type() device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer # load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval") model, tokenizer, meta = load_model("base", device, phase="eval")
bos_token_id = tokenizer.get_bos_token_id() bos_token_id = tokenizer.get_bos_token_id()
@@ -322,11 +319,10 @@ if __name__ == "__main__":
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs) stream = model.generate(prompt_tokens, **kwargs)
with autocast_ctx: for token in stream:
for token in stream: generated_tokens.append(token)
generated_tokens.append(token) chunk = tokenizer.decode([token])
chunk = tokenizer.decode([token]) print(chunk, end="", flush=True)
print(chunk, end="", flush=True)
print() print()
torch.cuda.synchronize() torch.cuda.synchronize()
t1 = time.time() t1 = time.time()
@@ -338,12 +334,11 @@ if __name__ == "__main__":
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
with autocast_ctx: for token_column, token_masks in stream:
for token_column, token_masks in stream: token = token_column[0] # only print out the first row
token = token_column[0] # only print out the first row generated_tokens.append(token)
generated_tokens.append(token) chunk = tokenizer.decode([token])
chunk = tokenizer.decode([token]) print(chunk, end="", flush=True)
print(chunk, end="", flush=True)
print() print()
torch.cuda.synchronize() torch.cuda.synchronize()
t1 = time.time() t1 = time.time()
+13 -5
View File
@@ -45,14 +45,22 @@ HAS_FA3 = _fa3 is not None
_override_impl = None _override_impl = None
def _use_fa3(): def _resolve_use_fa3():
"""Determine whether to use FA3 based on availability and override.""" """Decide once whether to use FA3, based on availability, override, and dtype."""
if _override_impl == 'fa3': if _override_impl == 'fa3':
assert HAS_FA3, "Cannot override to FA3: not available on this hardware" assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
return True return True
if _override_impl == 'sdpa': if _override_impl == 'sdpa':
return False return False
return HAS_FA3 # auto if HAS_FA3:
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
from nanochat.common import COMPUTE_DTYPE
if COMPUTE_DTYPE == torch.bfloat16:
return True
return False
return False
USE_FA3 = _resolve_use_fa3()
# ============================================================================= # =============================================================================
@@ -108,7 +116,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
Returns: Returns:
Output tensor of shape (B, T, H, D) Output tensor of shape (B, T, H, D)
""" """
if _use_fa3(): if USE_FA3:
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
@@ -138,7 +146,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
Returns: Returns:
Output tensor of shape (B, T_new, H, D) Output tensor of shape (B, T_new, H, D)
""" """
if _use_fa3(): if USE_FA3:
return _fa3.flash_attn_with_kvcache( return _fa3.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size causal=causal, window_size=window_size
+5 -5
View File
@@ -72,6 +72,8 @@ generates a different graph. Numerics are bitwise identical in eager mode.
import torch import torch
import torch.nn as nn import torch.nn as nn
from nanochat.common import COMPUTE_DTYPE
# Avoid division by zero when computing scale from an all-zeros tensor # Avoid division by zero when computing scale from an all-zeros tensor
EPS = 1e-12 EPS = 1e-12
@@ -198,11 +200,9 @@ class Float8Linear(nn.Linear):
""" """
def forward(self, input): def forward(self, input):
# Replicate the autocast behavior of F.linear — when autocast is active, # Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects
# we need to manually cast input to the autocast dtype (e.g. bf16), # reduced precision input, and we no longer rely on autocast to do this.
# since we bypass F.linear's built-in autocast handling. input = input.to(COMPUTE_DTYPE)
if torch.is_autocast_enabled():
input = input.to(torch.get_autocast_gpu_dtype())
# _scaled_mm only works on 2D tensors, so flatten batch dimensions # _scaled_mm only works on 2D tensors, so flatten batch dimensions
orig_shape = input.shape orig_shape = input.shape
input_2d = input.reshape(-1, orig_shape[-1]) input_2d = input.reshape(-1, orig_shape[-1])
+27 -18
View File
@@ -19,7 +19,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nanochat.common import get_dist_info, print0 from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
from nanochat.optim import MuonAdamW, DistMuonAdamW from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
@@ -40,8 +40,14 @@ class GPTConfig:
def norm(x): def norm(x):
# Purely functional rmsnorm with no learnable params return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
return F.rms_norm(x, (x.size(-1),))
class Linear(nn.Linear):
"""nn.Linear that casts weights to match input dtype in forward.
Replaces autocast: master weights stay fp32 for optimizer precision,
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
def forward(self, x):
return F.linear(x, self.weight.to(dtype=x.dtype))
def has_ve(layer_idx, n_layer): def has_ve(layer_idx, n_layer):
@@ -66,12 +72,12 @@ class CausalSelfAttention(nn.Module):
self.head_dim = self.n_embd // self.n_head self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0 assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32 self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
def forward(self, x, ve, cos_sin, window_size, kv_cache): def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size() B, T, C = x.size()
@@ -121,8 +127,8 @@ class CausalSelfAttention(nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x): def forward(self, x):
x = self.c_fc(x) x = self.c_fc(x)
@@ -164,7 +170,7 @@ class GPT(nn.Module):
"wte": nn.Embedding(padded_vocab_size, config.n_embd), "wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
}) })
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False) self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
# Per-layer learnable scalars (inspired by modded-nanogpt) # Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
@@ -234,11 +240,13 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin self.cos, self.sin = cos, sin
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory # Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
if self.transformer.wte.weight.device.type == "cuda": # embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
self.transformer.wte.to(dtype=torch.bfloat16) # because GradScaler cannot unscale fp16 gradients.
if COMPUTE_DTYPE != torch.float16:
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
for ve in self.value_embeds.values(): for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16) ve.to(dtype=COMPUTE_DTYPE)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently # TODO: bump base theta more? e.g. 100K is more common more recently
@@ -253,7 +261,7 @@ class GPT(nn.Module):
# calculate the rotation frequencies at each (time, channel) pair # calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq) freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin() cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin return cos, sin
@@ -391,18 +399,19 @@ class GPT(nn.Module):
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos() T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Forward the trunk of the Transformer # Forward the trunk of the Transformer
x = self.transformer.wte(idx) # embed current token x = self.transformer.wte(idx) # embed current token
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
x = norm(x) x = norm(x)
x0 = x # save initial normalized embedding for x0 residual x0 = x # save initial normalized embedding for x0 residual
for i, block in enumerate(self.transformer.h): for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
x = norm(x) x = norm(x)
+4 -12
View File
@@ -29,8 +29,6 @@ import random
import zipfile import zipfile
import tempfile import tempfile
import argparse import argparse
from contextlib import nullcontext
import torch import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
@@ -199,8 +197,6 @@ def main():
# Distributed / precision setup # Distributed / precision setup
device_type = autodetect_device_type() if args.device_type == '' else args.device_type device_type = autodetect_device_type() if args.device_type == '' else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# Load model and tokenizer # Load model and tokenizer
is_hf_model = args.hf_path is not None is_hf_model = args.hf_path is not None
if is_hf_model: if is_hf_model:
@@ -244,8 +240,7 @@ def main():
print0("\nConditioned samples:") print0("\nConditioned samples:")
for prompt in prompts: for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>") tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx: sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0]) sample_str = tokenizer.decode(sample[0])
print0("-" * 80) print0("-" * 80)
print0(sample_str) print0(sample_str)
@@ -253,8 +248,7 @@ def main():
print0("\nUnconditioned samples:") print0("\nUnconditioned samples:")
tokens = tokenizer("", prepend="<|bos|>") tokens = tokenizer("", prepend="<|bos|>")
with autocast_ctx: uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
for sample in uncond: for sample in uncond:
sample_str = tokenizer.decode(sample) sample_str = tokenizer.decode(sample)
print0("-" * 80) print0("-" * 80)
@@ -277,8 +271,7 @@ def main():
for split_name in ["train", "val"]: for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device) loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx: bpb = evaluate_bpb(model, loader, steps, token_bytes)
bpb = evaluate_bpb(model, loader, steps, token_bytes)
bpb_results[split_name] = bpb bpb_results[split_name] = bpb
print0(f"{split_name} bpb: {bpb:.6f}") print0(f"{split_name} bpb: {bpb:.6f}")
@@ -287,8 +280,7 @@ def main():
print0("\n" + "="*80) print0("\n" + "="*80)
print0("CORE Evaluation") print0("CORE Evaluation")
print0("="*80) print0("="*80)
with autocast_ctx: core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
# Write CSV output # Write CSV output
if ddp_rank == 0: if ddp_rank == 0:
+40 -15
View File
@@ -19,14 +19,15 @@ import time
import math import math
import argparse import argparse
from dataclasses import asdict from dataclasses import asdict
from contextlib import nullcontext, contextmanager from contextlib import contextmanager
import wandb import wandb
import torch import torch
import torch.distributed as dist
from nanochat.gpt import GPT, GPTConfig from nanochat.gpt import GPT, GPTConfig, Linear
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb from nanochat.loss_eval import evaluate_bpb
@@ -86,7 +87,6 @@ user_config = vars(args).copy() # for logging
device_type = autodetect_device_type() if args.device_type == "" else args.device_type device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
if device_type == "cuda": if device_type == "cuda":
@@ -95,17 +95,23 @@ if device_type == "cuda":
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
else: else:
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
# wandb logging init # wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Flash Attention status # Flash Attention status
if HAS_FA3: from nanochat.flash_attention import USE_FA3
using_fa3 = USE_FA3
if using_fa3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
else: else:
print0("!" * 80) print0("!" * 80)
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
else:
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
print0("WARNING: Training will be less efficient without FA3") print0("WARNING: Training will be less efficient without FA3")
if args.window_pattern != "L": if args.window_pattern != "L":
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
@@ -213,9 +219,9 @@ def disable_fp8(model):
yield # No FP8 modules, nothing to do yield # No FP8 modules, nothing to do
return return
# Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy) # Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
for parent, attr_name, fp8_module in fp8_locations: for parent, attr_name, fp8_module in fp8_locations:
linear = nn.Linear( linear = Linear(
fp8_module.in_features, fp8_module.in_features,
fp8_module.out_features, fp8_module.out_features,
bias=fp8_module.bias is not None, bias=fp8_module.bias is not None,
@@ -315,6 +321,12 @@ if resuming:
optimizer.load_state_dict(optimizer_data) optimizer.load_state_dict(optimizer_data)
del optimizer_data del optimizer_data
# -----------------------------------------------------------------------------
# GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32)
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
if scaler is not None:
print0("GradScaler enabled for fp16 training")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val # Initialize the DataLoaders for train/val
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
@@ -405,7 +417,7 @@ while True:
model.eval() model.eval()
val_loader = build_val_loader() val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with disable_fp8(model), autocast_ctx: with disable_fp8(model):
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
if val_bpb < min_val_bpb: if val_bpb < min_val_bpb:
@@ -424,7 +436,7 @@ while True:
results = {} results = {}
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
model.eval() model.eval()
with disable_fp8(orig_model), autocast_ctx: with disable_fp8(orig_model):
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({ wandb_run.log({
@@ -451,7 +463,7 @@ while True:
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
for prompt in prompts: for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>") tokens = tokenizer(prompt, prepend="<|bos|>")
with disable_fp8(orig_model), autocast_ctx: with disable_fp8(orig_model):
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
print0(tokenizer.decode(sample[0])) print0(tokenizer.decode(sample[0]))
model.train() model.train()
@@ -491,11 +503,13 @@ while True:
synchronize() synchronize()
t0 = time.time() t0 = time.time()
for micro_step in range(grad_accum_steps): for micro_step in range(grad_accum_steps):
with autocast_ctx: loss = model(x, y)
loss = model(x, y)
train_loss = loss.detach() # for logging train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# step the optimizer # step the optimizer
lrm = get_lr_multiplier(step) lrm = get_lr_multiplier(step)
@@ -506,7 +520,18 @@ while True:
if group['kind'] == 'muon': if group['kind'] == 'muon':
group["momentum"] = muon_momentum group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay group["weight_decay"] = muon_weight_decay
optimizer.step() if scaler is not None:
scaler.unscale_(optimizer)
# In distributed training, all ranks must agree on whether to skip the step.
# Each rank may independently encounter inf/nan gradients, so we all-reduce
# the found_inf flag (MAX = if any rank found inf, all ranks skip).
if is_ddp_initialized():
for v in scaler._found_inf_per_device(optimizer).values():
dist.all_reduce(v, op=dist.ReduceOp.MAX)
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
synchronize() synchronize()
+5 -10
View File
@@ -7,7 +7,6 @@ python -m scripts.chat_cli
import argparse import argparse
import torch import torch
from nanochat.common import compute_init, autodetect_device_type from nanochat.common import compute_init, autodetect_device_type
from contextlib import nullcontext
from nanochat.engine import Engine from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
@@ -19,15 +18,12 @@ parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the mod
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation') parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
args = parser.parse_args() args = parser.parse_args()
# Init the model and tokenizer # Init the model and tokenizer
device_type = autodetect_device_type() if args.device_type == "" else args.device_type device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine # Special tokens for the chat state machine
@@ -87,12 +83,11 @@ while True:
} }
response_tokens = [] response_tokens = []
print("\nAssistant: ", end="", flush=True) print("\nAssistant: ", end="", flush=True)
with autocast_ctx: for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs): token = token_column[0] # pop the batch dimension (num_samples=1)
token = token_column[0] # pop the batch dimension (num_samples=1) response_tokens.append(token)
response_tokens.append(token) token_text = tokenizer.decode([token])
token_text = tokenizer.decode([token]) print(token_text, end="", flush=True)
print(token_text, end="", flush=True)
print() print()
# we have to ensure that the assistant end token is the last token # we have to ensure that the assistant end token is the last token
# so even if generation ends due to max tokens, we have to append it to the end # so even if generation ends due to max tokens, we have to append it to the end
+12 -18
View File
@@ -10,8 +10,6 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
import argparse import argparse
from functools import partial from functools import partial
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -185,7 +183,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl") parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.") parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('-t', '--temperature', type=float, default=0.0) parser.add_argument('-t', '--temperature', type=float, default=0.0)
parser.add_argument('-m', '--max-new-tokens', type=int, default=512) parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
parser.add_argument('-n', '--num-samples', type=int, default=1) parser.add_argument('-n', '--num-samples', type=int, default=1)
@@ -199,8 +196,6 @@ if __name__ == "__main__":
device_type = autodetect_device_type() if args.device_type == "" else args.device_type device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer) engine = Engine(model, tokenizer)
@@ -220,19 +215,18 @@ if __name__ == "__main__":
# Run all the task evaluations sequentially # Run all the task evaluations sequentially
results = {} results = {}
for task_name in task_names: for task_name in task_names:
with autocast_ctx: acc = run_chat_eval(
acc = run_chat_eval( task_name,
task_name, model, tokenizer, engine,
model, tokenizer, engine, batch_size=args.batch_size,
batch_size=args.batch_size, num_samples=args.num_samples,
num_samples=args.num_samples, max_new_tokens=args.max_new_tokens,
max_new_tokens=args.max_new_tokens, temperature=args.temperature,
temperature=args.temperature, top_k=args.top_k,
top_k=args.top_k, max_problems=args.max_problems,
max_problems=args.max_problems, )
) results[task_name] = acc
results[task_name] = acc print0(f"{task_name} accuracy: {100 * acc:.2f}%")
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
# Log to report # Log to report
from nanochat.report import get_report from nanochat.report import get_report
+11 -19
View File
@@ -22,8 +22,6 @@ import itertools
import wandb import wandb
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import save_checkpoint, load_model from nanochat.checkpoint_manager import save_checkpoint, load_model
from nanochat.engine import Engine from nanochat.engine import Engine
@@ -36,7 +34,6 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime # Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading # Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
@@ -68,8 +65,6 @@ user_config = vars(args).copy()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init # wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process use_dummy_wandb = args.run == "dummy" or not master_process
@@ -108,15 +103,14 @@ def get_batch():
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
for sampling_step in range(num_sampling_steps): for sampling_step in range(num_sampling_steps):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32 seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
with autocast_ctx: generated_token_sequences_batch, masks_batch = engine.generate_batch(
generated_token_sequences_batch, masks_batch = engine.generate_batch( tokens,
tokens, num_samples=args.device_batch_size,
num_samples=args.device_batch_size, max_tokens=args.max_new_tokens,
max_tokens=args.max_new_tokens, temperature=args.temperature,
temperature=args.temperature, top_k=args.top_k,
top_k=args.top_k, seed=seed, # must make sure to change the seed for each sampling step
seed=seed, # must make sure to change the seed for each sampling step )
)
generated_token_sequences.extend(generated_token_sequences_batch) generated_token_sequences.extend(generated_token_sequences_batch)
masks.extend(masks_batch) masks.extend(masks_batch)
@@ -231,9 +225,8 @@ for step in range(num_steps):
if step % args.eval_every == 0: if step % args.eval_every == 0:
model.eval() model.eval()
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
with autocast_ctx: records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0) records = list(records_iter) # collect all records
records = list(records_iter) # collect all records
for k in range(1, args.device_batch_size + 1): for k in range(1, args.device_batch_size + 1):
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records) passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
num_records = torch.tensor(len(records), dtype=torch.long, device=device) num_records = torch.tensor(len(records), dtype=torch.long, device=device)
@@ -268,8 +261,7 @@ for step in range(num_steps):
rewards = rewards_all[b0:b1] rewards = rewards_all[b0:b1]
advantages = advantages_all[b0:b1] advantages = advantages_all[b0:b1]
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate # Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
with autocast_ctx: logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0. # Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
pg_obj = (logp * advantages.unsqueeze(-1)).sum() pg_obj = (logp * advantages.unsqueeze(-1)).sum()
# normalize by the number of valid tokens, number of passes, and examples_per_rank # normalize by the number of valid tokens, number of passes, and examples_per_rank
+24 -12
View File
@@ -16,8 +16,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time import time
import wandb import wandb
import torch import torch
from contextlib import nullcontext from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.tokenizer import get_token_bytes from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
from nanochat.loss_eval import evaluate_bpb from nanochat.loss_eval import evaluate_bpb
@@ -75,7 +74,7 @@ user_config = vars(args).copy()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 master_process = ddp_rank == 0
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
if device_type == "cuda": if device_type == "cuda":
@@ -151,6 +150,11 @@ if args.load_optimizer:
else: else:
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
# GradScaler for fp16 training (bf16/fp32 don't need it)
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
if scaler is not None:
print0("GradScaler enabled for fp16 training")
# Override the initial learning rate as a fraction of the base learning rate # Override the initial learning rate as a fraction of the base learning rate
for group in optimizer.param_groups: for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac group["lr"] = group["lr"] * args.init_lr_frac
@@ -344,8 +348,7 @@ while True:
model.eval() model.eval()
val_loader = build_val_loader() val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx: val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb: if val_bpb < min_val_bpb:
min_val_bpb = val_bpb min_val_bpb = val_bpb
@@ -373,9 +376,8 @@ while True:
for task_name in all_tasks: for task_name in all_tasks:
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
max_problems = None if limit < 0 else limit # -1 means no limit max_problems = None if limit < 0 else limit # -1 means no limit
with autocast_ctx: acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
acc = run_chat_eval(task_name, orig_model, tokenizer, engine, batch_size=args.device_batch_size, max_problems=max_problems)
batch_size=args.device_batch_size, max_problems=max_problems)
task_results[task_name] = acc task_results[task_name] = acc
print0(f" {task_name}: {100*acc:.2f}%") print0(f" {task_name}: {100*acc:.2f}%")
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect) # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
@@ -428,11 +430,13 @@ while True:
synchronize() synchronize()
t0 = time.time() t0 = time.time()
for micro_step in range(grad_accum_steps): for micro_step in range(grad_accum_steps):
with autocast_ctx: loss = model(x, y)
loss = model(x, y)
train_loss = loss.detach() # for logging train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically progress = max(progress, approx_progress) # only increase progress monotonically
# step the optimizer # step the optimizer
@@ -442,7 +446,15 @@ while True:
group["lr"] = group["initial_lr"] * lrm group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon': if group['kind'] == 'muon':
group["momentum"] = muon_momentum group["momentum"] = muon_momentum
optimizer.step() if scaler is not None:
scaler.unscale_(optimizer)
if is_ddp_initialized():
for v in scaler._found_inf_per_device(optimizer).values():
dist.all_reduce(v, op=dist.ReduceOp.MAX)
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
synchronize() synchronize()
t1 = time.time() t1 = time.time()
+25 -33
View File
@@ -44,7 +44,6 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass from dataclasses import dataclass
from contextlib import nullcontext
from nanochat.common import compute_init, autodetect_device_type from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine from nanochat.engine import Engine
@@ -69,7 +68,6 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args() args = parser.parse_args()
@@ -84,7 +82,6 @@ logger = logging.getLogger(__name__)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
@dataclass @dataclass
class Worker: class Worker:
@@ -93,7 +90,6 @@ class Worker:
device: torch.device device: torch.device
engine: Engine engine: Engine
tokenizer: object tokenizer: object
autocast_ctx: torch.amp.autocast
class WorkerPool: class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU.""" """Pool of workers, each with a model replica on a different GPU."""
@@ -125,14 +121,11 @@ class WorkerPool:
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer) engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
worker = Worker( worker = Worker(
gpu_id=gpu_id, gpu_id=gpu_id,
device=device, device=device,
engine=engine, engine=engine,
tokenizer=tokenizer, tokenizer=tokenizer,
autocast_ctx=autocast_ctx
) )
self.workers.append(worker) self.workers.append(worker)
await self.available_workers.put(worker) await self.available_workers.put(worker)
@@ -279,34 +272,33 @@ async def generate_stream(
# Track the last complete UTF-8 string (without replacement characters) # Track the last complete UTF-8 string (without replacement characters)
last_clean_text = "" last_clean_text = ""
with worker.autocast_ctx: for token_column, token_masks in worker.engine.generate(
for token_column, token_masks in worker.engine.generate( tokens,
tokens, num_samples=1,
num_samples=1, max_tokens=max_new_tokens,
max_tokens=max_new_tokens, temperature=temperature,
temperature=temperature, top_k=top_k,
top_k=top_k, seed=random.randint(0, 2**31 - 1)
seed=random.randint(0, 2**31 - 1) ):
): token = token_column[0]
token = token_column[0]
# Stopping criteria # Stopping criteria
if token == assistant_end or token == bos: if token == assistant_end or token == bos:
break break
# Append the token to sequence # Append the token to sequence
accumulated_tokens.append(token) accumulated_tokens.append(token)
# Decode all accumulated tokens to get proper UTF-8 handling # Decode all accumulated tokens to get proper UTF-8 handling
# Note that decode is a quite efficient operation, basically table lookup and string concat # Note that decode is a quite efficient operation, basically table lookup and string concat
current_text = worker.tokenizer.decode(accumulated_tokens) current_text = worker.tokenizer.decode(accumulated_tokens)
# Only emit text if it doesn't end with a replacement character # Only emit text if it doesn't end with a replacement character
# This ensures we don't emit incomplete UTF-8 sequences # This ensures we don't emit incomplete UTF-8 sequences
if not current_text.endswith(''): if not current_text.endswith(''):
# Extract only the new text since last clean decode # Extract only the new text since last clean decode
new_text = current_text[len(last_clean_text):] new_text = current_text[len(last_clean_text):]
if new_text: # Only yield if there's new content if new_text: # Only yield if there's new content
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
last_clean_text = current_text last_clean_text = current_text
yield f"data: {json.dumps({'done': True})}\n\n" yield f"data: {json.dumps({'done': True})}\n\n"
+5 -4
View File
@@ -21,8 +21,9 @@ from nanochat.engine import KVCache
def set_impl(impl): def set_impl(impl):
"""Set the implementation override ('fa3', 'sdpa', or None for auto).""" """Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3."""
fa_module._override_impl = impl fa_module._override_impl = impl
fa_module.USE_FA3 = fa_module._resolve_use_fa3()
def run_both_impls(fn): def run_both_impls(fn):
@@ -343,19 +344,19 @@ class TestOverrideMechanism:
def test_override_fa3(self): def test_override_fa3(self):
"""Test that override='fa3' uses FA3.""" """Test that override='fa3' uses FA3."""
set_impl('fa3') set_impl('fa3')
assert fa_module._use_fa3() == True assert fa_module.USE_FA3 == True
set_impl(None) set_impl(None)
def test_override_sdpa(self): def test_override_sdpa(self):
"""Test that override='sdpa' uses SDPA.""" """Test that override='sdpa' uses SDPA."""
set_impl('sdpa') set_impl('sdpa')
assert fa_module._use_fa3() == False assert fa_module.USE_FA3 == False
set_impl(None) set_impl(None)
def test_override_auto(self): def test_override_auto(self):
"""Test that override=None uses auto-detection.""" """Test that override=None uses auto-detection."""
set_impl(None) set_impl(None)
assert fa_module._use_fa3() == HAS_FA3 assert fa_module.USE_FA3 == HAS_FA3
if __name__ == "__main__": if __name__ == "__main__":