delete autocast, an unnecessary thorn in my side, manage dtypes directly

This commit is contained in:
Andrej Karpathy
2026-03-04 23:55:24 +00:00
parent 752abc836e
commit 1076f97059
15 changed files with 258 additions and 167 deletions
+14 -6
View File
@@ -45,14 +45,22 @@ HAS_FA3 = _fa3 is not None
_override_impl = None
def _use_fa3():
"""Determine whether to use FA3 based on availability and override."""
def _resolve_use_fa3():
"""Decide once whether to use FA3, based on availability, override, and dtype."""
if _override_impl == 'fa3':
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
return True
if _override_impl == 'sdpa':
return False
return HAS_FA3 # auto
if HAS_FA3:
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
from nanochat.common import COMPUTE_DTYPE
if COMPUTE_DTYPE == torch.bfloat16:
return True
return False
return False
USE_FA3 = _resolve_use_fa3()
# =============================================================================
@@ -90,7 +98,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
# =============================================================================
@@ -108,7 +116,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
Returns:
Output tensor of shape (B, T, H, D)
"""
if _use_fa3():
if USE_FA3:
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
@@ -138,7 +146,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
Returns:
Output tensor of shape (B, T_new, H, D)
"""
if _use_fa3():
if USE_FA3:
return _fa3.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size