Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner import (
MoERunner,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner_interface import (
MoERunnerInterface,
from vllm.model_executor.layers.fused_moe.runner.moe_runner_factory import (
create_moe_runner,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
Expand Down Expand Up @@ -589,7 +586,7 @@ def _get_quant_method() -> FusedMoEMethodBase:
# Storing the runner in the FusedMoE is an intermediate state, eventually
# the runner will own the FusedMoE layer and provide the execution interface
# for MoE ops.
self.runner: MoERunnerInterface = MoERunner(
self.runner = create_moe_runner(
layer_name=self.layer_name,
moe_config=self.moe_config,
router=self.router,
Expand Down
128 changes: 128 additions & 0 deletions vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm.distributed import (
get_ep_group,
get_pcp_group,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import MoERunnerBase


class DefaultMoERunner(MoERunnerBase):
"""
Standard MoE runner implementation for executing Mixture of Experts layers.

This is the primary concrete implementation of MoE execution logic, providing
comprehensive support for standard MoE operations. It handles:
- Expert routing and token dispatching using various routing strategies
- Shared experts computation with optional parallel execution using CUDA streams
- Tensor model parallel and expert parallel operations
- Multiple quantization methods and optimized kernel selection
- Both monolithic and decomposed expert execution paths
- Integration with various parallel execution modes (TP, EP, DP)

The runner orchestrates the complete MoE forward pass including routing tokens
to experts, executing expert computations in parallel, and combining results.
It supports advanced features like overlapped execution of shared experts,
optimized kernels for different parallel configurations, and seamless
integration with vLLM's distributed execution framework.

This implementation is suitable for most standard MoE use cases. For specialized
scenarios like large batch chunking, alternative runners like ChunkingMoERunner
may be more appropriate.

Eventually, this class may be split into more specialized implementations
for different configurations (e.g., with/without shared experts, gates, etc.).
"""

@property
def do_naive_dispatch_combine(self) -> bool:
return (
self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
)

def _maybe_dispatch(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if self.do_naive_dispatch_combine:
res = get_ep_group().dispatch_router_logits(
hidden_states,
router_logits,
self.moe_config.is_sequence_parallel,
)
assert len(res) == 2
hidden_states, router_logits = res

# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstraction to better support PCP.
if self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)

return hidden_states, router_logits

def _maybe_combine(
self,
shared_output: torch.Tensor | None,
hidden_states: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
if self.do_naive_dispatch_combine:
hidden_states = get_ep_group().combine(
hidden_states, self.moe_config.is_sequence_parallel
)

if self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().reduce_scatter(
hidden_states,
dim=0,
)

if self.shared_experts is not None:
assert shared_output is not None
return shared_output, hidden_states
else:
return hidden_states

def _forward_impl(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# TODO(bnell): parts of the dispatch/combine steps will go away once
# #32567 lands and the remaining kernels are made MKs. The PCP
# code will probably remain
hidden_states, router_logits = self._maybe_dispatch(
layer,
hidden_states,
router_logits,
)

shared_output, hidden_states = self._apply_quant_method(
layer=layer,
hidden_states=hidden_states,
router_logits=router_logits,
shared_experts_input=shared_experts_input,
)

return self._maybe_combine(
shared_output,
hidden_states,
)
Loading