Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion benchmarks/attention_benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult:
from mla_runner import run_mla_benchmark as run_mla

return run_mla(
config.backend, config, prefill_backend=config.prefill_backend, **kwargs
config.backend,
config,
prefill_backend=config.prefill_backend,
sparse_mla_force_mqa=config.sparse_mla_force_mqa,
**kwargs,
)


Expand Down Expand Up @@ -858,6 +862,71 @@ def main():
f"\n [yellow]Prefill always faster for batch_size={bs}[/]"
)

# Handle MHA vs MQA comparison mode for sparse MLA
elif hasattr(args, "mode") and args.mode == "mha_vs_mqa":
console.print("[yellow]Mode: MHA vs MQA comparison for sparse MLA[/]")

# Two variants: MHA (default) and MQA (forced)
variants = [
("mha", False),
("mqa", True),
]
total = len(backends) * len(args.batch_specs) * len(variants)

with tqdm(total=total, desc="Benchmarking") as pbar:
for spec in args.batch_specs:
for backend in backends:
for variant_label, force_mqa in variants:
config = BenchmarkConfig(
backend=f"{backend}_{variant_label}",
batch_spec=spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
sparse_mla_force_mqa=force_mqa,
)

# run_mla_benchmark needs the real backend name
from mla_runner import run_mla_benchmark as run_mla

try:
result = run_mla(
backend,
config,
sparse_mla_force_mqa=force_mqa,
)
except Exception as e:
result = BenchmarkResult(
config=config,
mean_time=float("inf"),
std_time=0,
min_time=float("inf"),
max_time=float("inf"),
error=str(e),
)

all_results.append(result)

if not result.success:
console.print(
f"[red]Error {backend}_{variant_label} "
f"{spec}: {result.error}[/]"
)

pbar.update(1)

# Display results with variant labels as separate "backends"
console.print("\n[bold green]MHA vs MQA Results:[/]")
formatter = ResultsFormatter(console)
variant_backends = [f"{b}_{v}" for b in backends for v, _ in variants]
formatter.print_table(all_results, variant_backends)

# Handle model parameter sweep mode
elif hasattr(args, "model_parameter_sweep") and args.model_parameter_sweep:
# Model parameter sweep
Expand Down
1 change: 1 addition & 0 deletions benchmarks/attention_benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ class BenchmarkConfig:
# Backend-specific tuning
num_kv_splits: int | None = None # CUTLASS MLA
reorder_batch_threshold: int | None = None # FlashAttn MLA, FlashMLA
sparse_mla_force_mqa: bool = False # Force MQA path for sparse MLA


@dataclass
Expand Down
64 changes: 64 additions & 0 deletions benchmarks/attention_benchmarks/configs/mla_sparse_mha_vs_mqa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Sparse MLA benchmark: forward_mha vs forward_mqa crossover plot
#
# Dense pure-prefill sweep to generate a smooth speedup curve.
# Two crossover points:
# - Low end: MQA→MHA around q128-q160
# - High end: MHA→MQA around q6k-q8k
#
# Usage:
# python benchmark.py --config configs/mla_sparse_mha_vs_mqa.yaml

mode: mha_vs_mqa

model:
name: "deepseek-v3"
num_layers: 60
num_q_heads: 128
num_kv_heads: 1
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128

batch_specs:
- "q64"
- "q96"
- "q128"
- "q160"
- "q192"
- "q224"
- "q256"
- "q320"
- "q384"
- "q448"
- "q512"
- "q640"
- "q768"
- "q896"
- "q1k"
- "q1280"
- "q1536"
- "q1792"
- "q2k"
- "q2560"
- "q3k"
- "q3584"
- "q4k"
- "q5k"
- "q6k"
- "q7k"
- "q8k"
- "q10k"
- "q12k"
- "q14k"
- "q16k"

backends:
- FLASHMLA_SPARSE

