delete autocast, an unnecessary thorn in my side, manage dtypes directly
This commit is contained in:
+4
-12
@@ -29,8 +29,6 @@ import random
|
||||
import zipfile
|
||||
import tempfile
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||
@@ -199,8 +197,6 @@ def main():
|
||||
# Distributed / precision setup
|
||||
device_type = autodetect_device_type() if args.device_type == '' else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Load model and tokenizer
|
||||
is_hf_model = args.hf_path is not None
|
||||
if is_hf_model:
|
||||
@@ -244,8 +240,7 @@ def main():
|
||||
print0("\nConditioned samples:")
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample_str = tokenizer.decode(sample[0])
|
||||
print0("-" * 80)
|
||||
print0(sample_str)
|
||||
@@ -253,8 +248,7 @@ def main():
|
||||
|
||||
print0("\nUnconditioned samples:")
|
||||
tokens = tokenizer("", prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
||||
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
||||
for sample in uncond:
|
||||
sample_str = tokenizer.decode(sample)
|
||||
print0("-" * 80)
|
||||
@@ -277,8 +271,7 @@ def main():
|
||||
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
bpb_results[split_name] = bpb
|
||||
print0(f"{split_name} bpb: {bpb:.6f}")
|
||||
|
||||
@@ -287,8 +280,7 @@ def main():
|
||||
print0("\n" + "="*80)
|
||||
print0("CORE Evaluation")
|
||||
print0("="*80)
|
||||
with autocast_ctx:
|
||||
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
core_results = evaluate_core(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
||||
# Write CSV output
|
||||
if ddp_rank == 0:
|
||||
|
||||
+40
-15
@@ -19,14 +19,15 @@ import time
|
||||
import math
|
||||
import argparse
|
||||
from dataclasses import asdict
|
||||
from contextlib import nullcontext, contextmanager
|
||||
from contextlib import contextmanager
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.gpt import GPT, GPTConfig, Linear
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
@@ -86,7 +87,6 @@ user_config = vars(args).copy() # for logging
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
@@ -95,17 +95,23 @@ if device_type == "cuda":
|
||||
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
||||
else:
|
||||
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if HAS_FA3:
|
||||
from nanochat.flash_attention import USE_FA3
|
||||
using_fa3 = USE_FA3
|
||||
if using_fa3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
|
||||
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
|
||||
else:
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Training will be less efficient without FA3")
|
||||
if args.window_pattern != "L":
|
||||
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||
@@ -213,9 +219,9 @@ def disable_fp8(model):
|
||||
yield # No FP8 modules, nothing to do
|
||||
return
|
||||
|
||||
# Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy)
|
||||
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
|
||||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
linear = nn.Linear(
|
||||
linear = Linear(
|
||||
fp8_module.in_features,
|
||||
fp8_module.out_features,
|
||||
bias=fp8_module.bias is not None,
|
||||
@@ -315,6 +321,12 @@ if resuming:
|
||||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32)
|
||||
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
||||
if scaler is not None:
|
||||
print0("GradScaler enabled for fp16 training")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
@@ -405,7 +417,7 @@ while True:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with disable_fp8(model), autocast_ctx:
|
||||
with disable_fp8(model):
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
@@ -424,7 +436,7 @@ while True:
|
||||
results = {}
|
||||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||
model.eval()
|
||||
with disable_fp8(orig_model), autocast_ctx:
|
||||
with disable_fp8(orig_model):
|
||||
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
@@ -451,7 +463,7 @@ while True:
|
||||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with disable_fp8(orig_model), autocast_ctx:
|
||||
with disable_fp8(orig_model):
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
@@ -491,11 +503,13 @@ while True:
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
if scaler is not None:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(step)
|
||||
@@ -506,7 +520,18 @@ while True:
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
optimizer.step()
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
# In distributed training, all ranks must agree on whether to skip the step.
|
||||
# Each rank may independently encounter inf/nan gradients, so we all-reduce
|
||||
# the found_inf flag (MAX = if any rank found inf, all ranks skip).
|
||||
if is_ddp_initialized():
|
||||
for v in scaler._found_inf_per_device(optimizer).values():
|
||||
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
|
||||
+5
-10
@@ -7,7 +7,6 @@ python -m scripts.chat_cli
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from contextlib import nullcontext
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
@@ -19,15 +18,12 @@ parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the mod
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
@@ -87,12 +83,11 @@ while True:
|
||||
}
|
||||
response_tokens = []
|
||||
print("\nAssistant: ", end="", flush=True)
|
||||
with autocast_ctx:
|
||||
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||
response_tokens.append(token)
|
||||
token_text = tokenizer.decode([token])
|
||||
print(token_text, end="", flush=True)
|
||||
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
|
||||
token = token_column[0] # pop the batch dimension (num_samples=1)
|
||||
response_tokens.append(token)
|
||||
token_text = tokenizer.decode([token])
|
||||
print(token_text, end="", flush=True)
|
||||
print()
|
||||
# we have to ensure that the assistant end token is the last token
|
||||
# so even if generation ends due to max tokens, we have to append it to the end
|
||||
|
||||
+12
-18
@@ -10,8 +10,6 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -185,7 +183,6 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--source', type=str, required=True, help="Source of the model: sft|rl")
|
||||
parser.add_argument('-a', '--task-name', type=str, default=None, help="Task name. Default = all tasks. Use | to split multiple tasks.")
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.0)
|
||||
parser.add_argument('-m', '--max-new-tokens', type=int, default=512)
|
||||
parser.add_argument('-n', '--num-samples', type=int, default=1)
|
||||
@@ -199,8 +196,6 @@ if __name__ == "__main__":
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
@@ -220,19 +215,18 @@ if __name__ == "__main__":
|
||||
# Run all the task evaluations sequentially
|
||||
results = {}
|
||||
for task_name in task_names:
|
||||
with autocast_ctx:
|
||||
acc = run_chat_eval(
|
||||
task_name,
|
||||
model, tokenizer, engine,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.num_samples,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
max_problems=args.max_problems,
|
||||
)
|
||||
results[task_name] = acc
|
||||
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
||||
acc = run_chat_eval(
|
||||
task_name,
|
||||
model, tokenizer, engine,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.num_samples,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
max_problems=args.max_problems,
|
||||
)
|
||||
results[task_name] = acc
|
||||
print0(f"{task_name} accuracy: {100 * acc:.2f}%")
|
||||
|
||||
# Log to report
|
||||
from nanochat.report import get_report
|
||||
|
||||
+11
-19
@@ -22,8 +22,6 @@ import itertools
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
@@ -36,7 +34,6 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||
@@ -68,8 +65,6 @@ user_config = vars(args).copy()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
@@ -108,15 +103,14 @@ def get_batch():
|
||||
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
|
||||
for sampling_step in range(num_sampling_steps):
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||
with autocast_ctx:
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=args.device_batch_size,
|
||||
max_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
)
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=args.device_batch_size,
|
||||
max_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
)
|
||||
generated_token_sequences.extend(generated_token_sequences_batch)
|
||||
masks.extend(masks_batch)
|
||||
|
||||
@@ -231,9 +225,8 @@ for step in range(num_steps):
|
||||
if step % args.eval_every == 0:
|
||||
model.eval()
|
||||
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
with autocast_ctx:
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
||||
records = list(records_iter) # collect all records
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
||||
records = list(records_iter) # collect all records
|
||||
for k in range(1, args.device_batch_size + 1):
|
||||
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
||||
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
||||
@@ -268,8 +261,7 @@ for step in range(num_steps):
|
||||
rewards = rewards_all[b0:b1]
|
||||
advantages = advantages_all[b0:b1]
|
||||
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
|
||||
with autocast_ctx:
|
||||
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
||||
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
||||
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
|
||||
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
||||
# normalize by the number of valid tokens, number of passes, and examples_per_rank
|
||||
|
||||
+24
-12
@@ -16,8 +16,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
@@ -75,7 +74,7 @@ user_config = vars(args).copy()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
@@ -151,6 +150,11 @@ if args.load_optimizer:
|
||||
else:
|
||||
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
|
||||
|
||||
# GradScaler for fp16 training (bf16/fp32 don't need it)
|
||||
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
||||
if scaler is not None:
|
||||
print0("GradScaler enabled for fp16 training")
|
||||
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
@@ -344,8 +348,7 @@ while True:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
@@ -373,9 +376,8 @@ while True:
|
||||
for task_name in all_tasks:
|
||||
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
|
||||
max_problems = None if limit < 0 else limit # -1 means no limit
|
||||
with autocast_ctx:
|
||||
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
||||
batch_size=args.device_batch_size, max_problems=max_problems)
|
||||
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
||||
batch_size=args.device_batch_size, max_problems=max_problems)
|
||||
task_results[task_name] = acc
|
||||
print0(f" {task_name}: {100*acc:.2f}%")
|
||||
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
|
||||
@@ -428,11 +430,13 @@ while True:
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
if scaler is not None:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizer
|
||||
@@ -442,7 +446,15 @@ while True:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
optimizer.step()
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
if is_ddp_initialized():
|
||||
for v in scaler._found_inf_per_device(optimizer).values():
|
||||
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
|
||||
+25
-33
@@ -44,7 +44,6 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
@@ -69,7 +68,6 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
|
||||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
@@ -84,7 +82,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
@@ -93,7 +90,6 @@ class Worker:
|
||||
device: torch.device
|
||||
engine: Engine
|
||||
tokenizer: object
|
||||
autocast_ctx: torch.amp.autocast
|
||||
|
||||
class WorkerPool:
|
||||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
@@ -125,14 +121,11 @@ class WorkerPool:
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
device=device,
|
||||
engine=engine,
|
||||
tokenizer=tokenizer,
|
||||
autocast_ctx=autocast_ctx
|
||||
)
|
||||
self.workers.append(worker)
|
||||
await self.available_workers.put(worker)
|
||||
@@ -279,34 +272,33 @@ async def generate_stream(
|
||||
# Track the last complete UTF-8 string (without replacement characters)
|
||||
last_clean_text = ""
|
||||
|
||||
with worker.autocast_ctx:
|
||||
for token_column, token_masks in worker.engine.generate(
|
||||
tokens,
|
||||
num_samples=1,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
seed=random.randint(0, 2**31 - 1)
|
||||
):
|
||||
token = token_column[0]
|
||||
for token_column, token_masks in worker.engine.generate(
|
||||
tokens,
|
||||
num_samples=1,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
seed=random.randint(0, 2**31 - 1)
|
||||
):
|
||||
token = token_column[0]
|
||||
|
||||
# Stopping criteria
|
||||
if token == assistant_end or token == bos:
|
||||
break
|
||||
# Stopping criteria
|
||||
if token == assistant_end or token == bos:
|
||||
break
|
||||
|
||||
# Append the token to sequence
|
||||
accumulated_tokens.append(token)
|
||||
# Decode all accumulated tokens to get proper UTF-8 handling
|
||||
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
||||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||
# Only emit text if it doesn't end with a replacement character
|
||||
# This ensures we don't emit incomplete UTF-8 sequences
|
||||
if not current_text.endswith('�'):
|
||||
# Extract only the new text since last clean decode
|
||||
new_text = current_text[len(last_clean_text):]
|
||||
if new_text: # Only yield if there's new content
|
||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||
last_clean_text = current_text
|
||||
# Append the token to sequence
|
||||
accumulated_tokens.append(token)
|
||||
# Decode all accumulated tokens to get proper UTF-8 handling
|
||||
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
||||
current_text = worker.tokenizer.decode(accumulated_tokens)
|
||||
# Only emit text if it doesn't end with a replacement character
|
||||
# This ensures we don't emit incomplete UTF-8 sequences
|
||||
if not current_text.endswith('�'):
|
||||
# Extract only the new text since last clean decode
|
||||
new_text = current_text[len(last_clean_text):]
|
||||
if new_text: # Only yield if there's new content
|
||||
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
||||
last_clean_text = current_text
|
||||
|
||||
yield f"data: {json.dumps({'done': True})}\n\n"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user