diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index a46e3972ed8e..5b7dc7542f72 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -49,16 +49,37 @@ def uses_weight_scale_2_pattern(self) -> bool: """ return False + def _maybe_add_dp_ep_naive_fallback( + self, + prepare_finalize: FusedMoEPrepareAndFinalize | None, + ) -> FusedMoEPrepareAndFinalize | None: + """ + Ensure DP+EP without all2all still gets dispatch/combine via naive + prepare/finalize. + """ + if ( + prepare_finalize is None + and not self.moe.moe_parallel_config.use_all2all_kernels + and self.moe.dp_size > 1 + and self.moe.use_ep + ): + from .naive_prepare_finalize import FusedMoENaivePrepareAndFinalize + + return FusedMoENaivePrepareAndFinalize() + return prepare_finalize + def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> FusedMoEPrepareAndFinalize | None: from .all2all_utils import maybe_make_prepare_finalize - return maybe_make_prepare_finalize( + prepare_finalize = maybe_make_prepare_finalize( self.moe, self.moe_quant_config, routing_tables ) + return self._maybe_add_dp_ep_naive_fallback(prepare_finalize) + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 30ff1bf2f008..8a2bfc2f5d99 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -92,13 +92,18 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( - hidden_states=x, + prepare_finalize = self.fused_experts.prepare_finalize + hidden_states, router_logits = prepare_finalize.preprocess_inputs( + x, router_logits, layer + ) + + topk_weights, topk_ids, zero_expert_result = layer.select_experts( + hidden_states=hidden_states, router_logits=router_logits, ) result = self.fused_experts( - hidden_states=x, + hidden_states=hidden_states, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, @@ -110,4 +115,14 @@ def apply( expert_map=None if self.disable_expert_map else layer.expert_map, ) - return result + result = prepare_finalize.postprocess_output(result, layer) + + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + if zero_expert_num != 0 and zero_expert_type is not None: + assert not isinstance(result, tuple), ( + "Shared + zero experts are mutually exclusive not yet supported" + ) + return result, zero_expert_result + else: + return result diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f0d94bfbcaba..941b9479a359 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1931,11 +1931,11 @@ def forward_impl( else: hidden_states_combined, router_logits = dispatch_res - # Run shared experts before matrix multiply. - # because matrix multiply maybe modify the hidden_states. - if has_separate_shared_experts and not use_shared_experts_stream: - assert self.shared_experts is not None - shared_output = self.shared_experts(hidden_states) + # Run shared experts before matrix multiply. + # because matrix multiply maybe modify the hidden_states. + if has_separate_shared_experts and not use_shared_experts_stream: + assert self.shared_experts is not None + shared_output = self.shared_experts(hidden_states) # NOTE: Similar with DP, PCP also needs dispatch and combine. For # simplicity, AgRsAll2All was added separately for PCP here. Maybe @@ -1950,6 +1950,12 @@ def forward_impl( dim=0, ) + # Run shared experts before matrix multiply. + # because matrix multiply maybe modify the hidden_states. + if has_separate_shared_experts and not use_shared_experts_stream: + assert self.shared_experts is not None + shared_output = self.shared_experts(hidden_states) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1958,6 +1964,7 @@ def forward_impl( else hidden_states, router_logits=router_logits, ) + zero_expert_result: torch.Tensor | None = None if has_separate_shared_experts: assert self.shared_experts is not None @@ -1978,8 +1985,14 @@ def forward_impl( shared_output, final_hidden_states, ) + elif ( + self.zero_expert_num is not None + and self.zero_expert_num > 0 + and isinstance(final_hidden_states, tuple) + ): + final_hidden_states, zero_expert_result = final_hidden_states - def combine_output(states: torch.Tensor) -> torch.Tensor: + def reduce_output(states: torch.Tensor) -> torch.Tensor: if do_naive_dispatch_combine: states = get_ep_group().combine(states, self.is_sequence_parallel) @@ -1994,10 +2007,13 @@ def combine_output(states: torch.Tensor) -> torch.Tensor: if self.shared_experts is not None: return ( final_hidden_states[0], - combine_output(final_hidden_states[1]), + reduce_output(final_hidden_states[1]), ) + elif self.zero_expert_num is not None and self.zero_expert_num > 0: + assert isinstance(final_hidden_states, torch.Tensor) + return (reduce_output(final_hidden_states), zero_expert_result) else: - return combine_output(final_hidden_states) + return reduce_output(final_hidden_states) @classmethod def make_expert_params_mapping( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 25308b3106a4..bed7c2cd488d 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import Enum from math import prod -from typing import final +from typing import Any, final import torch @@ -166,6 +166,15 @@ def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"): """ return + def preprocess_inputs( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer: torch.nn.Module, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Optional hook that can modify tensors prior to routing.""" + return hidden_states, router_logits + @abstractmethod def prepare( self, @@ -200,6 +209,14 @@ def prepare( """ raise NotImplementedError + def postprocess_output( + self, + result: Any, + layer: torch.nn.Module, + ) -> Any: + """Optional hook that can modify tensors after finalize completes.""" + return result + def supports_async(self) -> bool: """ Indicates whether or not this class implements prepare_async and diff --git a/vllm/model_executor/layers/fused_moe/naive_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/naive_prepare_finalize.py new file mode 100644 index 000000000000..db1912670386 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/naive_prepare_finalize.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Naive prepare/finalize implementation for EP+DP without all2all kernels.""" + +import torch + +from vllm.distributed import get_ep_group +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) + + +class FusedMoENaivePrepareAndFinalize(MoEPrepareAndFinalizeNoEP): + """Dispatch/combine via prepare/finalize hooks for DP+EP without all2all.""" + + def preprocess_inputs( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer: torch.nn.Module, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Require is_sequence_parallel to be set to avoid silent misrouting + is_sequence_parallel = layer.is_sequence_parallel + return get_ep_group().dispatch( + hidden_states, router_logits, is_sequence_parallel + ) + + def postprocess_output( + self, + result, + layer: torch.nn.Module, + ): + shared_experts = getattr(layer, "shared_experts", None) + zero_expert_num = getattr(layer, "zero_expert_num", 0) or 0 + if isinstance(result, tuple): + if shared_experts is not None: + shared_output, expert_output = result + return shared_output, self._combine(expert_output, layer) + if zero_expert_num > 0: + expert_output, aux = result + return self._combine(expert_output, layer), aux + return self._combine(result, layer) + + @staticmethod + def _combine(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor: + if tensor.numel() == 0: + return tensor + is_sequence_parallel = layer.is_sequence_parallel + return get_ep_group().combine(tensor, is_sequence_parallel) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 029edc44cd77..b0d934a40aa6 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -109,7 +109,7 @@ def maybe_make_prepare_finalize( routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> FusedMoEPrepareAndFinalize | None: if self.rocm_aiter_moe_enabled: - return None + return self._maybe_add_dp_ep_naive_fallback(None) else: return super().maybe_make_prepare_finalize(routing_tables)