Skip to content

[sglang-miles] Cherry-pick #22911: return_routed_experts with overlap scheduling#23854

Closed
ByronHsu wants to merge 1 commit into
sgl-project:sglang-milesfrom
ByronHsu:cherry-pick-22911-overlap-routed-experts
Closed

[sglang-miles] Cherry-pick #22911: return_routed_experts with overlap scheduling#23854
ByronHsu wants to merge 1 commit into
sgl-project:sglang-milesfrom
ByronHsu:cherry-pick-22911-overlap-routed-experts

Conversation

@ByronHsu
Copy link
Copy Markdown
Collaborator

Summary

Cherry-pick of upstream #22911 (c5603268 — "[perf] support return_routed_experts with overlap scheduling") onto sglang-miles.

Resolved two conflicts to keep sglang-miles-specific bits on top of upstream's refactor:

  • routed_experts_capturer.py: kept the DeepEP gather_buffer / all-gather path in capture(). Added _get_local_range() from upstream but with the DeepEP-aware not get_moe_a2a_backend().is_deepep() guard so we don't slice by DP rank when capture() already all-gathered.
  • model_runner.py: kept the miles if not self.is_draft_worker: guard and the cuda_graph_num_tokens = bs * num_tokens_per_bs fix for speculative decoding, on top of upstream's no_copy_to_cpu = not server_args.disable_overlap_schedule plumbing and output.routed_experts_output assignment.

Auto-merged: scheduler_output_processor_mixin.py, tp_worker.py, managers/utils.py, eagle_worker_v2.py, multi_layer_eagle_worker_v2.py.

Verification

Throughputpython -m sglang.bench_one_batch_server --model Qwen/Qwen3-30B-A3B --tp 4 --enable-return-routed-experts --batch-size 64 --input-len 1024 --output-len 512 on 4xH200:

Output throughput Latency
Before (parent e0790b54f) 7453.48 tok/s 5.01s
After (this commit) 8609.21 tok/s 4.34s
Δ +15.5% -13.4%

Accuracypython test/registered/rl/test_return_routed_experts.py (TP=2 DP=2, baseline = overlap OFF / sync DtoH, reference = overlap ON / new async path):

test_return_routed_experts (/generate)              ... ok   0 / 8,442,240 mismatches
test_return_routed_experts_chat_completions         ... ok   0 / 9,517,440 mismatches
test_return_routed_experts_completions              ... ok   0 / 8,745,600 mismatches
Ran 3 tests in 176.729s
OK

Test plan

  • Bench: throughput recovers above pre-regression baseline on Qwen3-30B-A3B TP=4 with return_routed_experts on
  • Accuracy: test_return_routed_experts passes (3/3 endpoints, 0 mismatches across ~26.7M expert IDs)
  • CI: stage-b-test-2-gpu-large suite green

🤖 Generated with Claude Code

…ject#22911)

Cherry-pick of upstream c560326 onto sglang-miles. Resolved conflicts
in routed_experts_capturer.py and model_runner.py to keep the miles
draft-worker guard, the bs * num_tokens_per_bs cuda_graph fix, and the
DeepEP all-gather path on top of upstream's _get_local_range /
no_copy_to_cpu refactor.

Verified on H200 TP=4 Qwen3-30B-A3B (batch=64, in=1024, out=512):
output throughput 7453.48 -> 8609.21 tok/s (+15.5%). Router replay
accuracy test (test_return_routed_experts) passes 3/3 with 0
mismatches across ~26.7M expert IDs.

Co-authored-by: Yuzhen Zhou <82826991+zyzshishui@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ByronHsu ByronHsu closed this Apr 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants