tune the data mixture a bit, load optimizer by default when SFT. These were confirmed to be best settings from sweeps of sft
This commit is contained in:
@@ -186,6 +186,9 @@ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
|||||||
if step is None:
|
if step is None:
|
||||||
step = find_last_step(checkpoint_dir)
|
step = find_last_step(checkpoint_dir)
|
||||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||||
|
if not os.path.exists(optimizer_path):
|
||||||
|
log0(f"Optimizer checkpoint not found: {optimizer_path}")
|
||||||
|
return None
|
||||||
log0(f"Loading optimizer state from {optimizer_path}")
|
log0(f"Loading optimizer state from {optimizer_path}")
|
||||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||||
return optimizer_data
|
return optimizer_data
|
||||||
|
|||||||
+21
-8
@@ -43,7 +43,7 @@ parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (e
|
|||||||
# Model loading
|
# Model loading
|
||||||
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
|
||||||
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
|
||||||
parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
|
parser.add_argument("--load-optimizer", type=int, default=1, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
|
||||||
# Training horizon
|
# Training horizon
|
||||||
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
|
||||||
# Batch sizes (default: inherit from pretrained checkpoint)
|
# Batch sizes (default: inherit from pretrained checkpoint)
|
||||||
@@ -64,6 +64,9 @@ parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number o
|
|||||||
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
|
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
|
||||||
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
|
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
|
||||||
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
|
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
|
||||||
|
# Data mixture
|
||||||
|
parser.add_argument("--mmlu-epochs", type=int, default=3, help="number of epochs of MMLU in training mixture (teaches Multiple Choice)")
|
||||||
|
parser.add_argument("--gsm8k-epochs", type=int, default=4, help="number of epochs of GSM8K in training mixture (teaches Math and Tool Use)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
user_config = vars(args).copy()
|
user_config = vars(args).copy()
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -132,12 +135,21 @@ token_bytes = get_token_bytes(device=device)
|
|||||||
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
|
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
|
||||||
|
|
||||||
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
|
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
|
||||||
|
# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the
|
||||||
|
# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and
|
||||||
|
# restore our fresh SFT LRs after loading.
|
||||||
base_dir = get_base_dir()
|
base_dir = get_base_dir()
|
||||||
if args.load_optimizer:
|
if args.load_optimizer:
|
||||||
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
|
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
|
||||||
|
if optimizer_data is not None:
|
||||||
|
base_lrs = [group["lr"] for group in optimizer.param_groups]
|
||||||
optimizer.load_state_dict(optimizer_data)
|
optimizer.load_state_dict(optimizer_data)
|
||||||
del optimizer_data
|
del optimizer_data
|
||||||
print0("Loaded optimizer state from pretrained checkpoint")
|
for group, base_lr in zip(optimizer.param_groups, base_lrs):
|
||||||
|
group["lr"] = base_lr
|
||||||
|
print0("Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset)")
|
||||||
|
else:
|
||||||
|
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
|
||||||
|
|
||||||
# Override the initial learning rate as a fraction of the base learning rate
|
# Override the initial learning rate as a fraction of the base learning rate
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
@@ -146,16 +158,17 @@ for group in optimizer.param_groups:
|
|||||||
|
|
||||||
# SFT data mixture and DataLoader
|
# SFT data mixture and DataLoader
|
||||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||||
train_dataset = TaskMixture([
|
train_tasks = [
|
||||||
SmolTalk(split="train"), # 460K rows of general conversations
|
SmolTalk(split="train"), # 460K rows of general conversations
|
||||||
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
|
||||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
|
||||||
GSM8K(subset="main", split="train"), # 2 epochs of GSM8K
|
|
||||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these
|
||||||
|
*[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch
|
||||||
|
*[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch
|
||||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||||
]) # total: 460K + 100K + 16K + 200K + 80K = 856K rows
|
]
|
||||||
|
train_dataset = TaskMixture(train_tasks)
|
||||||
|
print0(f"Training mixture: {len(train_dataset):,} rows (MMLU x{args.mmlu_epochs}, GSM8K x{args.gsm8k_epochs})")
|
||||||
val_dataset = TaskMixture([
|
val_dataset = TaskMixture([
|
||||||
SmolTalk(split="test"), # 24K rows in test set
|
SmolTalk(split="test"), # 24K rows in test set
|
||||||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||||
|
|||||||
Reference in New Issue
Block a user