From c5c5606d0b0433ca2ca0c267cb2a835dba60e548 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 30 Sep 2025 20:08:48 +0000 Subject: [PATCH 1/3] enable spec compat for flashinfer-mla Signed-off-by: Benjamin Chislett --- vllm/v1/attention/backends/mla/common.py | 33 +++++++++++++++++-- .../attention/backends/mla/flashinfer_mla.py | 27 +++++++++++++-- vllm/v1/attention/backends/utils.py | 3 +- 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 561793b6a377..4df4e57375da 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -190,7 +190,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass, field -from typing import Generic, Optional, TypeVar, Union +from typing import ClassVar, Generic, Optional, TypeVar, Union import torch from tqdm import tqdm @@ -436,6 +436,26 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + + # Whether the backend supports reordering the batch such that + # short sequences (i.e. verification for speculative decoding) are + # classified as decode requests. + # If True, this will increase `reorder_batch_threshold` (below) when + # speculative decoding is enabled. + supports_spec_as_decode: ClassVar[bool] = False + + # Whether the backend supports grouping decode requests with + # different query lengths in the same batch. If False, when + # `reorder_batch_threshold > 1`, any decode requests which do not + # have the same query length as the first decode request will + # fall back to the prefill kernel. + supports_nonuniform_decode: ClassVar[bool] = False + + # The threshold for reordering the batch into decode and prefill requests. + # If > 1, the batch will be reordered such that requests with + # query length <= threshold are classified as decode requests. + # Use `supports_spec_as_decode` (above) to set this automatically + # when speculative decoding is enabled. reorder_batch_threshold: int = 1 @staticmethod @@ -479,6 +499,7 @@ def __init__(self, self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.compilation_config = vllm_config.compilation_config + self.vllm_config = vllm_config self.device = device self.num_heads = self.model_config.get_num_attention_heads( @@ -551,6 +572,10 @@ def __init__(self, device=device, ) + supports_spec_as_decode = self.supports_spec_as_decode + self._init_reorder_batch_threshold(self.reorder_batch_threshold, + supports_spec_as_decode) + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): qo_indptr = prefill.query_start_loc @@ -680,8 +705,10 @@ def build(self, query_seq_lens_cpu) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=not self.supports_nonuniform_decode) # Note(hc): update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 701248670f72..465cf64c6ac7 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import ClassVar, Optional, Union import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla @@ -10,13 +10,20 @@ from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, - MLACommonMetadata) + MLACommonMetadata, + MLACommonMetadataBuilder) logger = init_logger(__name__) FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 +class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata] + ): + # enable spec-as-decode optimization + supports_spec_as_decode: ClassVar[bool] = True + + class FlashInferMLABackend(MLACommonBackend): @staticmethod @@ -27,6 +34,10 @@ def get_name() -> str: def get_impl_cls() -> type["FlashInferMLAImpl"]: return FlashInferMLAImpl + @staticmethod + def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: + return FlashInferMLAMetadataBuilder + g_fi_workspace = torch.zeros( FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, @@ -87,7 +98,14 @@ def _forward_decode( q = torch.cat([q_nope, q_pe], dim=-1) # trtllm API requires extra dimension q_len_per_request for MTP - q = q.unsqueeze(1) + if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0: + logger.warning_once( + """FlashInferMLAImpl got a query of uneven length. + This usually indicates an issue in batch reordering + or incorrect setup in dummy_run.""") + q = q.unsqueeze(1) + else: + q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1]) if self.bmm1_scale is None: self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * @@ -109,6 +127,9 @@ def _forward_decode( bmm2_scale=self.bmm2_scale, ) + # Flatten the output for consistent shape + o = o.view(-1, o.shape[-2], o.shape[-1]) + # TODO: Return LSE pending support from Flashinfer API: # https://github.com/flashinfer-ai/flashinfer/pull/1566 return o, None diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index f37a829f401c..9e29faa9c8e8 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -260,7 +260,8 @@ def _init_reorder_batch_threshold( if (speculative_config is not None and speculative_config.num_speculative_tokens is not None): self.reorder_batch_threshold = \ - 1 + speculative_config.num_speculative_tokens + max(self.reorder_batch_threshold, + 1 + speculative_config.num_speculative_tokens) @abstractmethod def build(self, From d0c7b3681d5fbef29adf0e2e3eebb7350de54995 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Fri, 3 Oct 2025 23:02:10 +0000 Subject: [PATCH 2/3] update defaults for flashattn_mla Signed-off-by: Benjamin Chislett --- vllm/v1/attention/backends/mla/flashattn_mla.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 652b1cdb6b76..34d0867e0f0d 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -64,6 +64,9 @@ class FlashAttnMLAMetadataBuilder( cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.UNIFORM_BATCH + supports_spec_as_decode: ClassVar[bool] = True + supports_nonuniform_decode: ClassVar[bool] = True + reorder_batch_threshold: int = 512 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], From e1ac9d8141c65e820d72a0db252d7bb12cb7570a Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 7 Oct 2025 02:59:18 +0000 Subject: [PATCH 3/3] simplify flags Signed-off-by: Benjamin Chislett --- vllm/v1/attention/backends/mla/common.py | 19 +++++++------------ .../attention/backends/mla/flashattn_mla.py | 3 --- .../attention/backends/mla/flashinfer_mla.py | 2 +- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 3af6d940004a..3fb00f5917ea 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -458,20 +458,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # short sequences (i.e. verification for speculative decoding) are # classified as decode requests. # If True, this will increase `reorder_batch_threshold` (below) when - # speculative decoding is enabled. - supports_spec_as_decode: ClassVar[bool] = False - - # Whether the backend supports grouping decode requests with - # different query lengths in the same batch. If False, when - # `reorder_batch_threshold > 1`, any decode requests which do not - # have the same query length as the first decode request will - # fall back to the prefill kernel. - supports_nonuniform_decode: ClassVar[bool] = False + # speculative decoding is enabled, and set `require_uniform=True` when + # when reordering the batch. Non-uniform decode requests will + # fall back to prefill in this case. + supports_uniform_spec_as_decode: ClassVar[bool] = False # The threshold for reordering the batch into decode and prefill requests. # If > 1, the batch will be reordered such that requests with # query length <= threshold are classified as decode requests. - # Use `supports_spec_as_decode` (above) to set this automatically + # Use `supports_uniform_spec_as_decode` (above) to set this automatically # when speculative decoding is enabled. reorder_batch_threshold: int = 1 @@ -598,7 +593,7 @@ def __init__( device=device, ) - supports_spec_as_decode = self.supports_spec_as_decode + supports_spec_as_decode = self.supports_uniform_spec_as_decode self._init_reorder_batch_threshold( self.reorder_batch_threshold, supports_spec_as_decode ) @@ -741,7 +736,7 @@ def build( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold, - require_uniform=not self.supports_nonuniform_decode, + require_uniform=self.supports_uniform_spec_as_decode, ) ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index be0560073dc1..c0c2dbe1f961 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -68,9 +68,6 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - supports_spec_as_decode: ClassVar[bool] = True - supports_nonuniform_decode: ClassVar[bool] = True - reorder_batch_threshold: int = 512 def __init__( diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index ba2964aeaff6..c337e8929004 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -22,7 +22,7 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): # enable spec-as-decode optimization - supports_spec_as_decode: ClassVar[bool] = True + supports_uniform_spec_as_decode: ClassVar[bool] = True class FlashInferMLABackend(MLACommonBackend):