add support for CPU and for MPS. I had to change a few cosmetic things. I also discovered I think a bit of a bug, where I was casting wte to bfloat16 in the wrong place (the model init) instead of in init_weights

This commit is contained in:
karpathy
2025-10-16 10:04:43 -07:00
parent 722da4f543
commit 306bc380ab
6 changed files with 68 additions and 46 deletions
+8 -6
View File
@@ -89,11 +89,14 @@ def get_dist_info():
else:
return False, 0, 0, 1
def compute_init(device_type="cuda"): # cuda|cpu
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""
# CUDA is currently required
# assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
# Reproducibility
torch.manual_seed(42)
@@ -101,11 +104,10 @@ def compute_init(device_type="cuda"): # cuda|cpu
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# Precision
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()