From c7f029d0d65938c059ca93c89f92b370fb6fa6f0 Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Tue, 31 Mar 2026 18:01:41 +0000 Subject: [PATCH] [Perf] Restore torch.compile fusion for topk postprocessing PR #16945 refactored topk postprocessing into `_post_process_topk_ids` but inlined the `topk_ids_logical_to_physical` and `_mask_topk_ids_padded_region` calls instead of delegating to the existing `@torch.compile`-decorated `_biased_grouped_topk_postprocess`. This caused those two operations to run as separate eager kernels instead of being fused by torch.compile, a regression for CUDA paths using expert-parallel / EPLB. Fix: call `_biased_grouped_topk_postprocess` (which already carries `@torch.compile(dynamic=True)`) from within `_post_process_topk_ids`, restoring the compiled kernel fusion. Ref: https://github.com/sgl-project/sglang/pull/16945#discussion_r2682016393 --- python/sglang/srt/layers/moe/topk.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 892dcebaea81..bb6691814cdd 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -956,8 +956,9 @@ def _post_process_topk_ids( topk_ids=topk_ids, ) if _is_cuda: - topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) - _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + topk_ids = _biased_grouped_topk_postprocess( + topk_ids, expert_location_dispatch_info, num_token_non_padded + ) if num_fused_shared_experts > 0 and _use_aiter: M, N = router_logits.shape