diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 584e0449f436..5beb782d7386 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools from collections.abc import Callable import torch @@ -57,6 +58,19 @@ def vllm_topk_sigmoid( return topk_weights, topk_indices +@functools.lru_cache(maxsize=8) +def _aiter_get_num_expert_group(num_experts: int) -> int: + _AITER_MAX_EXPERTS_PER_GROUP = 32 + g = max(1, -(-num_experts // _AITER_MAX_EXPERTS_PER_GROUP)) + while num_experts % g != 0: + g += 1 + assert num_experts % g == 0, f"{num_experts=} not divisible by {g=}" + assert num_experts // g <= _AITER_MAX_EXPERTS_PER_GROUP, ( + f"group size {num_experts // g} exceeds limit {_AITER_MAX_EXPERTS_PER_GROUP}" + ) + return g + + def fused_topk_bias( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -108,6 +122,30 @@ def fused_topk_bias( return topk_weights, topk_ids else: raise ValueError(f"Unsupported scoring function: {scoring_func}") + elif rocm_aiter_ops.is_fused_moe_enabled() and scoring_func == "sigmoid": + M = hidden_states.size(0) + num_experts = gating_output.shape[-1] + num_expert_group = _aiter_get_num_expert_group(num_experts) + if topk >= num_expert_group: + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty( + M, + topk, + dtype=torch.int32 if indices_type is None else indices_type, + device=hidden_states.device, + ) + rocm_aiter_ops.biased_grouped_topk( + gating_output, + e_score_correction_bias.to(gating_output.dtype), + topk_weights, + topk_ids, + num_expert_group=num_expert_group, + topk_group=num_expert_group, + need_renorm=renormalize, + ) + return topk_weights, topk_ids n_routed_experts = gating_output.shape[-1] if scoring_func == "softmax":