Skip to content

[Attention] Add FlashInfer FA2 MLA attention backend for SM120#36322

Open
grimulkan wants to merge 2 commits intovllm-project:mainfrom
grimulkan:flashinfer-fa2-mla
Open

[Attention] Add FlashInfer FA2 MLA attention backend for SM120#36322
grimulkan wants to merge 2 commits intovllm-project:mainfrom
grimulkan:flashinfer-fa2-mla

Conversation

@grimulkan
Copy link
Contributor

@grimulkan grimulkan commented Mar 7, 2026

[Attention] Add FlashInfer FA2 MLA attention backend for SM120

Add a new MLA attention backend using FlashInfer's FA2 (FlashAttention-2) kernel via the BatchMLAPagedAttentionWrapper plan()/run() API. This provides a native compiled CUDA decode path for MLA models on SM120 (consumer Blackwell), where the existing FLASHINFER_MLA backend's trtllm-gen cubins are unavailable.

  • Implement FlashInferFA2MLABackend with CSR-format page indices and per-batch-size wrapper caching for CUDA graph compatibility
  • Support BF16 KV cache with LSE return for DCP compatibility
  • Support MTP via uniform qo_indptr (multiple tokens per request)
  • Register with SM12 capability gating; prioritize above Triton MLA for automatic BF16 selection

Purpose

The existing FLASHINFER_MLA backend uses FlashInfer's trtllm_batch_decode_with_kv_cache_mla() stateless API, which dispatches to pre-compiled trtllm-gen FMHA cubins on SM100 (and the XQA kernel on SM120, though this is not currently used by vllm). This API takes a merged Q/KV tensor and a flat block table, which is a fundamentally different calling convention from FlashInfer's FA2 path, which uses a stateful BatchMLAPagedAttentionWrapper with split q_nope/q_pe/ckv/kpe tensors and CSR-format page indices via plan()/run().

Rather than modifying the existing FLASHINFER_MLA backend (which would risk regressing the well-tested SM100 trtllm-gen path), this PR adds a separate FLASHINFER_FA2_MLA backend that uses the FA2 API directly. This mirrors the approach taken by sglang, which uses BatchMLAPagedAttentionWrapper as its primary MLA decode path across SM80/90/120.

