delete torchao dependency, create our own exact API-matched version of Float8Linear, document it very well. for some poorly understood reason, the performance is not only ~identical but actually runs 3% faster. despite of it being significantly simpler and much less code. i don't fully understand why/how atm

This commit is contained in:
Andrej Karpathy
2026-02-10 18:46:39 +00:00
parent 1ec0a34779
commit e569b59f92
4 changed files with 275 additions and 13 deletions
+3 -1
View File
@@ -165,7 +165,9 @@ if args.fp8:
if device_type != "cuda":
print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
else:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
# our custom fp8 is simpler than torchao, written for exact API compatibility
from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
import torch.nn as nn
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)