delete autocast, an unnecessary thorn in my side, manage dtypes directly

This commit is contained in:
Andrej Karpathy
2026-03-04 23:55:24 +00:00
parent 752abc836e
commit 1076f97059
15 changed files with 258 additions and 167 deletions
+11 -19
View File
@@ -22,8 +22,6 @@ import itertools
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import save_checkpoint, load_model
from nanochat.engine import Engine
@@ -36,7 +34,6 @@ parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
@@ -68,8 +65,6 @@ user_config = vars(args).copy()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
@@ -108,15 +103,14 @@ def get_batch():
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
for sampling_step in range(num_sampling_steps):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
with autocast_ctx:
generated_token_sequences_batch, masks_batch = engine.generate_batch(
tokens,
num_samples=args.device_batch_size,
max_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
seed=seed, # must make sure to change the seed for each sampling step
)
generated_token_sequences_batch, masks_batch = engine.generate_batch(
tokens,
num_samples=args.device_batch_size,
max_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
seed=seed, # must make sure to change the seed for each sampling step
)
generated_token_sequences.extend(generated_token_sequences_batch)
masks.extend(masks_batch)
@@ -231,9 +225,8 @@ for step in range(num_steps):
if step % args.eval_every == 0:
model.eval()
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
with autocast_ctx:
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
records = list(records_iter) # collect all records
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
records = list(records_iter) # collect all records
for k in range(1, args.device_batch_size + 1):
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
@@ -268,8 +261,7 @@ for step in range(num_steps):
rewards = rewards_all[b0:b1]
advantages = advantages_all[b0:b1]
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
with autocast_ctx:
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
# normalize by the number of valid tokens, number of passes, and examples_per_rank