delete autocast, an unnecessary thorn in my side, manage dtypes directly

This commit is contained in:
Andrej Karpathy
2026-03-04 23:55:24 +00:00
parent 752abc836e
commit 1076f97059
15 changed files with 258 additions and 167 deletions
+4 -12
View File
@@ -29,8 +29,6 @@ import random
import zipfile
import tempfile
import argparse
from contextlib import nullcontext
import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
@@ -199,8 +197,6 @@ def main():
# Distributed / precision setup
device_type = autodetect_device_type() if args.device_type == '' else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# Load model and tokenizer
is_hf_model = args.hf_path is not None
if is_hf_model:
@@ -244,8 +240,7 @@ def main():
print0("\nConditioned samples:")
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
print0("-" * 80)
print0(sample_str)
@@ -253,8 +248,7 @@ def main():
print0("\nUnconditioned samples:")
tokens = tokenizer("", prepend="<|bos|>")
with autocast_ctx:
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
for sample in uncond:
sample_str = tokenizer.decode(sample)
print0("-" * 80)
@@ -277,8 +271,7 @@ def main():
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
bpb = evaluate_bpb(model, loader, steps, token_bytes)
bpb_results[split_name] = bpb
print0(f"{split_name} bpb: {bpb:.6f}")
@@ -287,8 +280,7 @@ def main():
print0("\n" + "="*80)
print0("CORE Evaluation")
print0("="*80)
with autocast_ctx:
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
# Write CSV output
if ddp_rank == 0:
+40 -15
View File
@@ -19,14 +19,15 @@ import time
import math
import argparse
from dataclasses import asdict
from contextlib import nullcontext, contextmanager
from contextlib import contextmanager
import wandb
import torch
import torch.distributed as dist
from nanochat.gpt import GPT, GPTConfig
from nanochat.gpt import GPT, GPTConfig, Linear
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
@@ -86,7 +87,6 @@ user_config = vars(args).copy() # for logging
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
if device_type == "cuda":
@@ -95,17 +95,23 @@ if device_type == "cuda":
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
else:
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Flash Attention status
if HAS_FA3:
from nanochat.flash_attention import USE_FA3
using_fa3 = USE_FA3
if using_fa3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
else:
print0("!" * 80)
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
else:
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
print0("WARNING: Training will be less efficient without FA3")
if args.window_pattern != "L":
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
@@ -213,9 +219,9 @@ def disable_fp8(model):
yield # No FP8 modules, nothing to do
return
# Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy)
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
for parent, attr_name, fp8_module in fp8_locations:
linear = nn.Linear(
linear = Linear(
fp8_module.in_features,
fp8_module.out_features,
bias=fp8_module.bias is not None,
@@ -315,6 +321,12 @@ if resuming:
optimizer.load_state_dict(optimizer_data)
del optimizer_data
# -----------------------------------------------------------------------------
# GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32)
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
if scaler is not None:
print0("GradScaler enabled for fp16 training")
# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
@@ -405,7 +417,7 @@ while True:
model.eval()
val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with disable_fp8(model), autocast_ctx:
with disable_fp8(model):
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
if val_bpb < min_val_bpb:
@@ -424,7 +436,7 @@ while True:
results = {}
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
model.eval()
with disable_fp8(orig_model), autocast_ctx:
with disable_fp8(orig_model):
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
@@ -451,7 +463,7 @@ while True:
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with disable_fp8(orig_model), autocast_ctx:
with disable_fp8(orig_model):
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
print0(tokenizer.decode(sample[0]))
model.train()
@@ -491,11 +503,13 @@ while True:
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
loss = model(x, y)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# step the optimizer
lrm = get_lr_multiplier(step)
@@ -506,7 +520,18 @@ while True:
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
optimizer.step()
if scaler is not None:
scaler.unscale_(optimizer)
# In distributed training, all ranks must agree on whether to skip the step.
# Each rank may independently encounter inf/nan gradients, so we all-reduce
# the found_inf flag (MAX = if any rank found inf, all ranks skip).
if is_ddp_initialized():
for v in scaler._found_inf_per_device(optimizer).values():
dist.all_reduce(v, op=dist.ReduceOp.MAX)
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
model.zero_grad(set_to_none=True)
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
synchronize()
+5 -10
View File
@@ -7,7 +7,6 @@ python -m scripts.chat_cli
import argparse
import torch
from nanochat.common import compute_init, autodetect_device_type
from contextlib import nullcontext
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
@@ -19,15 +18,12 @@ parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the mod
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
args = parser.parse_args()
# Init the model and tokenizer
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine
@@ -87,12 +83,11 @@ while True:
}
response_tokens = []
print("\nAssistant: ", end="", flush=True)
with autocast_ctx:
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
token = token_column[0] # pop the batch dimension (num_samples=1)
response_tokens.append(token)
token_text = tokenizer.decode([token])
print(token_text, end="", flush=True)
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
token = token_column[0] # pop the batch dimension (num_samples=1)
response_tokens.append(token)
token_text = tokenizer.decode([token])
print(token_text, end="", flush=True)
print()
# we have to ensure that the assistant end token is the last token
# so even if generation ends due to max tokens, we have to append it to the end
+12 -18
View File
@@ -10,8 +10,6 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
import argparse
from functools import partial
from contextlib import nullcontext
import torch
import torch.distributed as dist
@@ -185,7 +183,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('-t', '--temperature', type=float, default=0.0)
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
parser.add_argument('-n', '--num-samples', type=int, default=1)
@@ -199,8 +196,6 @@ if __name__ == "__main__":
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer)
@@ -220,19 +215,18 @@ if __name__ == "__main__":
# Run all the task evaluations sequentially
results = {}
for task_name in task_names:
with autocast_ctx:
acc = run_chat_eval(
task_name,
model, tokenizer, engine,
batch_size=args.batch_size,
num_samples=args.num_samples,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
max_problems=args.max_problems,
)
results[task_name] = acc
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
acc = run_chat_eval(
task_name,
model, tokenizer, engine,
batch_size=args.batch_size,
num_samples=args.num_samples,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
max_problems=args.max_problems,
)
results[task_name] = acc
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
# Log to report
from nanochat.report import get_report
+11 -19
View File
@@ -22,8 +22,6 @@ import itertools
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import save_checkpoint, load_model
from nanochat.engine import Engine
@@ -36,7 +34,6 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
@@ -68,8 +65,6 @@ user_config = vars(args).copy()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
@@ -108,15 +103,14 @@ def get_batch():
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
for sampling_step in range(num_sampling_steps):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
with autocast_ctx:
generated_token_sequences_batch, masks_batch = engine.generate_batch(
tokens,
num_samples=args.device_batch_size,
max_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
seed=seed, # must make sure to change the seed for each sampling step
)
generated_token_sequences_batch, masks_batch = engine.generate_batch(
tokens,
num_samples=args.device_batch_size,
max_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
seed=seed, # must make sure to change the seed for each sampling step
)
generated_token_sequences.extend(generated_token_sequences_batch)
masks.extend(masks_batch)
@@ -231,9 +225,8 @@ for step in range(num_steps):
if step % args.eval_every == 0:
model.eval()
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
with autocast_ctx:
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
records = list(records_iter) # collect all records
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
records = list(records_iter) # collect all records
for k in range(1, args.device_batch_size + 1):
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
@@ -268,8 +261,7 @@ for step in range(num_steps):
rewards = rewards_all[b0:b1]
advantages = advantages_all[b0:b1]
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
with autocast_ctx:
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
# normalize by the number of valid tokens, number of passes, and examples_per_rank
+24 -12
View File
@@ -16,8 +16,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
from nanochat.loss_eval import evaluate_bpb
@@ -75,7 +74,7 @@ user_config = vars(args).copy()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
if device_type == "cuda":
@@ -151,6 +150,11 @@ if args.load_optimizer:
else:
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
# GradScaler for fp16 training (bf16/fp32 don't need it)
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
if scaler is not None:
print0("GradScaler enabled for fp16 training")
# Override the initial learning rate as a fraction of the base learning rate
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
@@ -344,8 +348,7 @@ while True:
model.eval()
val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
if val_bpb < min_val_bpb:
min_val_bpb = val_bpb
@@ -373,9 +376,8 @@ while True:
for task_name in all_tasks:
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
max_problems = None if limit < 0 else limit # -1 means no limit
with autocast_ctx:
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
batch_size=args.device_batch_size, max_problems=max_problems)
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
batch_size=args.device_batch_size, max_problems=max_problems)
task_results[task_name] = acc
print0(f" {task_name}: {100*acc:.2f}%")
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
@@ -428,11 +430,13 @@ while True:
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
loss = model(x, y)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically
# step the optimizer
@@ -442,7 +446,15 @@ while True:
group["lr"] = group["initial_lr"] * lrm
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
optimizer.step()
if scaler is not None:
scaler.unscale_(optimizer)
if is_ddp_initialized():
for v in scaler._found_inf_per_device(optimizer).values():
dist.all_reduce(v, op=dist.ReduceOp.MAX)
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
model.zero_grad(set_to_none=True)
synchronize()
t1 = time.time()
+25 -33
View File
@@ -44,7 +44,6 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from contextlib import nullcontext
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
@@ -69,7 +68,6 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
@@ -84,7 +82,6 @@ logger = logging.getLogger(__name__)
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
@dataclass
class Worker:
@@ -93,7 +90,6 @@ class Worker:
device: torch.device
engine: Engine
tokenizer: object
autocast_ctx: torch.amp.autocast
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
@@ -125,14 +121,11 @@ class WorkerPool:
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
autocast_ctx=autocast_ctx
)
self.workers.append(worker)
await self.available_workers.put(worker)
@@ -279,34 +272,33 @@ async def generate_stream(
# Track the last complete UTF-8 string (without replacement characters)
last_clean_text = ""
with worker.autocast_ctx:
for token_column, token_masks in worker.engine.generate(
tokens,
num_samples=1,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
seed=random.randint(0, 2**31 - 1)
):
token = token_column[0]
for token_column, token_masks in worker.engine.generate(
tokens,
num_samples=1,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
seed=random.randint(0, 2**31 - 1)
):
token = token_column[0]
# Stopping criteria
if token == assistant_end or token == bos:
break
# Stopping criteria
if token == assistant_end or token == bos:
break
# Append the token to sequence
accumulated_tokens.append(token)
# Decode all accumulated tokens to get proper UTF-8 handling
# Note that decode is a quite efficient operation, basically table lookup and string concat
current_text = worker.tokenizer.decode(accumulated_tokens)
# Only emit text if it doesn't end with a replacement character
# This ensures we don't emit incomplete UTF-8 sequences
if not current_text.endswith(''):
# Extract only the new text since last clean decode
new_text = current_text[len(last_clean_text):]
if new_text: # Only yield if there's new content
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
last_clean_text = current_text
# Append the token to sequence
accumulated_tokens.append(token)
# Decode all accumulated tokens to get proper UTF-8 handling
# Note that decode is a quite efficient operation, basically table lookup and string concat
current_text = worker.tokenizer.decode(accumulated_tokens)
# Only emit text if it doesn't end with a replacement character
# This ensures we don't emit incomplete UTF-8 sequences
if not current_text.endswith(''):
# Extract only the new text since last clean decode
new_text = current_text[len(last_clean_text):]
if new_text: # Only yield if there's new content
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
last_clean_text = current_text
yield f"data: {json.dumps({'done': True})}\n\n"