Skip to content

[Session R3] Add routed_experts_start_len for absolute routing slice control#24851

Merged
ByronHsu merged 5 commits into
sgl-project:mainfrom
ByronHsu:byron/routed-experts-start-len
May 10, 2026
Merged

[Session R3] Add routed_experts_start_len for absolute routing slice control#24851
ByronHsu merged 5 commits into
sgl-project:mainfrom
ByronHsu:byron/routed-experts-start-len

Conversation

@ByronHsu
Copy link
Copy Markdown
Collaborator

@ByronHsu ByronHsu commented May 9, 2026

Motivation

In multi-turn RL rollouts with MoE models (e.g. Kimi-K2, Qwen3-30B-A3B), return_routed_experts returns the full conversation-length routing data on every request — including the prefix-cached range. As conversations grow, the host gather + ZMQ payload scales as O(seqlen), producing ~1s ITL spikes on long-context requests with high cache-hit ratios. This compounds with DP attention: since all DP ranks synchronize on every decode step, a single rank stalled on a long routed-experts gather blocks the entire batch across all ranks. This makes it the dominant decode bottleneck in our production RL training loop.

The problem

Turn 1:  prompt=[P1]  → generate → routings for [P1 + O1]        # 100 rows, OK
Turn 2:  prompt=[P1+O1+P2] → generate → routings for [P1+O1+P2+O2]  # 500 rows, but only ~100 are new
Turn N:  prompt=[...8K tokens...] → generate → routings for [8K+ON]   # 8K rows, ~100 new — 1s stall

The caller (RL rollout client) only needs the new routing rows, not the full prefix it already accumulated from prior turns.

The solution

Add routed_experts_start_len: Optional[int] that lets callers specify an absolute start position. The server returns routings covering [start_len, seqlen - 1) instead of the full [0, seqlen - 1). The caller sets start_len = len(accumulated_prompt) and gets back only the output-token routings.

Turn 2 with start_len=len(P1+O1):
  → routings for [P2+O2] only  # ~100 rows instead of 500

Experiment results

Kimi-K2-Instruct, pure TP=8 (8× H200), --enable-return-routed-experts --load-format dummy --moe-runner-backend triton --disable-piecewise-cuda-graph --cuda-graph-max-bs 4. 95% prefix-cache hit, 100 output tokens. Timer placed after GPU sync.

Headline numbers

 in_len   cached     full_ms  sliced_ms   speedup
   2048     2047        7.78      0.65    11.96×
   4096     4095       11.37      0.73    15.59×
   8192     8191       14.16      0.86    16.54×
  16384    16383       49.80      1.17    42.46×
  32768    32767      121.23      1.90    63.77×

Per-component breakdown

in_len mode elapsed routed_experts release_kv other (stream_output etc)
2048 full 7.78 ms 2.19 ms 0.20 ms 5.39 ms
2048 sliced 0.65 ms 0.10 ms 0.16 ms 0.39 ms
8192 full 14.16 ms 2.98 ms 0.40 ms 10.78 ms
8192 sliced 0.86 ms 0.09 ms 0.39 ms 0.38 ms
32768 full 121.23 ms 30.12 ms 1.51 ms 89.60 ms
32768 sliced 1.90 ms 0.10 ms 1.43 ms 0.37 ms

Key findings:

  • routed_experts host gather (aten::index): grows linearly in full mode (2.2 ms at 2k → 30 ms at 32k); flat ~0.1 ms with start_len at all sizes. This PR eliminates this cost.
  • stream_output (ZMQ/pickle serialization): 5 ms → 90 ms in full mode (proportional to the 60 MB payload); flat ~0.4 ms with start_len (tiny payload).
  • release_kv: grows linearly with seqlen in both modes (0.16 ms → 1.43 ms). This is the radix tree walk — unrelated to routed experts and the sole source of small residual growth.

Changes

Request lifecycle plumbingrouted_experts_start_len: Optional[int] = None added to:

  • GenerateReqInput, TokenizedGenerateReqInput (io_struct.py)
  • CompletionRequest, ChatCompletionRequest (protocol.py)
  • Engine.generate, Engine.async_generate (engine.py)
  • Req (schedule_batch.py)
  • Pass-through in serving_chat, serving_completions, tokenizer_manager, scheduler, session_controller, encode_receiver

Server-side logic:

  • BaseTopkCapturer.get_topk() gains a start_len parameter with defensive clamping
  • maybe_collect_routed_experts() honors req.routed_experts_start_len, early-returns when return_routed_experts is False, and logs row-count mismatches as soft warnings
  • Scheduler validates start_len <= prompt_tokens and aborts with a clear error otherwise

Tests:

  • TestRoutedExpertsStartLen: 4 test cases covering default behavior, row-count correctness, bounds checking (abort on too-large start_len), and cache-hit interaction (radix prefix extends past start_len)

Test plan

  • Existing TestReturnRoutedExperts passes (no regression to full-sequence return)
  • New TestRoutedExpertsStartLen passes with TP=2 on Qwen3-30B-A3B
  • start_len=None (default) produces identical output to omitting the field
  • start_len=N returns exactly seqlen - 1 - N rows matching the tail of full return
  • start_len > prompt_tokens aborts with clear error message
  • Cache-hit case: radix prefix allowed to extend past start_len

…trol

Add `routed_experts_start_len` parameter that lets callers specify an
absolute start position for returned routed-expert data, covering
`[start_len, seqlen-1)`. This gives RL rollout callers explicit
control over which routing rows are returned — useful when the caller
already knows the prompt prefix length and wants output-only or
partial-prompt routings without relying on cache heuristics.

Motivation: In multi-turn RL rollouts, the accumulated routed-experts
data grows with the full conversation length. Without slicing control,
every request returns the full sequence's routing data including the
prefix-cached range, causing O(seqlen) host gathers and ZMQ payloads
that produce ~1s ITL spikes on long-context requests with high
cache-hit ratios. With `routed_experts_start_len`, callers can request
only the new tokens' routings, reducing the per-finish cost to
O(seqlen - start_len).

Changes:
- Add `routed_experts_start_len: Optional[int] = None` field across the
  full request lifecycle: GenerateReqInput, TokenizedGenerateReqInput,
  OpenAI CompletionRequest/ChatCompletionRequest, Engine.generate/
  async_generate, Req, tokenizer_manager, scheduler, session_controller,
  encode_receiver, and serving_chat/serving_completions.
- Add validation in scheduler: abort if start_len > prompt_tokens.
- Update BaseTopkCapturer.get_topk() with start_len parameter and
  defensive clamping.
- Update maybe_collect_routed_experts to honor start_len, early-return
  when return_routed_experts is False, and log row-count mismatches.
- Add comprehensive TestRoutedExpertsStartLen test class covering
  default behavior, row-count correctness, bounds checking, and
  cache-hit interaction.

Co-authored-by: Cursor <cursoragent@cursor.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ByronHsu
Copy link
Copy Markdown
Collaborator Author

ByronHsu commented May 9, 2026

cc @zyzshishui to take over

@ByronHsu ByronHsu changed the title [Feature] Add routed_experts_start_len for absolute routing slice control [Session R3] Add routed_experts_start_len for absolute routing slice control May 9, 2026
@ispobock
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Comment thread python/sglang/srt/managers/scheduler.py Outdated
return

if (
recv_req.routed_experts_start_len is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add if recv_req.return_routed_experts and

Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to add some doc

@zyzshishui zyzshishui requested a review from wisclmy0611 as a code owner May 10, 2026 06:33
@github-actions github-actions Bot added the documentation Improvements or additions to documentation label May 10, 2026
@zyzshishui
Copy link
Copy Markdown
Contributor

better to add some doc

added, plz check again

@ByronHsu ByronHsu merged commit d82e339 into sgl-project:main May 10, 2026
125 of 135 checks passed
ByronHsu added a commit that referenced this pull request May 10, 2026
…bsolute routing slice control (#24904)

Co-authored-by: Byron Hsu <byron@periodiclabs.ai>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: zyzshishui <zyzshishui@gmail.com>
Co-authored-by: Yuzhen Zhou <82826991+zyzshishui@users.noreply.github.com>
ltcs11 added a commit to ltcs11/sglang that referenced this pull request May 11, 2026
* main: (87 commits)
  [Fix] Disable FlashInfer allreduce fusion under deterministic inference (sgl-project#24629)
  fix: STANDALONE spec-decode hidden-size mismatch crash (sgl-project#24217)
  Followup fix for Custom AR V2 in non NVL scenarios (sgl-project#24742)
  Fix reduce_scatterv producer contract for SUM_LEN (sgl-project#24785)
  [NPU]Documentation update for communications quantization feature (sgl-project#24668)
  [Session R3] Add routed_experts_start_len for absolute routing slice control (sgl-project#24851)
  [Model] Add MiniCPM-V 4.6 support (sgl-project#24855)
  Support Intern-S2-Preview (sgl-project#24875)
  [PD] Unify dsv4 dispatch with swa (sgl-project#24888)
  Optimize MHC pipeline: DeepGemm, fused norm, fused hc_head (sgl-project#24775)
  Fix PD bootstrap failure handling (sgl-project#24772)
  [Spec] Cleanup idle stub and shape-check patterns (sgl-project#24881)
  [Bug] Add dsv4 state_type branch to mooncake disaggregation (sgl-project#24878)
  [Spec V1] Split draft-extend phase from `EagleDraftInput` into new `EagleDraftExtendInput` (sgl-project#24859)
  [Gemma4] Optimize Gemm4 with fused Q/K/V RMSNorm + per-expert FP8 ckpt loader (sgl-project#24696)
  [spec decoding] support kimi-k2.5-eagle3-mla (sgl-project#24826)
  [SPEC V2] fix: skip stale state updates in spec-v2 overlap (sgl-project#23456)
  [RL] Call torch.cuda.empty_cache() for `in-place` pause mode to avoid OOM (sgl-project#24854)
  [diffusion] CI: add cache-dit CI tests (sgl-project#19213)
  [Utils] Make request dump robust to unpicklable server_args and large meta_info (sgl-project#24767)
  ...

# Conflicts:
#	python/sglang/srt/utils/common.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants