delete the configurator in favor of argparse and clean up a lot of kwarg details to make them more consistent across all scripts

This commit is contained in:
Andrej Karpathy
2026-01-04 19:14:23 +00:00
parent 507d54224a
commit eb7bbc1b66
9 changed files with 546 additions and 450 deletions
+69 -57
View File
@@ -16,57 +16,69 @@ python -m scripts.chat_rl
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
"""
import argparse
import os
import itertools
import re
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import save_checkpoint, load_model
from nanochat.engine import Engine
from tasks.gsm8k import GSM8K
# RL hyperparameters
run = "dummy" # wandb run name
source = "sft" # mid|sft
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
device_batch_size = 8 # no forward pass will go above this to not OOM
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
num_samples = 16 # number of samples per example (/question)
max_new_tokens = 256
temperature = 1.0
top_k = 50 # TODO: try None?
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
weight_decay = 0.0
init_lr_frac = 0.05
num_epochs = 1 # how many epochs of gsm8k to train on
save_every = 60 # every how many steps to save the model
eval_every = 60 # every how many steps to evaluate the model for val pass@k
eval_examples = 400 # number of examples used for evaluating pass@k
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
# CLI arguments
parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
# Logging
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
# Training horizon
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K")
# Batch sizes / sampling
parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass")
parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks")
parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question")
# Generation
parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample")
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)")
# Optimization
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR")
# Evaluation / checkpointing
parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps")
parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation")
parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
# Init compute/precision
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
# Init model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
engine = Engine(model, tokenizer) # for sampling rollouts
# -----------------------------------------------------------------------------
@@ -74,7 +86,7 @@ engine = Engine(model, tokenizer) # for sampling rollouts
train_task = GSM8K(subset="main", split="train")
val_task = GSM8K(subset="main", split="test")
num_steps = (len(train_task) // examples_per_step) * num_epochs
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
print0(f"Calculated number of steps: {num_steps}")
@torch.no_grad()
@@ -95,16 +107,16 @@ def get_batch():
model.eval() # ensure the model is in eval mode
generated_token_sequences = []
masks = []
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
for sampling_step in range(num_sampling_steps):
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
with autocast_ctx:
generated_token_sequences_batch, masks_batch = engine.generate_batch(
tokens,
num_samples=device_batch_size,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
num_samples=args.device_batch_size,
max_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
seed=seed, # must make sure to change the seed for each sampling step
)
generated_token_sequences.extend(generated_token_sequences_batch)
@@ -191,16 +203,16 @@ def run_gsm8k_eval(task, tokenizer, engine,
# Init the optimizer
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay,
unembedding_lr=args.unembedding_lr,
embedding_lr=args.embedding_lr,
matrix_lr=args.matrix_lr,
weight_decay=args.weight_decay,
)
# Set the initial learning rate as a fraction of the base learning rate
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["lr"] * init_lr_frac
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
# Learning rate scheduler: simple rampdown to zero over num_steps
@@ -209,9 +221,9 @@ def get_lr_multiplier(it):
return lrm
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
examples_per_rank = examples_per_step // ddp_world_size # per GPU
print0(f"Total sequences per step: {args.examples_per_step * args.num_samples}") # total batch size in sequences/step
assert args.examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
examples_per_rank = args.examples_per_step // ddp_world_size # per GPU
print0(f"Calculated examples per rank: {examples_per_rank}")
# Kick off the training loop
@@ -219,22 +231,22 @@ batch_iterator = get_batch()
for step in range(num_steps):
# Evaluate the model once in a while and log to wandb
if step % eval_every == 0:
if step % args.eval_every == 0:
model.eval()
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
with autocast_ctx:
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
records = list(records_iter) # collect all records
for k in range(1, device_batch_size + 1):
for k in range(1, args.device_batch_size + 1):
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
if ddp:
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
passk = passk / num_records.item() # normalize by the total number of records
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)]
print0(f"Step {step} | {', '.join(print_passk)}")
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)}
wandb_run.log({
"step": step,
**log_passk,
@@ -249,11 +261,11 @@ for step in range(num_steps):
# Evaluate the loss and gradients
model.train() # ensure the model is in train mode
# We need one more loop because we can never exceed the device_batch_size
assert inputs_all.size(0) % device_batch_size == 0
num_passes = inputs_all.size(0) // device_batch_size
assert inputs_all.size(0) % args.device_batch_size == 0
num_passes = inputs_all.size(0) // args.device_batch_size
for pass_idx in range(num_passes):
# Pluck out the batch for this pass
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
inputs = inputs_all[b0:b1]
targets = targets_all[b0:b1]
rewards = rewards_all[b0:b1]
@@ -306,10 +318,10 @@ for step in range(num_steps):
})
# Master process saves the model once in a while. Skip first step. Save last step.
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
base_dir = get_base_dir()
depth = model.config.n_layer
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(