From 5ea3788a83cb85eb470fbbcfb8ead82feb57f86f Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 28 Feb 2025 07:46:37 +0000 Subject: [PATCH 01/42] sync with amd, support v1 Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ce45b34f6435..42cb34a64bf3 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -234,6 +234,8 @@ except ImportError: flashinfer_available = False +is_hip = current_platform.is_rocm() + logger = init_logger(__name__) CUDNN_WORKSPACE_SIZE = 12800 From d0486319ece7349f94f55754ecf7b7fbb7feef6a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 4 Mar 2025 21:06:33 +0000 Subject: [PATCH 02/42] fix IMA Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 42cb34a64bf3..b8f92f66eb9b 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -370,10 +370,28 @@ class MLACommonMetadata(Generic[D]): FlashInferPrefillMetadata, CudnnPrefillMetadata]] = None + # Computed in __post_init__ + prefill_query_start_loc: Optional[torch.Tensor] = None + prefill_max_query_len: Optional[int] = None + decode_seq_lens: Optional[torch.Tensor] = None + decode_block_table: Optional[torch.Tensor] = None + def __post_init__(self): if self.head_dim is not None: MLACommonBackend.validate_head_size(self.head_dim) + # Pre-compute prefill/decode tensor slices and other stats + if self.num_prefills is not None and self.num_prefills > 0: + assert self.num_decodes is not None and self.num_decodes > 0 + start = self.num_decodes # prefill_start + self.prefill_query_start_loc = \ + self.query_start_loc[start:] - self.query_start_loc[start] + self.prefill_max_query_len = self.seq_lens[start:].max().item() + + if self.num_decodes is not None and self.num_decodes > 0: + self.decode_seq_lens = self.seq_lens[:self.num_decodes] + self.decode_block_table = self.block_table[:self.num_decodes, ...] + M = TypeVar("M", bound=MLACommonMetadata) From e09841c7d6a0383376e8d88c8a13180d1a10f003 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 4 Mar 2025 21:25:11 +0000 Subject: [PATCH 03/42] bugfix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b8f92f66eb9b..3a221b97116a 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -382,7 +382,7 @@ def __post_init__(self): # Pre-compute prefill/decode tensor slices and other stats if self.num_prefills is not None and self.num_prefills > 0: - assert self.num_decodes is not None and self.num_decodes > 0 + assert self.num_decodes is not None start = self.num_decodes # prefill_start self.prefill_query_start_loc = \ self.query_start_loc[start:] - self.query_start_loc[start] From 52e72346da9298c34b0276f6842fc1ec167e48d7 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 4 Mar 2025 22:51:57 +0000 Subject: [PATCH 04/42] working Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 3a221b97116a..ea5ed96a7d2e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -370,9 +370,11 @@ class MLACommonMetadata(Generic[D]): FlashInferPrefillMetadata, CudnnPrefillMetadata]] = None - # Computed in __post_init__ + # New for MLA (compared to FlashAttention) + # For handling prefill decode split prefill_query_start_loc: Optional[torch.Tensor] = None prefill_max_query_len: Optional[int] = None + prefill_block_table: Optional[torch.Tensor] = None decode_seq_lens: Optional[torch.Tensor] = None decode_block_table: Optional[torch.Tensor] = None @@ -387,6 +389,7 @@ def __post_init__(self): self.prefill_query_start_loc = \ self.query_start_loc[start:] - self.query_start_loc[start] self.prefill_max_query_len = self.seq_lens[start:].max().item() + self.prefill_block_table = self.block_table[start:, ...] if self.num_decodes is not None and self.num_decodes > 0: self.decode_seq_lens = self.seq_lens[:self.num_decodes] From 46042018c39c778ebf8c6810c4536228eeeae1e3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Mar 2025 02:14:44 +0000 Subject: [PATCH 05/42] cleanup Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 48 ++++++++++++++---------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ea5ed96a7d2e..653d2c6dc21e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -370,31 +370,39 @@ class MLACommonMetadata(Generic[D]): FlashInferPrefillMetadata, CudnnPrefillMetadata]] = None - # New for MLA (compared to FlashAttention) - # For handling prefill decode split - prefill_query_start_loc: Optional[torch.Tensor] = None - prefill_max_query_len: Optional[int] = None - prefill_block_table: Optional[torch.Tensor] = None - decode_seq_lens: Optional[torch.Tensor] = None - decode_block_table: Optional[torch.Tensor] = None + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + chunked_context: Optional[ChunkedContextMetadata] = None + + @dataclass + class DecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + + decode: Optional[DecodeMetadata] = None + prefill: Optional[PrefillMetadata] = None def __post_init__(self): if self.head_dim is not None: MLACommonBackend.validate_head_size(self.head_dim) - # Pre-compute prefill/decode tensor slices and other stats - if self.num_prefills is not None and self.num_prefills > 0: - assert self.num_decodes is not None - start = self.num_decodes # prefill_start - self.prefill_query_start_loc = \ - self.query_start_loc[start:] - self.query_start_loc[start] - self.prefill_max_query_len = self.seq_lens[start:].max().item() - self.prefill_block_table = self.block_table[start:, ...] - - if self.num_decodes is not None and self.num_decodes > 0: - self.decode_seq_lens = self.seq_lens[:self.num_decodes] - self.decode_block_table = self.block_table[:self.num_decodes, ...] - M = TypeVar("M", bound=MLACommonMetadata) From f9f3e3e7e24f2e5e9785741583aadc7028b18a97 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Mar 2025 05:12:07 +0000 Subject: [PATCH 06/42] fa MLA Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/flashattn_mla.py | 100 ++++++++++++++++++ vllm/attention/backends/{ => mla}/flashmla.py | 0 .../backends/{ => mla}/triton_mla.py | 0 vllm/attention/backends/utils.py | 1 - vllm/attention/utils/fa_utils.py | 12 +++ 5 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 vllm/attention/backends/mla/flashattn_mla.py rename vllm/attention/backends/{ => mla}/flashmla.py (100%) rename vllm/attention/backends/{ => mla}/triton_mla.py (100%) diff --git a/vllm/attention/backends/mla/flashattn_mla.py b/vllm/attention/backends/mla/flashattn_mla.py new file mode 100644 index 000000000000..ce9abdd04e1f --- /dev/null +++ b/vllm/attention/backends/mla/flashattn_mla.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) +from vllm.vllm_flash_attn.fa_utils import flash_attn_supports_mla +from vllm.vllm_flash_attn import flash_attn_varlen_func + +if TYPE_CHECKING: + pass + + +class FlashAttnMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHATTN_MLA" + + @staticmethod + def get_impl_cls() -> Type["FlashAttnMLAImpl"]: + return FlashAttnMLAImpl + + +class FlashAttnMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert flash_attn_supports_mla(), \ + "FlashAttnMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 FlashMLA not yet supported") + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + kv_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] + + o = flash_attn_varlen_func( + q=q_pe, + k=kv_pe_cache.unsqueeze(-2), # Add head dim of 1 + v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 + q_v=q_nope, + max_seqlen_q=decode_meta.max_decode_query_len, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_k=decode_meta.max_decode_seq_len, + seqused_k=decode_meta.seq_lens_tensor, + block_table=decode_meta.block_tables, + softmax_scale=self.scale, + causal=True, + fa_version=3 # only version 3 is supported + ) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/mla/flashmla.py similarity index 100% rename from vllm/attention/backends/flashmla.py rename to vllm/attention/backends/mla/flashmla.py diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/mla/triton_mla.py similarity index 100% rename from vllm/attention/backends/triton_mla.py rename to vllm/attention/backends/mla/triton_mla.py diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 34e059067d84..fa35b6e736a8 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -600,7 +600,6 @@ class MLADims: def get_mla_dims(model_config: ModelConfig) -> MLADims: hf_text_config = model_config.hf_text_config - return MLADims( q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), kv_lora_rank=hf_text_config.kv_lora_rank, diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index f8b00565f051..2f2804e3b858 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -68,5 +68,17 @@ def flash_attn_supports_fp8() -> bool: current_platform.get_device_capability().major == 9 +def flash_attn_supports_mla(): + if current_platform.is_cuda(): + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + is_fa_version_supported) + return is_fa_version_supported(3) \ + and current_platform.get_device_capability()[0] == 9 + except (ImportError, AssertionError): + pass + return False + + def is_flash_attn_varlen_func_available() -> bool: return current_platform.is_cuda() or current_platform.is_xpu() From b3b060bb160214ec61da80c6242f848a8fb3b6b9 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 6 Mar 2025 19:25:56 +0000 Subject: [PATCH 07/42] commit wip Signed-off-by: Lucas Wilkinson --- .../attention/backends/mla/flashattn_mla.py | 145 ++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 vllm/v1/attention/backends/mla/flashattn_mla.py diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py new file mode 100644 index 000000000000..c36f720b57e7 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.utils import is_flash_attn_mla_supported +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.vllm_flash_attn import flash_attn_varlen_func + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHATTN_MLA_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: + return FlashAttnMLAMetadata + + @staticmethod + def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: + return FlashAttnMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashAttnMLAImpl"]: + return FlashAttnMLAImpl + + +@dataclass +class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): + pass + + +@dataclass +class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): + pass + + +class FlashAttnMLAMetadataBuilder( + MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + + def __init__(self, runner): + super().__init__(runner) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + def _build_decode(self, input_positions: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor) -> FlashAttnMLADecodeMetadata: + # + + return FlashAttnMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + ) + + +class FlashAttnMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert is_flash_attn_mla_supported(), \ + "FlashAttnMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 FlashMLA not yet supported") + + decode_meta = attn_metadata.decode + assert decode_meta is not None + + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + kv_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] + + o = flash_attn_varlen_func( + q=q_pe, + k=kv_pe_cache.unsqueeze(-2), # Add head dim of 1 + v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 + q_v=q_nope, + max_seqlen_q=decode_meta.max_decode_query_len, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_k=decode_meta.max_decode_seq_len, + seqused_k=decode_meta.seq_lens_tensor, + block_table=decode_meta.block_tables, + softmax_scale=self.scale, + causal=True, + fa_version=3 # only version 3 is supported + ) + + return self._v_up_proj_and_o_proj(o) From e286de85eb44606897f2059c1ec3c9bb41fedee8 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 18 Apr 2025 21:03:33 +0000 Subject: [PATCH 08/42] cleanup Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/flashmla.py | 242 ++++++++++++++++++ .../{mla/flashattn_mla.py => triton_mla.py} | 83 +++--- vllm/attention/utils/fa_utils.py | 1 + vllm/platforms/cuda.py | 47 ++-- vllm/v1/attention/backends/mla/common.py | 49 ++-- .../attention/backends/mla/flashattn_mla.py | 35 ++- vllm/v1/attention/backends/mla/flashmla.py | 11 +- 7 files changed, 371 insertions(+), 97 deletions(-) create mode 100644 vllm/attention/backends/flashmla.py rename vllm/attention/backends/{mla/flashattn_mla.py => triton_mla.py} (52%) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py new file mode 100644 index 000000000000..5d0c23093310 --- /dev/null +++ b/vllm/attention/backends/flashmla.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA" + + @staticmethod + def get_impl_cls() -> Type["FlashMLAImpl"]: + return FlashMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashMLAState"]: + return FlashMLAState + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata): + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None + decode_num_splits: Optional[torch.Tensor] = None + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + # TODO: cache assignment? + if decode_metadata is not None: + decode_metadata.decode_tile_scheduler_metadata=\ + self.decode_tile_scheduler_metadata + decode_metadata.decode_num_splits=\ + self.decode_num_splits + return decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + raise NotImplementedError( + "advance_step is not implemented for FlashMLA") + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + m = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + + if m.num_decode_tokens > 0: + m.decode_tile_scheduler_metadata, m.decode_num_splits = \ + get_mla_metadata( + m.seq_lens_tensor[m.num_prefills:], + self.num_q_heads, + 1, # MQA for the decode path + ) + + return m + + +class FlashMLAState(MLACommonState[FlashMLAMetadata]): + + def __init__(self, *args, **kwds): + super().__init__(*args, **kwds) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + @contextmanager + def graph_capture(self, max_batch_size: int): + # Run a dummy `get_mla_metadata` so we can get the right shapes + self._graph_decoder_tile_scheduler_metadata, \ + self._graph_decode_num_splits = get_mla_metadata( + torch.ones( + max_batch_size, dtype=torch.int32, device=self.runner.device), + self.num_q_heads, + 1, # MQA for the decode path + ) + + with super().graph_capture(max_batch_size): + yield + + del self._graph_decoder_tile_scheduler_metadata + del self._graph_decode_num_splits + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + assert metadata.num_decode_tokens > 0 + + decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( + self._graph_seq_lens[:batch_size], + self.num_q_heads, + 1, # MQA for the decode path + ) + + self._graph_decoder_tile_scheduler_metadata.copy_( + decoder_tile_scheduler_metadata) + self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) + + metadata.decode_tile_scheduler_metadata=\ + self._graph_decoder_tile_scheduler_metadata + metadata.decode_num_splits=\ + self._graph_decode_num_splits[:batch_size + 1] + + return metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers["decode_tile_scheduler_metadata"] = \ + attn_metadata.decode_metadata.decode_tile_scheduler_metadata + input_buffers["decode_num_splits"] = \ + attn_metadata.decode_metadata.decode_num_splits + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + input_buffers["decode_tile_scheduler_metadata"].copy_( + attn_metadata.decode_metadata.decode_tile_scheduler_metadata) + input_buffers["decode_num_splits"].copy_( + attn_metadata.decode_metadata.decode_num_splits) + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert is_flashmla_supported(), \ + "FlashMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashMLA with FP8 KV cache not yet supported") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1)\ + .unsqueeze(1) # Add seqlen dim of 1 (decode) + + o, _ = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, + num_splits=decode_meta.decode_num_splits, + softmax_scale=self.scale, + causal=True, + ) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/mla/flashattn_mla.py b/vllm/attention/backends/triton_mla.py similarity index 52% rename from vllm/attention/backends/mla/flashattn_mla.py rename to vllm/attention/backends/triton_mla.py index ce9abdd04e1f..61e5c76d9fda 100644 --- a/vllm/attention/backends/mla/flashattn_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -1,32 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type import torch -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata) -from vllm.vllm_flash_attn.fa_utils import flash_attn_supports_mla -from vllm.vllm_flash_attn import flash_attn_varlen_func +from vllm.attention.ops.triton_decode_attention import decode_attention_fwd -if TYPE_CHECKING: - pass - -class FlashAttnMLABackend(MLACommonBackend): +class TritonMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: - return "FLASHATTN_MLA" + return "TRITON_MLA" @staticmethod - def get_impl_cls() -> Type["FlashAttnMLAImpl"]: - return FlashAttnMLAImpl + def get_impl_cls() -> Type["TritonMLAImpl"]: + return TritonMLAImpl -class FlashAttnMLAImpl(MLACommonImpl[MLACommonMetadata]): +class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): def __init__( self, @@ -47,15 +44,12 @@ def __init__( blocksparse_params, logits_soft_cap, attn_type, **mla_args) - assert flash_attn_supports_mla(), \ - "FlashAttnMLA is not supported on this device" - unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap ] if any(unsupported_features): raise NotImplementedError( - "FlashMLAImpl does not support one of the following: " + "TritonMLAImpl does not support one of the following: " "alibi_slopes, sliding_window, blocksparse_params, " "logits_soft_cap") @@ -63,7 +57,11 @@ def __init__( raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " - "FlashMLAImpl") + "TritonMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "TritonMLA with FP8 KV cache not yet supported") def _forward_decode( self, @@ -73,28 +71,43 @@ def _forward_decode( attn_metadata: MLACommonMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 FlashMLA not yet supported") decode_meta = attn_metadata.decode_metadata assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + num_kv_splits = 4 # TODO: heuristic + + # TODO(lucas) Allocate ahead of time + attn_logits = torch.empty( + ( + B, + self.num_heads, + num_kv_splits, + # NOTE(lucas) idk why the +1 is here but sglang has it so we + # just mirror that + self.kv_lora_rank + 1, + ), + dtype=torch.float32, + device=q.device, + ) + # Add a head dim of 1 + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] - kv_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] - - o = flash_attn_varlen_func( - q=q_pe, - k=kv_pe_cache.unsqueeze(-2), # Add head dim of 1 - v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 - q_v=q_nope, - max_seqlen_q=decode_meta.max_decode_query_len, - cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_k=decode_meta.max_decode_seq_len, - seqused_k=decode_meta.seq_lens_tensor, - block_table=decode_meta.block_tables, - softmax_scale=self.scale, - causal=True, - fa_version=3 # only version 3 is supported - ) + PAGE_SIZE = kv_c_and_k_pe_cache.size(1) + + # Run MQA + decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, attn_logits, + num_kv_splits, self.scale, PAGE_SIZE) return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index 2f2804e3b858..dc0af7e28e3e 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -69,6 +69,7 @@ def flash_attn_supports_fp8() -> bool: def flash_attn_supports_mla(): + from vllm.platforms import current_platform if current_platform.is_cuda(): try: from vllm.vllm_flash_attn.flash_attn_interface import ( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 134ba36e5e73..c2af930c89ae 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -223,9 +223,21 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if use_mla: # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here - if selected_backend == _Backend.CUTLASS_MLA or ( - cls.is_device_capability(100) and selected_backend is None - and block_size == 128): + + from vllm.attention.backends.flashmla import is_flashmla_supported + from vllm.vllm_flash_attn.fa_utils import flash_attn_supports_mla + + use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( + selected_backend is None and cls.is_device_capability(100) + and block_size == 128) + use_flashattn = selected_backend == _Backend.FLASH_ATTN or ( + selected_backend is None and flash_attn_supports_mla()) + use_flashmla = selected_backend == _Backend.FLASHMLA or ( + selected_backend is None and is_flash_mla_supported()[0]) + use_triton = selected_backend == _Backend.TRITON_MLA or ( + selected_backend is None) + + if use_cutlassmla: if use_v1: logger.info_once("Using Cutlass MLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." @@ -233,22 +245,18 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, else: logger.warning( "Cutlass MLA backend is only supported on V1 engine") - if selected_backend == _Backend.TRITON_MLA or block_size != 64: + if use_flashattn: if use_v1: - logger.info_once("Using Triton MLA backend on V1 engine.") + logger.info_once( + "Using FlashAttention MLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") + "flashattn_mla.FlashAttnMLABackend") else: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" - else: - from vllm.attention.backends.flashmla import ( - is_flashmla_supported) - if not is_flashmla_supported()[0]: logger.warning( - "FlashMLA backend is not supported due to %s", - is_flashmla_supported()[1]) - elif block_size != 64: + "FlashAttention MLA backend is only supported on " + "V1 engine.") + if use_flashmla: + if block_size != 64: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", @@ -263,6 +271,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using FlashMLA backend.") return ("vllm.attention.backends." "flashmla.FlashMLABackend") + if use_triton: + if use_v1: + logger.info_once("Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") + else: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" + if use_v1: FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 653d2c6dc21e..6a3c03ecb28a 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -370,35 +370,6 @@ class MLACommonMetadata(Generic[D]): FlashInferPrefillMetadata, CudnnPrefillMetadata]] = None - @dataclass - class ChunkedContextMetadata: - # New for MLA (compared to FlashAttention) - # For handling chunked prefill - cu_seq_lens: torch.Tensor - starts: torch.Tensor - seq_tot: list[int] - max_seq_lens: list[int] - workspace: torch.Tensor - - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - block_table: torch.Tensor - query_start_loc: torch.Tensor - max_query_len: int - chunked_context: Optional[ChunkedContextMetadata] = None - - @dataclass - class DecodeMetadata: - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - block_table: torch.Tensor - seq_lens: torch.Tensor - - decode: Optional[DecodeMetadata] = None - prefill: Optional[PrefillMetadata] = None - def __post_init__(self): if self.head_dim is not None: MLACommonBackend.validate_head_size(self.head_dim) @@ -432,6 +403,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): understand this class """ reorder_batch_threshold: ClassVar[int] = 1 + decode_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, @@ -585,11 +557,17 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor): + def _build_decode(self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + input_positions: torch.Tensor): return MLACommonDecodeMetadata( + input_positions=input_positions, block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, ) def build_for_cudagraph_capture( @@ -621,6 +599,7 @@ def build(self, device = self.device block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + device = self.runner.device query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -733,6 +712,12 @@ def build(self, decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens=seq_lens[:num_decodes], + query_start_loc_device=query_start_loc[:self._num_decodes + 1], + query_start_loc_cpu=query_start_loc_cpu[:self._num_decodes + + 1], + seq_lens_device=seq_lens[:self._num_decodes], + seq_lens_cpu=seq_lens_cpu[:self._num_decodes], + input_positions=input_positions[:self._num_decode_tokens], ) attn_metadata = self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index c36f720b57e7..80a1f3b21ac0 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -6,7 +6,6 @@ import torch from vllm.attention.backends.abstract import AttentionType -from vllm.attention.backends.utils import is_flash_attn_mla_supported from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, @@ -14,6 +13,7 @@ MLACommonMetadata, MLACommonMetadataBuilder) from vllm.vllm_flash_attn import flash_attn_varlen_func +from vllm.vllm_flash_attn.fa_utils import flash_attn_supports_mla if TYPE_CHECKING: pass @@ -21,7 +21,7 @@ logger = init_logger(__name__) -class FlashMLABackend(MLACommonBackend): +class FlashAttnMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -42,7 +42,9 @@ def get_impl_cls() -> type["FlashAttnMLAImpl"]: @dataclass class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): - pass + query_start_loc: torch.Tensor + max_query_len: int + max_seq_len: int @dataclass @@ -59,15 +61,24 @@ def __init__(self, runner): self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) - def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, - seq_lens: torch.Tensor) -> FlashAttnMLADecodeMetadata: - # + def _build_decode(self, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + input_positions: torch.Tensor, + block_table: torch.Tensor) -> FlashAttnMLADecodeMetadata: + + query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) + max_query_len = query_lens_cpu.max().item() + max_seq_len = seq_lens_cpu.max().item() return FlashAttnMLADecodeMetadata( input_positions=input_positions, block_table=block_table, - seq_lens=seq_lens, + seq_lens=seq_lens_device, + query_start_loc=query_start_loc_device, + max_query_len=max_query_len, + max_seq_len=max_seq_len, ) @@ -92,7 +103,7 @@ def __init__( blocksparse_params, logits_soft_cap, attn_type, **mla_args) - assert is_flash_attn_mla_supported(), \ + assert flash_attn_supports_mla(), \ "FlashAttnMLA is not supported on this device" unsupported_features = [ @@ -132,10 +143,10 @@ def _forward_decode( k=kv_pe_cache.unsqueeze(-2), # Add head dim of 1 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 q_v=q_nope, - max_seqlen_q=decode_meta.max_decode_query_len, + max_seqlen_q=decode_meta.max_query_len, cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_k=decode_meta.max_decode_seq_len, - seqused_k=decode_meta.seq_lens_tensor, + max_seqlen_k=decode_meta.max_seq_len, + seqused_k=decode_meta.seq_lens, block_table=decode_meta.block_tables, softmax_scale=self.scale, causal=True, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 1c50144d4790..430eac91c21a 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -86,10 +86,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32) def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + input_positions: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( - seq_lens, + seq_lens_device, self.num_q_heads, 1, # MQA for the decode path ) @@ -123,7 +127,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor, return FlashMLADecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, + input_positions=input_positions, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, ) From 27a2cd2545c78b9235fd56d991bbb941d990189c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 20 Apr 2025 03:21:14 +0000 Subject: [PATCH 09/42] fix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/flashattn_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 80a1f3b21ac0..4824b1526a81 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -147,7 +147,7 @@ def _forward_decode( cu_seqlens_q=decode_meta.query_start_loc, max_seqlen_k=decode_meta.max_seq_len, seqused_k=decode_meta.seq_lens, - block_table=decode_meta.block_tables, + block_table=decode_meta.block_table, softmax_scale=self.scale, causal=True, fa_version=3 # only version 3 is supported From 9165af3bf8020c3c3783a1f5bb9e5e2ba420e85c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 20 Apr 2025 03:25:27 +0000 Subject: [PATCH 10/42] move files Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/flashmla.py | 242 -------------------------- vllm/attention/backends/triton_mla.py | 113 ------------ 2 files changed, 355 deletions(-) delete mode 100644 vllm/attention/backends/flashmla.py delete mode 100644 vllm/attention/backends/triton_mla.py diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py deleted file mode 100644 index 5d0c23093310..000000000000 --- a/vllm/attention/backends/flashmla.py +++ /dev/null @@ -1,242 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from contextlib import contextmanager -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder, - MLACommonState) -from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - - -class FlashMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "FLASHMLA" - - @staticmethod - def get_impl_cls() -> Type["FlashMLAImpl"]: - return FlashMLAImpl - - @staticmethod - def get_metadata_cls() -> Type["FlashMLAMetadata"]: - return FlashMLAMetadata - - @staticmethod - def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: - return FlashMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["FlashMLAState"]: - return FlashMLAState - - -@dataclass -class FlashMLAMetadata(MLACommonMetadata): - decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, - torch.Tensor]] = None - decode_num_splits: Optional[torch.Tensor] = None - - @property - def decode_metadata(self): - decode_metadata = super().decode_metadata - # TODO: cache assignment? - if decode_metadata is not None: - decode_metadata.decode_tile_scheduler_metadata=\ - self.decode_tile_scheduler_metadata - decode_metadata.decode_num_splits=\ - self.decode_num_splits - return decode_metadata - - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - raise NotImplementedError( - "advance_step is not implemented for FlashMLA") - - -class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - m = super().build(seq_lens, query_lens, cuda_graph_pad_size, - batch_size) - - if m.num_decode_tokens > 0: - m.decode_tile_scheduler_metadata, m.decode_num_splits = \ - get_mla_metadata( - m.seq_lens_tensor[m.num_prefills:], - self.num_q_heads, - 1, # MQA for the decode path - ) - - return m - - -class FlashMLAState(MLACommonState[FlashMLAMetadata]): - - def __init__(self, *args, **kwds): - super().__init__(*args, **kwds) - - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) - - @contextmanager - def graph_capture(self, max_batch_size: int): - # Run a dummy `get_mla_metadata` so we can get the right shapes - self._graph_decoder_tile_scheduler_metadata, \ - self._graph_decode_num_splits = get_mla_metadata( - torch.ones( - max_batch_size, dtype=torch.int32, device=self.runner.device), - self.num_q_heads, - 1, # MQA for the decode path - ) - - with super().graph_capture(max_batch_size): - yield - - del self._graph_decoder_tile_scheduler_metadata - del self._graph_decode_num_splits - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - metadata = super().graph_capture_get_metadata_for_batch( - batch_size, is_encoder_decoder_model) - assert metadata.num_decode_tokens > 0 - - decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( - self._graph_seq_lens[:batch_size], - self.num_q_heads, - 1, # MQA for the decode path - ) - - self._graph_decoder_tile_scheduler_metadata.copy_( - decoder_tile_scheduler_metadata) - self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) - - metadata.decode_tile_scheduler_metadata=\ - self._graph_decoder_tile_scheduler_metadata - metadata.decode_num_splits=\ - self._graph_decode_num_splits[:batch_size + 1] - - return metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = super().get_graph_input_buffers( - attn_metadata, is_encoder_decoder_model) - input_buffers["decode_tile_scheduler_metadata"] = \ - attn_metadata.decode_metadata.decode_tile_scheduler_metadata - input_buffers["decode_num_splits"] = \ - attn_metadata.decode_metadata.decode_num_splits - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - super().prepare_graph_input_buffers(input_buffers, attn_metadata, - is_encoder_decoder_model) - - input_buffers["decode_tile_scheduler_metadata"].copy_( - attn_metadata.decode_metadata.decode_tile_scheduler_metadata) - input_buffers["decode_num_splits"].copy_( - attn_metadata.decode_metadata.decode_num_splits) - - -class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], - logits_soft_cap: Optional[float], - attn_type: str, - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, - **mla_args) - - assert is_flashmla_supported(), \ - "FlashMLA is not supported on this device" - - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] - if any(unsupported_features): - raise NotImplementedError( - "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashMLAImpl") - - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashMLA with FP8 KV cache not yet supported") - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: FlashMLAMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - - q = torch.cat([q_nope, q_pe], dim=-1)\ - .unsqueeze(1) # Add seqlen dim of 1 (decode) - - o, _ = flash_mla_with_kvcache( - q=q, - k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, - num_splits=decode_meta.decode_num_splits, - softmax_scale=self.scale, - causal=True, - ) - - return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py deleted file mode 100644 index 61e5c76d9fda..000000000000 --- a/vllm/attention/backends/triton_mla.py +++ /dev/null @@ -1,113 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Dict, List, Optional, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata) -from vllm.attention.ops.triton_decode_attention import decode_attention_fwd - - -class TritonMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "TRITON_MLA" - - @staticmethod - def get_impl_cls() -> Type["TritonMLAImpl"]: - return TritonMLAImpl - - -class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], - logits_soft_cap: Optional[float], - attn_type: str, - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, - **mla_args) - - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] - if any(unsupported_features): - raise NotImplementedError( - "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonMLAImpl") - - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "TritonMLA with FP8 KV cache not yet supported") - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - B = q_nope.shape[0] - - q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.zeros(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) - - num_kv_splits = 4 # TODO: heuristic - - # TODO(lucas) Allocate ahead of time - attn_logits = torch.empty( - ( - B, - self.num_heads, - num_kv_splits, - # NOTE(lucas) idk why the +1 is here but sglang has it so we - # just mirror that - self.kv_lora_rank + 1, - ), - dtype=torch.float32, - device=q.device, - ) - - # Add a head dim of 1 - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] - PAGE_SIZE = kv_c_and_k_pe_cache.size(1) - - # Run MQA - decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, attn_logits, - num_kv_splits, self.scale, PAGE_SIZE) - - return self._v_up_proj_and_o_proj(o) From d056efdbfba15bb3a88352791b787008a7466403 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 20 Apr 2025 19:28:31 +0000 Subject: [PATCH 11/42] fix up Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/flashattn_mla.py | 182 ++++++++++++++++++ vllm/platforms/cuda.py | 4 +- vllm/v1/attention/backends/mla/common.py | 74 ++++++- .../attention/backends/mla/flashattn_mla.py | 35 +++- 4 files changed, 280 insertions(+), 15 deletions(-) create mode 100644 vllm/attention/backends/mla/flashattn_mla.py diff --git a/vllm/attention/backends/mla/flashattn_mla.py b/vllm/attention/backends/mla/flashattn_mla.py new file mode 100644 index 000000000000..65064043255c --- /dev/null +++ b/vllm/attention/backends/mla/flashattn_mla.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.logger import init_logger +from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata +from vllm.vllm_flash_attn.fa_utils import flash_attn_supports_mla + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + + +class FlashAttnMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHATTN_MLA_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: + return FlashAttnMLAMetadata + + @staticmethod + def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: + return FlashAttnMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashAttnMLAImpl"]: + return FlashAttnMLAImpl + + +@dataclass +class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): + query_start_loc: torch.Tensor + max_query_len: int + max_seq_len: int + + +@dataclass +class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): + pass + + +class FlashAttnMLAMetadataBuilder( + MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + + def __init__(self, runner): + super().__init__(runner) + + def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): + if self.fa_aot_schedule: + return get_scheduler_metadata( + batch_size=num_reqs, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + cache_seqlens=seqlens, + num_heads_q=self.num_heads, + num_heads_kv=1, + headdim=self.mla_dims.qk_rope_head_dim, + headdim_v=self.mla_dims.kv_lora_rank, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + ) + return None + + def _build_decode(self, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + input_positions: torch.Tensor, + block_table: torch.Tensor) -> FlashAttnMLADecodeMetadata: + + query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) + max_query_len = query_lens_cpu.max().item() + max_seq_len = seq_lens_cpu.max().item() + + scheduler_metadata = self._schedule_decode( + num_reqs=seq_lens_cpu.numel(), + cu_query_lens=query_start_loc_device, + max_query_len=max_query_len, + seqlens=seq_lens_device, + max_seq_len=max_seq_len, + causal=True, + ) + + return FlashAttnMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens_device, + query_start_loc=query_start_loc_device, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + scheduler_metadata=scheduler_metadata, + ) + + +class FlashAttnMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert flash_attn_supports_mla(), \ + "FlashAttnMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 FlashMLA not yet supported") + + decode_meta = attn_metadata.decode + assert decode_meta is not None + + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + kv_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] + + o = flash_attn_varlen_func( + q=q_pe, + k=kv_pe_cache.unsqueeze(-2), # Add head dim of 1 + v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 + q_v=q_nope, + max_seqlen_q=decode_meta.max_query_len, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_k=decode_meta.max_seq_len, + seqused_k=decode_meta.seq_lens, + block_table=decode_meta.block_table, + softmax_scale=self.scale, + causal=True, + scheduler_metadata=decode_meta.scheduler_metadata, + fa_version=3 # only version 3 is supported + ) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c2af930c89ae..bbdc011f6e48 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -224,7 +224,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here - from vllm.attention.backends.flashmla import is_flashmla_supported + from vllm.attention.ops.flashmla import is_flashmla_supported from vllm.vllm_flash_attn.fa_utils import flash_attn_supports_mla use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( @@ -233,7 +233,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, use_flashattn = selected_backend == _Backend.FLASH_ATTN or ( selected_backend is None and flash_attn_supports_mla()) use_flashmla = selected_backend == _Backend.FLASHMLA or ( - selected_backend is None and is_flash_mla_supported()[0]) + selected_backend is None and is_flashmla_supported()[0]) use_triton = selected_backend == _Backend.TRITON_MLA or ( selected_backend is None) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 6a3c03ecb28a..7108e6412560 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -218,7 +218,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec try: - from vllm.vllm_flash_attn import flash_attn_varlen_func + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + get_scheduler_metadata) is_vllm_fa = True except ImportError: # For rocm use upstream flash attention @@ -300,10 +301,12 @@ class ChunkedContextMetadata: max_seq_lens: list[int] seq_lens: torch.Tensor workspace: torch.Tensor + scheduler_metatadata: list[Optional[torch.Tensor]] block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int + scheduler_metadata: Optional[torch.Tensor] = None chunked_context: Optional[ChunkedContextMetadata] = None @@ -585,7 +588,27 @@ def build_for_cudagraph_capture( return self.build(0, m) + def _schedule_prefill(self, num_reqs, cu_query_lens, max_query_len, + seqlens, max_seq_len, causal): + if self.fa_aot_schedule: + return get_scheduler_metadata( + batch_size=num_reqs, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + cache_seqlens=seqlens, + num_heads_q=self.num_heads, + num_heads_kv=self.num_heads, + headdim=self.mla_dims.qk_nope_head_dim + + self.mla_dims.qk_rope_head_dim, + headdim_v=self.mla_dims.v_head_dim, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + ) + return None + def build(self, + num_actual_tokens: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> M: @@ -593,6 +616,8 @@ def build(self, num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len + assert self._num_decodes + self._num_prefills == num_reqs + # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. @@ -673,19 +698,28 @@ def build(self, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) - chunked_context_metadata_cls = \ - CudnnPrefillMetadata.ChunkedContextMetadata \ - if self._use_cudnn_prefill else \ - MLACommonPrefillMetadata.ChunkedContextMetadata + chunks_max_seq_lens = chunk_seq_lens.max(dim=1).values + + chunks_scheduler_metadata = [] + for i in range(num_chunks): + chunks_scheduler_metadata.append( + self._schedule_prefill( + self._num_prefills, + cu_seq_lens_cpu[i], + max_query_len, + chunk_seq_lens[i], + chunks_max_seq_lens[i], + causal=False, + )) chunked_context_metadata = \ chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - seq_lens=chunk_seq_lens, + max_seq_lens=chunks_max_seq_lens, workspace=self.chunked_prefill_workspace, + scheduler_metatadata=chunks_scheduler_metadata, ) if self._use_cudnn_prefill: @@ -694,11 +728,22 @@ def build(self, assert max(chunked_context_metadata.max_seq_lens) <= \ self.chunked_prefill_workspace_size - prefill_metadata = self.prefill_metadata_cls( - block_table=block_table_tensor[reqs_start:, ...], + scheduler_metadata = self._schedule_prefill( + self._num_prefills, + prefill_query_start_loc, + max_query_len, + prefill_query_start_loc, + max_query_len, + causal=True, + ) + + prefill_metadata = MLACommonPrefillMetadata( + input_positions=input_positions[tokens_start:], + block_table=block_table[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, + scheduler_metadata=scheduler_metadata, ) if self._use_cudnn_prefill: @@ -827,6 +872,7 @@ def _flash_attn_varlen_diff_headdims(self, v, return_softmax_lse=False, softmax_scale=None, + scheduler_metadata=None, **kwargs): maybe_padded_v = v if self._pad_v: @@ -835,6 +881,7 @@ def _flash_attn_varlen_diff_headdims(self, if is_vllm_fa: kwargs["return_softmax_lse"] = return_softmax_lse + kwargs["scheduler_metadata"] = scheduler_metadata else: # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse @@ -1065,6 +1112,14 @@ def _compute_prefill_context( q=q, k=k, v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + scheduler_metadata=prefill_metadata.chunked_context.scheduler_metatadata[i], ) if output is None: @@ -1111,6 +1166,7 @@ def _forward_prefill( k=k, v=v, return_softmax_lse=has_context, + scheduler_metadata=attn_metadata.prefill.scheduler_metadata, ) if has_context: diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 4824b1526a81..a24e08e0919e 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -12,7 +12,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) -from vllm.vllm_flash_attn import flash_attn_varlen_func +from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata from vllm.vllm_flash_attn.fa_utils import flash_attn_supports_mla if TYPE_CHECKING: @@ -45,6 +45,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): query_start_loc: torch.Tensor max_query_len: int max_seq_len: int + scheduler_metadata: Optional[torch.Tensor] = None @dataclass @@ -58,8 +59,23 @@ class FlashAttnMLAMetadataBuilder( def __init__(self, runner): super().__init__(runner) - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) + def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): + if self.fa_aot_schedule: + return get_scheduler_metadata( + batch_size=num_reqs, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + cache_seqlens=seqlens, + num_heads_q=self.num_heads, + num_heads_kv=1, + headdim=self.mla_dims.qk_rope_head_dim, + headdim_v=self.mla_dims.kv_lora_rank, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + ) + return None def _build_decode(self, seq_lens_cpu: torch.Tensor, seq_lens_device: torch.Tensor, @@ -72,6 +88,15 @@ def _build_decode(self, seq_lens_cpu: torch.Tensor, max_query_len = query_lens_cpu.max().item() max_seq_len = seq_lens_cpu.max().item() + scheduler_metadata = self._schedule_decode( + num_reqs=seq_lens_cpu.numel(), + cu_query_lens=query_start_loc_device, + max_query_len=max_query_len, + seqlens=seq_lens_device, + max_seq_len=max_seq_len, + causal=True, + ) + return FlashAttnMLADecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -79,6 +104,7 @@ def _build_decode(self, seq_lens_cpu: torch.Tensor, query_start_loc=query_start_loc_device, max_query_len=max_query_len, max_seq_len=max_seq_len, + scheduler_metadata=scheduler_metadata, ) @@ -150,7 +176,8 @@ def _forward_decode( block_table=decode_meta.block_table, softmax_scale=self.scale, causal=True, - fa_version=3 # only version 3 is supported + fa_version=3, # only version 3 is supported + scheduler_metadata=decode_meta.scheduler_metadata, ) return self._v_up_proj_and_o_proj(o) From ac4c6246e124e0723622eb063c4a85367efcda9a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 22 Apr 2025 19:08:33 +0000 Subject: [PATCH 12/42] v0 support + decode threshold Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 2 + vllm/attention/backends/mla/flashattn_mla.py | 189 ++++++++++++------ vllm/platforms/cuda.py | 44 ++-- .../attention/backends/mla/flashattn_mla.py | 2 + 4 files changed, 154 insertions(+), 83 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index c5ed4c6e4032..a2d024104d95 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -654,6 +654,8 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): attn_state.context_chunk_workspace_size self.page_size = self.runner.block_size + self.fa_aot_schedule = (get_flash_attn_version() == 3) + def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] diff --git a/vllm/attention/backends/mla/flashattn_mla.py b/vllm/attention/backends/mla/flashattn_mla.py index 65064043255c..222254394736 100644 --- a/vllm/attention/backends/mla/flashattn_mla.py +++ b/vllm/attention/backends/mla/flashattn_mla.py @@ -1,22 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, List, Optional import torch from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonDecodeMetadata, MLACommonImpl, MLACommonMetadata, - MLACommonMetadataBuilder) + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.backends.utils import get_mla_dims from vllm.logger import init_logger from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata -from vllm.vllm_flash_attn.fa_utils import flash_attn_supports_mla +from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_mla, + get_flash_attn_version) if TYPE_CHECKING: - pass + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata logger = init_logger(__name__) @@ -25,7 +28,7 @@ class FlashAttnMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: - return "FLASHATTN_MLA_VLLM_V1" + return "FLASHATTN_MLA" @staticmethod def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: @@ -41,70 +44,142 @@ def get_impl_cls() -> type["FlashAttnMLAImpl"]: @dataclass -class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): - query_start_loc: torch.Tensor - max_query_len: int - max_seq_len: int - - -@dataclass -class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): - pass +class FlashAttnMLAMetadata(MLACommonMetadata): + decode_scheduler_metadata: Optional[torch.Tensor] = None + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + # TODO: cache assignment? + if decode_metadata is not None: + decode_metadata.decode_scheduler_metadata=\ + self.decode_scheduler_metadata + return decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + raise NotImplementedError( + "advance_step is not implemented for FlashAttnMLA") class FlashAttnMLAMetadataBuilder( MLACommonMetadataBuilder[FlashAttnMLAMetadata]): - def __init__(self, runner): - super().__init__(runner) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_heads_q = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + self.fa_aot_schedule = (get_flash_attn_version() == 3) + self.mla_dims = get_mla_dims(self.runner.model_config) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + m = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + + decode_cu_seqlens_q = m.query_start_loc[ + m.num_prefills:] - m.query_start_loc[m.num_prefills] + + if m.num_decode_tokens > 0 and self.fa_aot_schedule: + m.decode_scheduler_metadata = get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=m.max_decode_query_len, + max_seqlen_k=m.max_decode_seq_len, + cache_seqlens=m.seq_start_loc[m.num_prefills:], + num_heads_q=self.num_heads_q, + num_heads_kv=1, + headdim=self.mla_dims.qk_rope_head_dim, + headdim_v=self.mla_dims.kv_lora_rank, + page_size=self.page_size, + cu_seqlens_q=decode_cu_seqlens_q, + causal=True) + return m + + +class FlashAttnMLAState(MLACommonState[FlashAttnMLAMetadata]): - def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): + def __init__(self, *args, **kwds): + super().__init__(*args, **kwds) + + self.fa_aot_schedule = (get_flash_attn_version() == 3) + self.num_heads_q = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + self.mla_dims = get_mla_dims(self.runner.model_config) + self.page_size = self.runner.block_size + + def _dummy_scheduler_metadata(self, max_batch_size: int): if self.fa_aot_schedule: return get_scheduler_metadata( - batch_size=num_reqs, - max_seqlen_q=max_query_len, - max_seqlen_k=max_seq_len, - cache_seqlens=seqlens, - num_heads_q=self.num_heads, + batch_size=max_batch_size, + max_seqlen_q=1, + max_seqlen_k=1, + cache_seqlens=torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device), + num_heads_q=self.num_heads_q, num_heads_kv=1, headdim=self.mla_dims.qk_rope_head_dim, headdim_v=self.mla_dims.kv_lora_rank, page_size=self.page_size, - cu_seqlens_q=cu_query_lens, - causal=causal, - ) + cu_seqlens_q=torch.arange(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device), + causal=True) return None - def _build_decode(self, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - input_positions: torch.Tensor, - block_table: torch.Tensor) -> FlashAttnMLADecodeMetadata: - - query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) - max_query_len = query_lens_cpu.max().item() - max_seq_len = seq_lens_cpu.max().item() - - scheduler_metadata = self._schedule_decode( - num_reqs=seq_lens_cpu.numel(), - cu_query_lens=query_start_loc_device, - max_query_len=max_query_len, - seqlens=seq_lens_device, - max_seq_len=max_seq_len, - causal=True, - ) + @contextmanager + def graph_capture(self, max_batch_size: int): + # Run a dummy `get_scheduler_metadata` so we can get the right shapes + self._graph_scheduler_metadata = self._dummy_scheduler_metadata( + max_batch_size) - return FlashAttnMLADecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens_device, - query_start_loc=query_start_loc_device, - max_query_len=max_query_len, - max_seq_len=max_seq_len, - scheduler_metadata=scheduler_metadata, - ) + with super().graph_capture(max_batch_size): + yield + + del self._graph_scheduler_metadata + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + assert metadata.num_decode_tokens > 0 + + decoder_scheduler_metadata = self._dummy_scheduler_metadata(batch_size) + + metadata_size = decoder_scheduler_metadata.numel() + self._graph_scheduler_metadata[:metadata_size].copy_( + decoder_scheduler_metadata) + + metadata.decode_scheduler_metadata=\ + self._graph_scheduler_metadata[:metadata_size] + + return metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers["decode_scheduler_metadata"] = \ + attn_metadata.decode_metadata.decode_scheduler_metadata + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + input_buffers["decode_scheduler_metadata"].copy_( + attn_metadata.decode_metadata.decode_scheduler_metadata) class FlashAttnMLAImpl(MLACommonImpl[MLACommonMetadata]): @@ -136,7 +211,7 @@ def __init__( ] if any(unsupported_features): raise NotImplementedError( - "FlashMLAImpl does not support one of the following: " + "FlashAttnMLAImpl does not support one of the following: " "alibi_slopes, sliding_window, blocksparse_params, " "logits_soft_cap") @@ -175,8 +250,8 @@ def _forward_decode( block_table=decode_meta.block_table, softmax_scale=self.scale, causal=True, + fa_version=3, # only version 3 is supported scheduler_metadata=decode_meta.scheduler_metadata, - fa_version=3 # only version 3 is supported ) return self._v_up_proj_and_o_proj(o) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bbdc011f6e48..60b0326e68b5 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -237,6 +237,19 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, use_triton = selected_backend == _Backend.TRITON_MLA or ( selected_backend is None) + print("use_cutlassmla", use_cutlassmla) + print("use_flashmla", use_flashmla) + print("use_flashattn", use_flashattn) + print("use_triton", use_triton) + + def _get_version(name, import_suffix) -> str: + if use_v1: + logger.info_once("Using %s backend on V1 engine.", name) + return f"vllm.v1.attention.backends.mla.{import_suffix}" + else: + logger.info("Using %s backend.", name) + return f"vllm.attention.backends.mla.{import_suffix}" + if use_cutlassmla: if use_v1: logger.info_once("Using Cutlass MLA backend on V1 engine.") @@ -246,15 +259,8 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.warning( "Cutlass MLA backend is only supported on V1 engine") if use_flashattn: - if use_v1: - logger.info_once( - "Using FlashAttention MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashattn_mla.FlashAttnMLABackend") - else: - logger.warning( - "FlashAttention MLA backend is only supported on " - "V1 engine.") + return _get_version("FlashAttention MLA", + "flashattn_mla.FlashAttnMLABackend") if use_flashmla: if block_size != 64: logger.warning( @@ -262,24 +268,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, " (currently only supports block size 64).", block_size) else: - if use_v1: - logger.info_once( - "Using FlashMLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashmla.FlashMLABackend") - else: - logger.info("Using FlashMLA backend.") - return ("vllm.attention.backends." - "flashmla.FlashMLABackend") + return _get_version("FlashMLA", "flashmla.FlashMLABackend") if use_triton: - if use_v1: - logger.info_once("Using Triton MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") - else: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" - + return _get_version("Triton MLA", + "triton_mla.TritonMLABackend") if use_v1: FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index a24e08e0919e..b77426ad3ae0 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -55,6 +55,8 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder( MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + # TODO(lucas): tune this value + decode_threshold: int = 64 def __init__(self, runner): super().__init__(runner) From 1f6bb3d9260bf17cb47acfa29e1970e3742877ca Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 22 Apr 2025 21:13:16 +0000 Subject: [PATCH 13/42] v0 fix Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 4 +-- vllm/attention/backends/mla/flashattn_mla.py | 26 +++++++++++++++----- vllm/platforms/cuda.py | 5 ---- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index a2d024104d95..3b6b7308d325 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -786,8 +786,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_seqs = len(seq_lens) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * - cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + query_start_loc.extend([query_start_loc[-1]] * cuda_graph_pad_size) num_decode_tokens = batch_size - self.num_prefill_tokens block_tables = self._get_graph_runner_block_tables( diff --git a/vllm/attention/backends/mla/flashattn_mla.py b/vllm/attention/backends/mla/flashattn_mla.py index 222254394736..1f60764a21e3 100644 --- a/vllm/attention/backends/mla/flashattn_mla.py +++ b/vllm/attention/backends/mla/flashattn_mla.py @@ -42,6 +42,10 @@ def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashAttnMLAImpl"]: return FlashAttnMLAImpl + @staticmethod + def get_state_cls() -> type["FlashAttnMLAState"]: + return FlashAttnMLAState + @dataclass class FlashAttnMLAMetadata(MLACommonMetadata): @@ -77,6 +81,7 @@ def __init__(self, *args, **kwargs): self.runner.parallel_config) self.fa_aot_schedule = (get_flash_attn_version() == 3) self.mla_dims = get_mla_dims(self.runner.model_config) + self.page_size = self.runner.block_size def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): @@ -138,6 +143,9 @@ def graph_capture(self, max_batch_size: int): # Run a dummy `get_scheduler_metadata` so we can get the right shapes self._graph_scheduler_metadata = self._dummy_scheduler_metadata( max_batch_size) + self._graph_query_start_loc = torch.arange(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device) with super().graph_capture(max_batch_size): yield @@ -158,6 +166,8 @@ def graph_capture_get_metadata_for_batch( metadata.decode_scheduler_metadata=\ self._graph_scheduler_metadata[:metadata_size] + metadata.query_start_loc=\ + self._graph_query_start_loc[:batch_size + 1] return metadata @@ -168,6 +178,8 @@ def get_graph_input_buffers(self, attn_metadata, is_encoder_decoder_model) input_buffers["decode_scheduler_metadata"] = \ attn_metadata.decode_metadata.decode_scheduler_metadata + input_buffers["query_start_loc"] = \ + attn_metadata.decode_metadata.query_start_loc return input_buffers @@ -180,6 +192,8 @@ def prepare_graph_input_buffers(self, input_buffers["decode_scheduler_metadata"].copy_( attn_metadata.decode_metadata.decode_scheduler_metadata) + input_buffers["query_start_loc"].copy_( + attn_metadata.decode_metadata.query_start_loc) class FlashAttnMLAImpl(MLACommonImpl[MLACommonMetadata]): @@ -232,7 +246,7 @@ def _forward_decode( if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 FlashMLA not yet supported") - decode_meta = attn_metadata.decode + decode_meta = attn_metadata.decode_metadata assert decode_meta is not None kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] @@ -243,15 +257,15 @@ def _forward_decode( k=kv_pe_cache.unsqueeze(-2), # Add head dim of 1 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 q_v=q_nope, - max_seqlen_q=decode_meta.max_query_len, + max_seqlen_q=decode_meta.max_decode_query_len, cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_k=decode_meta.max_seq_len, - seqused_k=decode_meta.seq_lens, - block_table=decode_meta.block_table, + max_seqlen_k=decode_meta.max_decode_seq_len, + seqused_k=decode_meta.seq_lens_tensor, + block_table=decode_meta.block_tables, softmax_scale=self.scale, causal=True, fa_version=3, # only version 3 is supported - scheduler_metadata=decode_meta.scheduler_metadata, + scheduler_metadata=decode_meta.decode_scheduler_metadata, ) return self._v_up_proj_and_o_proj(o) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 60b0326e68b5..b12b40a68e79 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -237,11 +237,6 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, use_triton = selected_backend == _Backend.TRITON_MLA or ( selected_backend is None) - print("use_cutlassmla", use_cutlassmla) - print("use_flashmla", use_flashmla) - print("use_flashattn", use_flashattn) - print("use_triton", use_triton) - def _get_version(name, import_suffix) -> str: if use_v1: logger.info_once("Using %s backend on V1 engine.", name) From 73c87366b8a415422a18ba6981af42a4e64859cf Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 22 Apr 2025 21:22:05 +0000 Subject: [PATCH 14/42] fix Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 3b6b7308d325..536eb52d80cd 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -654,8 +654,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): attn_state.context_chunk_workspace_size self.page_size = self.runner.block_size - self.fa_aot_schedule = (get_flash_attn_version() == 3) - def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] From 0f9ed95b973dafaf24012aa23f85d0dd5d0bdf8b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 22 Apr 2025 22:03:50 +0000 Subject: [PATCH 15/42] fix logs Signed-off-by: Lucas Wilkinson --- vllm/platforms/cuda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b12b40a68e79..137ff7f5a870 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -239,10 +239,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def _get_version(name, import_suffix) -> str: if use_v1: - logger.info_once("Using %s backend on V1 engine.", name) + logger.info_once(f"Using {name} backend on V1 engine.") return f"vllm.v1.attention.backends.mla.{import_suffix}" else: - logger.info("Using %s backend.", name) + logger.info_once(f"Using {name} backend.") return f"vllm.attention.backends.mla.{import_suffix}" if use_cutlassmla: From f2dc4a33c32f9aefede3ebaa361a4841f91c1a29 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 22 Apr 2025 22:21:50 +0000 Subject: [PATCH 16/42] don't schedule prefills Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 52 +------------------ .../attention/backends/mla/flashattn_mla.py | 5 +- 2 files changed, 6 insertions(+), 51 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 7108e6412560..b15ef2fef927 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -218,8 +218,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec try: - from vllm.vllm_flash_attn import (flash_attn_varlen_func, - get_scheduler_metadata) + from vllm.vllm_flash_attn import flash_attn_varlen_func is_vllm_fa = True except ImportError: # For rocm use upstream flash attention @@ -301,12 +300,10 @@ class ChunkedContextMetadata: max_seq_lens: list[int] seq_lens: torch.Tensor workspace: torch.Tensor - scheduler_metatadata: list[Optional[torch.Tensor]] block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int - scheduler_metadata: Optional[torch.Tensor] = None chunked_context: Optional[ChunkedContextMetadata] = None @@ -588,25 +585,6 @@ def build_for_cudagraph_capture( return self.build(0, m) - def _schedule_prefill(self, num_reqs, cu_query_lens, max_query_len, - seqlens, max_seq_len, causal): - if self.fa_aot_schedule: - return get_scheduler_metadata( - batch_size=num_reqs, - max_seqlen_q=max_query_len, - max_seqlen_k=max_seq_len, - cache_seqlens=seqlens, - num_heads_q=self.num_heads, - num_heads_kv=self.num_heads, - headdim=self.mla_dims.qk_nope_head_dim + - self.mla_dims.qk_rope_head_dim, - headdim_v=self.mla_dims.v_head_dim, - page_size=self.page_size, - cu_seqlens_q=cu_query_lens, - causal=causal, - ) - return None - def build(self, num_actual_tokens: int, common_prefix_len: int, @@ -700,26 +678,13 @@ def build(self, chunks_max_seq_lens = chunk_seq_lens.max(dim=1).values - chunks_scheduler_metadata = [] - for i in range(num_chunks): - chunks_scheduler_metadata.append( - self._schedule_prefill( - self._num_prefills, - cu_seq_lens_cpu[i], - max_query_len, - chunk_seq_lens[i], - chunks_max_seq_lens[i], - causal=False, - )) - chunked_context_metadata = \ chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunks_max_seq_lens, - workspace=self.chunked_prefill_workspace, - scheduler_metatadata=chunks_scheduler_metadata, + workspace=self.chunked_prefill_workspace ) if self._use_cudnn_prefill: @@ -728,22 +693,12 @@ def build(self, assert max(chunked_context_metadata.max_seq_lens) <= \ self.chunked_prefill_workspace_size - scheduler_metadata = self._schedule_prefill( - self._num_prefills, - prefill_query_start_loc, - max_query_len, - prefill_query_start_loc, - max_query_len, - causal=True, - ) - prefill_metadata = MLACommonPrefillMetadata( input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, - scheduler_metadata=scheduler_metadata, ) if self._use_cudnn_prefill: @@ -872,7 +827,6 @@ def _flash_attn_varlen_diff_headdims(self, v, return_softmax_lse=False, softmax_scale=None, - scheduler_metadata=None, **kwargs): maybe_padded_v = v if self._pad_v: @@ -1119,7 +1073,6 @@ def _compute_prefill_context( softmax_scale=self.scale, causal=False, # Context is unmasked return_softmax_lse=True, - scheduler_metadata=prefill_metadata.chunked_context.scheduler_metatadata[i], ) if output is None: @@ -1166,7 +1119,6 @@ def _forward_prefill( k=k, v=v, return_softmax_lse=has_context, - scheduler_metadata=attn_metadata.prefill.scheduler_metadata, ) if has_context: diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index b77426ad3ae0..bb80e667b7f8 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -13,7 +13,8 @@ MLACommonMetadata, MLACommonMetadataBuilder) from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata -from vllm.vllm_flash_attn.fa_utils import flash_attn_supports_mla +from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_mla, + get_flash_attn_version) if TYPE_CHECKING: pass @@ -60,6 +61,8 @@ class FlashAttnMLAMetadataBuilder( def __init__(self, runner): super().__init__(runner) + self.fa_aot_schedule = (get_flash_attn_version() == 3) + self.page_size = self.runner.block_size def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): From d695fdc2a47d5375e6ba48b2713f200685c72d48 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 24 Apr 2025 19:22:32 +0000 Subject: [PATCH 17/42] still default to FlashMLA Signed-off-by: Lucas Wilkinson --- vllm/platforms/cuda.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 137ff7f5a870..c8d2c8c782cb 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -166,8 +166,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") from vllm.attention.ops.flashmla import is_flashmla_supported - if use_flashmla and is_flashmla_supported()[0] \ - and cache_config.block_size != 64: + use_flashmla = is_flashmla_supported()[0] and \ + (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" or + envs.VLLM_ATTENTION_BACKEND is None) + + if use_flashmla and cache_config.block_size != 64: cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") @@ -253,9 +256,6 @@ def _get_version(name, import_suffix) -> str: else: logger.warning( "Cutlass MLA backend is only supported on V1 engine") - if use_flashattn: - return _get_version("FlashAttention MLA", - "flashattn_mla.FlashAttnMLABackend") if use_flashmla: if block_size != 64: logger.warning( @@ -264,6 +264,9 @@ def _get_version(name, import_suffix) -> str: block_size) else: return _get_version("FlashMLA", "flashmla.FlashMLABackend") + if use_flashattn: + return _get_version("FlashAttention MLA", + "flashattn_mla.FlashAttnMLABackend") if use_triton: return _get_version("Triton MLA", "triton_mla.TritonMLABackend") From 82c9393e45d92389d7da9926389cbd52adcaadd1 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 17:51:53 +0000 Subject: [PATCH 18/42] Remove V0 FlashAttention MLA Signed-off-by: Matthew Bonanni --- vllm/attention/backends/mla/flashattn_mla.py | 271 ------------------- 1 file changed, 271 deletions(-) delete mode 100644 vllm/attention/backends/mla/flashattn_mla.py diff --git a/vllm/attention/backends/mla/flashattn_mla.py b/vllm/attention/backends/mla/flashattn_mla.py deleted file mode 100644 index 1f60764a21e3..000000000000 --- a/vllm/attention/backends/mla/flashattn_mla.py +++ /dev/null @@ -1,271 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from contextlib import contextmanager -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional - -import torch - -from vllm.attention.backends.abstract import AttentionType -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder, - MLACommonState) -from vllm.attention.backends.utils import get_mla_dims -from vllm.logger import init_logger -from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata -from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_mla, - get_flash_attn_version) - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -class FlashAttnMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "FLASHATTN_MLA" - - @staticmethod - def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: - return FlashAttnMLAMetadata - - @staticmethod - def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: - return FlashAttnMLAMetadataBuilder - - @staticmethod - def get_impl_cls() -> type["FlashAttnMLAImpl"]: - return FlashAttnMLAImpl - - @staticmethod - def get_state_cls() -> type["FlashAttnMLAState"]: - return FlashAttnMLAState - - -@dataclass -class FlashAttnMLAMetadata(MLACommonMetadata): - decode_scheduler_metadata: Optional[torch.Tensor] = None - - @property - def decode_metadata(self): - decode_metadata = super().decode_metadata - # TODO: cache assignment? - if decode_metadata is not None: - decode_metadata.decode_scheduler_metadata=\ - self.decode_scheduler_metadata - return decode_metadata - - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - raise NotImplementedError( - "advance_step is not implemented for FlashAttnMLA") - - -class FlashAttnMLAMetadataBuilder( - MLACommonMetadataBuilder[FlashAttnMLAMetadata]): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.num_heads_q = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) - self.fa_aot_schedule = (get_flash_attn_version() == 3) - self.mla_dims = get_mla_dims(self.runner.model_config) - self.page_size = self.runner.block_size - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - m = super().build(seq_lens, query_lens, cuda_graph_pad_size, - batch_size) - - decode_cu_seqlens_q = m.query_start_loc[ - m.num_prefills:] - m.query_start_loc[m.num_prefills] - - if m.num_decode_tokens > 0 and self.fa_aot_schedule: - m.decode_scheduler_metadata = get_scheduler_metadata( - batch_size=batch_size, - max_seqlen_q=m.max_decode_query_len, - max_seqlen_k=m.max_decode_seq_len, - cache_seqlens=m.seq_start_loc[m.num_prefills:], - num_heads_q=self.num_heads_q, - num_heads_kv=1, - headdim=self.mla_dims.qk_rope_head_dim, - headdim_v=self.mla_dims.kv_lora_rank, - page_size=self.page_size, - cu_seqlens_q=decode_cu_seqlens_q, - causal=True) - return m - - -class FlashAttnMLAState(MLACommonState[FlashAttnMLAMetadata]): - - def __init__(self, *args, **kwds): - super().__init__(*args, **kwds) - - self.fa_aot_schedule = (get_flash_attn_version() == 3) - self.num_heads_q = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) - self.mla_dims = get_mla_dims(self.runner.model_config) - self.page_size = self.runner.block_size - - def _dummy_scheduler_metadata(self, max_batch_size: int): - if self.fa_aot_schedule: - return get_scheduler_metadata( - batch_size=max_batch_size, - max_seqlen_q=1, - max_seqlen_k=1, - cache_seqlens=torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device), - num_heads_q=self.num_heads_q, - num_heads_kv=1, - headdim=self.mla_dims.qk_rope_head_dim, - headdim_v=self.mla_dims.kv_lora_rank, - page_size=self.page_size, - cu_seqlens_q=torch.arange(max_batch_size + 1, - dtype=torch.int32, - device=self.runner.device), - causal=True) - return None - - @contextmanager - def graph_capture(self, max_batch_size: int): - # Run a dummy `get_scheduler_metadata` so we can get the right shapes - self._graph_scheduler_metadata = self._dummy_scheduler_metadata( - max_batch_size) - self._graph_query_start_loc = torch.arange(max_batch_size + 1, - dtype=torch.int32, - device=self.runner.device) - - with super().graph_capture(max_batch_size): - yield - - del self._graph_scheduler_metadata - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - metadata = super().graph_capture_get_metadata_for_batch( - batch_size, is_encoder_decoder_model) - assert metadata.num_decode_tokens > 0 - - decoder_scheduler_metadata = self._dummy_scheduler_metadata(batch_size) - - metadata_size = decoder_scheduler_metadata.numel() - self._graph_scheduler_metadata[:metadata_size].copy_( - decoder_scheduler_metadata) - - metadata.decode_scheduler_metadata=\ - self._graph_scheduler_metadata[:metadata_size] - metadata.query_start_loc=\ - self._graph_query_start_loc[:batch_size + 1] - - return metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = super().get_graph_input_buffers( - attn_metadata, is_encoder_decoder_model) - input_buffers["decode_scheduler_metadata"] = \ - attn_metadata.decode_metadata.decode_scheduler_metadata - input_buffers["query_start_loc"] = \ - attn_metadata.decode_metadata.query_start_loc - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - super().prepare_graph_input_buffers(input_buffers, attn_metadata, - is_encoder_decoder_model) - - input_buffers["decode_scheduler_metadata"].copy_( - attn_metadata.decode_metadata.decode_scheduler_metadata) - input_buffers["query_start_loc"].copy_( - attn_metadata.decode_metadata.query_start_loc) - - -class FlashAttnMLAImpl(MLACommonImpl[MLACommonMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], - logits_soft_cap: Optional[float], - attn_type: str, - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, - **mla_args) - - assert flash_attn_supports_mla(), \ - "FlashAttnMLA is not supported on this device" - - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] - if any(unsupported_features): - raise NotImplementedError( - "FlashAttnMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashMLAImpl") - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 FlashMLA not yet supported") - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] - kv_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] - - o = flash_attn_varlen_func( - q=q_pe, - k=kv_pe_cache.unsqueeze(-2), # Add head dim of 1 - v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 - q_v=q_nope, - max_seqlen_q=decode_meta.max_decode_query_len, - cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_k=decode_meta.max_decode_seq_len, - seqused_k=decode_meta.seq_lens_tensor, - block_table=decode_meta.block_tables, - softmax_scale=self.scale, - causal=True, - fa_version=3, # only version 3 is supported - scheduler_metadata=decode_meta.decode_scheduler_metadata, - ) - - return self._v_up_proj_and_o_proj(o) From dc16bb5ad9187b09e41ddfa08bc7df0bd9d89c3f Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 17:54:21 +0000 Subject: [PATCH 19/42] Move back to original location Signed-off-by: Matthew Bonanni --- vllm/attention/backends/{mla => }/flashmla.py | 0 vllm/attention/backends/{mla => }/triton_mla.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename vllm/attention/backends/{mla => }/flashmla.py (100%) rename vllm/attention/backends/{mla => }/triton_mla.py (100%) diff --git a/vllm/attention/backends/mla/flashmla.py b/vllm/attention/backends/flashmla.py similarity index 100% rename from vllm/attention/backends/mla/flashmla.py rename to vllm/attention/backends/flashmla.py diff --git a/vllm/attention/backends/mla/triton_mla.py b/vllm/attention/backends/triton_mla.py similarity index 100% rename from vllm/attention/backends/mla/triton_mla.py rename to vllm/attention/backends/triton_mla.py From 046af0bc55946ac7e671cdc4b18c8340962176ef Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 17:55:37 +0000 Subject: [PATCH 20/42] Undo change Signed-off-by: Matthew Bonanni --- vllm/attention/backends/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index fa35b6e736a8..34e059067d84 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -600,6 +600,7 @@ class MLADims: def get_mla_dims(model_config: ModelConfig) -> MLADims: hf_text_config = model_config.hf_text_config + return MLADims( q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), kv_lora_rank=hf_text_config.kv_lora_rank, From 8a0fe940e85770f3bb6119db88a616cea4a74bb3 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 17:58:13 +0000 Subject: [PATCH 21/42] Match main Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 7 ++----- vllm/v1/attention/backends/mla/common.py | 14 +++++++------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c8d2c8c782cb..8754def5ac3f 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -166,11 +166,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") from vllm.attention.ops.flashmla import is_flashmla_supported - use_flashmla = is_flashmla_supported()[0] and \ - (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" or - envs.VLLM_ATTENTION_BACKEND is None) - - if use_flashmla and cache_config.block_size != 64: + if use_flashmla and is_flashmla_supported()[0] \ + and cache_config.block_size != 64: cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b15ef2fef927..2fe87ed03616 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -234,8 +234,6 @@ except ImportError: flashinfer_available = False -is_hip = current_platform.is_rocm() - logger = init_logger(__name__) CUDNN_WORKSPACE_SIZE = 12800 @@ -602,7 +600,6 @@ def build(self, device = self.device block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - device = self.runner.device query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -676,15 +673,19 @@ def build(self, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) - chunks_max_seq_lens = chunk_seq_lens.max(dim=1).values + chunked_context_metadata_cls = \ + CudnnPrefillMetadata.ChunkedContextMetadata \ + if self._use_cudnn_prefill else \ + MLACommonPrefillMetadata.ChunkedContextMetadata chunked_context_metadata = \ chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunks_max_seq_lens, - workspace=self.chunked_prefill_workspace + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, ) if self._use_cudnn_prefill: @@ -835,7 +836,6 @@ def _flash_attn_varlen_diff_headdims(self, if is_vllm_fa: kwargs["return_softmax_lse"] = return_softmax_lse - kwargs["scheduler_metadata"] = scheduler_metadata else: # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse From 9c5445d89b0472f54cb9a01c7c79229ada181abb Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 18:45:39 +0000 Subject: [PATCH 22/42] Use reorder_batch_threshold throughout Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/flashinfer.py | 3 ++- vllm/v1/attention/backends/linear_attn.py | 5 +++-- vllm/v1/attention/backends/mamba1_attn.py | 5 +++-- vllm/v1/attention/backends/mamba2_attn.py | 5 +++-- vllm/v1/attention/backends/mla/common.py | 7 +++---- vllm/v1/attention/backends/mla/flashattn_mla.py | 5 +++-- vllm/v1/attention/backends/short_conv_attn.py | 7 ++++--- vllm/v1/attention/backends/xformers.py | 12 +++++++----- 8 files changed, 28 insertions(+), 21 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 50819bb2bb94..5fa60332f04a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -462,7 +462,8 @@ def build(self, num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) page_size = self.page_size max_q_len = common_attn_metadata.max_query_len diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index f08b6d7f177c..ac0034b5dcf0 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -52,8 +52,9 @@ def build(self, state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) attn_metadata = LinearAttentionMetadata( num_prefills=num_prefills, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 97a1aa86dda0..7cbfa2c2c9a5 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -50,8 +50,9 @@ def build( query_start_loc.device) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) has_initial_states = None padded_decodes = num_decodes diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ed30884fdbc9..f3e6cd7430e0 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -115,8 +115,9 @@ def build(self, state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 2fe87ed03616..a64afc90712a 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -401,7 +401,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): understand this class """ reorder_batch_threshold: ClassVar[int] = 1 - decode_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, @@ -555,8 +554,7 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode(self, - block_table_tensor: torch.Tensor, + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, @@ -611,7 +609,8 @@ def build(self, query_seq_lens_cpu) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index bb80e667b7f8..1527363ea096 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Optional import torch @@ -57,7 +58,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder( MLACommonMetadataBuilder[FlashAttnMLAMetadata]): # TODO(lucas): tune this value - decode_threshold: int = 64 + reorder_batch_threshold: ClassVar[int] = 64 def __init__(self, runner): super().__init__(runner) diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index d80ced8ec876..fcbf0c7b5356 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -58,8 +58,9 @@ def build(self, state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) has_initial_states = None if num_prefills > 0: #[batch,] @@ -78,4 +79,4 @@ def build(self, has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, ) - return attn_metadata \ No newline at end of file + return attn_metadata diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index e0eb7d8be974..2162efba30fe 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -212,9 +212,10 @@ def __init__( def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) + return reorder_batch_to_split_decodes_and_prefills( + input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) def build( self, @@ -223,8 +224,9 @@ def build( fast_build: bool = False, ) -> XFormersAttentionMetadata: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc From 6b90fd773c0648a2c17195dc2d4bc449106a53ea Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 18:56:24 +0000 Subject: [PATCH 23/42] Remove input_positions Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 6 +----- vllm/v1/attention/backends/mla/flashattn_mla.py | 2 -- vllm/v1/attention/backends/mla/flashmla.py | 4 +--- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a64afc90712a..d03cb85e6f41 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -558,10 +558,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - input_positions: torch.Tensor): + query_start_loc_device: torch.Tensor): return MLACommonDecodeMetadata( - input_positions=input_positions, block_table=block_table_tensor, seq_lens=seq_lens_device, ) @@ -694,7 +692,6 @@ def build(self, self.chunked_prefill_workspace_size prefill_metadata = MLACommonPrefillMetadata( - input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, @@ -717,7 +714,6 @@ def build(self, 1], seq_lens_device=seq_lens[:self._num_decodes], seq_lens_cpu=seq_lens_cpu[:self._num_decodes], - input_positions=input_positions[:self._num_decode_tokens], ) attn_metadata = self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 1527363ea096..1cc929c0a3b9 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -87,7 +87,6 @@ def _build_decode(self, seq_lens_cpu: torch.Tensor, seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, - input_positions: torch.Tensor, block_table: torch.Tensor) -> FlashAttnMLADecodeMetadata: query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) @@ -104,7 +103,6 @@ def _build_decode(self, seq_lens_cpu: torch.Tensor, ) return FlashAttnMLADecodeMetadata( - input_positions=input_positions, block_table=block_table, seq_lens=seq_lens_device, query_start_loc=query_start_loc_device, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 430eac91c21a..77de7a951f20 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -89,8 +89,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - input_positions: torch.Tensor) -> FlashMLADecodeMetadata: + query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens_device, @@ -128,7 +127,6 @@ def _build_decode(self, block_table_tensor: torch.Tensor, return FlashMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, - input_positions=input_positions, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, ) From 790bde63a3f8cfa26f620386839c4814f7417b18 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 19:03:04 +0000 Subject: [PATCH 24/42] Match main, remove unused arguments Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 14 ++------------ vllm/v1/attention/backends/mla/flashattn_mla.py | 8 ++++---- vllm/v1/attention/backends/mla/flashmla.py | 4 +--- 3 files changed, 7 insertions(+), 19 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index d03cb85e6f41..44a43982f487 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -580,7 +580,6 @@ def build_for_cudagraph_capture( return self.build(0, m) def build(self, - num_actual_tokens: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> M: @@ -588,8 +587,6 @@ def build(self, num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - assert self._num_decodes + self._num_prefills == num_reqs - # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. @@ -691,8 +688,8 @@ def build(self, assert max(chunked_context_metadata.max_seq_lens) <= \ self.chunked_prefill_workspace_size - prefill_metadata = MLACommonPrefillMetadata( - block_table=block_table[reqs_start:, ...], + prefill_metadata = self.prefill_metadata_cls( + block_table=block_table_tensor[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, @@ -1061,13 +1058,6 @@ def _compute_prefill_context( q=q, k=k, v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, ) if output is None: diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 1cc929c0a3b9..8ef110afe92e 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -83,11 +83,11 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, ) return None - def _build_decode(self, seq_lens_cpu: torch.Tensor, + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - block_table: torch.Tensor) -> FlashAttnMLADecodeMetadata: + query_start_loc_device: torch.Tensor) -> FlashAttnMLADecodeMetadata: query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) max_query_len = query_lens_cpu.max().item() @@ -103,7 +103,7 @@ def _build_decode(self, seq_lens_cpu: torch.Tensor, ) return FlashAttnMLADecodeMetadata( - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens_device, query_start_loc=query_start_loc_device, max_query_len=max_query_len, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 77de7a951f20..79cbe400a3a1 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -87,9 +87,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata: + seq_lens_device: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens_device, From 161f50e0313f32a454fe60e47234cb0077a5bca7 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 19:15:38 +0000 Subject: [PATCH 25/42] Align _build_decode signature Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 12 +++++------- vllm/v1/attention/backends/mla/flashmla.py | 4 +++- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 13 ++++++++----- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 44a43982f487..63f0e085340e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -558,7 +558,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor): + query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, @@ -705,12 +705,10 @@ def build(self, if num_decodes > 0: decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], - seq_lens=seq_lens[:num_decodes], - query_start_loc_device=query_start_loc[:self._num_decodes + 1], - query_start_loc_cpu=query_start_loc_cpu[:self._num_decodes + - 1], - seq_lens_device=seq_lens[:self._num_decodes], - seq_lens_cpu=seq_lens_cpu[:self._num_decodes], + seq_lens_cpu=seq_lens_cpu[:num_decodes], + seq_lens_device=seq_lens[:num_decodes], + query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], + query_start_loc_device=query_start_loc[:num_decodes + 1], ) attn_metadata = self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 79cbe400a3a1..77de7a951f20 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -87,7 +87,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor) -> FlashMLADecodeMetadata: + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens_device, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 870cc600388e..417f8a514ae0 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -105,11 +105,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], device=device) def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size - block_table_bounds = (seq_lens + page_size - 1) // page_size + block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device - num_reqs = seq_lens.size(0) + num_reqs = seq_lens_device.size(0) mask = (torch.arange(block_table_tensor.size(1), dtype=block_table_tensor.dtype, @@ -117,7 +120,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table_tensor[mask] - paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = seq_lens_device % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) @@ -156,7 +159,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, From d87b921fe67fbe0f0d5783cc97a824822332e7b9 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 20:07:40 +0000 Subject: [PATCH 26/42] Fix more arguments Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 2 +- vllm/v1/attention/backends/mla/common.py | 10 +++++----- vllm/v1/attention/backends/mla/flashattn_mla.py | 14 +++++++------- vllm/v1/attention/backends/xformers.py | 4 +++- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8754def5ac3f..b70ed3596001 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -225,7 +225,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, # we should probably consider factoring out V1 here from vllm.attention.ops.flashmla import is_flashmla_supported - from vllm.vllm_flash_attn.fa_utils import flash_attn_supports_mla + from vllm.attention.utils.fa_utils import flash_attn_supports_mla use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 63f0e085340e..92581b0b7a9c 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -554,11 +554,10 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata: + def _build_decode( + self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, @@ -597,6 +596,7 @@ def build(self, query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 8ef110afe92e..88f94d36c151 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -7,6 +7,8 @@ import torch from vllm.attention.backends.abstract import AttentionType +from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, + get_flash_attn_version) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, @@ -14,8 +16,6 @@ MLACommonMetadata, MLACommonMetadataBuilder) from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata -from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_mla, - get_flash_attn_version) if TYPE_CHECKING: pass @@ -83,11 +83,11 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, ) return None - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> FlashAttnMLADecodeMetadata: + def _build_decode( + self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor + ) -> FlashAttnMLADecodeMetadata: query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) max_query_len = query_lens_cpu.max().item() diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 2162efba30fe..7ef7fc877d86 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, ClassVar, Optional import torch @@ -197,6 +197,8 @@ def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]: class XFormersAttentionMetadataBuilder( AttentionMetadataBuilder[XFormersAttentionMetadata]): + reorder_batch_threshold: ClassVar[int] = 1 + def __init__( self, kv_cache_spec: AttentionSpec, From 5a32eeb929e4c09538a8c5e74cf2a1eccfffed78 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 20:39:59 +0000 Subject: [PATCH 27/42] More compatibility fixes Signed-off-by: Matthew Bonanni --- .../attention/backends/mla/flashattn_mla.py | 63 ++++++++++--------- vllm/v1/attention/backends/mla/flashmla.py | 9 ++- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 88f94d36c151..436b059c4627 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -2,24 +2,24 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import ClassVar, Optional import torch -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, get_flash_attn_version) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.kv_cache_interface import AttentionSpec from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata -if TYPE_CHECKING: - pass - logger = init_logger(__name__) @@ -60,8 +60,10 @@ class FlashAttnMLAMetadataBuilder( # TODO(lucas): tune this value reorder_batch_threshold: ClassVar[int] = 64 - def __init__(self, runner): - super().__init__(runner) + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device, + FlashAttnMLAMetadata) self.fa_aot_schedule = (get_flash_attn_version() == 3) self.page_size = self.runner.block_size @@ -112,7 +114,7 @@ def _build_decode( ) -class FlashAttnMLAImpl(MLACommonImpl[MLACommonMetadata]): +class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): def __init__( self, @@ -123,47 +125,48 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, - **mla_args) + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) assert flash_attn_supports_mla(), \ "FlashAttnMLA is not supported on this device" - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( - "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "FlashAttnMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " - "FlashMLAImpl") + "FlashAttnMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashAttnMLA V1 with FP8 KV cache not yet supported") def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, + attn_metadata: FlashAttnMLAMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 FlashMLA not yet supported") + assert attn_metadata.decode is not None - decode_meta = attn_metadata.decode - assert decode_meta is not None + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "FP8 FlashAttention MLA not yet supported") kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] kv_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] @@ -173,15 +176,15 @@ def _forward_decode( k=kv_pe_cache.unsqueeze(-2), # Add head dim of 1 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 q_v=q_nope, - max_seqlen_q=decode_meta.max_query_len, - cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_k=decode_meta.max_seq_len, - seqused_k=decode_meta.seq_lens, - block_table=decode_meta.block_table, + max_seqlen_q=attn_metadata.decode.max_query_len, + cu_seqlens_q=attn_metadata.decode.query_start_loc, + max_seqlen_k=attn_metadata.decode.max_seq_len, + seqused_k=attn_metadata.decode.seq_lens, + block_table=attn_metadata.decode.block_table, softmax_scale=self.scale, causal=True, fa_version=3, # only version 3 is supported - scheduler_metadata=decode_meta.scheduler_metadata, + scheduler_metadata=attn_metadata.decode.scheduler_metadata, ) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 77de7a951f20..df617ab7a8ea 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -85,11 +85,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], device=self.device, dtype=torch.int32) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata: + def _build_decode( + self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens_device, From 63ec527d3573dc6074da8e76e6e4cd61c51e1ebc Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 21:00:22 +0000 Subject: [PATCH 28/42] Fix backend enum Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 2 +- tests/v1/attention/utils.py | 2 ++ vllm/platforms/interface.py | 2 ++ vllm/v1/attention/backends/mla/flashattn_mla.py | 3 +-- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 24070358799e..d07c33e84c78 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -16,7 +16,7 @@ BACKENDS_TO_TEST = [ _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, - _Backend.TRITON_MLA_VLLM_V1 + _Backend.FLASH_ATTN_MLA_VLLM_V1, _Backend.TRITON_MLA_VLLM_V1 ] # Remove CUTLASS_MLA from the list if not using sm100 diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 6a08cdc56f73..f71e0a550005 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", _Backend.FLASHMLA_VLLM_V1: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", + _Backend.FLASH_ATTN_MLA_VLLM_V1: + "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", _Backend.TRITON_MLA_VLLM_V1: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", } diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 00bc555288e8..324722a2a32c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -55,6 +55,8 @@ class _Backend(enum.Enum): FLASHMLA_VLLM_V1 = enum.auto() FLASHMLA = enum.auto() # Supported by V1 CUTLASS_MLA = enum.auto() + FLASH_ATTN_MLA_VLLM_V1 = enum.auto() + FLASH_ATTN_MLA = enum.auto() # Supported by V1 PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 436b059c4627..8871e7af269f 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -27,7 +27,7 @@ class FlashAttnMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: - return "FLASHATTN_MLA_VLLM_V1" + return "FLASH_ATTN_MLA_VLLM_V1" @staticmethod def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: @@ -65,7 +65,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata) self.fa_aot_schedule = (get_flash_attn_version() == 3) - self.page_size = self.runner.block_size def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): From da96e280b9f5164db1de226dfbc9f2492c71e3eb Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 21 Aug 2025 21:35:18 +0000 Subject: [PATCH 29/42] Remove unused helpers Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_attention_backends.py | 16 ---------------- tests/v1/attention/test_mla_backends.py | 14 -------------- 2 files changed, 30 deletions(-) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index e4c07aae0ebe..1ae8b91c347a 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -70,22 +70,6 @@ def _convert_dtype_to_torch(dtype): } -def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: - """Create a dummy KV cache tensor for testing.""" - kv_cache = torch.randn( - 2, # K and V - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), - device=device, - ) - return kv_cache - - def create_and_prepopulate_kv_cache( k_contexts: list[torch.Tensor], v_contexts: list[torch.Tensor], diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index d07c33e84c78..e3de1a5f239b 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -69,20 +69,6 @@ def _convert_dtype_to_torch(dtype): } -def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: - """Create a dummy KV cache tensor for testing.""" - kv_cache = torch.randn( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.head_size, # latent dimension - dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), - device=device, - ) - return kv_cache - - def create_and_prepopulate_kv_cache( kv_c_contexts: list[torch.Tensor], k_pe_contexts: list[torch.Tensor], From fb091248b8e7e59e6ef4067ee519dcbfc197abdf Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 25 Aug 2025 19:42:54 +0000 Subject: [PATCH 30/42] Rename Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashattn_mla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 8871e7af269f..c6c5a70aa234 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -168,11 +168,11 @@ def _forward_decode( "FP8 FlashAttention MLA not yet supported") kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] - kv_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] + k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] o = flash_attn_varlen_func( q=q_pe, - k=kv_pe_cache.unsqueeze(-2), # Add head dim of 1 + k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 q_v=q_nope, max_seqlen_q=attn_metadata.decode.max_query_len, From 70343e708c014b1a663ca3294b5ab55e652c1898 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 25 Aug 2025 20:05:47 +0000 Subject: [PATCH 31/42] Loosen tolerances for FA MLA backend Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index e3de1a5f239b..e1d7c35512f2 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -494,6 +494,10 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-1 + if backend_name == _Backend.FLASH_ATTN_MLA_VLLM_V1 and \ + any([q_len > 1 for q_len in query_lens]): + atol = 5e0 + max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() max_rel_diff = torch.max( torch.abs(backend_output - sdpa_output) / From 5daadfece741a0c4562c3fa6cb89c048bef0ddeb Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 25 Aug 2025 20:34:44 +0000 Subject: [PATCH 32/42] Fix _forward_decode signature Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashattn_mla.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index c6c5a70aa234..73ffd1bc8bdb 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -6,7 +6,7 @@ import torch -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, get_flash_attn_version) @@ -159,6 +159,7 @@ def _forward_decode( q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashAttnMLAMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None From b8e6e0a95a3f24dce75522d8941b4df5b20557e8 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 26 Aug 2025 01:59:19 +0000 Subject: [PATCH 33/42] Respect each backend's decode threshold Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 189 +++++++++++++----------- 1 file changed, 100 insertions(+), 89 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index e1d7c35512f2..dc639944cce2 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -317,6 +317,10 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): device=device) kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + all_sdpa_outputs = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + all_sdpa_outputs.append([]) + for i in range(batch_size): s_len = seq_lens[i] q_len = query_lens[i] @@ -344,85 +348,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): dtype=dtype, device=device) - # Determine if this is decode (single token) - # or prefill (multiple tokens) - is_decode = q_len == 1 + # Determine if this is decode or prefill + is_decode = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + builder_cls, _ = get_attention_backend(backend) + is_decode.append(q_len <= builder_cls.reorder_batch_threshold) # Split q into nope and rope components q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) - if is_decode: - # Decode path: MQA-style attention in latent space - # Transform q_nope to latent space: q_nope @ W_UK - # q_nope: [1, num_heads, qk_nope_head_dim] - # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] - ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, - W_UK) # [1, num_heads, kv_lora_rank] - - # Build MQA attention inputs - # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] - q_mqa = torch.cat([ql_nope, q_pe], dim=-1) - # K: [s_len, kv_lora_rank + qk_rope_head_dim] - # (broadcasted to all heads) - k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) - k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) - # V: [s_len, kv_lora_rank] (broadcasted to all heads) - v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) - - # SDPA expects (N, H, L, D) - q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) - k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) - v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) - - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale) - sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze( - 0) # [1, num_heads, kv_lora_rank] - - # Project back to output space: sdpa_out @ W_UV - sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV) - sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) - else: - # Prefill path: MHA-style attention with full sequence - # Apply kv_b_proj to the full kv_c tensor - kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, - kv_b_proj_weight) - k_nope_full, v_full = kv_nope_full.split( - [qk_nope_head_dim, v_head_dim], dim=-1) - - # Build attention inputs for full sequence - q_mha = torch.cat([q_nope, q_pe], - dim=-1) # [q_len, num_heads, total_dim] - k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) - k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) - - # Create custom attention mask: - # - Query tokens can attend to all context tokens - # - Query tokens can only attend to query tokens up to their pos - attn_mask = torch.ones(q_len, - s_len, - dtype=torch.bool, - device=device) - # Apply causal mask only to the query portion (context_len onwards) - causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) - attn_mask[:, context_len:] = causal_mask - - # SDPA expects (N, H, L, D) - q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) - k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) - v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) - - # Single attention call with custom mask - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, - k_sdpa_in, - v_sdpa_in, - attn_mask=attn_mask, - scale=scale) - sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0) - sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) - - all_sdpa_outputs.append(sdpa_out_i) + ####################################################### + # Decode path: MQA-style attention in latent space + # Transform q_nope to latent space: q_nope @ W_UK + # q_nope: [1, num_heads, qk_nope_head_dim] + # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, + W_UK) # [1, num_heads, kv_lora_rank] + + # Build MQA attention inputs + # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + # K: [s_len, kv_lora_rank + qk_rope_head_dim] + # (broadcasted to all heads) + k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) + # V: [s_len, kv_lora_rank] (broadcasted to all heads) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) + + # Create custom attention mask for decode path: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their position + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + + sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze( + 0) # [1, num_heads, kv_lora_rank] + + # Project back to output space: sdpa_out @ W_UV + sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, + W_UV) + sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2) + + ####################################################### + # Prefill path: MHA-style attention with full sequence + # Apply kv_b_proj to the full kv_c tensor + kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight) + k_nope_full, v_full = kv_nope_full.split( + [qk_nope_head_dim, v_head_dim], dim=-1) + + # Build attention inputs for full sequence + q_mha = torch.cat([q_nope, q_pe], + dim=-1) # [q_len, num_heads, total_dim] + k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) + k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) + + # Create custom attention mask: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their pos + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + # Single attention call with custom mask + sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) + sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) + + for i, backend in enumerate(BACKENDS_TO_TEST): + if is_decode[i]: + all_sdpa_outputs[i].append(sdpa_out_i_decode) + else: + all_sdpa_outputs[i].append(sdpa_out_i_prefill) # Inputs for vLLM MLA backends are just the new tokens all_q_vllm.append(q_c) @@ -437,7 +449,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): query_vllm = torch.cat(all_q_vllm, dim=0) kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) - sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + sdpa_outputs = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0)) # Create mock kv_b_proj using the same weights as reference implementation from vllm.model_executor.layers.linear import ColumnParallelLinear @@ -472,7 +486,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): randomize_blocks=True) # 4. Run vLLM backends and compare - for backend_name in BACKENDS_TO_TEST: + for i, backend_name in enumerate(BACKENDS_TO_TEST): backend_output = run_attention_backend( backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, @@ -480,12 +494,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): mock_kv_b_proj) # Check shape and dtype consistency - assert backend_output.shape == sdpa_output.shape, ( + assert backend_output.shape == sdpa_outputs[i].shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_output.shape}") - assert backend_output.dtype == sdpa_output.dtype, ( + f"SDPA shape {sdpa_outputs[i].shape}") + assert backend_output.dtype == sdpa_outputs[i].dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_output.dtype}") + f"SDPA dtype {sdpa_outputs[i].dtype}") assert torch.isfinite(backend_output).all(), ( f"[{backend_name}] produced non-finite values") @@ -494,16 +508,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-1 - if backend_name == _Backend.FLASH_ATTN_MLA_VLLM_V1 and \ - any([q_len > 1 for q_len in query_lens]): - atol = 5e0 - - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() + max_diff = torch.max(torch.abs(backend_output - + sdpa_outputs[i])).item() max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_output) / - torch.abs(sdpa_output)).item() + torch.abs(backend_output - sdpa_outputs[i]) / + torch.abs(sdpa_outputs[i])).item() all_close = torch.allclose(backend_output, - sdpa_output, + sdpa_outputs[i], rtol=rtol, atol=atol) From 91f01d4049a98bd2674ef8e4ad7083ee8f7c440e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 26 Aug 2025 02:23:35 +0000 Subject: [PATCH 34/42] Fix backend selection logic Signed-off-by: Matthew Bonanni --- vllm/engine/arg_utils.py | 3 +++ vllm/envs.py | 1 + vllm/platforms/cuda.py | 10 ++++++---- vllm/platforms/interface.py | 6 +++--- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3ab1115f1446..2c37d7655e2c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1501,6 +1501,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "TRITON_MLA", "CUTLASS_MLA", "FLASHMLA", + "FLASHMLA_VLLM_V1", + "FLASH_ATTN_MLA", + "FLASH_ATTN_MLA_VLLM_V1", "FLASHINFER", "FLASHINFER_VLLM_V1", "ROCM_AITER_MLA", diff --git a/vllm/envs.py b/vllm/envs.py index 1c9c4cdde800..67e5bf4b5920 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -459,6 +459,7 @@ def get_vllm_port() -> Optional[int]: # - "ROCM_FLASH": use ROCmFlashAttention # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA + # - "FLASH_ATTN_MLA": use FlashAttention for MLA "VLLM_ATTENTION_BACKEND": lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b70ed3596001..cf08900785ce 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -230,10 +230,12 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) and block_size == 128) - use_flashattn = selected_backend == _Backend.FLASH_ATTN or ( - selected_backend is None and flash_attn_supports_mla()) - use_flashmla = selected_backend == _Backend.FLASHMLA or ( - selected_backend is None and is_flashmla_supported()[0]) + use_flashattn = selected_backend in [ + _Backend.FLASH_ATTN_MLA, _Backend.FLASH_ATTN_MLA_VLLM_V1 + ] or (selected_backend is None and flash_attn_supports_mla()) + use_flashmla = selected_backend in [ + _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 + ] or (selected_backend is None and is_flashmla_supported()[0]) use_triton = selected_backend == _Backend.TRITON_MLA or ( selected_backend is None) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 324722a2a32c..2f8274fdf1c1 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -52,11 +52,11 @@ class _Backend(enum.Enum): FLASHINFER_VLLM_V1 = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA_VLLM_V1 = enum.auto() - FLASHMLA_VLLM_V1 = enum.auto() - FLASHMLA = enum.auto() # Supported by V1 CUTLASS_MLA = enum.auto() - FLASH_ATTN_MLA_VLLM_V1 = enum.auto() + FLASHMLA = enum.auto() # Supported by V1 + FLASHMLA_VLLM_V1 = enum.auto() FLASH_ATTN_MLA = enum.auto() # Supported by V1 + FLASH_ATTN_MLA_VLLM_V1 = enum.auto() PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() From 4201218833ccaec03bfce5f971cc85912665e823 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 26 Aug 2025 19:33:36 +0000 Subject: [PATCH 35/42] Address pre-commit Signed-off-by: Matthew Bonanni --- tests/v1/attention/test_mla_backends.py | 3 +-- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 9 ++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index dc639944cce2..53bf869932b3 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -301,7 +301,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # 2. Generate data and compute SDPA reference output for MLA all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] - all_sdpa_outputs = [] + all_sdpa_outputs: list[list[torch.Tensor]] = [] kv_c_contexts, k_pe_contexts = [], [] # Create shared MLA weight matrices for consistency across all sequences @@ -317,7 +317,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): device=device) kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) - all_sdpa_outputs = [] for i, backend in enumerate(BACKENDS_TO_TEST): all_sdpa_outputs.append([]) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 417f8a514ae0..42670093daa9 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -104,11 +104,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device=device) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata: + def _build_decode( + self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device From fe5ba4182c44fede25353bf71aa0d264448e1ac8 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 27 Aug 2025 02:20:27 +0000 Subject: [PATCH 36/42] Update GIT_TAG Signed-off-by: Matthew Bonanni --- cmake/external_projects/vllm_flash_attn.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 49defccbb1fa..3d32121f13ac 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f + GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn From 6455578e9568358b78befff871bba00937c1518c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 27 Aug 2025 14:59:07 +0000 Subject: [PATCH 37/42] Decode threshold tuning Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashattn_mla.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 73ffd1bc8bdb..b0852e866965 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -57,8 +57,7 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder( MLACommonMetadataBuilder[FlashAttnMLAMetadata]): - # TODO(lucas): tune this value - reorder_batch_threshold: ClassVar[int] = 64 + reorder_batch_threshold: ClassVar[int] = 512 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): From 513fdeb2b592b1feea2ff79c1bec2799dd5656cb Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 27 Aug 2025 17:49:05 +0000 Subject: [PATCH 38/42] Undo V0 change Signed-off-by: Matthew Bonanni --- vllm/attention/backends/mla/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 536eb52d80cd..c5ed4c6e4032 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -784,8 +784,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_seqs = len(seq_lens) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - query_start_loc.extend([query_start_loc[-1]] * cuda_graph_pad_size) + self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * + cuda_graph_pad_size) num_decode_tokens = batch_size - self.num_prefill_tokens block_tables = self._get_graph_runner_block_tables( From fd256158c95272834bb143f8db108fd91f42ee8c Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 27 Aug 2025 18:51:36 +0000 Subject: [PATCH 39/42] Pass qkv_dtype Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/flashattn_mla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index b0852e866965..da4bc2e72d17 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -72,10 +72,11 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, batch_size=num_reqs, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - cache_seqlens=seqlens, num_heads_q=self.num_heads, num_heads_kv=1, headdim=self.mla_dims.qk_rope_head_dim, + cache_seqlens=seqlens, + qkv_dtype=self.kv_cache_spec.dtype, headdim_v=self.mla_dims.kv_lora_rank, page_size=self.page_size, cu_seqlens_q=cu_query_lens, @@ -88,7 +89,6 @@ def _build_decode( seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor ) -> FlashAttnMLADecodeMetadata: - query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) max_query_len = query_lens_cpu.max().item() max_seq_len = seq_lens_cpu.max().item() From 398e55baea6fc8cdeb75671b55135d77c05f2a52 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 28 Aug 2025 14:23:57 +0000 Subject: [PATCH 40/42] increase wheel size Signed-off-by: Lucas Wilkinson --- .buildkite/check-wheel-size.py | 9 ++++----- docker/Dockerfile | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 68aff793ae6a..7e2e06136421 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -5,11 +5,10 @@ import sys import zipfile -# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB -# Note that we have 400 MiB quota, please use it wisely. -# See https://github.com/pypi/support/issues/3792 . -# Please also sync the value with the one in Dockerfile. -VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400)) +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB +# Note that we have 800 MiB quota, please use it wisely. +# See https://github.com/pypi/support/issues/6326 . +VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450)) def print_top_10_largest_files(zip_file): diff --git a/docker/Dockerfile b/docker/Dockerfile index 839ac501dbaf..3df5e1cdebfb 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -237,7 +237,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ # Check the size of the wheel if RUN_WHEEL_CHECK is true COPY .buildkite/check-wheel-size.py check-wheel-size.py # sync the default value with .buildkite/check-wheel-size.py -ARG VLLM_MAX_SIZE_MB=400 +ARG VLLM_MAX_SIZE_MB=450 ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB ARG RUN_WHEEL_CHECK=true RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ From 4f29ce1cf4fb354754d9ec00c5c047ecc7e72c3e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 28 Aug 2025 14:32:39 +0000 Subject: [PATCH 41/42] missing line Signed-off-by: Lucas Wilkinson --- .buildkite/check-wheel-size.py | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 7e2e06136421..76f6d7aeca0d 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -8,6 +8,7 @@ # Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB # Note that we have 800 MiB quota, please use it wisely. # See https://github.com/pypi/support/issues/6326 . +# Please also sync the value with the one in Dockerfile. VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450)) From 8672a7f14f307f9c7ae6fb74717ae17cc63cb276 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 28 Aug 2025 18:55:29 +0000 Subject: [PATCH 42/42] Fix backend selector logic and test Signed-off-by: Matthew Bonanni --- .../attention/test_attention_selector.py | 110 +++++++++++++----- tests/v1/attention/test_mla_backends.py | 4 +- tests/v1/attention/utils.py | 2 +- vllm/engine/arg_utils.py | 1 - vllm/platforms/cuda.py | 18 ++- vllm/platforms/interface.py | 1 - .../attention/backends/mla/flashattn_mla.py | 2 +- 7 files changed, 99 insertions(+), 39 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index aea166da3af2..3c2aaabacae8 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -22,7 +22,7 @@ def clear_cache(): # Define MLA and non-MLA backends separately DEVICE_MLA_BACKENDS = { - "cuda": ["TRITON_MLA", "FLASHMLA"], + "cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"], "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], "cpu": [], } @@ -98,21 +98,14 @@ def test_env( with patch("vllm.attention.selector.current_platform", RocmPlatform()): if use_mla: - # Validate HIP MLA backend-block_size combinations - valid_combination = ( - (name == "TRITON_MLA" and block_size != 1) - or (name == "ROCM_AITER_MLA" and block_size == 1)) - - if valid_combination: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name - assert backend.get_name() == expected - else: + # ROCm MLA backend logic: + # - TRITON_MLA: supported when block_size != 1 + # - ROCM_AITER_MLA: supported when block_size == 1 + # If backend is forced but doesn't match block_size, + # should raise ValueError + + if name == "TRITON_MLA" and block_size == 1: + # TRITON_MLA doesn't support block_size == 1 with pytest.raises(ValueError) as exc_info: get_attn_backend(16, torch.float16, @@ -122,6 +115,27 @@ def test_env( use_mla=use_mla) assert f"The selected backend, {name}" in str( exc_info.value) + elif name == "ROCM_AITER_MLA" and block_size != 1: + # ROCM_AITER_MLA only supports block_size == 1 + with pytest.raises(ValueError) as exc_info: + get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + assert f"The selected backend, {name}" in str( + exc_info.value) + else: + # Valid backend-block_size combination + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected else: backend = get_attn_backend(16, torch.float16, @@ -136,16 +150,57 @@ def test_env( with patch("vllm.attention.selector.current_platform", CudaPlatform()): if use_mla: - if name == "FLASHMLA" and block_size == 64: - from vllm.attention.backends.flashmla import ( - is_flashmla_supported) - - # only on cuda platforms with specific capability. - is_supported, _ = is_flashmla_supported() - - if not is_supported: - # if platform is not supported then skip this case. - pytest.skip() + # CUDA MLA backend logic: + # - CUTLASS_MLA: only supported with block_size == 128 + # and Blackwell GPUs (SM 10.0), V1 only + # - FLASHMLA: only supported with block_size == 64 + # - FLASH_ATTN_MLA: V1 only + # - TRITON_MLA: fallback for other cases + + if name == "CUTLASS_MLA": + if not use_v1: + # CUTLASS_MLA only supported on V1 engine + pytest.skip( + "CUTLASS_MLA only supported on V1 engine") + elif block_size != 128: + # CUTLASS_MLA only supports block_size == 128 + pytest.skip( + "CUTLASS_MLA only supports block_size 128") + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "CUTLASS_MLA_VLLM_V1" + assert backend.get_name() == expected + elif name == "FLASHMLA": + if block_size != 64: + # FlashMLA only supports block_size == 64 + pytest.skip("FlashMLA only supports block_size 64") + else: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + is_supported, _ = is_flashmla_supported() + if not is_supported: + pytest.skip( + "FlashMLA not supported on this platform") + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + elif name == "FLASH_ATTN_MLA": + if not use_v1: + # FlashAttention MLA only supported on V1 engine + pytest.skip( + "FlashAttention MLA only supported on V1 engine" + ) else: backend = get_attn_backend(16, torch.float16, @@ -153,9 +208,10 @@ def test_env( block_size, False, use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name + expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected else: + # TRITON_MLA or other fallback backend = get_attn_backend(16, torch.float16, torch.float16, diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 53bf869932b3..e7cd116fdc83 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -15,8 +15,8 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, - _Backend.FLASH_ATTN_MLA_VLLM_V1, _Backend.TRITON_MLA_VLLM_V1 + _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA, + _Backend.TRITON_MLA_VLLM_V1 ] # Remove CUTLASS_MLA from the list if not using sm100 diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index f71e0a550005..5c49566240df 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -139,7 +139,7 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", _Backend.FLASHMLA_VLLM_V1: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", - _Backend.FLASH_ATTN_MLA_VLLM_V1: + _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", _Backend.TRITON_MLA_VLLM_V1: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2c37d7655e2c..d464e9ecb1ff 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1503,7 +1503,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "FLASHMLA", "FLASHMLA_VLLM_V1", "FLASH_ATTN_MLA", - "FLASH_ATTN_MLA_VLLM_V1", "FLASHINFER", "FLASHINFER_VLLM_V1", "ROCM_AITER_MLA", diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index cf08900785ce..05b4d333a23b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -230,12 +230,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) and block_size == 128) - use_flashattn = selected_backend in [ - _Backend.FLASH_ATTN_MLA, _Backend.FLASH_ATTN_MLA_VLLM_V1 - ] or (selected_backend is None and flash_attn_supports_mla()) use_flashmla = selected_backend in [ _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 ] or (selected_backend is None and is_flashmla_supported()[0]) + use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( + selected_backend is None and flash_attn_supports_mla()) use_triton = selected_backend == _Backend.TRITON_MLA or ( selected_backend is None) @@ -245,7 +244,7 @@ def _get_version(name, import_suffix) -> str: return f"vllm.v1.attention.backends.mla.{import_suffix}" else: logger.info_once(f"Using {name} backend.") - return f"vllm.attention.backends.mla.{import_suffix}" + return f"vllm.attention.backends.{import_suffix}" if use_cutlassmla: if use_v1: @@ -264,8 +263,15 @@ def _get_version(name, import_suffix) -> str: else: return _get_version("FlashMLA", "flashmla.FlashMLABackend") if use_flashattn: - return _get_version("FlashAttention MLA", - "flashattn_mla.FlashAttnMLABackend") + if use_v1: + logger.info_once( + "Using FlashAttention MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashattn_mla.FlashAttnMLABackend") + else: + logger.warning( + "FlashAttention MLA backend is only supported on V1 " + "engine.") if use_triton: return _get_version("Triton MLA", "triton_mla.TritonMLABackend") diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 2f8274fdf1c1..9466232873f7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -56,7 +56,6 @@ class _Backend(enum.Enum): FLASHMLA = enum.auto() # Supported by V1 FLASHMLA_VLLM_V1 = enum.auto() FLASH_ATTN_MLA = enum.auto() # Supported by V1 - FLASH_ATTN_MLA_VLLM_V1 = enum.auto() PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index da4bc2e72d17..0e08307ddf84 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -27,7 +27,7 @@ class FlashAttnMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: - return "FLASH_ATTN_MLA_VLLM_V1" + return "FLASH_ATTN_MLA" @staticmethod def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: