diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index e7170babb6c9..a44cad027a7b 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -137,8 +137,9 @@ Priority is **1 = highest** (tried first). | 1 | `FLASH_ATTN_MLA` | | 2 | `FLASHMLA` | | 3 | `FLASHINFER_MLA` | -| 4 | `TRITON_MLA` | -| 5 | `FLASHMLA_SPARSE` | +| 4 | `FLASHINFER_FA2_MLA` | +| 5 | `TRITON_MLA` | +| 6 | `FLASHMLA_SPARSE` | > **Note:** ROCm and CPU platforms have their own selection logic. See the platform-specific documentation for details. @@ -205,6 +206,7 @@ configuration. | Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. | |---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------| | `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x | +| `FLASHINFER_FA2_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | | `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 32c0b9064275..51776d2343be 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -39,6 +39,7 @@ AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.FLASH_ATTN_MLA, AttentionBackendEnum.FLASHINFER_MLA, + AttentionBackendEnum.FLASHINFER_FA2_MLA, AttentionBackendEnum.TRITON_MLA, ] @@ -47,6 +48,10 @@ BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA) +# Remove FLASHINFER_FA2_MLA on sm100 (not supported) +if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major == 10: + BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_FA2_MLA) + # Remove FLASH_ATTN_MLA from the list if not supported if not flash_attn_supports_mla(): BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASH_ATTN_MLA) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index d3d75d883196..d6dcc6361225 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -77,6 +77,7 @@ def _get_backend_priorities( AttentionBackendEnum.FLASH_ATTN_MLA, AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.FLASHINFER_MLA, + AttentionBackendEnum.FLASHINFER_FA2_MLA, AttentionBackendEnum.TRITON_MLA, AttentionBackendEnum.FLASHMLA_SPARSE, ] diff --git a/vllm/v1/attention/backends/mla/flashinfer_fa2_mla.py b/vllm/v1/attention/backends/mla/flashinfer_fa2_mla.py new file mode 100644 index 000000000000..391a6ffa6ebd --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashinfer_fa2_mla.py @@ -0,0 +1,333 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""FlashInfer FA2 MLA Backend. + +Uses the FlashInfer BatchMLAPagedAttentionWrapper with the FA2 backend for +bf16 MLA decode. This backend requires plan()/run() API with CSR-format +page indices and supports returning LSE for DCP. +""" + +from dataclasses import dataclass +from typing import ClassVar + +import torch +from flashinfer import BatchMLAPagedAttentionWrapper + +from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.model_executor.layers.attention.mla_attention import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + QueryLenSupport, + get_mla_dims, +) +from vllm.platforms.interface import DeviceCapability +from vllm.v1.attention.backend import ( + AttentionCGSupport, + AttentionLayer, + AttentionType, +) +from vllm.v1.attention.backends.flashinfer import _copy_page_indices_kernel +from vllm.v1.attention.backends.utils import ( + KVCacheLayoutType, + get_per_layer_parameters, + infer_global_hyperparameters, +) + +logger = init_logger(__name__) + +FLASHINFER_FA2_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + + +class FlashInferFA2MLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + ] + + @staticmethod + def get_name() -> str: + return "FLASHINFER_FA2_MLA" + + @staticmethod + def get_impl_cls() -> type["FlashInferFA2MLAImpl"]: + return FlashInferFA2MLAImpl + + @staticmethod + def get_builder_cls() -> type["FlashInferFA2MLAMetadataBuilder"]: + return FlashInferFA2MLAMetadataBuilder + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + # Supported on SM 8.x+ except SM 10.x (which has dedicated backends). + return capability.major != 10 + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + if vllm_config.model_config is not None: + hf_text_config = vllm_config.model_config.hf_text_config + qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) + if qk_nope_head_dim != 128: + return ( + f"FlashInfer FA2 MLA kernel requires " + f"qk_nope_head_dim == 128, but got {qk_nope_head_dim}" + ) + return None + + @classmethod + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": + return None + + +@dataclass +class FlashInferFA2MLADecodeMetadata(MLACommonDecodeMetadata): + wrapper: BatchMLAPagedAttentionWrapper | None = None + + +class FlashInferFA2MLAMetadataBuilder( + MLACommonMetadataBuilder[MLACommonMetadata[FlashInferFA2MLADecodeMetadata]], +): + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM + + def __init__(self, kv_cache_spec, layer_names, vllm_config, device, **kwargs): + super().__init__(kv_cache_spec, layer_names, vllm_config, device, **kwargs) + + self._mla_dims = get_mla_dims(self.model_config) + + max_num_pages = ( + self.vllm_config.cache_config.num_gpu_blocks + if self.vllm_config.cache_config.num_gpu_blocks + else self.vllm_config.scheduler_config.max_num_seqs * 1024 + ) + + self._workspace_buffer = torch.zeros( + FLASHINFER_FA2_MLA_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device, + ) + # Max kv_indices buffer shared by all wrappers (kv_indices supports + # variable-length copy in plan(), so one buffer is fine). + self._kv_indices_buf = torch.zeros( + max_num_pages, dtype=torch.int32, device=self.device + ) + + # Per-batch-size wrappers for CUDA graph compatibility. + # Each wrapper has pre-allocated buffers sized exactly for its + # batch size so plan()'s in-place copy works correctly. + self._wrappers: dict[int, BatchMLAPagedAttentionWrapper] = {} + + # Pre-compute constant plan() parameters. + self._num_heads = self.num_heads * self.dcp_world_size + # Derive sm_scale from the model's attention layers so it includes + # any model-specific corrections (e.g. YaRN mscale for DeepSeek). + global_params = infer_global_hyperparameters( + get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) # type: ignore[type-abstract] + ) + self._sm_scale = global_params.sm_scale + + def _get_wrapper(self, batch_size: int) -> BatchMLAPagedAttentionWrapper: + wrapper = self._wrappers.get(batch_size) + if wrapper is None: + wrapper = BatchMLAPagedAttentionWrapper( + self._workspace_buffer, + use_cuda_graph=True, + qo_indptr=torch.zeros( + batch_size + 1, dtype=torch.int32, device=self.device + ), + kv_indptr=torch.zeros( + batch_size + 1, dtype=torch.int32, device=self.device + ), + kv_indices=self._kv_indices_buf, + kv_len_arr=torch.zeros( + batch_size, dtype=torch.int32, device=self.device + ), + backend="auto", + ) + self._wrappers[batch_size] = wrapper + return wrapper + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_device: torch.Tensor, + max_seq_len: int, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, + ) -> FlashInferFA2MLADecodeMetadata: + num_decodes = seq_lens_device.shape[0] + + # Compute number of pages per request + num_blocks = ( + seq_lens_device.to(torch.int32) + self.page_size - 1 + ) // self.page_size + + # Build CSR-format kv_indptr + kv_indptr = torch.empty(num_decodes + 1, dtype=torch.int32, device=self.device) + kv_indptr[0] = 0 + torch.cumsum(num_blocks, dim=0, out=kv_indptr[1:]) + + # Flatten block_table into kv_indices via Triton kernel. + # The full pre-allocated buffer is passed rather than slicing to + # total_pages, since kv_indptr delimits the valid range and plan() + # copies only the entries within that range. + _copy_page_indices_kernel[(num_decodes,)]( + self._kv_indices_buf, + block_table_tensor, + block_table_tensor.stride(0), + kv_indptr, + BLOCK_SIZE=1024, + ) + + kv_lens = seq_lens_device.to(torch.int32) + + # Build qo_indptr + tokens_per_req = num_decode_tokens // num_decodes + qo_indptr = torch.arange( + 0, + num_decodes * tokens_per_req + 1, + tokens_per_req, + dtype=torch.int32, + device=self.device, + ) + + # Get or create wrapper for this batch size. + wrapper = self._get_wrapper(num_decodes) + wrapper.plan( + qo_indptr, + kv_indptr, + self._kv_indices_buf, + kv_lens, + self._num_heads, + self._mla_dims.kv_lora_rank, + self._mla_dims.qk_rope_head_dim, + self.page_size, + False, # causal=False for decode + self._sm_scale, + self.model_config.dtype, + self.model_config.dtype, + ) + + return FlashInferFA2MLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, + wrapper=wrapper, + ) + + +class FlashInferFA2MLAImpl( + MLACommonImpl[MLACommonMetadata[FlashInferFA2MLADecodeMetadata]] +): + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashInferFA2MLAImpl does not support one of the " + "following: alibi_slopes, sliding_window, logits_soft_cap" + ) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferFA2MLAImpl" + ) + + # FA2 uses bf16 queries; no query quantization needed + self.supports_quant_query_input = False + + def forward_mqa( + self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata[FlashInferFA2MLADecodeMetadata], + layer: AttentionLayer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + assert attn_metadata.decode.wrapper is not None + + # Verify the wrapper's sm_scale (set during plan()) matches the + # model layer's scale. These must agree for correct attention output. + assert attn_metadata.decode.wrapper._sm_scale == self.scale, ( + f"FlashInfer FA2 MLA wrapper sm_scale " + f"({attn_metadata.decode.wrapper._sm_scale}) does not match " + f"model scale ({self.scale})." + ) + + # Split query into nope and rope components + if isinstance(q, tuple): + q_nope, q_pe = q + else: + q_nope = q[..., : self.kv_lora_rank] + q_pe = q[..., self.kv_lora_rank :] + + # Split cache into compressed KV and rope key components + ckv_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] + kpe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :] + + output, lse = attn_metadata.decode.wrapper.run( + q_nope, + q_pe, + ckv_cache, + kpe_cache, + return_lse=True, + # Return LSE in base-e to match the DCP output-merging kernel. + return_lse_base_on_e=True, + ) + + return output, lse diff --git a/vllm/v1/attention/backends/registry.py b/vllm/v1/attention/backends/registry.py index 8e60551e2662..172ed20a5546 100644 --- a/vllm/v1/attention/backends/registry.py +++ b/vllm/v1/attention/backends/registry.py @@ -73,6 +73,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend" ) FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" + FLASHINFER_FA2_MLA = ( + "vllm.v1.attention.backends.mla.flashinfer_fa2_mla.FlashInferFA2MLABackend" + ) NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend" FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"