fix: pass device_type to compute_init in engine.__main__ (#451)
When running engine.py directly on non-GPU devices (CPU, MPS), compute_init() needs the device_type parameter to initialize correctly. This fixes failures on machines without CUDA support.
This commit is contained in:
+1
-1
@@ -306,8 +306,8 @@ if __name__ == "__main__":
|
|||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
# init compute
|
# init compute
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
|
||||||
device_type = autodetect_device_type()
|
device_type = autodetect_device_type()
|
||||||
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
|
|
||||||
# load the model and tokenizer
|
# load the model and tokenizer
|
||||||
|
|||||||
Reference in New Issue
Block a user