initial commit
This commit is contained in:
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Utilities for saving and loading model/optim/state checkpoints.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
from nanochat.common import setup_default_logging
|
||||
|
||||
# Set up logging
|
||||
setup_default_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
def log0(message):
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
|
||||
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
# Save the model state (parameters)
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
torch.save(model_data, model_path)
|
||||
log0(f"Saved model file to: {model_path}")
|
||||
# Save the optimizer state (useful for SFT or any other fine-tuning)
|
||||
if optimizer_data is not None:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||
torch.save(optimizer_data, optimizer_path)
|
||||
log0(f"Saved optimizer file to: {optimizer_path}")
|
||||
# Save the metadata dict as json
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
log0(f"Saved metadata file to: {meta_path}")
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
||||
# Load the model state
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
# Load the optimizer state if requested
|
||||
optimizer_data = None
|
||||
if load_optimizer:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
# Load the metadata
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "r") as f:
|
||||
meta_data = json.load(f)
|
||||
return model_data, optimizer_data, meta_data
|
||||
|
||||
|
||||
def build_model(checkpoint_dir, step, device, phase):
|
||||
"""
|
||||
A bunch of repetitive code to build a model from a given checkpoint.
|
||||
Returns:
|
||||
- base model - uncompiled, not wrapped in DDP
|
||||
- tokenizer
|
||||
- meta data saved during base model training
|
||||
"""
|
||||
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
||||
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
||||
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
|
||||
model_config_kwargs = meta_data["model_config"]
|
||||
log0(f"Building model with config: {model_config_kwargs}")
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
with torch.device("meta"):
|
||||
model = GPT(model_config)
|
||||
# Load the model state
|
||||
model.to_empty(device=device)
|
||||
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
# Put the model in the right training phase / mode
|
||||
if phase == "eval":
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
# Load the Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
# Sanity check: compatibility between model and tokenizer
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
|
||||
return model, tokenizer, meta_data
|
||||
|
||||
|
||||
def find_largest_model(checkpoint_dir):
|
||||
# attempt to guess the model tag: take the biggest model available
|
||||
model_tags = [f for f in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, f))]
|
||||
if not model_tags:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||
# 1) normally all model tags are of the form d<number>, try that first:
|
||||
candidates = []
|
||||
for model_tag in model_tags:
|
||||
match = re.match(r"d(\d+)", model_tag)
|
||||
if match:
|
||||
model_depth = int(match.group(1))
|
||||
candidates.append((model_depth, model_tag))
|
||||
if candidates:
|
||||
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||
return candidates[0][1]
|
||||
# 2) if that failed, take the most recently updated model:
|
||||
candidates.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x[1])), reverse=True)
|
||||
return candidates[0][1]
|
||||
|
||||
|
||||
def find_last_step(checkpoint_dir):
|
||||
# Look into checkpoint_dir and find model_<step>.pt with the highest step
|
||||
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
||||
if not checkpoint_files:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
||||
return last_step
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# convenience functions that take into account nanochat's directory structure
|
||||
|
||||
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
||||
if model_tag is None:
|
||||
# guess the model tag by defaulting to the largest model
|
||||
model_tag = find_largest_model(checkpoints_dir)
|
||||
log0(f"No model tag provided, guessing model tag: {model_tag}")
|
||||
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
||||
if step is None:
|
||||
# guess the step by defaulting to the last step
|
||||
step = find_last_step(checkpoint_dir)
|
||||
assert step is not None, f"No checkpoints found in {checkpoint_dir}"
|
||||
# build the model
|
||||
log0(f"Loading model from {checkpoint_dir} with step {step}")
|
||||
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
||||
return model, tokenizer, meta_data
|
||||
|
||||
def load_model(source, *args, **kwargs):
|
||||
model_dir = {
|
||||
"base": "base_checkpoints",
|
||||
"mid": "mid_checkpoints",
|
||||
"sft": "chatsft_checkpoints",
|
||||
"rl": "chatrl_checkpoints",
|
||||
}[source]
|
||||
base_dir = get_base_dir()
|
||||
checkpoints_dir = os.path.join(base_dir, model_dir)
|
||||
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
||||
Reference in New Issue
Block a user