a number of upgrades to SFT script to bring it up to date w.r.t. pretraining and tuning some of its kwargs based on sweeps

This commit is contained in:
Andrej Karpathy
2026-02-16 14:41:53 +00:00
parent 2f09686724
commit 788dadeb88
3 changed files with 159 additions and 45 deletions
+19
View File
@@ -170,3 +170,22 @@ def load_model(source, *args, **kwargs):
base_dir = get_base_dir()
checkpoints_dir = os.path.join(base_dir, model_dir)
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
def load_optimizer_state(source, device, rank, model_tag=None, step=None):
"""Load just the optimizer shard for a given rank, without re-loading the model."""
model_dir = {
"base": "base_checkpoints",
"sft": "chatsft_checkpoints",
"rl": "chatrl_checkpoints",
}[source]
base_dir = get_base_dir()
checkpoints_dir = os.path.join(base_dir, model_dir)
if model_tag is None:
model_tag = find_largest_model(checkpoints_dir)
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
if step is None:
step = find_last_step(checkpoint_dir)
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
log0(f"Loading optimizer state from {optimizer_path}")
optimizer_data = torch.load(optimizer_path, map_location=device)
return optimizer_data