fix: safe DDP cleanup (check initialized PG, not just env) (#256)

This commit is contained in:
Dipesh Babu
2025-12-27 23:27:40 -05:00
committed by GitHub
parent 91d76cc690
commit 2f2d7ab80c
+20 -8
View File
@@ -113,12 +113,24 @@ def print_banner():
""" """
print0(banner) print0(banner)
def is_ddp(): def is_ddp_requested() -> bool:
# TODO is there a proper way """
return int(os.environ.get('RANK', -1)) != -1 True if launched by torchrun (env present), even before init.
Used to decide whether we *should* initialize a PG.
"""
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
def is_ddp_initialized() -> bool:
"""
True if torch.distributed is available and the process group is initialized.
Used at cleanup to avoid destroying a non-existent PG.
"""
return dist.is_available() and dist.is_initialized()
def get_dist_info(): def get_dist_info():
if is_ddp(): if is_ddp_requested():
# We rely on torchrun's env to decide if we SHOULD init.
# (Initialization itself happens in compute init.)
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ['RANK']) ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK'])
@@ -161,8 +173,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp and device_type == "cuda": if is_ddp_requested and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank) device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device) dist.init_process_group(backend="nccl", device_id=device)
@@ -173,11 +185,11 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
if ddp_rank == 0: if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}") logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup(): def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit""" """Companion function to compute_init, to clean things up before script exit"""
if is_ddp(): if is_ddp_initialized():
dist.destroy_process_group() dist.destroy_process_group()
class DummyWandb: class DummyWandb: