diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index e4248819265d..30b7cc5da496 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -72,7 +72,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: - from sgl_kernel import moe_fused_gate + from sgl_kernel import kimi_k2_moe_fused_gate, moe_fused_gate if _is_cuda or _is_hip: from sgl_kernel import topk_softmax @@ -817,16 +817,13 @@ def biased_grouped_topk_gpu( else: # Use optimized path for Kimi K2 (384 experts with num_expert_group=1) num_experts = gating_output.shape[1] - if num_experts == 384 and num_expert_group == 1: - return kimi_k2_biased_topk_impl( - hidden_states, - gating_output, + if _is_cuda and num_experts == 384 and num_expert_group == 1: + return kimi_k2_moe_fused_gate( + gating_output.to(dtype=torch.float32), correction_bias, - topk, - renormalize, + topk=topk, + renormalize=renormalize, routed_scaling_factor=routed_scaling_factor, - num_token_non_padded=num_token_non_padded, - expert_location_dispatch_info=expert_location_dispatch_info, apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) else: