Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion benchmarks/attention_benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@
from vllm.v1.worker.workspace import init_workspace_manager


def _str2bool(v) -> bool:
if isinstance(v, bool):
return v
if v.lower() in ("true", "1", "yes", "t"):
return True
if v.lower() in ("false", "0", "no", "f"):
return False
raise argparse.ArgumentTypeError(f"expected a boolean, got {v!r}")


def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
"""Run standard attention benchmark (Flash/Triton/FlashInfer)."""
from runner import run_attention_benchmark
Expand Down Expand Up @@ -459,6 +469,20 @@ def main():
help="Prefill backends to compare (fa2, fa3, fa4). "
"Uses the first decode backend for impl construction.",
)
parser.add_argument(
"--fp8-output-scale",
type=float,
help="Static per-tensor scale enabling the MLA prefill FP8-output "
"comparison on FA4 (fused write vs standalone post-quant).",
)
parser.add_argument(
"--fuse-quant-op",
nargs="+",
type=_str2bool,
help="FP8-output write path(s) to run: false = bf16 attention + "
"standalone static-FP8 quant, true = FA4 writes FP8 directly. "
"Default: both.",
)

# Batch specifications
parser.add_argument(
Expand Down Expand Up @@ -545,6 +569,12 @@ def main():
# Prefill backends (e.g., ["fa3", "fa4"])
args.prefill_backends = yaml_config.get("prefill_backends", None)

# FP8 output benchmark knobs; CLI wins.
if args.fp8_output_scale is None:
args.fp8_output_scale = yaml_config.get("fp8_output_scale", None)
if args.fuse_quant_op is None:
args.fuse_quant_op = yaml_config.get("fuse_quant_op", None)

# Check for special modes
args.mode = yaml_config.get("mode", None)

Expand Down Expand Up @@ -662,8 +692,59 @@ def main():
# Run benchmarks
all_results = []

# FA4 fused FP8 output vs standalone post-quant, on the same fa4 kernel:
# the delta is the post-quant kernel the fused path removes.
fp8_output_scale = getattr(args, "fp8_output_scale", None)
if fp8_output_scale is not None:
decode_backend = backends[0]
fuse_variants = args.fuse_quant_op or [False, True]
label_of = {False: "post_quant", True: "fused"}
console.print(
f"[yellow]FP8 output comparison @ scale={fp8_output_scale} "
f"(prefill=fa4, decode impl={decode_backend})[/]"
)
fp8_results = []
total = len(fuse_variants) * len(args.batch_specs)
with tqdm(total=total, desc="FP8 output benchmarking") as pbar:
for spec in args.batch_specs:
for fuse in fuse_variants:
config = BenchmarkConfig(
backend=decode_backend,
batch_spec=spec,
num_layers=args.num_layers,
head_dim=args.head_dim,
num_q_heads=args.num_q_heads,
num_kv_heads=args.num_kv_heads,
block_size=args.block_size,
device=args.device,
repeats=args.repeats,
warmup_iters=args.warmup_iters,
profile_memory=args.profile_memory,
kv_cache_dtype=args.kv_cache_dtype,
use_cuda_graphs=args.cuda_graphs,
prefill_backend="fa4",
)
result = run_benchmark(
config, output_scale=fp8_output_scale, fuse_quant_op=fuse
)
label = label_of[fuse]
labeled_config = replace(result.config, backend=label)
result = replace(result, config=labeled_config)
fp8_results.append(result)

if not result.success:
console.print(f"[red]Error {label} {spec}: {result.error}[/]")

pbar.update(1)

console.print("\n[bold green]FP8 Output Results:[/]")
formatter = ResultsFormatter(console)
labels = [label_of[f] for f in fuse_variants]
formatter.print_table(fp8_results, labels, compare_to_fastest=True)
all_results = fp8_results

# Handle special mode: decode_vs_prefill comparison
if hasattr(args, "mode") and args.mode == "decode_vs_prefill":
elif hasattr(args, "mode") and args.mode == "decode_vs_prefill":
console.print("[yellow]Mode: Decode vs Prefill pipeline comparison[/]")
console.print(
"[dim]For each query length, testing both decode and prefill pipelines[/]"
Expand Down
44 changes: 44 additions & 0 deletions benchmarks/attention_benchmarks/configs/mla_fa4_fp8_output.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# MLA prefill FP8-output microbenchmark (FA4).
# Compares the fused FP8 write against bf16 attention + a standalone static-FP8
# quant; the delta is the post-quant kernel the fused path removes.
# DeepSeek-Coder-V2-Lite dims; FA4 needs SM100/110.
#
# Usage:
# python benchmark.py --config configs/mla_fa4_fp8_output.yaml

description: "MLA prefill FA4 fused-FP8 output vs post-quant"

model:
name: "deepseek-v2-lite"
num_layers: 27
num_q_heads: 16
num_kv_heads: 1
head_dim: 576
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
block_size: 128

# Pure prefill (q_len == kv_len) so every token goes through forward_mha.
batch_specs:
- "q512"
- "q1k"
- "q2k"
- "q4k"
- "q8k"
- "2q4k"
- "4q4k"
- "8q4k"

# Only used to construct the MLA impl; the pure-prefill specs skip decode.
decode_backends:
- CUTLASS_MLA

# Sweep the two FP8 write paths (prefill backend is fixed to fa4).
fp8_output_scale: 0.1
fuse_quant_op: [false, true]

device: "cuda:0"
repeats: 50
warmup_iters: 10
75 changes: 64 additions & 11 deletions benchmarks/attention_benchmarks/mla_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,8 @@ def _run_single_benchmark(
device: torch.device,
indexer=None,
kv_cache_dtype: str | None = None,
output_scale: float | None = None,
fuse_quant_op: bool = False,
) -> BenchmarkResult:
"""
Run a single benchmark iteration.
Expand All @@ -717,6 +719,11 @@ def _run_single_benchmark(
mla_dims: MLA dimension configuration
device: Target device
indexer: Optional MockIndexer for sparse backends
output_scale: Static per-tensor FP8 scale for prefill output. None
keeps the plain bf16 output (no quantization).
fuse_quant_op: With output_scale set, True lets the prefill kernel write
FP8 directly; False runs bf16 attention then a standalone static-FP8
quant. The delta isolates the saved post-quant kernel.

Returns:
BenchmarkResult with timing statistics
Expand Down Expand Up @@ -820,23 +827,55 @@ def _run_single_benchmark(
num_prefill, mla_dims, query_fmt, device, torch.bfloat16
)

# Prefill FP8 output: fused (kernel writes e4m3) vs separate post-quant.
prefill_fp8_output = None
prefill_output_scale = None
prefill_quant_op = None
if has_prefill and output_scale is not None:
from vllm.platforms import current_platform

prefill_output_scale = torch.tensor(
[output_scale], device=device, dtype=torch.float32
)
if fuse_quant_op:
prefill_fp8_output = torch.empty_like(
prefill_inputs["output"], dtype=current_platform.fp8_dtype()
)
else:
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
QuantFP8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)

prefill_quant_op = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)

fused_output = output_scale is not None and fuse_quant_op

# Build forward function
def forward_fn():
results = []
if has_decode:
results.append(impl.forward_mqa(decode_inputs, kv_cache, metadata, layer))
if has_prefill:
results.append(
impl.forward_mha(
prefill_inputs["q"],
prefill_inputs["k_c_normed"],
prefill_inputs["k_pe"],
kv_cache,
metadata,
prefill_inputs["k_scale"],
prefill_inputs["output"],
)
out = impl.forward_mha(
prefill_inputs["q"],
prefill_inputs["k_c_normed"],
prefill_inputs["k_pe"],
kv_cache,
metadata,
prefill_inputs["k_scale"],
prefill_fp8_output if fused_output else prefill_inputs["output"],
prefill_output_scale if fused_output else None,
)
if fused_output:
out = prefill_fp8_output
elif prefill_quant_op is not None:
out, _ = prefill_quant_op(
prefill_inputs["output"], prefill_output_scale
)
results.append(out)
return results[0] if len(results) == 1 else tuple(results)

# Warmup
Expand Down Expand Up @@ -886,6 +925,8 @@ def _run_mla_benchmark_batched(
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
index_topk: int = 2048,
prefill_backend: str | None = None,
output_scale: float | None = None,
fuse_quant_op: bool = False,
) -> list[BenchmarkResult]:
"""
Unified batched MLA benchmark runner for all backends.
Expand Down Expand Up @@ -1025,6 +1066,8 @@ def _run_mla_benchmark_batched(
device,
indexer=indexer,
kv_cache_dtype=kv_cache_dtype,
output_scale=output_scale,
fuse_quant_op=fuse_quant_op,
)
results.append(result)

Expand Down Expand Up @@ -1052,6 +1095,8 @@ def run_mla_benchmark(
num_kv_splits: int | None = None,
index_topk: int = 2048,
prefill_backend: str | None = None,
output_scale: float | None = None,
fuse_quant_op: bool = False,
) -> BenchmarkResult | list[BenchmarkResult]:
"""
Unified MLA benchmark runner for all backends.
Expand All @@ -1071,6 +1116,9 @@ def run_mla_benchmark(
index_topk: Topk value for sparse MLA backends (default 2048)
prefill_backend: Prefill backend name (e.g., "fa3", "fa4").
When set, forces the specified FlashAttention version for prefill.
output_scale: Static per-tensor FP8 scale for prefill output (None = bf16).
fuse_quant_op: With output_scale set, fuse the FP8 write into the prefill
kernel vs a standalone post-quant kernel. See _run_single_benchmark.

Returns:
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
Expand All @@ -1095,7 +1143,12 @@ def run_mla_benchmark(

# Use unified batched execution
results = _run_mla_benchmark_batched(
backend, configs_with_params, index_topk, prefill_backend=prefill_backend
backend,
configs_with_params,
index_topk,
prefill_backend=prefill_backend,
output_scale=output_scale,
fuse_quant_op=fuse_quant_op,
)

# Return single result or list based on input
Expand Down
Loading
Loading