diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bf5155434160..7a0a3718cb80 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -35,6 +35,9 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_grouped_topk, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) @@ -1295,6 +1298,7 @@ def __init__( topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, + num_fused_shared_experts: int = 0, ) -> None: super().__init__() self.native_impl = grouped_topk @@ -1304,6 +1308,7 @@ def __init__( self.topk_group = topk_group self.scoring_func = scoring_func self.routed_scaling_factor = routed_scaling_factor + self.num_fused_shared_experts = num_fused_shared_experts def forward_native( self, @@ -1333,6 +1338,32 @@ def forward_cuda( hidden_states, gating_output, e_score_correction_bias ) + def forward_hip( + self, + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + e_score_correction_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if rocm_aiter_ops.is_fused_moe_enabled(): + if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): + assert self.num_fused_shared_experts == 0 + return rocm_aiter_grouped_topk( + hidden_states, + gating_output, + self.topk, + self.renormalize, + self.num_expert_group, + self.topk_group, + self.scoring_func, + self.routed_scaling_factor, + e_score_correction_bias, + self.num_fused_shared_experts, + ) + else: + return self.forward_native( + hidden_states, gating_output, e_score_correction_bias + ) + @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def eplb_map_to_physical_and_record( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2e7267d56d83..f0d94bfbcaba 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -4,7 +4,6 @@ from collections.abc import Callable, Iterable from contextlib import nullcontext from enum import Enum -from functools import partial from typing import Literal, cast, get_args, overload import torch @@ -67,9 +66,6 @@ def _eplb_map_to_physical_and_record( eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk, -) if current_platform.is_tpu(): from .moe_pallas import fused_moe as fused_moe_pallas @@ -1583,28 +1579,15 @@ def valid_grouping() -> bool: elif self.use_grouped_topk and valid_grouping(): assert self.topk_group is not None assert self.num_expert_group is not None - if rocm_aiter_ops.is_fused_moe_enabled(): - if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): - assert self.num_fused_shared_experts == 0 - grouped_topk_impl = partial( - rocm_aiter_grouped_topk, - num_fused_shared_experts=self.num_fused_shared_experts, - topk=self.top_k, - renormalize=self.renormalize, - num_expert_group=self.num_expert_group, - topk_group=self.topk_group, - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - ) - else: - grouped_topk_impl = GroupedTopk( - topk=self.top_k, - renormalize=self.renormalize, - num_expert_group=self.num_expert_group, - topk_group=self.topk_group, - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - ) + grouped_topk_impl = GroupedTopk( + topk=self.top_k, + renormalize=self.renormalize, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + num_fused_shared_experts=self.num_fused_shared_experts, + ) topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states,