diff --git a/benchmarks/attention_benchmarks/README.md b/benchmarks/attention_benchmarks/README.md new file mode 100644 index 000000000000..788ce94f23fb --- /dev/null +++ b/benchmarks/attention_benchmarks/README.md @@ -0,0 +1,266 @@ +# vLLM Attention Benchmarking Suite + +Fast, flexible benchmarking for vLLM attention and MLA backends with an extended batch specification grammar. + +## Quick Start + +```bash +cd benchmarks/attention_benchmarks + +# Run a pre-configured benchmark +python benchmark.py --config configs/mla_decode.yaml +python benchmark.py --config configs/mla_mixed_batch.yaml +python benchmark.py --config configs/speculative_decode.yaml +python benchmark.py --config configs/standard_attention.yaml +python benchmark.py --config configs/reorder_threshold.yaml + +# Or run custom benchmarks +python benchmark.py \ + --backends flash flashinfer \ + --batch-specs "q2k" "8q1s1k" "2q2k_32q1s1k" \ + --output-csv results.csv +``` + +## Simplified Batch Specification Grammar + +Express workloads concisely using query length and sequence length: + +```python +"q2k" # 2048-token prefill (q_len=2048, seq_len=2048) +"q1s1k" # Decode: 1 token with 1K sequence +"8q1s1k" # 8 decode requests +"q4s1k" # 4-token extend (e.g., spec decode) +"2q2k_32q1s1k" # Mixed: 2 prefills + 32 decodes +"16q4s1k" # 16 spec decode (4 tokens each) +``` + +### Grammar Rule + +```text +Format: (?) q(k?) (s(k?))? + +- count: Number of identical requests (optional, default=1) +- q_len: Query length (number of new tokens) +- seq_len: Total sequence length (optional, defaults to q_len for prefill) +- 'k': Multiplies value by 1024 + +Mixed batches: Use _ to combine (e.g., "2q2k_32q1s1k") +``` + +**Note**: Decode, prefill, and spec decode are just different query lengths - no special syntax needed! + +## Pre-configured Benchmarks + +The suite includes several pre-configured YAML benchmark configurations: + +### MLA Decode Benchmark + +Tests pure decode performance across MLA backends with varying batch sizes and sequence lengths. + +```bash +python benchmark.py --config configs/mla_decode.yaml +``` + +### MLA Mixed Batch Benchmark + +Tests chunked prefill performance with mixed prefill + decode batches. + +```bash +python benchmark.py --config configs/mla_mixed_batch.yaml +``` + +### Speculative Decoding Benchmark + +Tests speculative decode scenarios (K-token verification) and reorder_batch_threshold optimization. + +```bash +python benchmark.py --config configs/speculative_decode.yaml +``` + +### Standard Attention Benchmark + +Tests standard attention backends (Flash/Triton/FlashInfer) with pure prefill, decode, and mixed batches. + +```bash +python benchmark.py --config configs/standard_attention.yaml +``` + +### Reorder Threshold Study + +**Question:** At what query length does the prefill pipeline become faster than the decode pipeline? + +Tests query lengths from 1-1024 across 9 batch sizes to find the crossover point. Uses `decode_vs_prefill` mode to compare both pipelines for each query length. + +```bash +python benchmark.py --config configs/reorder_threshold.yaml +``` + +--- + +## Universal Benchmark + +The `benchmark.py` script handles **all** backends - both standard attention and MLA. + +### Standard Attention (Flash/Triton/FlashInfer) + +```bash +python benchmark.py \ + --backends flash triton flashinfer \ + --batch-specs "q2k" "8q1s1k" "2q2k_32q1s1k" \ + --num-layers 10 \ + --repeats 5 \ + --output-csv results.csv +``` + +### MLA Backends + +```bash +# Compare all MLA backends +python benchmark.py \ + --backends cutlass_mla flashinfer_mla flashattn_mla flashmla \ + --batch-specs "64q1s1k" "64q1s4k" \ + --output-csv mla_results.csv +``` + +### Parameter Sweeps + +Use `--sweep-param` and `--sweep-values` to run parameter sweeps from the CLI: + +#### CUTLASS MLA num-splits Optimization + +**Question:** What is the optimal `num_kv_splits` for CUTLASS MLA? + +```bash +python benchmark.py \ + --backend cutlass_mla \ + --batch-specs "64q1s1k" "64q1s4k" "64q1s16k" \ + --sweep-param num_kv_splits \ + --sweep-values 1 2 4 8 16 \ + --output-json optimal_splits.json +``` + +#### Reorder Batch Threshold Optimization + +**Question:** What's the optimal `reorder_batch_threshold` for speculative decoding? + +```bash +python benchmark.py \ + --backend flashmla \ + --batch-specs "q4s1k" "q8s2k" \ + --sweep-param reorder_batch_threshold \ + --sweep-values 1 4 16 64 256 512 \ + --output-csv threshold_sweep.csv +``` + +### All Command-Line Options + +```text +--config CONFIG # Path to YAML config file (overrides other args) +--backends BACKEND [BACKEND ...] # flash, triton, flashinfer, cutlass_mla, + # flashinfer_mla, flashattn_mla, flashmla +--backend BACKEND # Single backend (alternative to --backends) +--batch-specs SPEC [SPEC ...] # Batch specifications using extended grammar + +# Model configuration +--num-layers N # Number of layers +--head-dim N # Head dimension +--num-q-heads N # Query heads +--num-kv-heads N # KV heads +--block-size N # Block size + +# Benchmark settings +--device DEVICE # Device (default: cuda:0) +--repeats N # Repetitions +--warmup-iters N # Warmup iterations +--profile-memory # Profile memory usage + +# Parameter sweeps +--sweep-param PARAM # Parameter name to sweep (e.g., num_kv_splits, + # reorder_batch_threshold) +--sweep-values N [N ...] # Values to sweep for the parameter + +# Output +--output-csv FILE # Save to CSV +--output-json FILE # Save to JSON +``` + +## Hardware Requirements + +| Backend | Hardware | +|---------|----------| +| Flash/Triton/FlashInfer | Any CUDA GPU | +| CUTLASS MLA | Blackwell (SM100+) | +| FlashAttn MLA | Hopper (SM90+) | +| FlashMLA | Hopper (SM90+) | +| FlashInfer-MLA | Any CUDA GPU | + +## Using MLA Runner Directly + +All MLA backends are available through `mla_runner.run_mla_benchmark()`: + +```python +from mla_runner import run_mla_benchmark +from common import BenchmarkConfig + +config = BenchmarkConfig( + backend="cutlass_mla", + batch_spec="64q1s4k", + num_layers=10, + head_dim=576, + num_q_heads=128, + num_kv_heads=1, + block_size=128, + device="cuda:0", + repeats=5, + warmup_iters=3, +) + +# CUTLASS MLA with specific num_kv_splits +result = run_mla_benchmark("cutlass_mla", config, num_kv_splits=4) +print(f"Time: {result.mean_time:.6f}s") + +# FlashInfer-MLA +result = run_mla_benchmark("flashinfer_mla", config) + +# FlashAttn MLA (Hopper SM90+) +result = run_mla_benchmark("flashattn_mla", config, reorder_batch_threshold=64) + +# FlashMLA (Hopper SM90+) +result = run_mla_benchmark("flashmla", config, reorder_batch_threshold=64) +``` + +## Python API + +```python +from batch_spec import parse_batch_spec, format_batch_spec, get_batch_stats +from common import BenchmarkConfig, BenchmarkResult, ResultsFormatter + +# Parse batch specs +requests = parse_batch_spec("2q2k_q4s1k_32q1s1k") +print(format_batch_spec(requests)) +# "2 prefill (2x2k), 1 extend (1xq4kv1k), 32 decode (32x1k)" + +# Get batch statistics +stats = get_batch_stats(requests) +print(f"Total tokens: {stats['total_tokens']}") +print(f"Num decode: {stats['num_decode']}, Num prefill: {stats['num_prefill']}") + +# Format results +formatter = ResultsFormatter() +formatter.save_csv(results, "output.csv") +formatter.save_json(results, "output.json") +``` + +## Tips + +**1. Warmup matters** - Use `--warmup-iters 10` for stable results + +**2. Multiple repeats** - Use `--repeats 20` for low variance + +**3. Save results** - Always use `--output-csv` or `--output-json` + +**4. Test incrementally** - Start with `--num-layers 1 --repeats 1` + +**5. Extended grammar** - Leverage spec decode, chunked prefill patterns + +**6. Parameter sweeps** - Use `--sweep-param` and `--sweep-values` to find optimal values diff --git a/benchmarks/attention_benchmarks/__init__.py b/benchmarks/attention_benchmarks/__init__.py new file mode 100644 index 000000000000..df7a6328569d --- /dev/null +++ b/benchmarks/attention_benchmarks/__init__.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""vLLM Attention Benchmarking Suite.""" + +from .batch_spec import ( + BatchRequest, + format_batch_spec, + get_batch_stats, + parse_batch_spec, + reorder_for_flashinfer, + split_by_type, +) +from .common import ( + BenchmarkConfig, + BenchmarkResult, + MockLayer, + MockModelConfig, + ResultsFormatter, + get_attention_scale, + is_mla_backend, + setup_mla_dims, +) + +__all__ = [ + # Batch specification + "BatchRequest", + "parse_batch_spec", + "format_batch_spec", + "reorder_for_flashinfer", + "split_by_type", + "get_batch_stats", + # Benchmarking infrastructure + "BenchmarkConfig", + "BenchmarkResult", + "ResultsFormatter", + # Mock objects + "MockLayer", + "MockModelConfig", + # Utilities + "setup_mla_dims", + "get_attention_scale", + "is_mla_backend", +] diff --git a/benchmarks/attention_benchmarks/batch_spec.py b/benchmarks/attention_benchmarks/batch_spec.py new file mode 100644 index 000000000000..41681796e2e6 --- /dev/null +++ b/benchmarks/attention_benchmarks/batch_spec.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Simplified batch specification grammar for attention benchmarks. + +Grammar (underscore-separated segments): + Format: (?) q(k?) (s(k?))? + + - count: Number of identical requests (optional, default=1) + - q_len: Query length (number of new tokens) + - seq_len: Total sequence length (optional, defaults to q_len for prefill) + - 'k' suffix: Multiplies value by 1024 + +Common patterns: + - Prefill: q_len == seq_len (e.g., "q2k" → 2048 new tokens, 2048 seq) + - Decode: q_len == 1 (e.g., "q1s1k" → 1 token, 1024 seq length) + - Extend: q_len < seq_len (e.g., "q4s1k" → 4 tokens, 1024 seq length) + +Examples: + q2k -> [(2048, 2048)] # Prefill: 2048 tokens + q1s1k -> [(1, 1024)] # Decode: 1 token, 1K sequence + 8q1s1k -> [(1, 1024)] * 8 # 8 decode requests + q4s1k -> [(4, 1024)] # 4-token extend (spec decode) + 2q1k_32q1s1k -> [(1024, 1024)] * 2 + [(1, 1024)] * 32 # Mixed batch + 16q4s1k -> [(4, 1024)] * 16 # 16 spec decode requests +""" + +from collections import Counter +from dataclasses import dataclass + +import regex as re + + +@dataclass +class BatchRequest: + """Represents a single request in a batch.""" + + q_len: int # Query length (number of new tokens) + kv_len: int # Total KV cache length + + @property + def is_decode(self) -> bool: + """True if this is a decode request (q_len == 1).""" + return self.q_len == 1 + + @property + def is_prefill(self) -> bool: + """True if this is a pure prefill (q_len == kv_len).""" + return self.q_len == self.kv_len + + @property + def is_extend(self) -> bool: + """True if this is context extension (q_len > 1, kv_len > q_len).""" + return self.q_len > 1 and self.kv_len > self.q_len + + @property + def context_len(self) -> int: + """Context length (KV cache - query).""" + return self.kv_len - self.q_len + + def as_tuple(self) -> tuple[int, int]: + """Return as (q_len, kv_len) tuple for compatibility.""" + return (self.q_len, self.kv_len) + + +def _parse_size(size_str: str, k_suffix: str) -> int: + """Parse size string with optional 'k' suffix.""" + size = int(size_str) + return size * 1024 if k_suffix == "k" else size + + +def parse_batch_spec(spec: str) -> list[BatchRequest]: + """ + Parse batch specification string into list of BatchRequest objects. + + Grammar: (?) q(k?) (s(k?))? + + Args: + spec: Batch specification string (see module docstring for grammar) + + Returns: + List of BatchRequest objects + + Raises: + ValueError: If spec format is invalid + """ + requests = [] + + for seg in spec.split("_"): + # Unified pattern: (?) q(k?) (s(k?))? + m = re.match(r"^(?:(\d+))?q(\d+)(k?)(?:s(\d+)(k?))?$", seg) + if m: + cnt = int(m.group(1)) if m.group(1) else 1 + q_len = _parse_size(m.group(2), m.group(3)) + kv_len = _parse_size(m.group(4), m.group(5)) if m.group(4) else q_len + requests.extend([BatchRequest(q_len=q_len, kv_len=kv_len)] * cnt) + continue + + raise ValueError(f"Invalid batch spec segment: '{seg}'") + + return requests + + +def format_batch_spec(requests: list[BatchRequest]) -> str: + """ + Format list of BatchRequest into human-readable string. + + Groups requests by type and provides counts and sizes. + + Args: + requests: List of BatchRequest objects + + Returns: + Formatted string describing the batch + """ + kinds = { + "prefill": [], + "extend": [], + "decode": [], + } + + for req in requests: + tup = (req.q_len, req.kv_len) + if req.is_prefill: + kinds["prefill"].append(tup) + elif req.is_extend: + kinds["extend"].append(tup) + elif req.is_decode: + kinds["decode"].append(tup) + + parts = [] + for kind in ["prefill", "extend", "decode"]: + lst = kinds[kind] + if not lst: + continue + + cnt_total = len(lst) + ctr = Counter(lst) + inner = [] + + for (q, kv), cnt in ctr.items(): + if kind == "prefill": + size = f"{q // 1024}k" if q % 1024 == 0 else str(q) + inner.append(f"{cnt}x{size}") + elif kind == "decode": + size = f"{kv // 1024}k" if kv % 1024 == 0 else str(kv) + inner.append(f"{cnt}x{size}") + else: # extend + qstr = f"{q // 1024}k" if q % 1024 == 0 else str(q) + kstr = f"{kv // 1024}k" if kv % 1024 == 0 else str(kv) + inner.append(f"{cnt}xq{qstr}kv{kstr}") + + parts.append(f"{cnt_total} {kind} ({', '.join(inner)})") + + return ", ".join(parts) + + +def reorder_for_flashinfer(requests: list[BatchRequest]) -> list[BatchRequest]: + """ + Reorder requests for FlashInfer: decode first, then prefill. + + FlashInfer expects decode requests before prefill requests for + optimal performance. + + Args: + requests: Original list of BatchRequest + + Returns: + Reordered list with decode requests first + """ + decodes = [r for r in requests if r.is_decode] + non_decodes = [r for r in requests if not r.is_decode] + return decodes + non_decodes + + +def split_by_type( + requests: list[BatchRequest], +) -> dict[str, list[BatchRequest]]: + """ + Split requests by type for analysis. + + Args: + requests: List of BatchRequest + + Returns: + Dict with keys: 'decode', 'prefill', 'extend' + """ + result = { + "decode": [], + "prefill": [], + "extend": [], + } + + for req in requests: + if req.is_decode: + result["decode"].append(req) + elif req.is_prefill: + result["prefill"].append(req) + elif req.is_extend: + result["extend"].append(req) + + return result + + +def get_batch_stats(requests: list[BatchRequest]) -> dict: + """ + Compute statistics about a batch. + + Args: + requests: List of BatchRequest + + Returns: + Dict with batch statistics + """ + by_type = split_by_type(requests) + + return { + "total_requests": len(requests), + "num_decode": len(by_type["decode"]), + "num_prefill": len(by_type["prefill"]), + "num_extend": len(by_type["extend"]), + "total_tokens": sum(r.q_len for r in requests), + "total_kv_cache": sum(r.kv_len for r in requests), + "max_q_len": max((r.q_len for r in requests), default=0), + "max_kv_len": max((r.kv_len for r in requests), default=0), + "avg_q_len": sum(r.q_len for r in requests) / len(requests) if requests else 0, + "avg_kv_len": ( + sum(r.kv_len for r in requests) / len(requests) if requests else 0 + ), + } diff --git a/benchmarks/attention_benchmarks/benchmark.py b/benchmarks/attention_benchmarks/benchmark.py new file mode 100644 index 000000000000..ba11fca7452f --- /dev/null +++ b/benchmarks/attention_benchmarks/benchmark.py @@ -0,0 +1,886 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Universal vLLM Attention Benchmark + +Benchmark any attention backend with the extended grammar. +Supports standard attention (Flash/Triton/FlashInfer) and MLA backends. + +Examples: + # Standard attention + python benchmark.py --backends flash flashinfer --batch-specs "q2k" "8q1s1k" + + # MLA backends + python benchmark.py --backends cutlass_mla flashinfer_mla --batch-specs "64q1s1k" + + # Parameter sweep (CLI) + python benchmark.py --backend cutlass_mla \ + --batch-specs "64q1s1k" \ + --sweep-param num_kv_splits \ + --sweep-values 1 4 8 16 + + # Parameter sweep (YAML config - recommended) + python benchmark.py --config configs/cutlass_numsplits.yaml +""" + +import argparse +import sys +from dataclasses import replace +from pathlib import Path + +import yaml +from rich.console import Console +from tqdm import tqdm + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from batch_spec import parse_batch_spec +from common import ( + BenchmarkConfig, + BenchmarkResult, + ModelParameterSweep, + ParameterSweep, + ResultsFormatter, + is_mla_backend, +) + + +def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: + """Run standard attention benchmark (Flash/Triton/FlashInfer).""" + from runner import run_attention_benchmark + + return run_attention_benchmark(config) + + +def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult: + """Run MLA benchmark with appropriate backend.""" + from mla_runner import run_mla_benchmark as run_mla + + return run_mla(config.backend, config, **kwargs) + + +def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult: + """ + Run a single benchmark with proper backend selection. + + Args: + config: BenchmarkConfig with backend, batch_spec, and model params + **kwargs: Additional arguments passed to MLA benchmarks + + Returns: + BenchmarkResult (may have error field set on failure) + """ + try: + if is_mla_backend(config.backend): + return run_mla_benchmark(config, **kwargs) + else: + return run_standard_attention_benchmark(config) + except Exception as e: + return BenchmarkResult( + config=config, + mean_time=float("inf"), + std_time=0, + min_time=float("inf"), + max_time=float("inf"), + error=str(e), + ) + + +def run_model_parameter_sweep( + backends: list[str], + batch_specs: list[str], + base_config_args: dict, + sweep: ModelParameterSweep, + console: Console, +) -> list[BenchmarkResult]: + """ + Run model parameter sweep for given backends and batch specs. + + Args: + backends: List of backend names + batch_specs: List of batch specifications + base_config_args: Base configuration arguments (num_layers, head_dim, etc.) + sweep: ModelParameterSweep configuration + console: Rich console for output + + Returns: + List of BenchmarkResult objects + """ + all_results = [] + + console.print( + f"[yellow]Model sweep mode: testing {sweep.param_name} = {sweep.values}[/]" + ) + + total = len(backends) * len(batch_specs) * len(sweep.values) + + with tqdm(total=total, desc="Benchmarking") as pbar: + for backend in backends: + for spec in batch_specs: + for value in sweep.values: + # Create config with modified model parameter + config_args = base_config_args.copy() + config_args[sweep.param_name] = value + + # Create config with original backend for running + clean_config = BenchmarkConfig( + backend=backend, batch_spec=spec, **config_args + ) + + # Run benchmark + result = run_benchmark(clean_config) + + # Replace backend with labeled version for display + backend_label = sweep.get_label(backend, value) + labeled_config = replace(result.config, backend=backend_label) + result = replace(result, config=labeled_config) + all_results.append(result) + + if not result.success: + console.print( + f"[red]Error {backend} {spec} {sweep.param_name}=" + f"{value}: {result.error}[/]" + ) + + pbar.update(1) + + # Display sweep results - create separate table for each parameter value + console.print("\n[bold green]Model Parameter Sweep Results:[/]") + formatter = ResultsFormatter(console) + + # Group results by parameter value and extract backend mapping + by_param_value = {} + backend_mapping = {} # Maps labeled backend -> original backend + + for r in all_results: + # Extract original backend and param value from labeled backend + # The label format is: {backend}_{param_name}_{value} + # We need to reverse engineer this + labeled_backend = r.config.backend + + # Try each backend to find which one this result belongs to + for backend in backends: + for value in sweep.values: + expected_label = sweep.get_label(backend, value) + if labeled_backend == expected_label: + backend_mapping[labeled_backend] = backend + param_value = str(value) + + if param_value not in by_param_value: + by_param_value[param_value] = [] + by_param_value[param_value].append(r) + break + + # Create a table for each parameter value + sorted_param_values = sorted( + by_param_value.keys(), key=lambda x: int(x) if x.isdigit() else x + ) + + for param_value in sorted_param_values: + console.print(f"\n[bold cyan]{sweep.param_name} = {param_value}[/]") + param_results = by_param_value[param_value] + + # Create modified results with original backend names + modified_results = [] + for r in param_results: + # Get the original backend name from our mapping + original_backend = backend_mapping[r.config.backend] + modified_config = replace(r.config, backend=original_backend) + modified_result = replace(r, config=modified_config) + modified_results.append(modified_result) + + # Print table with original backend names + formatter.print_table(modified_results, backends, compare_to_fastest=True) + + # Show optimal backend for each (param_value, batch_spec) combination + console.print( + f"\n[bold cyan]Optimal backend for each ({sweep.param_name}, batch_spec):[/]" + ) + + # Group by (param_value, batch_spec) + by_param_and_spec = {} + for r in all_results: + if r.success: + # Find which (backend, value) this result corresponds to + labeled_backend = r.config.backend + for backend in backends: + for value in sweep.values: + expected_label = sweep.get_label(backend, value) + if labeled_backend == expected_label: + param_value = str(value) + spec = r.config.batch_spec + key = (param_value, spec) + + if key not in by_param_and_spec: + by_param_and_spec[key] = [] + by_param_and_spec[key].append(r) + break + + # Sort by param value then spec + sorted_keys = sorted( + by_param_and_spec.keys(), + key=lambda x: (int(x[0]) if x[0].isdigit() else x[0], x[1]), + ) + + current_param_value = None + for param_value, spec in sorted_keys: + # Print header when param value changes + if param_value != current_param_value: + console.print(f"\n [bold]{sweep.param_name}={param_value}:[/]") + current_param_value = param_value + + results = by_param_and_spec[(param_value, spec)] + best = min(results, key=lambda r: r.mean_time) + + # Extract original backend name using the mapping + backend_name = backend_mapping[best.config.backend] + + # Show all backends' times for comparison + times_str = " | ".join( + [ + f"{backend_mapping[r.config.backend]}: {r.mean_time:.6f}s" + for r in sorted(results, key=lambda r: r.mean_time) + ] + ) + + console.print( + f" {spec:12s} -> [bold green]{backend_name:15s}[/] ({times_str})" + ) + + return all_results + + +def run_parameter_sweep( + backends: list[str], + batch_specs: list[str], + base_config_args: dict, + sweep: ParameterSweep, + console: Console, +) -> list[BenchmarkResult]: + """ + Run parameter sweep for given backends and batch specs. + + Args: + backends: List of backend names + batch_specs: List of batch specifications + base_config_args: Base configuration arguments (num_layers, head_dim, etc.) + sweep: ParameterSweep configuration + console: Rich console for output + + Returns: + List of BenchmarkResult objects + """ + all_results = [] + + # Build list of values to sweep (including auto if requested) + sweep_values = list(sweep.values) + if sweep.include_auto: + sweep_values.append("auto") + + console.print(f"[yellow]Sweep mode: testing {sweep.param_name} = {sweep_values}[/]") + + total = len(backends) * len(batch_specs) * len(sweep_values) + + with tqdm(total=total, desc="Benchmarking") as pbar: + for backend in backends: + for spec in batch_specs: + for value in sweep_values: + # Create config with original backend for running + config = BenchmarkConfig( + backend=backend, batch_spec=spec, **base_config_args + ) + + # Prepare kwargs for benchmark runner + kwargs = {} + if value != "auto": + kwargs[sweep.param_name] = value + + # Run benchmark + result = run_benchmark(config, **kwargs) + + # Replace backend with labeled version for display + backend_label = sweep.get_label(backend, value) + labeled_config = replace(result.config, backend=backend_label) + result = replace(result, config=labeled_config) + all_results.append(result) + + if not result.success: + console.print( + f"[red]Error {backend} {spec} {sweep.param_name}=" + f"{value}: {result.error}[/]" + ) + + pbar.update(1) + + # Display sweep results + console.print("\n[bold green]Sweep Results:[/]") + backend_labels = [sweep.get_label(b, v) for b in backends for v in sweep_values] + formatter = ResultsFormatter(console) + formatter.print_table(all_results, backend_labels) + + # Show optimal values + console.print(f"\n[bold cyan]Optimal {sweep.param_name} per batch spec:[/]") + by_spec = {} + for r in all_results: + if r.success: + spec = r.config.batch_spec + if spec not in by_spec: + by_spec[spec] = [] + by_spec[spec].append(r) + + for spec in sorted(by_spec.keys()): + results = by_spec[spec] + best = min(results, key=lambda r: r.mean_time) + console.print( + f" {spec}: [bold green]{best.config.backend}[/] ({best.mean_time:.6f}s)" + ) + + return all_results + + +def load_config_from_yaml(config_path: str) -> dict: + """Load configuration from YAML file.""" + with open(config_path) as f: + return yaml.safe_load(f) + + +def generate_batch_specs_from_ranges(ranges: list[dict]) -> list[str]: + """ + Generate batch specs from range specifications. + + Args: + ranges: List of range specifications, each containing: + - template: Batch spec template (e.g., "q{q_len}kv1k") + - q_len: Dict with start, stop, step, end_inclusive (optional) + - Other parameters can also be ranges + + Returns: + List of generated batch spec strings + + Example: + ranges = [ + { + "template": "q{q_len}kv1k", + "q_len": { + "start": 1, + "stop": 16, + "step": 1, + "end_inclusive": true # Optional, defaults to true + } + } + ] + Returns: ["q1kv1k", "q2kv1k", ..., "q16kv1k"] + """ + all_specs = [] + + for range_spec in ranges: + template = range_spec.get("template") + if not template: + raise ValueError("Range specification must include 'template'") + + # Extract all range parameters from the spec + range_params = {} + for key, value in range_spec.items(): + if key == "template": + continue + if isinstance(value, dict) and "start" in value: + # This is a range specification + start = value["start"] + stop = value["stop"] + step = value.get("step", 1) + # Check if end should be inclusive (default: True) + end_inclusive = value.get("end_inclusive", True) + + # Adjust stop based on end_inclusive + if end_inclusive: + range_params[key] = list(range(start, stop + 1, step)) + else: + range_params[key] = list(range(start, stop, step)) + else: + # This is a fixed value + range_params[key] = [value] + + # Generate all combinations (Cartesian product) + if range_params: + import itertools + + param_names = list(range_params.keys()) + param_values = [range_params[name] for name in param_names] + + for values in itertools.product(*param_values): + params = dict(zip(param_names, values)) + spec = template.format(**params) + all_specs.append(spec) + else: + # No parameters, just use template as-is + all_specs.append(template) + + return all_specs + + +def main(): + parser = argparse.ArgumentParser( + description="Universal vLLM attention benchmark", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Config file + parser.add_argument( + "--config", + help="Path to YAML config file (overrides other args)", + ) + + # Backend selection + parser.add_argument( + "--backends", + nargs="+", + help="Backends to benchmark (flash, triton, flashinfer, cutlass_mla, " + "flashinfer_mla, flashattn_mla, flashmla)", + ) + parser.add_argument( + "--backend", + help="Single backend (alternative to --backends)", + ) + + # Batch specifications + parser.add_argument( + "--batch-specs", + nargs="+", + default=["q2k", "8q1s1k"], + help="Batch specifications using extended grammar", + ) + + # Model config + parser.add_argument("--num-layers", type=int, default=10, help="Number of layers") + parser.add_argument("--head-dim", type=int, default=128, help="Head dimension") + parser.add_argument("--num-q-heads", type=int, default=32, help="Query heads") + parser.add_argument("--num-kv-heads", type=int, default=8, help="KV heads") + parser.add_argument("--block-size", type=int, default=16, help="Block size") + + # Benchmark settings + parser.add_argument("--device", default="cuda:0", help="Device") + parser.add_argument("--repeats", type=int, default=1, help="Repetitions") + parser.add_argument("--warmup-iters", type=int, default=3, help="Warmup iterations") + parser.add_argument("--profile-memory", action="store_true", help="Profile memory") + + # Parameter sweep (use YAML config for advanced sweeps) + parser.add_argument( + "--sweep-param", + help="Parameter name to sweep (e.g., num_kv_splits, reorder_batch_threshold)", + ) + parser.add_argument( + "--sweep-values", + type=int, + nargs="+", + help="Values to sweep for the parameter", + ) + + # Output + parser.add_argument("--output-csv", help="Save to CSV") + parser.add_argument("--output-json", help="Save to JSON") + + args = parser.parse_args() + + console = Console() + console.print("[bold cyan]vLLM Attention Benchmark[/]") + + # Load config from YAML if provided + if args.config: + console.print(f"[yellow]Loading config from: {args.config}[/]") + yaml_config = load_config_from_yaml(args.config) + + # Show description if available + if "description" in yaml_config: + console.print(f"[dim]{yaml_config['description']}[/]") + + # Override args with YAML values + # (YAML takes precedence unless CLI arg was explicitly set) + # Backend(s) + if "backend" in yaml_config: + args.backend = yaml_config["backend"] + args.backends = None + elif "backends" in yaml_config: + args.backends = yaml_config["backends"] + args.backend = None + + # Check for special modes + if "mode" in yaml_config: + args.mode = yaml_config["mode"] + else: + args.mode = None + + # Batch specs and sizes + # Support both explicit batch_specs and generated batch_spec_ranges + if "batch_spec_ranges" in yaml_config: + # Generate batch specs from ranges + generated_specs = generate_batch_specs_from_ranges( + yaml_config["batch_spec_ranges"] + ) + # Combine with any explicit batch_specs + if "batch_specs" in yaml_config: + args.batch_specs = yaml_config["batch_specs"] + generated_specs + else: + args.batch_specs = generated_specs + console.print( + f"[dim]Generated {len(generated_specs)} batch specs from ranges[/]" + ) + elif "batch_specs" in yaml_config: + args.batch_specs = yaml_config["batch_specs"] + + if "batch_sizes" in yaml_config: + args.batch_sizes = yaml_config["batch_sizes"] + else: + args.batch_sizes = None + + # Model config + if "model" in yaml_config: + model = yaml_config["model"] + args.num_layers = model.get("num_layers", args.num_layers) + args.head_dim = model.get("head_dim", args.head_dim) + args.num_q_heads = model.get("num_q_heads", args.num_q_heads) + args.num_kv_heads = model.get("num_kv_heads", args.num_kv_heads) + args.block_size = model.get("block_size", args.block_size) + + # Benchmark settings + if "benchmark" in yaml_config: + bench = yaml_config["benchmark"] + args.device = bench.get("device", args.device) + args.repeats = bench.get("repeats", args.repeats) + args.warmup_iters = bench.get("warmup_iters", args.warmup_iters) + args.profile_memory = bench.get("profile_memory", args.profile_memory) + + # Parameter sweep configuration + if "parameter_sweep" in yaml_config: + sweep_config = yaml_config["parameter_sweep"] + args.parameter_sweep = ParameterSweep( + param_name=sweep_config["param_name"], + values=sweep_config["values"], + include_auto=sweep_config.get("include_auto", False), + label_format=sweep_config.get( + "label_format", "{backend}_{param_name}_{value}" + ), + ) + else: + args.parameter_sweep = None + + # Model parameter sweep configuration + if "model_parameter_sweep" in yaml_config: + sweep_config = yaml_config["model_parameter_sweep"] + args.model_parameter_sweep = ModelParameterSweep( + param_name=sweep_config["param_name"], + values=sweep_config["values"], + label_format=sweep_config.get( + "label_format", "{backend}_{param_name}_{value}" + ), + ) + else: + args.model_parameter_sweep = None + + # Output + if "output" in yaml_config: + output = yaml_config["output"] + if "csv" in output and not args.output_csv: + args.output_csv = output["csv"] + if "json" in output and not args.output_json: + args.output_json = output["json"] + + console.print() + + # Handle CLI-based parameter sweep (if not from YAML) + if ( + (not hasattr(args, "parameter_sweep") or args.parameter_sweep is None) + and args.sweep_param + and args.sweep_values + ): + args.parameter_sweep = ParameterSweep( + param_name=args.sweep_param, + values=args.sweep_values, + include_auto=False, + label_format="{backend}_{param_name}_{value}", + ) + + # Determine backends + backends = args.backends or ([args.backend] if args.backend else ["flash"]) + console.print(f"Backends: {', '.join(backends)}") + console.print(f"Batch specs: {', '.join(args.batch_specs)}") + console.print() + + # Run benchmarks + all_results = [] + + # Handle special mode: decode_vs_prefill comparison + if hasattr(args, "mode") and args.mode == "decode_vs_prefill": + console.print("[yellow]Mode: Decode vs Prefill pipeline comparison[/]") + console.print( + "[dim]For each query length, testing both decode and prefill pipelines[/]" + ) + console.print("[dim]Using batched execution for optimal performance[/]") + + # Extract batch sizes from config + batch_sizes = getattr(args, "batch_sizes", [1]) + backend = backends[0] # Use first backend (should only be one) + + # Calculate total benchmarks + total = len(batch_sizes) + + with tqdm(total=total, desc="Benchmarking") as pbar: + for batch_size in batch_sizes: + # Prepare all configs for this batch size + configs_with_thresholds = [] + + for spec in args.batch_specs: + # Parse the batch spec to get query length + requests = parse_batch_spec(spec) + if not requests: + console.print( + f"[red]Error: Could not parse batch spec '{spec}'[/]" + ) + continue + + # Get query length from first request + query_length = requests[0].q_len + + # Create batch spec for this batch size + # For batch_size > 1, we need to prepend the count + batch_spec = f"{batch_size}{spec}" if batch_size > 1 else spec + + # Create base config (without backend name) + base_config = BenchmarkConfig( + backend=backend, # Will be overridden later + batch_spec=batch_spec, + num_layers=args.num_layers, + head_dim=args.head_dim, + num_q_heads=args.num_q_heads, + num_kv_heads=args.num_kv_heads, + block_size=args.block_size, + device=args.device, + repeats=args.repeats, + warmup_iters=args.warmup_iters, + profile_memory=args.profile_memory, + ) + + # Add decode pipeline config + decode_threshold = query_length + config_decode = replace( + base_config, + backend=f"{backend}_decode_qlen{query_length}_bs{batch_size}", + ) + configs_with_thresholds.append((config_decode, decode_threshold)) + + # Add prefill pipeline config if query_length > 1 + if query_length > 1: + prefill_threshold = query_length - 1 + config_prefill = replace( + base_config, + backend=f"{backend}_prefill_qlen{query_length}" + f"_bs{batch_size}", + ) + configs_with_thresholds.append( + (config_prefill, prefill_threshold) + ) + + # Run all benchmarks for this batch size in one go (batched mode) + try: + from mla_runner import run_mla_benchmark as run_mla + + # Use batched API: pass list of (config, threshold) tuples + timing_results = run_mla(backend, configs_with_thresholds) + + # Create BenchmarkResult objects from timing results + for (config, _), timing in zip( + configs_with_thresholds, timing_results + ): + result = BenchmarkResult( + config=config, + mean_time=timing["mean"], + std_time=timing["std"], + min_time=timing["min"], + max_time=timing["max"], + throughput_tokens_per_sec=timing.get("throughput", None), + ) + all_results.append(result) + + except Exception as e: + import traceback + + console.print( + f"[red]Error running batched benchmarks for " + f"batch_size={batch_size}: {e}[/]" + ) + console.print("[red]Traceback:[/]") + traceback.print_exc() + # Add error results for all configs + for config, _ in configs_with_thresholds: + result = BenchmarkResult( + config=config, + mean_time=float("inf"), + std_time=0, + min_time=float("inf"), + max_time=float("inf"), + error=str(e), + ) + all_results.append(result) + + pbar.update(1) + + # Display decode vs prefill results + console.print("\n[bold green]Decode vs Prefill Results:[/]") + + # Group by batch size + by_batch_size = {} + for r in all_results: + if r.success: + # Extract batch size from backend name + parts = r.config.backend.split("_") + bs_part = [p for p in parts if p.startswith("bs")] + if bs_part: + bs = int(bs_part[0][2:]) + if bs not in by_batch_size: + by_batch_size[bs] = [] + by_batch_size[bs].append(r) + + # For each batch size, analyze crossover point + for bs in sorted(by_batch_size.keys()): + console.print(f"\n[bold cyan]Batch size: {bs}[/]") + results = by_batch_size[bs] + + # Group by query length + by_qlen = {} + for r in results: + parts = r.config.backend.split("_") + qlen_part = [p for p in parts if p.startswith("qlen")] + if qlen_part: + qlen = int(qlen_part[0][4:]) + if qlen not in by_qlen: + by_qlen[qlen] = {} + + pipeline = "decode" if "decode" in r.config.backend else "prefill" + by_qlen[qlen][pipeline] = r + + # Find crossover point + last_decode_faster = None + for qlen in sorted(by_qlen.keys()): + pipelines = by_qlen[qlen] + if "decode" in pipelines and "prefill" in pipelines: + decode_time = pipelines["decode"].mean_time + prefill_time = pipelines["prefill"].mean_time + faster = "decode" if decode_time < prefill_time else "prefill" + + speedup = ( + prefill_time / decode_time + if decode_time < prefill_time + else decode_time / prefill_time + ) + + console.print( + f" qlen={qlen:3d}: decode={decode_time:.6f}s, " + f"prefill={prefill_time:.6f}s -> " + f"[bold]{faster}[/] ({speedup:.2f}x)" + ) + + if faster == "decode": + last_decode_faster = qlen + + if last_decode_faster is not None: + optimal_threshold = last_decode_faster + console.print( + f"\n [bold green]Optimal threshold for batch_size={bs}: " + f"{optimal_threshold}[/]" + ) + console.print( + f" [dim](Use decode pipeline for query_length <= " + f"{optimal_threshold})[/]" + ) + else: + console.print( + f"\n [yellow]Prefill always faster for batch_size={bs}[/]" + ) + + # Handle model parameter sweep mode + elif hasattr(args, "model_parameter_sweep") and args.model_parameter_sweep: + # Model parameter sweep + base_config_args = { + "num_layers": args.num_layers, + "head_dim": args.head_dim, + "num_q_heads": args.num_q_heads, + "num_kv_heads": args.num_kv_heads, + "block_size": args.block_size, + "device": args.device, + "repeats": args.repeats, + "warmup_iters": args.warmup_iters, + "profile_memory": args.profile_memory, + } + all_results = run_model_parameter_sweep( + backends, + args.batch_specs, + base_config_args, + args.model_parameter_sweep, + console, + ) + + # Handle parameter sweep mode (unified) + elif hasattr(args, "parameter_sweep") and args.parameter_sweep: + # Unified parameter sweep + base_config_args = { + "num_layers": args.num_layers, + "head_dim": args.head_dim, + "num_q_heads": args.num_q_heads, + "num_kv_heads": args.num_kv_heads, + "block_size": args.block_size, + "device": args.device, + "repeats": args.repeats, + "warmup_iters": args.warmup_iters, + "profile_memory": args.profile_memory, + } + all_results = run_parameter_sweep( + backends, args.batch_specs, base_config_args, args.parameter_sweep, console + ) + + else: + # Normal mode: compare backends + total = len(backends) * len(args.batch_specs) + + with tqdm(total=total, desc="Benchmarking") as pbar: + for spec in args.batch_specs: + for backend in backends: + config = BenchmarkConfig( + backend=backend, + batch_spec=spec, + num_layers=args.num_layers, + head_dim=args.head_dim, + num_q_heads=args.num_q_heads, + num_kv_heads=args.num_kv_heads, + block_size=args.block_size, + device=args.device, + repeats=args.repeats, + warmup_iters=args.warmup_iters, + profile_memory=args.profile_memory, + ) + + result = run_benchmark(config) + all_results.append(result) + + if not result.success: + console.print(f"[red]Error {backend} {spec}: {result.error}[/]") + + pbar.update(1) + + # Display results + console.print("\n[bold green]Results:[/]") + formatter = ResultsFormatter(console) + formatter.print_table(all_results, backends) + + # Save results + if all_results: + formatter = ResultsFormatter(console) + if args.output_csv: + formatter.save_csv(all_results, args.output_csv) + if args.output_json: + formatter.save_json(all_results, args.output_json) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/attention_benchmarks/common.py b/benchmarks/attention_benchmarks/common.py new file mode 100644 index 000000000000..7155bdc3fc5b --- /dev/null +++ b/benchmarks/attention_benchmarks/common.py @@ -0,0 +1,503 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Common utilities for attention benchmarking.""" + +import csv +import json +import math +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from rich.console import Console +from rich.table import Table + +# Mock classes for vLLM attention infrastructure + + +class MockHfConfig: + """Mock HuggingFace config that satisfies vLLM's requirements.""" + + def __init__(self, mla_dims: dict): + self.num_attention_heads = mla_dims["num_q_heads"] + self.num_key_value_heads = mla_dims["num_kv_heads"] + self.hidden_size = mla_dims["head_dim"] * mla_dims["num_q_heads"] + self.model_type = "deepseek_v2" + self.is_encoder_decoder = False + self.kv_lora_rank = mla_dims["kv_lora_rank"] + self.qk_nope_head_dim = mla_dims["qk_nope_head_dim"] + self.qk_rope_head_dim = mla_dims["qk_rope_head_dim"] + self.v_head_dim = mla_dims["v_head_dim"] + self.qk_head_dim = mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"] + + def get_text_config(self): + return self + + +# Import AttentionLayerBase at module level to avoid circular dependencies +try: + from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase + + _HAS_ATTENTION_LAYER_BASE = True +except ImportError: + _HAS_ATTENTION_LAYER_BASE = False + AttentionLayerBase = object # Fallback + + +class MockKVBProj: + """Mock KV projection layer for MLA prefill mode. + + Mimics ColumnParallelLinear behavior for kv_b_proj in MLA backends. + Projects kv_c_normed to [qk_nope_head_dim + v_head_dim] per head. + """ + + def __init__(self, num_heads: int, qk_nope_head_dim: int, v_head_dim: int): + self.num_heads = num_heads + self.qk_nope_head_dim = qk_nope_head_dim + self.v_head_dim = v_head_dim + self.out_dim = qk_nope_head_dim + v_head_dim + + def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]: + """ + Project kv_c_normed to output space. + + Args: + x: Input tensor [num_tokens, kv_lora_rank] + + Returns: + Tuple containing output tensor + [num_tokens, num_heads, qk_nope_head_dim + v_head_dim] + """ + num_tokens = x.shape[0] + result = torch.randn( + num_tokens, + self.num_heads, + self.out_dim, + device=x.device, + dtype=x.dtype, + ) + return (result,) # Return as tuple to match ColumnParallelLinear API + + +class MockLayer(AttentionLayerBase): + """Mock attention layer with scale parameters and impl. + + Inherits from AttentionLayerBase so it passes isinstance checks + in get_layers_from_vllm_config when FlashInfer prefill is enabled. + """ + + def __init__(self, device: torch.device, impl=None, kv_cache_spec=None): + # Don't call super().__init__() as AttentionLayerBase doesn't have __init__ + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + self._q_scale = torch.tensor(1.0, device=device) + # Scalar floats for kernels that need them + self._k_scale_float = float(self._k_scale.item()) + self._v_scale_float = float(self._v_scale.item()) + self._q_scale_float = float(self._q_scale.item()) + # AttentionImpl for metadata builders to query + self.impl = impl + # KV cache spec for get_kv_cache_spec + self._kv_cache_spec = kv_cache_spec + + def get_attn_backend(self): + """Get the attention backend class (required by AttentionLayerBase).""" + # Return None as this is just a mock layer for benchmarking + return None + + def get_kv_cache_spec(self): + """Get the KV cache spec (required by AttentionLayerBase).""" + return self._kv_cache_spec + + +class MockModelConfig: + """Mock model configuration.""" + + def __init__( + self, + num_q_heads: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype = torch.float16, + max_model_len: int = 32768, + ): + self._n_q = num_q_heads + self._n_kv = num_kv_heads + self._d = head_dim + self.dtype = dtype + self.max_model_len = max_model_len + + def get_num_attention_heads(self, _=None) -> int: + return self._n_q + + def get_num_kv_heads(self, _=None) -> int: + return self._n_kv + + def get_head_size(self) -> int: + return self._d + + def get_num_layers(self) -> int: + """Mock method for layer count queries.""" + return 1 + + def get_sliding_window_for_layer(self, _layer_idx: int): + """Mock method for sliding window queries.""" + return None + + def get_logits_soft_cap_for_layer(self, _layer_idx: int): + """Mock method for logits soft cap queries.""" + return None + + def get_sm_scale_for_layer(self, _layer_idx: int) -> float: + """Mock method for SM scale queries.""" + return 1.0 / (self.get_head_size() ** 0.5) + + +class MockParallelConfig: + """Mock parallel configuration.""" + + pass + + +class MockCompilationConfig: + """Mock compilation configuration.""" + + def __init__(self): + self.full_cuda_graph = False + self.static_forward_context = {} + + +class MockVLLMConfig: + """Mock VLLM configuration.""" + + def __init__(self): + self.compilation_config = MockCompilationConfig() + + +class MockRunner: + """Mock GPU runner for metadata builders.""" + + def __init__( + self, + seq_lens: np.ndarray, + query_start_locs: np.ndarray, + device: torch.device, + num_q_heads: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype, + ): + self.model_config = MockModelConfig(num_q_heads, num_kv_heads, head_dim, dtype) + self.parallel_config = MockParallelConfig() + self.vllm_config = MockVLLMConfig() + self.seq_lens_np = seq_lens + self.query_start_loc_np = query_start_locs + self.device = device + self.attention_chunk_size = None + self.num_query_heads = num_q_heads + self.num_kv_heads = num_kv_heads + self.dtype = dtype + + +@dataclass +class ParameterSweep: + """Configuration for sweeping a backend parameter.""" + + param_name: str # Name of the backend parameter to sweep + values: list[Any] # List of values to test + include_auto: bool = False # Also test with param unset (auto mode) + label_format: str = "{backend}_{param_name}_{value}" # Result label template + + def get_label(self, backend: str, value: Any) -> str: + """Generate a label for a specific parameter value.""" + return self.label_format.format( + backend=backend, param_name=self.param_name, value=value + ) + + +@dataclass +class ModelParameterSweep: + """Configuration for sweeping a model configuration parameter.""" + + param_name: str # Name of the model config parameter to sweep (e.g., "num_q_heads") + values: list[Any] # List of values to test + label_format: str = "{backend}_{param_name}_{value}" # Result label template + + def get_label(self, backend: str, value: Any) -> str: + """Generate a label for a specific parameter value.""" + return self.label_format.format( + backend=backend, param_name=self.param_name, value=value + ) + + +@dataclass +class BenchmarkConfig: + """Configuration for a single benchmark run.""" + + backend: str + batch_spec: str + num_layers: int + head_dim: int + num_q_heads: int + num_kv_heads: int + block_size: int + device: str + dtype: torch.dtype = torch.float16 + repeats: int = 1 + warmup_iters: int = 3 + profile_memory: bool = False + use_cuda_graphs: bool = False + + # MLA-specific + kv_lora_rank: int | None = None + qk_nope_head_dim: int | None = None + qk_rope_head_dim: int | None = None + v_head_dim: int | None = None + + # Backend-specific tuning + num_kv_splits: int | None = None # CUTLASS MLA + reorder_batch_threshold: int | None = None # FlashAttn MLA, FlashMLA + + +@dataclass +class BenchmarkResult: + """Results from a single benchmark run.""" + + config: BenchmarkConfig + mean_time: float # seconds + std_time: float # seconds + min_time: float # seconds + max_time: float # seconds + throughput_tokens_per_sec: float | None = None + memory_allocated_mb: float | None = None + memory_reserved_mb: float | None = None + error: str | None = None + + @property + def success(self) -> bool: + """Whether benchmark completed successfully.""" + return self.error is None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "config": asdict(self.config), + "mean_time": self.mean_time, + "std_time": self.std_time, + "min_time": self.min_time, + "max_time": self.max_time, + "throughput_tokens_per_sec": self.throughput_tokens_per_sec, + "memory_allocated_mb": self.memory_allocated_mb, + "memory_reserved_mb": self.memory_reserved_mb, + "error": self.error, + } + + +class ResultsFormatter: + """Format and display benchmark results.""" + + def __init__(self, console: Console | None = None): + self.console = console or Console() + + def print_table( + self, + results: list[BenchmarkResult], + backends: list[str], + compare_to_fastest: bool = True, + ): + """ + Print results as a rich table. + + Args: + results: List of BenchmarkResult + backends: List of backend names being compared + compare_to_fastest: Show percentage comparison to fastest + """ + # Group by batch spec + by_spec = {} + for r in results: + spec = r.config.batch_spec + if spec not in by_spec: + by_spec[spec] = {} + by_spec[spec][r.config.backend] = r + + # Create shortened backend names for display + def shorten_backend_name(name: str) -> str: + """Shorten long backend names for table display.""" + # Remove common prefixes + name = name.replace("flashattn_mla", "famla") + name = name.replace("flashinfer_mla", "fimla") + name = name.replace("flashmla", "fmla") + name = name.replace("cutlass_mla", "cmla") + name = name.replace("numsplits", "ns") + return name + + table = Table(title="Attention Benchmark Results") + table.add_column("Batch\nSpec", no_wrap=True) + + multi = len(backends) > 1 + for backend in backends: + short_name = shorten_backend_name(backend) + # Time column + col_time = f"{short_name}\nTime (s)" + table.add_column(col_time, justify="right", no_wrap=False) + if multi and compare_to_fastest: + # Relative performance column + col_rel = f"{short_name}\nvs Best" + table.add_column(col_rel, justify="right", no_wrap=False) + + # Add rows + for spec in sorted(by_spec.keys()): + spec_results = by_spec[spec] + times = {b: r.mean_time for b, r in spec_results.items() if r.success} + best_time = min(times.values()) if times else 0.0 + + row = [spec] + for backend in backends: + if backend in spec_results: + r = spec_results[backend] + if r.success: + row.append(f"{r.mean_time:.6f}") + if multi and compare_to_fastest: + pct = ( + (r.mean_time / best_time * 100) if best_time > 0 else 0 + ) + pct_str = f"{pct:.1f}%" + if r.mean_time == best_time: + pct_str = f"[bold green]{pct_str}[/]" + row.append(pct_str) + else: + row.append("[red]ERROR[/]") + if multi and compare_to_fastest: + row.append("-") + else: + row.append("-") + if multi and compare_to_fastest: + row.append("-") + + table.add_row(*row) + + self.console.print(table) + + def save_csv(self, results: list[BenchmarkResult], path: str): + """Save results to CSV file.""" + if not results: + return + + path_obj = Path(path) + path_obj.parent.mkdir(parents=True, exist_ok=True) + + with open(path, "w", newline="") as f: + writer = csv.DictWriter( + f, + fieldnames=[ + "backend", + "batch_spec", + "num_layers", + "mean_time", + "std_time", + "throughput", + "memory_mb", + ], + ) + writer.writeheader() + for r in results: + writer.writerow( + { + "backend": r.config.backend, + "batch_spec": r.config.batch_spec, + "num_layers": r.config.num_layers, + "mean_time": r.mean_time, + "std_time": r.std_time, + "throughput": r.throughput_tokens_per_sec or 0, + "memory_mb": r.memory_allocated_mb or 0, + } + ) + + self.console.print(f"[green]Saved CSV results to {path}[/]") + + def save_json(self, results: list[BenchmarkResult], path: str): + """Save results to JSON file.""" + path_obj = Path(path) + path_obj.parent.mkdir(parents=True, exist_ok=True) + + data = [r.to_dict() for r in results] + with open(path, "w") as f: + json.dump(data, f, indent=2, default=str) + + self.console.print(f"[green]Saved JSON results to {path}[/]") + + +def setup_mla_dims(model_name: str = "deepseek-v3") -> dict: + """ + Get MLA dimensions for known models. + + Args: + model_name: Model identifier + + Returns: + Dict with MLA dimension configuration + """ + configs = { + "deepseek-v2": { + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "num_q_heads": 128, + "num_kv_heads": 1, + "head_dim": 576, + }, + "deepseek-v3": { + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "num_q_heads": 128, + "num_kv_heads": 1, + "head_dim": 576, + }, + "deepseek-v2-lite": { + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "num_q_heads": 16, + "num_kv_heads": 1, + "head_dim": 576, + }, + } + + if model_name not in configs: + raise ValueError( + f"Unknown model '{model_name}'. Known models: {list(configs.keys())}" + ) + + return configs[model_name] + + +def get_attention_scale(head_dim: int) -> float: + """Compute attention scale factor (1/sqrt(d)).""" + return 1.0 / math.sqrt(head_dim) + + +def is_mla_backend(backend: str) -> bool: + """ + Check if backend is an MLA backend using the backend's is_mla() property. + + Args: + backend: Backend name (e.g., "CUTLASS_MLA", "FLASHINFER_MLA") + + Returns: + True if the backend is an MLA backend, False otherwise + """ + from vllm.v1.attention.backends.registry import AttentionBackendEnum + + try: + backend_class = AttentionBackendEnum[backend.upper()].get_class() + return backend_class.is_mla() + except (KeyError, ValueError, ImportError): + return False diff --git a/benchmarks/attention_benchmarks/configs/mla_decode.yaml b/benchmarks/attention_benchmarks/configs/mla_decode.yaml new file mode 100644 index 000000000000..aaf4eec9b1c8 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_decode.yaml @@ -0,0 +1,61 @@ +# MLA decode-only benchmark configuration + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 + num_kv_heads: 1 # MLA uses single latent KV + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128 + +batch_specs: + # Small batches, varying sequence lengths + - "16q1s512" # 16 requests, 512 KV cache + - "16q1s1k" # 16 requests, 1k KV cache + - "16q1s2k" # 16 requests, 2k KV cache + - "16q1s4k" # 16 requests, 4k KV cache + + # Medium batches + - "32q1s1k" # 32 requests, 1k KV cache + - "32q1s2k" # 32 requests, 2k KV cache + - "32q1s4k" # 32 requests, 4k KV cache + - "32q1s8k" # 32 requests, 8k KV cache + + # Large batches + - "64q1s1k" # 64 requests, 1k KV cache + - "64q1s2k" # 64 requests, 2k KV cache + - "64q1s4k" # 64 requests, 4k KV cache + - "64q1s8k" # 64 requests, 8k KV cache + + # Very large batches + - "128q1s1k" # 128 requests, 1k KV cache + - "128q1s2k" # 128 requests, 2k KV cache + + # Long context + - "32q1s16k" # 32 requests, 16k KV cache + - "32q1s32k" # 32 requests, 32k KV cache + +backends: + - cutlass_mla + - flashinfer_mla + - flashattn_mla # Hopper only + - flashmla # Hopper only + +device: "cuda:0" +repeats: 5 +warmup_iters: 3 +profile_memory: true + +# Backend-specific tuning +cutlass_mla: + num_kv_splits: auto # or specific value like 4, 8, 16 + +flashattn_mla: + reorder_batch_threshold: 512 + +flashmla: + reorder_batch_threshold: 1 diff --git a/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml b/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml new file mode 100644 index 000000000000..ad3c0dced6ec --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml @@ -0,0 +1,60 @@ +# MLA mixed batch benchmark (prefill + decode) +# Tests chunked prefill performance + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 + num_kv_heads: 1 + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 + +batch_specs: + # Small prefill + decode + - "1q1k_8q1s1k" # 1 prefill + 8 decode + - "2q2k_16q1s1k" # 2 prefill + 16 decode + - "4q1k_32q1s2k" # 4 prefill + 32 decode + + # Medium prefill + decode + - "2q4k_32q1s2k" # 2 medium prefill + 32 decode + - "4q4k_64q1s2k" # 4 medium prefill + 64 decode + - "8q2k_64q1s4k" # 8 prefill + 64 decode + + # Large prefill + decode (chunked prefill stress test) + - "2q8k_32q1s1k" # 2 large prefill + 32 decode + - "1q16k_16q1s2k" # 1 very large prefill + 16 decode + - "2q16k_32q1s4k" # 2 very large prefill + 32 decode + + # Context extension + decode + - "2q1kkv2k_16q1s1k" # 2 extend + 16 decode + - "4q2kkv4k_32q1s2k" # 4 extend + 32 decode + - "2q1kkv8k_32q1s2k" # 2 large extend + 32 decode + + # Explicitly chunked prefill + - "q8k" # 8k prefill with chunking hint + - "q16k" # 16k prefill with chunking hint + - "2q8k_32q1s2k" # 2 chunked prefill + 32 decode + + # High decode ratio (realistic serving) + - "1q2k_63q1s1k" # 1 prefill + 63 decode + - "2q2k_62q1s2k" # 2 prefill + 62 decode + - "4q4k_60q1s4k" # 4 prefill + 60 decode + +backends: + - cutlass_mla + - flashinfer_mla + - flashattn_mla # Hopper only + - flashmla # Hopper only + +device: "cuda:0" +repeats: 5 +warmup_iters: 3 +profile_memory: true + +# Analyze chunked prefill workspace size impact +chunked_prefill: + test_workspace_sizes: [4096, 8192, 16384, 32768, 65536] diff --git a/benchmarks/attention_benchmarks/configs/reorder_threshold.yaml b/benchmarks/attention_benchmarks/configs/reorder_threshold.yaml new file mode 100644 index 000000000000..1ea0a12b5338 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/reorder_threshold.yaml @@ -0,0 +1,88 @@ +# Study 4: What is optimal reorder_batch_threshold for MLA backends supporting query length > 1? +# Question: At what query length does prefill pipeline become faster than decode pipeline? +# Methodology: For each query length, compare decode vs prefill performance to find crossover point +# Applies to: FlashAttn MLA, FlashMLA + +description: "Decode vs Prefill pipeline crossover analysis" + +# Test FlashAttn MLA +backend: flashattn_mla + +# Mode: decode_vs_prefill comparison (special sweep mode) +# For each batch spec, we'll test both decode and prefill pipelines +mode: "decode_vs_prefill" + +# Query lengths to test (from old benchmark_mla_threshold.py methodology) +# Each query length will be tested with BOTH decode and prefill pipelines: +# - decode: threshold >= query_length (forces decode pipeline) +# - prefill: threshold < query_length (forces prefill pipeline) +# +# We use qs1k format which creates q_len=N, seq_len=1024 requests +# This tests different query lengths with fixed sequence length context +# +# Using batch_spec_ranges for automatic generation: +batch_spec_ranges: + - template: "q{q_len}s1k" + q_len: + start: 1 + stop: 16 + step: 1 + end_inclusive: false + - template: "q{q_len}s1k" + q_len: + start: 16 + stop: 64 + step: 2 + end_inclusive: false + - template: "q{q_len}s1k" + q_len: + start: 64 + stop: 1024 + step: 4 + end_inclusive: true + +# Batch sizes to test (from old script) +batch_sizes: + - 1 + - 2 + - 4 + - 8 + - 16 + - 32 + - 64 + - 128 + - 256 + +# Model configuration (DeepSeek V2/V3 defaults) +model: + num_layers: 10 + head_dim: 576 + num_q_heads: 128 + num_kv_heads: 1 + block_size: 128 + +# Benchmark settings +benchmark: + device: "cuda:0" + repeats: 15 # More repeats for spec decode variance + warmup_iters: 5 + profile_memory: false + +# Output +output: + csv: "reorder_threshold_results.csv" + json: "reorder_threshold_results.json" + +# Expected outcome (reproduces old benchmark_mla_threshold.py study): +# - For each batch size, find the crossover point where prefill becomes faster than decode +# - Show decode vs prefill performance across all query lengths +# - Determine optimal reorder_batch_threshold based on last query length where decode is faster +# - Understand how crossover point varies with batch size +# - Provide data-driven guidance for default threshold value +# +# Methodology (from old script): +# - Each query length tested with BOTH pipelines: +# * decode: threshold >= query_length (forces decode pipeline) +# * prefill: threshold < query_length (forces prefill pipeline) +# - Compare which is faster to find crossover point +# diff --git a/benchmarks/attention_benchmarks/configs/speculative_decode.yaml b/benchmarks/attention_benchmarks/configs/speculative_decode.yaml new file mode 100644 index 000000000000..56d2428fe74f --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/speculative_decode.yaml @@ -0,0 +1,62 @@ +# Speculative decoding benchmark configuration +# Tests reorder_batch_threshold optimization + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 + num_kv_heads: 1 + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + +batch_specs: + # Pure speculative decode (K-token verification) + - "q2s1k" # 2-token spec, 1k KV + - "q4s1k" # 4-token spec, 1k KV + - "q8s1k" # 8-token spec, 1k KV + - "q16s1k" # 16-token spec, 1k KV + + # Speculative with different context lengths + - "q4s2k" # 4-token spec, 2k KV + - "q4s4k" # 4-token spec, 4k KV + - "q8s2k" # 8-token spec, 2k KV + - "q8s4k" # 8-token spec, 4k KV + + # Mixed: speculative + regular decode + - "32q4s1k" # 32 spec requests + - "16q4s1k_16q1s1k" # 16 spec + 16 regular + - "8q8s2k_24q1s2k" # 8 spec (8-tok) + 24 regular + + # Mixed: speculative + prefill + decode + - "2q1k_16q4s1k_16q1s1k" # 2 prefill + 16 spec + 16 decode + - "4q2k_32q4s2k_32q1s2k" # 4 prefill + 32 spec + 32 decode + + # Large batches with speculation + - "64q4s1k" # 64 spec requests + - "32q8s2k" # 32 spec (8-token) + - "16q16s4k" # 16 spec (16-token) + +# Backends that support query length > 1 +backends: + - flashattn_mla # reorder_batch_threshold = 512 + - flashmla # reorder_batch_threshold = 1 (tunable) + +# FlashInfer-MLA also supports uniform spec-as-decode but with different mechanism +# - flashinfer_mla + +# Benchmark settings +benchmark: + device: "cuda:0" + repeats: 10 # More repeats for statistical significance + warmup_iters: 5 + profile_memory: false + +# Test these threshold values for optimization +parameter_sweep: + param_name: "reorder_batch_threshold" + values: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + include_auto: false + label_format: "{backend}_threshold_{value}" diff --git a/benchmarks/attention_benchmarks/configs/standard_attention.yaml b/benchmarks/attention_benchmarks/configs/standard_attention.yaml new file mode 100644 index 000000000000..c0bdb98fbf62 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/standard_attention.yaml @@ -0,0 +1,40 @@ +# Standard attention backend benchmark configuration + +model: + num_layers: 32 + num_q_heads: 32 + num_kv_heads: 8 # GQA with 4:1 ratio + head_dim: 128 + block_size: 16 + +batch_specs: + # Pure prefill + - "q512" # Small prefill (512 tokens) + - "q2k" # Medium prefill (2048 tokens) + - "q4k" # Large prefill (4096 tokens) + - "q8k" # Very large prefill (8192 tokens) + + # Pure decode + - "8q1s1k" # 8 requests, 1k KV cache each + - "16q1s2k" # 16 requests, 2k KV cache each + - "32q1s1k" # 32 requests, 1k KV cache each + - "64q1s4k" # 64 requests, 4k KV cache each + + # Mixed prefill/decode + - "2q2k_8q1s1k" # 2 prefill + 8 decode + - "4q1k_16q1s2k" # 4 prefill + 16 decode + - "2q4k_32q1s1k" # 2 large prefill + 32 decode + + # Context extension + - "q1ks2k" # 1k query, 2k sequence (chunked prefill) + - "2q1ks4k" # 2 requests: 1k query, 4k sequence + +backends: + - flash + - triton + - flashinfer + +device: "cuda:0" +repeats: 5 +warmup_iters: 3 +profile_memory: false diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py new file mode 100644 index 000000000000..2c6c3aaac360 --- /dev/null +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -0,0 +1,836 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +MLA benchmark runner - shared utilities for MLA benchmarks. + +This module provides helpers for running MLA backends without +needing full VllmConfig integration. +""" + +import importlib + +import numpy as np +import torch +from batch_spec import parse_batch_spec +from common import ( + BenchmarkResult, + MockHfConfig, + MockKVBProj, + MockLayer, + setup_mla_dims, +) + +from vllm.config import ( + CacheConfig, + CompilationConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, + set_current_vllm_config, +) + +# ============================================================================ +# VllmConfig Creation +# ============================================================================ + + +def _add_mock_methods_to_model_config(model_config: ModelConfig) -> None: + """ + Add mock methods for layer-specific queries to ModelConfig. + + These methods are needed by metadata builders but aren't normally + present on ModelConfig when used in benchmark contexts. + """ + import types + + model_config.get_num_layers = types.MethodType(lambda self: 1, model_config) + model_config.get_sliding_window_for_layer = types.MethodType( + lambda self, _i: None, model_config + ) + model_config.get_logits_soft_cap_for_layer = types.MethodType( + lambda self, _i: None, model_config + ) + model_config.get_sm_scale_for_layer = types.MethodType( + lambda self, _i: 1.0 / model_config.get_head_size() ** 0.5, model_config + ) + + +def create_minimal_vllm_config( + model_name: str = "deepseek-v3", + block_size: int = 128, + max_num_seqs: int = 256, + mla_dims: dict | None = None, +) -> VllmConfig: + """ + Create minimal VllmConfig for MLA benchmarks. + + Args: + model_name: Model name (deepseek-v2, deepseek-v3, etc.) - used if mla_dims not + provided + block_size: KV cache block size + max_num_seqs: Maximum number of sequences + mla_dims: Optional custom MLA dimensions dict. If not provided, uses + setup_mla_dims(model_name) + + Returns: + VllmConfig for benchmarking + """ + # Get MLA dimensions - use provided or load from model name + if mla_dims is None: + mla_dims = setup_mla_dims(model_name) + + # Create mock HF config first (avoids downloading from HuggingFace) + mock_hf_config = MockHfConfig(mla_dims) + + # Create a temporary minimal config.json to avoid HF downloads + # This ensures consistent ModelConfig construction without network access + import json + import os + import shutil + import tempfile + + minimal_config = { + "architectures": ["DeepseekV2ForCausalLM"], + "model_type": "deepseek_v2", + "num_attention_heads": mla_dims["num_q_heads"], + "num_key_value_heads": mla_dims["num_kv_heads"], + "hidden_size": mla_dims["head_dim"] * mla_dims["num_q_heads"], + "torch_dtype": "bfloat16", + "max_position_embeddings": 163840, # DeepSeek V3 default + "rope_theta": 10000.0, + "vocab_size": 128256, + } + + # Create temporary directory with config.json + temp_dir = tempfile.mkdtemp(prefix="vllm_bench_") + config_path = os.path.join(temp_dir, "config.json") + with open(config_path, "w") as f: + json.dump(minimal_config, f) + + try: + # Create model config using local path - no HF downloads + model_config = ModelConfig( + model=temp_dir, # Use local temp directory + tokenizer=None, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="bfloat16", + seed=0, + max_model_len=32768, + quantization=None, + quantization_param_path=None, + enforce_eager=False, + max_context_len_to_capture=None, + max_seq_len_to_capture=8192, + max_logprobs=20, + disable_sliding_window=False, + skip_tokenizer_init=True, + served_model_name=None, + limit_mm_per_prompt=None, + use_async_output_proc=True, + config_format="auto", + ) + finally: + # Clean up temporary directory + shutil.rmtree(temp_dir, ignore_errors=True) + + # Override with our mock config + model_config.hf_config = mock_hf_config + model_config.hf_text_config = mock_hf_config + + # Add mock methods for layer-specific queries + _add_mock_methods_to_model_config(model_config) + + # Create sub-configs + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=False, + ) + + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=8192, + max_model_len=32768, + is_encoder_decoder=False, + enable_chunked_prefill=True, + ) + + parallel_config = ParallelConfig( + tensor_parallel_size=1, + ) + + compilation_config = CompilationConfig() + + return VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + compilation_config=compilation_config, + ) + + +# ============================================================================ +# Backend Configuration +# ============================================================================ + + +# Backend name to class name prefix mapping +_BACKEND_NAME_MAP = { + "flashattn_mla": "FlashAttnMLA", + "flashmla": "FlashMLA", + "flashinfer_mla": "FlashInferMLA", + "cutlass_mla": "CutlassMLA", +} + +# Special properties that differ from defaults +_BACKEND_PROPERTIES = { + "flashmla": { + "query_format": "concat", # Single concatenated tensor (vs tuple) + "block_size": 64, # FlashMLA uses fixed block size + }, + "flashinfer_mla": { + "block_size": 64, # FlashInfer MLA only supports 32 or 64 + }, +} + + +def _get_backend_config(backend: str) -> dict: + """ + Get backend configuration using naming conventions. + + All MLA backends follow the pattern: + - Module: vllm.v1.attention.backends.mla.{backend} + - Impl: {Name}Impl + - Metadata: {Name}Metadata (or MLACommonMetadata) + - DecodeMetadata: {Name}DecodeMetadata (or MLACommonDecodeMetadata) + - MetadataBuilder: {Name}MetadataBuilder + """ + if backend not in _BACKEND_NAME_MAP: + raise ValueError(f"Unknown backend: {backend}") + + name = _BACKEND_NAME_MAP[backend] + props = _BACKEND_PROPERTIES.get(backend, {}) + + # Check if backend uses common metadata (FlashInfer, CUTLASS) + uses_common = backend in ("flashinfer_mla", "cutlass_mla") + + return { + "module": f"vllm.v1.attention.backends.mla.{backend}", + "impl_class": f"{name}Impl", + "metadata_class": "MLACommonMetadata" if uses_common else f"{name}Metadata", + "decode_metadata_class": "MLACommonDecodeMetadata" + if uses_common + else f"{name}DecodeMetadata", + "builder_class": f"{name}MetadataBuilder", + "query_format": props.get("query_format", "tuple"), + "block_size": props.get("block_size", None), + } + + +# ============================================================================ +# Metadata Building Helpers +# ============================================================================ + + +def _build_attention_metadata( + requests: list, + block_size: int, + device: torch.device, + builder_instance, +) -> tuple: + """ + Build attention metadata from batch requests. + + Args: + requests: List of BatchRequest objects + block_size: KV cache block size + device: Target device + builder_instance: Metadata builder instance + + Returns: + Tuple of (metadata, kv_cache_num_blocks) + """ + q_lens = [r.q_len for r in requests] + kv_lens = [r.kv_len for r in requests] + total_q = sum(q_lens) + max_kv = max(kv_lens) + + # Build query start locations + q_start_cpu = torch.tensor( + [0] + [sum(q_lens[: i + 1]) for i in range(len(q_lens))], + dtype=torch.int32, + ) + q_start_gpu = q_start_cpu.to(device) + + # Build sequence lengths + seq_lens_cpu = torch.tensor(kv_lens, dtype=torch.int32) + seq_lens_gpu = seq_lens_cpu.to(device) + + # Build num_computed_tokens (context length for each request) + context_lens = [kv_len - q_len for q_len, kv_len in zip(q_lens, kv_lens)] + num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) + + # Build block table + num_blocks_per_req = [(kv + block_size - 1) // block_size for kv in kv_lens] + max_num_blocks = max(num_blocks_per_req) + + block_table_cpu = np.zeros((len(requests), max_num_blocks), dtype=np.int32) + current_block = 0 + for i, num_blocks in enumerate(num_blocks_per_req): + for j in range(num_blocks): + block_table_cpu[i, j] = current_block + current_block += 1 + + block_table_gpu = torch.from_numpy(block_table_cpu).to(device) + + # Build slot mapping + slot_mapping_list = [] + for i, (q_len, kv_len, num_blocks) in enumerate( + zip(q_lens, kv_lens, num_blocks_per_req) + ): + context_len = kv_len - q_len + for j in range(q_len): + token_kv_idx = context_len + j + block_idx = token_kv_idx // block_size + offset_in_block = token_kv_idx % block_size + global_block_id = block_table_cpu[i, block_idx] + slot_id = global_block_id * block_size + offset_in_block + slot_mapping_list.append(slot_id) + + slot_mapping = torch.tensor(slot_mapping_list, dtype=torch.int64, device=device) + + # Create CommonAttentionMetadata + from vllm.v1.attention.backends.utils import CommonAttentionMetadata + + common_attn_metadata = CommonAttentionMetadata( + num_reqs=len(requests), + max_query_len=max(q_lens), + max_seq_len=max_kv, + num_actual_tokens=total_q, + query_start_loc=q_start_gpu, + query_start_loc_cpu=q_start_cpu, + seq_lens=seq_lens_gpu, + _seq_lens_cpu=seq_lens_cpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, + slot_mapping=slot_mapping, + block_table_tensor=block_table_gpu, + dcp_local_seq_lens=None, + ) + + # Use the production build() method + metadata = builder_instance.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + fast_build=False, + ) + + return metadata, current_block + + +def _create_input_tensors( + total_q: int, + mla_dims: dict, + query_format: str, + device: torch.device, + dtype: torch.dtype, +): + """ + Create input tensors for both decode and prefill modes. + + MLA requires different tensor formats for decode vs prefill: + - Decode: Uses kv_lora_rank (512) dimension + - Prefill: Uses qk_nope_head_dim (128) to stay under FlashAttention's 256 limit + + Args: + total_q: Total number of query tokens + mla_dims: MLA dimension configuration + query_format: Either "tuple" or "concat" + device: Target device + dtype: Tensor dtype + + Returns: + Tuple of (decode_inputs, prefill_inputs) + - decode_inputs: Query tensor(s) for decode mode + - prefill_inputs: Dict with 'q', 'k_c_normed', 'k_pe', 'k_scale' for prefill + """ + if query_format == "tuple": + # Decode mode format: (q_nope, q_pe) where q_nope has kv_lora_rank dim + q_nope_decode = torch.randn( + total_q, + mla_dims["num_q_heads"], + mla_dims["kv_lora_rank"], + device=device, + dtype=dtype, + ) + q_pe = torch.randn( + total_q, + mla_dims["num_q_heads"], + mla_dims["qk_rope_head_dim"], + device=device, + dtype=dtype, + ) + decode_inputs = (q_nope_decode, q_pe) + + # For prefill, we need q with qk_nope_head_dim instead of kv_lora_rank + q_nope_prefill = torch.randn( + total_q, + mla_dims["num_q_heads"], + mla_dims["qk_nope_head_dim"], + device=device, + dtype=dtype, + ) + prefill_q = torch.cat([q_nope_prefill, q_pe], dim=-1) + else: # concat + decode_inputs = torch.randn( + total_q, + mla_dims["num_q_heads"], + mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"], + device=device, + dtype=dtype, + ) + # For prefill with concat format + prefill_q = torch.randn( + total_q, + mla_dims["num_q_heads"], + mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"], + device=device, + dtype=dtype, + ) + + # Create additional inputs needed for prefill forward + k_c_normed = torch.randn( + total_q, + mla_dims["kv_lora_rank"], + device=device, + dtype=dtype, + ) + k_pe = torch.randn( + total_q, + 1, # Single head for MLA + mla_dims["qk_rope_head_dim"], + device=device, + dtype=dtype, + ) + k_scale = torch.ones(1, device=device, dtype=torch.float32) + + output = torch.zeros( + total_q, + mla_dims["num_q_heads"] * mla_dims["v_head_dim"], + device=device, + dtype=dtype, + ) + + prefill_inputs = { + "q": prefill_q, + "k_c_normed": k_c_normed, + "k_pe": k_pe, + "k_scale": k_scale, + "output": output, + } + + return decode_inputs, prefill_inputs + + +# ============================================================================ +# Backend Initialization +# ============================================================================ + + +def _create_backend_impl( + backend_cfg: dict, + mla_dims: dict, + vllm_config: VllmConfig, + device: torch.device, +): + """ + Create backend implementation instance. + + Args: + backend_cfg: Backend configuration dict + mla_dims: MLA dimension configuration + vllm_config: VllmConfig instance + device: Target device + + Returns: + Tuple of (impl, layer, builder_instance) + """ + # Import backend classes + backend_module = importlib.import_module(backend_cfg["module"]) + impl_class = getattr(backend_module, backend_cfg["impl_class"]) + + # Calculate scale + scale = 1.0 / np.sqrt(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]) + + # Create mock kv_b_proj layer for prefill mode + mock_kv_b_proj = MockKVBProj( + num_heads=mla_dims["num_q_heads"], + qk_nope_head_dim=mla_dims["qk_nope_head_dim"], + v_head_dim=mla_dims["v_head_dim"], + ) + + # Create impl + impl = impl_class( + num_heads=mla_dims["num_q_heads"], + head_size=mla_dims["head_dim"], + scale=scale, + num_kv_heads=mla_dims["num_kv_heads"], + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=mla_dims["kv_lora_rank"], + qk_nope_head_dim=mla_dims["qk_nope_head_dim"], + qk_rope_head_dim=mla_dims["qk_rope_head_dim"], + qk_head_dim=mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"], + v_head_dim=mla_dims["v_head_dim"], + kv_b_proj=mock_kv_b_proj, + ) + + # Initialize DCP attributes + if not hasattr(impl, "dcp_world_size") or impl.dcp_world_size in (None, -1): + impl.dcp_world_size = 1 + impl.dcp_rank = 0 + + # Create KV cache spec for MockLayer + from vllm.v1.kv_cache_interface import FullAttentionSpec + + kv_cache_spec = FullAttentionSpec( + block_size=backend_cfg["block_size"] or vllm_config.cache_config.block_size, + num_kv_heads=1, # MLA uses 1 KV head + head_size=576, # MLA head dim + dtype=torch.bfloat16, + ) + + # Create mock layer + layer = MockLayer(device, impl=impl, kv_cache_spec=kv_cache_spec) + + # Create builder instance if needed + builder_instance = None + if backend_cfg["builder_class"]: + builder_class = getattr(backend_module, backend_cfg["builder_class"]) + + # Populate static_forward_context so builder can find the layer + # MockLayer inherits from AttentionLayerBase, so isinstance checks pass + vllm_config.compilation_config.static_forward_context = {"placeholder": layer} + + builder_instance = builder_class( + kv_cache_spec=kv_cache_spec, + layer_names=["placeholder"], + vllm_config=vllm_config, + device=device, + ) + + return impl, layer, builder_instance + + +# ============================================================================ +# Config Helpers +# ============================================================================ + + +def _extract_mla_dims_from_config(config) -> dict | None: + """ + Extract MLA dimensions from BenchmarkConfig if all required fields are present. + + Args: + config: BenchmarkConfig instance + + Returns: + Dict with MLA dimensions if all fields are provided, None otherwise + """ + # Check if all MLA-specific fields are provided + if all( + [ + config.kv_lora_rank is not None, + config.qk_nope_head_dim is not None, + config.qk_rope_head_dim is not None, + config.v_head_dim is not None, + ] + ): + return { + "kv_lora_rank": config.kv_lora_rank, + "qk_nope_head_dim": config.qk_nope_head_dim, + "qk_rope_head_dim": config.qk_rope_head_dim, + "v_head_dim": config.v_head_dim, + "num_q_heads": config.num_q_heads, + "num_kv_heads": config.num_kv_heads, + "head_dim": config.head_dim, + } + # Fallback: if MLA fields not fully specified, try to construct from basic fields + elif config.head_dim == 576: + # This looks like a DeepSeek MLA config, use standard dimensions with custom + # head count + return { + "kv_lora_rank": 512, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "num_q_heads": config.num_q_heads, + "num_kv_heads": config.num_kv_heads, + "head_dim": config.head_dim, + } + return None + + +# ============================================================================ +# Benchmark Execution +# ============================================================================ + + +def _run_single_benchmark( + config, + impl, + layer, + builder_instance, + backend_cfg: dict, + mla_dims: dict, + device: torch.device, +) -> BenchmarkResult: + """ + Run a single benchmark iteration. + + Args: + config: BenchmarkConfig instance + impl: Backend implementation instance + layer: MockLayer instance + builder_instance: Metadata builder instance + backend_cfg: Backend configuration dict + mla_dims: MLA dimension configuration + device: Target device + + Returns: + BenchmarkResult with timing statistics + """ + # Parse batch spec + requests = parse_batch_spec(config.batch_spec) + q_lens = [r.q_len for r in requests] + total_q = sum(q_lens) + + # Determine block size + block_size = backend_cfg["block_size"] or config.block_size + + # Build metadata + metadata, num_blocks = _build_attention_metadata( + requests, block_size, device, builder_instance + ) + + # Create KV cache + kv_cache = torch.zeros( + num_blocks, + block_size, + mla_dims["kv_lora_rank"] + mla_dims["qk_rope_head_dim"], + device=device, + dtype=torch.bfloat16, + ) + + # Create input tensors for both decode and prefill modes + decode_inputs, prefill_inputs = _create_input_tensors( + total_q, + mla_dims, + backend_cfg["query_format"], + device, + torch.bfloat16, + ) + + # Determine which forward method to use based on metadata + if metadata.decode is not None: + forward_fn = lambda: impl._forward_decode( + decode_inputs, kv_cache, metadata, layer + ) + elif metadata.prefill is not None: + forward_fn = lambda: impl._forward_prefill( + prefill_inputs["q"], + prefill_inputs["k_c_normed"], + prefill_inputs["k_pe"], + kv_cache, + metadata, + prefill_inputs["k_scale"], + prefill_inputs["output"], + ) + else: + raise RuntimeError("Metadata has neither decode nor prefill metadata") + + # Warmup + for _ in range(config.warmup_iters): + forward_fn() + torch.cuda.synchronize() + + # Benchmark + times = [] + for _ in range(config.repeats): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(config.num_layers): + forward_fn() + end.record() + + torch.cuda.synchronize() + elapsed_ms = start.elapsed_time(end) + times.append(elapsed_ms / 1000.0 / config.num_layers) + + mean_time = float(np.mean(times)) + return BenchmarkResult( + config=config, + mean_time=mean_time, + std_time=float(np.std(times)), + min_time=float(np.min(times)), + max_time=float(np.max(times)), + throughput_tokens_per_sec=total_q / mean_time if mean_time > 0 else 0, + ) + + +def _run_mla_benchmark_batched( + backend: str, + configs_with_params: list[tuple], # [(config, threshold, num_splits), ...] +) -> list[BenchmarkResult]: + """ + Unified batched MLA benchmark runner for all backends. + + Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla + + This function reuses backend initialization across multiple benchmarks + to avoid setup/teardown overhead. + + Args: + backend: Backend name + configs_with_params: List of (config, threshold, num_splits) tuples + - threshold: reorder_batch_threshold (FlashAttn/FlashMLA only) + - num_splits: num_kv_splits (CUTLASS only) + + Returns: + List of BenchmarkResult objects + """ + if not configs_with_params: + return [] + + backend_cfg = _get_backend_config(backend) + device = torch.device(configs_with_params[0][0].device) + torch.cuda.set_device(device) + + # Determine block size + config_block_size = configs_with_params[0][0].block_size + block_size = backend_cfg["block_size"] or config_block_size + + # Extract MLA dimensions from the first config + first_config = configs_with_params[0][0] + mla_dims = _extract_mla_dims_from_config(first_config) + + # If config didn't provide MLA dims, fall back to default model + if mla_dims is None: + mla_dims = setup_mla_dims("deepseek-v3") + + # Create and set vLLM config for MLA (reused across all benchmarks) + vllm_config = create_minimal_vllm_config( + model_name="deepseek-v3", # Used only for model path + block_size=block_size, + mla_dims=mla_dims, # Use custom dims from config or default + ) + + results = [] + + with set_current_vllm_config(vllm_config): + # Create backend impl, layer, and builder (reused across benchmarks) + impl, layer, builder_instance = _create_backend_impl( + backend_cfg, mla_dims, vllm_config, device + ) + + # Run each benchmark with the shared impl + for config, threshold, num_splits in configs_with_params: + # Set threshold for this benchmark (FlashAttn/FlashMLA only) + original_threshold = None + if threshold is not None and builder_instance: + original_threshold = builder_instance.reorder_batch_threshold + builder_instance.reorder_batch_threshold = threshold + + # Set num_splits for CUTLASS + original_num_splits = None + if num_splits is not None and hasattr(impl, "_num_kv_splits"): + original_num_splits = impl._num_kv_splits + impl._num_kv_splits = num_splits + + try: + result = _run_single_benchmark( + config, + impl, + layer, + builder_instance, + backend_cfg, + mla_dims, + device, + ) + results.append(result) + + finally: + # Restore original threshold + if original_threshold is not None: + builder_instance.reorder_batch_threshold = original_threshold + + # Restore original num_splits + if original_num_splits is not None: + impl._num_kv_splits = original_num_splits + + return results + + +# ============================================================================ +# Public API +# ============================================================================ + + +def run_mla_benchmark( + backend: str, + config, + reorder_batch_threshold: int | None = None, + num_kv_splits: int | None = None, +) -> BenchmarkResult | list[BenchmarkResult]: + """ + Unified MLA benchmark runner for all backends. + + Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla + + Always uses batched execution internally for optimal performance. + + Args: + backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla) + config: BenchmarkConfig or list of (BenchmarkConfig, param) tuples + reorder_batch_threshold: Threshold override for FlashAttn/FlashMLA + (single config mode only) + num_kv_splits: Number of KV splits for CUTLASS (single config mode only) + + Returns: + BenchmarkResult (single mode) or list of BenchmarkResult (batched mode) + """ + # Normalize to batched mode: (config, threshold, num_splits) + if isinstance(config, list): + # Already in batched format + if len(config) > 0 and isinstance(config[0], tuple): + # Format: [(cfg, param), ...] where param is threshold or num_splits + if backend in ("flashattn_mla", "flashmla"): + configs_with_params = [(cfg, param, None) for cfg, param in config] + else: # cutlass_mla or flashinfer_mla + configs_with_params = [(cfg, None, param) for cfg, param in config] + else: + # Format: [cfg, ...] - just configs + configs_with_params = [(cfg, None, None) for cfg in config] + return_single = False + else: + # Single config: convert to batched format + configs_with_params = [(config, reorder_batch_threshold, num_kv_splits)] + return_single = True + + # Use unified batched execution + results = _run_mla_benchmark_batched(backend, configs_with_params) + + # Return single result or list based on input + return results[0] if return_single else results diff --git a/benchmarks/attention_benchmarks/runner.py b/benchmarks/attention_benchmarks/runner.py new file mode 100644 index 000000000000..bf08a1550c0c --- /dev/null +++ b/benchmarks/attention_benchmarks/runner.py @@ -0,0 +1,481 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Standard attention benchmark runner - shared utilities for non-MLA benchmarks. + +This module provides helpers for running standard attention backends +(FlashAttention, Triton, FlashInfer) with real vLLM integration. +""" + +import types + +import numpy as np +import torch +from batch_spec import parse_batch_spec, reorder_for_flashinfer +from common import BenchmarkConfig, BenchmarkResult, MockLayer, get_attention_scale + +from vllm.config import ( + CacheConfig, + CompilationConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import FullAttentionSpec + +# ============================================================================ +# Backend Configuration +# ============================================================================ + + +_BACKEND_CONFIG = { + "flash": { + "module": "vllm.v1.attention.backends.flash_attn", + "backend_class": "FlashAttentionBackend", + "dtype": torch.float16, + "cache_layout": "standard", + # ^ [2, num_blocks, block_size, num_kv_heads, head_dim] + }, + "triton": { + "module": "vllm.v1.attention.backends.triton_attn", + "backend_class": "TritonAttentionBackend", + "dtype": torch.float32, + "cache_layout": "standard", + }, + "flashinfer": { + "module": "vllm.v1.attention.backends.flashinfer", + "backend_class": "FlashInferBackend", + "dtype": torch.float16, + "cache_layout": "flashinfer", + # ^ [num_blocks, 2, block_size, num_kv_heads, head_dim] + }, +} + + +def _get_backend_config(backend: str) -> dict: + if backend not in _BACKEND_CONFIG: + raise ValueError( + f"Unknown backend: {backend}. " + f"Available: {', '.join(_BACKEND_CONFIG.keys())}" + ) + return _BACKEND_CONFIG[backend] + + +# ============================================================================ +# Metadata Building Helpers +# ============================================================================ + + +def _build_common_attn_metadata( + q_lens: list[int], + kv_lens: list[int], + block_size: int, + device: torch.device, +) -> CommonAttentionMetadata: + """Build CommonAttentionMetadata from query/kv lengths.""" + batch_size = len(q_lens) + total_tokens = sum(q_lens) + + query_start_loc = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + query_start_loc[1:] = torch.tensor(q_lens, dtype=torch.int32, device=device).cumsum( + 0 + ) + query_start_loc_cpu = query_start_loc.cpu() + + seq_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device) + seq_lens_cpu = seq_lens.cpu() + max_seq_len = int(seq_lens_cpu.max()) + + context_lens = [kv - q for kv, q in zip(kv_lens, q_lens)] + num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) + + max_blocks = (max(kv_lens) + block_size - 1) // block_size + num_blocks = batch_size * max_blocks + block_table_tensor = torch.arange( + num_blocks, dtype=torch.int32, device=device + ).view(batch_size, max_blocks) + slot_mapping = torch.arange(total_tokens, dtype=torch.int64, device=device) + + max_query_len = max(q_lens) + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=batch_size, + num_actual_tokens=total_tokens, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + causal=True, + ) + + +def _create_vllm_config( + config: BenchmarkConfig, + dtype: torch.dtype, + max_num_blocks: int, +) -> VllmConfig: + """Create a VllmConfig for benchmarking with mock model methods.""" + model_config = ModelConfig( + model="meta-llama/Meta-Llama-3-8B", + tokenizer="meta-llama/Meta-Llama-3-8B", + trust_remote_code=False, + dtype=dtype, + seed=0, + max_model_len=1024, + ) + + cache_config = CacheConfig( + block_size=config.block_size, + cache_dtype="auto", + swap_space=0, + ) + cache_config.num_gpu_blocks = max_num_blocks + cache_config.num_cpu_blocks = 0 + + parallel_config = ParallelConfig(tensor_parallel_size=1) + scheduler_config = SchedulerConfig( + max_num_seqs=256, + max_num_batched_tokens=8192, + max_model_len=8192, + is_encoder_decoder=False, + enable_chunked_prefill=True, + ) + device_config = DeviceConfig() + load_config = LoadConfig() + compilation_config = CompilationConfig() + + # Add mock methods for benchmark config values + model_config.get_num_layers = types.MethodType( + lambda self: config.num_layers, model_config + ) + model_config.get_sliding_window_for_layer = types.MethodType( + lambda self, i: None, model_config + ) + model_config.get_logits_soft_cap_for_layer = types.MethodType( + lambda self, i: 0.0, model_config + ) + model_config.get_sm_scale_for_layer = types.MethodType( + lambda self, i: 1.0 / config.head_dim**0.5, model_config + ) + model_config.get_num_attention_heads = types.MethodType( + lambda self, parallel_config=None: config.num_q_heads, model_config + ) + model_config.get_num_kv_heads = types.MethodType( + lambda self, parallel_config=None: config.num_kv_heads, model_config + ) + model_config.get_head_size = types.MethodType( + lambda self: config.head_dim, model_config + ) + model_config.get_sliding_window = types.MethodType(lambda self: None, model_config) + + return VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + compilation_config=compilation_config, + ) + + +# ============================================================================ +# Backend Initialization +# ============================================================================ + + +def _create_backend_impl( + backend_cfg: dict, + config: BenchmarkConfig, + device: torch.device, +): + """Create backend implementation instance.""" + import importlib + + backend_module = importlib.import_module(backend_cfg["module"]) + backend_class = getattr(backend_module, backend_cfg["backend_class"]) + + scale = get_attention_scale(config.head_dim) + dtype = backend_cfg["dtype"] + + impl = backend_class.get_impl_cls()( + num_heads=config.num_q_heads, + head_size=config.head_dim, + scale=scale, + num_kv_heads=config.num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + ) + + kv_cache_spec = FullAttentionSpec( + block_size=config.block_size, + num_kv_heads=config.num_kv_heads, + head_size=config.head_dim, + dtype=dtype, + ) + + layer = MockLayer(device, kv_cache_spec=kv_cache_spec) + + return backend_class, impl, layer, dtype + + +def _create_metadata_builder( + backend_class, + kv_cache_spec: FullAttentionSpec, + vllm_config: VllmConfig, + device: torch.device, +): + """Create metadata builder instance.""" + return backend_class.get_builder_cls()( + kv_cache_spec=kv_cache_spec, + layer_names=["layer_0"], + vllm_config=vllm_config, + device=device, + ) + + +# ============================================================================ +# Tensor Creation Helpers +# ============================================================================ + + +def _create_input_tensors( + config: BenchmarkConfig, + total_q: int, + device: torch.device, + dtype: torch.dtype, +) -> tuple: + """Create Q, K, V input tensors for all layers.""" + q_list = [ + torch.randn( + total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype + ) + for _ in range(config.num_layers) + ] + k_list = [ + torch.randn( + total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype + ) + for _ in range(config.num_layers) + ] + v_list = [ + torch.randn( + total_q, config.num_kv_heads, config.head_dim, device=device, dtype=dtype + ) + for _ in range(config.num_layers) + ] + return q_list, k_list, v_list + + +def _create_kv_cache( + config: BenchmarkConfig, + max_num_blocks: int, + cache_layout: str, + device: torch.device, + dtype: torch.dtype, +) -> list: + """Create KV cache tensors for all layers.""" + if cache_layout == "flashinfer": + # FlashInfer layout: [num_blocks, 2, block_size, num_kv_heads, head_dim] + cache_list = [ + torch.zeros( + max_num_blocks, + 2, + config.block_size, + config.num_kv_heads, + config.head_dim, + device=device, + dtype=dtype, + ) + for _ in range(config.num_layers) + ] + else: + # Standard layout: [2, num_blocks, block_size, num_kv_heads, head_dim] + cache_list = [ + torch.zeros( + 2, + max_num_blocks, + config.block_size, + config.num_kv_heads, + config.head_dim, + device=device, + dtype=dtype, + ) + for _ in range(config.num_layers) + ] + return cache_list + + +# ============================================================================ +# Benchmark Execution +# ============================================================================ + + +def _run_single_benchmark( + config: BenchmarkConfig, + impl, + layer, + q_list: list, + k_list: list, + v_list: list, + cache_list: list, + attn_metadata, + device: torch.device, + dtype: torch.dtype, +) -> tuple: + """Run single benchmark iteration with warmup and timing loop.""" + total_q = q_list[0].shape[0] + out = torch.empty( + total_q, config.num_q_heads, config.head_dim, device=device, dtype=dtype + ) + + # Warmup + for _ in range(config.warmup_iters): + for i in range(config.num_layers): + impl.forward( + layer, + q_list[i], + k_list[i], + v_list[i], + cache_list[i], + attn_metadata, + output=out, + ) + torch.cuda.synchronize() + + # Benchmark + times = [] + for _ in range(config.repeats): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for i in range(config.num_layers): + impl.forward( + layer, + q_list[i], + k_list[i], + v_list[i], + cache_list[i], + attn_metadata, + output=out, + ) + end.record() + + torch.cuda.synchronize() + elapsed_ms = start.elapsed_time(end) + times.append(elapsed_ms / 1000.0 / config.num_layers) # seconds per layer + + mem_stats = {} + if config.profile_memory: + mem_stats = { + "allocated_mb": torch.cuda.memory_allocated(device) / 1024**2, + "reserved_mb": torch.cuda.memory_reserved(device) / 1024**2, + } + + return times, mem_stats + + +# ============================================================================ +# Public API +# ============================================================================ + + +def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: + """ + Run standard attention benchmark with real kernels. + + Supports: flash, triton, flashinfer + + Args: + config: Benchmark configuration + + Returns: + BenchmarkResult with timing and memory statistics + """ + device = torch.device(config.device) + torch.cuda.set_device(device) + + backend_cfg = _get_backend_config(config.backend) + + requests = parse_batch_spec(config.batch_spec) + + if config.backend == "flashinfer": + requests = reorder_for_flashinfer(requests) + + q_lens = [r.q_len for r in requests] + kv_lens = [r.kv_len for r in requests] + total_q = sum(q_lens) + max_kv = max(kv_lens) + + max_num_blocks = (max_kv + config.block_size - 1) // config.block_size + + backend_class, impl, layer, dtype = _create_backend_impl( + backend_cfg, config, device + ) + + common_metadata = _build_common_attn_metadata( + q_lens, kv_lens, config.block_size, device + ) + + kv_cache_spec = FullAttentionSpec( + block_size=config.block_size, + num_kv_heads=config.num_kv_heads, + head_size=config.head_dim, + dtype=dtype, + ) + + vllm_config = _create_vllm_config(config, dtype, max_num_blocks) + + builder = _create_metadata_builder( + backend_class, kv_cache_spec, vllm_config, device + ) + + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_metadata, + ) + + q_list, k_list, v_list = _create_input_tensors(config, total_q, device, dtype) + + cache_list = _create_kv_cache( + config, max_num_blocks, backend_cfg["cache_layout"], device, dtype + ) + + times, mem_stats = _run_single_benchmark( + config, + impl, + layer, + q_list, + k_list, + v_list, + cache_list, + attn_metadata, + device, + dtype, + ) + + mean_time = np.mean(times) + throughput = total_q / mean_time if mean_time > 0 else 0 + + return BenchmarkResult( + config=config, + mean_time=mean_time, + std_time=np.std(times), + min_time=np.min(times), + max_time=np.max(times), + throughput_tokens_per_sec=throughput, + memory_allocated_mb=mem_stats.get("allocated_mb"), + memory_reserved_mb=mem_stats.get("reserved_mb"), + )