273 lines
12 KiB
Python
273 lines
12 KiB
Python
"""Minimal FP8 training for nanochat — tensorwise dynamic scaling only.
|
|
|
|
Drop-in replacement for torchao's Float8Linear (~2000 lines) with ~150 lines.
|
|
We only need the "tensorwise" recipe (one scalar scale per tensor), not the full
|
|
generality of torchao (rowwise scaling, FSDP float8 all-gather, DTensor, tensor
|
|
subclass dispatch tables, etc.)
|
|
|
|
How FP8 training works
|
|
======================
|
|
A standard Linear layer does one matmul in forward and two in backward:
|
|
forward: output = input @ weight.T
|
|
backward: grad_input = grad_output @ weight
|
|
grad_weight= grad_output.T @ input
|
|
|
|
FP8 training wraps each of these three matmuls with:
|
|
1. Compute scale = FP8_MAX / max(|tensor|) for each operand
|
|
2. Quantize: fp8_tensor = clamp(tensor * scale, -FP8_MAX, FP8_MAX).to(fp8)
|
|
3. Matmul via torch._scaled_mm (cuBLAS FP8 kernel, ~2x faster than bf16)
|
|
4. Dequantize: _scaled_mm handles this internally using the inverse scales
|
|
|
|
The key insight: torch._scaled_mm and the float8 dtypes are PyTorch built-ins.
|
|
torchao is just orchestration around these primitives. We can call them directly.
|
|
|
|
FP8 dtype choice
|
|
================
|
|
There are two FP8 formats. We use both, following the standard convention:
|
|
- float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448]
|
|
Higher precision (more mantissa bits), used for input and weight.
|
|
- float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344]
|
|
Wider range (more exponent bits), used for gradients which can be large.
|
|
|
|
torch._scaled_mm layout requirements
|
|
=====================================
|
|
The cuBLAS FP8 kernel requires specific memory layouts:
|
|
- First argument (A): must be row-major (contiguous)
|
|
- Second argument (B): must be column-major (B.t().contiguous().t())
|
|
If B is obtained by transposing a contiguous tensor (e.g. weight.t()), it is
|
|
already column-major — no copy needed. Otherwise we use _to_col_major().
|
|
|
|
How this differs from torchao's approach
|
|
========================================
|
|
torchao uses a "tensor subclass" architecture: Float8TrainingTensor is a subclass
|
|
of torch.Tensor that bundles FP8 data + scale + metadata. It implements
|
|
__torch_dispatch__ with a dispatch table that intercepts every aten op (mm, t,
|
|
reshape, clone, ...) and handles it in FP8-aware fashion. When you call
|
|
output = input @ weight.T
|
|
the @ operator dispatches to aten.mm, which gets intercepted and routed to
|
|
torch._scaled_mm behind the scenes. This is ~2000 lines of code because you need
|
|
a handler for every tensor operation that might touch an FP8 tensor.
|
|
|
|
We take a simpler approach: a single autograd.Function (_Float8Matmul) that takes
|
|
full-precision inputs, quantizes to FP8 internally, calls _scaled_mm, and returns
|
|
full-precision outputs. Marked @allow_in_graph so torch.compile treats it as one
|
|
opaque node rather than trying to trace inside.
|
|
|
|
The trade-off is in how torch.compile sees the two approaches:
|
|
- torchao: compile decomposes the tensor subclass (via __tensor_flatten__) and
|
|
sees every individual op (amax, scale, cast, _scaled_mm) as separate graph
|
|
nodes. Inductor can fuse these with surrounding operations (e.g. fuse the
|
|
amax computation with the preceding layer's activation function).
|
|
- ours: compile sees a single opaque call. It can optimize everything around
|
|
the FP8 linear (attention, norms, etc.) but cannot fuse across the boundary.
|
|
|
|
Both call the exact same cuBLAS _scaled_mm kernel — the GPU matmul is identical.
|
|
The difference is only in the "glue" ops (amax, scale, cast) which are tiny
|
|
compared to the matmul. In practice this means our version is slightly faster
|
|
(less compilation overhead, no tensor subclass dispatch cost) but can produce
|
|
subtly different floating-point rounding paths under torch.compile, since Inductor
|
|
generates a different graph. Numerics are bitwise identical in eager mode.
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# Avoid division by zero when computing scale from an all-zeros tensor
|
|
EPS = 1e-12
|
|
|
|
|
|
@torch.no_grad()
|
|
def _to_fp8(x, fp8_dtype):
|
|
"""Dynamically quantize a tensor to FP8 using tensorwise scaling.
|
|
|
|
"Tensorwise" means one scalar scale for the entire tensor (as opposed to
|
|
"rowwise" which computes a separate scale per row). Tensorwise is faster
|
|
because cuBLAS handles the scaling; rowwise needs the CUTLASS kernel.
|
|
|
|
Returns (fp8_data, inverse_scale) for use with torch._scaled_mm.
|
|
"""
|
|
fp8_max = torch.finfo(fp8_dtype).max
|
|
# Compute the max absolute value across the entire tensor
|
|
amax = x.float().abs().max()
|
|
# Scale maps [0, amax] -> [0, fp8_max]. Use float64 for the division to
|
|
# ensure consistent numerics between torch.compile and eager mode.
|
|
# (torchao does the same upcast — without it, compile/eager can diverge)
|
|
scale = fp8_max / amax.double().clamp(min=EPS)
|
|
scale = scale.float()
|
|
# Quantize: scale into FP8 range, saturate (clamp prevents overflow when
|
|
# casting — PyTorch's default is to wrap, not saturate), then cast to FP8
|
|
x_scaled = x.float() * scale
|
|
x_clamped = x_scaled.clamp(-fp8_max, fp8_max)
|
|
x_fp8 = x_clamped.to(fp8_dtype)
|
|
# _scaled_mm expects the *inverse* of our scale (it multiplies by this to
|
|
# convert FP8 values back to the original range during the matmul)
|
|
inv_scale = scale.reciprocal()
|
|
return x_fp8, inv_scale
|
|
|
|
|
|
def _to_col_major(x):
|
|
"""Rearrange a 2D tensor's memory to column-major layout.
|
|
|
|
torch._scaled_mm requires its second operand in column-major layout.
|
|
The trick: transpose -> contiguous (forces a copy in transposed order)
|
|
-> transpose back. The result has the same logical shape but column-major
|
|
strides, e.g. a [M, N] tensor gets strides (1, M) instead of (N, 1).
|
|
"""
|
|
return x.t().contiguous().t()
|
|
|
|
|
|
# allow_in_graph tells torch.compile to treat this as an opaque operation —
|
|
# dynamo won't try to decompose it into smaller ops. See the module docstring
|
|
# for how this differs from torchao's tensor subclass approach.
|
|
@torch._dynamo.allow_in_graph
|
|
class _Float8Matmul(torch.autograd.Function):
|
|
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
|
|
|
|
The forward saves input and weight in their original precision for the
|
|
backward pass. Each GEMM independently re-quantizes its operands to FP8.
|
|
(We don't reuse the forward's FP8 tensors in backward — the backward might
|
|
want different precision, and saving FP8 would lose information.)
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_2d, weight):
|
|
ctx.save_for_backward(input_2d, weight)
|
|
|
|
# Quantize both operands to e4m3 (higher precision format)
|
|
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
|
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
|
|
|
# output = input @ weight.T
|
|
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
|
|
# weight_fp8 is [N, K] contiguous, so weight_fp8.t() is [K, N] with
|
|
# strides (1, K) = column-major (good for second arg, no copy needed!)
|
|
output = torch._scaled_mm(
|
|
input_fp8,
|
|
weight_fp8.t(),
|
|
scale_a=input_inv,
|
|
scale_b=weight_inv,
|
|
out_dtype=input_2d.dtype,
|
|
# use_fast_accum=True accumulates the dot products in lower precision.
|
|
# Slightly less accurate but measurably faster. Standard practice for
|
|
# the forward pass; we use False in backward for more precise gradients.
|
|
use_fast_accum=True,
|
|
)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input_2d, weight = ctx.saved_tensors
|
|
|
|
# === GEMM 1: grad_input = grad_output @ weight ===
|
|
# Shapes: [B, N] @ [N, K] -> [B, K]
|
|
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
|
|
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
|
|
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
|
# go_fp8 is [B, N] contiguous = row-major, good for first arg
|
|
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
|
|
w_col = _to_col_major(w_fp8)
|
|
grad_input = torch._scaled_mm(
|
|
go_fp8,
|
|
w_col,
|
|
scale_a=go_inv,
|
|
scale_b=w_inv,
|
|
out_dtype=grad_output.dtype,
|
|
use_fast_accum=False,
|
|
)
|
|
|
|
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
|
# Shapes: [N, B] @ [B, K] -> [N, K]
|
|
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
|
|
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
|
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
|
# Transposing gives column-major, but first arg needs row-major,
|
|
# so we must call .contiguous() to physically rearrange the memory.
|
|
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
|
|
in_col = _to_col_major(in_fp8) # [B, K] column-major
|
|
grad_weight = torch._scaled_mm(
|
|
go_T,
|
|
in_col,
|
|
scale_a=go_inv_2,
|
|
scale_b=in_inv,
|
|
out_dtype=grad_output.dtype,
|
|
use_fast_accum=False,
|
|
)
|
|
|
|
return grad_input, grad_weight
|
|
|
|
|
|
class Float8Linear(nn.Linear):
|
|
"""Drop-in nn.Linear replacement that does FP8 compute.
|
|
|
|
Weights and biases remain in their original precision (e.g. fp32/bf16).
|
|
Only the matmul is performed in FP8 via the _Float8Matmul autograd function.
|
|
"""
|
|
|
|
def forward(self, input):
|
|
# Replicate the autocast behavior of F.linear — when autocast is active,
|
|
# we need to manually cast input to the autocast dtype (e.g. bf16),
|
|
# since we bypass F.linear's built-in autocast handling.
|
|
if torch.is_autocast_enabled():
|
|
input = input.to(torch.get_autocast_gpu_dtype())
|
|
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
|
|
orig_shape = input.shape
|
|
input_2d = input.reshape(-1, orig_shape[-1])
|
|
output = _Float8Matmul.apply(input_2d, self.weight)
|
|
output = output.reshape(*orig_shape[:-1], output.shape[-1])
|
|
if self.bias is not None:
|
|
output = output + self.bias.to(output.dtype)
|
|
return output
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
"""Create Float8Linear from nn.Linear, sharing the same weight and bias.
|
|
|
|
Uses meta device to avoid allocating a temporary weight tensor — we
|
|
create the module shell on meta (shapes/dtypes only, no memory), then
|
|
point .weight and .bias to the original module's parameters.
|
|
"""
|
|
with torch.device("meta"):
|
|
new_mod = cls(mod.in_features, mod.out_features, bias=False)
|
|
new_mod.weight = mod.weight
|
|
new_mod.bias = mod.bias
|
|
return new_mod
|
|
|
|
|
|
class Float8LinearConfig:
|
|
"""Minimal config matching torchao's API. Only tensorwise recipe is supported."""
|
|
|
|
@staticmethod
|
|
def from_recipe_name(recipe_name):
|
|
if recipe_name != "tensorwise":
|
|
raise ValueError(
|
|
f"Only 'tensorwise' recipe is supported, got '{recipe_name}'. "
|
|
f"Rowwise/axiswise recipes require the full torchao library."
|
|
)
|
|
return Float8LinearConfig()
|
|
|
|
|
|
def convert_to_float8_training(module, *, config=None, module_filter_fn=None):
|
|
"""Replace nn.Linear layers with Float8Linear throughout a module.
|
|
|
|
Walks the module tree in post-order (children before parents) and swaps
|
|
each nn.Linear that passes the optional filter. The new Float8Linear shares
|
|
the original weight and bias tensors — no copies, no extra memory.
|
|
|
|
Args:
|
|
module: Root module to convert.
|
|
config: Float8LinearConfig (accepted for API compat, only tensorwise supported).
|
|
module_filter_fn: Optional filter(module, fqn) -> bool. Only matching Linears
|
|
are converted. Common use: skip layers with dims not divisible by 16
|
|
(hardware requirement for FP8 matmuls on H100).
|
|
"""
|
|
def _convert(mod, prefix=""):
|
|
for name, child in mod.named_children():
|
|
fqn = f"{prefix}.{name}" if prefix else name
|
|
_convert(child, fqn)
|
|
if isinstance(child, nn.Linear) and not isinstance(child, Float8Linear):
|
|
if module_filter_fn is None or module_filter_fn(child, fqn):
|
|
setattr(mod, name, Float8Linear.from_float(child))
|
|
|
|
_convert(module)
|
|
return module
|