Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,37 @@ def uses_weight_scale_2_pattern(self) -> bool:
"""
return False

def _maybe_add_dp_ep_naive_fallback(
self,
prepare_finalize: FusedMoEPrepareAndFinalize | None,
) -> FusedMoEPrepareAndFinalize | None:
"""
Ensure DP+EP without all2all still gets dispatch/combine via naive
prepare/finalize.
"""
if (
prepare_finalize is None
and not self.moe.moe_parallel_config.use_all2all_kernels
and self.moe.dp_size > 1
and self.moe.use_ep
):
from .naive_prepare_finalize import FusedMoENaivePrepareAndFinalize

return FusedMoENaivePrepareAndFinalize()
return prepare_finalize

def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
from .all2all_utils import maybe_make_prepare_finalize

return maybe_make_prepare_finalize(
prepare_finalize = maybe_make_prepare_finalize(
self.moe, self.moe_quant_config, routing_tables
)

return self._maybe_add_dp_ep_naive_fallback(prepare_finalize)

def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,18 @@ def apply(
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
prepare_finalize = self.fused_experts.prepare_finalize
hidden_states, router_logits = prepare_finalize.preprocess_inputs(
x, router_logits, layer
)

topk_weights, topk_ids, zero_expert_result = layer.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)

result = self.fused_experts(
hidden_states=x,
hidden_states=hidden_states,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
Expand All @@ -110,4 +115,14 @@ def apply(
expert_map=None if self.disable_expert_map else layer.expert_map,
)

return result
result = prepare_finalize.postprocess_output(result, layer)

zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
if zero_expert_num != 0 and zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
return result, zero_expert_result
else:
return result
32 changes: 24 additions & 8 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,11 +1931,11 @@ def forward_impl(
else:
hidden_states_combined, router_logits = dispatch_res

# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states)
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states)

# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
Expand All @@ -1950,6 +1950,12 @@ def forward_impl(
dim=0,
)

# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states)

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
Expand All @@ -1958,6 +1964,7 @@ def forward_impl(
else hidden_states,
router_logits=router_logits,
)
zero_expert_result: torch.Tensor | None = None

if has_separate_shared_experts:
assert self.shared_experts is not None
Expand All @@ -1978,8 +1985,14 @@ def forward_impl(
shared_output,
final_hidden_states,
)
elif (
self.zero_expert_num is not None
and self.zero_expert_num > 0
and isinstance(final_hidden_states, tuple)
):
final_hidden_states, zero_expert_result = final_hidden_states

def combine_output(states: torch.Tensor) -> torch.Tensor:
def reduce_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states, self.is_sequence_parallel)

Expand All @@ -1994,10 +2007,13 @@ def combine_output(states: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
reduce_output(final_hidden_states[1]),
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, torch.Tensor)
return (reduce_output(final_hidden_states), zero_expert_result)
else:
return combine_output(final_hidden_states)
return reduce_output(final_hidden_states)

@classmethod
def make_expert_params_mapping(
Expand Down
19 changes: 18 additions & 1 deletion vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import final
from typing import Any, final

import torch

Expand Down Expand Up @@ -166,6 +166,15 @@ def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
"""
return

def preprocess_inputs(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer: torch.nn.Module,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Optional hook that can modify tensors prior to routing."""
return hidden_states, router_logits

@abstractmethod
def prepare(
self,
Expand Down Expand Up @@ -200,6 +209,14 @@ def prepare(
"""
raise NotImplementedError

def postprocess_output(
self,
result: Any,
layer: torch.nn.Module,
) -> Any:
"""Optional hook that can modify tensors after finalize completes."""
return result

def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async and
Expand Down
49 changes: 49 additions & 0 deletions vllm/model_executor/layers/fused_moe/naive_prepare_finalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Naive prepare/finalize implementation for EP+DP without all2all kernels."""

import torch

from vllm.distributed import get_ep_group
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)


class FusedMoENaivePrepareAndFinalize(MoEPrepareAndFinalizeNoEP):
"""Dispatch/combine via prepare/finalize hooks for DP+EP without all2all."""

def preprocess_inputs(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer: torch.nn.Module,
) -> tuple[torch.Tensor, torch.Tensor]:
# Require is_sequence_parallel to be set to avoid silent misrouting
is_sequence_parallel = layer.is_sequence_parallel
return get_ep_group().dispatch(
hidden_states, router_logits, is_sequence_parallel
)

def postprocess_output(
self,
result,
layer: torch.nn.Module,
):
shared_experts = getattr(layer, "shared_experts", None)
zero_expert_num = getattr(layer, "zero_expert_num", 0) or 0
if isinstance(result, tuple):
if shared_experts is not None:
shared_output, expert_output = result
return shared_output, self._combine(expert_output, layer)
if zero_expert_num > 0:
expert_output, aux = result
return self._combine(expert_output, layer), aux
return self._combine(result, layer)

@staticmethod
def _combine(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
if tensor.numel() == 0:
return tensor
is_sequence_parallel = layer.is_sequence_parallel
return get_ep_group().combine(tensor, is_sequence_parallel)
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def maybe_make_prepare_finalize(
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if self.rocm_aiter_moe_enabled:
return None
return self._maybe_add_dp_ep_naive_fallback(None)
else:
return super().maybe_make_prepare_finalize(routing_tables)

Expand Down