From 2b58e2dd2ae134078c610665bf0811196c62830c Mon Sep 17 00:00:00 2001 From: obxium Date: Sat, 18 Oct 2025 09:31:11 -0400 Subject: [PATCH] Update logo in code as well --- nanochat/common.py | 91 +++++++++++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 33 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index 8b10df9..bb825ff 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -8,43 +8,58 @@ import logging import torch import torch.distributed as dist + class ColoredFormatter(logging.Formatter): """Custom formatter that adds colors to log messages.""" + # ANSI color codes COLORS = { - 'DEBUG': '\033[36m', # Cyan - 'INFO': '\033[32m', # Green - 'WARNING': '\033[33m', # Yellow - 'ERROR': '\033[31m', # Red - 'CRITICAL': '\033[35m', # Magenta + "DEBUG": "\033[36m", # Cyan + "INFO": "\033[32m", # Green + "WARNING": "\033[33m", # Yellow + "ERROR": "\033[31m", # Red + "CRITICAL": "\033[35m", # Magenta } - RESET = '\033[0m' - BOLD = '\033[1m' + RESET = "\033[0m" + BOLD = "\033[1m" + def format(self, record): # Add color to the level name levelname = record.levelname if levelname in self.COLORS: - record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" + record.levelname = ( + f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" + ) # Format the message message = super().format(record) # Add color to specific parts of the message - if levelname == 'INFO': + if levelname == "INFO": # Highlight numbers and percentages - message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message) - message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) + message = re.sub( + r"(\d+\.?\d*\s*(?:GB|MB|%|docs))", + rf"{self.BOLD}\1{self.RESET}", + message, + ) + message = re.sub( + r"(Shard \d+)", + rf"{self.COLORS['INFO']}{self.BOLD}\1{self.RESET}", + message, + ) return message + def setup_default_logging(): handler = logging.StreamHandler() - handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) - logging.basicConfig( - level=logging.INFO, - handlers=[handler] + handler.setFormatter( + ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ) + logging.basicConfig(level=logging.INFO, handlers=[handler]) + setup_default_logging() logger = logging.getLogger(__name__) + def get_base_dir(): # co-locate nanochat intermediates with other cached data in ~/.cache (by default) if os.environ.get("NANOCHAT_BASE_DIR"): @@ -56,39 +71,44 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir -def print0(s="",**kwargs): - ddp_rank = int(os.environ.get('RANK', 0)) + +def print0(s="", **kwargs): + ddp_rank = int(os.environ.get("RANK", 0)) if ddp_rank == 0: print(s, **kwargs) + def print_banner(): # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ banner = """ - █████ █████ - ░░███ ░░███ - ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████ -░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░ - ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███ - ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███ - ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████ -░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░ -""" + █████ █████ + ░░███ ░░███ + ████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████ + ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░ + ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███ + ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███ + ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████ + ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░ + """ print0(banner) + def is_ddp(): # TODO is there a proper way - return int(os.environ.get('RANK', -1)) != -1 + return int(os.environ.get("RANK", -1)) != -1 + def get_dist_info(): if is_ddp(): - assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) - ddp_rank = int(os.environ['RANK']) - ddp_local_rank = int(os.environ['LOCAL_RANK']) - ddp_world_size = int(os.environ['WORLD_SIZE']) + assert all(var in os.environ for var in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]) + ddp_rank = int(os.environ["RANK"]) + ddp_local_rank = int(os.environ["LOCAL_RANK"]) + ddp_world_size = int(os.environ["WORLD_SIZE"]) return True, ddp_rank, ddp_local_rank, ddp_world_size else: return False, 0, 0, 1 + def compute_init(): """Basic initialization that we keep doing over and over, so make common.""" @@ -104,13 +124,13 @@ def compute_init(): # torch.backends.cudnn.benchmark = False # Precision - torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls # Distributed setup: Distributed Data Parallel (DDP), optional ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() if ddp: device = torch.device("cuda", ddp_local_rank) - torch.cuda.set_device(device) # make "cuda" default to this device + torch.cuda.set_device(device) # make "cuda" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() else: @@ -121,16 +141,21 @@ def compute_init(): return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device + def compute_cleanup(): """Companion function to compute_init, to clean things up before script exit""" if is_ddp(): dist.destroy_process_group() + class DummyWandb: """Useful if we wish to not use wandb but have all the same signatures""" + def __init__(self): pass + def log(self, *args, **kwargs): pass + def finish(self): pass