diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index f2365d70ee9..b372858f7be 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -183,12 +183,15 @@ def forward_cuda( *, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + sm_first: bool = False, # only used for triton kernels topk ) -> TopKOutput: if self.use_triton_kernels: - routing_data, gather_idx, scatter_idx = routing( - router_logits, self.top_k, self.renormalize + return triton_kernels_topk( + router_logits=router_logits, + topk=self.top_k, + renormalize=self.renormalize, + sm_first=sm_first, ) - return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) else: torch_native = False return select_experts( @@ -644,6 +647,22 @@ def biased_grouped_topk_cpu( ) +def triton_kernels_topk( + router_logits: torch.Tensor, + topk: int, + renormalize: bool = False, + sm_first: bool = False, +) -> TritonKernelTopKOutput: + """Top-K routing for Triton kernels MoE.""" + assert not renormalize, "Triton kernels topk doesn't support renormalize" + routing_data, gather_idx, scatter_idx = routing( + logits=router_logits, + n_expts_act=topk, + sm_first=sm_first, + ) + return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) + + if _is_cpu and _is_cpu_amx_available: biased_grouped_topk = biased_grouped_topk_cpu grouped_topk = grouped_topk_cpu