a number of upgrades to SFT script to bring it up to date w.r.t. pretraining and tuning some of its kwargs based on sweeps
This commit is contained in:
@@ -170,3 +170,22 @@ def load_model(source, *args, **kwargs):
|
|||||||
base_dir = get_base_dir()
|
base_dir = get_base_dir()
|
||||||
checkpoints_dir = os.path.join(base_dir, model_dir)
|
checkpoints_dir = os.path.join(base_dir, model_dir)
|
||||||
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
||||||
|
|
||||||
|
def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
||||||
|
"""Load just the optimizer shard for a given rank, without re-loading the model."""
|
||||||
|
model_dir = {
|
||||||
|
"base": "base_checkpoints",
|
||||||
|
"sft": "chatsft_checkpoints",
|
||||||
|
"rl": "chatrl_checkpoints",
|
||||||
|
}[source]
|
||||||
|
base_dir = get_base_dir()
|
||||||
|
checkpoints_dir = os.path.join(base_dir, model_dir)
|
||||||
|
if model_tag is None:
|
||||||
|
model_tag = find_largest_model(checkpoints_dir)
|
||||||
|
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
||||||
|
if step is None:
|
||||||
|
step = find_last_step(checkpoint_dir)
|
||||||
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||||
|
log0(f"Loading optimizer state from {optimizer_path}")
|
||||||
|
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||||
|
return optimizer_data
|
||||||
|
|||||||
@@ -468,6 +468,7 @@ while True:
|
|||||||
"user_config": user_config, # inputs to the training script
|
"user_config": user_config, # inputs to the training script
|
||||||
"device_batch_size": args.device_batch_size,
|
"device_batch_size": args.device_batch_size,
|
||||||
"max_seq_len": args.max_seq_len,
|
"max_seq_len": args.max_seq_len,
|
||||||
|
"total_batch_size": total_batch_size,
|
||||||
"dataloader_state_dict": dataloader_state_dict,
|
"dataloader_state_dict": dataloader_state_dict,
|
||||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||||
"min_val_bpb": min_val_bpb,
|
"min_val_bpb": min_val_bpb,
|
||||||
|
|||||||
+128
-34
@@ -9,6 +9,7 @@ Or torchrun for training:
|
|||||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
|
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
@@ -16,12 +17,14 @@ import time
|
|||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops
|
||||||
from nanochat.tokenizer import get_token_bytes
|
from nanochat.tokenizer import get_token_bytes
|
||||||
from nanochat.checkpoint_manager import save_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
from nanochat.checkpoint_manager import load_model
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from nanochat.flash_attention import HAS_FA3
|
||||||
|
from nanochat.engine import Engine
|
||||||
|
from scripts.chat_eval import run_chat_eval
|
||||||
|
|
||||||
from tasks.common import TaskMixture
|
from tasks.common import TaskMixture
|
||||||
from tasks.gsm8k import GSM8K
|
from tasks.gsm8k import GSM8K
|
||||||
@@ -37,27 +40,30 @@ parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the m
|
|||||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||||
# Runtime
|
# Runtime
|
||||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
|
|
||||||
# Model loading
|
# Model loading
|
||||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||||
|
parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
|
||||||
# Training horizon
|
# Training horizon
|
||||||
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||||
# Batch sizes
|
# Batch sizes (default: inherit from pretrained checkpoint)
|
||||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)")
|
||||||
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
|
parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)")
|
||||||
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
|
parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)")
|
||||||
# Optimization
|
# Optimization (default: inherit from pretrained checkpoint)
|
||||||
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
|
parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)")
|
||||||
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
|
parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)")
|
||||||
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
|
parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)")
|
||||||
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
|
parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR")
|
||||||
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
|
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
|
||||||
|
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
|
||||||
|
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
|
||||||
# Evaluation
|
# Evaluation
|
||||||
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
|
parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)")
|
||||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
|
||||||
# Output
|
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
|
||||||
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
|
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
|
||||||
|
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
user_config = vars(args).copy()
|
user_config = vars(args).copy()
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -66,20 +72,48 @@ user_config = vars(args).copy()
|
|||||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0
|
master_process = ddp_rank == 0
|
||||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
|
||||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||||
|
if device_type == "cuda":
|
||||||
|
gpu_device_name = torch.cuda.get_device_name(0)
|
||||||
|
gpu_peak_flops = get_peak_flops(gpu_device_name)
|
||||||
|
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
||||||
|
else:
|
||||||
|
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||||
|
|
||||||
# wandb logging init
|
# wandb logging init
|
||||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
||||||
|
|
||||||
|
# Flash Attention status
|
||||||
|
if not HAS_FA3:
|
||||||
|
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||||
pretrain_batch_size = meta.get("device_batch_size", None)
|
|
||||||
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
|
# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override)
|
||||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
|
pretrain_user_config = meta.get("user_config", {})
|
||||||
|
for name, fallback, source in [
|
||||||
|
("max_seq_len", 2048, meta),
|
||||||
|
("device_batch_size", 32, meta),
|
||||||
|
("total_batch_size", 524288, meta),
|
||||||
|
("embedding_lr", 0.3, pretrain_user_config),
|
||||||
|
("unembedding_lr", 0.004, pretrain_user_config),
|
||||||
|
("matrix_lr", 0.02, pretrain_user_config),
|
||||||
|
]:
|
||||||
|
arg_val = getattr(args, name)
|
||||||
|
pretrain_val = source.get(name)
|
||||||
|
if arg_val is None:
|
||||||
|
resolved = pretrain_val if pretrain_val is not None else fallback
|
||||||
|
setattr(args, name, resolved)
|
||||||
|
print0(f"Inherited {name}={resolved} from pretrained checkpoint")
|
||||||
|
elif pretrain_val is not None and arg_val != pretrain_val:
|
||||||
|
print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}")
|
||||||
|
else:
|
||||||
|
print0(f"Using {name}={arg_val}")
|
||||||
|
|
||||||
orig_model = model
|
orig_model = model
|
||||||
model = torch.compile(model, dynamic=False)
|
model = torch.compile(model, dynamic=False)
|
||||||
depth = model.config.n_layer
|
depth = model.config.n_layer
|
||||||
@@ -94,14 +128,23 @@ print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation ste
|
|||||||
token_bytes = get_token_bytes(device=device)
|
token_bytes = get_token_bytes(device=device)
|
||||||
|
|
||||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||||
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
|
# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero
|
||||||
|
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
|
||||||
|
|
||||||
|
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
|
||||||
|
base_dir = get_base_dir()
|
||||||
|
if args.load_optimizer:
|
||||||
|
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
|
||||||
|
optimizer.load_state_dict(optimizer_data)
|
||||||
|
del optimizer_data
|
||||||
|
print0("Loaded optimizer state from pretrained checkpoint")
|
||||||
|
|
||||||
# Override the initial learning rate as a fraction of the base learning rate
|
# Override the initial learning rate as a fraction of the base learning rate
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group["lr"] = group["lr"] * args.init_lr_frac
|
group["lr"] = group["lr"] * args.init_lr_frac
|
||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
|
|
||||||
# SFT data mixture and DataLoader
|
# SFT data mixture and DataLoader
|
||||||
base_dir = get_base_dir()
|
|
||||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||||
train_dataset = TaskMixture([
|
train_dataset = TaskMixture([
|
||||||
SmolTalk(split="train"), # 460K rows of general conversations
|
SmolTalk(split="train"), # 460K rows of general conversations
|
||||||
@@ -236,10 +279,17 @@ train_loader = sft_data_generator_bos_bestfit("train")
|
|||||||
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
|
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
|
||||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||||
|
|
||||||
# Learning rate scheduler
|
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
||||||
|
# Same shape as base_train but uses progress (0→1) instead of absolute step counts,
|
||||||
|
# because SFT doesn't always know num_iterations in advance (dataset-driven stopping).
|
||||||
def get_lr_multiplier(progress):
|
def get_lr_multiplier(progress):
|
||||||
# first 80% of training: no decay, then linearly ramp down to 0.
|
if progress < args.warmup_ratio:
|
||||||
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
|
return (progress + 1e-8) / args.warmup_ratio
|
||||||
|
elif progress <= 1.0 - args.warmdown_ratio:
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio
|
||||||
|
return (1 - decay) * 1.0 + decay * args.final_lr_frac
|
||||||
|
|
||||||
# Momentum scheduler for Muon optimizer
|
# Momentum scheduler for Muon optimizer
|
||||||
def get_muon_momentum(it):
|
def get_muon_momentum(it):
|
||||||
@@ -282,8 +332,44 @@ while True:
|
|||||||
})
|
})
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# save checkpoint at the end of the run (only on master process)
|
# once in a while: estimate the ChatCORE metric (all ranks participate)
|
||||||
if master_process and last_step and not args.dry_run:
|
# use the original uncompiled model because the inputs keep changing shape
|
||||||
|
chatcore_results = {}
|
||||||
|
if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)):
|
||||||
|
model.eval()
|
||||||
|
engine = Engine(orig_model, tokenizer)
|
||||||
|
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
|
||||||
|
categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'}
|
||||||
|
baseline_accuracies = {
|
||||||
|
'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25,
|
||||||
|
'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0,
|
||||||
|
}
|
||||||
|
task_results = {}
|
||||||
|
for task_name in all_tasks:
|
||||||
|
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
|
||||||
|
max_problems = None if limit < 0 else limit # -1 means no limit
|
||||||
|
with autocast_ctx:
|
||||||
|
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
||||||
|
batch_size=args.device_batch_size, max_problems=max_problems)
|
||||||
|
task_results[task_name] = acc
|
||||||
|
print0(f" {task_name}: {100*acc:.2f}%")
|
||||||
|
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
|
||||||
|
def centered_mean(tasks):
|
||||||
|
return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks)
|
||||||
|
chatcore = centered_mean(all_tasks)
|
||||||
|
chatcore_cat = centered_mean(categorical_tasks)
|
||||||
|
print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}")
|
||||||
|
wandb_run.log({
|
||||||
|
"step": step,
|
||||||
|
"total_training_flops": flops_so_far,
|
||||||
|
"chatcore_metric": chatcore,
|
||||||
|
"chatcore_cat": chatcore_cat,
|
||||||
|
**{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()},
|
||||||
|
})
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard)
|
||||||
|
if last_step:
|
||||||
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
|
||||||
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
@@ -304,7 +390,8 @@ while True:
|
|||||||
"window_pattern": model.config.window_pattern,
|
"window_pattern": model.config.window_pattern,
|
||||||
},
|
},
|
||||||
"user_config": user_config, # inputs to the training script
|
"user_config": user_config, # inputs to the training script
|
||||||
}
|
},
|
||||||
|
rank=ddp_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if last_step:
|
if last_step:
|
||||||
@@ -346,8 +433,7 @@ while True:
|
|||||||
pct_done = 100 * progress
|
pct_done = 100 * progress
|
||||||
tok_per_sec = int(args.total_batch_size / dt)
|
tok_per_sec = int(args.total_batch_size / dt)
|
||||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||||
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
|
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||||
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
|
|
||||||
if step > 10:
|
if step > 10:
|
||||||
total_training_time += dt # only count the time after the first 10 steps
|
total_training_time += dt # only count the time after the first 10 steps
|
||||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
|
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
|
||||||
@@ -364,13 +450,21 @@ while True:
|
|||||||
"train/epoch": current_epoch,
|
"train/epoch": current_epoch,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# The garbage collector spends ~500ms scanning for cycles quite frequently.
|
||||||
|
# We manually manage it to avoid these pauses during training.
|
||||||
|
if step == 1:
|
||||||
|
gc.collect() # manually collect a lot of garbage from setup
|
||||||
|
gc.freeze() # freeze all currently surviving objects and exclude them from GC
|
||||||
|
gc.disable() # disable GC entirely except:
|
||||||
|
elif step % 5000 == 0: # every 5000 steps...
|
||||||
|
gc.collect() # manually collect, just to be safe for very long runs
|
||||||
|
|
||||||
# print a few more stats
|
# print a few more stats
|
||||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||||
|
|
||||||
# Log to report
|
# Log to report
|
||||||
if not args.dry_run:
|
|
||||||
from nanochat.report import get_report
|
from nanochat.report import get_report
|
||||||
get_report().log(section="SFT", data=[
|
get_report().log(section="SFT", data=[
|
||||||
user_config, # CLI args
|
user_config, # CLI args
|
||||||
|
|||||||
Reference in New Issue
Block a user