diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bf10bc9d5c4c..7adac0374cf8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -38,8 +38,11 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import ( create_fused_moe_router, ) -from vllm.model_executor.layers.fused_moe.runner.moe_runner_factory import ( - create_moe_runner, +from vllm.model_executor.layers.fused_moe.runner.moe_runner import ( + MoERunner, +) +from vllm.model_executor.layers.fused_moe.runner.moe_runner_interface import ( + MoERunnerInterface, ) from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( SharedExperts, @@ -586,7 +589,7 @@ def _get_quant_method() -> FusedMoEMethodBase: # Storing the runner in the FusedMoE is an intermediate state, eventually # the runner will own the FusedMoE layer and provide the execution interface # for MoE ops. - self.runner = create_moe_runner( + self.runner: MoERunnerInterface = MoERunner( layer_name=self.layer_name, moe_config=self.moe_config, router=self.router, diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py deleted file mode 100644 index df4c0c869248..000000000000 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ /dev/null @@ -1,128 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.distributed import ( - get_ep_group, - get_pcp_group, -) -from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import MoERunnerBase - - -class DefaultMoERunner(MoERunnerBase): - """ - Standard MoE runner implementation for executing Mixture of Experts layers. - - This is the primary concrete implementation of MoE execution logic, providing - comprehensive support for standard MoE operations. It handles: - - Expert routing and token dispatching using various routing strategies - - Shared experts computation with optional parallel execution using CUDA streams - - Tensor model parallel and expert parallel operations - - Multiple quantization methods and optimized kernel selection - - Both monolithic and decomposed expert execution paths - - Integration with various parallel execution modes (TP, EP, DP) - - The runner orchestrates the complete MoE forward pass including routing tokens - to experts, executing expert computations in parallel, and combining results. - It supports advanced features like overlapped execution of shared experts, - optimized kernels for different parallel configurations, and seamless - integration with vLLM's distributed execution framework. - - This implementation is suitable for most standard MoE use cases. For specialized - scenarios like large batch chunking, alternative runners like ChunkingMoERunner - may be more appropriate. - - Eventually, this class may be split into more specialized implementations - for different configurations (e.g., with/without shared experts, gates, etc.). - """ - - @property - def do_naive_dispatch_combine(self) -> bool: - return ( - self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk - ) - - def _maybe_dispatch( - self, - layer: torch.nn.Module, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - # For naive dispatch/combine Dp/Ep, dispatch the hidden states and - # router logits to all experts. - # NOTE: this will be removed once all kernels are migrated into the - # MoEKernel framework. - if self.do_naive_dispatch_combine: - res = get_ep_group().dispatch_router_logits( - hidden_states, - router_logits, - self.moe_config.is_sequence_parallel, - ) - assert len(res) == 2 - hidden_states, router_logits = res - - # NOTE: Similar with DP, PCP also needs dispatch and combine. For - # simplicity, AgRsAll2All was added separately for PCP here. Maybe - # we should modify All2AllManager abstraction to better support PCP. - if self.moe_config.pcp_size > 1: - hidden_states = get_pcp_group().all_gather( - hidden_states, - dim=0, - ) - router_logits = get_pcp_group().all_gather( - router_logits, - dim=0, - ) - - return hidden_states, router_logits - - def _maybe_combine( - self, - shared_output: torch.Tensor | None, - hidden_states: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]: - if self.do_naive_dispatch_combine: - hidden_states = get_ep_group().combine( - hidden_states, self.moe_config.is_sequence_parallel - ) - - if self.moe_config.pcp_size > 1: - hidden_states = get_pcp_group().reduce_scatter( - hidden_states, - dim=0, - ) - - if self.shared_experts is not None: - assert shared_output is not None - return shared_output, hidden_states - else: - return hidden_states - - def _forward_impl( - self, - layer: torch.nn.Module, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - # TODO(bnell): parts of the dispatch/combine steps will go away once - # #32567 lands and the remaining kernels are made MKs. The PCP - # code will probably remain - hidden_states, router_logits = self._maybe_dispatch( - layer, - hidden_states, - router_logits, - ) - - shared_output, hidden_states = self._apply_quant_method( - layer=layer, - hidden_states=hidden_states, - router_logits=router_logits, - shared_experts_input=shared_experts_input, - ) - - return self._maybe_combine( - shared_output, - hidden_states, - ) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py index 199ceab0659c..00be12780a16 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py @@ -1,44 +1,705 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod +from collections.abc import Callable +from contextlib import nullcontext +from typing import TYPE_CHECKING import torch +import torch.nn.functional as F +from vllm.distributed import ( + get_ep_group, + get_pcp_group, + tensor_model_parallel_all_reduce, +) +from vllm.forward_context import ( + ForwardContext, + get_forward_context, + is_forward_context_available, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, +) from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) +from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( + FusedMoERouter, +) +from vllm.model_executor.layers.fused_moe.router.zero_expert_router import ( + ZeroExpertRouter, +) +from vllm.model_executor.layers.fused_moe.runner.moe_runner_interface import ( + MoERunnerInterface, +) from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( SharedExperts, + SharedExpertsOrder, +) +from vllm.platforms import current_platform +from vllm.utils.torch_utils import ( + _USE_LAYERNAME, + LayerName, + direct_register_custom_op, +) + + +def get_layer_from_name(layer_name: str) -> torch.nn.Module: + forward_context: ForwardContext = get_forward_context() + if not _USE_LAYERNAME and layer_name == "from_forward_context": + all_moe_layers = forward_context.all_moe_layers + assert all_moe_layers is not None + moe_layer_index = forward_context.moe_layer_index + if moe_layer_index >= len(all_moe_layers): + raise AssertionError( + "We expected the number of MOE layers in `all_moe_layers` " + "to be equal to the number of " + "{vllm.moe_forward, vllm.moe_forward_shared} calls." + ) + layer_name = all_moe_layers[moe_layer_index] + forward_context.moe_layer_index += 1 + return forward_context.no_compile_layers[layer_name] + + +# On torch >= 2.11, layer_name is a hoisted LayerName opaque object; +# on older versions it remains a plain str. +if TYPE_CHECKING: + from typing import TypeAlias + + _layer_name_type: TypeAlias = str | LayerName +else: + _layer_name_type = LayerName if _USE_LAYERNAME else str + + +@torch.compiler.assume_constant_result +def _resolve_layer_name(layer_name: str | LayerName) -> str: + from torch._library.fake_class_registry import FakeScriptObject + + if isinstance(layer_name, LayerName): + return layer_name.value + elif isinstance(layer_name, FakeScriptObject): + return layer_name.real_obj.value + return layer_name + + +# Note: _moe_forward and _moe_forward_shared should not contain any +# implementation details, They should merely pass along control to +# the runner's '_forward_impl' method. +# These functions should never be called directly since they do not +# include all the functionality of the MoE layer. +def _moe_forward( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + shared_experts_input: torch.Tensor | None, + layer_name: _layer_name_type, +) -> torch.Tensor: + layer = get_layer_from_name(_resolve_layer_name(layer_name)) + return layer.runner._forward_impl( + layer, + hidden_states, + router_logits, + shared_experts_input, + ) + + +def _moe_forward_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + shared_experts_input: torch.Tensor | None, + layer_name: _layer_name_type, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def _moe_forward_shared( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + shared_experts_input: torch.Tensor | None, + layer_name: _layer_name_type, +) -> tuple[torch.Tensor, torch.Tensor]: + layer = get_layer_from_name(_resolve_layer_name(layer_name)) + return layer.runner._forward_impl( + layer, + hidden_states, + router_logits, + shared_experts_input, + ) + + +def _moe_forward_shared_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + shared_experts_input: torch.Tensor | None, + layer_name: _layer_name_type, +) -> tuple[torch.Tensor, torch.Tensor]: + # Output shapes: + # - fused_out: same as hidden_states (routed experts use transformed size) + # - shared_out: same as shared_experts_input if provided, else same as + # hidden_states + # (For latent MoE: shared experts use original hidden_size, not latent size) + fused_out = torch.empty_like(hidden_states) + if shared_experts_input is not None: + shared_out = torch.empty_like(shared_experts_input) + else: + shared_out = torch.empty_like(hidden_states) + return shared_out, fused_out + + +direct_register_custom_op( + op_name="moe_forward", + op_func=_moe_forward, + mutates_args=["hidden_states"], + fake_impl=_moe_forward_fake, + tags=(torch.Tag.needs_fixed_stride_order,), ) -class MoERunner(ABC): +direct_register_custom_op( + op_name="moe_forward_shared", + op_func=_moe_forward_shared, + fake_impl=_moe_forward_shared_fake, + tags=(torch.Tag.needs_fixed_stride_order,), +) + + +def _unpack( + result: torch.Tensor | tuple[torch.Tensor, torch.Tensor], +) -> tuple[torch.Tensor | None, torch.Tensor]: + if isinstance(result, tuple): + return result + else: + return (None, result) + + +class MoERunner(MoERunnerInterface): """ - Abstract base class for Mixture of Experts (MoE) runners. + Standard MoE runner implementation for executing Mixture of Experts layers. + + This is the primary concrete implementation of MoE execution logic, providing + comprehensive support for standard MoE operations. It handles: + - Expert routing and token dispatching using various routing strategies + - Shared experts computation with optional parallel execution using CUDA streams + - Tensor model parallel and expert parallel operations + - Multiple quantization methods and optimized kernel selection + - Both monolithic and decomposed expert execution paths + - Integration with various parallel execution modes (TP, EP, DP) - This class defines the interface that all MoE runner implementations must follow. - MoE runners are responsible for executing the forward pass of MoE layers, handling - expert routing, and managing tensor parallel operations. + The runner orchestrates the complete MoE forward pass including routing tokens + to experts, executing expert computations in parallel, and combining results. + It supports advanced features like overlapped execution of shared experts, + optimized kernels for different parallel configurations, and seamless + integration with vLLM's distributed execution framework. + + Eventually, this class may be split into more specialized implementations + for different configurations (e.g., with/without shared experts, gates, etc.). """ - @abstractmethod + def __init__( + self, + layer_name: str, + moe_config: FusedMoEConfig, + router: FusedMoERouter, + routed_input_transform: torch.nn.Module | None, + gate: torch.nn.Module | None, + shared_experts: torch.nn.Module | None, + quant_method: FusedMoEMethodBase, + enable_dbo: bool, + routed_output_transform: torch.nn.Module | None = None, + routed_scaling_factor: float = 1.0, + ): + super().__init__() + self.moe_config = moe_config + self.router = router + self.routed_input_transform = routed_input_transform + self.routed_output_transform = routed_output_transform + self.routed_scaling_factor = routed_scaling_factor + self.gate = gate + self.quant_method = quant_method + self.enable_dbo = enable_dbo + + self._shared_experts: SharedExperts | None = None + if shared_experts is not None: + self._shared_experts = SharedExperts( + shared_experts, + moe_config=moe_config, + # Note: For now we must pass quant_method along to SharedExperts so it + # can property determine where the shared experts are supposed to be + # called, i.e. by a MK or by the MoERunner. + # Once the MK can be created upfront, we can just pass in the proper + # flags derived from the quant_method's MK. + quant_method=quant_method, + enable_dbo=enable_dbo, + ) + + # Needed for string -> FusedMoE layer lookup in custom ops. + self.layer_name = layer_name + + self._forward_entry = self._select_forward() + + def _select_forward(self) -> Callable: + if current_platform.is_tpu() or current_platform.is_cpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + # Note: CPU doesn't require wrapped _forward_impl. + return _moe_forward if self._shared_experts is None else _moe_forward_shared + + return ( + torch.ops.vllm.moe_forward + if self._shared_experts is None + else torch.ops.vllm.moe_forward_shared + ) + + @property + def shared_experts(self) -> SharedExperts | None: + return self._shared_experts + + # TODO(bnell): temporary hack, do not call this method. + def _replace_quant_method(self, quant_method: FusedMoEMethodBase): + if self._shared_experts is not None: + self._shared_experts._quant_method = quant_method + self.quant_method = quant_method + + def is_internal_router(self) -> bool: + return self.gate is not None + + def apply_routed_input_transform( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply transform for routed experts (e.g., latent projection). + + This is called by FusedMoE.forward_native. The original hidden_states + is saved separately so shared experts get [S, hidden_size] while + routed experts get the transformed [S, moe_latent_size]. + + Returns (possibly transformed) hidden states and the input for shared + experts (or None if there are no shared experts). + """ + if self.routed_input_transform is not None: + result = self.routed_input_transform(hidden_states) + # ReplicatedLinear returns (output, extra_bias) tuple. + # We only need the output tensor; extra_bias is not used here. + if isinstance(result, tuple): + return result[0], hidden_states + return result, hidden_states + + return ( + hidden_states, + hidden_states if self._shared_experts is not None else None, + ) + + def apply_routed_output_transform( + self, + fused_output: torch.Tensor, + ) -> torch.Tensor: + """Apply transform to routed expert output (e.g., latent to full dim). + + Used by latent MoE models (e.g., NemotronH) where routed experts + operate in a compressed latent space and need projection back to + the full hidden dimension before combining with shared expert output. + """ + if self.routed_output_transform is not None: + r = self.routed_output_transform(fused_output) + fused_output = r[0] if isinstance(r, tuple) else r + return fused_output + + def _maybe_apply_routed_scale_to_output( + self, + shared_output: torch.Tensor | None, + fused_output: torch.Tensor, + ) -> tuple[torch.Tensor | None, torch.Tensor]: + """Apply routed_scaling_factor to the output with FP16 overflow + protection. + + Scale the fused expert output by routed_scaling_factor. For FP16, + avoid overflow by dividing shared_output by the scale instead + (the decoder layer compensates with matching divisions). + """ + if self.routed_scaling_factor != 1.0: + if fused_output.dtype != torch.float16 or shared_output is None: + fused_output *= self.routed_scaling_factor + elif shared_output is not None: + shared_output *= 1.0 / self.routed_scaling_factor + return shared_output, fused_output + + @property + def _fused_output_is_reduced(self) -> bool: + return ( + self.quant_method.moe_kernel is not None + and self.quant_method.moe_kernel.output_is_reduced() + ) + + def _maybe_reduce_shared_expert_output( + self, + shared_output: torch.Tensor | None, + ) -> torch.Tensor | None: + """All-reduce shared expert output when the combine kernel already + reduced fused output. + + This is the "early" all-reduce path. When the combine kernel produces + already-reduced fused output, shared output must be reduced separately + to match. + """ + if shared_output is not None and self._fused_output_is_reduced: + shared_output = tensor_model_parallel_all_reduce(shared_output) + return shared_output + + def _maybe_reduce_final_output( + self, + states: torch.Tensor, + trunc_size: int, + ) -> torch.Tensor: + """Truncate padded dimensions and all-reduce the combined output. + + This is the "late" all-reduce path. When neither fused nor shared + output was individually reduced, the combined sum is all-reduced + here. Skipped when sequence-parallel is active (SP handles its + own reduction) or when the early path already reduced both outputs. + """ + # We don't need to reduce the final output if: + # - We are not running with TP or DP + # - The MK already reduced the fused output itself. + if ( + not self.moe_config.is_sequence_parallel + and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1) + and not self._fused_output_is_reduced + ): + states = tensor_model_parallel_all_reduce(states) + + return states[..., :trunc_size] + + def _encode_layer_name(self) -> str | LayerName: + if _USE_LAYERNAME: + return LayerName(self.layer_name) + # Can be unavailable or None in unittests + if ( + is_forward_context_available() + and get_forward_context().all_moe_layers is not None + ): + return "from_forward_context" + return self.layer_name + + def _maybe_pad_hidden_states( + self, + shared_experts_input: torch.Tensor | None, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, int]: + """Pad hidden_states to moe_config.hidden_dim and compute the + original dimension for later truncation. + + For latent MoE, the routed hidden_states may be smaller than + hidden_dim. Padding ensures uniform tensor sizes through the + fused MoE kernel. The returned trunc_size is used by + _maybe_reduce_final_output to strip the padding from the result. + """ + shared_experts_hidden_dim = ( + shared_experts_input.shape[-1] if shared_experts_input is not None else 0 + ) + transformed_hidden_dim = hidden_states.shape[-1] + if ( + not self.quant_method.skip_forward_padding + and self.moe_config.hidden_dim != transformed_hidden_dim + ): + hidden_states = F.pad( + hidden_states, + (0, self.moe_config.hidden_dim - transformed_hidden_dim), + mode="constant", + value=0.0, + ) + + if self.routed_output_transform is not None and shared_experts_hidden_dim > 0: + orig_hidden_dims = shared_experts_hidden_dim + else: + orig_hidden_dims = transformed_hidden_dim + + return hidden_states, orig_hidden_dims + + def _maybe_apply_shared_experts( + self, + shared_experts_input: torch.Tensor | None, + order: SharedExpertsOrder, + ): + if self._shared_experts is not None: + assert shared_experts_input is not None + self._shared_experts.apply(shared_experts_input, order) + + def _apply_quant_method( + self, + layer: torch.nn.Module, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> tuple[torch.Tensor | None, torch.Tensor]: + """Run expert routing and the fused MoE kernel via the quant method. + + Orchestrates shared expert execution (before/after), expert selection + via the router, and the actual fused MoE computation. Returns + (shared_expert_output, fused_expert_output). + """ + self._maybe_apply_shared_experts( + shared_experts_input, SharedExpertsOrder.NO_OVERLAP + ) + + if self.quant_method.is_monolithic: + fused_out = self.quant_method.apply_monolithic( + layer=layer, + x=hidden_states, + router_logits=router_logits, + ) + else: + topk_weights, topk_ids = self.router.select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + ) + + # Passing shared_experts_input in case SharedExpertsOrder is + # MK_INTERNAL_OVERLAPPED. + fused_out = self.quant_method.apply( + layer=layer, + x=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + shared_experts_input=shared_experts_input, + ) + + self._maybe_apply_shared_experts( + shared_experts_input, + SharedExpertsOrder.MULTI_STREAM_OVERLAPPED, + ) + + return ( + self._shared_experts.output if self._shared_experts is not None else None, + fused_out, + ) + + def _sequence_parallel_context(self): + """Return a context manager for sequence-parallel token + redistribution. + + When sequence parallelism is active, returns a context that handles + local size tracking for proper token scatter/gather. Otherwise + returns a no-op context. + """ + ctx = get_forward_context() + return ( + ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size) + if ctx.dp_metadata + else nullcontext() + ) + + def _maybe_sync_shared_experts_stream( + self, + shared_experts_input: torch.Tensor | None, + ): + # If router/gate provided, then apply it here. + # (Note: This code runs only when "overlapped mode" is on to allow + # parallel execution of shared experts with the FusedMoE via + # separate cuda stream) + if self._shared_experts is not None: + assert shared_experts_input is not None + self._shared_experts.maybe_sync_shared_experts_stream(shared_experts_input) + + def _maybe_add_zero_expert_output( + self, + result: torch.Tensor, + ) -> torch.Tensor: + """Add the zero expert's contribution to the final result. + + When a ZeroExpertRouter is used, it computes a bias-like output + from the "zero expert" that is added to the combined routed+shared + expert output. + """ + if isinstance(self.router, ZeroExpertRouter): + zero_expert_output = self.router.zero_expert_output + assert zero_expert_output is not None + result = result + zero_expert_output + return result + def forward( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: - raise NotImplementedError + """Invoke the fused moe layer. - @abstractmethod - def is_internal_router(self) -> bool: - raise NotImplementedError + Input: + - hidden_states + - router_logits + + Output: + - The new hidden_states. + + Calling sequence + - forward + - self._forward_entry (_moe_forward or _moe_forward_shared custom op) + - _forward_impl + + Note: The existence of _moe_forward and _moe_forward_shared custom ops are due + to the following reason: + 1. pytorch cannot handle union types in custom op signatures so + _moe_forward and _moe_forward_shared must be split. + """ + + # Apply transform for routed experts (e.g., latent projection + # for latent MoE) + hidden_states, shared_experts_input = self.apply_routed_input_transform( + hidden_states + ) + + hidden_states, og_hidden_dim = self._maybe_pad_hidden_states( + shared_experts_input, + hidden_states, + ) + + result = self._forward_entry( + hidden_states, + router_logits, + shared_experts_input, + self._encode_layer_name(), + ) + + # + # Note: there are two all-reduce points below. They are mutually + # exclusive, controlled by _fused_output_is_reduced + # - When True: the combine kernel already reduced fused_output, + # so we reduce shared_output here to match, then skip the + # all-reduce in _maybe_reduce_final_output. + # - When False: neither output is reduced yet, so we combine + # them first and all-reduce the sum in _maybe_reduce_final_output. + + # Extract outputs from result + shared_output, fused_output = _unpack(result) + + # If combine kernel already reduced fused, reduce shared to match. + # See note above re: the two all-reduce points. + shared_output = self._maybe_reduce_shared_expert_output(shared_output) + + shared_output, fused_output = self._maybe_apply_routed_scale_to_output( + shared_output, fused_output + ) + + # Apply output transform (e.g. latent -> full dim) + fused_output = self.apply_routed_output_transform(fused_output) + + if shared_output is not None: + result = shared_output + fused_output + else: + result = fused_output + + result = self._maybe_reduce_final_output(result, og_hidden_dim) + + return self._maybe_add_zero_expert_output(result) @property - @abstractmethod - def shared_experts(self) -> SharedExperts | None: - raise NotImplementedError + def do_naive_dispatch_combine(self) -> bool: + return ( + self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk + ) - # TODO(bnell): temporary hack, do not call this method. - @abstractmethod - def _replace_quant_method(self, quant_method: FusedMoEMethodBase): - raise NotImplementedError + def _maybe_dispatch( + self, + layer: torch.nn.Module, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # For naive dispatch/combine Dp/Ep, dispatch the hidden states and + # router logits to all experts. + # NOTE: this will be removed once all kernels are migrated into the + # MoEKernel framework. + if self.do_naive_dispatch_combine: + result = get_ep_group().dispatch_router_logits( + hidden_states, + router_logits, + self.moe_config.is_sequence_parallel, + ) + assert len(result) == 2 + hidden_states, router_logits = result + + # NOTE: Similar with DP, PCP also needs dispatch and combine. For + # simplicity, AgRsAll2All was added separately for PCP here. Maybe + # we should modify All2AllManager abstraction to better support PCP. + if self.moe_config.pcp_size > 1: + hidden_states = get_pcp_group().all_gather( + hidden_states, + dim=0, + ) + router_logits = get_pcp_group().all_gather( + router_logits, + dim=0, + ) + + return hidden_states, router_logits + + def _maybe_combine( + self, + shared_output: torch.Tensor | None, + hidden_states: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]: + if self.do_naive_dispatch_combine: + hidden_states = get_ep_group().combine( + hidden_states, self.moe_config.is_sequence_parallel + ) + + if self.moe_config.pcp_size > 1: + hidden_states = get_pcp_group().reduce_scatter( + hidden_states, + dim=0, + ) + + if self.shared_experts is not None: + assert shared_output is not None + return shared_output, hidden_states + else: + return hidden_states + + def _forward_impl( + self, + layer: torch.nn.Module, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Entry point called by the custom op to run the MoE computation. + + Handles pre-dispatch setup (gate application, external shared expert + triggering, quant config init) then performs the following steps + within the sequence-parallel context. + + - Performs expert routing + - fused MoE kernel execution + - shared expert computation. + + Returns a single tensor of combined fused and shared output (if present). + """ + # TODO(bnell): this can be removed after MK migration is complete. + layer.ensure_moe_quant_config_init() + + # Sync aux and main stream for shared expert multi-stream overlap. + self._maybe_sync_shared_experts_stream(shared_experts_input) + + # If the Runner holds the gate, apply it after the stream sync, + # so it can run overlapped with the + # NOTE: in future PR, MoE runner will always hold the gate. + if self.gate is not None: + router_logits, _ = self.gate(hidden_states) + + with self._sequence_parallel_context(): + # TODO(bnell): parts of the dispatch/combine steps will go away once + # #32567 lands and the remaining kernels are made MKs. The PCP + # code will probably remain + hidden_states, router_logits = self._maybe_dispatch( + layer, + hidden_states, + router_logits, + ) + + shared_output, hidden_states = self._apply_quant_method( + layer=layer, + hidden_states=hidden_states, + router_logits=router_logits, + shared_experts_input=shared_experts_input, + ) + + return self._maybe_combine( + shared_output, + hidden_states, + ) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py deleted file mode 100644 index 136e1b1f5b20..000000000000 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner_base.py +++ /dev/null @@ -1,639 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import abstractmethod -from collections.abc import Callable -from contextlib import nullcontext -from typing import TYPE_CHECKING - -import torch -import torch.nn.functional as F - -from vllm.distributed import ( - tensor_model_parallel_all_reduce, -) -from vllm.forward_context import ( - ForwardContext, - get_forward_context, - is_forward_context_available, -) -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, -) -from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( - FusedMoEMethodBase, -) -from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( - FusedMoERouter, -) -from vllm.model_executor.layers.fused_moe.router.zero_expert_router import ( - ZeroExpertRouter, -) -from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner -from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( - SharedExperts, - SharedExpertsOrder, -) -from vllm.platforms import current_platform -from vllm.utils.torch_utils import ( - _USE_LAYERNAME, - LayerName, - direct_register_custom_op, -) - - -def get_layer_from_name(layer_name: str) -> torch.nn.Module: - forward_context: ForwardContext = get_forward_context() - if not _USE_LAYERNAME and layer_name == "from_forward_context": - all_moe_layers = forward_context.all_moe_layers - assert all_moe_layers is not None - moe_layer_index = forward_context.moe_layer_index - if moe_layer_index >= len(all_moe_layers): - raise AssertionError( - "We expected the number of MOE layers in `all_moe_layers` " - "to be equal to the number of " - "{vllm.moe_forward, vllm.moe_forward_shared} calls." - ) - layer_name = all_moe_layers[moe_layer_index] - forward_context.moe_layer_index += 1 - return forward_context.no_compile_layers[layer_name] - - -# On torch >= 2.11, layer_name is a hoisted LayerName opaque object; -# on older versions it remains a plain str. -if TYPE_CHECKING: - from typing import TypeAlias - - _layer_name_type: TypeAlias = str | LayerName -else: - _layer_name_type = LayerName if _USE_LAYERNAME else str - - -@torch.compiler.assume_constant_result -def _resolve_layer_name(layer_name: str | LayerName) -> str: - from torch._library.fake_class_registry import FakeScriptObject - - if isinstance(layer_name, LayerName): - return layer_name.value - elif isinstance(layer_name, FakeScriptObject): - return layer_name.real_obj.value - return layer_name - - -# Note: _moe_forward and _moe_forward_shared should not contain any -# implementation details, They should merely pass along control to -# the runner's '_forward_dispatch' method. -# These functions should never be called directly since they do not -# include all the functionality of the MoE layer. -def _moe_forward( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - shared_experts_input: torch.Tensor | None, - layer_name: _layer_name_type, -) -> torch.Tensor: - layer = get_layer_from_name(_resolve_layer_name(layer_name)) - return layer.runner._forward_dispatch( - layer, - hidden_states, - router_logits, - shared_experts_input, - ) - - -def _moe_forward_fake( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - shared_experts_input: torch.Tensor | None, - layer_name: _layer_name_type, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -def _moe_forward_shared( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - shared_experts_input: torch.Tensor | None, - layer_name: _layer_name_type, -) -> tuple[torch.Tensor, torch.Tensor]: - layer = get_layer_from_name(_resolve_layer_name(layer_name)) - return layer.runner._forward_dispatch( - layer, - hidden_states, - router_logits, - shared_experts_input, - ) - - -def _moe_forward_shared_fake( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - shared_experts_input: torch.Tensor | None, - layer_name: _layer_name_type, -) -> tuple[torch.Tensor, torch.Tensor]: - # Output shapes: - # - fused_out: same as hidden_states (routed experts use transformed size) - # - shared_out: same as shared_experts_input if provided, else same as - # hidden_states - # (For latent MoE: shared experts use original hidden_size, not latent size) - fused_out = torch.empty_like(hidden_states) - if shared_experts_input is not None: - shared_out = torch.empty_like(shared_experts_input) - else: - shared_out = torch.empty_like(hidden_states) - return shared_out, fused_out - - -direct_register_custom_op( - op_name="moe_forward", - op_func=_moe_forward, - mutates_args=["hidden_states"], - fake_impl=_moe_forward_fake, - tags=(torch.Tag.needs_fixed_stride_order,), -) - - -direct_register_custom_op( - op_name="moe_forward_shared", - op_func=_moe_forward_shared, - fake_impl=_moe_forward_shared_fake, - tags=(torch.Tag.needs_fixed_stride_order,), -) - - -def _unpack( - result: torch.Tensor | tuple[torch.Tensor, torch.Tensor], -) -> tuple[torch.Tensor | None, torch.Tensor]: - if isinstance(result, tuple): - return result - else: - return (None, result) - - -class MoERunnerBase(MoERunner): - """ - Abstract base class providing common functionality for MoE runner implementations. - - This class serves as the foundation for concrete MoE runner implementations by - providing shared state management and common utilities. It handles: - - Common initialization and configuration management - - Shared expert output reduction logic for tensor parallel scenarios - - Base methods for tensor model parallel reductions - - Common properties and utility functions used across different runner types - - Concrete subclasses must implement the abstract methods to define their specific - execution strategies, such as standard execution, chunked processing, or other - specialized approaches. The base class provides the infrastructure while - allowing flexibility in the actual MoE computation implementation. - - Key abstract methods that subclasses must implement: - - _forward_impl: The core MoE computation logic specific to each runner type - """ - - def __init__( - self, - layer_name: str, - moe_config: FusedMoEConfig, - router: FusedMoERouter, - routed_input_transform: torch.nn.Module | None, - gate: torch.nn.Module | None, - shared_experts: torch.nn.Module | None, - quant_method: FusedMoEMethodBase, - enable_dbo: bool, - routed_output_transform: torch.nn.Module | None = None, - routed_scaling_factor: float = 1.0, - ): - super().__init__() - self.moe_config = moe_config - self.router = router - self.routed_input_transform = routed_input_transform - self.routed_output_transform = routed_output_transform - self.routed_scaling_factor = routed_scaling_factor - self.gate = gate - self.quant_method = quant_method - self.enable_dbo = enable_dbo - self._fused_output_is_reduced = ( - self.quant_method.moe_kernel is not None - and self.quant_method.moe_kernel.output_is_reduced() - ) - - self._shared_experts: SharedExperts | None = None - if shared_experts is not None: - self._shared_experts = SharedExperts( - shared_experts, - moe_config=moe_config, - # Note: For now we must pass quant_method along to SharedExperts so it - # can property determine where the shared experts are supposed to be - # called, i.e. by a MK or by the MoERunner. - # Once the MK can be created upfront, we can just pass in the proper - # flags derived from the quant_method's MK. - quant_method=quant_method, - enable_dbo=enable_dbo, - ) - - # Needed for string -> FusedMoE layer lookup in custom ops. - self.layer_name = layer_name - - self._forward_entry = self._select_forward() - - def _select_forward(self) -> Callable: - if current_platform.is_tpu() or current_platform.is_cpu(): - # TODO: Once the OOM issue for the TPU backend is resolved, we - # will switch to using the moe_forward custom op. - # Note: CPU doesn't require wrapped _forward_impl. - return _moe_forward if self._shared_experts is None else _moe_forward_shared - - return ( - torch.ops.vllm.moe_forward - if self._shared_experts is None - else torch.ops.vllm.moe_forward_shared - ) - - @property - def shared_experts(self) -> SharedExperts | None: - return self._shared_experts - - # TODO(bnell): temporary hack, do not call this method. - def _replace_quant_method(self, quant_method: FusedMoEMethodBase): - if self._shared_experts is not None: - self._shared_experts._quant_method = quant_method - self.quant_method = quant_method - - def is_internal_router(self) -> bool: - return self.gate is not None - - def apply_routed_input_transform( - self, hidden_states: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor | None]: - """Apply transform for routed experts (e.g., latent projection). - - This is called by FusedMoE.forward_native. The original hidden_states - is saved separately so shared experts get [S, hidden_size] while - routed experts get the transformed [S, moe_latent_size]. - - Returns (possibly transformed) hidden states and the input for shared - experts (or None if there are no shared experts). - """ - if self.routed_input_transform is not None: - result = self.routed_input_transform(hidden_states) - # ReplicatedLinear returns (output, extra_bias) tuple. - # We only need the output tensor; extra_bias is not used here. - if isinstance(result, tuple): - return result[0], hidden_states - return result, hidden_states - - return ( - hidden_states, - hidden_states if self._shared_experts is not None else None, - ) - - def apply_routed_output_transform( - self, - fused_output: torch.Tensor, - ) -> torch.Tensor: - """Apply transform to routed expert output (e.g., latent to full dim). - - Used by latent MoE models (e.g., NemotronH) where routed experts - operate in a compressed latent space and need projection back to - the full hidden dimension before combining with shared expert output. - """ - if self.routed_output_transform is not None: - r = self.routed_output_transform(fused_output) - fused_output = r[0] if isinstance(r, tuple) else r - return fused_output - - def _maybe_apply_routed_scale_to_output( - self, - shared_output: torch.Tensor | None, - fused_output: torch.Tensor, - ) -> tuple[torch.Tensor | None, torch.Tensor]: - """Apply routed_scaling_factor to the output with FP16 overflow - protection. - - Scale the fused expert output by routed_scaling_factor. For FP16, - avoid overflow by dividing shared_output by the scale instead - (the decoder layer compensates with matching divisions). - """ - if self.routed_scaling_factor != 1.0: - if fused_output.dtype != torch.float16: - fused_output *= self.routed_scaling_factor - elif shared_output is not None: - shared_output *= 1.0 / self.routed_scaling_factor - return shared_output, fused_output - - def _maybe_reduce_shared_expert_output( - self, - shared_output: torch.Tensor | None, - ) -> torch.Tensor | None: - """All-reduce shared expert output when the combine kernel already - reduced fused output. - - This is the "early" all-reduce path. When the combine kernel produces - already-reduced fused output, shared output must be reduced separately - to match. - """ - if self._fused_output_is_reduced: - assert shared_output is not None - shared_output = tensor_model_parallel_all_reduce(shared_output) - return shared_output - - def _maybe_reduce_final_output( - self, - states: torch.Tensor, - trunc_size: int, - ) -> torch.Tensor: - """Truncate padded dimensions and all-reduce the combined output. - - This is the "late" all-reduce path. When neither fused nor shared - output was individually reduced, the combined sum is all-reduced - here. Skipped when sequence-parallel is active (SP handles its - own reduction) or when the early path already reduced both outputs. - """ - # We don't need to reduce the final output if: - # - We are not running with TP or DP - # - The MK already reduced the fused output itself. - if ( - not self.moe_config.is_sequence_parallel - and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1) - and not self._fused_output_is_reduced - ): - states = tensor_model_parallel_all_reduce(states) - - return states[..., :trunc_size] - - def _encode_layer_name(self) -> str | LayerName: - if _USE_LAYERNAME: - return LayerName(self.layer_name) - # Can be unavailable or None in unittests - if ( - is_forward_context_available() - and get_forward_context().all_moe_layers is not None - ): - return "from_forward_context" - return self.layer_name - - def _maybe_pad_hidden_states( - self, - shared_experts_input: torch.Tensor | None, - hidden_states: torch.Tensor, - ) -> tuple[torch.Tensor, int]: - """Pad hidden_states to moe_config.hidden_dim and compute the - original dimension for later truncation. - - For latent MoE, the routed hidden_states may be smaller than - hidden_dim. Padding ensures uniform tensor sizes through the - fused MoE kernel. The returned trunc_size is used by - _maybe_reduce_final_output to strip the padding from the result. - """ - shared_experts_hidden_dim = ( - shared_experts_input.shape[-1] if shared_experts_input is not None else 0 - ) - transformed_hidden_dim = hidden_states.shape[-1] - if ( - not self.quant_method.skip_forward_padding - and self.moe_config.hidden_dim != transformed_hidden_dim - ): - hidden_states = F.pad( - hidden_states, - (0, self.moe_config.hidden_dim - transformed_hidden_dim), - mode="constant", - value=0.0, - ) - - if self.routed_output_transform is not None and shared_experts_hidden_dim > 0: - orig_hidden_dims = shared_experts_hidden_dim - else: - orig_hidden_dims = transformed_hidden_dim - - return hidden_states, orig_hidden_dims - - def _maybe_apply_shared_experts( - self, - shared_experts_input: torch.Tensor | None, - order: SharedExpertsOrder, - ): - if self._shared_experts is not None: - assert shared_experts_input is not None - self._shared_experts.apply(shared_experts_input, order) - - def _apply_quant_method( - self, - layer: torch.nn.Module, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> tuple[torch.Tensor | None, torch.Tensor]: - """Run expert routing and the fused MoE kernel via the quant method. - - Orchestrates shared expert execution (before/after), expert selection - via the router, and the actual fused MoE computation. Returns - (shared_expert_output, fused_expert_output). - """ - self._maybe_apply_shared_experts( - shared_experts_input, SharedExpertsOrder.NO_OVERLAP - ) - - if self.quant_method.is_monolithic: - fused_out = self.quant_method.apply_monolithic( - layer=layer, - x=hidden_states, - router_logits=router_logits, - ) - else: - topk_weights, topk_ids = self.router.select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - ) - - # Passing shared_experts_input in case SharedExpertsOrder is - # MK_INTERNAL_OVERLAPPED. - fused_out = self.quant_method.apply( - layer=layer, - x=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - shared_experts_input=shared_experts_input, - ) - - self._maybe_apply_shared_experts( - shared_experts_input, - SharedExpertsOrder.MULTI_STREAM_OVERLAPPED, - ) - - return ( - self._shared_experts.output if self._shared_experts is not None else None, - fused_out, - ) - - def _sequence_parallel_context(self): - """Return a context manager for sequence-parallel token - redistribution. - - When sequence parallelism is active, returns a context that handles - local size tracking for proper token scatter/gather. Otherwise - returns a no-op context. - """ - ctx = get_forward_context() - return ( - ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size) - if ctx.dp_metadata - else nullcontext() - ) - - def _maybe_sync_shared_experts_stream( - self, - shared_experts_input: torch.Tensor | None, - ): - # If router/gate provided, then apply it here. - # (Note: This code runs only when "overlapped mode" is on to allow - # parallel execution of shared experts with the FusedMoE via - # separate cuda stream) - if self._shared_experts is not None: - assert shared_experts_input is not None - self._shared_experts.maybe_sync_shared_experts_stream(shared_experts_input) - - def _maybe_add_zero_expert_output( - self, - result: torch.Tensor, - ) -> torch.Tensor: - """Add the zero expert's contribution to the final result. - - When a ZeroExpertRouter is used, it computes a bias-like output - from the "zero expert" that is added to the combined routed+shared - expert output. - """ - if isinstance(self.router, ZeroExpertRouter): - zero_expert_output = self.router.zero_expert_output - assert zero_expert_output is not None - result = result + zero_expert_output - return result - - def forward( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor: - """Invoke the fused moe layer. - - Input: - - hidden_states - - router_logits - - Output: - - The new hidden_states. - - Calling sequence - - forward - - self._forward_entry (_moe_forward or _moe_forward_shared custom op) - - _forward_dispatch - - _forward_impl - - Note: The existence of _moe_forward and _moe_forward_shared custom ops are due - to the following reasons: - 1. the chunking loop in ChunkingMoERunner._forward_impl cannot be compiled by - torch.compile - 2. pytorch cannot handle union types in custom op signatures so - _moe_forward and _moe_forward_shared must be split. - - If ChunkingMoERunner._forward_impl can be implemented via torch.scan we can - potentially get rid of _moe_forward and _moe_forward_shared and collapse the - whole sequence into the 'forward' method. - """ - - # Apply transform for routed experts (e.g., latent projection - # for latent MoE) - hidden_states, shared_experts_input = self.apply_routed_input_transform( - hidden_states - ) - - hidden_states, og_hidden_dim = self._maybe_pad_hidden_states( - shared_experts_input, - hidden_states, - ) - - result = self._forward_entry( - hidden_states, - router_logits, - shared_experts_input, - self._encode_layer_name(), - ) - - # - # Note: there are two all-reduce points below. They are mutually - # exclusive, controlled by _fused_output_is_reduced - # - When True: the combine kernel already reduced fused_output, - # so we reduce shared_output here to match, then skip the - # all-reduce in _maybe_reduce_final_output. - # - When False: neither output is reduced yet, so we combine - # them first and all-reduce the sum in _maybe_reduce_final_output. - - # Extract outputs from result - shared_output, fused_output = _unpack(result) - - # If combine kernel already reduced fused, reduce shared to match. - # See note above re: the two all-reduce points. - shared_output = self._maybe_reduce_shared_expert_output(shared_output) - - shared_output, fused_output = self._maybe_apply_routed_scale_to_output( - shared_output, fused_output - ) - - # Apply output transform (e.g. latent -> full dim) - fused_output = self.apply_routed_output_transform(fused_output) - - if shared_output is not None: - result = shared_output + fused_output - else: - result = fused_output - - result = self._maybe_reduce_final_output(result, og_hidden_dim) - - return self._maybe_add_zero_expert_output(result) - - def _forward_dispatch( - self, - layer: torch.nn.Module, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """Entry point called by the custom op to run the MoE computation. - - Handles pre-dispatch setup (gate application, external shared expert - triggering, quant config init) then delegates to _forward_impl within - the sequence-parallel context. - """ - # TODO(bnell): this can be removed after MK migration is complete. - layer.ensure_moe_quant_config_init() - - # Sync aux and main stream for shared expert multi-stream overlap. - self._maybe_sync_shared_experts_stream(shared_experts_input) - - # If the Runner holds the gate, apply it after the stream sync, - # so it can run overlapped with the - # NOTE: in future PR, MoE runner will always hold the gate. - if self.gate is not None: - router_logits, _ = self.gate(hidden_states) - - with self._sequence_parallel_context(): - return self._forward_impl( - layer, - hidden_states, - router_logits, - shared_experts_input, - ) - - @abstractmethod - def _forward_impl( - self, - layer: torch.nn.Module, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - shared_experts_input: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """Core MoE computation to be implemented by subclasses. - - Performs expert routing, fused MoE kernel execution, and shared - expert computation. Returns a single tensor (fused output only) - or a tuple of (shared_output, fused_output) when shared experts - are present. - """ - raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner_factory.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner_factory.py deleted file mode 100644 index feb4614d8372..000000000000 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner_factory.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, -) -from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( - FusedMoEMethodBase, -) -from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( - FusedMoERouter, -) -from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import ( - DefaultMoERunner, -) -from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner -from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( - SharedExperts, -) - - -def create_moe_runner( - layer_name: str, - moe_config: FusedMoEConfig, - router: FusedMoERouter, - routed_input_transform: torch.nn.Module | None, - gate: torch.nn.Module | None, - shared_experts: SharedExperts | None, - quant_method: FusedMoEMethodBase, - enable_dbo: bool, - routed_output_transform: torch.nn.Module | None = None, - routed_scaling_factor: float = 1.0, -) -> MoERunner: - return DefaultMoERunner( - layer_name, - moe_config, - router, - routed_input_transform, - gate, - shared_experts, - quant_method, - enable_dbo, - routed_output_transform=routed_output_transform, - routed_scaling_factor=routed_scaling_factor, - ) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner_interface.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner_interface.py new file mode 100644 index 000000000000..80bd83e3732e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner_interface.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + +import torch + +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( + SharedExperts, +) + + +class MoERunnerInterface(ABC): + """ + Abstract base class for Mixture of Experts (MoE) runners. + + This class defines the interface that all MoE runner implementations must follow. + MoE runners are responsible for executing the forward pass of MoE layers, handling + expert routing, and managing tensor parallel operations. + """ + + @abstractmethod + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def is_internal_router(self) -> bool: + raise NotImplementedError + + @property + @abstractmethod + def shared_experts(self) -> SharedExperts | None: + raise NotImplementedError + + # TODO(bnell): temporary hack, do not call this method. + @abstractmethod + def _replace_quant_method(self, quant_method: FusedMoEMethodBase): + raise NotImplementedError