Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 103 additions & 29 deletions benchmarks/attention_benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
"""Run MLA benchmark with appropriate backend."""
from mla_runner import run_mla_benchmark as run_mla

return run_mla(config.backend, config, **kwargs)
return run_mla(
config.backend, config, prefill_backend=config.prefill_backend, **kwargs
)


def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
Expand Down Expand Up @@ -440,14 +442,21 @@ def main():
# Backend selection
parser.add_argument(
"--backends",
"--decode-backends",
nargs="+",
help="Backends to benchmark (flash, triton, flashinfer, cutlass_mla, "
help="Decode backends to benchmark (flash, triton, flashinfer, cutlass_mla, "
"flashinfer_mla, flashattn_mla, flashmla)",
)
parser.add_argument(
"--backend",
help="Single backend (alternative to --backends)",
)
parser.add_argument(
"--prefill-backends",
nargs="+",
help="Prefill backends to compare (fa2, fa3, fa4). "
"Uses the first decode backend for impl construction.",
)

# Batch specifications
parser.add_argument(
Expand Down Expand Up @@ -502,7 +511,7 @@ def main():

# Override args with YAML values, but CLI args take precedence
# Check if CLI provided backends (they would be non-None and not default)
cli_backends_provided = args.backends is not None or args.backend is not None
cli_backends_provided = args.backend is not None or args.backends is not None

# Backend(s) - only use YAML if CLI didn't specify
if not cli_backends_provided:
Expand All @@ -512,6 +521,12 @@ def main():
elif "backends" in yaml_config:
args.backends = yaml_config["backends"]
args.backend = None
elif "decode_backends" in yaml_config:
args.backends = yaml_config["decode_backends"]
args.backend = None

# Prefill backends (e.g., ["fa3", "fa4"])
args.prefill_backends = yaml_config.get("prefill_backends", None)

# Check for special modes
if "mode" in yaml_config:
Expand Down Expand Up @@ -613,7 +628,10 @@ def main():

# Determine backends
backends = args.backends or ([args.backend] if args.backend else ["flash"])
prefill_backends = getattr(args, "prefill_backends", None)
console.print(f"Backends: {', '.join(backends)}")
if prefill_backends:
console.print(f"Prefill backends: {', '.join(prefill_backends)}")
console.print(f"Batch specs: {', '.join(args.batch_specs)}")
console.print()

Expand Down Expand Up @@ -850,37 +868,93 @@ def main():

else:
# Normal mode: compare backends
total = len(backends) * len(args.batch_specs)
decode_results = []
prefill_results = []

with tqdm(total=total, desc="Benchmarking") as pbar:
for spec in args.batch_specs:
for backend in backends:
config = BenchmarkConfig(
backend=backend,
batch_spec=spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
)
# Run decode backend comparison
if not prefill_backends:
# No prefill backends specified: compare decode backends as before
total = len(backends) * len(args.batch_specs)

result = run_benchmark(config)
all_results.append(result)
with tqdm(total=total, desc="Benchmarking") as pbar:
for spec in args.batch_specs:
for backend in backends:
config = BenchmarkConfig(
backend=backend,
batch_spec=spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
)

if not result.success:
console.print(f"[red]Error {backend} {spec}: {result.error}[/]")
result = run_benchmark(config)
decode_results.append(result)

pbar.update(1)
if not result.success:
console.print(
f"[red]Error {backend} {spec}: {result.error}[/]"
)

# Display results
console.print("\n[bold green]Results:[/]")
formatter = ResultsFormatter(console)
formatter.print_table(all_results, backends)
pbar.update(1)

console.print("\n[bold green]Results:[/]")
formatter = ResultsFormatter(console)
formatter.print_table(decode_results, backends)

# Run prefill backend comparison
if prefill_backends:
# Use first decode backend for impl construction
decode_backend = backends[0]
total = len(prefill_backends) * len(args.batch_specs)

console.print(
f"[yellow]Prefill comparison mode: "
f"using {decode_backend} for decode impl[/]"
)

with tqdm(total=total, desc="Prefill benchmarking") as pbar:
for spec in args.batch_specs:
for pb in prefill_backends:
config = BenchmarkConfig(
backend=decode_backend,
batch_spec=spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
prefill_backend=pb,
)

result = run_benchmark(config)

# Label result with prefill backend name for display
labeled_config = replace(result.config, backend=pb)
result = replace(result, config=labeled_config)
prefill_results.append(result)

if not result.success:
console.print(f"[red]Error {pb} {spec}: {result.error}[/]")

pbar.update(1)

console.print("\n[bold green]Prefill Backend Results:[/]")
formatter = ResultsFormatter(console)
formatter.print_table(
prefill_results, prefill_backends, compare_to_fastest=True
)

all_results = decode_results + prefill_results

# Save results
if all_results:
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/attention_benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, num_heads: int, qk_nope_head_dim: int, v_head_dim: int):
self.qk_nope_head_dim = qk_nope_head_dim
self.v_head_dim = v_head_dim
self.out_dim = qk_nope_head_dim + v_head_dim
self.weight = torch.empty(0, dtype=torch.bfloat16)

def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]:
"""
Expand Down Expand Up @@ -213,6 +214,7 @@ class BenchmarkConfig:
use_cuda_graphs: bool = False

# MLA-specific
prefill_backend: str | None = None
kv_lora_rank: int | None = None
qk_nope_head_dim: int | None = None
qk_rope_head_dim: int | None = None
Expand Down
114 changes: 89 additions & 25 deletions benchmarks/attention_benchmarks/configs/mla_prefill.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
# MLA prefill-only benchmark configuration for sparse backends
# MLA prefill backend comparison
#
# Compares all available MLA prefill backends:
# FA backends: fa2, fa3, fa4 (FlashAttention versions)
# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer)
#
# Uses cutlass_mla as the decode backend for impl construction
# (only the prefill path is exercised).
#
# Backends that aren't available on the current platform will report errors
# in the results table (e.g., fa3 on Blackwell, cudnn without artifactory).
#
# Usage:
# python benchmark.py --config configs/mla_prefill.yaml

description: "MLA prefill backend comparison"

model:
name: "deepseek-v3"
Expand All @@ -12,20 +27,25 @@ model:
v_head_dim: 128
block_size: 128

# Model parameter sweep: simulate tensor parallelism by varying num_q_heads
# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads
model_parameter_sweep:
param_name: "num_q_heads"
values: [128, 64, 32, 16]
label_format: "{backend}_{value}h"
# model:
# name: "deepseek-v2-lite"
# num_layers: 27
# num_q_heads: 16
# num_kv_heads: 1
# head_dim: 576
# kv_lora_rank: 512
# qk_nope_head_dim: 128
# qk_rope_head_dim: 64
# v_head_dim: 128
# block_size: 128

batch_specs:
# Pure prefill
- "1q512"
- "1q1k"
- "1q2k"
- "1q4k"
- "1q8k"
- "q512"
- "q1k"
- "q2k"
- "q4k"
- "q8k"

# Batched pure prefill
- "2q512"
Expand All @@ -44,19 +64,63 @@ batch_specs:
- "8q4k"
- "8q8k"

# Extend
- "1q512s4k"
- "1q512s8k"
- "1q1ks8k"
- "1q2ks8k"
- "1q2ks16k"
- "1q4ks16k"
# Chunked prefill / extend
# Short context
- "q128s1k"
- "q256s2k"
- "q512s4k"
- "q1ks4k"
- "q2ks8k"
- "2q128s1k"
- "2q256s2k"
- "2q512s4k"
- "2q1ks4k"
- "2q2ks8k"
- "4q128s1k"
- "4q256s2k"
- "4q512s4k"
- "4q1ks4k"
- "4q2ks8k"
- "8q128s1k"
- "8q256s2k"
- "8q512s4k"
- "8q1ks4k"

# Medium context
- "q128s16k"
- "q512s16k"
- "q1ks16k"
- "q2ks16k"
- "2q128s16k"
- "2q512s16k"
- "2q1ks16k"
- "2q2ks16k"
- "4q128s16k"
- "4q512s16k"
- "4q1ks16k"
- "4q2ks16k"

# Long context
- "q128s64k"
- "q512s64k"
- "q1ks64k"
- "q2ks64k"
- "2q128s64k"
- "2q512s64k"
- "2q1ks64k"
- "2q2ks64k"

decode_backends:
- CUTLASS_MLA

backends:
- FLASHMLA_SPARSE
- FLASHINFER_MLA_SPARSE
prefill_backends:
- fa2
- fa3
- fa4
- flashinfer
- cudnn
- trtllm

device: "cuda:0"
repeats: 10
warmup_iters: 3
profile_memory: true
repeats: 20
warmup_iters: 5
Loading