fix: use meta device in disable_fp8 to avoid VRAM spike (#616)
When swapping Float8Linear to Linear in disable_fp8 context manager, using device=fp8_module.weight.device directly allocates new tensors on GPU, causing unnecessary VRAM spike (~1GB for large models). This fix uses device='meta' to avoid physical memory allocation, then swaps in the weight tensor reference. This eliminates the unnecessary VRAM spike during evaluation phase. Fixes issue #592 Co-authored-by: RoomWithOutRoof <roomwithoutroof@sparklab.ai>
This commit is contained in:
@@ -218,12 +218,13 @@ def disable_fp8(model):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
|
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
|
||||||
|
# Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards
|
||||||
for parent, attr_name, fp8_module in fp8_locations:
|
for parent, attr_name, fp8_module in fp8_locations:
|
||||||
linear = Linear(
|
linear = Linear(
|
||||||
fp8_module.in_features,
|
fp8_module.in_features,
|
||||||
fp8_module.out_features,
|
fp8_module.out_features,
|
||||||
bias=fp8_module.bias is not None,
|
bias=fp8_module.bias is not None,
|
||||||
device=fp8_module.weight.device,
|
device="meta", # Use meta device to avoid unnecessary VRAM allocation
|
||||||
dtype=fp8_module.weight.dtype,
|
dtype=fp8_module.weight.dtype,
|
||||||
)
|
)
|
||||||
linear.weight = fp8_module.weight # share, don't copy
|
linear.weight = fp8_module.weight # share, don't copy
|
||||||
|
|||||||
Reference in New Issue
Block a user