fix minor bug in fp8 application to skip tiny matmuls
This commit is contained in:
@@ -170,20 +170,22 @@ if args.fp8:
|
|||||||
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# Filter: only convert layers with dimensions divisible by 16 (FP8 hardware requirement)
|
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
|
||||||
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
|
def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
|
||||||
if not isinstance(mod, nn.Linear):
|
if not isinstance(mod, nn.Linear):
|
||||||
return False
|
return False
|
||||||
# FP8 requires both in_features and out_features divisible by 16
|
|
||||||
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
|
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
|
||||||
return False
|
return False
|
||||||
|
if min(mod.in_features, mod.out_features) < 128:
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
|
fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe)
|
||||||
|
num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
||||||
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
|
convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter)
|
||||||
num_fp8_layers = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__)
|
||||||
num_skipped = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) - num_fp8_layers
|
num_skipped = num_linear - num_fp8
|
||||||
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8_layers} layers, skipped {num_skipped} (dims not divisible by 16)")
|
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
|
||||||
|
|
||||||
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
|
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|||||||
Reference in New Issue
Block a user