diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py index 44cbdeed4507..e37f30755663 100644 --- a/tests/test_routing_simulator.py +++ b/tests/test_routing_simulator.py @@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device): envs.environment_variables[env_name] = lambda s=strategy: s # Test the select_experts method - topk_weights, topk_ids = fused_moe.select_experts( + topk_weights, topk_ids = fused_moe.router.select_experts( hidden_states=hidden_states, router_logits=router_logits, ) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index a482c6f55caf..9fa0dbb51c1e 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -1155,6 +1155,13 @@ def _sync_load_pass(self) -> list[torch.Tensor]: return self._allreduce_list(load_pass_list) +@dataclass +class EplbLayerState: + expert_load_view: torch.Tensor | None = None + logical_to_physical_map: torch.Tensor | None = None + logical_replica_count: torch.Tensor | None = None + + def _node_count_with_rank_mapping( pg: ProcessGroup | StatelessProcessGroup, rank_mapping: dict[int, int], diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9f04397e91f7..b1bd580b963a 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoeWeightScaleSupported, @@ -48,6 +51,7 @@ def get_config() -> dict[str, Any] | None: __all__ = [ "FusedMoE", + "FusedMoERouter", "FusedMoEConfig", "FusedMoEMethodBase", "UnquantizedFusedMoEMethod", diff --git a/vllm/model_executor/layers/fused_moe/default_fused_moe_router.py b/vllm/model_executor/layers/fused_moe/default_fused_moe_router.py new file mode 100644 index 000000000000..e72c0d1f672e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/default_fused_moe_router.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable +from functools import partial + +import torch + +import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops +from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator +from vllm.platforms import current_platform + +if current_platform.is_cuda_alike(): + from .fused_moe import eplb_map_to_physical_and_record +else: + + def eplb_map_to_physical_and_record( + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> torch.Tensor: + # CPU fallback: no EPLB so just return as is + return topk_ids + + +from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_grouped_topk, +) + + +class DefaultFusedMoERouter(FusedMoERouter): + def __init__( + self, + top_k: int, + global_num_experts: int, + eplb_state: EplbLayerState, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: int | None = None, + topk_group: int | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + num_fused_shared_experts: int = 0, + enable_eplb: bool = False, + indices_type_getter: Callable[[], torch.dtype | None] | None = None, + routing_method_type: RoutingMethodType | None = None, + ): + super().__init__() + self.top_k = top_k + self.global_num_experts = global_num_experts + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.custom_routing_function = custom_routing_function + self.scoring_func = scoring_func + self.routed_scaling_factor = routed_scaling_factor + self.e_score_correction_bias = e_score_correction_bias + self.num_fused_shared_experts = num_fused_shared_experts + self.enable_eplb = enable_eplb + self.eplb_state = eplb_state + self.indices_type_getter = indices_type_getter + + if self.scoring_func != "softmax" and not self.use_grouped_topk: + raise ValueError( + "Only softmax scoring function is supported for non-grouped topk." + ) + + # ToDo: Better logic to determine the routing method type + if routing_method_type is not None: + self._routing_method_type: RoutingMethodType = routing_method_type + else: + if scoring_func == "sigmoid": + if self.use_grouped_topk: + self._routing_method_type = RoutingMethodType.DeepSeekV3 + elif self.top_k == 1: + self._routing_method_type = RoutingMethodType.Llama4 + elif self.scoring_func == "softmax": + self._routing_method_type = ( + RoutingMethodType.Renormalize + if not self.renormalize + else RoutingMethodType.RenormalizeNaive + ) + else: + self._routing_method_type = RoutingMethodType.TopK + + @property + def routing_method_type(self) -> RoutingMethodType: + return self._routing_method_type + + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route the input hidden states to the top-k experts based on the + router logits. + + Returns: + (topk_weights, topk_ids) + (tuple[torch.Tensor, torch.Tensor]): + The weights, expert ids computation result. + + **Compatibility**: When EPLB is not enabled, the returned ids are + equivalent to global logical ids, so should be compatible with + plain MoE implementations without redundant experts. + """ + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, + fused_topk_bias, + ) + + if self.enable_eplb: + if self.eplb_state.expert_load_view is None: + raise ValueError("enable_eplb=True requiere expert_load_view != None") + if self.eplb_state.logical_to_physical_map is None: + raise ValueError( + "enable_eplb=True requiere logical_to_physical_map != None" + ) + if self.eplb_state.logical_replica_count is None: + raise ValueError( + "enable_eplb=True requiere logical_replica_count != None" + ) + + def valid_grouping() -> bool: + # Check if num_experts is greater than num_expert_group + # and is divisible by num_expert_group + assert self.num_expert_group is not None + num_experts = router_logits.shape[-1] + if num_experts <= self.num_expert_group: + return False + return num_experts % self.num_expert_group == 0 + + indices_type = ( + self.indices_type_getter() if self.indices_type_getter is not None else None + ) + + # Check if we should use a routing simulation strategy + routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY + if routing_strategy != "": + topk_weights, topk_ids = RoutingSimulator.simulate_routing( + hidden_states=hidden_states, + router_logits=router_logits, + strategy_name=routing_strategy, + top_k=self.top_k, + indices_type=indices_type, + ) + + # DeepSeekv2 uses grouped_top_k + elif self.use_grouped_topk and valid_grouping(): + assert self.topk_group is not None + assert self.num_expert_group is not None + if rocm_aiter_ops.is_fused_moe_enabled(): + if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): + assert self.num_fused_shared_experts == 0 + grouped_topk_impl = partial( + rocm_aiter_grouped_topk, + num_fused_shared_experts=self.num_fused_shared_experts, + ) + else: + grouped_topk_impl = grouped_topk + + topk_weights, topk_ids = grouped_topk_impl( + hidden_states=hidden_states, + gating_output=router_logits, + topk=self.top_k, + renormalize=self.renormalize, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + ) + elif self.e_score_correction_bias is not None: + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=router_logits, + e_score_correction_bias=self.e_score_correction_bias.data, + topk=self.top_k, + renormalize=self.renormalize, + ) + if self.routed_scaling_factor != 1.0: + topk_weights *= self.routed_scaling_factor + elif self.custom_routing_function is None: + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=self.top_k, + renormalize=self.renormalize, + indices_type=indices_type, + ) + else: + topk_weights, topk_ids = self.custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=self.top_k, + renormalize=self.renormalize, + ) + + if self.enable_eplb: + assert self.eplb_state.expert_load_view is not None + assert self.eplb_state.logical_to_physical_map is not None + assert self.eplb_state.logical_replica_count is not None + topk_ids = eplb_map_to_physical_and_record( + topk_ids=topk_ids, + expert_load_view=self.eplb_state.expert_load_view, + logical_to_physical_map=self.eplb_state.logical_to_physical_map, + logical_replica_count=self.eplb_state.logical_replica_count, + ) + + if (indices_type is not None) and topk_ids.dtype != indices_type: + topk_ids = topk_ids.to(dtype=indices_type) + + assert topk_ids.dtype == indices_type or indices_type is None + + return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index a46e3972ed8e..389ccf358c56 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -10,6 +10,9 @@ FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, @@ -109,6 +112,7 @@ def method_name(self) -> str: def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 6abefde0763e..10fa0ca7930d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel, FusedMoEPrepareAndFinalize, @@ -88,10 +89,11 @@ def get_fused_moe_quant_config( def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_router.py b/vllm/model_executor/layers/fused_moe/fused_moe_router.py new file mode 100644 index 000000000000..c322a8cd4cd6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_router.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + +import torch + +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType + + +class FusedMoERouter(ABC): + """ + FusedMoERouter is an abstract class that provides a 'select_experts' + method that is used for routing hidden states based on router logits. + """ + + @property + @abstractmethod + def routing_method_type(self) -> RoutingMethodType: + raise NotImplementedError + + @abstractmethod + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route the input hidden states to the top-k experts based on the + router logits. + + Returns: + (topk_weights, topk_ids) + (tuple[torch.Tensor, torch.Tensor]): + The weights and expert ids computation result. + + **Compatibility**: When EPLB is not enabled, the returned ids are + equivalent to global logical ids, so should be compatible with + plain MoE implementations without redundant experts. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 52f093f62d5a..c86b9a4ae67b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -21,7 +21,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.distributed.eplb.eplb_state import EplbState +from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -31,10 +31,21 @@ FusedMoEQuantConfig, RoutingMethodType, ) +from vllm.model_executor.layers.fused_moe.default_fused_moe_router import ( + DefaultFusedMoERouter, +) +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( + FusedMoEModularMethod, +) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, ) -from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator +from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( + UnquantizedFusedMoEMethod, +) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, ) @@ -48,31 +59,6 @@ ) from vllm.v1.worker.ubatching import dbo_current_ubatch_id -if current_platform.is_cuda_alike(): - from .fused_moe import eplb_map_to_physical_and_record -else: - - def _eplb_map_to_physical_and_record( - topk_ids: torch.Tensor, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> torch.Tensor: - # CPU fallback: no EPLB so just return as is - return topk_ids - - eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record -from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk -from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( - FusedMoEMethodBase, -) -from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( - FusedMoEModularMethod, -) -from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( - UnquantizedFusedMoEMethod, -) - logger = init_logger(__name__) @@ -339,7 +325,7 @@ def __init__( is_sequence_parallel=False, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, - routing_method_type: int | None = None, + routing_method_type: RoutingMethodType | None = None, router_logits_dtype: torch.dtype | None = None, ): super().__init__() @@ -415,14 +401,6 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix - self.enable_eplb = enable_eplb - self.expert_load_view: torch.Tensor | None = None - self.logical_to_physical_map: torch.Tensor | None = None - self.logical_replica_count: torch.Tensor | None = None - self.expert_placement_strategy: ExpertPlacementStrategy = ( - vllm_config.parallel_config.expert_placement_strategy - ) - # ROCm aiter shared experts fusion self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() self.aiter_fmoe_shared_expert_enabled = ( @@ -434,6 +412,7 @@ def __init__( if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled else 0 ) + if ( not self.aiter_fmoe_shared_expert_enabled and self.num_fused_shared_experts != 0 @@ -443,6 +422,43 @@ def __init__( "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled" ) + self.enable_eplb = enable_eplb + self.eplb_state = EplbLayerState() + self.expert_placement_strategy: ExpertPlacementStrategy = ( + vllm_config.parallel_config.expert_placement_strategy + ) + + if self.enable_eplb and not self.quant_method.supports_eplb: + # TODO: Add support for additional quantization methods. + # The implementation for other quantization methods does not + # contain essential differences, but the current quant API + # design causes duplicated work when extending to new + # quantization methods, so I'm leaving it for now. + # If you plan to add support for more quantization methods, + # please refer to the implementation in `Fp8MoEMethod`. + raise NotImplementedError( + f"EPLB is not supported {self.quant_method.__class__.__name__}. " + "EPLB is only supported for FP8 quantization for now." + ) + + self.router = DefaultFusedMoERouter( + top_k=top_k, + global_num_experts=self.global_num_experts, + eplb_state=self.eplb_state, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + num_fused_shared_experts=self.num_fused_shared_experts, + enable_eplb=enable_eplb, + indices_type_getter=lambda: self.quant_method.topk_indices_dtype, + routing_method_type=routing_method_type, + ) + # Determine expert maps if self.use_ep: if self.enable_eplb: @@ -509,42 +525,9 @@ def __init__( self.hidden_size = hidden_size self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results - self.renormalize = renormalize - self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.topk_group = topk_group - self.custom_routing_function = custom_routing_function - self.scoring_func = scoring_func - self.routed_scaling_factor = routed_scaling_factor - self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation - if self.scoring_func != "softmax" and not self.use_grouped_topk: - raise ValueError( - "Only softmax scoring function is supported for non-grouped topk." - ) - - # ToDo: Better logic to determine the routing method type - if routing_method_type is not None: - self.routing_method_type = routing_method_type - else: - if scoring_func == "sigmoid": - if self.use_grouped_topk: - self.routing_method_type = RoutingMethodType.DeepSeekV3 - elif self.top_k == 1: - self.routing_method_type = RoutingMethodType.Llama4 - elif self.scoring_func == "softmax": - self.routing_method_type = ( - RoutingMethodType.Renormalize - if not self.renormalize - else RoutingMethodType.RenormalizeNaive - ) - else: - self.routing_method_type = RoutingMethodType.TopK - self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, @@ -611,19 +594,6 @@ def _get_quant_method() -> FusedMoEMethodBase: "(when AITER MoE is disabled) for now" ) - if self.enable_eplb and not self.quant_method.supports_eplb: - # TODO: Add support for additional quantization methods. - # The implementation for other quantization methods does not - # contain essential differences, but the current quant API - # design causes duplicated work when extending to new - # quantization methods, so I'm leaving it for now. - # If you plan to add support for more quantization methods, - # please refer to the implementation in `Fp8MoEMethod`. - raise NotImplementedError( - f"EPLB is not supported {self.quant_method.__class__.__name__}. " - "EPLB is only supported for FP8 quantization for now." - ) - moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, @@ -1466,9 +1436,9 @@ def set_eplb_state( This is used later in forward pass, where we get the expert mapping and record the load metrics in `expert_load_view`. """ - self.expert_load_view = expert_load_view[moe_layer_idx] - self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] - self.logical_replica_count = logical_replica_count[moe_layer_idx] + self.eplb_state.expert_load_view = expert_load_view[moe_layer_idx] + self.eplb_state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] + self.eplb_state.logical_replica_count = logical_replica_count[moe_layer_idx] def ensure_moe_quant_config_init(self): if self.quant_method.moe_quant_config is None: @@ -1509,129 +1479,6 @@ def ensure_dp_chunking_init(self): device=torch.cuda.current_device(), ) - def select_experts( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Route the input hidden states to the top-k experts based on the - router logits. - - Returns: - (topk_weights, topk_ids) - (tuple[torch.Tensor, torch.Tensor]): - The weights and expert ids. - - **Compatibility**: When EPLB is not enabled, the returned ids are - equivalent to global logical ids, so should be compatible with - plain MoE implementations without redundant experts. - """ - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, - fused_topk_bias, - ) - - if self.enable_eplb: - if self.quant_method.supports_eplb: - if self.expert_load_view is None: - raise ValueError( - "enable_eplb=True requiere expert_load_view != None" - ) - if self.logical_to_physical_map is None: - raise ValueError( - "enable_eplb=True requiere logical_to_physical_map != None" - ) - if self.logical_replica_count is None: - raise ValueError( - "enable_eplb=True requiere logical_replica_count != None" - ) - else: - raise NotImplementedError( - f"EPLB is not supported for {self.quant_method.method_name}." - ) - - def valid_grouping() -> bool: - # Check if num_experts is greater than num_expert_group - # and is divisible by num_expert_group - num_experts = router_logits.shape[-1] - if num_experts <= self.num_expert_group: - return False - return num_experts % self.num_expert_group == 0 - - indices_type = self.quant_method.topk_indices_dtype - - # Check if we should use a routing simulation strategy - routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY - if routing_strategy != "": - topk_weights, topk_ids = RoutingSimulator.simulate_routing( - hidden_states=hidden_states, - router_logits=router_logits, - strategy_name=routing_strategy, - top_k=self.top_k, - indices_type=indices_type, - ) - - # DeepSeekv2 uses grouped_top_k - elif self.use_grouped_topk and valid_grouping(): - assert self.topk_group is not None - assert self.num_expert_group is not None - grouped_topk_impl = GroupedTopk( - topk=self.top_k, - renormalize=self.renormalize, - num_expert_group=self.num_expert_group, - topk_group=self.topk_group, - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - num_fused_shared_experts=self.num_fused_shared_experts, - ) - - topk_weights, topk_ids = grouped_topk_impl( - hidden_states=hidden_states, - gating_output=router_logits, - e_score_correction_bias=self.e_score_correction_bias, - ) - elif self.e_score_correction_bias is not None: - topk_weights, topk_ids = fused_topk_bias( - hidden_states=hidden_states, - gating_output=router_logits, - e_score_correction_bias=self.e_score_correction_bias.data, - topk=self.top_k, - renormalize=self.renormalize, - ) - if self.routed_scaling_factor != 1.0: - topk_weights *= self.routed_scaling_factor - elif self.custom_routing_function is None: - topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.top_k, - renormalize=self.renormalize, - indices_type=indices_type, - ) - else: - topk_weights, topk_ids = self.custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.top_k, - renormalize=self.renormalize, - ) - - if self.enable_eplb: - topk_ids = eplb_map_to_physical_and_record( - topk_ids=topk_ids, - expert_load_view=self.expert_load_view, - logical_to_physical_map=self.logical_to_physical_map, - logical_replica_count=self.logical_replica_count, - ) - - if (indices_type is not None) and topk_ids.dtype != indices_type: - topk_ids = topk_ids.to(dtype=indices_type) - - assert topk_ids.dtype == indices_type or indices_type is None - - return topk_weights, topk_ids - def must_reduce_shared_expert_outputs(self) -> bool: """ The shared_experts are typically computed using the RowParallelLinear @@ -1778,6 +1625,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, + router=self.router, x=staged_hidden_states, router_logits=staged_router_logits, ) @@ -1950,6 +1798,7 @@ def forward_impl( # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, + router=self.router, x=hidden_states_combined if do_naive_dispatch_combine else hidden_states, @@ -2051,15 +1900,8 @@ def extra_repr(self) -> str: f"tp_size={self.tp_size},\n" f"ep_size={self.ep_size}, " f"reduce_results={self.reduce_results}, " - f"renormalize={self.renormalize}, " - f"use_grouped_topk={self.use_grouped_topk}" ) - if self.use_grouped_topk: - s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501 - - s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501 - return s diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 92ef850205fc..1cdc25135a34 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, @@ -285,10 +286,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: return self.forward( + router=router, layer=layer, x=x, router_logits=router_logits, @@ -311,10 +314,11 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantCon def forward_cuda( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -337,6 +341,7 @@ def forward_cuda( def forward_cpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -370,6 +375,7 @@ def forward_cpu( def forward_xpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 602d02d2f15a..5763a41193e8 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -15,6 +15,7 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -759,12 +760,13 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index efe5677045e4..1d2334f3933a 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -10,7 +10,11 @@ FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, + FusedMoEMethodBase, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -495,12 +499,13 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 509de5dff9c1..86878c84ab83 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -40,6 +40,7 @@ MarlinExperts, fused_marlin_moe, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, @@ -458,6 +459,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -484,7 +486,7 @@ def apply( x_routing, _ = x else: x_routing = x - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x_routing, router_logits=router_logits, ) @@ -926,10 +928,11 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1066,12 +1069,13 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1426,6 +1430,7 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -1433,7 +1438,7 @@ def apply( f"{layer.activation} not supported for Marlin MoE." ) - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1677,12 +1682,13 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1978,6 +1984,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: @@ -2290,6 +2297,7 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ): @@ -2298,7 +2306,7 @@ def apply( "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." ) assert self.moe_quant_config is not None - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 56b11b22f7ff..37e6020cb2a9 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -15,6 +15,7 @@ FusedMoEQuantConfig, int8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( @@ -137,12 +138,13 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2879315a6886..1c0c35bf6f41 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -29,6 +29,7 @@ FusedMoEQuantConfig, RoutingMethodType, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, @@ -997,6 +998,7 @@ def allow_inplace(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -1051,7 +1053,7 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9600bb42295d..1c03e5243a85 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -16,7 +16,11 @@ FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, + FusedMoEMethodBase, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -629,6 +633,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -639,7 +644,7 @@ def apply( "fused GGUF MoE method." ) - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3c958588c78f..68a2c375e353 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -15,6 +15,7 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -895,12 +896,13 @@ def select_gemm_impl( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 9de2924ec71b..475bd853676e 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -9,6 +9,9 @@ from vllm._ipex_ops import ipex_ops as ops from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -384,6 +387,7 @@ def get_fused_moe_quant_config( def apply( self, layer: torch.nn.Module, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 2e4f1daf6690..a646012ddd3a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -14,8 +14,10 @@ from vllm.attention.layer import Attention from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -200,7 +202,9 @@ def get_quant_method( quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method elif isinstance(layer, FusedMoE): - quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer) + quant_method = self.FusedMoEMethodCls( + quant_config=self, moe_config=layer.moe_config + ) if getattr(quant_method, "backend", "") == "marlin": quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method @@ -720,14 +724,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def __init__( self, quant_config: ModelOptFp8Config, - layer: FusedMoE, + moe_config: FusedMoEConfig, ) -> None: - super().__init__(layer.moe_config) + super().__init__(moe_config) self.quant_config = quant_config assert self.quant_config.is_checkpoint_fp8_serialized self.fp8_backend = select_fp8_moe_backend( block_quant=False, - tp_size=layer.moe_parallel_config.tp_size, + tp_size=moe_config.moe_parallel_config.tp_size, with_lora_support=self.moe.is_lora_enabled, ) self.kernel: mk.FusedMoEModularKernel | None = None @@ -935,6 +939,7 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -961,7 +966,8 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - topk_weights, topk_ids = layer.select_experts( + # Expert selection + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1325,9 +1331,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def __init__( self, quant_config: ModelOptNvFp4Config, - layer: FusedMoE, + moe_config: FusedMoEConfig, ) -> None: - super().__init__(layer.moe_config) + super().__init__(moe_config) self.quant_config = quant_config self.nvfp4_backend = select_nvfp4_moe_backend() # TODO: move this type of check into the oracle. @@ -1597,6 +1603,7 @@ def supports_eplb(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -1621,7 +1628,7 @@ def apply( x_routing, _ = x else: x_routing = x - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x_routing, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 513f6f7b21ab..d5d94082587f 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -11,6 +11,7 @@ int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEConfig, @@ -364,13 +365,14 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts assert layer.activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 15edd3e613bf..8e050b795f94 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -27,6 +27,7 @@ MarlinExperts, fused_marlin_moe, ) +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, UnfusedOAITritonExperts, @@ -891,6 +892,7 @@ def allow_inplace(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -898,7 +900,7 @@ def apply( raise NotImplementedError("EPLB is not supported for mxfp4") if self.mxfp4_backend == Mxfp4Backend.MARLIN: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -992,7 +994,7 @@ def apply( ): from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -1119,7 +1121,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 4ab618dc44ef..6b731314825a 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -13,6 +13,7 @@ FusedMoE, FusedMoEConfig, FusedMoEMethodBase, + FusedMoERouter, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( @@ -350,10 +351,11 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) @@ -542,6 +544,7 @@ def get_fused_moe_quant_config(self, layer): def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -750,10 +753,11 @@ def allow_inplace(self) -> bool: def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index dce9c661ec33..239adb384708 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -15,7 +15,11 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, + FusedMoEMethodBase, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -356,10 +360,11 @@ def get_fused_moe_quant_config( def apply( self, layer: FusedMoE, + router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, )