delete autocast, an unnecessary thorn in my side, manage dtypes directly
This commit is contained in:
@@ -10,6 +10,26 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from filelock import FileLock
|
||||
|
||||
# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
|
||||
# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
|
||||
# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
|
||||
_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
|
||||
def _detect_compute_dtype():
|
||||
env = os.environ.get("NANOCHAT_DTYPE")
|
||||
if env is not None:
|
||||
return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
|
||||
if torch.cuda.is_available():
|
||||
# bf16 requires SM 80+ (Ampere: A100, A10, etc.)
|
||||
# Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability >= (8, 0):
|
||||
return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
|
||||
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
|
||||
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
|
||||
return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
|
||||
return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
|
||||
COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter that adds colors to log messages."""
|
||||
# ANSI color codes
|
||||
|
||||
Reference in New Issue
Block a user