Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e4228dc
Change default from CUTLASS MLA to FlashInfer MLA
MatthewBonanni Jan 19, 2026
30cd686
Change log lines from debug to info
MatthewBonanni Jan 19, 2026
0b4c746
Merge remote-tracking branch 'upstream/main'
MatthewBonanni Jan 19, 2026
2cfba3b
First pass
MatthewBonanni Jan 19, 2026
5417186
Fix typo
MatthewBonanni Jan 19, 2026
c427ff8
Fix pre-commit
MatthewBonanni Jan 19, 2026
b80db7e
Merge branch 'main' into mla_prefill_abstraction
MatthewBonanni Jan 29, 2026
370d66d
Use device_config
MatthewBonanni Jan 29, 2026
5151bd9
Remove dead code
MatthewBonanni Jan 30, 2026
6922185
Bump deprecation version
MatthewBonanni Jan 30, 2026
f0a0acd
Update docs
MatthewBonanni Jan 30, 2026
3dc0fba
Cleanup
MatthewBonanni Jan 30, 2026
0225644
Format name
MatthewBonanni Jan 30, 2026
8013037
Fix dagger inside quotes
MatthewBonanni Jan 30, 2026
6688a67
Add selector test
MatthewBonanni Jan 30, 2026
a962bee
Update table
MatthewBonanni Jan 30, 2026
25d976f
Comment
MatthewBonanni Jan 30, 2026
f30cb13
Add model dtype support
MatthewBonanni Jan 30, 2026
eb27ec8
Add type annotation to fix docs build
MatthewBonanni Jan 30, 2026
17c9a9e
Merge branch 'main' into mla_prefill_abstraction
MatthewBonanni Feb 2, 2026
0f42a95
Fix rebase issue
MatthewBonanni Feb 2, 2026
97c9aa7
Introduce hashable config
MatthewBonanni Feb 2, 2026
fb7bead
Fix test
MatthewBonanni Feb 2, 2026
3b25d20
Pass device capability directly
MatthewBonanni Feb 2, 2026
0086b94
Fix hashing
MatthewBonanni Feb 2, 2026
5051737
Fix tests
MatthewBonanni Feb 3, 2026
115a75b
Merge branch 'main' into mla_prefill_abstraction
MatthewBonanni Mar 16, 2026
88d0be8
Fix pre-commit
MatthewBonanni Mar 16, 2026
0cd22de
Clean up FA import
MatthewBonanni Mar 16, 2026
7bf5ae8
Merge branch 'main' into mla_prefill_abstraction
MatthewBonanni Apr 2, 2026
d4c77ab
wip
LucasWilkinson Apr 7, 2026
2f0b2ad
cleanup
LucasWilkinson Apr 8, 2026
1d04d07
cleanup
LucasWilkinson Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,21 +189,26 @@ MLA uses separate backends for prefill and decode phases.

### Prefill Backends

The prefill backend is selected at runtime based on hardware and
configuration.

| Backend | Description | Compute Cap. | Enable | Disable | Notes |
| ------- | ----------- | ------------ | ------ | ------- | ----- |
| TRT-LLM Ragged‡ | TensorRT-LLM ragged attention | 10.x | Default on SM100 | `-ac.use_trtllm_ragged_deepseek_prefill=0` | DeepSeek R1 dims only |
| FlashInfer | FlashInfer CUTLASS backend | 10.x | `-ac.disable_flashinfer_prefill=0` | `-ac.disable_flashinfer_prefill=1` | DeepSeek R1 dims only |
| cuDNN | cuDNN-based attention | 10.x | `-ac.use_cudnn_prefill=1` | `-ac.use_cudnn_prefill=0` | |
| FlashAttention | FlashAttention varlen (FA2/FA3) | Any | Default fallback | Use other backends | FA3 on SM90, FA2 otherwise |
To explicitly select a prefill backend, use
`-ac.mla_prefill_backend=<BACKEND>` (e.g., `FLASH_ATTN`, `FLASHINFER`).
Otherwise, the prefill backend is selected automatically at runtime based on
hardware and configuration.

