delete autocast, an unnecessary thorn in my side, manage dtypes directly
This commit is contained in:
+40
-15
@@ -19,14 +19,15 @@ import time
|
||||
import math
|
||||
import argparse
|
||||
from dataclasses import asdict
|
||||
from contextlib import nullcontext, contextmanager
|
||||
from contextlib import contextmanager
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.gpt import GPT, GPTConfig, Linear
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
@@ -86,7 +87,6 @@ user_config = vars(args).copy() # for logging
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
@@ -95,17 +95,23 @@ if device_type == "cuda":
|
||||
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
|
||||
else:
|
||||
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if HAS_FA3:
|
||||
from nanochat.flash_attention import USE_FA3
|
||||
using_fa3 = USE_FA3
|
||||
if using_fa3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
|
||||
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
|
||||
else:
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Training will be less efficient without FA3")
|
||||
if args.window_pattern != "L":
|
||||
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||
@@ -213,9 +219,9 @@ def disable_fp8(model):
|
||||
yield # No FP8 modules, nothing to do
|
||||
return
|
||||
|
||||
# Swap Float8Linear -> nn.Linear (shares the same weight tensor, no copy)
|
||||
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
|
||||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
linear = nn.Linear(
|
||||
linear = Linear(
|
||||
fp8_module.in_features,
|
||||
fp8_module.out_features,
|
||||
bias=fp8_module.bias is not None,
|
||||
@@ -315,6 +321,12 @@ if resuming:
|
||||
optimizer.load_state_dict(optimizer_data)
|
||||
del optimizer_data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32)
|
||||
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
||||
if scaler is not None:
|
||||
print0("GradScaler enabled for fp16 training")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
@@ -405,7 +417,7 @@ while True:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with disable_fp8(model), autocast_ctx:
|
||||
with disable_fp8(model):
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
@@ -424,7 +436,7 @@ while True:
|
||||
results = {}
|
||||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||
model.eval()
|
||||
with disable_fp8(orig_model), autocast_ctx:
|
||||
with disable_fp8(orig_model):
|
||||
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
@@ -451,7 +463,7 @@ while True:
|
||||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with disable_fp8(orig_model), autocast_ctx:
|
||||
with disable_fp8(orig_model):
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
@@ -491,11 +503,13 @@ while True:
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
if scaler is not None:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(step)
|
||||
@@ -506,7 +520,18 @@ while True:
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
group["weight_decay"] = muon_weight_decay
|
||||
optimizer.step()
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
# In distributed training, all ranks must agree on whether to skip the step.
|
||||
# Each rank may independently encounter inf/nan gradients, so we all-reduce
|
||||
# the found_inf flag (MAX = if any rank found inf, all ranks skip).
|
||||
if is_ddp_initialized():
|
||||
for v in scaler._found_inf_per_device(optimizer).values():
|
||||
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
train_loss_f = train_loss.item() # .item() is a CPU-GPU sync point
|
||||
synchronize()
|
||||
|
||||
Reference in New Issue
Block a user