diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py index 44cbdeed4507..e37f30755663 100644 --- a/tests/test_routing_simulator.py +++ b/tests/test_routing_simulator.py @@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device): envs.environment_variables[env_name] = lambda s=strategy: s # Test the select_experts method - topk_weights, topk_ids = fused_moe.select_experts( + topk_weights, topk_ids = fused_moe.router.select_experts( hidden_states=hidden_states, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9f04397e91f7..b1bd580b963a 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoeWeightScaleSupported, @@ -48,6 +51,7 @@ def get_config() -> dict[str, Any] | None: __all__ = [ "FusedMoE", + "FusedMoERouter", "FusedMoEConfig", "FusedMoEMethodBase", "UnquantizedFusedMoEMethod", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index a46e3972ed8e..389ccf358c56 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -10,6 +10,9 @@ FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, @@ -109,6 +112,7 @@ def method_name(self) -> str: def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 6abefde0763e..10fa0ca7930d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel, FusedMoEPrepareAndFinalize, @@ -88,10 +89,11 @@ def get_fused_moe_quant_config( def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_router.py b/vllm/model_executor/layers/fused_moe/fused_moe_router.py new file mode 100644 index 000000000000..c322a8cd4cd6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_router.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + +import torch + +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType + + +class FusedMoERouter(ABC): + """ + FusedMoERouter is an abstract class that provides a 'select_experts' + method that is used for routing hidden states based on router logits. + """ + + @property + @abstractmethod + def routing_method_type(self) -> RoutingMethodType: + raise NotImplementedError + + @abstractmethod + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route the input hidden states to the top-k experts based on the + router logits. + + Returns: + (topk_weights, topk_ids) + (tuple[torch.Tensor, torch.Tensor]): + The weights and expert ids computation result. + + **Compatibility**: When EPLB is not enabled, the returned ids are + equivalent to global logical ids, so should be compatible with + plain MoE implementations without redundant experts. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 52f093f62d5a..0e0731bb22e5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -31,6 +31,7 @@ FusedMoEQuantConfig, RoutingMethodType, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, ) @@ -284,6 +285,23 @@ def maybe_roundup_hidden_size( return hidden_size +class FusedMoERouterImpl(FusedMoERouter): + def __init__(self, layer: "FusedMoE"): + super().__init__() + self.layer = layer + + @property + def routing_method_type(self) -> RoutingMethodType: + return self.layer.routing_method_type + + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.layer._select_experts(hidden_states, router_logits) + + @CustomOp.register("fused_moe") class FusedMoE(CustomOp): """FusedMoE layer for MoE models. @@ -339,7 +357,7 @@ def __init__( is_sequence_parallel=False, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, - routing_method_type: int | None = None, + routing_method_type: RoutingMethodType | None = None, router_logits_dtype: torch.dtype | None = None, ): super().__init__() @@ -529,7 +547,7 @@ def __init__( # ToDo: Better logic to determine the routing method type if routing_method_type is not None: - self.routing_method_type = routing_method_type + self.routing_method_type: RoutingMethodType = routing_method_type else: if scoring_func == "sigmoid": if self.use_grouped_topk: @@ -646,6 +664,8 @@ def _get_quant_method() -> FusedMoEMethodBase: self.batched_hidden_states: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None + self.router = FusedMoERouterImpl(self) + # Note: maybe_init_modular_kernel should only be called by # prepare_communication_buffer_for_model. # This is called after all weight loading and post-processing, so it @@ -1509,7 +1529,7 @@ def ensure_dp_chunking_init(self): device=torch.cuda.current_device(), ) - def select_experts( + def _select_experts( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -1778,6 +1798,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, + router=self.router, x=staged_hidden_states, router_logits=staged_router_logits, ) @@ -1950,6 +1971,7 @@ def forward_impl( # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, + router=self.router, x=hidden_states_combined if do_naive_dispatch_combine else hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 92ef850205fc..1cdc25135a34 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, @@ -285,10 +286,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: return self.forward( + router=router, layer=layer, x=x, router_logits=router_logits, @@ -311,10 +314,11 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantCon def forward_cuda( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -337,6 +341,7 @@ def forward_cuda( def forward_cpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -370,6 +375,7 @@ def forward_cpu( def forward_xpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 602d02d2f15a..5763a41193e8 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -15,6 +15,7 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -759,12 +760,13 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index efe5677045e4..1d2334f3933a 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -10,7 +10,11 @@ FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, + FusedMoEMethodBase, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -495,12 +499,13 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 509de5dff9c1..86878c84ab83 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -40,6 +40,7 @@ MarlinExperts, fused_marlin_moe, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, @@ -458,6 +459,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -484,7 +486,7 @@ def apply( x_routing, _ = x else: x_routing = x - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x_routing, router_logits=router_logits, ) @@ -926,10 +928,11 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1066,12 +1069,13 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1426,6 +1430,7 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -1433,7 +1438,7 @@ def apply( f"{layer.activation} not supported for Marlin MoE." ) - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1677,12 +1682,13 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1978,6 +1984,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: @@ -2290,6 +2297,7 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ): @@ -2298,7 +2306,7 @@ def apply( "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." ) assert self.moe_quant_config is not None - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 56b11b22f7ff..37e6020cb2a9 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -15,6 +15,7 @@ FusedMoEQuantConfig, int8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( @@ -137,12 +138,13 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2879315a6886..1c0c35bf6f41 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -29,6 +29,7 @@ FusedMoEQuantConfig, RoutingMethodType, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, @@ -997,6 +998,7 @@ def allow_inplace(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -1051,7 +1053,7 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9600bb42295d..1c03e5243a85 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -16,7 +16,11 @@ FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, + FusedMoEMethodBase, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -629,6 +633,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -639,7 +644,7 @@ def apply( "fused GGUF MoE method." ) - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3c958588c78f..68a2c375e353 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -15,6 +15,7 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -895,12 +896,13 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 9de2924ec71b..475bd853676e 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -9,6 +9,9 @@ from vllm._ipex_ops import ipex_ops as ops from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -384,6 +387,7 @@ def get_fused_moe_quant_config( def apply( self, layer: torch.nn.Module, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 2e4f1daf6690..a646012ddd3a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -14,8 +14,10 @@ from vllm.attention.layer import Attention from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -200,7 +202,9 @@ def get_quant_method( quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method elif isinstance(layer, FusedMoE): - quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer) + quant_method = self.FusedMoEMethodCls( + quant_config=self, moe_config=layer.moe_config + ) if getattr(quant_method, "backend", "") == "marlin": quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method @@ -720,14 +724,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def __init__( self, quant_config: ModelOptFp8Config, - layer: FusedMoE, + moe_config: FusedMoEConfig, ) -> None: - super().__init__(layer.moe_config) + super().__init__(moe_config) self.quant_config = quant_config assert self.quant_config.is_checkpoint_fp8_serialized self.fp8_backend = select_fp8_moe_backend( block_quant=False, - tp_size=layer.moe_parallel_config.tp_size, + tp_size=moe_config.moe_parallel_config.tp_size, with_lora_support=self.moe.is_lora_enabled, ) self.kernel: mk.FusedMoEModularKernel | None = None @@ -935,6 +939,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -961,7 +966,8 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - topk_weights, topk_ids = layer.select_experts( + # Expert selection + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1325,9 +1331,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def __init__( self, quant_config: ModelOptNvFp4Config, - layer: FusedMoE, + moe_config: FusedMoEConfig, ) -> None: - super().__init__(layer.moe_config) + super().__init__(moe_config) self.quant_config = quant_config self.nvfp4_backend = select_nvfp4_moe_backend() # TODO: move this type of check into the oracle. @@ -1597,6 +1603,7 @@ def supports_eplb(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -1621,7 +1628,7 @@ def apply( x_routing, _ = x else: x_routing = x - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x_routing, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 513f6f7b21ab..d5d94082587f 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -11,6 +11,7 @@ int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEConfig, @@ -364,13 +365,14 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 15edd3e613bf..8e050b795f94 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -27,6 +27,7 @@ MarlinExperts, fused_marlin_moe, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, UnfusedOAITritonExperts, @@ -891,6 +892,7 @@ def allow_inplace(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -898,7 +900,7 @@ def apply( raise NotImplementedError("EPLB is not supported for mxfp4") if self.mxfp4_backend == Mxfp4Backend.MARLIN: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -992,7 +994,7 @@ def apply( ): from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1119,7 +1121,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 4ab618dc44ef..6b731314825a 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -13,6 +13,7 @@ FusedMoE, FusedMoEConfig, FusedMoEMethodBase, + FusedMoERouter, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( @@ -350,10 +351,11 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -542,6 +544,7 @@ def get_fused_moe_quant_config(self, layer): def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -750,10 +753,11 @@ def allow_inplace(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index dce9c661ec33..239adb384708 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -15,7 +15,11 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, + FusedMoEMethodBase, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -356,10 +360,11 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, )