Merge branch 'Chetter2-patch-1'

This commit is contained in:
Andrej Karpathy
2026-02-18 23:17:39 +00:00
2 changed files with 14 additions and 18 deletions
+7 -13
View File
@@ -123,19 +123,16 @@ def _to_col_major(x):
class _Float8Matmul(torch.autograd.Function): class _Float8Matmul(torch.autograd.Function):
"""Custom autograd for the three FP8 GEMMs of a Linear layer. """Custom autograd for the three FP8 GEMMs of a Linear layer.
The forward saves input and weight in their original precision for the The forward quantizes input and weight to FP8 and saves
backward pass. Each GEMM independently re-quantizes its operands to FP8. the quantized tensors + scales for backward.
(We don't reuse the forward's FP8 tensors in backward — the backward might
want different precision, and saving FP8 would lose information.)
""" """
@staticmethod @staticmethod
def forward(ctx, input_2d, weight): def forward(ctx, input_2d, weight):
ctx.save_for_backward(input_2d, weight)
# Quantize both operands to e4m3 (higher precision format) # Quantize both operands to e4m3 (higher precision format)
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn) input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn) weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
# output = input @ weight.T # output = input @ weight.T
# input_fp8 is [B, K] contiguous = row-major (good for first arg) # input_fp8 is [B, K] contiguous = row-major (good for first arg)
@@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input_2d, weight = ctx.saved_tensors in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
# === GEMM 1: grad_input = grad_output @ weight === # === GEMM 1: grad_input = grad_output @ weight ===
# Shapes: [B, N] @ [N, K] -> [B, K] # Shapes: [B, N] @ [N, K] -> [B, K]
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision) # Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2) go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
# go_fp8 is [B, N] contiguous = row-major, good for first arg # go_fp8 is [B, N] contiguous = row-major, good for first arg
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg # w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
w_col = _to_col_major(w_fp8) w_col = _to_col_major(w_fp8)
@@ -177,17 +173,15 @@ class _Float8Matmul(torch.autograd.Function):
# === GEMM 2: grad_weight = grad_output.T @ input === # === GEMM 2: grad_weight = grad_output.T @ input ===
# Shapes: [N, B] @ [B, K] -> [N, K] # Shapes: [N, B] @ [B, K] -> [N, K]
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) # go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# Transposing gives column-major, but first arg needs row-major, # Transposing gives column-major, but first arg needs row-major,
# so we must call .contiguous() to physically rearrange the memory. # so we must call .contiguous() to physically rearrange the memory.
go_T = go_fp8_2.t().contiguous() # [N, B] row-major go_T = go_fp8.t().contiguous() # [N, B] row-major
in_col = _to_col_major(in_fp8) # [B, K] column-major in_col = _to_col_major(in_fp8) # [B, K] column-major
grad_weight = torch._scaled_mm( grad_weight = torch._scaled_mm(
go_T, go_T,
in_col, in_col,
scale_a=go_inv_2, scale_a=go_inv,
scale_b=in_inv, scale_b=in_inv,
out_dtype=grad_output.dtype, out_dtype=grad_output.dtype,
use_fast_accum=False, use_fast_accum=False,
+7 -5
View File
@@ -170,20 +170,22 @@ if args.fp8:
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training # from torchao.float8 import Float8LinearConfig, convert_to_float8_training
import torch.nn as nn import torch.nn as nn
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement) # Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
if not isinstance(mod, nn.Linear): if not isinstance(mod, nn.Linear):
return False return False
# FP8 requires both in_features and out_features divisible by 16
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False return False
if min(mod.in_features, mod.out_features) < 128:
return False
return True return True
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers num_skipped = num_linear - num_fp8
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)") print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16 # Context manager to temporarily disable FP8 so that model evaluation remains in BF16
@contextmanager @contextmanager