[Attention] Add FlashInfer FA2 MLA attention backend for SM120#36322
[Attention] Add FlashInfer FA2 MLA attention backend for SM120#36322grimulkan wants to merge 2 commits intovllm-project:mainfrom
Conversation
Add a new MLA attention backend using FlashInfer's BatchMLAPagedAttention- Wrapper with the FA2 (FlashAttention-2) kernel. This provides a native decode path for MLA models on SM120 (consumer Blackwell) GPUs, where the existing FlashInfer MLA backend's trtllm-gen cubins are unavailable. - Implement FlashInferFA2MLABackend using the plan()/run() stateful API with CSR-format page indices (vs the stateless trtllm API) - Support BF16 KV cache only (FA2's MMA instructions are f16/bf16) - Enable LSE return for Decode Context Parallel (DCP) compatibility - Support MTP via qo_indptr (uniform query lengths per request) - Add per-batch-size wrapper caching for CUDA graph compatibility - Register backend with SM12 capability gating; prioritize above Triton MLA so FA2 is auto-selected for BF16, Triton for FP8 Signed-off-by: grimulkan <grimulkan@gmail.com>
|
Documentation preview: https://vllm--36322.org.readthedocs.build/en/36322/ |
There was a problem hiding this comment.
Code Review
This pull request introduces a new MLA attention backend, FLASHINFER_FA2_MLA, specifically for NVIDIA GPUs with SM 12.x compute capability. This backend leverages FlashInfer's FlashAttention-2 kernel to provide a native CUDA path for MLA models, which is particularly useful for consumer Blackwell GPUs where existing trtllm-gen cubins are not available. The changes include the implementation of the backend, its registration, and updates to documentation and backend priority lists.
The implementation is well-structured and demonstrates a good understanding of the vLLM attention backend framework and the FlashInfer library. My main feedback is a critical performance concern in the metadata building path, where a GPU-to-CPU synchronization can be avoided. Addressing this will improve the performance of this new backend.
Note: Security Review is unavailable for this PR.
There was a problem hiding this comment.
Pull request overview
This PR adds a new MLA attention backend (FLASHINFER_FA2_MLA) that uses FlashInfer's BatchMLAPagedAttentionWrapper with the FA2 (FlashAttention-2) kernel for MLA decode on SM 12.x (consumer Blackwell) GPUs. This fills a gap where the existing FLASHINFER_MLA backend's trtllm-gen cubins are unavailable on SM 12.x, providing a compiled CUDA decode path as an alternative to the Triton MLA backend.
Changes:
- New
FlashInferFA2MLABackendimplementation with CSR-format page indices conversion, per-batch-size wrapper caching for CUDA graph compatibility, and BF16-only KV cache support with LSE return for DCP - Registration of
FLASHINFER_FA2_MLAin the backend enum and SM 12.x MLA priority list, positioned aboveTRITON_MLA - Documentation updates to the attention backends design doc with priority table and feature matrix entries
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
vllm/v1/attention/backends/mla/flashinfer_fa2_mla.py |
New backend implementation: FlashInferFA2MLABackend, FlashInferFA2MLAMetadataBuilder, FlashInferFA2MLAImpl using FlashInfer's BatchMLAPagedAttentionWrapper plan()/run() API |
vllm/v1/attention/backends/registry.py |
Registers FLASHINFER_FA2_MLA enum entry pointing to the new backend class |
vllm/platforms/cuda.py |
Adds FLASHINFER_FA2_MLA to the non-SM10.x MLA priority list, above TRITON_MLA |
docs/design/attention_backends.md |
Updates priority table and decode backends feature matrix with the new backend |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: grimulkan <grimulkan@gmail.com>
|
hi @grimulkan would that also work for gpt-oss-20b on nvidia rtx 6000 pro blackwell? |
|
@geraldstanje I think GPT-OSS-20B can work with normal (non-MLA) attention backends for decode already and doesn't need this, unless I am mistaken EDIT: Apparently there is a community project to convert the model to use MLA, maybe not yet complete. In theory, yes, that model should work with this. |
hi @grimulkan where can i see more about the community project regarding MLA gpt oss-safeguard-20b? does MLA mean it will speedup inference on nvidia rtx 6000 pro blackwell? |
|
https://www.reddit.com/r/deeplearning/comments/1qcmj31/gptoss_mla_conversion_breakthrough_20b_still/ |
[Attention] Add FlashInfer FA2 MLA attention backend for SM120
Add a new MLA attention backend using FlashInfer's FA2 (FlashAttention-2) kernel via the
BatchMLAPagedAttentionWrapperplan()/run() API. This provides a native compiled CUDA decode path for MLA models on SM120 (consumer Blackwell), where the existingFLASHINFER_MLAbackend's trtllm-gen cubins are unavailable.FlashInferFA2MLABackendwith CSR-format page indices and per-batch-size wrapper caching for CUDA graph compatibilityqo_indptr(multiple tokens per request)Purpose
The existing
FLASHINFER_MLAbackend uses FlashInfer'strtllm_batch_decode_with_kv_cache_mla()stateless API, which dispatches to pre-compiled trtllm-gen FMHA cubins on SM100 (and the XQA kernel on SM120, though this is not currently used byvllm). This API takes a merged Q/KV tensor and a flat block table, which is a fundamentally different calling convention from FlashInfer's FA2 path, which uses a statefulBatchMLAPagedAttentionWrapperwith splitq_nope/q_pe/ckv/kpetensors and CSR-format page indices viaplan()/run().Rather than modifying the existing
FLASHINFER_MLAbackend (which would risk regressing the well-tested SM100 trtllm-gen path), this PR adds a separateFLASHINFER_FA2_MLAbackend that uses the FA2 API directly. This mirrors the approach taken by sglang, which usesBatchMLAPagedAttentionWrapperas its primary MLA decode path across SM80/90/120.With this change, SM120 MLA backend selection becomes:
FLASHINFER_FA2_MLA(compiled CUDA FA2 kernel, auto-selected)TRITON_MLA(FA2's MMA instructions only support f16/bf16, so FA2 is filtered out and Triton takes over)Changes
FlashInferFA2MLABackendin flashinfer_fa2_mla.pyFlashInferFA2MLAMetadataBuilder: converts block table to CSR-formatkv_indptr/kv_indicesvia existing_copy_page_indices_kernel, manages per-batch-sizeBatchMLAPagedAttentionWrapperinstances for CUDA graph compatibilityFlashInferFA2MLAImpl: splits query intoq_nope/q_pe, splits cache intockv/kpe, callswrapper.run()withreturn_lse=TrueFLASHINFER_FA2_MLAenum in registry.pyTRITON_MLATest Plan
vllm servewith--attention-backend FLASHINFER_FA2_MLAon SM120Test Results
Inference Tests:
Support Matrix:
Kimi K2.5 on sm120 RTX 6000 Pro (native int4 experts, Marlin gemm, dense MLA):
*Requires #34597
**Requires #34795
End-to-end test:
GSM8K
lm_evalwith--kv-cache-dtype auto --attention-backend FLASHINFER_FA2_MLA(this is actually the new default) on 8 x RTX 6000 Pro with Kimi K2.5:Reference with
--attention-backend TRITON_MLA:Known Limitations
mma_sync_m16n16k16_row_col_f16f16f32which only supports half/bfloat16 MMA operands. There is no in-kernel FP8 dequant path (unlike Triton MLA). For FP8,TRITON_MLAwill work if [Kernel] Add FP8 KV cache support to Triton MLA decode attention #34597 is merged (auto-selected when FA2 is filtered out).qk_nope_head_dim == 128(validated insupports_combination).BatchMLAPagedAttentionWrappercall, letting Flashinfer handle the sm100, sm120, etc., dispatch instead. Creating a separate backend to handle the FA2 path (which is useful for sm89 also) seems clunky, but I'm not sure there's a better way without disrupting the current sm90/100 approach.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.