Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions benchmarks/benchmark_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,61 @@ def setup_fa4(ctx):
]


def get_peak_flops(device_name: str, dtype: torch.dtype = torch.bfloat16) -> float:
"""Return peak FLOPS for known GPUs scaled to the given dtype. Returns None if unknown.

Base values are BF16 dense (without sparsity). Scaling:
FP16 / BF16 : 1x (identical hardware throughput)
FP8 : 2x
FP32 : 0.5x
"""
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
flops = 312e12
elif "H100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h100/
# NOTE: Specifications are one-half lower without sparsity.
if "NVL" in device_name:
flops = 835e12
elif "PCIe" in device_name:
flops = 756e12
else: # for H100 SXM and other variants
flops = 989e12
elif "H200" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h200/
flops = 989e12
elif "H20" in device_name:
flops = 148e12
elif "GB200" in device_name or "GB300" in device_name:
# Grace Blackwell Superchips (Grace CPU + Blackwell GPU)
# BF16 dense per GPU: 2,500 TFLOPS (half of 5,000 TFLOPS with sparsity)
# GB200 data from https://www.nvidia.com/en-us/data-center/dgx-gb200
# GB300 data from https://www.nvidia.com/en-us/data-center/dgx-gb300
flops = 2.5e15
elif "B300" in device_name:
# data from https://www.nvidia.com/en-us/data-center/b300/
# NOTE: Specifications are one-half lower without sparsity.
flops = 3.5e15
elif "B200" in device_name:
# data from https://www.nvidia.com/en-us/data-center/b200/
# data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703
# NOTE: Specifications are one-half lower without sparsity.
flops = 2.25e15
elif "A6000" in device_name:
flops = 309.7e12
elif "L40S" in device_name or "l40s" in device_name:
flops = 362e12
else:
return None # unknown device, MFU will be omitted

if dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
flops *= 2
elif dtype == torch.float32:
flops *= 0.5
# torch.float16 and torch.bfloat16 have identical throughput, no scaling needed
return flops


def parse_int_k(s):
"""Parse an integer with optional k/K suffix, e.g. '8k' -> 8192."""
s = s.strip().lower()
Expand Down Expand Up @@ -271,6 +326,7 @@ def main():
dtype = torch.bfloat16
dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
device = 'cuda'
peak_flops = get_peak_flops(torch.cuda.get_device_name(0), dtype=dtype)
page_size = None
softcap = 0.0
deterministic = args.deterministic
Expand Down Expand Up @@ -357,7 +413,7 @@ def main():
if not shown_backends:
return

col_w = 16
col_w = 20 if peak_flops is not None else 16

for direction, times, flops_mult in [("FWD", time_f, 1.0), ("BWD", time_b, 2.5)]:
if not times:
Expand All @@ -366,11 +422,12 @@ def main():
if not configs:
continue

col_label = "ms / TFLOPS / MFU%" if peak_flops is not None else "ms / TFLOPS"
header = f"{'hdim':>9} {'causal':>6} {'batch':>5} {'seqlen':>6}"
for b in shown_backends:
header += f" {b:>{col_w}}"
print(f"\n{'=' * len(header)}")
print(f" {direction} (ms / TFLOPS)")
print(f" {direction} ({col_label})")
print(f"{'=' * len(header)}")
print(header)
print("-" * len(header))
Expand All @@ -385,7 +442,11 @@ def main():
if t is not None:
tflops = flops_mult * nFLOPS / t * 1e-12
ms = t * 1e3
cell = f"{ms:.2f}/{tflops:.0f}"
if peak_flops is not None:
mfu = flops_mult * nFLOPS / t / peak_flops * 100
cell = f"{ms:.2f}/{tflops:.0f}/{mfu:.1f}%"
else:
cell = f"{ms:.2f}/{tflops:.0f}"
row += f" {cell:>{col_w}}"
else:
row += f" {'—':>{col_w}}"
Expand Down