Skip to content

[fully_async] fix: replace routed_experts on partial rollout resume i…#6029

Merged
wuxibin89 merged 1 commit intoverl-project:mainfrom
NoonePauseferg:fix/partial-rollout-routed-experts
Apr 17, 2026
Merged

[fully_async] fix: replace routed_experts on partial rollout resume i…#6029
wuxibin89 merged 1 commit intoverl-project:mainfrom
NoonePauseferg:fix/partial-rollout-routed-experts

Conversation

@NoonePauseferg
Copy link
Copy Markdown
Contributor

@NoonePauseferg NoonePauseferg commented Apr 16, 2026

What does this PR do?

Fixes a bug in FullyAsyncLLMServerManager.generate() where routed_experts was incorrectly concatenated via torch.cat during partial rollout resume, causing duplicated routing data and broken MoE expert replay in the actor.

sglang returns routed_experts for the full sequence (prompt + all generated tokens). Evidence from sglang source:

  1. io_struct.py#L1020 — field definition:

    # The routed experts for each token, including both input and output tokens
    # routed_experts[i] is a tensor of shape (token, layer, top_k) for request i
    routed_experts: List[Optional[torch.Tensor]]
  2. schedule_batch.pyseqlen used to collect routing covers the full sequence:

    @property
    def seqlen(self) -> int:
        return len(self.origin_input_ids) + len(self.output_ids)
  3. topk.py#L1049-1051 — capture is unconditional (no prefill/decode check):

    get_global_experts_capturer().capture(layer_id=layer_id, topk_ids=topk_ids)
  4. scheduler_output_processor_mixin.py#L105-111 — collection uses full seqlen:

    req.routed_experts = get_global_experts_capturer().get_routed_experts(
        req_pool_idx=req.req_pool_idx,
        seqlen=req.seqlen,  # origin_input_ids + output_ids
        req_to_token_pool=self.req_to_token_pool,
    )

When partial rollout resumes after abort, the input becomes prompt + already_generated_tokens. sglang re-processes the entire input during prefill and returns routed_experts covering all positions. The old code concatenated this with the previous routed_experts:

old routing:    prompt + A B C
new routing:    prompt + A B C + D E
concat result:  prompt + A B C + prompt + A B C + D E   <-- duplicated!
expected:       prompt + A B C + D E

This shifted the routing and caused incorrect MoE expert replay, leading to actor/ppo_kl spikes.

Fix: replace routed_experts instead of concatenating, since the resumed call already covers all positions.

Related: #4348 (partial rollout RFC), #4101 (R3 router replay), #5344 (R3 in fully async)

Checklist Before Starting

Test

  • Ran async training with partial_rollout=True and enable_rollout_routing_replay=True (R3 mode)
  • Verified actor/ppo_kl no longer spikes after partial rollout resume
  • Verified routed_experts tensor shape matches (prompt_len + response_len, num_layers, top_k) after resume

Design & Code Changes

Single-line change in verl/experimental/fully_async_policy/agent_loop/agent_loop.py:

- if output.routed_experts is not None:
-     if final_output.routed_experts is None:
-         final_output.routed_experts = output.routed_experts
-     else:
-         final_output.routed_experts = torch.cat([final_output.routed_experts, output.routed_experts], dim=0)
+ # sglang returns routed_experts for the full sequence (prompt + all tokens),
+ # so on partial rollout resume the new output already covers all positions.
+ if output.routed_experts is not None:
+     final_output.routed_experts = output.routed_experts

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation.
  • Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: This is an async distributed training bug that requires multi-node sglang + megatron setup with MoE model and partial rollout enabled. Not feasible to reproduce in CI.
  • Once your PR is ready for CI, send a message in the ci-request channel.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the generate function in agent_loop.py to simplify the handling of routed_experts. The logic was changed from concatenating expert routing data to directly assigning the latest output, as the underlying engine now provides the routing information for the full sequence. I have no feedback to provide.

@ArronHZG
Copy link
Copy Markdown
Collaborator

Please use pre-commit run --all-files to format code.

…nstead of concatenating

sglang returns routed_experts for the full sequence including both input
and output tokens (see io_struct.py comment on routed_experts field).
When partial rollout resumes, the new call's input is prompt + already
generated tokens, so sglang returns routed_experts covering all positions.

The old code concatenated old and new routed_experts via torch.cat,
which duplicated routing for prompt and previously generated tokens.
This caused incorrect MoE expert replay in the actor and ppo_kl spikes.

Fix: simply replace routed_experts since the resumed call's output
already covers the entire sequence.
@NoonePauseferg NoonePauseferg force-pushed the fix/partial-rollout-routed-experts branch from 788affd to 18d467f Compare April 16, 2026 12:37
Comment thread verl/experimental/fully_async_policy/agent_loop/agent_loop.py
@wuxibin89 wuxibin89 merged commit d079d72 into verl-project:main Apr 17, 2026
69 of 80 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants