Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_routing_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s

# Test the select_experts method
topk_weights, topk_ids = fused_moe.select_experts(
topk_weights, topk_ids = fused_moe.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoeWeightScaleSupported,
Expand Down Expand Up @@ -48,6 +51,7 @@ def get_config() -> dict[str, Any] | None:

__all__ = [
"FusedMoE",
"FusedMoERouter",
"FusedMoEConfig",
"FusedMoEMethodBase",
"UnquantizedFusedMoEMethod",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
Expand Down Expand Up @@ -109,6 +112,7 @@ def method_name(self) -> str:
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
FusedMoEPrepareAndFinalize,
Expand Down Expand Up @@ -88,10 +89,11 @@ def get_fused_moe_quant_config(
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
Expand Down
40 changes: 40 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod

import torch

from vllm.model_executor.layers.fused_moe.config import RoutingMethodType


class FusedMoERouter(ABC):
"""
FusedMoERouter is an abstract class that provides a 'select_experts'
method that is used for routing hidden states based on router logits.
"""

@property
@abstractmethod
def routing_method_type(self) -> RoutingMethodType:
raise NotImplementedError

@abstractmethod
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
router logits.

Returns:
(topk_weights, topk_ids)
(tuple[torch.Tensor, torch.Tensor]):
The weights and expert ids computation result.

**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
raise NotImplementedError
28 changes: 25 additions & 3 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
Expand Down Expand Up @@ -284,6 +285,23 @@ def maybe_roundup_hidden_size(
return hidden_size


class FusedMoERouterImpl(FusedMoERouter):
Copy link
Copy Markdown
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea that eventually we will have N of these, 1 for each routing_method_type?

And then that FusedMoERouter would not have the layer as an attr?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, see #30623

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also #30573 (which comes before #30623)

def __init__(self, layer: "FusedMoE"):
super().__init__()
self.layer = layer

@property
def routing_method_type(self) -> RoutingMethodType:
return self.layer.routing_method_type

def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.layer._select_experts(hidden_states, router_logits)


@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models.
Expand Down Expand Up @@ -339,7 +357,7 @@ def __init__(
is_sequence_parallel=False,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: int | None = None,
routing_method_type: RoutingMethodType | None = None,
router_logits_dtype: torch.dtype | None = None,
):
super().__init__()
Expand Down Expand Up @@ -529,7 +547,7 @@ def __init__(

# ToDo: Better logic to determine the routing method type
if routing_method_type is not None:
self.routing_method_type = routing_method_type
self.routing_method_type: RoutingMethodType = routing_method_type
else:
if scoring_func == "sigmoid":
if self.use_grouped_topk:
Expand Down Expand Up @@ -646,6 +664,8 @@ def _get_quant_method() -> FusedMoEMethodBase:
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None

self.router = FusedMoERouterImpl(self)

# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it
Expand Down Expand Up @@ -1509,7 +1529,7 @@ def ensure_dp_chunking_init(self):
device=torch.cuda.current_device(),
)

def select_experts(
def _select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand Down Expand Up @@ -1778,6 +1798,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
router=self.router,
x=staged_hidden_states,
router_logits=staged_router_logits,
)
Expand Down Expand Up @@ -1950,6 +1971,7 @@ def forward_impl(
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
router=self.router,
x=hidden_states_combined
if do_naive_dispatch_combine
else hidden_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute,
Expand Down Expand Up @@ -285,10 +286,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
router=router,
layer=layer,
x=x,
router_logits=router_logits,
Expand All @@ -311,10 +314,11 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantCon
def forward_cuda(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
Expand All @@ -337,6 +341,7 @@ def forward_cuda(
def forward_cpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -370,6 +375,7 @@ def forward_cpu(
def forward_xpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
Expand Down Expand Up @@ -759,12 +760,13 @@ def select_gemm_impl(
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."

topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
Expand Down
9 changes: 7 additions & 2 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
Expand Down Expand Up @@ -495,12 +499,13 @@ def get_fused_moe_quant_config(
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts

topk_weights, topk_ids = layer.select_experts(
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
Expand Down
Loading