fix: use meta device in disable_fp8 to avoid VRAM spike (#616)

When swapping Float8Linear to Linear in disable_fp8 context manager,
using device=fp8_module.weight.device directly allocates new tensors
on GPU, causing unnecessary VRAM spike (~1GB for large models).

This fix uses device='meta' to avoid physical memory allocation,
then swaps in the weight tensor reference. This eliminates the
unnecessary VRAM spike during evaluation phase.

Fixes issue #592

Co-authored-by: RoomWithOutRoof <roomwithoutroof@sparklab.ai>
This commit is contained in:
RoomWithOutRoof
2026-03-26 05:24:57 +08:00
committed by GitHub
parent c0dbf1f3ff
commit 47e983eea7
+2 -1
View File
@@ -218,12 +218,13 @@ def disable_fp8(model):
return return
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype) # Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
# Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards
for parent, attr_name, fp8_module in fp8_locations: for parent, attr_name, fp8_module in fp8_locations:
linear = Linear( linear = Linear(
fp8_module.in_features, fp8_module.in_features,
fp8_module.out_features, fp8_module.out_features,
bias=fp8_module.bias is not None, bias=fp8_module.bias is not None,
device=fp8_module.weight.device, device="meta", # Use meta device to avoid unnecessary VRAM allocation
dtype=fp8_module.weight.dtype, dtype=fp8_module.weight.dtype,
) )
linear.weight = fp8_module.weight # share, don't copy linear.weight = fp8_module.weight # share, don't copy