| Backend | Description | Dtypes | Compute Cap. | Notes |
| ------- | ----------- | ------ | ------------ | ----- |
| `TRTLLM_RAGGED_PREFILL`‡ | TensorRT-LLM ragged attention | fp16, bf16 | 10.x | DeepSeek R1 dims only |
| `FLASHINFER_PREFILL` | FlashInfer CUTLASS backend | fp16, bf16 | 10.x | DeepSeek R1 dims only |
| `CUDNN_PREFILL` | cuDNN-based attention | fp16, bf16 | 10.x | DeepSeek R1 dims only |
| `FLASH_ATTN_PREFILL` | FlashAttention varlen (FA2/FA3) | fp16, bf16 | Any | FA3 on SM90, FA2 otherwise |

> **‡** TRT-LLM Ragged is the default on Blackwell (SM100).
> On other GPUs, FlashAttention is used as the default.

### Decode Backends

MLA decode backends are selected using the standard
`-ac.backend=<BACKEND>` argument (e.g., `FLASHMLA`, `TRITON_MLA`).

| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. |
| ------- | ------ | --------- | ----------- | ---------- | ---- | ------ | --------- | --- | --------------- | ------------ |
| `CUTLASS_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x |
Expand Down
103 changes: 103 additions & 0 deletions tests/kernels/test_flashinfer_mla_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

from vllm.platforms import current_platform

FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024

if not current_platform.has_device_capability(100):
pytest.skip(
reason="FlashInfer MLA Requires compute capability of 10 or above.",
allow_module_level=True,
)
else:
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla


_FP8_DTYPES = {
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
}


def _finite_check_tensor(t: torch.Tensor) -> torch.Tensor:
return t.to(torch.float16) if t.dtype in _FP8_DTYPES else t


@pytest.mark.parametrize("max_seq_len", [128, 256, 512, 1024, 2048, 4096])
def test_flashinfer_mla_decode_padding_rows_not_updated(max_seq_len: int):
"""Regression test: kernel must not write into padding rows."""
torch.set_default_device("cuda")
torch.manual_seed(42)

dtype = torch.float8_e4m3fn
block_size = 64
num_heads = 128
kv_lora_rank = 512
qk_nope_head_dim = 128
qk_rope_head_dim = 64
qk_head_dim = kv_lora_rank + qk_rope_head_dim

bs = 8
q_len_per_request = 1
seq_lens_tensor = torch.tensor([3, 3, 3, 3, 3, 0, 0, 0], dtype=torch.int32)

# Build a realistic block-table layout:
# - table width follows max_seq_len,
# - active rows get unique page IDs for required blocks,
# - unused slots stay -1.
max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.full((bs, max_blocks_per_seq), -1, dtype=torch.int32)
num_pages = bs * max_blocks_per_seq
page_ids = torch.randperm(num_pages, dtype=torch.int32)
cursor = 0
for req_idx, seq_len in enumerate(seq_lens_tensor.tolist()):
num_blocks = (seq_len + block_size - 1) // block_size
if num_blocks > 0:
block_tables[req_idx, :num_blocks] = page_ids[cursor : cursor + num_blocks]
cursor += num_blocks

kv_cache = torch.randn(num_pages, block_size, qk_head_dim).to(dtype)
q = torch.randn(bs, q_len_per_request, num_heads, qk_head_dim).to(dtype)
assert torch.isfinite(_finite_check_tensor(kv_cache)).all(), (
"kv_cache contains NaN/Inf before test."
)
assert torch.isfinite(_finite_check_tensor(q)).all(), (
"q contains NaN/Inf before test."
)

workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=q.device,
)

out = torch.zeros(
(bs, num_heads, kv_lora_rank),
dtype=torch.bfloat16,
device=q.device,
)
padding_rows = seq_lens_tensor == 0
padding_expected = torch.ones_like(out[padding_rows])

for i in range(1000):
out[padding_rows] = 1
out_ans = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_cache.unsqueeze(1),
workspace_buffer=workspace_buffer,
qk_nope_head_dim=qk_nope_head_dim,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_tables=block_tables,
seq_lens=seq_lens_tensor,
max_seq_len=max_seq_len,
out=out,
)
assert torch.equal(out_ans[padding_rows], padding_expected), (
f"Kernel updated padding rows (seq_lens == 0) at iteration {i}."
)
Loading
Loading