Removed redundant qunatization of gradients

This commit is contained in:
Alan
2026-02-15 15:41:33 +00:00
committed by GitHub
parent d9678ff0f9
commit 124f49be98
+3 -4
View File
@@ -173,16 +173,15 @@ class _Float8Matmul(torch.autograd.Function):
# === GEMM 2: grad_weight = grad_output.T @ input === # === GEMM 2: grad_weight = grad_output.T @ input ===
# Shapes: [N, B] @ [B, K] -> [N, K] # Shapes: [N, B] @ [B, K] -> [N, K]
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2) # go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
# Transposing gives column-major, but first arg needs row-major, # Transposing gives column-major, but first arg needs row-major,
# so we must call .contiguous() to physically rearrange the memory. # so we must call .contiguous() to physically rearrange the memory.
go_T = go_fp8_2.t().contiguous() # [N, B] row-major go_T = go_fp8.t().contiguous() # [N, B] row-major
in_col = _to_col_major(in_fp8) # [B, K] column-major in_col = _to_col_major(in_fp8) # [B, K] column-major
grad_weight = torch._scaled_mm( grad_weight = torch._scaled_mm(
go_T, go_T,
in_col, in_col,
scale_a=go_inv_2, scale_a=go_inv,
scale_b=in_inv, scale_b=in_inv,
out_dtype=grad_output.dtype, out_dtype=grad_output.dtype,
use_fast_accum=False, use_fast_accum=False,