add fp8 training with torchao
This commit is contained in:
+91
-8
@@ -16,7 +16,7 @@ import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from contextlib import nullcontext, contextmanager
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
@@ -39,6 +39,9 @@ parser = argparse.ArgumentParser(description="Pretrain base model")
|
||||
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
|
||||
# Runtime
|
||||
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
|
||||
# FP8 training
|
||||
parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)")
|
||||
parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)")
|
||||
# Model architecture
|
||||
parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model")
|
||||
parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio")
|
||||
@@ -65,7 +68,7 @@ parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR a
|
||||
parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)")
|
||||
# Evaluation
|
||||
parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)")
|
||||
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
|
||||
parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)")
|
||||
parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric")
|
||||
parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)")
|
||||
@@ -177,11 +180,11 @@ if resuming:
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
del model_data # free up this memory after the copy
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
# -----------------------------------------------------------------------------
|
||||
# Determine the length of the training run based on model size
|
||||
|
||||
# Detailed parameter counts
|
||||
param_counts = orig_model.num_scaling_params()
|
||||
param_counts = model.num_scaling_params()
|
||||
print0(f"Parameter counts:")
|
||||
for key, value in param_counts.items():
|
||||
print0(f"{key:24s}: {value:,}")
|
||||
@@ -211,6 +214,85 @@ print0(f"Total number of training tokens: {total_tokens:,}")
|
||||
print0(f"Tokens : Scaling params ratio: {args.total_batch_size * num_iterations / num_scaling_params:.2f}") # Chinchilla is ~20
|
||||
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# FP8 training initialization and management (has to be done before torch.compile)
|
||||
|
||||
# Convert Linear layers to Float8Linear if --fp8 is set
|
||||
if args.fp8:
|
||||
if device_type != "cuda":
|
||||
print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
|
||||
else:
|
||||
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
||||
import torch.nn as nn
|
||||
|
||||
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)
|
||||
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
|
||||
if not isinstance(mod, nn.Linear):
|
||||
return False
|
||||
# FP8 requires both in_features and out_features divisible by 16
|
||||
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
|
||||
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
|
||||
num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
||||
num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers
|
||||
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)")
|
||||
|
||||
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
|
||||
@contextmanager
|
||||
def disable_fp8(model):
|
||||
"""Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation.
|
||||
|
||||
CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead,
|
||||
we swap out Float8Linear modules entirely and restore them after.
|
||||
"""
|
||||
import torch.nn as nn
|
||||
|
||||
# Find all Float8Linear modules and their locations
|
||||
fp8_locations = [] # list of (parent_module, attr_name, fp8_module)
|
||||
for name, module in model.named_modules():
|
||||
if 'Float8' in type(module).__name__:
|
||||
if '.' in name:
|
||||
parent_name, attr_name = name.rsplit('.', 1)
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent = model
|
||||
attr_name = name
|
||||
fp8_locations.append((parent, attr_name, module))
|
||||
|
||||
if not fp8_locations:
|
||||
yield # No FP8 modules, nothing to do
|
||||
return
|
||||
|
||||
# Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy)
|
||||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
linear = nn.Linear(
|
||||
fp8_module.in_features,
|
||||
fp8_module.out_features,
|
||||
bias=fp8_module.bias is not None,
|
||||
device=fp8_module.weight.device,
|
||||
dtype=fp8_module.weight.dtype,
|
||||
)
|
||||
linear.weight = fp8_module.weight # share, don't copy
|
||||
if fp8_module.bias is not None:
|
||||
linear.bias = fp8_module.bias
|
||||
setattr(parent, attr_name, linear)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore Float8Linear modules
|
||||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
setattr(parent, attr_name, fp8_module)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Compile the model
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
adam_betas = (args.adam_beta1, args.adam_beta2)
|
||||
@@ -287,7 +369,7 @@ while True:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
with disable_fp8(model), autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
@@ -302,10 +384,11 @@ while True:
|
||||
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
# disable FP8 for evaluation to use BF16 for more consistent/accurate results
|
||||
results = {}
|
||||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
with disable_fp8(orig_model), autocast_ctx:
|
||||
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
@@ -332,7 +415,7 @@ while True:
|
||||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
with disable_fp8(orig_model), autocast_ctx:
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
|
||||
Reference in New Issue
Block a user