diff --git a/README.md b/README.md index 05c7942..077fd9c 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,27 @@ The important thing to note is that nanochat is written and configured around on The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way. +## Precision / dtype + +nanochat does not use `torch.amp.autocast`. Instead, precision is managed explicitly through a single global `COMPUTE_DTYPE` (defined in `nanochat/common.py`). By default this is auto-detected based on your hardware: + +| Hardware | Default dtype | Why | +|----------|--------------|-----| +| CUDA SM 80+ (A100, H100, ...) | `bfloat16` | Native bf16 tensor cores | +| CUDA SM < 80 (V100, T4, ...) | `float32` | No bf16; fp16 available via `NANOCHAT_DTYPE=float16` (uses GradScaler) | +| CPU / MPS | `float32` | No reduced-precision tensor cores | + +You can override the default with the `NANOCHAT_DTYPE` environment variable: + +```bash +NANOCHAT_DTYPE=float32 python -m scripts.chat_cli -p "hello" # force fp32 +NANOCHAT_DTYPE=bfloat16 torchrun --nproc_per_node=8 -m scripts.base_train # force bf16 +``` + +How it works: model weights are stored in fp32 (for optimizer precision), but our custom `Linear` layer casts them to `COMPUTE_DTYPE` during the forward pass. Embeddings are stored directly in `COMPUTE_DTYPE` to save memory. This gives us the same mixed-precision benefit as autocast but with full explicit control over what runs in which precision. + +Note: `float16` training automatically enables a `GradScaler` in `base_train.py` to prevent gradient underflow. SFT suppors this too but RL currently does not. Inference in fp16 works fine everywhere. + ## Guides I've published a number of guides that might contain helpful information, most recent to least recent: diff --git a/dev/LOG.md b/dev/LOG.md index b4b3757..fd5c3c7 100644 --- a/dev/LOG.md +++ b/dev/LOG.md @@ -4,6 +4,41 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026 --- +## 2026-03-04: Remove autocast, explicit dtype management, fp16 GradScaler + +Replaced `torch.amp.autocast` throughout the codebase with explicit dtype management via a single `COMPUTE_DTYPE` global. Also added fp16 training support with GradScaler. + +### Motivation + +autocast is "magic we don't control" — it silently decides which ops run in which precision via internal allowlists. For this codebase, autocast was doing very little: the only thing it actually cast was `nn.Linear` weights from fp32 to bf16 for matmuls. `F.rms_norm`, `F.cross_entropy`, and Flash Attention all handle their own dtypes already. By making precision explicit, we gain fine-grained control (e.g. can experiment with fp32 norms) and eliminate an unnecessary layer of abstraction. + +### What changed + +**Core mechanism** (`nanochat/common.py`, `nanochat/gpt.py`): +- `COMPUTE_DTYPE` auto-detected from hardware: SM 80+ → bf16, pre-Ampere → fp32, CPU/MPS → fp32. Override via `NANOCHAT_DTYPE` env var. +- Custom `Linear(nn.Linear)` class that casts weights to match input dtype in forward: `F.linear(x, self.weight.to(dtype=x.dtype))`. This is the single mechanism that replaces autocast. +- Embeddings cast to `COMPUTE_DTYPE` at init (saves memory). Exception: fp16 keeps embeddings fp32 because GradScaler cannot unscale fp16 gradients. +- Embedding output explicitly cast to `COMPUTE_DTYPE` in `GPT.forward()` (no-op for bf16, active for fp16 path). +- RoPE cos/sin cache uses `COMPUTE_DTYPE` instead of hardcoded bf16. + +**Autocast removal** (11 files): +- Deleted `--dtype` CLI flag, `ptdtype` variables, `autocast_ctx` definitions, and all `with autocast_ctx:` blocks from: `base_train.py`, `chat_sft.py`, `chat_rl.py`, `chat_cli.py`, `chat_eval.py`, `chat_web.py`, `base_eval.py`, `engine.py`, `bench_train_toks.py`, `test_e2e_pipeline.py`. + +**fp16 + GradScaler** (`base_train.py`, `chat_sft.py`): +- `scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None` +- Backward: `scaler.scale(loss).backward()` vs plain `loss.backward()` +- After accumulation: `scaler.unscale_(optimizer)` → distributed inf-sync via `scaler._found_inf_per_device(optimizer)` all-reduced with `ReduceOp.MAX` → `scaler.step(optimizer)` → `scaler.update()` +- Zero overhead for bf16/fp32 paths (scaler is None, no branching inside kernels). + +**FP8 fix** (`nanochat/fp8.py`, `base_train.py`): +- `Float8Linear.forward` explicitly casts input to `COMPUTE_DTYPE` (previously relied on autocast). +- `disable_fp8` context manager now creates our custom `Linear` (not vanilla `nn.Linear`) when swapping out Float8Linear during eval. + +**Flash Attention** (`flash_attention.py`): +- FA3 Hopper kernels don't support fp16 or fp32, so `USE_FA3` (module-level constant, resolved once at import) returns False, falling back to SDPA. + +--- + ## 2026-03-04: Dataset upgrade: FineWeb-EDU 100B → ClimbMix 400B Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction. diff --git a/nanochat/common.py b/nanochat/common.py index 2dd0792..bd14fd2 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -10,6 +10,26 @@ import torch import torch.distributed as dist from filelock import FileLock +# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision. +# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast. +# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32" +_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32} +def _detect_compute_dtype(): + env = os.environ.get("NANOCHAT_DTYPE") + if env is not None: + return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}" + if torch.cuda.is_available(): + # bf16 requires SM 80+ (Ampere: A100, A10, etc.) + # Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores + capability = torch.cuda.get_device_capability() + if capability >= (8, 0): + return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)" + # fp16 training requires GradScaler (not yet implemented), so fall back to fp32. + # Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing. + return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)" + return torch.float32, "auto-detected: no CUDA (CPU/MPS)" +COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype() + class ColoredFormatter(logging.Formatter): """Custom formatter that adds colors to log messages.""" # ANSI color codes diff --git a/nanochat/engine.py b/nanochat/engine.py index a1ba24c..4724c8f 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -19,7 +19,6 @@ from contextlib import contextmanager from collections import deque from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model -from contextlib import nullcontext # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -308,8 +307,6 @@ if __name__ == "__main__": # init compute device_type = autodetect_device_type() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() - # load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="eval") bos_token_id = tokenizer.get_bos_token_id() @@ -322,11 +319,10 @@ if __name__ == "__main__": torch.cuda.synchronize() t0 = time.time() stream = model.generate(prompt_tokens, **kwargs) - with autocast_ctx: - for token in stream: - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + for token in stream: + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() @@ -338,12 +334,11 @@ if __name__ == "__main__": stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 torch.cuda.synchronize() t0 = time.time() - with autocast_ctx: - for token_column, token_masks in stream: - token = token_column[0] # only print out the first row - generated_tokens.append(token) - chunk = tokenizer.decode([token]) - print(chunk, end="", flush=True) + for token_column, token_masks in stream: + token = token_column[0] # only print out the first row + generated_tokens.append(token) + chunk = tokenizer.decode([token]) + print(chunk, end="", flush=True) print() torch.cuda.synchronize() t1 = time.time() diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 89ca42b..af2aee3 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -45,14 +45,22 @@ HAS_FA3 = _fa3 is not None _override_impl = None -def _use_fa3(): - """Determine whether to use FA3 based on availability and override.""" +def _resolve_use_fa3(): + """Decide once whether to use FA3, based on availability, override, and dtype.""" if _override_impl == 'fa3': assert HAS_FA3, "Cannot override to FA3: not available on this hardware" return True if _override_impl == 'sdpa': return False - return HAS_FA3 # auto + if HAS_FA3: + # FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback + from nanochat.common import COMPUTE_DTYPE + if COMPUTE_DTYPE == torch.bfloat16: + return True + return False + return False + +USE_FA3 = _resolve_use_fa3() # ============================================================================= @@ -90,7 +98,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): # sliding window (left) if window >= 0 and window < Tk: mask = mask & ((row_idx - col_idx) <= window) - + return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) # ============================================================================= @@ -108,7 +116,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): Returns: Output tensor of shape (B, T, H, D) """ - if _use_fa3(): + if USE_FA3: return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) @@ -138,7 +146,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N Returns: Output tensor of shape (B, T_new, H, D) """ - if _use_fa3(): + if USE_FA3: return _fa3.flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 3e88285..f9bf8d5 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -72,6 +72,8 @@ generates a different graph. Numerics are bitwise identical in eager mode. import torch import torch.nn as nn +from nanochat.common import COMPUTE_DTYPE + # Avoid division by zero when computing scale from an all-zeros tensor EPS = 1e-12 @@ -123,7 +125,7 @@ def _to_col_major(x): class _Float8Matmul(torch.autograd.Function): """Custom autograd for the three FP8 GEMMs of a Linear layer. - The forward quantizes input and weight to FP8 and saves + The forward quantizes input and weight to FP8 and saves the quantized tensors + scales for backward. """ @@ -198,11 +200,9 @@ class Float8Linear(nn.Linear): """ def forward(self, input): - # Replicate the autocast behavior of F.linear — when autocast is active, - # we need to manually cast input to the autocast dtype (e.g. bf16), - # since we bypass F.linear's built-in autocast handling. - if torch.is_autocast_enabled(): - input = input.to(torch.get_autocast_gpu_dtype()) + # Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects + # reduced precision input, and we no longer rely on autocast to do this. + input = input.to(COMPUTE_DTYPE) # _scaled_mm only works on 2D tensors, so flatten batch dimensions orig_shape = input.shape input_2d = input.reshape(-1, orig_shape[-1]) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 74e39fd..04ee5c5 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -19,7 +19,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from nanochat.common import get_dist_info, print0 +from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE from nanochat.optim import MuonAdamW, DistMuonAdamW # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere @@ -40,8 +40,14 @@ class GPTConfig: def norm(x): - # Purely functional rmsnorm with no learnable params - return F.rms_norm(x, (x.size(-1),)) + return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok + +class Linear(nn.Linear): + """nn.Linear that casts weights to match input dtype in forward. + Replaces autocast: master weights stay fp32 for optimizer precision, + but matmuls run in the activation dtype (typically bf16 from embeddings).""" + def forward(self, x): + return F.linear(x, self.weight.to(dtype=x.dtype)) def has_ve(layer_idx, n_layer): @@ -66,12 +72,12 @@ class CausalSelfAttention(nn.Module): self.head_dim = self.n_embd // self.n_head assert self.n_embd % self.n_head == 0 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 - self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) - self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) - self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) - self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) + self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False) + self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) + self.c_proj = Linear(self.n_embd, self.n_embd, bias=False) self.ve_gate_channels = 32 - self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None + self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() @@ -121,8 +127,8 @@ class CausalSelfAttention(nn.Module): class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) + self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False) + self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False) def forward(self, x): x = self.c_fc(x) @@ -164,7 +170,7 @@ class GPT(nn.Module): "wte": nn.Embedding(padded_vocab_size, config.n_embd), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) - self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False) + self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False) # Per-layer learnable scalars (inspired by modded-nanogpt) # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral) # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled) @@ -234,11 +240,13 @@ class GPT(nn.Module): cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin - # Cast embeddings to bf16: optimizer can tolerate it and it saves memory - if self.transformer.wte.weight.device.type == "cuda": - self.transformer.wte.to(dtype=torch.bfloat16) + # Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision + # embeddings and it saves memory. Exception: fp16 requires fp32 embeddings + # because GradScaler cannot unscale fp16 gradients. + if COMPUTE_DTYPE != torch.float16: + self.transformer.wte.to(dtype=COMPUTE_DTYPE) for ve in self.value_embeds.values(): - ve.to(dtype=torch.bfloat16) + ve.to(dtype=COMPUTE_DTYPE) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -253,7 +261,7 @@ class GPT(nn.Module): # calculate the rotation frequencies at each (time, channel) pair freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() - cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 + cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE) cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting return cos, sin @@ -391,18 +399,19 @@ class GPT(nn.Module): # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" + assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}" # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache T0 = 0 if kv_cache is None else kv_cache.get_pos() cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length # Forward the trunk of the Transformer x = self.transformer.wte(idx) # embed current token + x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path) x = norm(x) x0 = x # save initial normalized embedding for x0 residual for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None + ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) x = norm(x) diff --git a/scripts/base_eval.py b/scripts/base_eval.py index e45ae43..a57bbaf 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -29,8 +29,6 @@ import random import zipfile import tempfile import argparse -from contextlib import nullcontext - import torch from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock @@ -199,8 +197,6 @@ def main(): # Distributed / precision setup device_type = autodetect_device_type() if args.device_type == '' else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() - # Load model and tokenizer is_hf_model = args.hf_path is not None if is_hf_model: @@ -244,8 +240,7 @@ def main(): print0("\nConditioned samples:") for prompt in prompts: tokens = tokenizer(prompt, prepend="<|bos|>") - with autocast_ctx: - sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) + sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) sample_str = tokenizer.decode(sample[0]) print0("-" * 80) print0(sample_str) @@ -253,8 +248,7 @@ def main(): print0("\nUnconditioned samples:") tokens = tokenizer("", prepend="<|bos|>") - with autocast_ctx: - uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0) + uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0) for sample in uncond: sample_str = tokenizer.decode(sample) print0("-" * 80) @@ -277,8 +271,7 @@ def main(): for split_name in ["train", "val"]: loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device) - with autocast_ctx: - bpb = evaluate_bpb(model, loader, steps, token_bytes) + bpb = evaluate_bpb(model, loader, steps, token_bytes) bpb_results[split_name] = bpb print0(f"{split_name} bpb: {bpb:.6f}") @@ -287,8 +280,7 @@ def main(): print0("\n" + "="*80) print0("CORE Evaluation") print0("="*80) - with autocast_ctx: - core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task) + core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task) # Write CSV output if ddp_rank == 0: diff --git a/scripts/base_train.py b/scripts/base_train.py index 9461e88..4bf7959 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -19,14 +19,15 @@ import time import math import argparse from dataclasses import asdict -from contextlib import nullcontext, contextmanager +from contextlib import contextmanager import wandb import torch +import torch.distributed as dist -from nanochat.gpt import GPT, GPTConfig +from nanochat.gpt import GPT, GPTConfig, Linear from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -86,7 +87,6 @@ user_config = vars(args).copy() # for logging device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 if device_type == "cuda": @@ -95,17 +95,23 @@ if device_type == "cuda": print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") else: gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS +print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})") # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) # Flash Attention status -if HAS_FA3: +from nanochat.flash_attention import USE_FA3 +using_fa3 = USE_FA3 +if using_fa3: print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") else: print0("!" * 80) - print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") + if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16: + print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback") + else: + print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") print0("WARNING: Training will be less efficient without FA3") if args.window_pattern != "L": print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") @@ -213,9 +219,9 @@ def disable_fp8(model): yield # No FP8 modules, nothing to do return - # Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy) + # Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype) for parent, attr_name, fp8_module in fp8_locations: - linear = nn.Linear( + linear = Linear( fp8_module.in_features, fp8_module.out_features, bias=fp8_module.bias is not None, @@ -315,6 +321,12 @@ if resuming: optimizer.load_state_dict(optimizer_data) del optimizer_data +# ----------------------------------------------------------------------------- +# GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32) +scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None +if scaler is not None: + print0("GradScaler enabled for fp16 training") + # ----------------------------------------------------------------------------- # Initialize the DataLoaders for train/val dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] @@ -405,7 +417,7 @@ while True: model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) - with disable_fp8(model), autocast_ctx: + with disable_fp8(model): val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") if val_bpb < min_val_bpb: @@ -424,7 +436,7 @@ while True: results = {} if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): model.eval() - with disable_fp8(orig_model), autocast_ctx: + with disable_fp8(orig_model): results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") wandb_run.log({ @@ -451,7 +463,7 @@ while True: engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation for prompt in prompts: tokens = tokenizer(prompt, prepend="<|bos|>") - with disable_fp8(orig_model), autocast_ctx: + with disable_fp8(orig_model): sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) print0(tokenizer.decode(sample[0])) model.train() @@ -491,11 +503,13 @@ while True: synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): - with autocast_ctx: - loss = model(x, y) + loss = model(x, y) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here - loss.backward() + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # step the optimizer lrm = get_lr_multiplier(step) @@ -506,7 +520,18 @@ while True: if group['kind'] == 'muon': group["momentum"] = muon_momentum group["weight_decay"] = muon_weight_decay - optimizer.step() + if scaler is not None: + scaler.unscale_(optimizer) + # In distributed training, all ranks must agree on whether to skip the step. + # Each rank may independently encounter inf/nan gradients, so we all-reduce + # the found_inf flag (MAX = if any rank found inf, all ranks skip). + if is_ddp_initialized(): + for v in scaler._found_inf_per_device(optimizer).values(): + dist.all_reduce(v, op=dist.ReduceOp.MAX) + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() model.zero_grad(set_to_none=True) train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point synchronize() diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index 7de7e10..2bcc8aa 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -7,7 +7,6 @@ python -m scripts.chat_cli import argparse import torch from nanochat.common import compute_init, autodetect_device_type -from contextlib import nullcontext from nanochat.engine import Engine from nanochat.checkpoint_manager import load_model @@ -19,15 +18,12 @@ parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the mod parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation') parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') -parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) args = parser.parse_args() # Init the model and tokenizer device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) # Special tokens for the chat state machine @@ -87,12 +83,11 @@ while True: } response_tokens = [] print("\nAssistant: ", end="", flush=True) - with autocast_ctx: - for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs): - token = token_column[0] # pop the batch dimension (num_samples=1) - response_tokens.append(token) - token_text = tokenizer.decode([token]) - print(token_text, end="", flush=True) + for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs): + token = token_column[0] # pop the batch dimension (num_samples=1) + response_tokens.append(token) + token_text = tokenizer.decode([token]) + print(token_text, end="", flush=True) print() # we have to ensure that the assistant end token is the last token # so even if generation ends due to max tokens, we have to append it to the end diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index bc15239..858d4c2 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -10,8 +10,6 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy import argparse from functools import partial -from contextlib import nullcontext - import torch import torch.distributed as dist @@ -185,7 +183,6 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl") parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.") - parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('-t', '--temperature', type=float, default=0.0) parser.add_argument('-m', '--max-new-tokens', type=int, default=512) parser.add_argument('-n', '--num-samples', type=int, default=1) @@ -199,8 +196,6 @@ if __name__ == "__main__": device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) - ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) engine = Engine(model, tokenizer) @@ -220,19 +215,18 @@ if __name__ == "__main__": # Run all the task evaluations sequentially results = {} for task_name in task_names: - with autocast_ctx: - acc = run_chat_eval( - task_name, - model, tokenizer, engine, - batch_size=args.batch_size, - num_samples=args.num_samples, - max_new_tokens=args.max_new_tokens, - temperature=args.temperature, - top_k=args.top_k, - max_problems=args.max_problems, - ) - results[task_name] = acc - print0(f"{task_name} accuracy: {100 * acc:.2f}%") + acc = run_chat_eval( + task_name, + model, tokenizer, engine, + batch_size=args.batch_size, + num_samples=args.num_samples, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + max_problems=args.max_problems, + ) + results[task_name] = acc + print0(f"{task_name} accuracy: {100 * acc:.2f}%") # Log to report from nanochat.report import get_report diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index 20a1a0a..cb2cb0e 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -22,8 +22,6 @@ import itertools import wandb import torch import torch.distributed as dist -from contextlib import nullcontext - from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type from nanochat.checkpoint_manager import save_checkpoint, load_model from nanochat.engine import Engine @@ -36,7 +34,6 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K") parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") # Runtime parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") -parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16") # Model loading parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from") parser.add_argument("--model-step", type=int, default=None, help="model step to load from") @@ -68,8 +65,6 @@ user_config = vars(args).copy() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() # wandb logging init use_dummy_wandb = args.run == "dummy" or not master_process @@ -108,15 +103,14 @@ def get_batch(): num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs for sampling_step in range(num_sampling_steps): seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32 - with autocast_ctx: - generated_token_sequences_batch, masks_batch = engine.generate_batch( - tokens, - num_samples=args.device_batch_size, - max_tokens=args.max_new_tokens, - temperature=args.temperature, - top_k=args.top_k, - seed=seed, # must make sure to change the seed for each sampling step - ) + generated_token_sequences_batch, masks_batch = engine.generate_batch( + tokens, + num_samples=args.device_batch_size, + max_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + seed=seed, # must make sure to change the seed for each sampling step + ) generated_token_sequences.extend(generated_token_sequences_batch) masks.extend(masks_batch) @@ -231,9 +225,8 @@ for step in range(num_steps): if step % args.eval_every == 0: model.eval() passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size - with autocast_ctx: - records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0) - records = list(records_iter) # collect all records + records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0) + records = list(records_iter) # collect all records for k in range(1, args.device_batch_size + 1): passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records) num_records = torch.tensor(len(records), dtype=torch.long, device=device) @@ -268,8 +261,7 @@ for step in range(num_steps): rewards = rewards_all[b0:b1] advantages = advantages_all[b0:b1] # Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate - with autocast_ctx: - logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T) + logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T) # Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0. pg_obj = (logp * advantages.unsqueeze(-1)).sum() # normalize by the number of valid tokens, number of passes, and examples_per_rank diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index cb9e078..c1adbb6 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -16,8 +16,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import time import wandb import torch -from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.loss_eval import evaluate_bpb @@ -75,7 +74,7 @@ user_config = vars(args).copy() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})") synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 if device_type == "cuda": @@ -151,6 +150,11 @@ if args.load_optimizer: else: print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)") +# GradScaler for fp16 training (bf16/fp32 don't need it) +scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None +if scaler is not None: + print0("GradScaler enabled for fp16 training") + # Override the initial learning rate as a fraction of the base learning rate for group in optimizer.param_groups: group["lr"] = group["lr"] * args.init_lr_frac @@ -344,8 +348,7 @@ while True: model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) - with autocast_ctx: - val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) + val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") if val_bpb < min_val_bpb: min_val_bpb = val_bpb @@ -373,9 +376,8 @@ while True: for task_name in all_tasks: limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample max_problems = None if limit < 0 else limit # -1 means no limit - with autocast_ctx: - acc = run_chat_eval(task_name, orig_model, tokenizer, engine, - batch_size=args.device_batch_size, max_problems=max_problems) + acc = run_chat_eval(task_name, orig_model, tokenizer, engine, + batch_size=args.device_batch_size, max_problems=max_problems) task_results[task_name] = acc print0(f" {task_name}: {100*acc:.2f}%") # Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect) @@ -428,11 +430,13 @@ while True: synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): - with autocast_ctx: - loss = model(x, y) + loss = model(x, y) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here - loss.backward() + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward progress = max(progress, approx_progress) # only increase progress monotonically # step the optimizer @@ -442,7 +446,15 @@ while True: group["lr"] = group["initial_lr"] * lrm if group['kind'] == 'muon': group["momentum"] = muon_momentum - optimizer.step() + if scaler is not None: + scaler.unscale_(optimizer) + if is_ddp_initialized(): + for v in scaler._found_inf_per_device(optimizer).values(): + dist.all_reduce(v, op=dist.ReduceOp.MAX) + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() model.zero_grad(set_to_none=True) synchronize() t1 = time.time() diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 66d7806..ffaf7da 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -44,7 +44,6 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel from typing import List, Optional, AsyncGenerator from dataclasses import dataclass -from contextlib import nullcontext from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine @@ -69,7 +68,6 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') -parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16']) parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') args = parser.parse_args() @@ -84,7 +82,6 @@ logger = logging.getLogger(__name__) device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) -ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 @dataclass class Worker: @@ -93,7 +90,6 @@ class Worker: device: torch.device engine: Engine tokenizer: object - autocast_ctx: torch.amp.autocast class WorkerPool: """Pool of workers, each with a model replica on a different GPU.""" @@ -125,14 +121,11 @@ class WorkerPool: model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) engine = Engine(model, tokenizer) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() - worker = Worker( gpu_id=gpu_id, device=device, engine=engine, tokenizer=tokenizer, - autocast_ctx=autocast_ctx ) self.workers.append(worker) await self.available_workers.put(worker) @@ -279,34 +272,33 @@ async def generate_stream( # Track the last complete UTF-8 string (without replacement characters) last_clean_text = "" - with worker.autocast_ctx: - for token_column, token_masks in worker.engine.generate( - tokens, - num_samples=1, - max_tokens=max_new_tokens, - temperature=temperature, - top_k=top_k, - seed=random.randint(0, 2**31 - 1) - ): - token = token_column[0] + for token_column, token_masks in worker.engine.generate( + tokens, + num_samples=1, + max_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + seed=random.randint(0, 2**31 - 1) + ): + token = token_column[0] - # Stopping criteria - if token == assistant_end or token == bos: - break + # Stopping criteria + if token == assistant_end or token == bos: + break - # Append the token to sequence - accumulated_tokens.append(token) - # Decode all accumulated tokens to get proper UTF-8 handling - # Note that decode is a quite efficient operation, basically table lookup and string concat - current_text = worker.tokenizer.decode(accumulated_tokens) - # Only emit text if it doesn't end with a replacement character - # This ensures we don't emit incomplete UTF-8 sequences - if not current_text.endswith('�'): - # Extract only the new text since last clean decode - new_text = current_text[len(last_clean_text):] - if new_text: # Only yield if there's new content - yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n" - last_clean_text = current_text + # Append the token to sequence + accumulated_tokens.append(token) + # Decode all accumulated tokens to get proper UTF-8 handling + # Note that decode is a quite efficient operation, basically table lookup and string concat + current_text = worker.tokenizer.decode(accumulated_tokens) + # Only emit text if it doesn't end with a replacement character + # This ensures we don't emit incomplete UTF-8 sequences + if not current_text.endswith('�'): + # Extract only the new text since last clean decode + new_text = current_text[len(last_clean_text):] + if new_text: # Only yield if there's new content + yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n" + last_clean_text = current_text yield f"data: {json.dumps({'done': True})}\n\n" diff --git a/tests/test_attention_fallback.py b/tests/test_attention_fallback.py index 9741c7f..3eddc72 100644 --- a/tests/test_attention_fallback.py +++ b/tests/test_attention_fallback.py @@ -21,8 +21,9 @@ from nanochat.engine import KVCache def set_impl(impl): - """Set the implementation override ('fa3', 'sdpa', or None for auto).""" + """Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3.""" fa_module._override_impl = impl + fa_module.USE_FA3 = fa_module._resolve_use_fa3() def run_both_impls(fn): @@ -343,19 +344,19 @@ class TestOverrideMechanism: def test_override_fa3(self): """Test that override='fa3' uses FA3.""" set_impl('fa3') - assert fa_module._use_fa3() == True + assert fa_module.USE_FA3 == True set_impl(None) def test_override_sdpa(self): """Test that override='sdpa' uses SDPA.""" set_impl('sdpa') - assert fa_module._use_fa3() == False + assert fa_module.USE_FA3 == False set_impl(None) def test_override_auto(self): """Test that override=None uses auto-detection.""" set_impl(None) - assert fa_module._use_fa3() == HAS_FA3 + assert fa_module.USE_FA3 == HAS_FA3 if __name__ == "__main__":