Skip to content

[Disagg] Finalize routed_experts_output in process_batch_result_disagg_prefill#23885

Merged
ByronHsu merged 1 commit into
sgl-project:mainfrom
ByronHsu:fix-disagg-prefill-routed-experts-finalize
Apr 27, 2026
Merged

[Disagg] Finalize routed_experts_output in process_batch_result_disagg_prefill#23885
ByronHsu merged 1 commit into
sgl-project:mainfrom
ByronHsu:fix-disagg-prefill-routed-experts-finalize

Conversation

@ByronHsu
Copy link
Copy Markdown
Collaborator

@ByronHsu ByronHsu commented Apr 27, 2026

Motivation

PR #22911 ([perf] support return_routed_experts with overlap scheduling) introduced a deferred-D2H path for the captured routed-expert IDs. After copy_done.synchronize(), callers must invoke RoutedExpertsOutput.finalize() to write the CPU-side tensor into host_cache.buffer. The agg-mode handlers in scheduler_output_processor_mixin.py (process_batch_result_prefill, process_batch_result_decode) were updated to do this; the PD-disagg prefill handler in disaggregation/prefill.py was missed.

As a result, in PD-disagg mode the prefill worker's host_cache.buffer is never written for prefill slots — they stay at the initial torch.zeros(...). maybe_collect_routed_experts(req) then reads zeros into req.routed_experts for every prompt position.

Symptom

PD-disagg + --enable-return-routed-experts + overlap scheduling: every prompt token's top_k row in the response is [0, 0, ..., 0]. Routing replay in downstream trainers (e.g. Megatron) asserts:

AssertionError: [rank 0] Duplicate experts in routing! unique_counts[:10]=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], expected=8

This breaks RL workloads that rely on routing replay (e.g. --enable-routing-replay flows), since the prefill prompt tokens come back as all-zero topk rows that the trainer interprets as 8 duplicate selections of expert 0.

Modifications

Add the same finalize call after copy_done.synchronize() in process_batch_result_disagg_prefill, mirroring the agg-mode handlers:

if copy_done is not None:
    copy_done.synchronize()
if result.routed_experts_output is not None:           # added
    result.routed_experts_output.finalize()            # added
    result.routed_experts_output = None                # added

Is this also needed for --disable-overlap-schedule?

No — fix is a no-op in that mode, but it's the only mode where the fix is required.

When --disable-overlap-schedule is set, model_runner.py:2989 flips no_copy_to_cpu = False, and on_forward_end takes the else branch at routed_experts_capturer.py:269: it calls _sync_fwd_experts_buffer_DtoH(...) which writes straight into host_cache.buffer synchronously and returns None. So result.routed_experts_output is None and the new if result.routed_experts_output is not None: ... finalize() block is a no-op.

So:

  • Overlap on (default): fix is required — that's the regression.
  • Overlap off (--disable-overlap-schedule): old sync _sync_fwd_experts_buffer_DtoH path writes the host buffer inside the forward; the added block doesn't fire and isn't needed.

Verification

Reproduced locally on Qwen3-30B-A3B with 1 prefill GPU + 1 decode GPU + mini-lb router + --enable-return-routed-experts + overlap scheduling on, mini-lb _merge_routed_experts enabled (PR #22916), 16 concurrent generation requests, max_new_tokens=32:

total tokens bad (token, layer) rows of which all-zero
pre-fix 714 10 464 10 464
post-fix 714 0 0

Pre-fix all-zero rows are concentrated on prompt positions (decode worker's host buffer never had them, prefill worker's was the one that needed populating). Post-fix all rows have the expected 8 unique expert IDs per token.

Checklist

🤖 Generated with Claude Code

…g_prefill

PR sgl-project#22911 ("[perf] support return_routed_experts with overlap scheduling")
introduced a deferred-D2H path for the captured routed-expert IDs. The
agg-mode result handlers in scheduler_output_processor_mixin.py call
RoutedExpertsOutput.finalize() after copy_done.synchronize() so the
CPU-side tensor lands in host_cache.buffer. The PD-disagg prefill handler
in disaggregation/prefill.py was missed, so its host buffer is never
written and every prefill slot stays at the initial torch.zeros(...).

Symptom: with PD-disagg + --enable-return-routed-experts + overlap
scheduling, every prompt token's top_k row in the response is
[0, 0, ..., 0]. Routing replay in trainers (e.g. Megatron) then asserts
"Duplicate experts in routing! unique_counts=[1,1,...,1] expected=8".

Fix: add the same finalize() call after copy_done.synchronize() in
process_batch_result_disagg_prefill, matching the agg path.

Verified locally on Qwen3-30B-A3B with 1 prefill + 1 decode + mini-lb +
--enable-return-routed-experts + overlap scheduling. Pre-fix: 10 464 of
34 272 (token, layer) rows are all-zero. Post-fix: 0 bad rows across the
same workload.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ByronHsu ByronHsu merged commit cb0429f into sgl-project:main Apr 27, 2026
57 of 65 checks passed
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
…g_prefill (sgl-project#23885)

Co-authored-by: Byron Hsu <byron@periodiclabs.ai>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants