Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ def fused_topk(
)
del token_expert_indicies

return _fused_topk_postprocess(
topk_weights=topk_weights,
topk_ids=topk_ids,
renormalize=renormalize,
expert_location_dispatch_info=expert_location_dispatch_info,
num_token_non_padded=num_token_non_padded,
)


@torch.compile(dynamic=True, backend=get_compiler_backend())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of @torch.compile(dynamic=True) for _fused_topk_postprocess is noted. This is often necessary when tensor shapes, particularly batch dimensions, can vary.

Could you share any insights or benchmarks regarding this compilation strategy?

  • Was dynamic=True chosen due to specific dynamic shapes encountered (e.g., the first dimension of topk_weights and topk_ids) that prevent static compilation or fullgraph=True?
  • Have other options, like fullgraph=True or providing specific shape guards if the dynamism is limited/predictable, been considered or benchmarked?

Understanding the trade-offs and performance impact here would be valuable, as dynamic=True can sometimes introduce compilation overhead or prevent certain optimizations compared to a fully static graph.

def _fused_topk_postprocess(
topk_weights,
topk_ids,
renormalize,
expert_location_dispatch_info,
num_token_non_padded,
):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
Expand Down Expand Up @@ -313,7 +330,6 @@ def select_experts(
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):

router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
router_logits=router_logits,
Expand Down
Loading