Save FP8 tensors in autograd ctx instead of full-precision inputs
Store quantized input/weight and their inverse scales in _Float8Matmul ctx to avoid re-quantization in backward and reduce saved-activation memory without changing numerics.
This commit is contained in:
+4
-9
@@ -123,19 +123,16 @@ def _to_col_major(x):
|
|||||||
class _Float8Matmul(torch.autograd.Function):
|
class _Float8Matmul(torch.autograd.Function):
|
||||||
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
|
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
|
||||||
|
|
||||||
The forward saves input and weight in their original precision for the
|
The forward quantizes input and weight to FP8 and saves
|
||||||
backward pass. Each GEMM independently re-quantizes its operands to FP8.
|
the quantized tensors + scales for backward.
|
||||||
(We don't reuse the forward's FP8 tensors in backward — the backward might
|
|
||||||
want different precision, and saving FP8 would lose information.)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_2d, weight):
|
def forward(ctx, input_2d, weight):
|
||||||
ctx.save_for_backward(input_2d, weight)
|
|
||||||
|
|
||||||
# Quantize both operands to e4m3 (higher precision format)
|
# Quantize both operands to e4m3 (higher precision format)
|
||||||
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
input_fp8, input_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
||||||
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
weight_fp8, weight_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
||||||
|
ctx.save_for_backward(input_fp8, input_inv, weight_fp8, weight_inv)
|
||||||
|
|
||||||
# output = input @ weight.T
|
# output = input @ weight.T
|
||||||
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
|
# input_fp8 is [B, K] contiguous = row-major (good for first arg)
|
||||||
@@ -156,13 +153,12 @@ class _Float8Matmul(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input_2d, weight = ctx.saved_tensors
|
in_fp8, in_inv, w_fp8, w_inv = ctx.saved_tensors
|
||||||
|
|
||||||
# === GEMM 1: grad_input = grad_output @ weight ===
|
# === GEMM 1: grad_input = grad_output @ weight ===
|
||||||
# Shapes: [B, N] @ [N, K] -> [B, K]
|
# Shapes: [B, N] @ [N, K] -> [B, K]
|
||||||
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
|
# Gradients use e5m2 (wider range), weights use e4m3 (higher precision)
|
||||||
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
|
go_fp8, go_inv = _to_fp8(grad_output, torch.float8_e5m2)
|
||||||
w_fp8, w_inv = _to_fp8(weight, torch.float8_e4m3fn)
|
|
||||||
# go_fp8 is [B, N] contiguous = row-major, good for first arg
|
# go_fp8 is [B, N] contiguous = row-major, good for first arg
|
||||||
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
|
# w_fp8 is [N, K] contiguous = row-major, need column-major for second arg
|
||||||
w_col = _to_col_major(w_fp8)
|
w_col = _to_col_major(w_fp8)
|
||||||
@@ -178,7 +174,6 @@ class _Float8Matmul(torch.autograd.Function):
|
|||||||
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
||||||
# Shapes: [N, B] @ [B, K] -> [N, K]
|
# Shapes: [N, B] @ [B, K] -> [N, K]
|
||||||
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
|
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
|
||||||
in_fp8, in_inv = _to_fp8(input_2d, torch.float8_e4m3fn)
|
|
||||||
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
||||||
# Transposing gives column-major, but first arg needs row-major,
|
# Transposing gives column-major, but first arg needs row-major,
|
||||||
# so we must call .contiguous() to physically rearrange the memory.
|
# so we must call .contiguous() to physically rearrange the memory.
|
||||||
|
|||||||
Reference in New Issue
Block a user