delete autocast, an unnecessary thorn in my side, manage dtypes directly
This commit is contained in:
+24
-12
@@ -16,8 +16,7 @@ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
@@ -75,7 +74,7 @@ user_config = vars(args).copy()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
if device_type == "cuda":
|
||||
@@ -151,6 +150,11 @@ if args.load_optimizer:
|
||||
else:
|
||||
print0("WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse)")
|
||||
|
||||
# GradScaler for fp16 training (bf16/fp32 don't need it)
|
||||
scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None
|
||||
if scaler is not None:
|
||||
print0("GradScaler enabled for fp16 training")
|
||||
|
||||
# Override the initial learning rate as a fraction of the base learning rate
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["lr"] * args.init_lr_frac
|
||||
@@ -344,8 +348,7 @@ while True:
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
@@ -373,9 +376,8 @@ while True:
|
||||
for task_name in all_tasks:
|
||||
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
|
||||
max_problems = None if limit < 0 else limit # -1 means no limit
|
||||
with autocast_ctx:
|
||||
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
||||
batch_size=args.device_batch_size, max_problems=max_problems)
|
||||
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
|
||||
batch_size=args.device_batch_size, max_problems=max_problems)
|
||||
task_results[task_name] = acc
|
||||
print0(f" {task_name}: {100*acc:.2f}%")
|
||||
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
|
||||
@@ -428,11 +430,13 @@ while True:
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
loss = model(x, y)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
if scaler is not None:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
# step the optimizer
|
||||
@@ -442,7 +446,15 @@ while True:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
if group['kind'] == 'muon':
|
||||
group["momentum"] = muon_momentum
|
||||
optimizer.step()
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
if is_ddp_initialized():
|
||||
for v in scaler._found_inf_per_device(optimizer).values():
|
||||
dist.all_reduce(v, op=dist.ReduceOp.MAX)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
|
||||
Reference in New Issue
Block a user