delete autocast, an unnecessary thorn in my side, manage dtypes directly
This commit is contained in:
@@ -82,6 +82,27 @@ The important thing to note is that nanochat is written and configured around on
|
|||||||
|
|
||||||
The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way.
|
The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way.
|
||||||
|
|
||||||
|
## Precision / dtype
|
||||||
|
|
||||||
|
nanochat does not use `torch.amp.autocast`. Instead, precision is managed explicitly through a single global `COMPUTE_DTYPE` (defined in `nanochat/common.py`). By default this is auto-detected based on your hardware:
|
||||||
|
|
||||||
|
| Hardware | Default dtype | Why |
|
||||||
|
|----------|--------------|-----|
|
||||||
|
| CUDA SM 80+ (A100, H100, ...) | `bfloat16` | Native bf16 tensor cores |
|
||||||
|
| CUDA SM < 80 (V100, T4, ...) | `float32` | No bf16; fp16 available via `NANOCHAT_DTYPE=float16` (uses GradScaler) |
|
||||||
|
| CPU / MPS | `float32` | No reduced-precision tensor cores |
|
||||||
|
|
||||||
|
You can override the default with the `NANOCHAT_DTYPE` environment variable:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
NANOCHAT_DTYPE=float32 python -m scripts.chat_cli -p "hello" # force fp32
|
||||||
|
NANOCHAT_DTYPE=bfloat16 torchrun --nproc_per_node=8 -m scripts.base_train # force bf16
|
||||||
|
```
|
||||||
|
|
||||||
|
How it works: model weights are stored in fp32 (for optimizer precision), but our custom `Linear` layer casts them to `COMPUTE_DTYPE` during the forward pass. Embeddings are stored directly in `COMPUTE_DTYPE` to save memory. This gives us the same mixed-precision benefit as autocast but with full explicit control over what runs in which precision.
|
||||||
|
|
||||||
|
Note: `float16` training automatically enables a `GradScaler` in `base_train.py` to prevent gradient underflow. SFT suppors this too but RL currently does not. Inference in fp16 works fine everywhere.
|
||||||
|
|
||||||
## Guides
|
## Guides
|
||||||
|
|
||||||
I've published a number of guides that might contain helpful information, most recent to least recent:
|
I've published a number of guides that might contain helpful information, most recent to least recent:
|
||||||
|
|||||||
+35
@@ -4,6 +4,41 @@ A running summary documenting some experiments and findings. Started ~Jan 7 2026
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## 2026-03-04: Remove autocast, explicit dtype management, fp16 GradScaler
|
||||||
|
|
||||||
|
Replaced `torch.amp.autocast` throughout the codebase with explicit dtype management via a single `COMPUTE_DTYPE` global. Also added fp16 training support with GradScaler.
|
||||||
|
|
||||||
|
### Motivation
|
||||||
|
|
||||||
|
autocast is "magic we don't control" — it silently decides which ops run in which precision via internal allowlists. For this codebase, autocast was doing very little: the only thing it actually cast was `nn.Linear` weights from fp32 to bf16 for matmuls. `F.rms_norm`, `F.cross_entropy`, and Flash Attention all handle their own dtypes already. By making precision explicit, we gain fine-grained control (e.g. can experiment with fp32 norms) and eliminate an unnecessary layer of abstraction.
|
||||||
|
|
||||||
|
### What changed
|
||||||
|
|
||||||
|
**Core mechanism** (`nanochat/common.py`, `nanochat/gpt.py`):
|
||||||
|
- `COMPUTE_DTYPE` auto-detected from hardware: SM 80+ → bf16, pre-Ampere → fp32, CPU/MPS → fp32. Override via `NANOCHAT_DTYPE` env var.
|
||||||
|
- Custom `Linear(nn.Linear)` class that casts weights to match input dtype in forward: `F.linear(x, self.weight.to(dtype=x.dtype))`. This is the single mechanism that replaces autocast.
|
||||||
|
- Embeddings cast to `COMPUTE_DTYPE` at init (saves memory). Exception: fp16 keeps embeddings fp32 because GradScaler cannot unscale fp16 gradients.
|
||||||
|
- Embedding output explicitly cast to `COMPUTE_DTYPE` in `GPT.forward()` (no-op for bf16, active for fp16 path).
|
||||||
|
- RoPE cos/sin cache uses `COMPUTE_DTYPE` instead of hardcoded bf16.
|
||||||
|
|
||||||
|
**Autocast removal** (11 files):
|
||||||
|
- Deleted `--dtype` CLI flag, `ptdtype` variables, `autocast_ctx` definitions, and all `with autocast_ctx:` blocks from: `base_train.py`, `chat_sft.py`, `chat_rl.py`, `chat_cli.py`, `chat_eval.py`, `chat_web.py`, `base_eval.py`, `engine.py`, `bench_train_toks.py`, `test_e2e_pipeline.py`.
|
||||||
|
|
||||||
|
**fp16 + GradScaler** (`base_train.py`, `chat_sft.py`):
|
||||||
|
- `scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None`
|
||||||
|
- Backward: `scaler.scale(loss).backward()` vs plain `loss.backward()`
|
||||||
|
- After accumulation: `scaler.unscale_(optimizer)` → distributed inf-sync via `scaler._found_inf_per_device(optimizer)` all-reduced with `ReduceOp.MAX` → `scaler.step(optimizer)` → `scaler.update()`
|
||||||
|
- Zero overhead for bf16/fp32 paths (scaler is None, no branching inside kernels).
|
||||||
|
|
||||||
|
**FP8 fix** (`nanochat/fp8.py`, `base_train.py`):
|
||||||
|
- `Float8Linear.forward` explicitly casts input to `COMPUTE_DTYPE` (previously relied on autocast).
|
||||||
|
- `disable_fp8` context manager now creates our custom `Linear` (not vanilla `nn.Linear`) when swapping out Float8Linear during eval.
|
||||||
|
|
||||||
|
**Flash Attention** (`flash_attention.py`):
|
||||||
|
- FA3 Hopper kernels don't support fp16 or fp32, so `USE_FA3` (module-level constant, resolved once at import) returns False, falling back to SDPA.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 2026-03-04: Dataset upgrade: FineWeb-EDU 100B → ClimbMix 400B
|
## 2026-03-04: Dataset upgrade: FineWeb-EDU 100B → ClimbMix 400B
|
||||||
|
|
||||||
Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction.
|
Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction.
|
||||||
|
|||||||
@@ -10,6 +10,26 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
|
# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
|
||||||
|
# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
|
||||||
|
# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
|
||||||
|
_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
|
||||||
|
def _detect_compute_dtype():
|
||||||
|
env = os.environ.get("NANOCHAT_DTYPE")
|
||||||
|
if env is not None:
|
||||||
|
return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
# bf16 requires SM 80+ (Ampere: A100, A10, etc.)
|
||||||
|
# Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
if capability >= (8, 0):
|
||||||
|
return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
|
||||||
|
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
|
||||||
|
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
|
||||||
|
return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
|
||||||
|
return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
|
||||||
|
COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
|
||||||
|
|
||||||
class ColoredFormatter(logging.Formatter):
|
class ColoredFormatter(logging.Formatter):
|
||||||
"""Custom formatter that adds colors to log messages."""
|
"""Custom formatter that adds colors to log messages."""
|
||||||
# ANSI color codes
|
# ANSI color codes
|
||||||
|
|||||||
+9
-14
@@ -19,7 +19,6 @@ from contextlib import contextmanager
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from nanochat.common import compute_init, autodetect_device_type
|
from nanochat.common import compute_init, autodetect_device_type
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Calculator tool helpers
|
# Calculator tool helpers
|
||||||
@@ -308,8 +307,6 @@ if __name__ == "__main__":
|
|||||||
# init compute
|
# init compute
|
||||||
device_type = autodetect_device_type()
|
device_type = autodetect_device_type()
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
|
||||||
|
|
||||||
# load the model and tokenizer
|
# load the model and tokenizer
|
||||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||||
bos_token_id = tokenizer.get_bos_token_id()
|
bos_token_id = tokenizer.get_bos_token_id()
|
||||||
@@ -322,11 +319,10 @@ if __name__ == "__main__":
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
stream = model.generate(prompt_tokens, **kwargs)
|
stream = model.generate(prompt_tokens, **kwargs)
|
||||||
with autocast_ctx:
|
for token in stream:
|
||||||
for token in stream:
|
generated_tokens.append(token)
|
||||||
generated_tokens.append(token)
|
chunk = tokenizer.decode([token])
|
||||||
chunk = tokenizer.decode([token])
|
print(chunk, end="", flush=True)
|
||||||
print(chunk, end="", flush=True)
|
|
||||||
print()
|
print()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
@@ -338,12 +334,11 @@ if __name__ == "__main__":
|
|||||||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
with autocast_ctx:
|
for token_column, token_masks in stream:
|
||||||
for token_column, token_masks in stream:
|
token = token_column[0] # only print out the first row
|
||||||
token = token_column[0] # only print out the first row
|
generated_tokens.append(token)
|
||||||
generated_tokens.append(token)
|
chunk = tokenizer.decode([token])
|
||||||
chunk = tokenizer.decode([token])
|
print(chunk, end="", flush=True)
|
||||||
print(chunk, end="", flush=True)
|
|
||||||
print()
|
print()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|||||||
@@ -45,14 +45,22 @@ HAS_FA3 = _fa3 is not None
|
|||||||
_override_impl = None
|
_override_impl = None
|
||||||
|
|
||||||
|
|
||||||
def _use_fa3():
|
def _resolve_use_fa3():
|
||||||
"""Determine whether to use FA3 based on availability and override."""
|
"""Decide once whether to use FA3, based on availability, override, and dtype."""
|
||||||
if _override_impl == 'fa3':
|
if _override_impl == 'fa3':
|
||||||
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
||||||
return True
|
return True
|
||||||
if _override_impl == 'sdpa':
|
if _override_impl == 'sdpa':
|
||||||
return False
|
return False
|
||||||
return HAS_FA3 # auto
|
if HAS_FA3:
|
||||||
|
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
|
||||||
|
from nanochat.common import COMPUTE_DTYPE
|
||||||
|
if COMPUTE_DTYPE == torch.bfloat16:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
USE_FA3 = _resolve_use_fa3()
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -108,7 +116,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
|||||||
Returns:
|
Returns:
|
||||||
Output tensor of shape (B, T, H, D)
|
Output tensor of shape (B, T, H, D)
|
||||||
"""
|
"""
|
||||||
if _use_fa3():
|
if USE_FA3:
|
||||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||||
|
|
||||||
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
||||||
@@ -138,7 +146,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||||||
Returns:
|
Returns:
|
||||||
Output tensor of shape (B, T_new, H, D)
|
Output tensor of shape (B, T_new, H, D)
|
||||||
"""
|
"""
|
||||||
if _use_fa3():
|
if USE_FA3:
|
||||||
return _fa3.flash_attn_with_kvcache(
|
return _fa3.flash_attn_with_kvcache(
|
||||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||||
causal=causal, window_size=window_size
|
causal=causal, window_size=window_size
|
||||||
|
|||||||
+5
-5
@@ -72,6 +72,8 @@ generates a different graph. Numerics are bitwise identical in eager mode.
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from nanochat.common import COMPUTE_DTYPE
|
||||||
|
|
||||||
# Avoid division by zero when computing scale from an all-zeros tensor
|
# Avoid division by zero when computing scale from an all-zeros tensor
|
||||||
EPS = 1e-12
|
EPS = 1e-12
|
||||||
|
|
||||||
@@ -198,11 +200,9 @@ class Float8Linear(nn.Linear):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
# Replicate the autocast behavior of F.linear — when autocast is active,
|
# Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects
|
||||||
# we need to manually cast input to the autocast dtype (e.g. bf16),
|
# reduced precision input, and we no longer rely on autocast to do this.
|
||||||
# since we bypass F.linear's built-in autocast handling.
|
input = input.to(COMPUTE_DTYPE)
|
||||||
if torch.is_autocast_enabled():
|
|
||||||
input = input.to(torch.get_autocast_gpu_dtype())
|
|
||||||
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
|
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
|
||||||
orig_shape = input.shape
|
orig_shape = input.shape
|
||||||
input_2d = input.reshape(-1, orig_shape[-1])
|
input_2d = input.reshape(-1, orig_shape[-1])
|
||||||
|
|||||||
+27
-18
@@ -19,7 +19,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from nanochat.common import get_dist_info, print0
|
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
|
||||||
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
||||||
|
|
||||||
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
||||||
@@ -40,8 +40,14 @@ class GPTConfig:
|
|||||||
|
|
||||||
|
|
||||||
def norm(x):
|
def norm(x):
|
||||||
# Purely functional rmsnorm with no learnable params
|
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
|
||||||
return F.rms_norm(x, (x.size(-1),))
|
|
||||||
|
class Linear(nn.Linear):
|
||||||
|
"""nn.Linear that casts weights to match input dtype in forward.
|
||||||
|
Replaces autocast: master weights stay fp32 for optimizer precision,
|
||||||
|
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
|
||||||
|
def forward(self, x):
|
||||||
|
return F.linear(x, self.weight.to(dtype=x.dtype))
|
||||||
|
|
||||||
|
|
||||||
def has_ve(layer_idx, n_layer):
|
def has_ve(layer_idx, n_layer):
|
||||||
@@ -66,12 +72,12 @@ class CausalSelfAttention(nn.Module):
|
|||||||
self.head_dim = self.n_embd // self.n_head
|
self.head_dim = self.n_embd // self.n_head
|
||||||
assert self.n_embd % self.n_head == 0
|
assert self.n_embd % self.n_head == 0
|
||||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||||
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
|
||||||
self.ve_gate_channels = 32
|
self.ve_gate_channels = 32
|
||||||
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||||
|
|
||||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||||
B, T, C = x.size()
|
B, T, C = x.size()
|
||||||
@@ -121,8 +127,8 @@ class CausalSelfAttention(nn.Module):
|
|||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.c_fc(x)
|
x = self.c_fc(x)
|
||||||
@@ -164,7 +170,7 @@ class GPT(nn.Module):
|
|||||||
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
||||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||||
})
|
})
|
||||||
self.lm_head = nn.Linear(config.n_embd, padded_vocab_size, bias=False)
|
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
|
||||||
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
||||||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||||
@@ -234,11 +240,13 @@ class GPT(nn.Module):
|
|||||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||||
self.cos, self.sin = cos, sin
|
self.cos, self.sin = cos, sin
|
||||||
|
|
||||||
# Cast embeddings to bf16: optimizer can tolerate it and it saves memory
|
# Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
|
||||||
if self.transformer.wte.weight.device.type == "cuda":
|
# embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
|
||||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
# because GradScaler cannot unscale fp16 gradients.
|
||||||
|
if COMPUTE_DTYPE != torch.float16:
|
||||||
|
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
||||||
for ve in self.value_embeds.values():
|
for ve in self.value_embeds.values():
|
||||||
ve.to(dtype=torch.bfloat16)
|
ve.to(dtype=COMPUTE_DTYPE)
|
||||||
|
|
||||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||||
@@ -253,7 +261,7 @@ class GPT(nn.Module):
|
|||||||
# calculate the rotation frequencies at each (time, channel) pair
|
# calculate the rotation frequencies at each (time, channel) pair
|
||||||
freqs = torch.outer(t, inv_freq)
|
freqs = torch.outer(t, inv_freq)
|
||||||
cos, sin = freqs.cos(), freqs.sin()
|
cos, sin = freqs.cos(), freqs.sin()
|
||||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
|
||||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
@@ -391,18 +399,19 @@ class GPT(nn.Module):
|
|||||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
|
||||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||||
|
|
||||||
# Forward the trunk of the Transformer
|
# Forward the trunk of the Transformer
|
||||||
x = self.transformer.wte(idx) # embed current token
|
x = self.transformer.wte(idx) # embed current token
|
||||||
|
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
||||||
x = norm(x)
|
x = norm(x)
|
||||||
x0 = x # save initial normalized embedding for x0 residual
|
x0 = x # save initial normalized embedding for x0 residual
|
||||||
for i, block in enumerate(self.transformer.h):
|
for i, block in enumerate(self.transformer.h):
|
||||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||||
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||||
x = norm(x)
|
x = norm(x)
|
||||||
|
|
||||||
|
|||||||
+4
-12
@@ -29,8 +29,6 @@ import random
|
|||||||
import zipfile
|
import zipfile
|
||||||
import tempfile
|
import tempfile
|
||||||
import argparse
|
import argparse
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||||
@@ -199,8 +197,6 @@ def main():
|
|||||||
# Distributed / precision setup
|
# Distributed / precision setup
|
||||||
device_type = autodetect_device_type() if args.device_type == '' else args.device_type
|
device_type = autodetect_device_type() if args.device_type == '' else args.device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
|
||||||
|
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
is_hf_model = args.hf_path is not None
|
is_hf_model = args.hf_path is not None
|
||||||
if is_hf_model:
|
if is_hf_model:
|
||||||
@@ -244,8 +240,7 @@ def main():
|
|||||||
print0("\nConditioned samples:")
|
print0("\nConditioned samples:")
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||||
with autocast_ctx:
|
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
|
||||||
sample_str = tokenizer.decode(sample[0])
|
sample_str = tokenizer.decode(sample[0])
|
||||||
print0("-" * 80)
|
print0("-" * 80)
|
||||||
print0(sample_str)
|
print0(sample_str)
|
||||||
@@ -253,8 +248,7 @@ def main():
|
|||||||
|
|
||||||
print0("\nUnconditioned samples:")
|
print0("\nUnconditioned samples:")
|
||||||
tokens = tokenizer("", prepend="<|bos|>")
|
tokens = tokenizer("", prepend="<|bos|>")
|
||||||
with autocast_ctx:
|
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
||||||
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
|
||||||
for sample in uncond:
|
for sample in uncond:
|
||||||
sample_str = tokenizer.decode(sample)
|
sample_str = tokenizer.decode(sample)
|
||||||
print0("-" * 80)
|
print0("-" * 80)
|
||||||
@@ -277,8 +271,7 @@ def main():
|
|||||||
|
|
||||||
for split_name in ["train", "val"]:
|
for split_name in ["train", "val"]:
|
||||||
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||||
with autocast_ctx:
|
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
|
||||||
bpb_results[split_name] = bpb
|
bpb_results[split_name] = bpb
|
||||||
print0(f"{split_name} bpb: {bpb:.6f}")
|
print0(f"{split_name} bpb: {bpb:.6f}")
|
||||||
|
|
||||||
@@ -287,8 +280,7 @@ def main():
|
|||||||
print0("\n" + "="*80)
|
print0("\n" + "="*80)
|
||||||
print0("CORE Evaluation")
|
print0("CORE Evaluation")
|
||||||
print0("="*80)
|
print0("="*80)
|
||||||
with autocast_ctx:
|
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||||
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
|
|
||||||
|
|
||||||
# Write CSV output
|
# Write CSV output
|
||||||
if ddp_rank == 0:
|
if ddp_rank == 0:
|
||||||
|
|||||||
+40
-15
@@ -19,14 +19,15 @@ import time
|
|||||||
import math
|
import math
|
||||||
import argparse
|
import argparse
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from contextlib import nullcontext, contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from nanochat.gpt import GPT, GPTConfig
|
from nanochat.gpt import GPT, GPTConfig, Linear
|
||||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
@@ -86,7 +87,6 @@ user_config = vars(args).copy() # for logging
|
|||||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
|
||||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||||
if device_type == "cuda":
|
if device_type == "cuda":
|
||||||
@@ -95,17 +95,23 @@ if device_type == "cuda":
|
|||||||
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
||||||
else:
|
else:
|
||||||
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||||
|
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
||||||
|
|
||||||
# wandb logging init
|
# wandb logging init
|
||||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||||
|
|
||||||
# Flash Attention status
|
# Flash Attention status
|
||||||
if HAS_FA3:
|
from nanochat.flash_attention import USE_FA3
|
||||||
|
using_fa3 = USE_FA3
|
||||||
|
if using_fa3:
|
||||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||||
else:
|
else:
|
||||||
print0("!" * 80)
|
print0("!" * 80)
|
||||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
|
||||||
|
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
|
||||||
|
else:
|
||||||
|
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||||
print0("WARNING: Training will be less efficient without FA3")
|
print0("WARNING: Training will be less efficient without FA3")
|
||||||
if args.window_pattern != "L":
|
if args.window_pattern != "L":
|
||||||
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||||
@@ -213,9 +219,9 @@ def disable_fp8(model):
|
|||||||
yield # No FP8 modules, nothing to do
|
yield # No FP8 modules, nothing to do
|
||||||
return
|
return
|
||||||
|
|
||||||
# Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy)
|
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
|
||||||
for parent, attr_name, fp8_module in fp8_locations:
|
for parent, attr_name, fp8_module in fp8_locations:
|
||||||
linear = nn.Linear(
|
linear = Linear(
|
||||||
fp8_module.in_features,
|
fp8_module.in_features,
|
||||||
fp8_module.out_features,
|
fp8_module.out_features,
|
||||||
bias=fp8_module.bias is not None,
|
bias=fp8_module.bias is not None,
|
||||||
@@ -315,6 +321,12 @@ if resuming:
|
|||||||
optimizer.load_state_dict(optimizer_data)
|
optimizer.load_state_dict(optimizer_data)
|
||||||
del optimizer_data
|
del optimizer_data
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32)
|
||||||
|
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
||||||
|
if scaler is not None:
|
||||||
|
print0("GradScaler enabled for fp16 training")
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Initialize the DataLoaders for train/val
|
# Initialize the DataLoaders for train/val
|
||||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||||
@@ -405,7 +417,7 @@ while True:
|
|||||||
model.eval()
|
model.eval()
|
||||||
val_loader = build_val_loader()
|
val_loader = build_val_loader()
|
||||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||||
with disable_fp8(model), autocast_ctx:
|
with disable_fp8(model):
|
||||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
||||||
if val_bpb < min_val_bpb:
|
if val_bpb < min_val_bpb:
|
||||||
@@ -424,7 +436,7 @@ while True:
|
|||||||
results = {}
|
results = {}
|
||||||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||||
model.eval()
|
model.eval()
|
||||||
with disable_fp8(orig_model), autocast_ctx:
|
with disable_fp8(orig_model):
|
||||||
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||||
wandb_run.log({
|
wandb_run.log({
|
||||||
@@ -451,7 +463,7 @@ while True:
|
|||||||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||||
with disable_fp8(orig_model), autocast_ctx:
|
with disable_fp8(orig_model):
|
||||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||||
print0(tokenizer.decode(sample[0]))
|
print0(tokenizer.decode(sample[0]))
|
||||||
model.train()
|
model.train()
|
||||||
@@ -491,11 +503,13 @@ while True:
|
|||||||
synchronize()
|
synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for micro_step in range(grad_accum_steps):
|
for micro_step in range(grad_accum_steps):
|
||||||
with autocast_ctx:
|
loss = model(x, y)
|
||||||
loss = model(x, y)
|
|
||||||
train_loss = loss.detach() # for logging
|
train_loss = loss.detach() # for logging
|
||||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||||
loss.backward()
|
if scaler is not None:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||||
# step the optimizer
|
# step the optimizer
|
||||||
lrm = get_lr_multiplier(step)
|
lrm = get_lr_multiplier(step)
|
||||||
@@ -506,7 +520,18 @@ while True:
|
|||||||
if group['kind'] == 'muon':
|
if group['kind'] == 'muon':
|
||||||
group["momentum"] = muon_momentum
|
group["momentum"] = muon_momentum
|
||||||
group["weight_decay"] = muon_weight_decay
|
group["weight_decay"] = muon_weight_decay
|
||||||
optimizer.step()
|
if scaler is not None:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
# In distributed training, all ranks must agree on whether to skip the step.
|
||||||
|
# Each rank may independently encounter inf/nan gradients, so we all-reduce
|
||||||
|
# the found_inf flag (MAX = if any rank found inf, all ranks skip).
|
||||||
|
if is_ddp_initialized():
|
||||||
|
for v in scaler._found_inf_per_device(optimizer).values():
|
||||||
|
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
optimizer.step()
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||||
synchronize()
|
synchronize()
|
||||||
|
|||||||
+5
-10
@@ -7,7 +7,6 @@ python -m scripts.chat_cli
|
|||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
from nanochat.common import compute_init, autodetect_device_type
|
from nanochat.common import compute_init, autodetect_device_type
|
||||||
from contextlib import nullcontext
|
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
|
|
||||||
@@ -19,15 +18,12 @@ parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the mod
|
|||||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Init the model and tokenizer
|
# Init the model and tokenizer
|
||||||
|
|
||||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
|
||||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||||
|
|
||||||
# Special tokens for the chat state machine
|
# Special tokens for the chat state machine
|
||||||
@@ -87,12 +83,11 @@ while True:
|
|||||||
}
|
}
|
||||||
response_tokens = []
|
response_tokens = []
|
||||||
print("\nAssistant: ", end="", flush=True)
|
print("\nAssistant: ", end="", flush=True)
|
||||||
with autocast_ctx:
|
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
||||||
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
response_tokens.append(token)
|
||||||
response_tokens.append(token)
|
token_text = tokenizer.decode([token])
|
||||||
token_text = tokenizer.decode([token])
|
print(token_text, end="", flush=True)
|
||||||
print(token_text, end="", flush=True)
|
|
||||||
print()
|
print()
|
||||||
# we have to ensure that the assistant end token is the last token
|
# we have to ensure that the assistant end token is the last token
|
||||||
# so even if generation ends due to max tokens, we have to append it to the end
|
# so even if generation ends due to max tokens, we have to append it to the end
|
||||||
|
|||||||
+12
-18
@@ -10,8 +10,6 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
@@ -185,7 +183,6 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
|
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
|
||||||
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
||||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
|
||||||
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
||||||
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
|
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
|
||||||
parser.add_argument('-n', '--num-samples', type=int, default=1)
|
parser.add_argument('-n', '--num-samples', type=int, default=1)
|
||||||
@@ -199,8 +196,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
|
||||||
|
|
||||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||||
engine = Engine(model, tokenizer)
|
engine = Engine(model, tokenizer)
|
||||||
@@ -220,19 +215,18 @@ if __name__ == "__main__":
|
|||||||
# Run all the task evaluations sequentially
|
# Run all the task evaluations sequentially
|
||||||
results = {}
|
results = {}
|
||||||
for task_name in task_names:
|
for task_name in task_names:
|
||||||
with autocast_ctx:
|
acc = run_chat_eval(
|
||||||
acc = run_chat_eval(
|
task_name,
|
||||||
task_name,
|
model, tokenizer, engine,
|
||||||
model, tokenizer, engine,
|
batch_size=args.batch_size,
|
||||||
batch_size=args.batch_size,
|
num_samples=args.num_samples,
|
||||||
num_samples=args.num_samples,
|
max_new_tokens=args.max_new_tokens,
|
||||||
max_new_tokens=args.max_new_tokens,
|
temperature=args.temperature,
|
||||||
temperature=args.temperature,
|
top_k=args.top_k,
|
||||||
top_k=args.top_k,
|
max_problems=args.max_problems,
|
||||||
max_problems=args.max_problems,
|
)
|
||||||
)
|
results[task_name] = acc
|
||||||
results[task_name] = acc
|
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
||||||
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
|
||||||
|
|
||||||
# Log to report
|
# Log to report
|
||||||
from nanochat.report import get_report
|
from nanochat.report import get_report
|
||||||
|
|||||||
+11
-19
@@ -22,8 +22,6 @@ import itertools
|
|||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
||||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
@@ -36,7 +34,6 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
|||||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||||
# Runtime
|
# Runtime
|
||||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
|
||||||
# Model loading
|
# Model loading
|
||||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||||
@@ -68,8 +65,6 @@ user_config = vars(args).copy()
|
|||||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
|
||||||
|
|
||||||
# wandb logging init
|
# wandb logging init
|
||||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||||
@@ -108,15 +103,14 @@ def get_batch():
|
|||||||
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
|
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
|
||||||
for sampling_step in range(num_sampling_steps):
|
for sampling_step in range(num_sampling_steps):
|
||||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||||
with autocast_ctx:
|
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
tokens,
|
||||||
tokens,
|
num_samples=args.device_batch_size,
|
||||||
num_samples=args.device_batch_size,
|
max_tokens=args.max_new_tokens,
|
||||||
max_tokens=args.max_new_tokens,
|
temperature=args.temperature,
|
||||||
temperature=args.temperature,
|
top_k=args.top_k,
|
||||||
top_k=args.top_k,
|
seed=seed, # must make sure to change the seed for each sampling step
|
||||||
seed=seed, # must make sure to change the seed for each sampling step
|
)
|
||||||
)
|
|
||||||
generated_token_sequences.extend(generated_token_sequences_batch)
|
generated_token_sequences.extend(generated_token_sequences_batch)
|
||||||
masks.extend(masks_batch)
|
masks.extend(masks_batch)
|
||||||
|
|
||||||
@@ -231,9 +225,8 @@ for step in range(num_steps):
|
|||||||
if step % args.eval_every == 0:
|
if step % args.eval_every == 0:
|
||||||
model.eval()
|
model.eval()
|
||||||
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||||
with autocast_ctx:
|
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
||||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
records = list(records_iter) # collect all records
|
||||||
records = list(records_iter) # collect all records
|
|
||||||
for k in range(1, args.device_batch_size + 1):
|
for k in range(1, args.device_batch_size + 1):
|
||||||
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
||||||
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
||||||
@@ -268,8 +261,7 @@ for step in range(num_steps):
|
|||||||
rewards = rewards_all[b0:b1]
|
rewards = rewards_all[b0:b1]
|
||||||
advantages = advantages_all[b0:b1]
|
advantages = advantages_all[b0:b1]
|
||||||
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
|
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
|
||||||
with autocast_ctx:
|
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
||||||
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
|
||||||
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
|
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
|
||||||
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
||||||
# normalize by the number of valid tokens, number of passes, and examples_per_rank
|
# normalize by the number of valid tokens, number of passes, and examples_per_rank
|
||||||
|
|||||||
+24
-12
@@ -16,8 +16,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
|||||||
import time
|
import time
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
from contextlib import nullcontext
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops
|
|
||||||
from nanochat.tokenizer import get_token_bytes
|
from nanochat.tokenizer import get_token_bytes
|
||||||
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
@@ -75,7 +74,7 @@ user_config = vars(args).copy()
|
|||||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0
|
master_process = ddp_rank == 0
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
||||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||||
if device_type == "cuda":
|
if device_type == "cuda":
|
||||||
@@ -151,6 +150,11 @@ if args.load_optimizer:
|
|||||||
else:
|
else:
|
||||||
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
|
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
|
||||||
|
|
||||||
|
# GradScaler for fp16 training (bf16/fp32 don't need it)
|
||||||
|
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
||||||
|
if scaler is not None:
|
||||||
|
print0("GradScaler enabled for fp16 training")
|
||||||
|
|
||||||
# Override the initial learning rate as a fraction of the base learning rate
|
# Override the initial learning rate as a fraction of the base learning rate
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group["lr"] = group["lr"] * args.init_lr_frac
|
group["lr"] = group["lr"] * args.init_lr_frac
|
||||||
@@ -344,8 +348,7 @@ while True:
|
|||||||
model.eval()
|
model.eval()
|
||||||
val_loader = build_val_loader()
|
val_loader = build_val_loader()
|
||||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||||
with autocast_ctx:
|
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
|
||||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||||
if val_bpb < min_val_bpb:
|
if val_bpb < min_val_bpb:
|
||||||
min_val_bpb = val_bpb
|
min_val_bpb = val_bpb
|
||||||
@@ -373,9 +376,8 @@ while True:
|
|||||||
for task_name in all_tasks:
|
for task_name in all_tasks:
|
||||||
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
|
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
|
||||||
max_problems = None if limit < 0 else limit # -1 means no limit
|
max_problems = None if limit < 0 else limit # -1 means no limit
|
||||||
with autocast_ctx:
|
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
||||||
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
batch_size=args.device_batch_size, max_problems=max_problems)
|
||||||
batch_size=args.device_batch_size, max_problems=max_problems)
|
|
||||||
task_results[task_name] = acc
|
task_results[task_name] = acc
|
||||||
print0(f" {task_name}: {100*acc:.2f}%")
|
print0(f" {task_name}: {100*acc:.2f}%")
|
||||||
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
|
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
|
||||||
@@ -428,11 +430,13 @@ while True:
|
|||||||
synchronize()
|
synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for micro_step in range(grad_accum_steps):
|
for micro_step in range(grad_accum_steps):
|
||||||
with autocast_ctx:
|
loss = model(x, y)
|
||||||
loss = model(x, y)
|
|
||||||
train_loss = loss.detach() # for logging
|
train_loss = loss.detach() # for logging
|
||||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||||
loss.backward()
|
if scaler is not None:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||||
# step the optimizer
|
# step the optimizer
|
||||||
@@ -442,7 +446,15 @@ while True:
|
|||||||
group["lr"] = group["initial_lr"] * lrm
|
group["lr"] = group["initial_lr"] * lrm
|
||||||
if group['kind'] == 'muon':
|
if group['kind'] == 'muon':
|
||||||
group["momentum"] = muon_momentum
|
group["momentum"] = muon_momentum
|
||||||
optimizer.step()
|
if scaler is not None:
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
if is_ddp_initialized():
|
||||||
|
for v in scaler._found_inf_per_device(optimizer).values():
|
||||||
|
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
optimizer.step()
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
synchronize()
|
synchronize()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|||||||
+25
-33
@@ -44,7 +44,6 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional, AsyncGenerator
|
from typing import List, Optional, AsyncGenerator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from contextlib import nullcontext
|
|
||||||
from nanochat.common import compute_init, autodetect_device_type
|
from nanochat.common import compute_init, autodetect_device_type
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
@@ -69,7 +68,6 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
|
|||||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
|
||||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -84,7 +82,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Worker:
|
class Worker:
|
||||||
@@ -93,7 +90,6 @@ class Worker:
|
|||||||
device: torch.device
|
device: torch.device
|
||||||
engine: Engine
|
engine: Engine
|
||||||
tokenizer: object
|
tokenizer: object
|
||||||
autocast_ctx: torch.amp.autocast
|
|
||||||
|
|
||||||
class WorkerPool:
|
class WorkerPool:
|
||||||
"""Pool of workers, each with a model replica on a different GPU."""
|
"""Pool of workers, each with a model replica on a different GPU."""
|
||||||
@@ -125,14 +121,11 @@ class WorkerPool:
|
|||||||
|
|
||||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||||
engine = Engine(model, tokenizer)
|
engine = Engine(model, tokenizer)
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
|
||||||
|
|
||||||
worker = Worker(
|
worker = Worker(
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
device=device,
|
device=device,
|
||||||
engine=engine,
|
engine=engine,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
autocast_ctx=autocast_ctx
|
|
||||||
)
|
)
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
await self.available_workers.put(worker)
|
await self.available_workers.put(worker)
|
||||||
@@ -279,34 +272,33 @@ async def generate_stream(
|
|||||||
# Track the last complete UTF-8 string (without replacement characters)
|
# Track the last complete UTF-8 string (without replacement characters)
|
||||||
last_clean_text = ""
|
last_clean_text = ""
|
||||||
|
|
||||||
with worker.autocast_ctx:
|
for token_column, token_masks in worker.engine.generate(
|
||||||
for token_column, token_masks in worker.engine.generate(
|
tokens,
|
||||||
tokens,
|
num_samples=1,
|
||||||
num_samples=1,
|
max_tokens=max_new_tokens,
|
||||||
max_tokens=max_new_tokens,
|
temperature=temperature,
|
||||||
temperature=temperature,
|
top_k=top_k,
|
||||||
top_k=top_k,
|
seed=random.randint(0, 2**31 - 1)
|
||||||
seed=random.randint(0, 2**31 - 1)
|
):
|
||||||
):
|
token = token_column[0]
|
||||||
token = token_column[0]
|
|
||||||
|
|
||||||
# Stopping criteria
|
# Stopping criteria
|
||||||
if token == assistant_end or token == bos:
|
if token == assistant_end or token == bos:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Append the token to sequence
|
# Append the token to sequence
|
||||||
accumulated_tokens.append(token)
|
accumulated_tokens.append(token)
|
||||||
# Decode all accumulated tokens to get proper UTF-8 handling
|
# Decode all accumulated tokens to get proper UTF-8 handling
|
||||||
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
||||||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||||
# Only emit text if it doesn't end with a replacement character
|
# Only emit text if it doesn't end with a replacement character
|
||||||
# This ensures we don't emit incomplete UTF-8 sequences
|
# This ensures we don't emit incomplete UTF-8 sequences
|
||||||
if not current_text.endswith('�'):
|
if not current_text.endswith('�'):
|
||||||
# Extract only the new text since last clean decode
|
# Extract only the new text since last clean decode
|
||||||
new_text = current_text[len(last_clean_text):]
|
new_text = current_text[len(last_clean_text):]
|
||||||
if new_text: # Only yield if there's new content
|
if new_text: # Only yield if there's new content
|
||||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||||
last_clean_text = current_text
|
last_clean_text = current_text
|
||||||
|
|
||||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||||
|
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ from nanochat.engine import KVCache
|
|||||||
|
|
||||||
|
|
||||||
def set_impl(impl):
|
def set_impl(impl):
|
||||||
"""Set the implementation override ('fa3', 'sdpa', or None for auto)."""
|
"""Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3."""
|
||||||
fa_module._override_impl = impl
|
fa_module._override_impl = impl
|
||||||
|
fa_module.USE_FA3 = fa_module._resolve_use_fa3()
|
||||||
|
|
||||||
|
|
||||||
def run_both_impls(fn):
|
def run_both_impls(fn):
|
||||||
@@ -343,19 +344,19 @@ class TestOverrideMechanism:
|
|||||||
def test_override_fa3(self):
|
def test_override_fa3(self):
|
||||||
"""Test that override='fa3' uses FA3."""
|
"""Test that override='fa3' uses FA3."""
|
||||||
set_impl('fa3')
|
set_impl('fa3')
|
||||||
assert fa_module._use_fa3() == True
|
assert fa_module.USE_FA3 == True
|
||||||
set_impl(None)
|
set_impl(None)
|
||||||
|
|
||||||
def test_override_sdpa(self):
|
def test_override_sdpa(self):
|
||||||
"""Test that override='sdpa' uses SDPA."""
|
"""Test that override='sdpa' uses SDPA."""
|
||||||
set_impl('sdpa')
|
set_impl('sdpa')
|
||||||
assert fa_module._use_fa3() == False
|
assert fa_module.USE_FA3 == False
|
||||||
set_impl(None)
|
set_impl(None)
|
||||||
|
|
||||||
def test_override_auto(self):
|
def test_override_auto(self):
|
||||||
"""Test that override=None uses auto-detection."""
|
"""Test that override=None uses auto-detection."""
|
||||||
set_impl(None)
|
set_impl(None)
|
||||||
assert fa_module._use_fa3() == HAS_FA3
|
assert fa_module.USE_FA3 == HAS_FA3
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user