From d82622679b4b6742d7b4e95e09f7b36041844ad6 Mon Sep 17 00:00:00 2001 From: Martin Hickey Date: Fri, 17 Apr 2026 16:27:35 +0100 Subject: [PATCH 1/3] Enable mypy for `vllm/model_executor/layers/` Signed-off-by: Martin Hickey --- tools/pre_commit/mypy.py | 1 - vllm/model_executor/layers/activation.py | 27 ++++---- .../layers/attention/attention.py | 20 ++++-- .../attention/chunked_local_attention.py | 2 +- .../layers/attention/cross_attention.py | 13 ++-- .../attention/encoder_only_attention.py | 4 +- .../layers/attention/mla_attention.py | 26 +++++--- .../layers/fused_moe/all2all_utils.py | 11 ++-- .../model_executor/layers/fused_moe/config.py | 3 +- .../experts/batched_deep_gemm_moe.py | 6 +- .../layers/fused_moe/oracle/mxfp4.py | 13 ++-- .../flashinfer_nvlink_one_sided.py | 24 +++++-- .../flashinfer_nvlink_two_sided.py | 25 ++++--- .../fused_moe/prepare_finalize/naive_dp_ep.py | 4 ++ .../fused_moe/runner/default_moe_runner.py | 4 +- vllm/model_executor/layers/kda.py | 43 +++++++----- vllm/model_executor/layers/layernorm.py | 2 +- vllm/model_executor/layers/mamba/abstract.py | 3 +- .../layers/mamba/gdn_linear_attn.py | 66 ++++++++++--------- .../layers/mamba/linear_attn.py | 9 +-- .../layers/mamba/mamba_mixer.py | 13 ++-- .../layers/mamba/mamba_mixer2.py | 21 ++++-- .../model_executor/layers/mamba/short_conv.py | 9 +-- .../layers/pooler/seqwise/poolers.py | 1 + .../layers/pooler/tokwise/poolers.py | 1 + .../layers/quantization/fp_quant.py | 8 ++- .../layers/quantization/quark/quark.py | 2 +- .../layers/sparse_attn_indexer.py | 26 ++++---- 28 files changed, 241 insertions(+), 146 deletions(-) diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 41c05efd201b..7c7b0ada60d6 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -29,7 +29,6 @@ "tests", # v0 related "vllm/lora", - "vllm/model_executor/layers", ] # TODO(woosuk): Include the code from Megatron and HuggingFace. diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 26a771cb7500..e2b70b771a12 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -666,16 +666,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): "gelu": lambda: GELU(), "gelu_fast": lambda: FastGELU(), "gelu_new": lambda: NewGELU(), - "gelu_pytorch_tanh": lambda: ( - # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile - logger.warning_once( - "[ROCm] PyTorch's native GELU with tanh approximation is unstable. " - "Falling back to GELU(approximate='none')." - ), - nn.GELU(approximate="none"), - )[1] - if current_platform.is_rocm() - else nn.GELU(approximate="tanh"), + "gelu_pytorch_tanh": lambda: _get_gelu_pytorch_tanh(), "relu": lambda: nn.ReLU(), "relu2": lambda: ReLUSquaredActivation(), "silu": lambda: nn.SiLU(), @@ -687,6 +678,18 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): ) +def _get_gelu_pytorch_tanh() -> nn.Module: + """Get PyTorch GELU with tanh approximation, with ROCm fallback.""" + if current_platform.is_rocm(): + # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile + logger.warning_once( + "[ROCm] PyTorch's native GELU with tanh approximation is unstable. " + "Falling back to GELU(approximate='none')." + ) + return nn.GELU(approximate="none") + return nn.GELU(approximate="tanh") + + def get_act_fn(act_fn_name: str) -> nn.Module: """Get an activation function by name.""" act_fn_name = act_fn_name.lower() @@ -703,12 +706,12 @@ def get_act_fn(act_fn_name: str) -> nn.Module: return _ACTIVATION_REGISTRY[act_fn_name] -_ACTIVATION_AND_MUL_REGISTRY = LazyDict( +_ACTIVATION_AND_MUL_REGISTRY: LazyDict[nn.Module] = LazyDict( { "gelu": lambda: GeluAndMul(), "silu": lambda: SiluAndMul(), "geglu": lambda: GeluAndMul(), - "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), + "swigluoai": lambda: SwigluOAIAndMul(), } ) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index b2a2295ce461..0f3c50684552 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -33,6 +33,7 @@ ) from vllm.v1.attention.backend import ( AttentionBackend, + AttentionMetadata, AttentionType, ) from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -209,6 +210,7 @@ def __init__( `self.kv_cache`. """ super().__init__() + sliding_window: int | None if per_layer_sliding_window is not None: # per-layer sliding window sliding_window = per_layer_sliding_window @@ -335,7 +337,7 @@ def __init__( cache_config.enable_prefix_caching = False impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls( + self.impl = impl_cls( # type: ignore[assignment] # impl_cls always returns an AttentionImpl subclass num_heads, head_size, scale, @@ -576,7 +578,7 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend - def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: # Block size may get updated after model loading, refresh it block_size = vllm_config.cache_config.block_size # Should not be called for enc-dec or encoder-only attention. @@ -680,9 +682,15 @@ def get_attention_context( extracted from the forward context. """ forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] + attn_metadata_raw = forward_context.attn_metadata + attn_metadata: AttentionMetadata + if isinstance(attn_metadata_raw, dict): + attn_metadata = attn_metadata_raw[layer_name] + else: + # list[dict[str, AttentionMetadata]]: used in speculative decoding + # where [0] is the base-model (non-speculative) metadata dict. + assert isinstance(attn_metadata_raw, list) + attn_metadata = attn_metadata_raw[0][layer_name] attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] kv_cache = attn_layer.kv_cache slot_mapping = forward_context.slot_mapping @@ -708,7 +716,7 @@ def unified_kv_cache_update( assert hasattr(attn_layer.impl, "do_kv_cache_update"), ( f"{attn_layer.impl.__class__.__name__} does not support kv cache update" ) - attn_layer.impl.do_kv_cache_update( + attn_layer.impl.do_kv_cache_update( # type: ignore[attr-defined] attn_layer, key, value, diff --git a/vllm/model_executor/layers/attention/chunked_local_attention.py b/vllm/model_executor/layers/attention/chunked_local_attention.py index 136574d97529..cb595438adef 100644 --- a/vllm/model_executor/layers/attention/chunked_local_attention.py +++ b/vllm/model_executor/layers/attention/chunked_local_attention.py @@ -29,7 +29,7 @@ @functools.lru_cache def create_chunked_local_attention_backend( - underlying_attn_backend: AttentionBackend, + underlying_attn_backend: type[AttentionBackend], attention_chunk_size: int, ) -> type[AttentionBackend]: prefix = f"ChunkedLocalAttention_{attention_chunk_size}_" diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index 61699832a624..312f906abacc 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -72,7 +72,7 @@ def _get_cross_slot_mapping( @functools.lru_cache def create_cross_attention_backend( - underlying_attn_backend: AttentionBackend, + underlying_attn_backend: type[AttentionBackend], ) -> type[AttentionBackend]: prefix = "CrossAttention_" underlying_builder = underlying_attn_backend.get_builder_cls() @@ -87,6 +87,7 @@ def build( ) -> AttentionMetadata: new_metadata = copy(common_attn_metadata) new_metadata.causal = False + assert new_metadata.encoder_seq_lens_cpu is not None max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max()) new_metadata.max_seq_len = max_encoder_len # Any computed tokens indicated decode step>1 (no chunked prefill) @@ -118,7 +119,7 @@ def build( self.device, ) attn_metadata = super().build(common_prefix_len, new_metadata, fast_build) - attn_metadata.slot_mapping = slot_mapping + attn_metadata.slot_mapping = slot_mapping # type: ignore[attr-defined] return attn_metadata # NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by @@ -144,8 +145,12 @@ def forward( and key is not None and value is not None ): - self.do_kv_cache_update( - layer, key, value, kv_cache, attn_metadata.slot_mapping + self.do_kv_cache_update( # type: ignore[attr-defined] + layer, + key, + value, + kv_cache, + attn_metadata.slot_mapping, # type: ignore[attr-defined] ) return super().forward( diff --git a/vllm/model_executor/layers/attention/encoder_only_attention.py b/vllm/model_executor/layers/attention/encoder_only_attention.py index 0897ee45b84d..5805fe2ae1ca 100644 --- a/vllm/model_executor/layers/attention/encoder_only_attention.py +++ b/vllm/model_executor/layers/attention/encoder_only_attention.py @@ -21,7 +21,7 @@ @functools.lru_cache def create_encoder_only_attention_backend( - underlying_attn_backend: AttentionBackend, + underlying_attn_backend: type[AttentionBackend], ) -> type[AttentionBackend]: prefix = "EncoderOnlyAttention_" underlying_builder = underlying_attn_backend.get_builder_cls() @@ -93,6 +93,6 @@ def __init__( **kwargs, ) - def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: # Does not need KV cache return None diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index cbbf5f3c3ca3..02f322116075 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -389,7 +389,7 @@ def __init__( cache_config.enable_prefix_caching = False impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) - self.impl = impl_cls( + self.impl = impl_cls( # type: ignore[assignment] # impl_cls always returns an MLAAttentionImpl subclass num_heads=self.num_heads, head_size=self.head_size, scale=self.scale, @@ -485,16 +485,22 @@ def forward( if self.use_direct_call: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] + attn_metadata_raw = forward_context.attn_metadata + attn_metadata: MLACommonMetadata + if isinstance(attn_metadata_raw, dict): + attn_metadata = attn_metadata_raw[self.layer_name] # type: ignore[assignment] + else: + # list[dict[str, AttentionMetadata]]: used in speculative decoding + # where [0] is the base-model (non-speculative) metadata dict. + assert isinstance(attn_metadata_raw, list) + attn_metadata = attn_metadata_raw[0][self.layer_name] # type: ignore[assignment] self_kv_cache = self.kv_cache slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " ) - self.impl.do_kv_cache_update( + self.impl.do_kv_cache_update( # type: ignore[attr-defined] kv_c_normed, k_pe, self_kv_cache, @@ -612,7 +618,7 @@ def forward_impl( num_mha_tokens = q.size(0) - num_mqa_tokens if num_mha_tokens > 0: - self.impl.forward_mha( + self.impl.forward_mha( # type: ignore[attr-defined] q[num_mqa_tokens:], k_c_normed[num_mqa_tokens:], k_pe[num_mqa_tokens:], @@ -695,7 +701,7 @@ def forward_impl( # call decode attn if not is_sparse_impl: assert attn_metadata.decode is not None - attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) + attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) # type: ignore[attr-defined] # correct dcp attn_out with lse. if self.impl.dcp_world_size > 1: @@ -1053,9 +1059,9 @@ class QueryLenSupport(Enum): "AITER_MLA backends use aiter kernels instead." ) elif current_platform.is_xpu(): - from vllm._xpu_ops import xpu_ops as ops + from vllm import _xpu_ops - flash_attn_varlen_func = ops.flash_attn_varlen_func # type: ignore[no-redef] + flash_attn_varlen_func = _xpu_ops.xpu_ops.flash_attn_varlen_func # type: ignore[no-redef,attr-defined,assignment] def dynamic_per_batched_tensor_quant( @@ -1988,7 +1994,7 @@ def build( assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata) self._build_fi_prefill_wrappers(attn_metadata.prefill) - return attn_metadata + return attn_metadata # type: ignore[return-value] def reorg_kvcache( diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 62b2602928f1..e8034113983e 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -117,17 +117,20 @@ def maybe_make_prepare_finalize( "Detected DP deployment with no --enable-expert-parallel. " "Falling back to AllGather+ReduceScatter dispatch/combine." ) + device_communicator = get_ep_group().device_communicator + assert device_communicator is not None + assert device_communicator.all2all_manager is not None return make_moe_prepare_and_finalize_naive_dp_ep( is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel, - num_dispatchers=( - get_ep_group().device_communicator.all2all_manager.world_size - ), + num_dispatchers=(device_communicator.all2all_manager.world_size), use_monolithic=use_monolithic, ) else: return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic) - all2all_manager = get_ep_group().device_communicator.all2all_manager + device_communicator = get_ep_group().device_communicator + assert device_communicator is not None + all2all_manager = device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: FusedMoEPrepareAndFinalize | None = None diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 1f077ab80be5..04a05069339b 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -7,6 +7,7 @@ import torch from vllm.config import ParallelConfig, SchedulerConfig +from vllm.config.kernel import MoEBackend from vllm.distributed import get_dp_group, get_pcp_group, get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -1192,7 +1193,7 @@ class FusedMoEConfig: # Defaults to intermediate_size_per_partition if not specified. intermediate_size_per_partition_unpadded: int | None = None - moe_backend: str = "auto" + moe_backend: MoEBackend = "auto" max_num_tokens: int = SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP has_bias: bool = False is_act_and_mul: bool = True diff --git a/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py index 2cb0bd7649f5..fad39b3e9d4a 100644 --- a/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/batched_deep_gemm_moe.py @@ -210,9 +210,9 @@ def persistent_masked_m_silu_mul_quant( DeepGemmQuantScaleFMT.UE8M0, ] - cuda_arch = current_platform.get_device_capability( - device_id=y.device.index - ).to_int() + device_capability = current_platform.get_device_capability(device_id=y.device.index) + assert device_capability is not None + cuda_arch = device_capability.to_int() if current_platform.is_cuda() and cuda_arch >= 80: torch.ops._C.persistent_masked_m_silu_mul_quant( diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 917d474fc9bd..587954d5267f 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -7,6 +7,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs +from vllm.config.kernel import MoEBackend from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoEConfig, @@ -146,7 +147,7 @@ def backend_to_kernel_cls( raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}") -def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend: +def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend: """Map user's moe_backend string to Mxfp4MoeBackend.""" mapping: dict[str, Mxfp4MoeBackend] = { "flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, @@ -201,10 +202,12 @@ def select_gpt_oss_mxfp4_moe_backend( Select the primary MXFP4 MoE backend. Note: Shape-specific fallbacks may still occur at runtime. """ - triton_kernels_supported = has_triton_kernels() and ( - 9, - 0, - ) <= current_platform.get_device_capability() < (11, 0) + device_capability = current_platform.get_device_capability() + triton_kernels_supported = ( + has_triton_kernels() + and device_capability is not None + and (9, 0) <= device_capability < (11, 0) + ) # LoRA: separate experts backend path if config.is_lora_enabled: diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py index bdde3da6b3a3..a04ff3b8b68f 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py @@ -4,6 +4,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.distributed import get_ep_group +from vllm.distributed.device_communicators.base_device_communicator import ( + All2AllManagerBase, +) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input @@ -11,12 +14,16 @@ def get_local_sizes(): - return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None + return dp_metadata.get_chunk_sizes_across_dp_rank() class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """FlashInfer implementation using the Moe AlltoAll kernel.""" + all2all_manager: All2AllManagerBase + def __init__( self, max_num_tokens: int, @@ -32,8 +39,12 @@ def __init__( self.hidden_size = hidden_size self.num_dispatchers_ = num_dispatchers - self.all2all_manager = get_ep_group().device_communicator.all2all_manager - self.all2all_manager.initialize( + device_communicator = get_ep_group().device_communicator + assert device_communicator is not None + all2all_manager = device_communicator.all2all_manager + assert all2all_manager is not None + self.all2all_manager = all2all_manager + self.all2all_manager.initialize( # type: ignore[attr-defined] max_num_tokens=self.max_num_tokens, top_k=self.top_k, num_experts=self.num_experts, @@ -97,7 +108,8 @@ def prepare( payloads.append(topk_ids) payloads.append(topk_weights) - recv_payloads = self.all2all_manager.moe_alltoall.dispatch( + assert self.all2all_manager.moe_alltoall is not None # type: ignore[attr-defined] + recv_payloads = self.all2all_manager.moe_alltoall.dispatch( # type: ignore[attr-defined] token_selected_experts=topk_ids, input_payloads=payloads, runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, @@ -131,7 +143,7 @@ def finalize( apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: - assert self.all2all_manager.moe_alltoall is not None + assert self.all2all_manager.moe_alltoall is not None # type: ignore[attr-defined] ep_size = self.all2all_manager.world_size hidden_size = fused_expert_output.shape[-1] @@ -139,7 +151,7 @@ def finalize( ep_size, self.runtime_max_tokens_per_rank, hidden_size ) - combined_output = self.all2all_manager.moe_alltoall.combine( + combined_output = self.all2all_manager.moe_alltoall.combine( # type: ignore[attr-defined] payload=fused_expert_output, runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, ) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py index be63bd4e3f61..47fe293d511e 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py @@ -15,19 +15,26 @@ def get_local_sizes(): - return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None + return dp_metadata.get_chunk_sizes_across_dp_rank() class FlashInferNVLinkTwoSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """Base class for FlashInfer MoE prepare and finalize operations.""" + all2all_manager: All2AllManagerBase + def __init__( self, num_dispatchers: int = 1, ): super().__init__() self.num_dispatchers_ = num_dispatchers - self.all2all_manager = get_ep_group().device_communicator.all2all_manager + device_communicator = get_ep_group().device_communicator + assert device_communicator is not None + assert device_communicator.all2all_manager is not None + self.all2all_manager = device_communicator.all2all_manager @property def activation_format(self) -> mk.FusedMoEActivationFormat: @@ -129,7 +136,7 @@ def flashinfer_alltoall_dispatch( ): from flashinfer.comm.trtllm_alltoall import MnnvlMoe - assert all2all_manager.ensure_alltoall_workspace_initialized(), ( + assert all2all_manager.ensure_alltoall_workspace_initialized(), ( # type: ignore[attr-defined] "FlashInfer AllToAll workspace not available" ) @@ -144,7 +151,7 @@ def flashinfer_alltoall_dispatch( topk_ids, topk_weights, None, - all2all_manager.prepare_workspace_tensor, + all2all_manager.prepare_workspace_tensor, # type: ignore[attr-defined] max_num_token, ep_rank, ep_size, @@ -172,7 +179,7 @@ def flashinfer_alltoall_dispatch( x = MnnvlMoe.mnnvl_moe_alltoallv( x, alltoall_info, - all2all_manager.workspace_tensor, + all2all_manager.workspace_tensor, # type: ignore[attr-defined] ep_rank, ep_size, ) @@ -180,7 +187,7 @@ def flashinfer_alltoall_dispatch( x_sf = MnnvlMoe.mnnvl_moe_alltoallv( x_sf, alltoall_info, - all2all_manager.workspace_tensor, + all2all_manager.workspace_tensor, # type: ignore[attr-defined] ep_rank, ep_size, ) @@ -196,7 +203,7 @@ def flashinfer_alltoall_dispatch( x = MnnvlMoe.mnnvl_moe_alltoallv( x, alltoall_info, - all2all_manager.workspace_tensor, + all2all_manager.workspace_tensor, # type: ignore[attr-defined] ep_rank, ep_size, ) @@ -212,13 +219,13 @@ def flashinfer_alltoall_combine( ): from flashinfer.comm.trtllm_alltoall import MnnvlMoe - assert all2all_manager.ensure_alltoall_workspace_initialized(), ( + assert all2all_manager.ensure_alltoall_workspace_initialized(), ( # type: ignore[attr-defined] "FlashInfer AllToAll workspace not available" ) return MnnvlMoe.mnnvl_moe_alltoallv_combine( output, alltoall_info, - all2all_manager.workspace_tensor, + all2all_manager.workspace_tensor, # type: ignore[attr-defined] ep_rank=all2all_manager.rank, ep_size=all2all_manager.world_size, top_k=top_k, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py index 6dc9f6958048..2b21e2db9f68 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py @@ -132,9 +132,11 @@ def prepare( ) if scales is None: + assert len(res) == 3 a1q, topk_weights, topk_ids = res a1q_scale = None else: + assert len(res) == 4 a1q, topk_weights, topk_ids, scales = res a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config) @@ -217,9 +219,11 @@ def prepare( ) if scales is None: + assert len(res) == 2 a1q, router_logits = res a1q_scale = None else: + assert len(res) == 3 a1q, router_logits, scales = res a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config) diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index 85c6563c0849..62689289ed28 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -58,11 +58,13 @@ def _maybe_dispatch( # NOTE: this will be removed once all kernels are migrated into the # MoEKernel framework. if self.do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch_router_logits( + res = get_ep_group().dispatch_router_logits( hidden_states, router_logits, self.moe_config.is_sequence_parallel, ) + assert len(res) == 2 + hidden_states, router_logits = res # NOTE: Similar with DP, PCP also needs dispatch and combine. For # simplicity, AgRsAll2All was added separately for PCP here. Maybe diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index b09f980c7e60..70c67f33f0ac 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -16,7 +16,6 @@ from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader from vllm.model_executor.utils import set_weight_attrs from vllm.utils.torch_utils import direct_register_custom_op -from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from .fla.ops.kda import ( @@ -123,7 +122,7 @@ def __init__( self.cache_config = cache_config if model_config is None: raise ValueError("model_config must be provided") - kda_config = model_config.linear_attn_config + kda_config = model_config.linear_attn_config # type: ignore[attr-defined] self.head_dim = kda_config["head_dim"] self.num_heads = kda_config["num_heads"] self.layer_idx = layer_idx @@ -297,19 +296,21 @@ def _forward( core_attn_out: torch.Tensor, ) -> None: forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata + attn_metadata_raw = forward_context.attn_metadata - if attn_metadata is None: + if attn_metadata_raw is None: # # V1 profile run return - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, GDNAttentionMetadata) - has_initial_state = attn_metadata.has_initial_state - non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc - non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - num_actual_tokens = attn_metadata.num_actual_tokens + assert isinstance(attn_metadata_raw, dict) + attn_metadata_narrowed = attn_metadata_raw[self.prefix] + assert isinstance(attn_metadata_narrowed, GDNAttentionMetadata) + has_initial_state = attn_metadata_narrowed.has_initial_state + non_spec_query_start_loc = attn_metadata_narrowed.non_spec_query_start_loc + non_spec_state_indices_tensor = ( + attn_metadata_narrowed.non_spec_state_indices_tensor + ) # noqa: E501 + num_actual_tokens = attn_metadata_narrowed.num_actual_tokens constant_caches = self.kv_cache q_proj_states = q_proj_states[:num_actual_tokens] @@ -335,7 +336,7 @@ def _forward( v_conv_weights = self.v_conv1d.weight.view( self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2) ) - if attn_metadata.num_prefills > 0: + if attn_metadata_narrowed.num_prefills > 0: q_proj_states = q_proj_states.transpose(0, 1) k_proj_states = k_proj_states.transpose(0, 1) v_proj_states = v_proj_states.transpose(0, 1) @@ -348,7 +349,7 @@ def _forward( has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, - metadata=attn_metadata, + metadata=attn_metadata_narrowed, ).transpose(0, 1) k = causal_conv1d_fn( k_proj_states, @@ -359,7 +360,7 @@ def _forward( has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, - metadata=attn_metadata, + metadata=attn_metadata_narrowed, ).transpose(0, 1) v = causal_conv1d_fn( v_proj_states, @@ -370,11 +371,12 @@ def _forward( has_initial_state=has_initial_state, cache_indices=non_spec_state_indices_tensor, query_start_loc=non_spec_query_start_loc, - metadata=attn_metadata, + metadata=attn_metadata_narrowed, ).transpose(0, 1) else: + assert non_spec_state_indices_tensor is not None decode_conv_indices = non_spec_state_indices_tensor[ - : attn_metadata.num_actual_tokens + : attn_metadata_narrowed.num_actual_tokens ] q = causal_conv1d_update( q_proj_states, @@ -408,7 +410,9 @@ def _forward( lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v) ) - if attn_metadata.num_prefills > 0: + if attn_metadata_narrowed.num_prefills > 0: + assert non_spec_state_indices_tensor is not None + assert has_initial_state is not None zero_idx = non_spec_state_indices_tensor[~has_initial_state] recurrent_state[zero_idx] = 0 initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous() @@ -429,6 +433,7 @@ def _forward( # Init cache recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state else: + assert non_spec_query_start_loc is not None ( core_attn_out_non_spec, last_recurrent_state, @@ -440,7 +445,9 @@ def _forward( beta=beta, initial_state=recurrent_state, use_qk_l2norm_in_kernel=True, - cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], + cu_seqlens=non_spec_query_start_loc[ + : attn_metadata_narrowed.num_decodes + 1 + ], ssm_state_indices=non_spec_state_indices_tensor, ) core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[ diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index e4d2d2be090e..c56f57658377 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -80,7 +80,7 @@ def poly_norm( from vllm import _custom_ops as ops out = torch.empty_like(x) - ops.poly_norm( + ops.poly_norm( # type: ignore[attr-defined] out, x, weight, diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 3c6b0139424d..2c05880c0fe1 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -42,9 +42,10 @@ def get_state_dtype(self) -> tuple[torch.dtype, ...]: def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: mamba_block_size = vllm_config.cache_config.mamba_block_size + assert mamba_block_size is not None page_size_padded = vllm_config.cache_config.mamba_page_size_padded return MambaSpec( - shapes=self.get_state_shape(), + shapes=tuple(self.get_state_shape()), dtypes=self.get_state_dtype(), block_size=mamba_block_size, page_size_padded=page_size_padded, diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 70a4794ad545..c74ca13024a8 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -62,7 +62,6 @@ _resolve_layer_name, direct_register_custom_op, ) -from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata logger = init_logger(__name__) @@ -121,9 +120,9 @@ def fi_chunk_gated_delta_rule( class ChunkGatedDeltaRule(CustomOp): def __init__(self) -> None: super().__init__() - backend_cfg = get_current_vllm_config().additional_config.get( - "gdn_prefill_backend", "auto" - ) + additional_config = get_current_vllm_config().additional_config + assert isinstance(additional_config, dict) + backend_cfg = additional_config.get("gdn_prefill_backend", "auto") backend = str(backend_cfg).strip().lower() supports_flashinfer = ( @@ -621,18 +620,19 @@ def forward_xpu( # Part 2: Core Attention # ============================================================ forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata + attn_metadata_raw = forward_context.attn_metadata core_attn_out = torch.zeros( (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), dtype=hidden_states.dtype, device=hidden_states.device, ) z = torch.empty_like(core_attn_out) - if attn_metadata is not None: - attn_metadata = attn_metadata[self.prefix] + if attn_metadata_raw is not None: + assert isinstance(attn_metadata_raw, dict) + attn_metadata = attn_metadata_raw[self.prefix] # TODO: xpu does not support this param yet - spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_sequence_masks = attn_metadata.spec_sequence_masks # type: ignore[attr-defined] assert spec_sequence_masks is None conv_weights = self.conv1d.weight.view( @@ -658,12 +658,12 @@ def forward_xpu( activation=self.activation, A_log=self.A_log, dt_bias=self.dt_bias, - num_prefills=attn_metadata.num_prefills, - num_decodes=attn_metadata.num_decodes, - has_initial_state=attn_metadata.has_initial_state, - non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, - non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, - num_actual_tokens=attn_metadata.num_actual_tokens, + num_prefills=attn_metadata.num_prefills, # type: ignore[attr-defined] + num_decodes=attn_metadata.num_decodes, # type: ignore[attr-defined] + has_initial_state=attn_metadata.has_initial_state, # type: ignore[attr-defined] + non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, # type: ignore[attr-defined] + non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, # type: ignore[attr-defined] + num_actual_tokens=attn_metadata.num_actual_tokens, # type: ignore[attr-defined] tp_size=self.tp_size, reorder_input=not self.gqa_interleaved_layout, ) @@ -792,16 +792,16 @@ def _forward_core( core_attn_out: torch.Tensor, ): forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata + attn_metadata_raw = forward_context.attn_metadata - if attn_metadata is None: + if attn_metadata_raw is None: # V1 profile run — warm up prefill kernels so that # autotuning completes before KV cache allocation. self._warmup_prefill_kernels(mixed_qkv) return - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata_raw, dict) + attn_metadata = attn_metadata_raw[self.prefix] # type: ignore[index] assert isinstance(attn_metadata, GDNAttentionMetadata) if ( @@ -860,14 +860,16 @@ def _forward_core( # 1.1: Process the multi-query part if spec_sequence_masks is not None: + # spec_state_indices_tensor is always set when spec_sequence_masks is set + assert spec_state_indices_tensor is not None mixed_qkv_spec = causal_conv1d_update( mixed_qkv_spec, conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=spec_state_indices_tensor[:, 0][ - : attn_metadata.num_spec_decodes + conv_state_indices=spec_state_indices_tensor[:, 0][ # type: ignore[index] + : attn_metadata.num_spec_decodes # type: ignore[attr-defined] ], num_accepted_tokens=num_accepted_tokens, query_start_loc=spec_query_start_loc, @@ -900,8 +902,8 @@ def _forward_core( conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=non_spec_state_indices_tensor[ - : attn_metadata.num_actual_tokens + conv_state_indices=non_spec_state_indices_tensor[ # type: ignore[index] + : attn_metadata.num_actual_tokens # type: ignore[attr-defined] ], validate_data=True, ) @@ -965,8 +967,9 @@ def _forward_core( v=value_spec, initial_state=ssm_state, inplace_final_state=True, - cu_seqlens=spec_query_start_loc[ - : attn_metadata.num_spec_decodes + 1 + cu_seqlens=spec_query_start_loc[ # type: ignore[index] + : attn_metadata.num_spec_decodes + + 1 # type: ignore[attr-defined] ], ssm_state_indices=spec_state_indices_tensor, num_accepted_tokens=num_accepted_tokens, @@ -978,8 +981,10 @@ def _forward_core( # 2.2: Process the remaining part if attn_metadata.num_prefills > 0: - initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() - initial_state[~has_initial_state, ...] = 0 + assert non_spec_state_indices_tensor is not None + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() # type: ignore[index] + assert has_initial_state is not None + initial_state[~has_initial_state, ...] = 0 # type: ignore[operator] ( core_attn_out_non_spec, last_recurrent_state, @@ -1012,8 +1017,9 @@ def _forward_core( v=value_non_spec, initial_state=ssm_state, inplace_final_state=True, - cu_seqlens=non_spec_query_start_loc[ - : attn_metadata.num_decodes + 1 + cu_seqlens=non_spec_query_start_loc[ # type: ignore[index] + : attn_metadata.num_decodes + + 1 # type: ignore[attr-defined] ], ssm_state_indices=non_spec_state_indices_tensor, use_qk_l2norm_in_kernel=True, @@ -1073,7 +1079,7 @@ def _forward_core_decode_non_spec( conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], + conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], # type: ignore[index] validate_data=False, ) out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1) @@ -1086,7 +1092,7 @@ def _forward_core_decode_non_spec( scale=self.head_k_dim**-0.5, initial_state=ssm_state, out=out_buf, - ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], + ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], # type: ignore[index] use_qk_l2norm_in_kernel=True, ) return diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 18fcc1426cc5..8e8527aed8af 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -396,10 +396,11 @@ def _forward( self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor ) -> None: forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] + attn_metadata_raw = forward_context.attn_metadata + attn_metadata: AttentionMetadata | None = None + if attn_metadata_raw is not None: + assert isinstance(attn_metadata_raw, dict) + attn_metadata = attn_metadata_raw[self.prefix] assert isinstance(attn_metadata, LinearAttentionMetadata) num_actual_tokens = ( attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 4509a0956280..0e476755201e 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -40,6 +40,7 @@ _resolve_layer_name, direct_register_custom_op, ) +from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -258,15 +259,16 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): """ forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata + attn_metadata_raw = forward_context.attn_metadata assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all" - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] + attn_metadata: AttentionMetadata | None = None + if attn_metadata_raw is not None: + assert isinstance(attn_metadata_raw, dict) + attn_metadata = attn_metadata_raw[self.prefix] assert isinstance(attn_metadata, Mamba1AttentionMetadata) query_start_loc_p = attn_metadata.query_start_loc_p state_indices_tensor_p = attn_metadata.state_indices_tensor_p @@ -391,6 +393,9 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): ssm_outputs.append(scan_out_p) if has_decode: + # state_indices_tensor_d is assigned when attn_metadata is not None, + # and has_decode is only True when attn_metadata is not None + assert state_indices_tensor_d is not None if is_mamba_cache_all: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 0518bde2f427..2b4b1934f9b3 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -572,14 +572,16 @@ def conv_ssm_forward( # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration - attn_metadata: AttentionMetadata = forward_context.attn_metadata + attn_metadata_raw = forward_context.attn_metadata assert self.cache_config is not None mamba_block_size = self.cache_config.mamba_block_size is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all" - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] + + attn_metadata: AttentionMetadata | None = None + if attn_metadata_raw is not None: + assert isinstance(attn_metadata_raw, dict) + attn_metadata = attn_metadata_raw[self.prefix] assert isinstance(attn_metadata, Mamba2AttentionMetadata) # conv_state must be (..., dim, width-1) for the conv kernels. # DS layout stores it that way directly; SD layout needs a @@ -708,6 +710,7 @@ def conv_ssm_forward( # 3. State Space Model sequence transformation initial_states = None if has_initial_states_p is not None and prep_initial_states: + assert state_indices_tensor_p is not None kernel_ssm_indices = state_indices_tensor_p if is_mamba_cache_all: kernel_ssm_indices = state_indices_tensor_p.gather( @@ -746,6 +749,13 @@ def conv_ssm_forward( ) if is_mamba_cache_all: + assert mamba_block_size is not None + assert state_indices_tensor_p is not None + assert block_idx_first_scheduled_token_p is not None + assert block_idx_last_scheduled_token_p is not None + assert last_chunk_indices_p is not None + assert num_computed_tokens_p is not None + # The chunk_stride is the number of chunks per mamba block # e.g., if mamba_block_size = 512 and chunk_size = 256, # then chunk_stride = 2 @@ -810,6 +820,7 @@ def conv_ssm_forward( ssm_state[cache_blocks_to_fill] = from_where # For all seqs, store the last state (note: might be partial): + assert state_indices_tensor_p is not None ssm_state[ state_indices_tensor_p.gather( 1, block_idx_last_scheduled_token_p.unsqueeze(1) @@ -820,10 +831,12 @@ def conv_ssm_forward( # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) # tensor + assert state_indices_tensor_p is not None ssm_state[state_indices_tensor_p] = varlen_states # Process decode requests if has_decode: + assert state_indices_tensor_d is not None if is_mamba_cache_all: state_indices_tensor_d_input = state_indices_tensor_d.gather( 1, block_idx_last_computed_token_d.unsqueeze(1) diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 11e9b590f86a..629167acfe52 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -113,10 +113,11 @@ def forward_cuda( # chunked prefill modes; they are computed at top-level model forward # since they stay the same and reused for all mamba layers in the same # iteration. - attn_metadata: AttentionMetadata = forward_context.attn_metadata - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] + attn_metadata_raw = forward_context.attn_metadata + attn_metadata: AttentionMetadata | None = None + if attn_metadata_raw is not None: + assert isinstance(attn_metadata_raw, dict) + attn_metadata = attn_metadata_raw[self.prefix] assert isinstance(attn_metadata, ShortConvAttentionMetadata) conv_state = ( self.kv_cache[0] diff --git a/vllm/model_executor/layers/pooler/seqwise/poolers.py b/vllm/model_executor/layers/pooler/seqwise/poolers.py index 74fa4cdbbe43..762869c67823 100644 --- a/vllm/model_executor/layers/pooler/seqwise/poolers.py +++ b/vllm/model_executor/layers/pooler/seqwise/poolers.py @@ -115,6 +115,7 @@ def pooler_for_classify( vllm_config = get_current_vllm_config() model_config = vllm_config.model_config + assert model_config.pooler_config is not None head = ClassifierPoolerHead( head_dtype=model_config.head_dtype, classifier=classifier, diff --git a/vllm/model_executor/layers/pooler/tokwise/poolers.py b/vllm/model_executor/layers/pooler/tokwise/poolers.py index 6462a5056c55..131074a5556a 100644 --- a/vllm/model_executor/layers/pooler/tokwise/poolers.py +++ b/vllm/model_executor/layers/pooler/tokwise/poolers.py @@ -124,6 +124,7 @@ def pooler_for_token_classify( vllm_config = get_current_vllm_config() model_config = vllm_config.model_config + assert model_config.pooler_config is not None head = TokenClassifierPoolerHead( head_dtype=model_config.head_dtype, classifier=classifier, diff --git a/vllm/model_executor/layers/quantization/fp_quant.py b/vllm/model_executor/layers/quantization/fp_quant.py index 4ed8d57dd430..7d0b6a974d7c 100644 --- a/vllm/model_executor/layers/quantization/fp_quant.py +++ b/vllm/model_executor/layers/quantization/fp_quant.py @@ -3,7 +3,7 @@ # Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202 -from typing import Any +from typing import Any, Literal, cast import torch from torch.nn.parameter import Parameter @@ -251,7 +251,11 @@ def apply( def fused_quantize_mx( x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str ) -> tuple[torch.Tensor, torch.Tensor]: - return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method) + return fusedQuantizeMx( + x_flat, + hadamard_matrix, + method=cast(Literal["quest", "abs_max"], forward_method), + ) def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method): diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 33bd0cfc22e8..6aaf9a645880 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -114,7 +114,7 @@ def apply_vllm_mapper( # noqa: B027 :param hf_to_vllm_mapper: maps from hf model structure (the assumed structure of the qconfig) to vllm model structure """ - quant_config_with_hf_to_vllm_mapper = {} + quant_config_with_hf_to_vllm_mapper: dict[str, Any] = {} for k, v in self.quant_config.items(): if isinstance(v, list): diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index bdaa6af09456..b84276c16e1c 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -26,7 +26,7 @@ if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops elif current_platform.is_xpu(): - from vllm._xpu_ops import xpu_ops as ops + from vllm import _xpu_ops logger = init_logger(__name__) @@ -84,12 +84,12 @@ def sparse_attn_indexer( total_seq_lens, topk_indices_buffer, ) - attn_metadata = attn_metadata[k_cache_prefix] - assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) - slot_mapping = attn_metadata.slot_mapping - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens + attn_metadata_narrowed = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata_narrowed, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata_narrowed.slot_mapping + has_decode = attn_metadata_narrowed.num_decodes > 0 + has_prefill = attn_metadata_narrowed.num_prefills > 0 + num_decode_tokens = attn_metadata_narrowed.num_decode_tokens # During speculative decoding, k may be padded to the CUDA graph batch # size while slot_mapping only covers actual tokens. Truncate k to avoid @@ -97,6 +97,8 @@ def sparse_attn_indexer( num_tokens = slot_mapping.shape[0] k = k[:num_tokens] + # scale_fmt can be None, but the function expects str + assert scale_fmt is not None ops.indexer_k_quant_and_cache( k, kv_cache, @@ -107,7 +109,7 @@ def sparse_attn_indexer( topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: - prefill_metadata = attn_metadata.prefill + prefill_metadata = attn_metadata_narrowed.prefill assert prefill_metadata is not None # Get the full shared workspace buffers once (will allocate on first use) @@ -144,7 +146,7 @@ def sparse_attn_indexer( ] if current_platform.is_xpu(): - ops.top_k_per_row_prefill( + _xpu_ops.xpu_ops.top_k_per_row_prefill( # type: ignore[attr-defined] logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, @@ -167,7 +169,7 @@ def sparse_attn_indexer( ) if has_decode: - decode_metadata = attn_metadata.decode + decode_metadata = attn_metadata_narrowed.decode assert decode_metadata is not None # kv_cache shape [ # kv_cache size requirement [num_block, block_size, n_head, head_dim], @@ -217,11 +219,11 @@ def sparse_attn_indexer( topk_indices, topk_workspace, topk_tokens, - attn_metadata.max_seq_len, + attn_metadata_narrowed.max_seq_len, ) else: if current_platform.is_xpu(): - ops.top_k_per_row_decode( + _xpu_ops.xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined] logits, next_n, seq_lens, From a43a34f21a57064bf47250c0f47618edfadf00e8 Mon Sep 17 00:00:00 2001 From: Martin Hickey Date: Mon, 20 Apr 2026 09:13:40 +0100 Subject: [PATCH 2/3] Update after review Review comment: - https://github.com/vllm-project/vllm/pull/40159#discussion_r3105454110 Signed-off-by: Martin Hickey --- vllm/model_executor/layers/attention/mla_attention.py | 4 ++-- vllm/model_executor/layers/sparse_attn_indexer.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 02f322116075..74d86a8cb7c7 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1059,9 +1059,9 @@ class QueryLenSupport(Enum): "AITER_MLA backends use aiter kernels instead." ) elif current_platform.is_xpu(): - from vllm import _xpu_ops + from vllm._xpu_ops import xpu_ops - flash_attn_varlen_func = _xpu_ops.xpu_ops.flash_attn_varlen_func # type: ignore[no-redef,attr-defined,assignment] + flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func # type: ignore[no-redef,attr-defined,assignment] def dynamic_per_batched_tensor_quant( diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index b84276c16e1c..1f1f7e7df89f 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -26,7 +26,7 @@ if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops elif current_platform.is_xpu(): - from vllm import _xpu_ops + from vllm._xpu_ops import xpu_ops logger = init_logger(__name__) @@ -146,7 +146,7 @@ def sparse_attn_indexer( ] if current_platform.is_xpu(): - _xpu_ops.xpu_ops.top_k_per_row_prefill( # type: ignore[attr-defined] + xpu_ops.top_k_per_row_prefill( # type: ignore[attr-defined] logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, @@ -223,7 +223,7 @@ def sparse_attn_indexer( ) else: if current_platform.is_xpu(): - _xpu_ops.xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined] + xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined] logits, next_n, seq_lens, From 4637206d6dae273e47dd6fdab9f0a189d4c52bc1 Mon Sep 17 00:00:00 2001 From: Martin Hickey Date: Mon, 20 Apr 2026 10:12:01 +0100 Subject: [PATCH 3/3] Fix unit tests Signed-off-by: Martin Hickey --- vllm/model_executor/layers/attention/attention.py | 5 +++-- vllm/model_executor/layers/attention/mla_attention.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 0f3c50684552..50650fef6dd2 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -686,11 +686,12 @@ def get_attention_context( attn_metadata: AttentionMetadata if isinstance(attn_metadata_raw, dict): attn_metadata = attn_metadata_raw[layer_name] - else: + elif isinstance(attn_metadata_raw, list): # list[dict[str, AttentionMetadata]]: used in speculative decoding # where [0] is the base-model (non-speculative) metadata dict. - assert isinstance(attn_metadata_raw, list) attn_metadata = attn_metadata_raw[0][layer_name] + else: + attn_metadata = attn_metadata_raw attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] kv_cache = attn_layer.kv_cache slot_mapping = forward_context.slot_mapping diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 74d86a8cb7c7..a1e3921b0c59 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -489,11 +489,12 @@ def forward( attn_metadata: MLACommonMetadata if isinstance(attn_metadata_raw, dict): attn_metadata = attn_metadata_raw[self.layer_name] # type: ignore[assignment] - else: + elif isinstance(attn_metadata_raw, list): # list[dict[str, AttentionMetadata]]: used in speculative decoding # where [0] is the base-model (non-speculative) metadata dict. - assert isinstance(attn_metadata_raw, list) attn_metadata = attn_metadata_raw[0][self.layer_name] # type: ignore[assignment] + else: + attn_metadata = attn_metadata_raw self_kv_cache = self.kv_cache slot_mapping = forward_context.slot_mapping