add autodetect of device and related stuff. getting weird warnings/errors still, so wip

This commit is contained in:
karpathy
2025-10-16 10:26:19 -07:00
parent 279b74312c
commit 786119d593
4 changed files with 25 additions and 11 deletions
+10
View File
@@ -89,6 +89,16 @@ def get_dist_info():
else:
return False, 0, 0, 1
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available():
device_type = "cuda"
if torch.backends.mps.is_available():
device_type = "mps"
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""