diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 75a9c4a22cb4..9debfe350eba 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -730,7 +730,8 @@ steps: # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - - pytest -v -s tests/kernels/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/attention/test_cutlass_mla_decode.py similarity index 100% rename from tests/kernels/test_cutlass_mla_decode.py rename to tests/kernels/attention/test_cutlass_mla_decode.py diff --git a/tests/kernels/attention/test_flashinfer_mla_decode.py b/tests/kernels/attention/test_flashinfer_mla_decode.py new file mode 100644 index 000000000000..02225432f77f --- /dev/null +++ b/tests/kernels/attention/test_flashinfer_mla_decode.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla +from torch import Tensor + +from vllm.platforms import current_platform + +FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="FlashInfer MLA Requires compute capability of 10 or above.", + allow_module_level=True) + + +def ref_mla( + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) +): + bs, num_heads, v_head_dim = out.shape + head_dim = query.shape[2] + + for i in range(bs): + # gather and flatten KV-cache + kv = kv_cache[ + block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, + head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + v = kv[:, :, :v_head_dim] + + q = query[i].view(num_heads, 1, head_dim) + o = F.scaled_dot_product_attention(q, + kv, + v, + scale=scale, + enable_gqa=True) + out[i] = o.view(num_heads, v_head_dim) + + return out + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("bs", [1, 2, 4, 16]) +@pytest.mark.parametrize("block_size", [32, 64]) +def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int): + torch.set_default_device('cuda') + torch.manual_seed(42) + + # Deepseek R1 config + num_heads = 128 + kv_lora_rank = 512 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + qk_head_dim = kv_lora_rank + qk_rope_head_dim + scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5 + + MAX_SEQ_LEN = 1024 + + seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) + + # Generate block tables with random but unique block IDs + # From https://github.com/flashinfer-ai/flashinfer/pull/1222 + blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size + max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4) + total_blocks_needed = sum(blocks_per_seq) + # Get random unique IDs for all blocks + all_block_ids = torch.randperm(total_blocks_needed) + + block_id = 0 + block_tables = torch.zeros( + (bs, max_num_blocks_per_seq), + dtype=torch.int32, + ) + + # Populate block tables and track block assignments + block_id = 0 + for i in range(bs): + num_blocks_needed = blocks_per_seq[i] + block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id + + num_blocks_needed] + block_id += num_blocks_needed + + kv_cache = torch.randn(block_tables.numel(), block_size, + qk_head_dim).to(dtype) + q = torch.randn(bs, num_heads, qk_head_dim).to(dtype) + + out_ref = q.new_zeros(bs, num_heads, kv_lora_rank) + ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor) + + workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=q.device, + ) + # Flashinfer MLA expects the query to be of shape + # (bs, q_len_per_request, num_heads, qk_head_dim), + # where q_len_per_request is the MTP query length (=1 without MTP) + q = q.unsqueeze(1) + + out_ans = trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens_tensor, + max_seq_len=max_seq_len, + bmm1_scale=scale, + ) + out_ans = out_ans.squeeze(1) + torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a4b5abbb35a3..f24c4257f396 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1505,6 +1505,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "FLASH_ATTN_MLA", "FLASHINFER", "FLASHINFER_VLLM_V1", + "FLASHINFER_MLA", "ROCM_AITER_MLA", "TORCH_SDPA_VLLM_V1", "FLEX_ATTENTION", diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index fc1a399d6f43..2b02a27f2a18 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -228,6 +228,8 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) and block_size == 128) + use_flashinfermla = (selected_backend == _Backend.FLASHINFER_MLA + and cls.has_device_capability(100)) use_flashmla = selected_backend in [ _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 ] or (selected_backend is None and is_flashmla_supported()[0]) @@ -252,6 +254,19 @@ def _get_version(name, import_suffix) -> str: else: logger.warning( "Cutlass MLA backend is only supported on V1 engine") + if use_flashinfermla: + if use_v1: + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + set_kv_cache_layout("HND") + logger.info_once( + "Using FlashInfer MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashinfer_mla.FlashInferMLABackend") + else: + logger.warning( + "FlashInfer MLA backend is only supported on V1 engine" + ) if use_flashmla: if block_size != 64: logger.warning( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 0cea49eece42..ab0eaa82ef20 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -51,6 +51,7 @@ class _Backend(enum.Enum): TORCH_SDPA_VLLM_V1 = enum.auto() FLASHINFER = enum.auto() FLASHINFER_VLLM_V1 = enum.auto() + FLASHINFER_MLA = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA_VLLM_V1 = enum.auto() CUTLASS_MLA = enum.auto() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 226bc436058d..440a206eb485 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -381,6 +381,7 @@ class MLACommonMetadata(Generic[D]): num_reqs: int max_query_len: int + max_seq_len: int num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor @@ -644,6 +645,7 @@ def build(self, num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because @@ -830,6 +832,7 @@ def build(self, attn_metadata = self.metadata_cls( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, + max_seq_len=max_seq_len, num_actual_tokens=num_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py new file mode 100644 index 000000000000..71eb9e0ce70e --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla + +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, + is_quantized_kv_cache) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) + +logger = init_logger(__name__) + +FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + + +class FlashInferMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHINFER_MLA" + + @staticmethod + def get_impl_cls() -> type["FlashInferMLAImpl"]: + return FlashInferMLAImpl + + +g_fi_workspace = torch.zeros( + FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device="cuda", +) + + +class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashInferMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashInferMLA V1 with FP8 KV cache not yet supported") + + self._workspace_buffer = g_fi_workspace + + def _forward_decode( + self, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if isinstance(q, tuple): + q_nope, q_pe = q + q = torch.cat([q_nope, q_pe], dim=-1) + + # trtllm API requires extra dimension q_len_per_request for MTP + q = q.unsqueeze(1) + + o = trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), + workspace_buffer=self._workspace_buffer, + qk_nope_head_dim=self.qk_nope_head_dim, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + block_tables=attn_metadata.decode.block_table, + seq_lens=attn_metadata.decode.seq_lens, + max_seq_len=attn_metadata.max_seq_len, + bmm1_scale=self.scale, + ) + + # TODO: Return LSE pending support from Flashinfer API: + # https://github.com/flashinfer-ai/flashinfer/pull/1566 + return o, None