Fix bug in setting precision (#538)
This commit is contained in:
committed by
Andrej Karpathy
parent
cac43e8511
commit
ad55575326
+1
-1
@@ -170,7 +170,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||||||
|
|
||||||
# Precision
|
# Precision
|
||||||
if device_type == "cuda":
|
if device_type == "cuda":
|
||||||
torch.backends.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
|
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
|
||||||
|
|
||||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||||
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||||
|
|||||||
Reference in New Issue
Block a user