diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 068bc67cdaae..892dcebaea81 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -129,6 +129,7 @@ def fused_topk_deepseek( if _use_aiter: try: from aiter import biased_grouped_topk as aiter_biased_grouped_topk + from aiter.fused_moe import fused_topk as aiter_fused_topk except ImportError: raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") @@ -511,12 +512,24 @@ def fused_topk( topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) if scoring_func == "softmax": - topk_softmax( - topk_weights, - topk_ids, - gating_output, - renormalize, - ) + if _use_aiter: + + # Use fused_topk instead of topk_softmax to auto dispatch to the correct kernel + topk_weights, topk_ids = aiter_fused_topk( + hidden_states, + gating_output, + topk, + renormalize, + topk_ids=topk_ids, + topk_weights=topk_weights, + ) + else: + topk_softmax( + topk_weights, + topk_ids, + gating_output, + renormalize, + ) elif scoring_func == "sigmoid": topk_sigmoid( topk_weights,