""" Common utilities for nanochat. """ import os import re import logging import urllib.request import torch import torch.distributed as dist from filelock import FileLock class ColoredFormatter(logging.Formatter): """Custom formatter that adds colors to log messages.""" # ANSI color codes COLORS = { 'DEBUG': '\033[36m', # Cyan 'INFO': '\033[32m', # Green 'WARNING': '\033[33m', # Yellow 'ERROR': '\033[31m', # Red 'CRITICAL': '\033[35m', # Magenta } RESET = '\033[0m' BOLD = '\033[1m' def format(self, record): # Add color to the level name levelname = record.levelname if levelname in self.COLORS: record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" # Format the message message = super().format(record) # Add color to specific parts of the message if levelname == 'INFO': # Highlight numbers and percentages message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message) message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) return message def setup_default_logging(): handler = logging.StreamHandler() handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) logging.basicConfig( level=logging.INFO, handlers=[handler] ) setup_default_logging() logger = logging.getLogger(__name__) def get_base_dir(): # co-locate nanochat intermediates with other cached data in ~/.cache (by default) if os.environ.get("NANOCHAT_BASE_DIR"): nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR") else: home_dir = os.path.expanduser("~") cache_dir = os.path.join(home_dir, ".cache") nanochat_dir = os.path.join(cache_dir, "nanochat") os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir def download_file_with_lock(url, filename, postprocess_fn=None): """ Downloads a file from a URL to a local path in the base directory. Uses a lock file to prevent concurrent downloads among multiple ranks. """ base_dir = get_base_dir() file_path = os.path.join(base_dir, filename) lock_path = file_path + ".lock" if os.path.exists(file_path): return file_path with FileLock(lock_path): # Only a single rank can acquire this lock # All other ranks block until it is released # Recheck after acquiring lock if os.path.exists(file_path): return file_path # Download the content as bytes print(f"Downloading {url}...") with urllib.request.urlopen(url) as response: content = response.read() # bytes # Write to local file with open(file_path, 'wb') as f: f.write(content) print(f"Downloaded to {file_path}") # Run the postprocess function if provided if postprocess_fn is not None: postprocess_fn(file_path) return file_path def print0(s="",**kwargs): ddp_rank = int(os.environ.get('RANK', 0)) if ddp_rank == 0: print(s, **kwargs) def print_banner(): # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ banner = """ █████ █████ ░░███ ░░███ ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████ ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░ ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███ ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███ ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████ ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░ """ print0(banner) def is_ddp_requested() -> bool: """ True if launched by torchrun (env present), even before init. Used to decide whether we *should* initialize a PG. """ return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE")) def is_ddp_initialized() -> bool: """ True if torch.distributed is available and the process group is initialized. Used at cleanup to avoid destroying a non-existent PG. """ return dist.is_available() and dist.is_initialized() def get_dist_info(): if is_ddp_requested(): # We rely on torchrun's env to decide if we SHOULD init. # (Initialization itself happens in compute init.) assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_world_size = int(os.environ['WORLD_SIZE']) return True, ddp_rank, ddp_local_rank, ddp_world_size else: return False, 0, 0, 1 def autodetect_device_type(): # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU if torch.cuda.is_available(): device_type = "cuda" elif torch.backends.mps.is_available(): device_type = "mps" else: device_type = "cpu" print0(f"Autodetected device type: {device_type}") return device_type def compute_init(device_type="cuda"): # cuda|cpu|mps """Basic initialization that we keep doing over and over, so make common.""" assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm" if device_type == "cuda": assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'" if device_type == "mps": assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'" # Reproducibility # Note that we set the global seeds here, but most of the code uses explicit rng objects. # The only place where global rng might be used is nn.Module initialization of the model weights. torch.manual_seed(42) if device_type == "cuda": torch.cuda.manual_seed(42) # skipping full reproducibility for now, possibly investigate slowdown later # torch.use_deterministic_algorithms(True) # Precision if device_type == "cuda": torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() if is_ddp_requested and device_type == "cuda": device = torch.device("cuda", ddp_local_rank) torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() else: device = torch.device(device_type) # mps|cpu if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device def compute_cleanup(): """Companion function to compute_init, to clean things up before script exit""" if is_ddp_initialized(): dist.destroy_process_group() class DummyWandb: """Useful if we wish to not use wandb but have all the same signatures""" def __init__(self): pass def log(self, *args, **kwargs): pass def finish(self): pass # hardcoded BF16 peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, MI325X, MI355X and Intel PVC # inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py def get_peak_flops(device_name: str) -> float: if "A100" in device_name: # data from https://www.nvidia.com/en-us/data-center/a100/ return 312e12 elif "H100" in device_name: # data from https://www.nvidia.com/en-us/data-center/h100/ # NOTE: Specifications are one-half lower without sparsity. if "NVL" in device_name: return 835e12 elif "PCIe" in device_name: return 756e12 else: # for H100 SXM and other variants return 989e12 elif "H200" in device_name: # data from https://www.nvidia.com/en-us/data-center/h200/ return 989e12 elif "B200" in device_name: # data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 return 2.25e15 elif "MI355X" in device_name: # MI355X data from https://www.amd.com/en/products/accelerators/instinct/mi350/mi355x.html return 2500e12 elif "MI300X" in device_name or "MI325X" in device_name: # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html return 1300e12 elif "MI250X" in device_name: # data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD) return 191.5e12 elif "Data Center GPU Max 1550" in device_name: # Also known as Ponte Vecchio (PVC). # data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html # Dot Product Accumulate Systolic (DPAS): # - Freq: 1300MHz # - #ops: 512 # Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16) # Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16) max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units return 512 * max_comp_units * 1300 * 10**6 elif "l40s" in device_name: # data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413" return 362e12 else: # for other GPU types, assume A100 logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100") return 312e12