device: "cuda:0"
repeats: 10
warmup_iters: 3
profile_memory: true
47 changes: 43 additions & 4 deletions benchmarks/attention_benchmarks/mla_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def create_minimal_vllm_config(
index_topk: int | None = None,
prefill_backend: str | None = None,
kv_cache_dtype: str = "auto",
sparse_mla_force_mqa: bool = False,
) -> VllmConfig:
"""
Create minimal VllmConfig for MLA benchmarks.
Expand All @@ -81,6 +82,8 @@ def create_minimal_vllm_config(
prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer",
"cudnn", "trtllm"). Configures the attention config to
force the specified prefill backend.
sparse_mla_force_mqa: If True, forces all sparse MLA tokens through
forward_mqa (even prefill tokens).

Returns:
VllmConfig for benchmarking
Expand Down Expand Up @@ -193,6 +196,9 @@ def create_minimal_vllm_config(
"use_trtllm_ragged_deepseek_prefill"
]

if sparse_mla_force_mqa:
vllm_config.attention_config.sparse_mla_force_mqa = True

return vllm_config


Expand Down Expand Up @@ -778,9 +784,22 @@ def _run_single_benchmark(
indexer.fill_random_indices(total_q, max_kv_len)

# Determine which forward methods to use based on metadata.
# Sparse MLA backends always use forward_mqa
has_decode = is_sparse or getattr(metadata, "decode", None) is not None
has_prefill = not is_sparse and getattr(metadata, "prefill", None) is not None
# Non-sparse backends use .decode/.prefill sub-objects.
# Sparse backends use num_decode_tokens/num_prefills directly.
#
# sparse_mla_force_mqa overrides: even for prefill metadata, use MQA.
force_mqa = getattr(config, "sparse_mla_force_mqa", False)
if force_mqa:
has_decode = True
has_prefill = False
elif is_sparse:
num_decode_tokens = getattr(metadata, "num_decode_tokens", 0)
num_prefills = getattr(metadata, "num_prefills", 0)
has_decode = num_decode_tokens > 0
has_prefill = num_prefills > 0
else:
has_decode = getattr(metadata, "decode", None) is not None
has_prefill = getattr(metadata, "prefill", None) is not None
if not has_decode and not has_prefill:
raise RuntimeError("Metadata has neither decode nor prefill metadata")

Expand Down Expand Up @@ -887,6 +906,7 @@ def _run_mla_benchmark_batched(
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
index_topk: int = 2048,
prefill_backend: str | None = None,
sparse_mla_force_mqa: bool = False,
) -> list[BenchmarkResult]:
"""
Unified batched MLA benchmark runner for all backends.
Expand All @@ -905,6 +925,8 @@ def _run_mla_benchmark_batched(
index_topk: Topk value for sparse MLA backends (default 2048)
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
When set, forces the specified FlashAttention version for prefill.
sparse_mla_force_mqa: If True, forces all sparse MLA tokens through
forward_mqa (even prefill tokens).

Returns:
List of BenchmarkResult objects
Expand Down Expand Up @@ -955,10 +977,20 @@ def _run_mla_benchmark_batched(
index_topk=index_topk if is_sparse else None,
prefill_backend=prefill_backend,
kv_cache_dtype=kv_cache_dtype,
sparse_mla_force_mqa=sparse_mla_force_mqa,
)

results = []

# Initialize workspace manager (needed by metadata builders)
from vllm.v1.worker.workspace import (
init_workspace_manager,
is_workspace_manager_initialized,
)

if not is_workspace_manager_initialized():
init_workspace_manager(device)

with set_current_vllm_config(vllm_config):
# Clear cached prefill backend detection functions so they re-evaluate
# with the current VllmConfig. These are @functools.cache decorated and
Expand Down Expand Up @@ -1068,6 +1100,7 @@ def run_mla_benchmark(
num_kv_splits: int | None = None,
index_topk: int = 2048,
prefill_backend: str | None = None,
sparse_mla_force_mqa: bool = False,
) -> BenchmarkResult | list[BenchmarkResult]:
"""
Unified MLA benchmark runner for all backends.
Expand All @@ -1087,6 +1120,8 @@ def run_mla_benchmark(
index_topk: Topk value for sparse MLA backends (default 2048)
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
When set, forces the specified FlashAttention version for prefill.
sparse_mla_force_mqa: If True, forces all sparse MLA tokens through
forward_mqa (even prefill tokens).

Returns:
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
Expand All @@ -1111,7 +1146,11 @@ def run_mla_benchmark(

# Use unified batched execution
results = _run_mla_benchmark_batched(
backend, configs_with_params, index_topk, prefill_backend=prefill_backend
backend,
configs_with_params,
index_topk,
prefill_backend=prefill_backend,
sparse_mla_force_mqa=sparse_mla_force_mqa,
)

# Return single result or list based on input
Expand Down
4 changes: 2 additions & 2 deletions cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ if(VLLM_FLASH_ATTN_SRC_DIR)
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 29210221863736a08f71a866459e368ad1ac4a95
GIT_REPOSITORY https://github.com/MatthewBonanni/flash-attention.git
GIT_TAG 906d6f92089ccb1f2ea288d18ef10db39ce4c752
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
Loading
Loading