With this change, SM120 MLA backend selection becomes:

  • BF16 KV cacheFLASHINFER_FA2_MLA (compiled CUDA FA2 kernel, auto-selected)
  • FP8 KV cacheTRITON_MLA (FA2's MMA instructions only support f16/bf16, so FA2 is filtered out and Triton takes over)

Changes

  • Add FlashInferFA2MLABackend in flashinfer_fa2_mla.py
    • FlashInferFA2MLAMetadataBuilder: converts block table to CSR-format kv_indptr/kv_indices via existing _copy_page_indices_kernel, manages per-batch-size BatchMLAPagedAttentionWrapper instances for CUDA graph compatibility
    • FlashInferFA2MLAImpl: splits query into q_nope/q_pe, splits cache into ckv/kpe, calls wrapper.run() with return_lse=True
  • Register FLASHINFER_FA2_MLA enum in registry.py
  • Add to SM120 MLA priority list in cuda.py, positioned above TRITON_MLA

Test Plan

  • Test inference using vllm serve with --attention-backend FLASHINFER_FA2_MLA on SM120
  • End-to-end lm_eval (GSM8K) on Kimi K2.5

Test Results

Inference Tests:

Support Matrix:

Attention backend bf16 DCP=1 bf16 DCP>1 fp8 DCP=1 fp8 DCP>1
Triton MLA True True True* True**
Flashinfer FA2 True True False False

Kimi K2.5 on sm120 RTX 6000 Pro (native int4 experts, Marlin gemm, dense MLA):

Cards TP DCP PP KV Cache Total KV Cache Space Triton MLA Peak Flashinfer FA2 Peak
8 8 8 1 fp8 3M tok 68 tok/s** N/A
8 8 1 1 fp8 380K tok 79 tok/s* N/A
8 8 8 1 bf16 1.5M tok 67 tok/s 72 tok/s
8 8 1 1 bf16 190K tok 78 tok/s 90 tok/s

*Requires #34597
**Requires #34795

End-to-end test:
GSM8K lm_eval with --kv-cache-dtype auto --attention-backend FLASHINFER_FA2_MLA (this is actually the new default) on 8 x RTX 6000 Pro with Kimi K2.5:

lm_eval   --model local-completions   --model_args "model=Kimi-K2.5-Thinking,base_url=http://localhost:5000/v1/completions,tokenizer=/models/moonshotai-Kimi-K2.5,num_concurrent=384,trust_remote_code=True,max_length=8192"   --tasks gsm8k   --num_fewshot 5
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9265|±  |0.0072|
|     |       |strict-match    |     5|exact_match|↑  |0.9249|±  |0.0073|

Reference with --attention-backend TRITON_MLA:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9378|±  |0.0067|
|     |       |strict-match    |     5|exact_match|↑  |0.9378|±  |0.0067|

Known Limitations

  • BF16 KV cache only. FA2's compiled CUDA kernel uses mma_sync_m16n16k16_row_col_f16f16f32 which only supports half/bfloat16 MMA operands. There is no in-kernel FP8 dequant path (unlike Triton MLA). For FP8, TRITON_MLA will work if [Kernel] Add FP8 KV cache support to Triton MLA decode attention #34597 is merged (auto-selected when FA2 is filtered out).
  • Requires qk_nope_head_dim == 128 (validated in supports_combination).
  • In the future, perhaps this backend could be merged with Flashinfer MLA using a unified BatchMLAPagedAttentionWrapper
    call, letting Flashinfer handle the sm100, sm120, etc., dispatch instead. Creating a separate backend to handle the FA2 path (which is useful for sm89 also) seems clunky, but I'm not sure there's a better way without disrupting the current sm90/100 approach.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Add a new MLA attention backend using FlashInfer's BatchMLAPagedAttention-
Wrapper with the FA2 (FlashAttention-2) kernel. This provides a native
decode path for MLA models on SM120 (consumer Blackwell) GPUs, where the
existing FlashInfer MLA backend's trtllm-gen cubins are unavailable.

- Implement FlashInferFA2MLABackend using the plan()/run() stateful API
  with CSR-format page indices (vs the stateless trtllm API)
- Support BF16 KV cache only (FA2's MMA instructions are f16/bf16)
- Enable LSE return for Decode Context Parallel (DCP) compatibility
- Support MTP via qo_indptr (uniform query lengths per request)
- Add per-batch-size wrapper caching for CUDA graph compatibility
- Register backend with SM12 capability gating; prioritize above
  Triton MLA so FA2 is auto-selected for BF16, Triton for FP8

Signed-off-by: grimulkan <grimulkan@gmail.com>
@grimulkan grimulkan requested a review from pavanimajety as a code owner March 7, 2026 10:15
Copilot AI review requested due to automatic review settings March 7, 2026 10:15
@mergify
Copy link

mergify bot commented Mar 7, 2026

Documentation preview: https://vllm--36322.org.readthedocs.build/en/36322/

@mergify mergify bot added documentation Improvements or additions to documentation nvidia v1 labels Mar 7, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new MLA attention backend, FLASHINFER_FA2_MLA, specifically for NVIDIA GPUs with SM 12.x compute capability. This backend leverages FlashInfer's FlashAttention-2 kernel to provide a native CUDA path for MLA models, which is particularly useful for consumer Blackwell GPUs where existing trtllm-gen cubins are not available. The changes include the implementation of the backend, its registration, and updates to documentation and backend priority lists.

The implementation is well-structured and demonstrates a good understanding of the vLLM attention backend framework and the FlashInfer library. My main feedback is a critical performance concern in the metadata building path, where a GPU-to-CPU synchronization can be avoided. Addressing this will improve the performance of this new backend.

Note: Security Review is unavailable for this PR.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new MLA attention backend (FLASHINFER_FA2_MLA) that uses FlashInfer's BatchMLAPagedAttentionWrapper with the FA2 (FlashAttention-2) kernel for MLA decode on SM 12.x (consumer Blackwell) GPUs. This fills a gap where the existing FLASHINFER_MLA backend's trtllm-gen cubins are unavailable on SM 12.x, providing a compiled CUDA decode path as an alternative to the Triton MLA backend.

Changes:

  • New FlashInferFA2MLABackend implementation with CSR-format page indices conversion, per-batch-size wrapper caching for CUDA graph compatibility, and BF16-only KV cache support with LSE return for DCP
  • Registration of FLASHINFER_FA2_MLA in the backend enum and SM 12.x MLA priority list, positioned above TRITON_MLA
  • Documentation updates to the attention backends design doc with priority table and feature matrix entries

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
vllm/v1/attention/backends/mla/flashinfer_fa2_mla.py New backend implementation: FlashInferFA2MLABackend, FlashInferFA2MLAMetadataBuilder, FlashInferFA2MLAImpl using FlashInfer's BatchMLAPagedAttentionWrapper plan()/run() API
vllm/v1/attention/backends/registry.py Registers FLASHINFER_FA2_MLA enum entry pointing to the new backend class
vllm/platforms/cuda.py Adds FLASHINFER_FA2_MLA to the non-SM10.x MLA priority list, above TRITON_MLA
docs/design/attention_backends.md Updates priority table and decode backends feature matrix with the new backend

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Signed-off-by: grimulkan <grimulkan@gmail.com>
@geraldstanje
Copy link

hi @grimulkan would that also work for gpt-oss-20b on nvidia rtx 6000 pro blackwell?

@grimulkan
Copy link
Contributor Author

grimulkan commented Mar 9, 2026

@geraldstanje I think GPT-OSS-20B can work with normal (non-MLA) attention backends for decode already and doesn't need this, unless I am mistaken

EDIT: Apparently there is a community project to convert the model to use MLA, maybe not yet complete. In theory, yes, that model should work with this.

@geraldstanje
Copy link

geraldstanje commented Mar 9, 2026

@geraldstanje I think GPT-OSS-20B can work with normal (non-MLA) attention backends for decode already and doesn't need this, unless I am mistaken

EDIT: Apparently there is a community project to convert the model to use MLA, maybe not yet complete. In theory, yes, that model should work with this.

hi @grimulkan where can i see more about the community project regarding MLA gpt oss-safeguard-20b? does MLA mean it will speedup inference on nvidia rtx 6000 pro blackwell?

@grimulkan
Copy link
Contributor Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation nvidia v1

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants