Skip to content

[algo] feat: support router replay#4101

Merged
ISEEKYAN merged 18 commits intoverl-project:mainfrom
litianjian:feat/router_replay
Dec 4, 2025
Merged

[algo] feat: support router replay#4101
ISEEKYAN merged 18 commits intoverl-project:mainfrom
litianjian:feat/router_replay

Conversation

@litianjian
Copy link
Copy Markdown
Contributor

@litianjian litianjian commented Nov 12, 2025

What does this PR do?

This PR introduces a draft Router Replay support into Verl.
Inspired by the recent research in MoE Reinforcement Learning(2510.11370, 2507.18071), this implementation supports Router Replay (R2) and Rollout Router Replay (R3).
R2 allows recording routing token selection during log probability computation and replaying expert selection during policy update. R3 enables recording during model inference and replaying during RL post-training.

The initial version supports Router Replay with Megatron backend, including comprehensive support for distributed training strategies (DP, TP, EP, ETP, PP, and Re-compute).

The current implementation uses a patch-based approach. Once the upstream PR NVIDIA/Megatron-LM#2101 is merged or provides corresponding interfaces, the patch can be removed and replaced with official API integration.

Usage Tutorial

Basic Configuration

To enable Router Replay functionality, add the following configuration to your trainer config:

Method 1: Trainer Configuration

Add the following configuration to your trainer config:

router_replay:
  enabled: true
  mode: "R2"  # Options: "R2", "R3"

Method 2: Launch Script Configuration

Add the following parameter to your launch script:

# In your launch script
actor_rollout_ref.actor.router_replay.mode="R2"

R2 Mode Usage

  1. Enable R2 mode in configuration
  2. Record phase: During log probability computation, routing selections are automatically recorded
  3. Replay phase: During policy update, recorded expert selections are replayed

R3 Mode Usage

  1. Enable R3 mode in configuration
  2. Record phase: During model inference, routing decisions are captured
  3. Replay phase: During RL post-training, recorded routing data is used

In Progress

R2

  • FSDP backend

R3

  • vLLM Rollout
  • Sglang Rollout

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Nov 12, 2025

CLA assistant check
All committers have signed the CLA.

@litianjian litianjian changed the title feat: support router replay [WIP]feat: support router replay Nov 12, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces Router Replay (R2/R3) for MoE models, a significant feature for deterministic training and analysis. The implementation relies on monkey-patching Megatron, which is a reasonable approach for now. My review focuses on ensuring the patch is robust and the new code is maintainable. I've identified a critical issue with global state management in the patching mechanism that could lead to memory leaks and incorrect behavior, along with several high-severity issues related to potential runtime errors and code maintainability. Addressing these points will improve the stability and long-term health of this new feature.

Comment thread verl/utils/megatron/router_replay_patch.py
Comment thread verl/utils/megatron/router_replay_utils.py Outdated
Comment thread verl/workers/actor/megatron_actor.py Outdated
Comment thread verl/workers/actor/megatron_actor.py Outdated
Comment thread verl/workers/megatron_workers.py Outdated
@Cesilina
Copy link
Copy Markdown

你好,想问下,在megatron_workers中的compute_log_prob中,R2模式下,此处 if self.enable_routing_replay and self.config.actor.router_replay.mode == "R2":
RouterReplay.set_global_routing_mode(RoutingMode.RECORD)
是不是只设置了record mode,但是并没有set_replay_data呢,后续在megatron_actor的计算中,merge_router_topk_indices中获取的router_instances_list的对象的record_topk_idx全是none

@ISEEKYAN
Copy link
Copy Markdown
Collaborator

please fix the conflicts so we can merge this

@litianjian
Copy link
Copy Markdown
Contributor Author

你好,想问下,在megatron_workers中的compute_log_prob中,R2模式下,此处 if self.enable_routing_replay and self.config.actor.router_replay.mode == "R2": RouterReplay.set_global_routing_mode(RoutingMode.RECORD) 是不是只设置了record mode,但是并没有set_replay_data呢,后续在megatron_actor的计算中,merge_router_topk_indices中获取的router_instances_list的对象的record_topk_idx全是none

In record mode, mcore records the router selection results in raw form. These selections will be used in the next update-policy stage

@Cesilina
Copy link
Copy Markdown

records the router selection results in raw for

enen ,Get it! Thanks

Comment thread verl/utils/megatron/router_replay_patch.py Outdated
Comment thread verl/utils/megatron/router_replay_patch.py Outdated
Comment thread verl/workers/actor/megatron_actor.py Outdated
@litianjian litianjian changed the title [WIP]feat: support router replay [algo] feat: support router replay Dec 3, 2025
@scut-zx
Copy link
Copy Markdown

scut-zx commented Dec 3, 2025

I would like to inquire about the latest progress of this project. Does R3 support the training of Megatron+vLLM? Does Megatron need to use the PR version you submitted: https://github.com/NVIDIA/Megatron-LM/pull/2101/files?

@ISEEKYAN
Copy link
Copy Markdown
Collaborator

ISEEKYAN commented Dec 3, 2025

I would like to inquire about the latest progress of this project. Does R3 support the training of Megatron+vLLM? Does Megatron need to use the PR version you submitted: https://github.com/NVIDIA/Megatron-LM/pull/2101/files?

it is ready for merge now. Now you don't need NVIDIA/Megatron-LM#2101 to achieve router replay in verl. But once megatron's PR is merged, we can remove some patches from verl.

@ISEEKYAN ISEEKYAN merged commit cb23607 into verl-project:main Dec 4, 2025
78 of 80 checks passed
@Shadowyuan616
Copy link
Copy Markdown

Shadowyuan616 commented Dec 9, 2025

Hi, is deepep necessary for routing replay (R2/R3)?
Do you have any methods to run the script on other kinds of devices? Deepep seems to be only compatible with SM90 device for nvshmem and full features.

@litianjian
Copy link
Copy Markdown
Contributor Author

Hi, is deepep necessary for routing replay (R2/R3)? Do you have any methods to run the script on other kinds of devices? Deepep seems to be only compatible with SM90 device for nvshmem and full features.

Hi, deepep is not necessary for routing replay.

@SuffixAutomata
Copy link
Copy Markdown

@litianjian Hello. I have a question about cb23607#diff-25157fd1cdb544dfb3245f630bfbd49cff24e79eac8ef893ce440217eb92b385 - why are we assigning .routed_experts to AgentData? As far as I can tell AgentData does not have this field, nor is this field used downstream.

@litianjian
Copy link
Copy Markdown
Contributor Author

@litianjian Hello. I have a question about cb23607#diff-25157fd1cdb544dfb3245f630bfbd49cff24e79eac8ef893ce440217eb92b385 - why are we assigning .routed_experts to AgentData? As far as I can tell AgentData does not have this field, nor is this field used downstream.

You’re right. This was my mistake—I need to add support for agentdata.

TimurTaepov pushed a commit to giorgossideris/verl that referenced this pull request Dec 20, 2025
### What does this PR do?

This PR introduces a draft **Router Replay** support into Verl.
Inspired by the recent research in **MoE Reinforcement
Learning**([2510.11370](https://arxiv.org/abs/2510.11370),
[2507.18071](https://arxiv.org/abs/2507.18071)), this implementation
supports **Router Replay (R2)** and **Rollout Router Replay (R3)**.
R2 allows recording routing token selection during` log probability
computation` and replaying expert selection during policy update. R3
enables recording during `model inference` and replaying during RL
post-training.

The initial version supports **Router Replay** with `Megatron` backend,
including comprehensive support for distributed training strategies
(**DP, TP, EP, ETP, PP, and Re-compute**).


The current implementation uses a patch-based approach. Once the
upstream PR
[NVIDIA/Megatron-LM#2101](NVIDIA/Megatron-LM#2101)
is merged or provides corresponding interfaces, the patch can be removed
and replaced with official API integration.

## Usage Tutorial

### Basic Configuration
To enable Router Replay functionality, add the following configuration
to your trainer config:
#### Method 1: Trainer Configuration
Add the following configuration to your trainer config:

```yaml
router_replay:
  enabled: true
  mode: "R2"  # Options: "R2", "R3"
```

#### Method 2: Launch Script Configuration
Add the following parameter to your launch script:

```bash
# In your launch script
actor_rollout_ref.actor.router_replay.mode="R2"
```

### R2 Mode Usage
1. **Enable R2 mode** in configuration
2. **Record phase**: During log probability computation, routing
selections are automatically recorded
3. **Replay phase**: During policy update, recorded expert selections
are replayed

### R3 Mode Usage
1. **Enable R3 mode** in configuration
2. **Record phase**: During model inference, routing decisions are
captured
3. **Replay phase**: During RL post-training, recorded routing data is
used
4. 
## In Progress
R2
- [ ]  FSDP backend

R3
- [x] vLLM Rollout
- [ ] Sglang Rollout

---------

Co-authored-by: litianjian <litianjian@bytedance.com>
Co-authored-by: zhangbiao.168 <zhangbiao.168@bytedance.com>
@wuxibin89 wuxibin89 mentioned this pull request Jan 12, 2026
28 tasks
vyomakesh0728 added a commit to vyomakesh0728/verl that referenced this pull request Jan 22, 2026
### What does this PR do?

This PR introduces a draft **Router Replay** support into Verl.
Inspired by the recent research in **MoE Reinforcement
Learning**([2510.11370](https://arxiv.org/abs/2510.11370),
[2507.18071](https://arxiv.org/abs/2507.18071)), this implementation
supports **Router Replay (R2)** and **Rollout Router Replay (R3)**.
R2 allows recording routing token selection during` log probability
computation` and replaying expert selection during policy update. R3
enables recording during `model inference` and replaying during RL
post-training.

The initial version supports **Router Replay** with `Megatron` backend,
including comprehensive support for distributed training strategies
(**DP, TP, EP, ETP, PP, and Re-compute**).


The current implementation uses a patch-based approach. Once the
upstream PR
[NVIDIA/Megatron-LM#2101](NVIDIA/Megatron-LM#2101)
is merged or provides corresponding interfaces, the patch can be removed
and replaced with official API integration.

## Usage Tutorial

### Basic Configuration
To enable Router Replay functionality, add the following configuration
to your trainer config:
#### Method 1: Trainer Configuration
Add the following configuration to your trainer config:

```yaml
router_replay:
  enabled: true
  mode: "R2"  # Options: "R2", "R3"
```

#### Method 2: Launch Script Configuration
Add the following parameter to your launch script:

```bash
# In your launch script
actor_rollout_ref.actor.router_replay.mode="R2"
```

### R2 Mode Usage
1. **Enable R2 mode** in configuration
2. **Record phase**: During log probability computation, routing
selections are automatically recorded
3. **Replay phase**: During policy update, recorded expert selections
are replayed

### R3 Mode Usage
1. **Enable R3 mode** in configuration
2. **Record phase**: During model inference, routing decisions are
captured
3. **Replay phase**: During RL post-training, recorded routing data is
used
4. 
## In Progress
R2
- [ ]  FSDP backend

R3
- [x] vLLM Rollout
- [ ] Sglang Rollout

---------

Co-authored-by: litianjian <litianjian@bytedance.com>
Co-authored-by: zhangbiao.168 <zhangbiao.168@bytedance.com>
SumanthRH pushed a commit to NovaSky-AI/SkyRL that referenced this pull request Mar 14, 2026
# Overview
This PR adds initial support for Rollout Routing Replay (R3) from ([See
Paper](https://arxiv.org/abs/2510.11370)).

See #815 for tracking of future tasks to fully support routing replay in
all settings.

We add the following flags to enable R3:

```
cfg.generator.inference_engine.enable_return_routed_experts=True
cfg.trainer.policy.megatron_config.moe_enable_routing_replay=True
```
`cfg.generator.inference_engine.enable_return_routed_experts=True` is a
pass through argument to vLLM, which records expert router indices
(returning a list of dimension `(batch_size, seq_len, num_layers,
top_k)`.

We then pass this list `rollout_expert_indices` list through to
Megatron's native `RouterReplay` feature
([link](https://github.com/NVIDIA/Megatron-LM/blob/main/docs/api-guide/router_replay.md)).

When `cfg.trainer.policy.megatron_config.moe_enable_routing_replay` is
set to `true`, Megatron initializes an instance of `RouterReplay` on
each training worker rank.
`RouterReplay.set_replay_data(per_layer_data)` can be used to set router
decisions, and
`RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)`
and
`RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_BACKWARD)`
can be used to set the routing mode to be forward or backward.

# Results
GSM8K Training on `moonlight16b-a3b` shows R3 improves training
stability - this can be seen both in logprob diffs as well as in
`clip_ratio`, `grad_norm`, and `loss`, which otherwise explode and
collapse training.
<img width="967" height="238" alt="image"
src="https://github.com/user-attachments/assets/7b84a273-ee27-4eca-a2eb-45c833182093"
/>
<img width="1449" height="241" alt="image"
src="https://github.com/user-attachments/assets/edf6cdab-6f16-4280-9312-410b3965c663"
/>


# Supported Settings
Router Replay is supported for the following settings:
#### Generator Settings
- `use_conversation_multi_turn=True` and
`use_conversation_multi_turn=False`
- `batched=False` and `batched=True`
- `async_engine=True` and `async_engine=False`
- NOT `retokenize_chat_history` mode - i.e. `
self.use_conversation_multi_turn and self.custom_chat_template`
- NOT `self.generator_cfg.step_wise_trajectories` - there are some
question marks about how to support this when using step wise training
and not strictly appending (what should the routing look like for per
turn obs that the inference engine doesn't see? - do we need to disable
routing overrides for those tokens?)
- fully_async training - technically should work but not tested in this
PR. Tracking in #815
#### Inference Engine Settings
- TP and EP and DP should be supported from the vLLM side
- NOTE: `_SKYRL_USE_NEW_INFERENCE` is not supported - this will be added
in a follow up PR
- NOTE: `cfg.generator.distributed_executor_backend` must be set to `mp`
- hanging related to a Ray Compiled Graph issue occurs when using the
default `ray` vLLM distributed executor backend. (see
vllm-project/vllm#36237 for details on the
error that comes up)
- NOTE: The above use of `mp` also means that serving must be single
node per engine, until we add support for using the mp backend with
multi-node serving - progress tracked here:
#1309

#### Trainer Settings
- TP, EP, DP are all supported. CP is in progress in this PR but needs
more testing. CP + PP will be added in a follow up PR.

#### Custom Generator support
- Custom generators using SkyRL's inference engine should just plumb
through

# Tests
Adds `test_router_replay.py`, which includes:
- `test_logprobs` - integration test that runs a training batch through
vllm, and through megatron with and without R3, to verify that logprob
diffs are lower with routing replay
- `test_forward_backward` - unit test for `forward_backward` that
verifies that a training step can complete successfully when routing
replay indices are passed in

Adds `test_generator_multi_turn_gsm8k_router_replay` to
`test_skyrl_gym_generator` to verify that the `SkyRLGymGenerator` plumbs
through the router indices in an expected format.

# Rollout Routing Replay 

<img width="656" height="408" alt="image"
src="https://github.com/user-attachments/assets/2d22bcf6-64ac-4dd6-97a3-a09d34fdef47"
/>

Relevant resources: 
vLLM PR: vllm-project/vllm#28284
Verl PR: verl-project/verl#4101
Mindlab blog:
https://macaron.im/mindlab/research/router-replay-r3-why-it-failed-and-how-we-fixed-it
Megatron-LM API guide:
https://github.com/NVIDIA/Megatron-LM/blob/main/docs/api-guide/router_replay.md

<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1273"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->

---------

Co-authored-by: Dev Patel <dev.patel@berkeley.edu>
wuxibin89 pushed a commit that referenced this pull request Apr 17, 2026
#6029)

### What does this PR do?

Fixes a bug in `FullyAsyncLLMServerManager.generate()` where
`routed_experts` was incorrectly concatenated via `torch.cat` during
partial rollout resume, causing duplicated routing data and broken MoE
expert replay in the actor.

sglang returns `routed_experts` for the **full sequence** (prompt + all
generated tokens). Evidence from sglang source:

1.
**[`io_struct.py#L1020`](https://github.com/sgl-project/sglang/blob/v0.5.9/python/sglang/srt/managers/io_struct.py#L1020)**
— field definition:
   ```python
# The routed experts for each token, including both input and output
tokens
# routed_experts[i] is a tensor of shape (token, layer, top_k) for
request i
   routed_experts: List[Optional[torch.Tensor]]
   ```

2. **`schedule_batch.py`** — `seqlen` used to collect routing covers the
full sequence:
   ```python
   @Property
   def seqlen(self) -> int:
       return len(self.origin_input_ids) + len(self.output_ids)
   ```

3. **`topk.py#L1049-1051`** — capture is unconditional (no
prefill/decode check):
   ```python
get_global_experts_capturer().capture(layer_id=layer_id,
topk_ids=topk_ids)
   ```

4. **`scheduler_output_processor_mixin.py#L105-111`** — collection uses
full `seqlen`:
   ```python
req.routed_experts = get_global_experts_capturer().get_routed_experts(
       req_pool_idx=req.req_pool_idx,
       seqlen=req.seqlen,  # origin_input_ids + output_ids
       req_to_token_pool=self.req_to_token_pool,
   )
   ```

When partial rollout resumes after abort, the input becomes `prompt +
already_generated_tokens`. sglang re-processes the entire input during
prefill and returns `routed_experts` covering all positions. The old
code concatenated this with the previous `routed_experts`:

```
old routing:    prompt + A B C
new routing:    prompt + A B C + D E
concat result:  prompt + A B C + prompt + A B C + D E   <-- duplicated!
expected:       prompt + A B C + D E
```

This shifted the routing and caused incorrect MoE expert replay, leading
to `actor/ppo_kl` spikes.

**Fix:** replace `routed_experts` instead of concatenating, since the
resumed call already covers all positions.

Related: #4348 (partial rollout RFC), #4101 (R3 router replay), #5344
(R3 in fully async)

### Checklist Before Starting

- [x] Search for similar PRs:
https://github.com/verl-project/verl/pulls?q=routed_experts+partial_rollout
- [x] Format the PR title as `[{modules}] {type}: {description}`

### Test

- Ran async training with `partial_rollout=True` and
`enable_rollout_routing_replay=True` (R3 mode)
- Verified `actor/ppo_kl` no longer spikes after partial rollout resume
- Verified `routed_experts` tensor shape matches `(prompt_len +
response_len, num_layers, top_k)` after resume

### Design & Code Changes

Single-line change in
`verl/experimental/fully_async_policy/agent_loop/agent_loop.py`:

```diff
- if output.routed_experts is not None:
-     if final_output.routed_experts is None:
-         final_output.routed_experts = output.routed_experts
-     else:
-         final_output.routed_experts = torch.cat([final_output.routed_experts, output.routed_experts], dim=0)
+ # sglang returns routed_experts for the full sequence (prompt + all tokens),
+ # so on partial rollout resume the new output already covers all positions.
+ if output.routed_experts is not None:
+     final_output.routed_experts = output.routed_experts
```

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: This is an async
distributed training bug that requires multi-node sglang + megatron
setup with MoE model and partial rollout enabled. Not feasible to
reproduce in CI.
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1).

Co-authored-by: vadim <vadim@mail.ru>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants