From 53236e3cf4243ba4c7b446ec2bb2609c2b302315 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 11 Dec 2025 21:53:37 +0000 Subject: [PATCH 1/7] [Kernels][MoE] Add FusedMoERouter object Signed-off-by: Bill Nell --- tests/test_routing_simulator.py | 2 +- .../layers/fused_moe/__init__.py | 4 ++ .../layers/fused_moe/fused_moe_method_base.py | 4 ++ .../fused_moe/fused_moe_modular_method.py | 4 +- .../layers/fused_moe/fused_moe_router.py | 41 +++++++++++++++++++ vllm/model_executor/layers/fused_moe/layer.py | 28 +++++++++++-- .../fused_moe/unquantized_fused_moe_method.py | 8 +++- .../layers/quantization/awq_marlin.py | 4 +- .../layers/quantization/bitsandbytes.py | 9 +++- .../compressed_tensors_moe.py | 20 ++++++--- .../layers/quantization/experts_int8.py | 4 +- .../model_executor/layers/quantization/fp8.py | 4 +- .../layers/quantization/gguf.py | 9 +++- .../layers/quantization/gptq_marlin.py | 4 +- .../layers/quantization/ipex_quant.py | 4 ++ .../layers/quantization/modelopt.py | 22 ++++++---- .../layers/quantization/moe_wna16.py | 4 +- .../layers/quantization/mxfp4.py | 9 ++-- .../layers/quantization/quark/quark_moe.py | 7 +++- .../model_executor/layers/quantization/rtn.py | 9 +++- 20 files changed, 165 insertions(+), 35 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/fused_moe_router.py diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py index 44cbdeed4507..e37f30755663 100644 --- a/tests/test_routing_simulator.py +++ b/tests/test_routing_simulator.py @@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device): envs.environment_variables[env_name] = lambda s=strategy: s # Test the select_experts method - topk_weights, topk_ids = fused_moe.select_experts( + topk_weights, topk_ids = fused_moe.router.select_experts( hidden_states=hidden_states, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9f04397e91f7..b1bd580b963a 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoeWeightScaleSupported, @@ -48,6 +51,7 @@ def get_config() -> dict[str, Any] | None: __all__ = [ "FusedMoE", + "FusedMoERouter", "FusedMoEConfig", "FusedMoEMethodBase", "UnquantizedFusedMoEMethod", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index a46e3972ed8e..389ccf358c56 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -10,6 +10,9 @@ FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, @@ -109,6 +112,7 @@ def method_name(self) -> str: def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 6abefde0763e..10fa0ca7930d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel, FusedMoEPrepareAndFinalize, @@ -88,10 +89,11 @@ def get_fused_moe_quant_config( def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_router.py b/vllm/model_executor/layers/fused_moe/fused_moe_router.py new file mode 100644 index 000000000000..46ac88e7d71e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_router.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + +import torch + +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType + + +class FusedMoERouter(ABC): + """ + FusedMoERouter is an abstract class that primarily provides a + 'select_experts' method that is used for routing hidden states based + on router logits. + """ + + @property + @abstractmethod + def routing_method_type(self) -> RoutingMethodType: + raise NotImplementedError + + @abstractmethod + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """ + Route the input hidden states to the top-k experts based on the + router logits. + + Returns: + (topk_weights, topk_ids, zero_expert_result) + (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): + The weights, expert ids, and (optional) zero expert computation result. + + **Compatibility**: When EPLB is not enabled, the returned ids are + equivalent to global logical ids, so should be compatible with + plain MoE implementations without redundant experts. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 52f093f62d5a..f6dfcbe41613 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -31,6 +31,7 @@ FusedMoEQuantConfig, RoutingMethodType, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, ) @@ -284,6 +285,23 @@ def maybe_roundup_hidden_size( return hidden_size +class FusedMoERouterImpl(FusedMoERouter): + def __init__(self, layer: "FusedMoE"): + super().__init__() + self.layer = layer + + @property + def routing_method_type(self) -> RoutingMethodType: + return self.layer.routing_method_type + + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + return self.layer._select_experts(hidden_states, router_logits) + + @CustomOp.register("fused_moe") class FusedMoE(CustomOp): """FusedMoE layer for MoE models. @@ -339,7 +357,7 @@ def __init__( is_sequence_parallel=False, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, - routing_method_type: int | None = None, + routing_method_type: RoutingMethodType | None = None, router_logits_dtype: torch.dtype | None = None, ): super().__init__() @@ -529,7 +547,7 @@ def __init__( # ToDo: Better logic to determine the routing method type if routing_method_type is not None: - self.routing_method_type = routing_method_type + self.routing_method_type: RoutingMethodType = routing_method_type else: if scoring_func == "sigmoid": if self.use_grouped_topk: @@ -646,6 +664,8 @@ def _get_quant_method() -> FusedMoEMethodBase: self.batched_hidden_states: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None + self.router = FusedMoERouterImpl(self) + # Note: maybe_init_modular_kernel should only be called by # prepare_communication_buffer_for_model. # This is called after all weight loading and post-processing, so it @@ -1509,7 +1529,7 @@ def ensure_dp_chunking_init(self): device=torch.cuda.current_device(), ) - def select_experts( + def _select_experts( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -1778,6 +1798,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, + router=self.router, x=staged_hidden_states, router_logits=staged_router_logits, ) @@ -1950,6 +1971,7 @@ def forward_impl( # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, + router=self.router, x=hidden_states_combined if do_naive_dispatch_combine else hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 92ef850205fc..1cdc25135a34 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, @@ -285,10 +286,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: return self.forward( + router=router, layer=layer, x=x, router_logits=router_logits, @@ -311,10 +314,11 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantCon def forward_cuda( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -337,6 +341,7 @@ def forward_cuda( def forward_cpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -370,6 +375,7 @@ def forward_cpu( def forward_xpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 602d02d2f15a..5763a41193e8 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -15,6 +15,7 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -759,12 +760,13 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index efe5677045e4..1d2334f3933a 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -10,7 +10,11 @@ FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, + FusedMoEMethodBase, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -495,12 +499,13 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 509de5dff9c1..8a6217c65677 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -55,6 +55,7 @@ make_nvfp4_moe_quant_config, select_nvfp4_moe_backend, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP, @@ -458,6 +459,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -484,7 +486,7 @@ def apply( x_routing, _ = x else: x_routing = x - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x_routing, router_logits=router_logits, ) @@ -926,10 +928,11 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1066,12 +1069,13 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1426,6 +1430,7 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -1433,7 +1438,7 @@ def apply( f"{layer.activation} not supported for Marlin MoE." ) - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1677,12 +1682,13 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1978,6 +1984,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: @@ -2290,6 +2297,7 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ): @@ -2298,7 +2306,7 @@ def apply( "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." ) assert self.moe_quant_config is not None - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 56b11b22f7ff..37e6020cb2a9 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -15,6 +15,7 @@ FusedMoEQuantConfig, int8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( @@ -137,12 +138,13 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2879315a6886..1c0c35bf6f41 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -29,6 +29,7 @@ FusedMoEQuantConfig, RoutingMethodType, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, @@ -997,6 +998,7 @@ def allow_inplace(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -1051,7 +1053,7 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9600bb42295d..1c03e5243a85 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -16,7 +16,11 @@ FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, + FusedMoEMethodBase, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -629,6 +633,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -639,7 +644,7 @@ def apply( "fused GGUF MoE method." ) - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3c958588c78f..68a2c375e353 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -15,6 +15,7 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -895,12 +896,13 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 9de2924ec71b..475bd853676e 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -9,6 +9,9 @@ from vllm._ipex_ops import ipex_ops as ops from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -384,6 +387,7 @@ def get_fused_moe_quant_config( def apply( self, layer: torch.nn.Module, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 2e4f1daf6690..43c34d0708f6 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -14,8 +14,10 @@ from vllm.attention.layer import Attention from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -200,7 +202,9 @@ def get_quant_method( quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method elif isinstance(layer, FusedMoE): - quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer) + quant_method = self.FusedMoEMethodCls( + quant_config=self, moe_config=layer.moe_config + ) if getattr(quant_method, "backend", "") == "marlin": quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method @@ -720,9 +724,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def __init__( self, quant_config: ModelOptFp8Config, - layer: FusedMoE, + moe_config: FusedMoEConfig, ) -> None: - super().__init__(layer.moe_config) + super().__init__(moe_config) self.quant_config = quant_config assert self.quant_config.is_checkpoint_fp8_serialized self.fp8_backend = select_fp8_moe_backend( @@ -935,6 +939,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -961,7 +966,8 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - topk_weights, topk_ids = layer.select_experts( + # Expert selection + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1325,9 +1331,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def __init__( self, quant_config: ModelOptNvFp4Config, - layer: FusedMoE, + moe_config: FusedMoEConfig, ) -> None: - super().__init__(layer.moe_config) + super().__init__(moe_config) self.quant_config = quant_config self.nvfp4_backend = select_nvfp4_moe_backend() # TODO: move this type of check into the oracle. @@ -1597,6 +1603,7 @@ def supports_eplb(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -1616,12 +1623,13 @@ def apply( e_score_correction_bias=layer.e_score_correction_bias, ) + assert not isinstance(x, tuple) # Hidden_states in select_experts is only used to extract metadata if isinstance(x, tuple): x_routing, _ = x else: x_routing = x - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x_routing, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 513f6f7b21ab..d5d94082587f 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -11,6 +11,7 @@ int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEConfig, @@ -364,13 +365,14 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 15edd3e613bf..8e050b795f94 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -27,6 +27,7 @@ MarlinExperts, fused_marlin_moe, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, UnfusedOAITritonExperts, @@ -891,6 +892,7 @@ def allow_inplace(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -898,7 +900,7 @@ def apply( raise NotImplementedError("EPLB is not supported for mxfp4") if self.mxfp4_backend == Mxfp4Backend.MARLIN: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -992,7 +994,7 @@ def apply( ): from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1119,7 +1121,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 4ab618dc44ef..66b8dad5db2e 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -13,6 +13,7 @@ FusedMoE, FusedMoEConfig, FusedMoEMethodBase, + FusedMoERouter, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( @@ -350,10 +351,11 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -750,10 +752,11 @@ def allow_inplace(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index dce9c661ec33..239adb384708 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -15,7 +15,11 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, + FusedMoEMethodBase, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -356,10 +360,11 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) From 53946c4f54bf44e45adb14b0b6f2a60ffbe7ec1f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 6 Jan 2026 17:53:18 +0000 Subject: [PATCH 2/7] fix Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe_router.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_router.py b/vllm/model_executor/layers/fused_moe/fused_moe_router.py index 46ac88e7d71e..d1af8ca82773 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_router.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_router.py @@ -24,7 +24,7 @@ def select_experts( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ Route the input hidden states to the top-k experts based on the router logits. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f6dfcbe41613..0e0731bb22e5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -298,7 +298,7 @@ def select_experts( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.layer._select_experts(hidden_states, router_logits) From 95e10b9bbf8a014a926158afb7e379df08b06be3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 6 Jan 2026 18:01:57 +0000 Subject: [PATCH 3/7] fix Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe_router.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_router.py b/vllm/model_executor/layers/fused_moe/fused_moe_router.py index d1af8ca82773..df4644bbe777 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_router.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_router.py @@ -30,9 +30,9 @@ def select_experts( router logits. Returns: - (topk_weights, topk_ids, zero_expert_result) - (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): - The weights, expert ids, and (optional) zero expert computation result. + (topk_weights, topk_ids) + (tuple[torch.Tensor, torch.Tensor]): + The weights and expert ids computation result. **Compatibility**: When EPLB is not enabled, the returned ids are equivalent to global logical ids, so should be compatible with From 3010974ff5b861ac3b2f9dda40485884b164412f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 6 Jan 2026 18:02:46 +0000 Subject: [PATCH 4/7] fix Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe_router.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_router.py b/vllm/model_executor/layers/fused_moe/fused_moe_router.py index df4644bbe777..c322a8cd4cd6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_router.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_router.py @@ -9,9 +9,8 @@ class FusedMoERouter(ABC): """ - FusedMoERouter is an abstract class that primarily provides a - 'select_experts' method that is used for routing hidden states based - on router logits. + FusedMoERouter is an abstract class that provides a 'select_experts' + method that is used for routing hidden states based on router logits. """ @property From 085a02055d42b521a5cb8f48c4aa596b59f8f215 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 6 Jan 2026 18:07:25 +0000 Subject: [PATCH 5/7] fix Signed-off-by: Bill Nell --- vllm/model_executor/layers/quantization/modelopt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 43c34d0708f6..483dac4069a7 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1623,7 +1623,6 @@ def apply( e_score_correction_bias=layer.e_score_correction_bias, ) - assert not isinstance(x, tuple) # Hidden_states in select_experts is only used to extract metadata if isinstance(x, tuple): x_routing, _ = x From a544be287ce213a32e51bf47128d43ebbf333e0b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 6 Jan 2026 18:51:03 +0000 Subject: [PATCH 6/7] fix lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/quantization/quark/quark_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 66b8dad5db2e..6b731314825a 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -544,6 +544,7 @@ def get_fused_moe_quant_config(self, layer): def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: From a3deb6233dd3dfb0b915504d76a25a7367ad95b3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 8 Jan 2026 17:55:41 +0000 Subject: [PATCH 7/7] fix merge Signed-off-by: Bill Nell --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 2 +- vllm/model_executor/layers/quantization/modelopt.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 8a6217c65677..86878c84ab83 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -40,6 +40,7 @@ MarlinExperts, fused_marlin_moe, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, @@ -55,7 +56,6 @@ make_nvfp4_moe_quant_config, select_nvfp4_moe_backend, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 483dac4069a7..a646012ddd3a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -731,7 +731,7 @@ def __init__( assert self.quant_config.is_checkpoint_fp8_serialized self.fp8_backend = select_fp8_moe_backend( block_quant=False, - tp_size=layer.moe_parallel_config.tp_size, + tp_size=moe_config.moe_parallel_config.tp_size, with_lora_support=self.moe.is_lora_enabled, ) self.kernel: mk.FusedMoEModularKernel | None = None