Skip to content

[megatron] fix: support hybrid dense/MoE models in router replay with PP/VPP#5452

Merged
PeterSH6 merged 3 commits intoverl-project:mainfrom
xhx1022:fix/vpp_hybrid_model
Mar 3, 2026
Merged

[megatron] fix: support hybrid dense/MoE models in router replay with PP/VPP#5452
PeterSH6 merged 3 commits intoverl-project:mainfrom
xhx1022:fix/vpp_hybrid_model

Conversation

@xhx1022
Copy link
Copy Markdown
Collaborator

@xhx1022 xhx1022 commented Mar 2, 2026

What does this PR do?

Router replay previously assumed all transformer layers are MoE layers, which caused incorrect layer indexing for hybrid models (e.g., models with both dense and MoE layers determined by moe_layer_freq).
This led to bugs when using pipeline parallelism (PP) and virtual pipeline parallelism (VPP), as layer offset calculations did not account for dense layers.

Although #5037 introduced the router replay mechanism by patching Megatron's TopKRouter, it did not fully handle hybrid (dense + MoE) models under VPP.
Specifically:

  • Bug 1 — Incorrect VPP offset (root cause): In
    num_layers_to_build = get_num_layers_to_build(tf_config, pre_vp_stage)
    , get_num_layers_to_build() was used to compute the offset across prior VPP stages. This returns the count of all transformer layers (including dense layers), but RouterReplay instances only exist on MoE layers. For hybrid models
    this over-counts the offset, causing the wrong slice of router instances to be selected.
  • Bug 2 — Replay data not set correctly (consequence): Because Bug 1 returns the wrong router instance list,
    router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank)
    either assigns target_indices to the wrong router or goes out of bounds, so replay data is never correctly dispatched to the corresponding MoE layers.

The same issue also exists in pp_gather(), where VPP offset calculation must slice gathered data by MoE layer count rather than total layer count.

Key changes:

  • Add is_moe_layer() and get_moe_num_layers_to_build() helpers to distinguish MoE layers from dense layers based on moe_layer_freq
  • Rewrite set_router_replay_data() to correctly index router replay data by MoE-layer ordinal for R2 mode with mixed dense/MoE models
  • Fix VPP offset calculation in pp_gather() and RouterReplayHelper to count only MoE layers instead of all transformer layers
  • Remove unnecessary layer_number tracking from RouterReplay patch to minimize intrusive changes to Megatron.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

… PP/VP

Signed-off-by: xhx1022 <1737006628@qq.com>
@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical issue with router replay in hybrid dense/MoE models, particularly when pipeline parallelism is enabled. The changes correctly distinguish MoE layers from dense layers and adjust layer indexing and offset calculations accordingly. The refactoring to remove unnecessary layer number tracking is also a good cleanup. Overall, the changes are well-implemented and address the core problem described. I've identified one critical issue in a new helper function that could lead to a crash and have provided a suggestion to fix it.

Comment on lines +174 to +182
def is_moe_layer(tf_config, layer_idx):
moe_layer_freq = getattr(tf_config, "moe_layer_freq", None)

if isinstance(moe_layer_freq, int):
return layer_idx % moe_layer_freq == 0
elif isinstance(moe_layer_freq, list):
return moe_layer_freq[layer_idx] == 1
else:
raise ValueError(f"Unsupported moe_layer_freq type: {type(moe_layer_freq)}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The is_moe_layer function has two potential issues that could lead to runtime errors:

  1. Unset moe_layer_freq: The docstring for get_moe_num_layers_to_build specifies that if moe_layer_freq is unset, all layers should be treated as MoE layers. However, the current implementation getattr(tf_config, "moe_layer_freq", None) will cause moe_layer_freq to be None, which then leads to a ValueError. This will cause a crash when router replay is used with a model that doesn't explicitly define moe_layer_freq.
  2. moe_layer_freq is zero: If moe_layer_freq is set to 0, the expression layer_idx % moe_layer_freq will cause a ZeroDivisionError.

To make the function more robust, it should default to 1 for an unset moe_layer_freq and explicitly handle non-positive values.

Suggested change
def is_moe_layer(tf_config, layer_idx):
moe_layer_freq = getattr(tf_config, "moe_layer_freq", None)
if isinstance(moe_layer_freq, int):
return layer_idx % moe_layer_freq == 0
elif isinstance(moe_layer_freq, list):
return moe_layer_freq[layer_idx] == 1
else:
raise ValueError(f"Unsupported moe_layer_freq type: {type(moe_layer_freq)}")
def is_moe_layer(tf_config, layer_idx):
moe_layer_freq = getattr(tf_config, "moe_layer_freq", 1)
if isinstance(moe_layer_freq, int):
if moe_layer_freq <= 0:
return False
return layer_idx % moe_layer_freq == 0
elif isinstance(moe_layer_freq, list):
return moe_layer_freq[layer_idx] == 1
else:
raise ValueError(f"Unsupported moe_layer_freq type: {type(moe_layer_freq)}")

xhx1022 added 2 commits March 2, 2026 20:49
Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: xhx1022 <1737006628@qq.com>
Copy link
Copy Markdown
Collaborator

@PeterSH6 PeterSH6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

@PeterSH6 PeterSH6 merged commit 5ab49ec into verl-project:main Mar 3, 2026
100 of 117 checks passed
JasonWei05 referenced this pull request in JasonWei05/eca Mar 3, 2026
… PP/VPP (#5452)

What does this PR do?
Router replay previously assumed all transformer layers are MoE layers,
which caused incorrect layer indexing for hybrid models (e.g., models
with both dense and MoE layers determined by moe_layer_freq).
This led to bugs when using pipeline parallelism (PP) and virtual
pipeline parallelism (VPP), as layer offset calculations did not account
for dense layers.

Although verl-project/verl#5037 introduced the
router replay mechanism by patching Megatron's TopKRouter, it did not
fully handle hybrid (dense + MoE) models under VPP.
  Specifically:

- Bug 1 — Incorrect VPP offset (root cause): In
https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L422,
get_num_layers_to_build() was used to compute the offset across prior
VPP stages. This returns the count of all transformer layers (including
dense layers), but RouterReplay instances only exist on MoE layers. For
hybrid models
this over-counts the offset, causing the wrong slice of router instances
to be selected.
- Bug 2 — Replay data not set correctly (consequence): Because Bug 1
returns the wrong router instance list,

https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L256
either assigns target_indices to the wrong router or goes out of bounds,
so replay data is never correctly dispatched to the corresponding MoE
layers.

The same issue also exists in pp_gather(), where VPP offset calculation
must slice gathered data by MoE layer count rather than total layer
count.

  Key changes:
- Add is_moe_layer() and get_moe_num_layers_to_build() helpers to
distinguish MoE layers from dense layers based on moe_layer_freq
- Rewrite set_router_replay_data() to correctly index router replay data
by MoE-layer ordinal for R2 mode with mixed dense/MoE models
- Fix VPP offset calculation in pp_gather() and RouterReplayHelper to
count only MoE layers instead of all transformer layers
- Remove unnecessary layer_number tracking from RouterReplay patch to
minimize intrusive changes to Megatron.

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`,
`rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`,
`deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`,
`model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`,
`fully_async`, `one_step_off`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.

---------

Signed-off-by: xhx1022 <1737006628@qq.com>
guillemgt pushed a commit to guillemgt/verl that referenced this pull request Mar 9, 2026
… PP/VPP (verl-project#5452)

What does this PR do?
Router replay previously assumed all transformer layers are MoE layers,
which caused incorrect layer indexing for hybrid models (e.g., models
with both dense and MoE layers determined by moe_layer_freq).
This led to bugs when using pipeline parallelism (PP) and virtual
pipeline parallelism (VPP), as layer offset calculations did not account
for dense layers.

Although verl-project#5037 introduced the
router replay mechanism by patching Megatron's TopKRouter, it did not
fully handle hybrid (dense + MoE) models under VPP.
  Specifically:

- Bug 1 — Incorrect VPP offset (root cause): In
https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L422,
get_num_layers_to_build() was used to compute the offset across prior
VPP stages. This returns the count of all transformer layers (including
dense layers), but RouterReplay instances only exist on MoE layers. For
hybrid models
this over-counts the offset, causing the wrong slice of router instances
to be selected.
- Bug 2 — Replay data not set correctly (consequence): Because Bug 1
returns the wrong router instance list,

https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L256
either assigns target_indices to the wrong router or goes out of bounds,
so replay data is never correctly dispatched to the corresponding MoE
layers.

The same issue also exists in pp_gather(), where VPP offset calculation
must slice gathered data by MoE layer count rather than total layer
count.

  Key changes:
- Add is_moe_layer() and get_moe_num_layers_to_build() helpers to
distinguish MoE layers from dense layers based on moe_layer_freq
- Rewrite set_router_replay_data() to correctly index router replay data
by MoE-layer ordinal for R2 mode with mixed dense/MoE models
- Fix VPP offset calculation in pp_gather() and RouterReplayHelper to
count only MoE layers instead of all transformer layers
- Remove unnecessary layer_number tracking from RouterReplay patch to
minimize intrusive changes to Megatron.

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`,
`rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`,
`deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`,
`model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`,
`fully_async`, `one_step_off`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.

---------

Signed-off-by: xhx1022 <1737006628@qq.com>
guillemgt added a commit to guillemgt/verl that referenced this pull request Mar 9, 2026
… PP/VPP (verl-project#5452)

What does this PR do?
Router replay previously assumed all transformer layers are MoE layers,
which caused incorrect layer indexing for hybrid models (e.g., models
with both dense and MoE layers determined by moe_layer_freq).
This led to bugs when using pipeline parallelism (PP) and virtual
pipeline parallelism (VPP), as layer offset calculations did not account
for dense layers.

Although verl-project#5037 introduced the
router replay mechanism by patching Megatron's TopKRouter, it did not
fully handle hybrid (dense + MoE) models under VPP.
  Specifically:

- Bug 1 — Incorrect VPP offset (root cause): In
https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L422,
get_num_layers_to_build() was used to compute the offset across prior
VPP stages. This returns the count of all transformer layers (including
dense layers), but RouterReplay instances only exist on MoE layers. For
hybrid models
this over-counts the offset, causing the wrong slice of router instances
to be selected.
- Bug 2 — Replay data not set correctly (consequence): Because Bug 1
returns the wrong router instance list,

https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L256
either assigns target_indices to the wrong router or goes out of bounds,
so replay data is never correctly dispatched to the corresponding MoE
layers.

The same issue also exists in pp_gather(), where VPP offset calculation
must slice gathered data by MoE layer count rather than total layer
count.

  Key changes:
- Add is_moe_layer() and get_moe_num_layers_to_build() helpers to
distinguish MoE layers from dense layers based on moe_layer_freq
- Rewrite set_router_replay_data() to correctly index router replay data
by MoE-layer ordinal for R2 mode with mixed dense/MoE models
- Fix VPP offset calculation in pp_gather() and RouterReplayHelper to
count only MoE layers instead of all transformer layers
- Remove unnecessary layer_number tracking from RouterReplay patch to
minimize intrusive changes to Megatron.

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`,
`rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`,
`deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`,
`model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`,
`fully_async`, `one_step_off`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.

---------

Signed-off-by: xhx1022 <1737006628@qq.com>
DearFishi pushed a commit to KunlunxinAD/verl that referenced this pull request Mar 20, 2026
… PP/VPP (verl-project#5452)

What does this PR do?
Router replay previously assumed all transformer layers are MoE layers,
which caused incorrect layer indexing for hybrid models (e.g., models
with both dense and MoE layers determined by moe_layer_freq).
This led to bugs when using pipeline parallelism (PP) and virtual
pipeline parallelism (VPP), as layer offset calculations did not account
for dense layers.

Although verl-project#5037 introduced the
router replay mechanism by patching Megatron's TopKRouter, it did not
fully handle hybrid (dense + MoE) models under VPP.
  Specifically:

- Bug 1 — Incorrect VPP offset (root cause): In
https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L422,
get_num_layers_to_build() was used to compute the offset across prior
VPP stages. This returns the count of all transformer layers (including
dense layers), but RouterReplay instances only exist on MoE layers. For
hybrid models
this over-counts the offset, causing the wrong slice of router instances
to be selected.
- Bug 2 — Replay data not set correctly (consequence): Because Bug 1
returns the wrong router instance list,

https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L256
either assigns target_indices to the wrong router or goes out of bounds,
so replay data is never correctly dispatched to the corresponding MoE
layers.

The same issue also exists in pp_gather(), where VPP offset calculation
must slice gathered data by MoE layer count rather than total layer
count.

  Key changes:
- Add is_moe_layer() and get_moe_num_layers_to_build() helpers to
distinguish MoE layers from dense layers based on moe_layer_freq
- Rewrite set_router_replay_data() to correctly index router replay data
by MoE-layer ordinal for R2 mode with mixed dense/MoE models
- Fix VPP offset calculation in pp_gather() and RouterReplayHelper to
count only MoE layers instead of all transformer layers
- Remove unnecessary layer_number tracking from RouterReplay patch to
minimize intrusive changes to Megatron.

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`,
`rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`,
`deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`,
`model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`,
`fully_async`, `one_step_off`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.

---------

Signed-off-by: xhx1022 <1737006628@qq.com>
@xhx1022 xhx1022 deleted the fix/vpp_hybrid_model branch March 21, 2026 02:44
sijyang pushed a commit to sijyang/verl that referenced this pull request Apr 1, 2026
… PP/VPP (verl-project#5452)

What does this PR do?
Router replay previously assumed all transformer layers are MoE layers,
which caused incorrect layer indexing for hybrid models (e.g., models
with both dense and MoE layers determined by moe_layer_freq).
This led to bugs when using pipeline parallelism (PP) and virtual
pipeline parallelism (VPP), as layer offset calculations did not account
for dense layers.

Although verl-project#5037 introduced the
router replay mechanism by patching Megatron's TopKRouter, it did not
fully handle hybrid (dense + MoE) models under VPP.
  Specifically:

- Bug 1 — Incorrect VPP offset (root cause): In
https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L422,
get_num_layers_to_build() was used to compute the offset across prior
VPP stages. This returns the count of all transformer layers (including
dense layers), but RouterReplay instances only exist on MoE layers. For
hybrid models
this over-counts the offset, causing the wrong slice of router instances
to be selected.
- Bug 2 — Replay data not set correctly (consequence): Because Bug 1
returns the wrong router instance list,

https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L256
either assigns target_indices to the wrong router or goes out of bounds,
so replay data is never correctly dispatched to the corresponding MoE
layers.

The same issue also exists in pp_gather(), where VPP offset calculation
must slice gathered data by MoE layer count rather than total layer
count.

  Key changes:
- Add is_moe_layer() and get_moe_num_layers_to_build() helpers to
distinguish MoE layers from dense layers based on moe_layer_freq
- Rewrite set_router_replay_data() to correctly index router replay data
by MoE-layer ordinal for R2 mode with mixed dense/MoE models
- Fix VPP offset calculation in pp_gather() and RouterReplayHelper to
count only MoE layers instead of all transformer layers
- Remove unnecessary layer_number tracking from RouterReplay patch to
minimize intrusive changes to Megatron.

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`,
`rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`,
`deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`,
`model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`,
`fully_async`, `one_step_off`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.

---------

Signed-off-by: xhx1022 <1737006628@qq.com>
DaizeDong pushed a commit to DaizeDong/verl that referenced this pull request Apr 19, 2026
… PP/VPP (verl-project#5452)

What does this PR do?
Router replay previously assumed all transformer layers are MoE layers,
which caused incorrect layer indexing for hybrid models (e.g., models
with both dense and MoE layers determined by moe_layer_freq).
This led to bugs when using pipeline parallelism (PP) and virtual
pipeline parallelism (VPP), as layer offset calculations did not account
for dense layers.

Although verl-project#5037 introduced the
router replay mechanism by patching Megatron's TopKRouter, it did not
fully handle hybrid (dense + MoE) models under VPP.
  Specifically:

- Bug 1 — Incorrect VPP offset (root cause): In
https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L422,
get_num_layers_to_build() was used to compute the offset across prior
VPP stages. This returns the count of all transformer layers (including
dense layers), but RouterReplay instances only exist on MoE layers. For
hybrid models
this over-counts the offset, causing the wrong slice of router instances
to be selected.
- Bug 2 — Replay data not set correctly (consequence): Because Bug 1
returns the wrong router instance list,

https://github.com/verl-project/verl/blob/c179476754150a5384f96d56b622a8f6330d2c04/verl/utils/megatron/router_replay_utils.py#L256
either assigns target_indices to the wrong router or goes out of bounds,
so replay data is never correctly dispatched to the corresponding MoE
layers.

The same issue also exists in pp_gather(), where VPP offset calculation
must slice gathered data by MoE layer count rather than total layer
count.

  Key changes:
- Add is_moe_layer() and get_moe_num_layers_to_build() helpers to
distinguish MoE layers from dense layers based on moe_layer_freq
- Rewrite set_router_replay_data() to correctly index router replay data
by MoE-layer ordinal for R2 mode with mixed dense/MoE models
- Fix VPP offset calculation in pp_gather() and RouterReplayHelper to
count only MoE layers instead of all transformer layers
- Remove unnecessary layer_number tracking from RouterReplay patch to
minimize intrusive changes to Megatron.

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`,
`rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`,
`deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`,
`model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`,
`fully_async`, `one_step_off`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.

---------

Signed-off-by: xhx1022 <1737006628@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants