Skip to content

[megatron, training_utils] fix: Patch MoEAlltoAllTokenDispatcher.preprocess for router replay#4986

Merged
wuxibin89 merged 1 commit intoverl-project:mainfrom
HollowMan6:router_replay_all_to_all
Jan 27, 2026
Merged

[megatron, training_utils] fix: Patch MoEAlltoAllTokenDispatcher.preprocess for router replay#4986
wuxibin89 merged 1 commit intoverl-project:mainfrom
HollowMan6:router_replay_all_to_all

Conversation

@HollowMan6
Copy link
Copy Markdown
Collaborator

@HollowMan6 HollowMan6 commented Jan 19, 2026

What does this PR do?

When router replay is enabled, and we have moe_token_dispatcher_type = "alltoall", duplicate indices in top_indices can cause routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall and we shall derive it from the routing map instead in this case.

To fix this error:

  File "/root/verl/verl/workers/megatron_workers.py", line 901, in compute_log_prob
    output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/utils/profiler/performance.py", line 105, in f
    return self.log(decorated_function, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/utils/profiler/performance.py", line 118, in log
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 238, in compute_log_prob
    output = self.forward_backward_batch(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 699, in forward_backward_batch
    losses_reduced = forward_backward_func(
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/pipeline_parallel/schedules.py", line 629, in forward_backward_no_pipelining
    output_tensor, num_tokens = forward_step(
                                ^^^^^^^^^^^^^
  File "megatron/core/pipeline_parallel/schedules.py", line 417, in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 614, in forward_step
    output = forward_fn(
             ^^^^^^^^^^^
  File "/root/verl/verl/models/mcore/model_forward_fused.py", line 120, in fused_forward_model
    output_orig: CausalLMOutputForPPO = model(**input_args)
                                        ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/distributed/data_parallel_base.py", line 22, in forward
    return self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/module.py", line 490, in forward
    outputs = self.module(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/models/mcore/model_forward_fused.py", line 178, in _fused_GPTModel_forward
    hidden_states = model.decoder(
                    ^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_block.py", line 610, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/module.py", line 353, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_block.py", line 754, in forward
    hidden_states, context = layer(
                             ^^^^^^
  File "megatron/core/transformer/transformer_layer.py", line 1169, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/module.py", line 353, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_layer.py", line 491, in forward
    output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_layer.py", line 715, in _forward_mlp
    mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/moe_layer.py", line 372, in forward
    outputs = custom_forward(hidden_states)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/moe_layer.py", line 353, in custom_forward
    dispatched_input, probs = self.dispatch(hidden_states, probs)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/moe_layer.py", line 247, in dispatch
    return self.token_dispatcher.token_dispatch(hidden_states, probs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/token_dispatcher.py", line 666, in token_dispatch
    global_input_tokens = all_to_all(
                          ^^^^^^^^^^^
  File "megatron/core/tensor_parallel/mappings.py", line 538, in all_to_all
    return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 581, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/tensor_parallel/mappings.py", line 444, in forward
    torch.distributed.all_to_all_single(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4688, in all_to_all_single
    work = group.alltoall_base(
           ^^^^^^^^^^^^^^^^^^^^
RuntimeError: Split sizes doesn't match total dim 0 size

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

✨ Presented to you with Mind Lab - A Lab for Experiential Intelligence.

Copilot AI review requested due to automatic review settings January 19, 2026 22:22
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a RuntimeError during MoE AlltoAll dispatch when using router replay. The fix involves monkey-patching MoEAlltoAllTokenDispatcher.preprocess to correctly calculate the number of output tokens when duplicate indices are present. While the intent is correct, I've found a critical issue in the patch's implementation that renders the fix ineffective and may cause further issues. My review includes a detailed explanation and a code suggestion to correct the patch.

Comment thread verl/utils/megatron/router_replay_patch.py
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a runtime error in MoE (Mixture of Experts) router replay functionality when using the all-to-all token dispatcher. When router replay is enabled, duplicate indices in top_indices can cause the routing map sum to be less than expected (num_tokens * topk), leading to split size mismatches during distributed all-to-all operations.

Changes:

  • Added import for MoEAlltoAllTokenDispatcher with appropriate error handling
  • Patched MoEAlltoAllTokenDispatcher.preprocess to correctly calculate num_out_tokens from routing_map when router replay is enabled
  • Updated step numbering in comments from Step 4 to Step 5 for existing patches

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread verl/utils/megatron/router_replay_patch.py Outdated
Comment thread verl/utils/megatron/router_replay_patch.py
@HollowMan6 HollowMan6 force-pushed the router_replay_all_to_all branch from 7c05ae3 to cc6f86f Compare January 19, 2026 22:46
@HollowMan6 HollowMan6 force-pushed the router_replay_all_to_all branch 6 times, most recently from 4533096 to f4177c3 Compare January 23, 2026 20:59
…rocess for router replay

When router replay is enabled, duplicate indices in top_indices can cause
routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall
and we shall derive it from the routing map instead in this case.

Fix this error:

```log
  File "/root/verl/verl/workers/megatron_workers.py", line 901, in compute_log_prob
    output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/utils/profiler/performance.py", line 105, in f
    return self.log(decorated_function, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/utils/profiler/performance.py", line 118, in log
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 238, in compute_log_prob
    output = self.forward_backward_batch(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 699, in forward_backward_batch
    losses_reduced = forward_backward_func(
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/pipeline_parallel/schedules.py", line 629, in forward_backward_no_pipelining
    output_tensor, num_tokens = forward_step(
                                ^^^^^^^^^^^^^
  File "megatron/core/pipeline_parallel/schedules.py", line 417, in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 614, in forward_step
    output = forward_fn(
             ^^^^^^^^^^^
  File "/root/verl/verl/models/mcore/model_forward_fused.py", line 120, in fused_forward_model
    output_orig: CausalLMOutputForPPO = model(**input_args)
                                        ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/distributed/data_parallel_base.py", line 22, in forward
    return self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/module.py", line 490, in forward
    outputs = self.module(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/models/mcore/model_forward_fused.py", line 178, in _fused_GPTModel_forward
    hidden_states = model.decoder(
                    ^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_block.py", line 610, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/module.py", line 353, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_block.py", line 754, in forward
    hidden_states, context = layer(
                             ^^^^^^
  File "megatron/core/transformer/transformer_layer.py", line 1169, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/module.py", line 353, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_layer.py", line 491, in forward
    output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_layer.py", line 715, in _forward_mlp
    mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/moe_layer.py", line 372, in forward
    outputs = custom_forward(hidden_states)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/moe_layer.py", line 353, in custom_forward
    dispatched_input, probs = self.dispatch(hidden_states, probs)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/moe_layer.py", line 247, in dispatch
    return self.token_dispatcher.token_dispatch(hidden_states, probs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/token_dispatcher.py", line 666, in token_dispatch
    global_input_tokens = all_to_all(
                          ^^^^^^^^^^^
  File "megatron/core/tensor_parallel/mappings.py", line 538, in all_to_all
    return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 581, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/tensor_parallel/mappings.py", line 444, in forward
    torch.distributed.all_to_all_single(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4688, in all_to_all_single
    work = group.alltoall_base(
           ^^^^^^^^^^^^^^^^^^^^
RuntimeError: Split sizes doesn't match total dim 0 size
```

Signed-off-by: Hollow Man <hollowman@opensuse.org>
@wuxibin89 wuxibin89 merged commit a1a35a7 into verl-project:main Jan 27, 2026
70 of 73 checks passed
@HollowMan6 HollowMan6 deleted the router_replay_all_to_all branch January 27, 2026 11:03
DaizeDong pushed a commit to DaizeDong/verl that referenced this pull request Apr 19, 2026
…rocess for router replay (verl-project#4986)

### What does this PR do?

When router replay is enabled, and we have `moe_token_dispatcher_type =
"alltoall"`, duplicate indices in top_indices can cause
routing_map.sum() < num_tokens * topk, leading to split size mismatch in
alltoall and we shall derive it from the routing map instead in this
case.

To fix this error:

```log
  File "/root/verl/verl/workers/megatron_workers.py", line 901, in compute_log_prob
    output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/utils/profiler/performance.py", line 105, in f
    return self.log(decorated_function, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/utils/profiler/performance.py", line 118, in log
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 238, in compute_log_prob
    output = self.forward_backward_batch(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 699, in forward_backward_batch
    losses_reduced = forward_backward_func(
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/pipeline_parallel/schedules.py", line 629, in forward_backward_no_pipelining
    output_tensor, num_tokens = forward_step(
                                ^^^^^^^^^^^^^
  File "megatron/core/pipeline_parallel/schedules.py", line 417, in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 614, in forward_step
    output = forward_fn(
             ^^^^^^^^^^^
  File "/root/verl/verl/models/mcore/model_forward_fused.py", line 120, in fused_forward_model
    output_orig: CausalLMOutputForPPO = model(**input_args)
                                        ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/distributed/data_parallel_base.py", line 22, in forward
    return self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/module.py", line 490, in forward
    outputs = self.module(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/models/mcore/model_forward_fused.py", line 178, in _fused_GPTModel_forward
    hidden_states = model.decoder(
                    ^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_block.py", line 610, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/module.py", line 353, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_block.py", line 754, in forward
    hidden_states, context = layer(
                             ^^^^^^
  File "megatron/core/transformer/transformer_layer.py", line 1169, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/module.py", line 353, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_layer.py", line 491, in forward
    output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/transformer_layer.py", line 715, in _forward_mlp
    mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/moe_layer.py", line 372, in forward
    outputs = custom_forward(hidden_states)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/moe_layer.py", line 353, in custom_forward
    dispatched_input, probs = self.dispatch(hidden_states, probs)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/moe_layer.py", line 247, in dispatch
    return self.token_dispatcher.token_dispatch(hidden_states, probs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/transformer/moe/token_dispatcher.py", line 666, in token_dispatch
    global_input_tokens = all_to_all(
                          ^^^^^^^^^^^
  File "megatron/core/tensor_parallel/mappings.py", line 538, in all_to_all
    return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 581, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "megatron/core/tensor_parallel/mappings.py", line 444, in forward
    torch.distributed.all_to_all_single(
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4688, in all_to_all_single
    work = group.alltoall_base(
           ^^^^^^^^^^^^^^^^^^^^
RuntimeError: Split sizes doesn't match total dim 0 size
```

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`,
`rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`,
`deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`,
`model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [X] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [X] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [X] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [X] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [X] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [X] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.

<sub>✨ Presented to you with <a href="https://macaron.im/mindlab">Mind
Lab</a> - A Lab for Experiential Intelligence.</sub>

Signed-off-by: Hollow Man <hollowman@opensuse.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants