diff --git a/benchmarks/attention_benchmarks/benchmark.py b/benchmarks/attention_benchmarks/benchmark.py index c4c331f7f8ef..438eb78c8096 100644 --- a/benchmarks/attention_benchmarks/benchmark.py +++ b/benchmarks/attention_benchmarks/benchmark.py @@ -50,6 +50,16 @@ from vllm.v1.worker.workspace import init_workspace_manager +def _str2bool(v) -> bool: + if isinstance(v, bool): + return v + if v.lower() in ("true", "1", "yes", "t"): + return True + if v.lower() in ("false", "0", "no", "f"): + return False + raise argparse.ArgumentTypeError(f"expected a boolean, got {v!r}") + + def run_standard_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: """Run standard attention benchmark (Flash/Triton/FlashInfer).""" from runner import run_attention_benchmark @@ -459,6 +469,20 @@ def main(): help="Prefill backends to compare (fa2, fa3, fa4). " "Uses the first decode backend for impl construction.", ) + parser.add_argument( + "--fp8-output-scale", + type=float, + help="Static per-tensor scale enabling the MLA prefill FP8-output " + "comparison on FA4 (fused write vs standalone post-quant).", + ) + parser.add_argument( + "--fuse-quant-op", + nargs="+", + type=_str2bool, + help="FP8-output write path(s) to run: false = bf16 attention + " + "standalone static-FP8 quant, true = FA4 writes FP8 directly. " + "Default: both.", + ) # Batch specifications parser.add_argument( @@ -545,6 +569,12 @@ def main(): # Prefill backends (e.g., ["fa3", "fa4"]) args.prefill_backends = yaml_config.get("prefill_backends", None) + # FP8 output benchmark knobs; CLI wins. + if args.fp8_output_scale is None: + args.fp8_output_scale = yaml_config.get("fp8_output_scale", None) + if args.fuse_quant_op is None: + args.fuse_quant_op = yaml_config.get("fuse_quant_op", None) + # Check for special modes args.mode = yaml_config.get("mode", None) @@ -662,8 +692,59 @@ def main(): # Run benchmarks all_results = [] + # FA4 fused FP8 output vs standalone post-quant, on the same fa4 kernel: + # the delta is the post-quant kernel the fused path removes. + fp8_output_scale = getattr(args, "fp8_output_scale", None) + if fp8_output_scale is not None: + decode_backend = backends[0] + fuse_variants = args.fuse_quant_op or [False, True] + label_of = {False: "post_quant", True: "fused"} + console.print( + f"[yellow]FP8 output comparison @ scale={fp8_output_scale} " + f"(prefill=fa4, decode impl={decode_backend})[/]" + ) + fp8_results = [] + total = len(fuse_variants) * len(args.batch_specs) + with tqdm(total=total, desc="FP8 output benchmarking") as pbar: + for spec in args.batch_specs: + for fuse in fuse_variants: + config = BenchmarkConfig( + backend=decode_backend, + batch_spec=spec, + num_layers=args.num_layers, + head_dim=args.head_dim, + num_q_heads=args.num_q_heads, + num_kv_heads=args.num_kv_heads, + block_size=args.block_size, + device=args.device, + repeats=args.repeats, + warmup_iters=args.warmup_iters, + profile_memory=args.profile_memory, + kv_cache_dtype=args.kv_cache_dtype, + use_cuda_graphs=args.cuda_graphs, + prefill_backend="fa4", + ) + result = run_benchmark( + config, output_scale=fp8_output_scale, fuse_quant_op=fuse + ) + label = label_of[fuse] + labeled_config = replace(result.config, backend=label) + result = replace(result, config=labeled_config) + fp8_results.append(result) + + if not result.success: + console.print(f"[red]Error {label} {spec}: {result.error}[/]") + + pbar.update(1) + + console.print("\n[bold green]FP8 Output Results:[/]") + formatter = ResultsFormatter(console) + labels = [label_of[f] for f in fuse_variants] + formatter.print_table(fp8_results, labels, compare_to_fastest=True) + all_results = fp8_results + # Handle special mode: decode_vs_prefill comparison - if hasattr(args, "mode") and args.mode == "decode_vs_prefill": + elif hasattr(args, "mode") and args.mode == "decode_vs_prefill": console.print("[yellow]Mode: Decode vs Prefill pipeline comparison[/]") console.print( "[dim]For each query length, testing both decode and prefill pipelines[/]" diff --git a/benchmarks/attention_benchmarks/configs/mla_fa4_fp8_output.yaml b/benchmarks/attention_benchmarks/configs/mla_fa4_fp8_output.yaml new file mode 100644 index 000000000000..85588fcf9584 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_fa4_fp8_output.yaml @@ -0,0 +1,44 @@ +# MLA prefill FP8-output microbenchmark (FA4). +# Compares the fused FP8 write against bf16 attention + a standalone static-FP8 +# quant; the delta is the post-quant kernel the fused path removes. +# DeepSeek-Coder-V2-Lite dims; FA4 needs SM100/110. +# +# Usage: +# python benchmark.py --config configs/mla_fa4_fp8_output.yaml + +description: "MLA prefill FA4 fused-FP8 output vs post-quant" + +model: + name: "deepseek-v2-lite" + num_layers: 27 + num_q_heads: 16 + num_kv_heads: 1 + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 + +# Pure prefill (q_len == kv_len) so every token goes through forward_mha. +batch_specs: + - "q512" + - "q1k" + - "q2k" + - "q4k" + - "q8k" + - "2q4k" + - "4q4k" + - "8q4k" + +# Only used to construct the MLA impl; the pure-prefill specs skip decode. +decode_backends: + - CUTLASS_MLA + +# Sweep the two FP8 write paths (prefill backend is fixed to fa4). +fp8_output_scale: 0.1 +fuse_quant_op: [false, true] + +device: "cuda:0" +repeats: 50 +warmup_iters: 10 diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py index abab1e2edbac..3a2078117462 100644 --- a/benchmarks/attention_benchmarks/mla_runner.py +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -704,6 +704,8 @@ def _run_single_benchmark( device: torch.device, indexer=None, kv_cache_dtype: str | None = None, + output_scale: float | None = None, + fuse_quant_op: bool = False, ) -> BenchmarkResult: """ Run a single benchmark iteration. @@ -717,6 +719,11 @@ def _run_single_benchmark( mla_dims: MLA dimension configuration device: Target device indexer: Optional MockIndexer for sparse backends + output_scale: Static per-tensor FP8 scale for prefill output. None + keeps the plain bf16 output (no quantization). + fuse_quant_op: With output_scale set, True lets the prefill kernel write + FP8 directly; False runs bf16 attention then a standalone static-FP8 + quant. The delta isolates the saved post-quant kernel. Returns: BenchmarkResult with timing statistics @@ -820,23 +827,55 @@ def _run_single_benchmark( num_prefill, mla_dims, query_fmt, device, torch.bfloat16 ) + # Prefill FP8 output: fused (kernel writes e4m3) vs separate post-quant. + prefill_fp8_output = None + prefill_output_scale = None + prefill_quant_op = None + if has_prefill and output_scale is not None: + from vllm.platforms import current_platform + + prefill_output_scale = torch.tensor( + [output_scale], device=device, dtype=torch.float32 + ) + if fuse_quant_op: + prefill_fp8_output = torch.empty_like( + prefill_inputs["output"], dtype=current_platform.fp8_dtype() + ) + else: + from vllm.model_executor.layers.quantization.input_quant_fp8 import ( + QuantFP8, + ) + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + ) + + prefill_quant_op = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) + + fused_output = output_scale is not None and fuse_quant_op + # Build forward function def forward_fn(): results = [] if has_decode: results.append(impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)) if has_prefill: - results.append( - impl.forward_mha( - prefill_inputs["q"], - prefill_inputs["k_c_normed"], - prefill_inputs["k_pe"], - kv_cache, - metadata, - prefill_inputs["k_scale"], - prefill_inputs["output"], - ) + out = impl.forward_mha( + prefill_inputs["q"], + prefill_inputs["k_c_normed"], + prefill_inputs["k_pe"], + kv_cache, + metadata, + prefill_inputs["k_scale"], + prefill_fp8_output if fused_output else prefill_inputs["output"], + prefill_output_scale if fused_output else None, ) + if fused_output: + out = prefill_fp8_output + elif prefill_quant_op is not None: + out, _ = prefill_quant_op( + prefill_inputs["output"], prefill_output_scale + ) + results.append(out) return results[0] if len(results) == 1 else tuple(results) # Warmup @@ -886,6 +925,8 @@ def _run_mla_benchmark_batched( configs_with_params: list[tuple], # [(config, threshold, num_splits), ...] index_topk: int = 2048, prefill_backend: str | None = None, + output_scale: float | None = None, + fuse_quant_op: bool = False, ) -> list[BenchmarkResult]: """ Unified batched MLA benchmark runner for all backends. @@ -1025,6 +1066,8 @@ def _run_mla_benchmark_batched( device, indexer=indexer, kv_cache_dtype=kv_cache_dtype, + output_scale=output_scale, + fuse_quant_op=fuse_quant_op, ) results.append(result) @@ -1052,6 +1095,8 @@ def run_mla_benchmark( num_kv_splits: int | None = None, index_topk: int = 2048, prefill_backend: str | None = None, + output_scale: float | None = None, + fuse_quant_op: bool = False, ) -> BenchmarkResult | list[BenchmarkResult]: """ Unified MLA benchmark runner for all backends. @@ -1071,6 +1116,9 @@ def run_mla_benchmark( index_topk: Topk value for sparse MLA backends (default 2048) prefill_backend: Prefill backend name (e.g., "fa3", "fa4"). When set, forces the specified FlashAttention version for prefill. + output_scale: Static per-tensor FP8 scale for prefill output (None = bf16). + fuse_quant_op: With output_scale set, fuse the FP8 write into the prefill + kernel vs a standalone post-quant kernel. See _run_single_benchmark. Returns: BenchmarkResult (single mode) or list of BenchmarkResult (batched mode) @@ -1095,7 +1143,12 @@ def run_mla_benchmark( # Use unified batched execution results = _run_mla_benchmark_batched( - backend, configs_with_params, index_topk, prefill_backend=prefill_backend + backend, + configs_with_params, + index_topk, + prefill_backend=prefill_backend, + output_scale=output_scale, + fuse_quant_op=fuse_quant_op, ) # Return single result or list based on input diff --git a/tests/v1/attention/test_mla_prefill_quant_output.py b/tests/v1/attention/test_mla_prefill_quant_output.py new file mode 100644 index 000000000000..d7659485aa9f --- /dev/null +++ b/tests/v1/attention/test_mla_prefill_quant_output.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for MLA prefill backend fused-quant-output support. + +Covers two things: + * `MLAPrefillBackend.supports_quant_output`, the capability gate that decides + whether the prefill kernel writes quantized output directly (FA4 native + fused FP8, see flash-attention#135) instead of the post-quant path. + * The numerical equivalence of that fused FP8 write versus the bf16-attention + + standalone static-FP8-quant path it replaces (GPU-only, SM100/SM110). +""" + +from unittest.mock import patch + +import pytest +import torch + +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8Dynamic128Sym, + kFp8StaticTensorSym, + kNvfp4Dynamic, +) +from vllm.platforms.interface import DeviceCapability +from vllm.v1.attention.backends.mla.prefill.base import MLAPrefillBackend +from vllm.v1.attention.backends.mla.prefill.flash_attn import ( + FlashAttnPrefillBackend, +) + +_FA_MODULE = "vllm.v1.attention.backends.mla.prefill.flash_attn" + + +class _DummyPrefillBackend(MLAPrefillBackend): + """Concrete backend that does NOT override supports_quant_output.""" + + @staticmethod + def get_name() -> str: + return "DUMMY" + + def run_prefill_new_tokens(self, *args, **kwargs): # pragma: no cover + raise NotImplementedError + + def run_prefill_context_chunk(self, *args, **kwargs): # pragma: no cover + raise NotImplementedError + + +@pytest.mark.parametrize( + "quant_key", [kFp8StaticTensorSym, kFp8Dynamic128Sym, kNvfp4Dynamic, None] +) +def test_base_backend_never_supports_quant_output(quant_key): + """The base default opts every backend out unless it overrides.""" + backend = object.__new__(_DummyPrefillBackend) + assert backend.supports_quant_output(quant_key) is False + + +def _make_fa_backend(version: int | None, is_vllm_fa: bool): + """Build a FlashAttnPrefillBackend without running its heavy __init__.""" + backend = object.__new__(FlashAttnPrefillBackend) + backend.vllm_flash_attn_version = version + backend._is_vllm_fa = is_vllm_fa + return backend + + +@pytest.mark.parametrize( + ("version", "is_vllm_fa", "dc_major", "quant_key", "expected"), + [ + # FA4 + vLLM-FA + Blackwell SM100/SM110 + static FP8 -> fused. + (4, True, 10, kFp8StaticTensorSym, True), + (4, True, 11, kFp8StaticTensorSym, True), + # Wrong compute capability (SM90 / SM120) -> not supported (#135). + (4, True, 9, kFp8StaticTensorSym, False), + (4, True, 12, kFp8StaticTensorSym, False), + # Not FA4. + (3, True, 10, kFp8StaticTensorSym, False), + (2, True, 10, kFp8StaticTensorSym, False), + (None, True, 10, kFp8StaticTensorSym, False), + # Upstream (ROCm) flash-attn, not vLLM-FA. + (4, False, 10, kFp8StaticTensorSym, False), + # Quant keys not wired through FA4 yet. + (4, True, 10, kFp8Dynamic128Sym, False), + (4, True, 10, kNvfp4Dynamic, False), + ], +) +def test_flash_attn_supports_quant_output( + version, is_vllm_fa, dc_major, quant_key, expected +): + backend = _make_fa_backend(version, is_vllm_fa) + with patch(f"{_FA_MODULE}.current_platform") as plat: + plat.get_device_capability.return_value = DeviceCapability( + major=dc_major, minor=0 + ) + assert backend.supports_quant_output(quant_key) is expected + + +def test_flash_attn_supports_quant_output_unknown_device(): + """A None device capability (e.g. capability probe failed) is safe.""" + backend = _make_fa_backend(version=4, is_vllm_fa=True) + with patch(f"{_FA_MODULE}.current_platform") as plat: + plat.get_device_capability.return_value = None + assert backend.supports_quant_output(kFp8StaticTensorSym) is False + + +def test_flash_attn_prefill_backend_signature_accepts_fused_kwargs(): + """run_prefill_new_tokens must accept out/output_scale so the direct + (non-**kwargs) call in forward_mha type- and runtime-checks.""" + import inspect + + params = inspect.signature( + FlashAttnPrefillBackend.run_prefill_new_tokens + ).parameters + assert "out" in params + assert "output_scale" in params + # The base contract must expose them too (Liskov / direct call site). + base_params = inspect.signature(MLAPrefillBackend.run_prefill_new_tokens).parameters + assert "out" in base_params + assert "output_scale" in base_params + + +def test_mla_impl_forward_mha_accepts_output_scale(): + """The abstract MLA impl forward_mha must carry output_scale so every + override (and the unconditional forward_impl call) stays compatible.""" + import inspect + + from vllm.v1.attention.backend import MLAAttentionImpl + + params = inspect.signature(MLAAttentionImpl.forward_mha).parameters + assert "output_scale" in params + assert params["output_scale"].default is None + + +def _fused_fp8_skip_reason() -> str | None: + """FA4 fused FP8 output needs a real Blackwell SM100/SM110 GPU.""" + if not torch.cuda.is_available(): + return "requires CUDA" + major = torch.cuda.get_device_capability()[0] + if major not in (10, 11): + return f"FA4 fused FP8 output requires SM100/SM110, got SM{major}x" + return None + + +_FUSED_FP8_SKIP = _fused_fp8_skip_reason() + + +@pytest.mark.skipif(_FUSED_FP8_SKIP is not None, reason=_FUSED_FP8_SKIP or "") +def test_fa4_fused_fp8_output_matches_post_quant(default_vllm_config): + """FA4's fused FP8 write (output_scale, flash-attention#135) must match the + bf16-attention + standalone static-FP8-quant path it replaces, since + production uses the same output_scale for both.""" + from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 + from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape + from vllm.platforms import current_platform + from vllm.vllm_flash_attn import flash_attn_varlen_func + + torch.manual_seed(0) + device = torch.device("cuda") + fp8_dtype = current_platform.fp8_dtype() + + # MLA prefill head dims (post kv_b_proj): q/k = qk_nope(128)+qk_rope(64), + # v = v_head_dim(128); DeepSeek-V2-Lite has 16 query heads. + num_heads, qk_head_dim, v_head_dim, seqlen = 16, 192, 128, 512 + cu_seqlens = torch.tensor([0, seqlen], dtype=torch.int32, device=device) + q = torch.randn(seqlen, num_heads, qk_head_dim, dtype=torch.bfloat16, device=device) + k = torch.randn(seqlen, num_heads, qk_head_dim, dtype=torch.bfloat16, device=device) + v = torch.randn(seqlen, num_heads, v_head_dim, dtype=torch.bfloat16, device=device) + + fa_kwargs = dict( + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + causal=True, + fa_version=4, + ) + + # Reference: bf16 attention, then standalone static per-tensor FP8 quant. + out_bf16 = flash_attn_varlen_func(q=q, k=k, v=v, **fa_kwargs) + out_2d = out_bf16.reshape(seqlen, num_heads * v_head_dim) + # Scale the amax near e4m3 max so the check uses the representable range. + finfo = torch.finfo(fp8_dtype) + scale = (out_2d.abs().max() / finfo.max).to(torch.float32).reshape(1) + quant_op = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) + ref_fp8, _ = quant_op(out_2d, scale) + + # Feature: FA4 writes e4m3 into the (tokens, heads*dim) buffer directly. + fused_fp8 = torch.empty( + seqlen, num_heads * v_head_dim, dtype=fp8_dtype, device=device + ) + flash_attn_varlen_func( + q=q, + k=k, + v=v, + out=fused_fp8.view(seqlen, num_heads, v_head_dim), + output_scale=scale, + **fa_kwargs, + ) + + # Non-degenerate (catches a no-op / all-zero write). + assert torch.isfinite(fused_fp8.float()).all() + assert fused_fp8.float().abs().any() + + # e4m3 has 3 mantissa bits, so allow ~1 mantissa step of rounding slack. + ref = ref_fp8.float() * scale + got = fused_fp8.float() * scale + torch.testing.assert_close(got, ref, rtol=0.125, atol=float(scale) * 2) + + # ...and most elements land in the exact same fp8 bucket. + exact = (fused_fp8.view(torch.uint8) == ref_fp8.view(torch.uint8)).float().mean() + assert exact > 0.9, f"only {exact:.1%} of fused FP8 outputs matched the baseline" diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 140e071c7465..bae66bfc2242 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -687,7 +687,23 @@ def forward_impl( num_mqa_tokens = attn_metadata.num_decode_tokens num_mha_tokens = q.size(0) - num_mqa_tokens + mha_use_quant_output = ( + quant_key is not None + and self.prefill_backend.supports_quant_output(quant_key) + and attn_metadata is not None + and attn_metadata.prefill is not None + and attn_metadata.prefill.chunked_context is None + and self.impl.dcp_world_size <= 1 + ) + if num_mha_tokens > 0: + if mha_use_quant_output: + mha_output = quant_output + mha_output_scale = output_scale + else: + mha_output = output + mha_output_scale = None + self.impl.forward_mha( # type: ignore[attr-defined] q[num_mqa_tokens:], k_c_normed[num_mqa_tokens:], @@ -695,7 +711,8 @@ def forward_impl( kv_cache, attn_metadata, self._k_scale, - output=output[num_mqa_tokens:], + output=mha_output[num_mqa_tokens:num_actual_toks], + output_scale=mha_output_scale, ) if num_mqa_tokens > 0: @@ -794,13 +811,15 @@ def forward_impl( self._v_up_proj(attn_out, out=mqa_output_slice) if quant_key is not None: - # Quantize the BF16 computation result into the quantized output - actual = output[:num_actual_toks] + quant_idx = num_mqa_tokens if mha_use_quant_output else num_actual_toks + if quant_idx == 0: + return quant_output + actual = output[:quant_idx] if quant_key == kNvfp4Dynamic: # NVFP4: two FP4 values packed into one uint8 assert output_block_scale is not None fp4_data, fp4_scales = ops.scaled_fp4_quant(actual, output_scale) - quant_output[:num_actual_toks].copy_(fp4_data) + quant_output[:quant_idx].copy_(fp4_data) output_block_scale[: fp4_scales.shape[0]].copy_(fp4_scales) elif quant_key in (kFp8Dynamic128Sym, kFp8Dynamic64Sym): # Per-group FP8 @@ -812,8 +831,8 @@ def forward_impl( finfo = torch.finfo(_FP8_DTYPE) torch.ops._C.per_token_group_fp8_quant( actual, - quant_output[:num_actual_toks], - output_block_scale[:num_actual_toks], + quant_output[:quant_idx], + output_block_scale[:quant_idx], quant_group_size, 1e-10, # eps finfo.min, @@ -825,7 +844,7 @@ def forward_impl( elif quant_key == kFp8StaticTensorSym: # Static FP8 quantization fp8_data, _ = self._quant_fp8_op(actual, output_scale) - quant_output[:num_actual_toks].copy_(fp8_data) + quant_output[:quant_idx].copy_(fp8_data) else: raise ValueError(f"Unsupported quant_key: {quant_key}") return quant_output @@ -2252,6 +2271,7 @@ def forward_mha( attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, output: torch.Tensor, + output_scale: torch.Tensor | None = None, ) -> None: assert attn_metadata.prefill is not None assert self.dcp_world_size != -1 @@ -2265,6 +2285,9 @@ def forward_mha( q = q.to(prefill_metadata.q_data_type) has_context = prefill_metadata.chunked_context is not None + assert output_scale is None or not has_context, ( + "Fused FP8 output is only wired for the non-chunked-context path" + ) kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim @@ -2281,6 +2304,12 @@ def forward_mha( k=k, v=v, return_softmax_lse=has_context, + out=( + output.view(-1, self.num_heads, self.v_head_dim) + if output_scale is not None + else None + ), + output_scale=output_scale, ) if has_context: @@ -2310,7 +2339,8 @@ def forward_mha( suffix_lse=suffix_lse, prefill_tokens_with_context=prefill_metadata.chunked_context.prefill_tokens_with_context, ) - else: + elif output_scale is None: + # With output_scale set, backend already wrote into `output` in place. assert isinstance(output_prefill, torch.Tensor) output_prefill = output_prefill.flatten(start_dim=-2) output.copy_(output_prefill) diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index af58bfd31a57..3daec8026f05 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -886,6 +886,7 @@ def forward_mha( attn_metadata: T, k_scale: torch.Tensor, output: torch.Tensor, + output_scale: torch.Tensor | None = None, ) -> None: """MHA-style prefill forward pass.""" raise NotImplementedError diff --git a/vllm/v1/attention/backends/mla/prefill/base.py b/vllm/v1/attention/backends/mla/prefill/base.py index 91d668826fd9..3a73fd907c26 100644 --- a/vllm/v1/attention/backends/mla/prefill/base.py +++ b/vllm/v1/attention/backends/mla/prefill/base.py @@ -12,6 +12,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( MLACommonPrefillMetadata, ) + from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.mla.prefill.selector import ( MLAPrefillSelectorConfig, @@ -44,6 +45,12 @@ def supports_dtype(cls, dtype: torch.dtype) -> bool: def is_available(cls) -> bool: return True + def supports_quant_output(self, quant_key: "QuantKey") -> bool: + """Whether `run_prefill_new_tokens` can write quantized output + directly (fused) for the given quant key, skipping the post-quant + pass. Overridden by backends that support it.""" + return False + @classmethod def validate_configuration( cls, @@ -107,6 +114,8 @@ def run_prefill_new_tokens( k: torch.Tensor, v: torch.Tensor, return_softmax_lse: bool, + out: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError diff --git a/vllm/v1/attention/backends/mla/prefill/flash_attn.py b/vllm/v1/attention/backends/mla/prefill/flash_attn.py index 029bd8ec9560..24763378e66b 100644 --- a/vllm/v1/attention/backends/mla/prefill/flash_attn.py +++ b/vllm/v1/attention/backends/mla/prefill/flash_attn.py @@ -8,6 +8,9 @@ import torch import vllm.envs as envs +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from vllm.v1.attention.backends.fa_utils import ( get_flash_attn_version, @@ -17,6 +20,7 @@ if TYPE_CHECKING: from vllm.config import VllmConfig + from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey if is_flash_attn_varlen_func_available(): from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func @@ -87,6 +91,16 @@ def __init__( # Track whether we're using vllm's FA or upstream (for ROCm) self._is_vllm_fa = current_platform.is_cuda() or current_platform.is_xpu() + def supports_quant_output(self, quant_key: "QuantKey") -> bool: + device_capability = current_platform.get_device_capability() + return ( + self.vllm_flash_attn_version == 4 + and self._is_vllm_fa + and device_capability is not None + and device_capability[0] in (10, 11) + and quant_key == kFp8StaticTensorSym + ) + def _flash_attn_varlen_diff_headdims( self, q: torch.Tensor, @@ -94,6 +108,8 @@ def _flash_attn_varlen_diff_headdims( v: torch.Tensor, return_softmax_lse: bool = False, softmax_scale: float | None = None, + out: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: maybe_padded_v = v @@ -104,10 +120,13 @@ def _flash_attn_varlen_diff_headdims( if self._is_vllm_fa: kwargs["return_softmax_lse"] = return_softmax_lse + kwargs["out"] = out + kwargs["output_scale"] = output_scale else: # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse + assert out is None and output_scale is None if envs.VLLM_BATCH_INVARIANT: kwargs["num_splits"] = 1 @@ -140,6 +159,8 @@ def run_prefill_new_tokens( k: torch.Tensor, v: torch.Tensor, return_softmax_lse: bool, + out: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: return self._flash_attn_varlen_diff_headdims( q=q, @@ -152,6 +173,8 @@ def run_prefill_new_tokens( softmax_scale=self.scale, causal=True, return_softmax_lse=return_softmax_lse, + out=out, + output_scale=output_scale, ) def run_prefill_context_chunk( diff --git a/vllm/v1/attention/backends/mla/prefill/flashinfer.py b/vllm/v1/attention/backends/mla/prefill/flashinfer.py index 0204f6ee1a02..92e26a0768e6 100644 --- a/vllm/v1/attention/backends/mla/prefill/flashinfer.py +++ b/vllm/v1/attention/backends/mla/prefill/flashinfer.py @@ -188,6 +188,8 @@ def run_prefill_new_tokens( k: torch.Tensor, v: torch.Tensor, return_softmax_lse: bool, + out: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self._prefill_main is not None diff --git a/vllm/v1/attention/backends/mla/prefill/tokenspeed_mla.py b/vllm/v1/attention/backends/mla/prefill/tokenspeed_mla.py index d6e4fca172ad..21f1c25be773 100644 --- a/vllm/v1/attention/backends/mla/prefill/tokenspeed_mla.py +++ b/vllm/v1/attention/backends/mla/prefill/tokenspeed_mla.py @@ -115,6 +115,8 @@ def run_prefill_new_tokens( k: torch.Tensor, v: torch.Tensor, return_softmax_lse: bool, + out: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from tokenspeed_mla import tokenspeed_mla_prefill diff --git a/vllm/v1/attention/backends/mla/prefill/trtllm_ragged.py b/vllm/v1/attention/backends/mla/prefill/trtllm_ragged.py index afb0444a3148..1d4102ef3c28 100644 --- a/vllm/v1/attention/backends/mla/prefill/trtllm_ragged.py +++ b/vllm/v1/attention/backends/mla/prefill/trtllm_ragged.py @@ -83,6 +83,8 @@ def run_prefill_new_tokens( k: torch.Tensor, v: torch.Tensor, return_softmax_lse: bool, + out: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from flashinfer.prefill import trtllm_ragged_attention_deepseek diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index e0a5730f5fd8..452625b6202b 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -812,6 +812,7 @@ def forward_mha( attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, output: torch.Tensor, + output_scale: torch.Tensor | None = None, ) -> None: """Dispatch prefill to the FP8 ASM kernel when available. @@ -837,6 +838,7 @@ def forward_mha( attn_metadata, k_scale, output, + output_scale, ) assert attn_metadata.prefill is not None @@ -852,8 +854,13 @@ def forward_mha( attn_metadata, k_scale, output, + output_scale, ) + assert output_scale is None, ( + "fused FP8 output not supported by the AITER FP8 MLA prefill path" + ) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) diff --git a/vllm/vllm_flash_attn/flash_attn_interface.py b/vllm/vllm_flash_attn/flash_attn_interface.py index 33955bb239ef..cb1974bf9af8 100644 --- a/vllm/vllm_flash_attn/flash_attn_interface.py +++ b/vllm/vllm_flash_attn/flash_attn_interface.py @@ -200,6 +200,8 @@ def flash_attn_varlen_func( k_descale=None, v_descale=None, num_splits: int = 0, + # FA4 Only + output_scale=None, # Version selector fa_version: int = DEFAULT_FA_VERSION, s_aux=None, @@ -269,6 +271,11 @@ def flash_attn_varlen_func( "seqused_k must be provided if block_table is provided" ) + assert output_scale is None or fa_version == 4, ( + f"Fused FP8 output (output_scale) is only supported by FA4, " + f"got fa_version={fa_version}" + ) + if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) # custom op does not support non-tuple input @@ -388,6 +395,7 @@ def flash_attn_varlen_func( return_lse=return_softmax_lse, out=out, learnable_sink=s_aux, + output_scale=output_scale, ) else: raise ValueError(f"Unsupported FA version: {fa_version}")