big change: add pretraining resumption logic so that checkpoints can now be approximately resumed and training can continue. this is useful for very long runs when you don't want the anxiety of your run crashing for some reason. alternatively, it's a way to recover training in the event of loss spikes. i mean, this should have been there in v0 but it's ok. the resumption is approximate to control complexity and bloat, but it's possible we want to change that in the future. to use, set --save_every to a step interval to write checkpoints with, and then use --resume_from_step to resume optimization from a given step. only base model training (pretraining) supports this atm, but it's ok because midtraining is comparably quite a bit faster.

This commit is contained in:
Andrej Karpathy
2025-11-13 15:34:40 +00:00
parent 91f09ccd0d
commit c6abcdfe3a
4 changed files with 145 additions and 63 deletions
+17 -18
View File
@@ -20,33 +20,32 @@ def log0(message):
if int(os.environ.get('RANK', 0)) == 0:
logger.info(message)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
os.makedirs(checkpoint_dir, exist_ok=True)
# Save the model state (parameters)
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
torch.save(model_data, model_path)
log0(f"Saved model file to: {model_path}")
# Save the optimizer state (useful for SFT or any other fine-tuning)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
# Save the model state parameters
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
torch.save(model_data, model_path)
logger.info(f"Saved model parameters to: {model_path}")
# Save the metadata dict as json
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
logger.info(f"Saved metadata to: {meta_path}")
# Note that optimizer state is sharded across ranks, so each rank must save its own.
if optimizer_data is not None:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
torch.save(optimizer_data, optimizer_path)
log0(f"Saved optimizer file to: {optimizer_path}")
# Save the metadata dict as json
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
log0(f"Saved metadata file to: {meta_path}")
logger.info(f"Saved optimizer state to: {optimizer_path}")
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
model_data = torch.load(model_path, map_location=device)
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
optimizer_data = torch.load(optimizer_path, map_location=device)
# Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")