delete autocast, an unnecessary thorn in my side, manage dtypes directly

This commit is contained in:
Andrej Karpathy
2026-03-04 23:55:24 +00:00
parent 752abc836e
commit 1076f97059
15 changed files with 258 additions and 167 deletions
+6 -6
View File
@@ -72,6 +72,8 @@ generates a different graph. Numerics are bitwise identical in eager mode.
import torch
import torch.nn as nn
from nanochat.common import COMPUTE_DTYPE
# Avoid division by zero when computing scale from an all-zeros tensor
EPS = 1e-12
@@ -123,7 +125,7 @@ def _to_col_major(x):
class _Float8Matmul(torch.autograd.Function):
"""Custom autograd for the three FP8 GEMMs of a Linear layer.
The forward quantizes input and weight to FP8 and saves
The forward quantizes input and weight to FP8 and saves
the quantized tensors + scales for backward.
"""
@@ -198,11 +200,9 @@ class Float8Linear(nn.Linear):
"""
def forward(self, input):
# Replicate the autocast behavior of F.linear — when autocast is active,
# we need to manually cast input to the autocast dtype (e.g. bf16),
# since we bypass F.linear's built-in autocast handling.
if torch.is_autocast_enabled():
input = input.to(torch.get_autocast_gpu_dtype())
# Cast input to COMPUTE_DTYPE (typically bf16) since _scaled_mm expects
# reduced precision input, and we no longer rely on autocast to do this.
input = input.to(COMPUTE_DTYPE)
# _scaled_mm only works on 2D tensors, so flatten batch dimensions
orig_shape = input.shape
input_2d = input.reshape(-1, orig_shape[-1])