delete the configurator in favor of argparse and clean up a lot of kwarg details to make them more consistent across all scripts
This commit is contained in:
+69
-57
@@ -16,57 +16,69 @@ python -m scripts.chat_rl
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import itertools
|
||||
import re
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
||||
# RL hyperparameters
|
||||
run = "dummy" # wandb run name
|
||||
source = "sft" # mid|sft
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 8 # no forward pass will go above this to not OOM
|
||||
examples_per_step = 16 # in total and across all ranks (note: examples, not samples/completions!)
|
||||
num_samples = 16 # number of samples per example (/question)
|
||||
max_new_tokens = 256
|
||||
temperature = 1.0
|
||||
top_k = 50 # TODO: try None?
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
matrix_lr = 0.02
|
||||
weight_decay = 0.0
|
||||
init_lr_frac = 0.05
|
||||
num_epochs = 1 # how many epochs of gsm8k to train on
|
||||
save_every = 60 # every how many steps to save the model
|
||||
eval_every = 60 # every how many steps to evaluate the model for val pass@k
|
||||
eval_examples = 400 # number of examples used for evaluating pass@k
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
# CLI arguments
|
||||
parser = argparse.ArgumentParser(description="Reinforcement learning on GSM8K")
|
||||
# Logging
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device_type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
||||
# Model loading
|
||||
parser.add_argument("--source", type=str, default="sft", help="mid|sft - which checkpoint to load from")
|
||||
parser.add_argument("--model_tag", type=str, default=None, help="model tag to load from")
|
||||
parser.add_argument("--model_step", type=int, default=None, help="model step to load from")
|
||||
# Training horizon
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="number of epochs over GSM8K")
|
||||
# Batch sizes / sampling
|
||||
parser.add_argument("--device_batch_size", type=int, default=8, help="max batch size per forward pass")
|
||||
parser.add_argument("--examples_per_step", type=int, default=16, help="total examples per optimization step across all ranks")
|
||||
parser.add_argument("--num_samples", type=int, default=16, help="number of samples per example/question")
|
||||
# Generation
|
||||
parser.add_argument("--max_new_tokens", type=int, default=256, help="max tokens to generate per sample")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="sampling temperature")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="top-k sampling (0 = disabled)")
|
||||
# Optimization
|
||||
parser.add_argument("--embedding_lr", type=float, default=0.2, help="learning rate for embedding parameters (Adam)")
|
||||
parser.add_argument("--unembedding_lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
||||
parser.add_argument("--matrix_lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
||||
parser.add_argument("--init_lr_frac", type=float, default=0.05, help="initial LR as fraction of base LR")
|
||||
# Evaluation / checkpointing
|
||||
parser.add_argument("--eval_every", type=int, default=60, help="evaluate pass@k every N steps")
|
||||
parser.add_argument("--eval_examples", type=int, default=400, help="number of examples for pass@k evaluation")
|
||||
parser.add_argument("--save_every", type=int, default=60, help="save checkpoint every N steps")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy()
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Init compute/precision
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=args.run, config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.model_step)
|
||||
engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -74,7 +86,7 @@ engine = Engine(model, tokenizer) # for sampling rollouts
|
||||
|
||||
train_task = GSM8K(subset="main", split="train")
|
||||
val_task = GSM8K(subset="main", split="test")
|
||||
num_steps = (len(train_task) // examples_per_step) * num_epochs
|
||||
num_steps = (len(train_task) // args.examples_per_step) * args.num_epochs
|
||||
print0(f"Calculated number of steps: {num_steps}")
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -95,16 +107,16 @@ def get_batch():
|
||||
model.eval() # ensure the model is in eval mode
|
||||
generated_token_sequences = []
|
||||
masks = []
|
||||
num_sampling_steps = num_samples // device_batch_size # go sequentially to prevent OOMs
|
||||
num_sampling_steps = args.num_samples // args.device_batch_size # go sequentially to prevent OOMs
|
||||
for sampling_step in range(num_sampling_steps):
|
||||
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
|
||||
with autocast_ctx:
|
||||
generated_token_sequences_batch, masks_batch = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=device_batch_size,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
num_samples=args.device_batch_size,
|
||||
max_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
seed=seed, # must make sure to change the seed for each sampling step
|
||||
)
|
||||
generated_token_sequences.extend(generated_token_sequences_batch)
|
||||
@@ -191,16 +203,16 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
||||
|
||||
# Init the optimizer
|
||||
optimizers = model.setup_optimizers(
|
||||
unembedding_lr=unembedding_lr,
|
||||
embedding_lr=embedding_lr,
|
||||
matrix_lr=matrix_lr,
|
||||
weight_decay=weight_decay,
|
||||
unembedding_lr=args.unembedding_lr,
|
||||
embedding_lr=args.embedding_lr,
|
||||
matrix_lr=args.matrix_lr,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
|
||||
# Set the initial learning rate as a fraction of the base learning rate
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["lr"] = group["lr"] * init_lr_frac
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later
|
||||
|
||||
# Learning rate scheduler: simple rampdown to zero over num_steps
|
||||
@@ -209,9 +221,9 @@ def get_lr_multiplier(it):
|
||||
return lrm
|
||||
|
||||
# Calculate the number of examples each rank handles to achieve the desired examples_per_step
|
||||
print0(f"Total sequences per step: {examples_per_step * num_samples}") # total batch size in sequences/step
|
||||
assert examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = examples_per_step // ddp_world_size # per GPU
|
||||
print0(f"Total sequences per step: {args.examples_per_step * args.num_samples}") # total batch size in sequences/step
|
||||
assert args.examples_per_step % ddp_world_size == 0, "Desired examples per step must be divisible by the number of ranks"
|
||||
examples_per_rank = args.examples_per_step // ddp_world_size # per GPU
|
||||
print0(f"Calculated examples per rank: {examples_per_rank}")
|
||||
|
||||
# Kick off the training loop
|
||||
@@ -219,22 +231,22 @@ batch_iterator = get_batch()
|
||||
for step in range(num_steps):
|
||||
|
||||
# Evaluate the model once in a while and log to wandb
|
||||
if step % eval_every == 0:
|
||||
if step % args.eval_every == 0:
|
||||
model.eval()
|
||||
passk = torch.zeros(device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
passk = torch.zeros(args.device_batch_size, device=device) # pass@k for k=1..device_batch_size
|
||||
with autocast_ctx:
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=device_batch_size, max_examples=eval_examples, temperature=1.0)
|
||||
records_iter = run_gsm8k_eval(val_task, tokenizer, engine, num_samples=args.device_batch_size, max_examples=args.eval_examples, temperature=1.0)
|
||||
records = list(records_iter) # collect all records
|
||||
for k in range(1, device_batch_size + 1):
|
||||
for k in range(1, args.device_batch_size + 1):
|
||||
passk[k - 1] = sum(any(o["is_correct"] for o in r["outcomes"][:k]) for r in records)
|
||||
num_records = torch.tensor(len(records), dtype=torch.long, device=device)
|
||||
if ddp:
|
||||
dist.all_reduce(num_records, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(passk, op=dist.ReduceOp.SUM)
|
||||
passk = passk / num_records.item() # normalize by the total number of records
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, device_batch_size + 1)]
|
||||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)]
|
||||
print0(f"Step {step} | {', '.join(print_passk)}")
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, device_batch_size + 1)}
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)}
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**log_passk,
|
||||
@@ -249,11 +261,11 @@ for step in range(num_steps):
|
||||
# Evaluate the loss and gradients
|
||||
model.train() # ensure the model is in train mode
|
||||
# We need one more loop because we can never exceed the device_batch_size
|
||||
assert inputs_all.size(0) % device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // device_batch_size
|
||||
assert inputs_all.size(0) % args.device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // args.device_batch_size
|
||||
for pass_idx in range(num_passes):
|
||||
# Pluck out the batch for this pass
|
||||
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
|
||||
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
|
||||
inputs = inputs_all[b0:b1]
|
||||
targets = targets_all[b0:b1]
|
||||
rewards = rewards_all[b0:b1]
|
||||
@@ -306,10 +318,10 @@ for step in range(num_steps):
|
||||
})
|
||||
|
||||
# Master process saves the model once in a while. Skip first step. Save last step.
|
||||
if master_process and ((step > 0 and step % save_every == 0) or step == num_steps - 1):
|
||||
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
|
||||
base_dir = get_base_dir()
|
||||
depth = model.config.n_layer
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # base the model tag on the depth of the base model
|
||||
checkpoint_dir = os.path.join(base_dir, "chatrl_checkpoints", output_dirname)
|
||||
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
|
||||
save_checkpoint(
|
||||
|
||||
Reference in New Issue
Block a user