Removed redundant qunatization of gradients
This commit is contained in:
+3
-4
@@ -173,16 +173,15 @@ class _Float8Matmul(torch.autograd.Function):
|
|||||||
|
|
||||||
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
# === GEMM 2: grad_weight = grad_output.T @ input ===
|
||||||
# Shapes: [N, B] @ [B, K] -> [N, K]
|
# Shapes: [N, B] @ [B, K] -> [N, K]
|
||||||
go_fp8_2, go_inv_2 = _to_fp8(grad_output, torch.float8_e5m2)
|
# go_fp8 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
||||||
# go_fp8_2 is [B, N] contiguous, we need go.T = [N, B] as first arg.
|
|
||||||
# Transposing gives column-major, but first arg needs row-major,
|
# Transposing gives column-major, but first arg needs row-major,
|
||||||
# so we must call .contiguous() to physically rearrange the memory.
|
# so we must call .contiguous() to physically rearrange the memory.
|
||||||
go_T = go_fp8_2.t().contiguous() # [N, B] row-major
|
go_T = go_fp8.t().contiguous() # [N, B] row-major
|
||||||
in_col = _to_col_major(in_fp8) # [B, K] column-major
|
in_col = _to_col_major(in_fp8) # [B, K] column-major
|
||||||
grad_weight = torch._scaled_mm(
|
grad_weight = torch._scaled_mm(
|
||||||
go_T,
|
go_T,
|
||||||
in_col,
|
in_col,
|
||||||
scale_a=go_inv_2,
|
scale_a=go_inv,
|
||||||
scale_b=in_inv,
|
scale_b=in_inv,
|
||||||
out_dtype=grad_output.dtype,
|
out_dtype=grad_output.dtype,
|
||||||
use_fast_accum=False,
|
use_fast_accum=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user