From ef71a0ebb31620f4f627819f3a62cdb382e94a2d Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Mon, 23 Mar 2026 20:38:31 +0000 Subject: [PATCH 01/22] move block_size update to platform update_block_size_for_backend Signed-off-by: Chendi Xue --- vllm/platforms/xpu.py | 91 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 7 deletions(-) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 5d39dfcebef5..b4a48de1ca94 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -159,12 +159,8 @@ def get_static_graph_wrapper_cls(cls) -> str: @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - cache_config = vllm_config.cache_config model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config - # in V1(or with chunked prefill) block_size is 64 - if cache_config and not cache_config.user_specified_block_size: - cache_config.block_size = 64 # lazy import to avoid circular import from vllm.config import CUDAGraphMode @@ -228,9 +224,90 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: @classmethod def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: - # TODO: XPU still sets block_size in check_and_update_config. - # Move that logic here so block_size is chosen by the backend. - pass + """ + Ensure block_size is compatible with the attention backend. + """ + from vllm.config.cache import CacheConfig + from vllm.v1.attention.backend import AttentionBackend + + _DEFAULT_BLOCK_SIZE = 64 + + cache_config = vllm_config.cache_config + if cache_config.user_specified_block_size: + # User specified --block-size; keep it. + return + + model_config = vllm_config.model_config + if model_config is None: + cache_config.block_size = _DEFAULT_BLOCK_SIZE + return + + from vllm.config.vllm import ( + get_layers_from_vllm_config, + ) + from vllm.model_executor.layers.attention_layer_base import ( + AttentionLayerBase, + ) + from vllm.utils.math_utils import cdiv + + attn_layers = get_layers_from_vllm_config( + vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ) + if not attn_layers: + logger.info("Update no attn layers block size to %d", _DEFAULT_BLOCK_SIZE) + cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + return + + def get_backend_block_size( + backend_cls: type["AttentionBackend"], block_size: int + ): + if backend_cls.get_name() in ( + AttentionBackendEnum.FLASH_ATTN.name, + AttentionBackendEnum.TRITON_ATTN.name, + AttentionBackendEnum.TORCH_SDPA.name, + ): + return _DEFAULT_BLOCK_SIZE + elif backend_cls.get_name() == AttentionBackendEnum.XPU_MLA_SPARSE.name: + return 128 + else: + return block_size + + if model_config.is_hybrid: + backend_cls_list = set([i.get_attn_backend() for i in attn_layers.values()]) + block_size_list = [ + get_backend_block_size(i, cache_config.block_size) + for i in backend_cls_list + ] + new_block_size = cdiv(max(block_size_list), min(block_size_list)) * min( + block_size_list + ) + if cache_config.block_size == new_block_size: + return + logger.info("Update hybrid model block size to %d", new_block_size) + if cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = new_block_size + if cache_config.mamba_page_size_padded is not None: + attn_page_size_1_token = ( + cache_config.mamba_page_size_padded // cache_config.block_size + ) + cache_config.mamba_page_size_padded = ( + new_block_size * attn_page_size_1_token + ) + cache_config.block_size = new_block_size + return + + first_layer = next(iter(attn_layers.values())) + backend_cls = first_layer.get_attn_backend() + + new_block_size = get_backend_block_size(backend_cls, cache_config.block_size) + if cache_config.block_size == new_block_size: + return + logger.info( + "Update %s block size to %d", backend_cls.get_name(), new_block_size + ) + cache_config.block_size = new_block_size + return @classmethod def support_hybrid_kv_cache(cls) -> bool: From 52f18b76112fda0763d74434624c4293947088c7 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 18 Mar 2026 21:11:36 +0000 Subject: [PATCH 02/22] enable is_act_and_mul for xpu Signed-off-by: Chendi Xue --- vllm/model_executor/layers/fused_moe/layer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 85fd1813a363..dba3028b0875 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -567,7 +567,9 @@ def _get_quant_method() -> FusedMoEMethodBase: # for heuristic purposes, so it must be initialized first. self.quant_method: FusedMoEMethodBase = _get_quant_method() - if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike(): + if not self.moe_config.is_act_and_mul and not ( + current_platform.is_cuda_alike() or current_platform.is_xpu() + ): raise NotImplementedError( "is_act_and_mul=False is supported only for CUDA and ROCm for now" ) From 778cf5f577337342a96a33ec9316d8cd59490b18 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 24 Mar 2026 13:44:13 -0400 Subject: [PATCH 03/22] Refactor: move hybrid block_size alignment into base Platform Move the hybrid model block_size/mamba page alignment logic from HybridAttentionMambaModelConfig.verify_and_update_config (which runs too early, before the backend is known) into Platform's update_block_size_for_backend (which runs after layers are constructed). This fixes the XPU block_size=64 ordering bug without requiring each platform to reimplement alignment logic: - Strip verify_and_update_config to validation-only (calculate_kv_scales, MambaModelConfig setup) - Add _align_hybrid_block_size() to base Platform class, deriving kernel alignment from backend's get_supported_kernel_block_sizes() - Add default_block_size class variable so platforms (e.g. XPU=64) can customize without overriding the entire method - Remove XPU's custom update_block_size_for_backend override Co-Authored-By: Claude Opus 4.6 Signed-off-by: Matthew Bonanni --- vllm/model_executor/models/config.py | 148 ++------------------------- vllm/platforms/interface.py | 146 ++++++++++++++++++++++++-- vllm/platforms/xpu.py | 88 +--------------- 3 files changed, 142 insertions(+), 240 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index a5644a414aee..5dc826dea25f 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -1,15 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from math import lcm from typing import TYPE_CHECKING from vllm.logger import init_logger -from vllm.model_executor.models import ModelRegistry -from vllm.utils.math_utils import cdiv, round_up -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec +from vllm.utils.math_utils import round_up if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -104,11 +99,11 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ - Ensure that page size of attention layers is greater than or - equal to the mamba layers. If not, automatically set the attention - block size to ensure that it is. If the attention page size is - strictly greater than the mamba page size, we pad the mamba page size - to make them equal. + Perform early validation and setup for hybrid attention/mamba models. + + Block size alignment with mamba page sizes is handled later by + Platform.update_block_size_for_backend(), which runs after model + layers are constructed and the attention backend is known. Args: vllm_config: vLLM Config @@ -129,140 +124,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: ) cache_config.calculate_kv_scales = False - # Save the user input before it gets modified by MambaModelConfig - mamba_block_size = cache_config.mamba_block_size # Enable FULL_AND_PIECEWISE by default MambaModelConfig.verify_and_update_config(vllm_config) - attention_config = vllm_config.attention_config - cache_config = vllm_config.cache_config - model_config = vllm_config.model_config - parallel_config = vllm_config.parallel_config - - if cache_config.cache_dtype == "auto": - kv_cache_dtype = model_config.dtype - else: - kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - - # get attention page size (for 1 token) - # Attention backend constraints: - # - FlashAttention (FA) requires block size to be multiple of 16 - # - MLA (Multi-head Latent Attention) requires larger alignment: - # * CUTLASS_MLA backend: kernel_block_size 128 alignment - # * Other MLA backends: kernel_block_size 64 alignment - if model_config.use_mla: - use_cutlass_mla = ( - attention_config.backend == AttentionBackendEnum.CUTLASS_MLA - ) - kernel_block_alignment_size = 128 if use_cutlass_mla else 64 - attn_page_size_1_token = MLAAttentionSpec( - block_size=1, - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - ).page_size_bytes - else: - kernel_block_alignment_size = 16 - attn_page_size_1_token = FullAttentionSpec( - block_size=1, - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - ).page_size_bytes - - model_cls, _ = ModelRegistry.resolve_model_cls( - model_config.architecture, - model_config=model_config, - ) - - # get mamba page size - mamba_page_size = MambaSpec( - shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), - dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), - block_size=-1, # block_size doesn't matter for mamba page size - ).page_size_bytes - - # Model may be marked as is_hybrid - # but mamba is skipped via config, - # return directly - if mamba_page_size == 0: - return - - if cache_config.mamba_cache_mode == "all": - # With prefix caching, select attention block size to - # optimize for mamba kernel performance - - # Mamba2 SSD kernel uses a chunk_size, e.g. 256 - # Align the block to the kernel: use lowest multiple of chunk_size - # of attention tokens that would fit mamba_page_size: - # e.g. for mamba page size = 788kB - # attn_1_token = 2kB -> fits ~394 tokens - # then round up to a multiple of 256 -> 512 tokens - # End result: - # attn_block_size = 512 - # mamba_block_size = 512 (aligned to a multiple of chunk_size) - # TODO(tdoublep): this constraint can be relaxed fairly - # easily by changing the way we layout chunks in the - # mamba2 kernels. - - base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() - attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) - chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) - attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) - cache_config.mamba_block_size = attn_block_size - else: - # Without prefix caching, select minimum valid attention block size - # to minimize mamba state padding - - # Calculate minimum attention block size that satisfies both: - # 1. Backend alignment requirements (kernel_block_alignment_size) - # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) - attn_block_size = kernel_block_alignment_size * cdiv( - mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token - ) - - # override attention block size if it is too small, - # even if the user has explicitly set it - if cache_config.block_size < attn_block_size: - cache_config.block_size = attn_block_size - logger.info( - "Setting attention block size to %d tokens " - "to ensure that attention page size is >= mamba page size.", - attn_block_size, - ) - - # By default, mamba block size will be set to max_model_len. - # When enabling prefix caching and using align mamba cache - # mode, we align mamba block size to the block size as the - # basic granularity for prefix caching. - if cache_config.mamba_cache_mode == "align": - cache_config.mamba_block_size = cache_config.block_size - - # compute new attention page size - attn_page_size = cache_config.block_size * attn_page_size_1_token - - assert attn_page_size >= mamba_page_size - - if attn_page_size == mamba_page_size: - # don't need to pad mamba page size - return - - # pad mamba page size to exactly match attention - if ( - cache_config.mamba_page_size_padded is None - or cache_config.mamba_page_size_padded != attn_page_size - ): - cache_config.mamba_page_size_padded = attn_page_size - mamba_padding_pct = ( - 100 * (attn_page_size - mamba_page_size) / mamba_page_size - ) - logger.info( - "Padding mamba page size by %.2f%% to ensure " - "that mamba page size and attention page size are " - "exactly equal.", - mamba_padding_pct, - ) - class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 39688bb8b235..94e00e2d5fa9 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -21,6 +21,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser + from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.selector import AttentionSelectorConfig else: FlexibleArgumentParser = object @@ -135,6 +136,10 @@ class Platform: supported_quantization: list[str] = [] + # Default block size for the KV cache on this platform. + # Backends may override via get_preferred_block_size(). + default_block_size: int = 16 + additional_env_vars: list[str] = [] _global_graph_pool: Any | None = None @@ -427,9 +432,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: """ Ensure block_size is compatible with the attention backend. + For hybrid models, also aligns block_size with mamba page sizes. """ - from vllm.config.cache import CacheConfig - cache_config = vllm_config.cache_config if cache_config.user_specified_block_size: # User specified --block-size; keep it. @@ -437,10 +441,8 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: model_config = vllm_config.model_config # model_config may be None during testing. - # Skip hybrid models — their block_size is managed by - # HybridAttentionMambaModelConfig. - if model_config is None or model_config.is_hybrid: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + if model_config is None: + cache_config.block_size = cls.default_block_size return from vllm.config.vllm import ( @@ -456,16 +458,14 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: AttentionLayerBase, # type: ignore[type-abstract] ) if not attn_layers: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + cache_config.block_size = cls.default_block_size return first_layer = next(iter(attn_layers.values())) backend_cls = first_layer.get_attn_backend() with set_current_vllm_config(vllm_config): - preferred = backend_cls.get_preferred_block_size( - CacheConfig.DEFAULT_BLOCK_SIZE - ) - if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: + preferred = backend_cls.get_preferred_block_size(cls.default_block_size) + if preferred != cls.default_block_size: logger.info( "Setting kv cache block size to %d for %s backend.", preferred, @@ -473,6 +473,130 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: ) cache_config.block_size = preferred + if model_config.is_hybrid: + cls._align_hybrid_block_size(vllm_config, backend_cls) + + @classmethod + def _align_hybrid_block_size( + cls, + vllm_config: "VllmConfig", + backend_cls: "type[AttentionBackend]", + ) -> None: + """ + For hybrid attention/mamba models, ensure that the attention page + size is >= the mamba page size, and pad the mamba page size to match. + """ + from math import lcm + + from vllm.model_executor.models import ModelRegistry + from vllm.utils.math_utils import cdiv + from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE + from vllm.v1.attention.backend import MultipleOf + from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + MambaSpec, + MLAAttentionSpec, + ) + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Compute attention page size for 1 token + if model_config.use_mla: + attn_page_size_1_token = MLAAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + else: + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + + # Get kernel block alignment from the backend's supported sizes + supported_sizes = backend_cls.get_supported_kernel_block_sizes() + kernel_block_alignment_size = min( + s.base if isinstance(s, MultipleOf) else s for s in supported_sizes + ) + + # Compute mamba page size + model_cls, _ = ModelRegistry.resolve_model_cls( + model_config.architecture, + model_config=model_config, + ) + mamba_page_size = MambaSpec( + shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), + block_size=-1, + ).page_size_bytes + + if mamba_page_size == 0: + return + + # Save user's mamba_block_size before we potentially overwrite it + mamba_block_size = cache_config.mamba_block_size + + if cache_config.mamba_cache_mode == "all": + # With prefix caching, align to mamba chunk size for kernel perf + # TODO(tdoublep): this constraint can be relaxed fairly + # easily by changing the way we layout chunks in the + # mamba2 kernels. + base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() + attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) + chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) + attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) + cache_config.mamba_block_size = attn_block_size + else: + # Without prefix caching, use minimum block size that satisfies + # both backend alignment and mamba page size compatibility + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, + kernel_block_alignment_size * attn_page_size_1_token, + ) + + if cache_config.block_size < attn_block_size: + cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size, + ) + + if cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = cache_config.block_size + + # Pad mamba page size to exactly match attention page size + attn_page_size = cache_config.block_size * attn_page_size_1_token + assert attn_page_size >= mamba_page_size + + if attn_page_size == mamba_page_size: + return + + if ( + cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size + ): + cache_config.mamba_page_size_padded = attn_page_size + mamba_padding_pct = ( + 100 * (attn_page_size - mamba_page_size) / mamba_page_size + ) + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", + mamba_padding_pct, + ) + @classmethod def verify_model_arch(cls, model_arch: str) -> None: """ diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index b4a48de1ca94..1d279e921fba 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -37,6 +37,7 @@ class XPUPlatform(Platform): ray_device_key: str = "GPU" dist_backend: str = "xccl" # xccl only device_control_env_var: str = "ZE_AFFINITY_MASK" + default_block_size: int = 64 @classmethod def import_kernels(cls) -> None: @@ -222,93 +223,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # ref. https://openucx.readthedocs.io/en/master/faq.html os.environ["UCX_MEMTYPE_CACHE"] = "n" - @classmethod - def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: - """ - Ensure block_size is compatible with the attention backend. - """ - from vllm.config.cache import CacheConfig - from vllm.v1.attention.backend import AttentionBackend - - _DEFAULT_BLOCK_SIZE = 64 - - cache_config = vllm_config.cache_config - if cache_config.user_specified_block_size: - # User specified --block-size; keep it. - return - - model_config = vllm_config.model_config - if model_config is None: - cache_config.block_size = _DEFAULT_BLOCK_SIZE - return - - from vllm.config.vllm import ( - get_layers_from_vllm_config, - ) - from vllm.model_executor.layers.attention_layer_base import ( - AttentionLayerBase, - ) - from vllm.utils.math_utils import cdiv - - attn_layers = get_layers_from_vllm_config( - vllm_config, - AttentionLayerBase, # type: ignore[type-abstract] - ) - if not attn_layers: - logger.info("Update no attn layers block size to %d", _DEFAULT_BLOCK_SIZE) - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE - return - - def get_backend_block_size( - backend_cls: type["AttentionBackend"], block_size: int - ): - if backend_cls.get_name() in ( - AttentionBackendEnum.FLASH_ATTN.name, - AttentionBackendEnum.TRITON_ATTN.name, - AttentionBackendEnum.TORCH_SDPA.name, - ): - return _DEFAULT_BLOCK_SIZE - elif backend_cls.get_name() == AttentionBackendEnum.XPU_MLA_SPARSE.name: - return 128 - else: - return block_size - - if model_config.is_hybrid: - backend_cls_list = set([i.get_attn_backend() for i in attn_layers.values()]) - block_size_list = [ - get_backend_block_size(i, cache_config.block_size) - for i in backend_cls_list - ] - new_block_size = cdiv(max(block_size_list), min(block_size_list)) * min( - block_size_list - ) - if cache_config.block_size == new_block_size: - return - logger.info("Update hybrid model block size to %d", new_block_size) - if cache_config.mamba_cache_mode == "align": - cache_config.mamba_block_size = new_block_size - if cache_config.mamba_page_size_padded is not None: - attn_page_size_1_token = ( - cache_config.mamba_page_size_padded // cache_config.block_size - ) - cache_config.mamba_page_size_padded = ( - new_block_size * attn_page_size_1_token - ) - cache_config.block_size = new_block_size - return - - first_layer = next(iter(attn_layers.values())) - backend_cls = first_layer.get_attn_backend() - - new_block_size = get_backend_block_size(backend_cls, cache_config.block_size) - if cache_config.block_size == new_block_size: - return - logger.info( - "Update %s block size to %d", backend_cls.get_name(), new_block_size - ) - cache_config.block_size = new_block_size - return - @classmethod def support_hybrid_kv_cache(cls) -> bool: return True From d06a35a593fcd930ac2e23c99ba6cb30b1206a7d Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 18 Mar 2026 21:11:36 +0000 Subject: [PATCH 04/22] Revert "enable is_act_and_mul for xpu" This reverts commit 52f18b76112fda0763d74434624c4293947088c7. Signed-off-by: Chendi Xue --- vllm/model_executor/layers/fused_moe/layer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index dba3028b0875..85fd1813a363 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -567,9 +567,7 @@ def _get_quant_method() -> FusedMoEMethodBase: # for heuristic purposes, so it must be initialized first. self.quant_method: FusedMoEMethodBase = _get_quant_method() - if not self.moe_config.is_act_and_mul and not ( - current_platform.is_cuda_alike() or current_platform.is_xpu() - ): + if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike(): raise NotImplementedError( "is_act_and_mul=False is supported only for CUDA and ROCm for now" ) From c827335f8f260efdfb0c9ee9a24147202e33bd76 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 24 Mar 2026 18:34:03 +0000 Subject: [PATCH 05/22] Fix kernel_block_alignment_size Signed-off-by: Chendi Xue --- vllm/platforms/interface.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 94e00e2d5fa9..f6e5654831fe 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -491,12 +491,12 @@ def _align_hybrid_block_size( from vllm.model_executor.models import ModelRegistry from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE - from vllm.v1.attention.backend import MultipleOf from vllm.v1.kv_cache_interface import ( FullAttentionSpec, MambaSpec, MLAAttentionSpec, ) + from vllm.v1.worker.utils import select_common_block_size cache_config = vllm_config.cache_config model_config = vllm_config.model_config @@ -524,9 +524,8 @@ def _align_hybrid_block_size( ).page_size_bytes # Get kernel block alignment from the backend's supported sizes - supported_sizes = backend_cls.get_supported_kernel_block_sizes() - kernel_block_alignment_size = min( - s.base if isinstance(s, MultipleOf) else s for s in supported_sizes + kernel_block_alignment_size = select_common_block_size( + cache_config.block_size, [backend_cls] ) # Compute mamba page size @@ -552,6 +551,7 @@ def _align_hybrid_block_size( # easily by changing the way we layout chunks in the # mamba2 kernels. base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() + assert base_chunk_size is not None attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) From 2d78e873f791e717a0693a86cb913ed8e75e2810 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 24 Mar 2026 20:18:33 +0000 Subject: [PATCH 06/22] Add get_preferred_block_size to fa and update for xpu Signed-off-by: Chendi Xue --- vllm/v1/attention/backends/flash_attn.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f3f19f60c398..793358d6469a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -45,6 +45,7 @@ from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv, round_up from vllm.v1.attention.backend import ( @@ -92,6 +93,12 @@ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: forward_includes_kv_cache_update: bool = False + @classmethod + def get_preferred_block_size(cls, default_block_size: int) -> int: + if current_platform.is_xpu(): + return max(default_block_size, 64) + return super().get_preferred_block_size(default_block_size) + @staticmethod def get_name() -> str: return "FLASH_ATTN" From 2f66345ba91a5a7376c9b06df9624caacc68844e Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 24 Mar 2026 21:24:42 +0000 Subject: [PATCH 07/22] remove default_block_size and fix multi_backend Signed-off-by: Chendi Xue --- vllm/platforms/interface.py | 54 ++++++++++++++++++++++++++++--------- vllm/platforms/xpu.py | 1 - 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f6e5654831fe..a19f9a515367 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -136,10 +136,6 @@ class Platform: supported_quantization: list[str] = [] - # Default block size for the KV cache on this platform. - # Backends may override via get_preferred_block_size(). - default_block_size: int = 16 - additional_env_vars: list[str] = [] _global_graph_pool: Any | None = None @@ -434,6 +430,8 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: Ensure block_size is compatible with the attention backend. For hybrid models, also aligns block_size with mamba page sizes. """ + from vllm.config.cache import CacheConfig + cache_config = vllm_config.cache_config if cache_config.user_specified_block_size: # User specified --block-size; keep it. @@ -442,7 +440,7 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: model_config = vllm_config.model_config # model_config may be None during testing. if model_config is None: - cache_config.block_size = cls.default_block_size + cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return from vllm.config.vllm import ( @@ -452,20 +450,48 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: from vllm.model_executor.layers.attention_layer_base import ( AttentionLayerBase, ) + from vllm.v1.attention.backend import AttentionBackend attn_layers = get_layers_from_vllm_config( vllm_config, AttentionLayerBase, # type: ignore[type-abstract] ) if not attn_layers: - cache_config.block_size = cls.default_block_size + cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return - first_layer = next(iter(attn_layers.values())) - backend_cls = first_layer.get_attn_backend() + def get_full_attn_backend_cls() -> type[AttentionBackend]: + backend_cls_list = [ + layer.get_attn_backend() for layer in attn_layers.values() + ] + backend_cls_dict = { + backend_cls.get_name(): backend_cls for backend_cls in backend_cls_list + } + SSM_ATTN_BACKEND_NAMES = [ + "MAMBA2_ATTN", + "MAMBA1_ATTN", + "GDN_ATTN", + "LINEAR_ATTN", + "SHORT_CONV_ATTN", + ] + backend_cls_list = [ + backend_cls + for name, backend_cls in backend_cls_dict.items() + if name not in SSM_ATTN_BACKEND_NAMES + ] + if len(backend_cls_list) == 1: + return backend_cls_list[0] + else: + raise ValueError( + f"Multiple attention backends are not supported: {backend_cls_list}" + ) + + backend_cls = get_full_attn_backend_cls() with set_current_vllm_config(vllm_config): - preferred = backend_cls.get_preferred_block_size(cls.default_block_size) - if preferred != cls.default_block_size: + preferred = backend_cls.get_preferred_block_size( + CacheConfig.DEFAULT_BLOCK_SIZE + ) + if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: logger.info( "Setting kv cache block size to %d for %s backend.", preferred, @@ -488,6 +514,7 @@ def _align_hybrid_block_size( """ from math import lcm + from vllm.config.vllm import set_current_vllm_config from vllm.model_executor.models import ModelRegistry from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE @@ -524,9 +551,10 @@ def _align_hybrid_block_size( ).page_size_bytes # Get kernel block alignment from the backend's supported sizes - kernel_block_alignment_size = select_common_block_size( - cache_config.block_size, [backend_cls] - ) + with set_current_vllm_config(vllm_config): + kernel_block_alignment_size = select_common_block_size( + cache_config.block_size, [backend_cls] + ) # Compute mamba page size model_cls, _ = ModelRegistry.resolve_model_cls( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 1d279e921fba..0a3681139239 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -37,7 +37,6 @@ class XPUPlatform(Platform): ray_device_key: str = "GPU" dist_backend: str = "xccl" # xccl only device_control_env_var: str = "ZE_AFFINITY_MASK" - default_block_size: int = 64 @classmethod def import_kernels(cls) -> None: From d80ab7645a91b92a500ace1fdd269cd6631c6f83 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 24 Mar 2026 21:51:26 +0000 Subject: [PATCH 08/22] move ssm check to attn_backend Signed-off-by: Chendi Xue --- vllm/platforms/interface.py | 22 +++++-------------- vllm/v1/attention/backend.py | 4 ++++ vllm/v1/attention/backends/gdn_attn.py | 4 ++++ vllm/v1/attention/backends/linear_attn.py | 4 ++++ vllm/v1/attention/backends/mamba1_attn.py | 4 ++++ vllm/v1/attention/backends/mamba2_attn.py | 4 ++++ vllm/v1/attention/backends/short_conv_attn.py | 4 ++++ 7 files changed, 30 insertions(+), 16 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index a19f9a515367..a77c543eac1a 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -465,25 +465,15 @@ def get_full_attn_backend_cls() -> type[AttentionBackend]: layer.get_attn_backend() for layer in attn_layers.values() ] backend_cls_dict = { - backend_cls.get_name(): backend_cls for backend_cls in backend_cls_list + backend_cls.get_name(): backend_cls + for backend_cls in backend_cls_list + if not backend_cls.is_ssm() } - SSM_ATTN_BACKEND_NAMES = [ - "MAMBA2_ATTN", - "MAMBA1_ATTN", - "GDN_ATTN", - "LINEAR_ATTN", - "SHORT_CONV_ATTN", - ] - backend_cls_list = [ - backend_cls - for name, backend_cls in backend_cls_dict.items() - if name not in SSM_ATTN_BACKEND_NAMES - ] - if len(backend_cls_list) == 1: - return backend_cls_list[0] + if len(backend_cls_dict) == 1: + return list(backend_cls_dict.values())[0] else: raise ValueError( - f"Multiple attention backends are not supported: {backend_cls_list}" + f"Multiple attention backends are not supported: {backend_cls_dict}" ) backend_cls = get_full_attn_backend_cls() diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index cd49ea30e6f4..9001b23f3d54 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -311,6 +311,10 @@ def validate_configuration( def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": return None + @classmethod + def is_ssm(cls) -> bool: + return False + class AttentionMetadata: pass diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 574cc87e7582..f65d9a4b3891 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -31,6 +31,10 @@ def get_name() -> str: def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: return GDNAttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class GDNAttentionMetadata: diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index fe27e7a389ac..b2ca151986cc 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -27,6 +27,10 @@ def get_name() -> str: def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: return LinearAttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class LinearAttentionMetadata: diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 8903406200ca..925fceb024f6 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -20,6 +20,10 @@ def get_name() -> str: def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: return Mamba1AttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class Mamba1AttentionMetadata(BaseMambaAttentionMetadata): diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 5e8abbab565e..fa7d4bd2ec51 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -96,6 +96,10 @@ def get_name() -> str: def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]: return Mamba2AttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class Mamba2AttentionMetadata(BaseMambaAttentionMetadata): diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index c6a8e6eeaa16..9c85ec5efb30 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -18,6 +18,10 @@ def get_name() -> str: def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: return ShortConvAttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class ShortConvAttentionMetadata(BaseMambaAttentionMetadata): From c7f188ad62eb04793804a7585d2e1d305ecd273d Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Tue, 24 Mar 2026 17:16:01 -0500 Subject: [PATCH 09/22] Update vllm/platforms/interface.py Co-authored-by: Matthew Bonanni Signed-off-by: Chendi.Xue --- vllm/platforms/interface.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index a77c543eac1a..21f74a665141 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -460,23 +460,16 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return - def get_full_attn_backend_cls() -> type[AttentionBackend]: - backend_cls_list = [ - layer.get_attn_backend() for layer in attn_layers.values() - ] - backend_cls_dict = { - backend_cls.get_name(): backend_cls - for backend_cls in backend_cls_list - if not backend_cls.is_ssm() - } - if len(backend_cls_dict) == 1: - return list(backend_cls_dict.values())[0] - else: - raise ValueError( - f"Multiple attention backends are not supported: {backend_cls_dict}" - ) - - backend_cls = get_full_attn_backend_cls() + backend_cls = None + for layer in attn_layers.values(): + b = layer.get_attn_backend() + if not b.is_ssm(): + backend_cls = b + break + + if backend_cls is None: + cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + return with set_current_vllm_config(vllm_config): preferred = backend_cls.get_preferred_block_size( CacheConfig.DEFAULT_BLOCK_SIZE From 347bbad186060463d475e9d1e60553382573d6ea Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 24 Mar 2026 22:50:35 +0000 Subject: [PATCH 10/22] clean up codes Signed-off-by: Chendi Xue --- vllm/platforms/interface.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 21f74a665141..f9e52daf36f0 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -450,7 +450,6 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: from vllm.model_executor.layers.attention_layer_base import ( AttentionLayerBase, ) - from vllm.v1.attention.backend import AttentionBackend attn_layers = get_layers_from_vllm_config( vllm_config, From 128ada2479963dfe078112de37cf2bed03fca77f Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 24 Mar 2026 22:55:57 +0000 Subject: [PATCH 11/22] update the way to get kernel_block_alignment_size Signed-off-by: Chendi Xue --- vllm/platforms/interface.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f9e52daf36f0..a655a58c9277 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -500,12 +500,12 @@ def _align_hybrid_block_size( from vllm.model_executor.models import ModelRegistry from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE + from vllm.v1.attention.backend import MultipleOf from vllm.v1.kv_cache_interface import ( FullAttentionSpec, MambaSpec, MLAAttentionSpec, ) - from vllm.v1.worker.utils import select_common_block_size cache_config = vllm_config.cache_config model_config = vllm_config.model_config @@ -532,12 +532,6 @@ def _align_hybrid_block_size( dtype=kv_cache_dtype, ).page_size_bytes - # Get kernel block alignment from the backend's supported sizes - with set_current_vllm_config(vllm_config): - kernel_block_alignment_size = select_common_block_size( - cache_config.block_size, [backend_cls] - ) - # Compute mamba page size model_cls, _ = ModelRegistry.resolve_model_cls( model_config.architecture, @@ -555,6 +549,16 @@ def _align_hybrid_block_size( # Save user's mamba_block_size before we potentially overwrite it mamba_block_size = cache_config.mamba_block_size + # Get kernel block alignment from the backend's supported sizes + with set_current_vllm_config(vllm_config): + kernel_block_alignment_size = max( + min( + s.base if isinstance(s, MultipleOf) else s + for s in backend_cls.get_supported_kernel_block_sizes() + ), + cache_config.block_size, + ) + if cache_config.mamba_cache_mode == "all": # With prefix caching, align to mamba chunk size for kernel perf # TODO(tdoublep): this constraint can be relaxed fairly From d61487495da17a8a5b9687608b105ba8010b7344 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 25 Mar 2026 23:53:20 +0000 Subject: [PATCH 12/22] Fix pytest error Signed-off-by: Chendi Xue --- vllm/platforms/interface.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index a655a58c9277..bd560b1783cb 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -434,8 +434,7 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config if cache_config.user_specified_block_size: - # User specified --block-size; keep it. - return + assert cache_config.block_size, "block_size must be positive." model_config = vllm_config.model_config # model_config may be None during testing. From cb0596fcb2f21574b033a51f08cd75a5b0d16395 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 26 Mar 2026 02:55:33 +0000 Subject: [PATCH 13/22] Fix pytest Signed-off-by: Chendi Xue --- tests/v1/worker/test_gpu_model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 93c5435e817b..fe27b2f27d78 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -846,6 +846,7 @@ def test_hybrid_attention_mamba_tensor_shapes(): # suppress var not used error assert fwd_context is not None vllm_ctx = vllm_config.compilation_config.static_forward_context + current_platform.update_block_size_for_backend(vllm_config) runner = GPUModelRunner(vllm_config, DEVICE) kv_cache_spec = runner.get_kv_cache_spec() @@ -1276,6 +1277,7 @@ def test_cudagraph_sizes_capped_for_mamba_cache(): assert fwd_context is not None runner = GPUModelRunner(vllm_config, DEVICE) + current_platform.update_block_size_for_backend(vllm_config) kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes From 7362163b2b3b265380aa6487cbbdf18eaff24449 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 26 Mar 2026 05:35:01 +0000 Subject: [PATCH 14/22] Fix mamba_attn_chunk_size ignore issue due to pre mamba_block_size setting done in model_config Signed-off-by: Chendi Xue --- vllm/platforms/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index bd560b1783cb..aac17c173fba 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -563,7 +563,7 @@ def _align_hybrid_block_size( # TODO(tdoublep): this constraint can be relaxed fairly # easily by changing the way we layout chunks in the # mamba2 kernels. - base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() + base_chunk_size = model_config.get_mamba_chunk_size() or mamba_block_size assert base_chunk_size is not None attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) From a234a85c0b11e6962d743f2be15d2fa00668c0af Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 26 Mar 2026 06:24:02 +0000 Subject: [PATCH 15/22] skip return default for user_specified_block_size Signed-off-by: Chendi Xue --- vllm/platforms/interface.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index aac17c173fba..632f5dbc6323 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -438,7 +438,7 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: model_config = vllm_config.model_config # model_config may be None during testing. - if model_config is None: + if not cache_config.user_specified_block_size and model_config is None: cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return @@ -454,7 +454,7 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: vllm_config, AttentionLayerBase, # type: ignore[type-abstract] ) - if not attn_layers: + if not cache_config.user_specified_block_size and not attn_layers: cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return @@ -465,20 +465,24 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: backend_cls = b break - if backend_cls is None: + if not cache_config.user_specified_block_size and backend_cls is None: cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return - with set_current_vllm_config(vllm_config): - preferred = backend_cls.get_preferred_block_size( - CacheConfig.DEFAULT_BLOCK_SIZE - ) - if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: - logger.info( - "Setting kv cache block size to %d for %s backend.", - preferred, - backend_cls.get_name(), - ) - cache_config.block_size = preferred + + assert backend_cls is not None + + if not cache_config.user_specified_block_size: + with set_current_vllm_config(vllm_config): + preferred = backend_cls.get_preferred_block_size( + CacheConfig.DEFAULT_BLOCK_SIZE + ) + if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: + logger.info( + "Setting kv cache block size to %d for %s backend.", + preferred, + backend_cls.get_name(), + ) + cache_config.block_size = preferred if model_config.is_hybrid: cls._align_hybrid_block_size(vllm_config, backend_cls) From 67f1195e8b70ba969fb672dcc616b721021b2667 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 26 Mar 2026 15:23:15 +0000 Subject: [PATCH 16/22] fix last fix Signed-off-by: Chendi Xue --- vllm/platforms/interface.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 632f5dbc6323..aa540465cf89 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -438,8 +438,9 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: model_config = vllm_config.model_config # model_config may be None during testing. - if not cache_config.user_specified_block_size and model_config is None: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + if model_config is None: + if not cache_config.user_specified_block_size: + cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return from vllm.config.vllm import ( @@ -454,8 +455,9 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: vllm_config, AttentionLayerBase, # type: ignore[type-abstract] ) - if not cache_config.user_specified_block_size and not attn_layers: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + if not attn_layers: + if not cache_config.user_specified_block_size: + cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return backend_cls = None @@ -465,8 +467,9 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: backend_cls = b break - if not cache_config.user_specified_block_size and backend_cls is None: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + if backend_cls is None: + if not cache_config.user_specified_block_size: + cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return assert backend_cls is not None From 04249bb477c23f9eeb00a46ea9a853e8298ce257 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 26 Mar 2026 16:03:48 -0400 Subject: [PATCH 17/22] Refactor update_block_size_for_backend for clarity - Extract _find_non_ssm_backend() helper to separate the backend lookup concern from block size logic - Reduce user_specified_block_size checks from 5 to 1 by structuring the function into two clear phases: 1. Pick block size from backend preference (skipped if user set --block-size) 2. Align for hybrid models (always runs, may increase block_size) - Replace silent no-op with assert for the hybrid model invariant that at least one non-SSM attention backend must exist - Remove redundant `assert backend_cls is not None` after None-guard return Co-Authored-By: Claude Opus 4.6 Signed-off-by: Matthew Bonanni --- vllm/platforms/interface.py | 84 ++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 43 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 961da7b82f43..34c74fd7670b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -425,28 +425,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: pass @classmethod - def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: - """ - Ensure block_size is compatible with the attention backend. - For hybrid models, also aligns block_size with mamba page sizes. - """ - from vllm.config.cache import CacheConfig - - cache_config = vllm_config.cache_config - if cache_config.user_specified_block_size: - assert cache_config.block_size, "block_size must be positive." - - model_config = vllm_config.model_config - # model_config may be None during testing. - if model_config is None: - if not cache_config.user_specified_block_size: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE - return - - from vllm.config.vllm import ( - get_layers_from_vllm_config, - set_current_vllm_config, - ) + def _find_non_ssm_backend( + cls, vllm_config: "VllmConfig" + ) -> "type[AttentionBackend] | None": + """Find the first non-SSM attention backend from model layers.""" + from vllm.config.vllm import get_layers_from_vllm_config from vllm.model_executor.layers.attention_layer_base import ( AttentionLayerBase, ) @@ -455,39 +438,54 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: vllm_config, AttentionLayerBase, # type: ignore[type-abstract] ) - if not attn_layers: - if not cache_config.user_specified_block_size: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE - return - - backend_cls = None for layer in attn_layers.values(): b = layer.get_attn_backend() if not b.is_ssm(): - backend_cls = b - break + return b + return None + + @classmethod + def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: + """ + Ensure block_size is compatible with the attention backend. + For hybrid models, also aligns block_size with mamba page sizes. + """ + from vllm.config.cache import CacheConfig + from vllm.config.vllm import set_current_vllm_config - if backend_cls is None: + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + + # model_config may be None during testing. + if not model_config: if not cache_config.user_specified_block_size: cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return - assert backend_cls is not None + backend_cls = cls._find_non_ssm_backend(vllm_config) + # Phase 1: Pick block size from backend (skip if user set --block-size) if not cache_config.user_specified_block_size: - with set_current_vllm_config(vllm_config): - preferred = backend_cls.get_preferred_block_size( - CacheConfig.DEFAULT_BLOCK_SIZE - ) - if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: - logger.info( - "Setting kv cache block size to %d for %s backend.", - preferred, - backend_cls.get_name(), - ) - cache_config.block_size = preferred + if backend_cls: + with set_current_vllm_config(vllm_config): + preferred = backend_cls.get_preferred_block_size( + CacheConfig.DEFAULT_BLOCK_SIZE + ) + if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: + logger.info( + "Setting kv cache block size to %d for %s backend.", + preferred, + backend_cls.get_name(), + ) + cache_config.block_size = preferred + else: + cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + # Phase 2: Align for hybrid models (always runs, may increase block_size) if model_config.is_hybrid: + assert backend_cls, ( + "Hybrid model must have at least one non-SSM attention backend" + ) cls._align_hybrid_block_size(vllm_config, backend_cls) @classmethod From 26d5bacbd854602cdd4b0da1c2000f6d0e59a1cd Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 26 Mar 2026 16:07:57 -0400 Subject: [PATCH 18/22] Comments Signed-off-by: Matthew Bonanni --- vllm/platforms/interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 34c74fd7670b..737b3d300fcd 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -464,7 +464,7 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: backend_cls = cls._find_non_ssm_backend(vllm_config) - # Phase 1: Pick block size from backend (skip if user set --block-size) + # 1. Pick block size from backend (skip if user set --block-size) if not cache_config.user_specified_block_size: if backend_cls: with set_current_vllm_config(vllm_config): @@ -481,7 +481,7 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: else: cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE - # Phase 2: Align for hybrid models (always runs, may increase block_size) + # 2. Align for hybrid models (may change block_size even if user specified it) if model_config.is_hybrid: assert backend_cls, ( "Hybrid model must have at least one non-SSM attention backend" From 8f1f8d8bbfd505a6e490d3f0c83063e3b844e6f3 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 26 Mar 2026 16:24:31 -0400 Subject: [PATCH 19/22] Simplify Signed-off-by: Matthew Bonanni --- vllm/platforms/interface.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 737b3d300fcd..47a9ea60436b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -458,30 +458,25 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: # model_config may be None during testing. if not model_config: - if not cache_config.user_specified_block_size: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE return backend_cls = cls._find_non_ssm_backend(vllm_config) - # 1. Pick block size from backend (skip if user set --block-size) - if not cache_config.user_specified_block_size: - if backend_cls: - with set_current_vllm_config(vllm_config): - preferred = backend_cls.get_preferred_block_size( - CacheConfig.DEFAULT_BLOCK_SIZE - ) - if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: - logger.info( - "Setting kv cache block size to %d for %s backend.", - preferred, - backend_cls.get_name(), - ) - cache_config.block_size = preferred - else: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE - - # 2. Align for hybrid models (may change block_size even if user specified it) + # Phase 1: Pick block size from backend (skip if user set --block-size) + if not cache_config.user_specified_block_size and backend_cls: + with set_current_vllm_config(vllm_config): + preferred = backend_cls.get_preferred_block_size( + CacheConfig.DEFAULT_BLOCK_SIZE + ) + if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: + logger.info( + "Setting kv cache block size to %d for %s backend.", + preferred, + backend_cls.get_name(), + ) + cache_config.block_size = preferred + + # Phase 2: Align for hybrid models (always runs, may increase block_size) if model_config.is_hybrid: assert backend_cls, ( "Hybrid model must have at least one non-SSM attention backend" From 6a9ba626ea2fb125f14220f0926145049a8c39fc Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 26 Mar 2026 21:34:35 +0000 Subject: [PATCH 20/22] Add user_specified_mamba_block_size Signed-off-by: Chendi Xue --- vllm/config/cache.py | 5 +++++ vllm/model_executor/models/config.py | 5 +++++ vllm/platforms/interface.py | 10 +++++++--- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 8a9eb484d58a..36dde1ea7c5b 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -38,6 +38,8 @@ class CacheConfig: Accepts None (meaning "use default"). After construction, always int.""" user_specified_block_size: bool = field(default=False, init=False) """Whether block_size was explicitly provided. Derived automatically.""" + user_specified_mamba_block_size: bool = field(default=False, init=False) + """Whether mamba_block_size was explicitly provided. Derived automatically.""" gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) """The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory @@ -172,6 +174,7 @@ def compute_hash(self) -> str: "cpu_kvcache_space_bytes", "mamba_page_size_padded", "user_specified_block_size", + "user_specified_mamba_block_size", "_block_size_resolved", # Post-init/derived counters "num_gpu_blocks", @@ -204,6 +207,8 @@ def _apply_block_size_default(self) -> "CacheConfig": object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE) else: object.__setattr__(self, "user_specified_block_size", True) + if self.mamba_block_size is not None: + object.__setattr__(self, "user_specified_mamba_block_size", True) return self @field_validator("calculate_kv_scales", mode="after") diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 5dc826dea25f..5682f3bb926e 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -113,6 +113,11 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # Disable calculate_kv_scales for hybrid models: uninitialized # recurrent state corrupts scales during the calibration pass. # See issue: https://github.com/vllm-project/vllm/issues/37554 + + # Save the user input before it gets modified by MambaModelConfig + cache_config.user_specified_mamba_block_size = ( + cache_config.mamba_block_size is not None + ) if cache_config.calculate_kv_scales: logger.warning( "Disabling calculate_kv_scales for hybrid model '%s'. " diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 47a9ea60436b..4a2bf5229218 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -545,8 +545,12 @@ def _align_hybrid_block_size( if mamba_page_size == 0: return - # Save user's mamba_block_size before we potentially overwrite it - mamba_block_size = cache_config.mamba_block_size + # mamba_block_size here should either be user specified value or None + mamba_block_size = ( + cache_config.mamba_block_size + if cache_config.user_specified_mamba_block_size + else None + ) # Get kernel block alignment from the backend's supported sizes with set_current_vllm_config(vllm_config): @@ -563,7 +567,7 @@ def _align_hybrid_block_size( # TODO(tdoublep): this constraint can be relaxed fairly # easily by changing the way we layout chunks in the # mamba2 kernels. - base_chunk_size = model_config.get_mamba_chunk_size() or mamba_block_size + base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() assert base_chunk_size is not None attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) From e07c3bd26272803217b29c5de13f9d64cfa0fded Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 26 Mar 2026 21:49:04 +0000 Subject: [PATCH 21/22] check backend_cls for non_attn_layer case Signed-off-by: Chendi Xue --- vllm/platforms/interface.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 4a2bf5229218..240eed64a1f3 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -461,9 +461,11 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: return backend_cls = cls._find_non_ssm_backend(vllm_config) + if backend_cls is None: + return # Phase 1: Pick block size from backend (skip if user set --block-size) - if not cache_config.user_specified_block_size and backend_cls: + if not cache_config.user_specified_block_size: with set_current_vllm_config(vllm_config): preferred = backend_cls.get_preferred_block_size( CacheConfig.DEFAULT_BLOCK_SIZE @@ -478,9 +480,6 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: # Phase 2: Align for hybrid models (always runs, may increase block_size) if model_config.is_hybrid: - assert backend_cls, ( - "Hybrid model must have at least one non-SSM attention backend" - ) cls._align_hybrid_block_size(vllm_config, backend_cls) @classmethod From 6ca84a4c0267b42844cf6ac4390e639a5b8e9c21 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 27 Mar 2026 17:53:46 +0000 Subject: [PATCH 22/22] fix comments Signed-off-by: Chendi Xue --- tests/v1/worker/test_gpu_model_runner.py | 2 +- vllm/model_executor/models/config.py | 4 ---- vllm/platforms/interface.py | 3 ++- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index fe27b2f27d78..acf25acdd38e 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -846,9 +846,9 @@ def test_hybrid_attention_mamba_tensor_shapes(): # suppress var not used error assert fwd_context is not None vllm_ctx = vllm_config.compilation_config.static_forward_context - current_platform.update_block_size_for_backend(vllm_config) runner = GPUModelRunner(vllm_config, DEVICE) + current_platform.update_block_size_for_backend(vllm_config) kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 5682f3bb926e..03b147e5c257 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -114,10 +114,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # recurrent state corrupts scales during the calibration pass. # See issue: https://github.com/vllm-project/vllm/issues/37554 - # Save the user input before it gets modified by MambaModelConfig - cache_config.user_specified_mamba_block_size = ( - cache_config.mamba_block_size is not None - ) if cache_config.calculate_kv_scales: logger.warning( "Disabling calculate_kv_scales for hybrid model '%s'. " diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 240eed64a1f3..fae37442ec57 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -478,7 +478,8 @@ def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: ) cache_config.block_size = preferred - # Phase 2: Align for hybrid models (always runs, may increase block_size) + # Phase 2: Align block/mamba sizes for hybrid models + # (may override user settings). if model_config.is_hybrid: cls._align_hybrid_block_size(vllm_config, backend_cls)