delete torchao dependency, create our own exact API-matched version of Float8Linear, document it very well. for some poorly understood reason, the performance is not only ~identical but actually runs 3% faster. despite of it being significantly simpler and much less code. i don't fully understand why/how atm
This commit is contained in:
@@ -165,7 +165,9 @@ if args.fp8:
|
||||
if device_type != "cuda":
|
||||
print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag")
|
||||
else:
|
||||
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
||||
# our custom fp8 is simpler than torchao, written for exact API compatibility
|
||||
from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training
|
||||
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
||||
import torch.nn as nn
|
||||
|
||||
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)
|
||||
|
||||
Reference in New Issue
Block a user