a number of upgrades to SFT script to bring it up to date w.r.t. pretraining and tuning some of its kwargs based on sweeps
This commit is contained in:
@@ -170,3 +170,22 @@ def load_model(source, *args, **kwargs):
|
||||
base_dir = get_base_dir()
|
||||
checkpoints_dir = os.path.join(base_dir, model_dir)
|
||||
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
||||
|
||||
def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
||||
"""Load just the optimizer shard for a given rank, without re-loading the model."""
|
||||
model_dir = {
|
||||
"base": "base_checkpoints",
|
||||
"sft": "chatsft_checkpoints",
|
||||
"rl": "chatrl_checkpoints",
|
||||
}[source]
|
||||
base_dir = get_base_dir()
|
||||
checkpoints_dir = os.path.join(base_dir, model_dir)
|
||||
if model_tag is None:
|
||||
model_tag = find_largest_model(checkpoints_dir)
|
||||
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
||||
if step is None:
|
||||
step = find_last_step(checkpoint_dir)
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
log0(f"Loading optimizer state from {optimizer_path}")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
return optimizer_data
|
||||
|
||||
Reference in New Issue
Block a user