Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,8 @@ steps:
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
- pytest -v -s tests/kernels/test_cutlass_mla_decode.py
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
- pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py
# Quantization
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
Expand Down
123 changes: 123 additions & 0 deletions tests/kernels/attention/test_flashinfer_mla_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from torch import Tensor

from vllm.platforms import current_platform

FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024

if not current_platform.has_device_capability(100):
pytest.skip(
reason="FlashInfer MLA Requires compute capability of 10 or above.",
allow_module_level=True)


def ref_mla(
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
):
bs, num_heads, v_head_dim = out.shape
head_dim = query.shape[2]

for i in range(bs):
# gather and flatten KV-cache
kv = kv_cache[
block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1,
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
v = kv[:, :, :v_head_dim]

q = query[i].view(num_heads, 1, head_dim)
o = F.scaled_dot_product_attention(q,
kv,
v,
scale=scale,
enable_gqa=True)
out[i] = o.view(num_heads, v_head_dim)

return out


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("bs", [1, 2, 4, 16])
@pytest.mark.parametrize("block_size", [32, 64])
def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
torch.set_default_device('cuda')
torch.manual_seed(42)

# Deepseek R1 config
num_heads = 128
kv_lora_rank = 512
qk_nope_head_dim = 128
qk_rope_head_dim = 64
qk_head_dim = kv_lora_rank + qk_rope_head_dim
scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5

MAX_SEQ_LEN = 1024

seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)

# Generate block tables with random but unique block IDs
# From https://github.com/flashinfer-ai/flashinfer/pull/1222
blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size
max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4)
total_blocks_needed = sum(blocks_per_seq)
# Get random unique IDs for all blocks
all_block_ids = torch.randperm(total_blocks_needed)

block_id = 0
block_tables = torch.zeros(
(bs, max_num_blocks_per_seq),
dtype=torch.int32,
)

# Populate block tables and track block assignments
block_id = 0
for i in range(bs):
num_blocks_needed = blocks_per_seq[i]
block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id +
num_blocks_needed]
block_id += num_blocks_needed

kv_cache = torch.randn(block_tables.numel(), block_size,
qk_head_dim).to(dtype)
q = torch.randn(bs, num_heads, qk_head_dim).to(dtype)

out_ref = q.new_zeros(bs, num_heads, kv_lora_rank)
ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor)

workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=q.device,
)
# Flashinfer MLA expects the query to be of shape
# (bs, q_len_per_request, num_heads, qk_head_dim),
# where q_len_per_request is the MTP query length (=1 without MTP)
q = q.unsqueeze(1)

out_ans = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_cache.unsqueeze(1),
workspace_buffer=workspace_buffer,
qk_nope_head_dim=qk_nope_head_dim,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_tables=block_tables,
seq_lens=seq_lens_tensor,
max_seq_len=max_seq_len,
bmm1_scale=scale,
)
out_ans = out_ans.squeeze(1)
torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2)
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"FLASH_ATTN_MLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
"FLASHINFER_MLA",
"ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1",
"FLEX_ATTENTION",
Expand Down
15 changes: 15 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None and cls.is_device_capability(100)
and block_size == 128)
use_flashinfermla = (selected_backend == _Backend.FLASHINFER_MLA
and cls.has_device_capability(100))
use_flashmla = selected_backend in [
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
] or (selected_backend is None and is_flashmla_supported()[0])
Expand All @@ -252,6 +254,19 @@ def _get_version(name, import_suffix) -> str:
else:
logger.warning(
"Cutlass MLA backend is only supported on V1 engine")
if use_flashinfermla:
if use_v1:
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
set_kv_cache_layout("HND")
logger.info_once(
"Using FlashInfer MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashinfer_mla.FlashInferMLABackend")
else:
logger.warning(
"FlashInfer MLA backend is only supported on V1 engine"
)
if use_flashmla:
if block_size != 64:
logger.warning(
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class _Backend(enum.Enum):
TORCH_SDPA_VLLM_V1 = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
TRITON_MLA_VLLM_V1 = enum.auto()
CUTLASS_MLA = enum.auto()
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ class MLACommonMetadata(Generic[D]):

num_reqs: int
max_query_len: int
max_seq_len: int

num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
Expand Down Expand Up @@ -644,6 +645,7 @@ def build(self,
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len

# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
Expand Down Expand Up @@ -830,6 +832,7 @@ def build(self,
attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=max_seq_len,
num_actual_tokens=num_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
Expand Down
110 changes: 110 additions & 0 deletions vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, Union

import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla

from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)

logger = init_logger(__name__)

FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024


class FlashInferMLABackend(MLACommonBackend):

@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"

@staticmethod
def get_impl_cls() -> type["FlashInferMLAImpl"]:
return FlashInferMLAImpl


g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device="cuda",
)


class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)

unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashInferMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashInferMLA V1 with FP8 KV cache not yet supported")

self._workspace_buffer = g_fi_workspace

def _forward_decode(
self,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

if isinstance(q, tuple):
q_nope, q_pe = q
q = torch.cat([q_nope, q_pe], dim=-1)

# trtllm API requires extra dimension q_len_per_request for MTP
q = q.unsqueeze(1)

o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.scale,
)

# TODO: Return LSE pending support from Flashinfer API:
# https://github.com/flashinfer-ai/flashinfer/pull/1566
return o, None