use COMPUTE_DTYPE-aware cast in Muon polar express step
The bf16 cast is intentional for speed on Hopper+ GPUs, but should be skipped on other platforms rather than blindly applied. fp16 is unstable here due to its limited exponent range, and fp32 platforms don't benefit from the cast. Now: bf16 when COMPUTE_DTYPE is bf16, no cast otherwise. Inspired by PR #667. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+3
-1
@@ -10,6 +10,7 @@ Further contributions from @karpathy and @chrisjmccormick.
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from nanochat.common import COMPUTE_DTYPE
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
"""
|
"""
|
||||||
@@ -112,7 +113,8 @@ def muon_step_fused(
|
|||||||
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
||||||
|
|
||||||
# Polar express
|
# Polar express
|
||||||
X = g.bfloat16()
|
# Cast to bf16 for speed when available; skip cast otherwise (fp16 is unstable here due to limited exponent range)
|
||||||
|
X = g.bfloat16() if COMPUTE_DTYPE == torch.bfloat16 else g
|
||||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
|
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
|
||||||
if g.size(-2) > g.size(-1): # Tall matrix
|
if g.size(-2) > g.size(-1): # Tall matrix
|
||||||
for a, b, c in polar_express_coeffs[:ns_steps]:
|
for a, b, c in polar_express_coeffs[:ns_steps]:
|
||||||
|
|||||||
Reference in New Issue
Block a user