delete autocast, an unnecessary thorn in my side, manage dtypes directly
This commit is contained in:
@@ -45,14 +45,22 @@ HAS_FA3 = _fa3 is not None
|
||||
_override_impl = None
|
||||
|
||||
|
||||
def _use_fa3():
|
||||
"""Determine whether to use FA3 based on availability and override."""
|
||||
def _resolve_use_fa3():
|
||||
"""Decide once whether to use FA3, based on availability, override, and dtype."""
|
||||
if _override_impl == 'fa3':
|
||||
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
||||
return True
|
||||
if _override_impl == 'sdpa':
|
||||
return False
|
||||
return HAS_FA3 # auto
|
||||
if HAS_FA3:
|
||||
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
|
||||
from nanochat.common import COMPUTE_DTYPE
|
||||
if COMPUTE_DTYPE == torch.bfloat16:
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
||||
USE_FA3 = _resolve_use_fa3()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -90,7 +98,7 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
||||
# sliding window (left)
|
||||
if window >= 0 and window < Tk:
|
||||
mask = mask & ((row_idx - col_idx) <= window)
|
||||
|
||||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
# =============================================================================
|
||||
@@ -108,7 +116,7 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
||||
Returns:
|
||||
Output tensor of shape (B, T, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
|
||||
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
||||
@@ -138,7 +146,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
||||
Returns:
|
||||
Output tensor of shape (B, T_new, H, D)
|
||||
"""
|
||||
if _use_fa3():
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
causal=causal, window_size=window_size
|
||||
|
||||
Reference in New Issue
Block a user