fix tf32 warning for deprecated api use
This commit is contained in:
+1
-1
@@ -158,7 +158,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
|
||||
# Precision
|
||||
if device_type == "cuda":
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
||||
Reference in New Issue
Block a user