use _PEAK_FLOPS_TABLE instead of if-else structure (#479)

This commit is contained in:
Sofie Van Landeghem
2026-02-01 04:45:06 +01:00
committed by GitHub
parent 43078c347e
commit 4d6415b8ef
+41 -59
View File
@@ -207,70 +207,52 @@ class DummyWandb:
def get_peak_flops(device_name: str) -> float: def get_peak_flops(device_name: str) -> float:
name = device_name.lower() name = device_name.lower()
# --- NVIDIA Blackwell --- # Table order matters: more specific patterns first.
if "gb200" in name or "grace blackwell" in name: _PEAK_FLOPS_TABLE = (
return 2.5e15 # NVIDIA Blackwell
if "b200" in name: (["gb200"], 2.5e15),
return 2.25e15 (["grace blackwell"], 2.5e15),
if "b100" in name: (["b200"], 2.25e15),
return 1.8e15 (["b100"], 1.8e15),
# NVIDIA Hopper
# --- NVIDIA Hopper (H100/H200/H800) --- (["h200", "nvl"], 836e12),
if "h200" in name: (["h200", "pcie"], 836e12),
if "nvl" in name or "pcie" in name: (["h200"], 989e12),
return 836e12 (["h100", "nvl"], 835e12),
return 989e12 # H200 SXM (["h100", "pcie"], 756e12),
if "h100" in name: (["h100"], 989e12),
if "nvl" in name: (["h800", "nvl"], 989e12),
return 835e12 (["h800"], 756e12),
if "pcie" in name: # NVIDIA Ampere data center
return 756e12 (["a100"], 312e12),
return 989e12 # H100 SXM (["a800"], 312e12),
if "h800" in name: (["a40"], 149.7e12),
if "nvl" in name: (["a30"], 165e12),
return 989e12 # NVIDIA Ada data center
return 756e12 # H800 PCIe (["l40s"], 362e12),
(["l40-s"], 362e12),
# --- NVIDIA Ampere data center --- (["l40 s"], 362e12),
if "a100" in name or "a800" in name: (["l4"], 121e12),
return 312e12 # AMD CDNA accelerators
if "a40" in name: (["mi355"], 2.5e15),
return 149.7e12 (["mi325"], 1.3074e15),
if "a30" in name: (["mi300x"], 1.3074e15),
return 165e12 (["mi300a"], 980.6e12),
(["mi250x"], 383e12),
# --- NVIDIA Ada data center --- (["mi250"], 362.1e12),
if "l40s" in name or "l40-s" in name or "l40 s" in name: # Consumer RTX
return 362e12 (["5090"], 209.5e12),
if "l4" in name: (["4090"], 165.2e12),
return 121e12 (["3090"], 71e12),
)
# --- AMD CDNA accelerators --- for patterns, flops in _PEAK_FLOPS_TABLE:
if "mi355" in name: if all(p in name for p in patterns):
return 2.5e15 return flops
if "mi325" in name or "mi300x" in name:
return 1.3074e15
if "mi300a" in name:
return 980.6e12
if "mi250x" in name:
return 383e12
if "mi250" in name:
return 362.1e12
# --- Intel ---
if "data center gpu max 1550" in name: if "data center gpu max 1550" in name:
# Ponte Vecchio (PVC) - dynamic based on compute units # Ponte Vecchio (PVC) - dynamic based on compute units
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
return 512 * max_comp_units * 1300 * 10**6 return 512 * max_comp_units * 1300 * 10**6
# --- Consumer RTX (for hobbyists) ---
if "5090" in name:
return 209.5e12
if "4090" in name:
return 165.2e12
if "3090" in name:
return 71e12
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess # Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%") logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
return float('inf') return float('inf')