[megatron, training_utils] fix: Patch MoEAlltoAllTokenDispatcher.preprocess for router replay#4986
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses a RuntimeError during MoE AlltoAll dispatch when using router replay. The fix involves monkey-patching MoEAlltoAllTokenDispatcher.preprocess to correctly calculate the number of output tokens when duplicate indices are present. While the intent is correct, I've found a critical issue in the patch's implementation that renders the fix ineffective and may cause further issues. My review includes a detailed explanation and a code suggestion to correct the patch.
There was a problem hiding this comment.
Pull request overview
This PR fixes a runtime error in MoE (Mixture of Experts) router replay functionality when using the all-to-all token dispatcher. When router replay is enabled, duplicate indices in top_indices can cause the routing map sum to be less than expected (num_tokens * topk), leading to split size mismatches during distributed all-to-all operations.
Changes:
- Added import for MoEAlltoAllTokenDispatcher with appropriate error handling
- Patched MoEAlltoAllTokenDispatcher.preprocess to correctly calculate num_out_tokens from routing_map when router replay is enabled
- Updated step numbering in comments from Step 4 to Step 5 for existing patches
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
7c05ae3 to
cc6f86f
Compare
4533096 to
f4177c3
Compare
…rocess for router replay
When router replay is enabled, duplicate indices in top_indices can cause
routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall
and we shall derive it from the routing map instead in this case.
Fix this error:
```log
File "/root/verl/verl/workers/megatron_workers.py", line 901, in compute_log_prob
output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/verl/verl/utils/profiler/performance.py", line 105, in f
return self.log(decorated_function, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/verl/verl/utils/profiler/performance.py", line 118, in log
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/verl/verl/workers/actor/megatron_actor.py", line 238, in compute_log_prob
output = self.forward_backward_batch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/verl/verl/workers/actor/megatron_actor.py", line 699, in forward_backward_batch
losses_reduced = forward_backward_func(
^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/pipeline_parallel/schedules.py", line 629, in forward_backward_no_pipelining
output_tensor, num_tokens = forward_step(
^^^^^^^^^^^^^
File "megatron/core/pipeline_parallel/schedules.py", line 417, in forward_step
output_tensor, loss_func = forward_step_func(data_iterator, model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/verl/verl/workers/actor/megatron_actor.py", line 614, in forward_step
output = forward_fn(
^^^^^^^^^^^
File "/root/verl/verl/models/mcore/model_forward_fused.py", line 120, in fused_forward_model
output_orig: CausalLMOutputForPPO = model(**input_args)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/distributed/data_parallel_base.py", line 22, in forward
return self.module(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/transformer/module.py", line 490, in forward
outputs = self.module(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/verl/verl/models/mcore/model_forward_fused.py", line 178, in _fused_GPTModel_forward
hidden_states = model.decoder(
^^^^^^^^^^^^^^
File "megatron/core/transformer/transformer_block.py", line 610, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/transformer/module.py", line 353, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/transformer/transformer_block.py", line 754, in forward
hidden_states, context = layer(
^^^^^^
File "megatron/core/transformer/transformer_layer.py", line 1169, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/transformer/module.py", line 353, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/transformer/transformer_layer.py", line 491, in forward
output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/transformer/transformer_layer.py", line 715, in _forward_mlp
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/transformer/moe/moe_layer.py", line 372, in forward
outputs = custom_forward(hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/transformer/moe/moe_layer.py", line 353, in custom_forward
dispatched_input, probs = self.dispatch(hidden_states, probs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/transformer/moe/moe_layer.py", line 247, in dispatch
return self.token_dispatcher.token_dispatch(hidden_states, probs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/transformer/moe/token_dispatcher.py", line 666, in token_dispatch
global_input_tokens = all_to_all(
^^^^^^^^^^^
File "megatron/core/tensor_parallel/mappings.py", line 538, in all_to_all
return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 581, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "megatron/core/tensor_parallel/mappings.py", line 444, in forward
torch.distributed.all_to_all_single(
File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4688, in all_to_all_single
work = group.alltoall_base(
^^^^^^^^^^^^^^^^^^^^
RuntimeError: Split sizes doesn't match total dim 0 size
```
Signed-off-by: Hollow Man <hollowman@opensuse.org>
f4177c3 to
b484923
Compare
…rocess for router replay (verl-project#4986) ### What does this PR do? When router replay is enabled, and we have `moe_token_dispatcher_type = "alltoall"`, duplicate indices in top_indices can cause routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall and we shall derive it from the routing map instead in this case. To fix this error: ```log File "/root/verl/verl/workers/megatron_workers.py", line 901, in compute_log_prob output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/verl/verl/utils/profiler/performance.py", line 105, in f return self.log(decorated_function, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/verl/verl/utils/profiler/performance.py", line 118, in log output = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/root/verl/verl/workers/actor/megatron_actor.py", line 238, in compute_log_prob output = self.forward_backward_batch( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/verl/verl/workers/actor/megatron_actor.py", line 699, in forward_backward_batch losses_reduced = forward_backward_func( ^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 629, in forward_backward_no_pipelining output_tensor, num_tokens = forward_step( ^^^^^^^^^^^^^ File "megatron/core/pipeline_parallel/schedules.py", line 417, in forward_step output_tensor, loss_func = forward_step_func(data_iterator, model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/verl/verl/workers/actor/megatron_actor.py", line 614, in forward_step output = forward_fn( ^^^^^^^^^^^ File "/root/verl/verl/models/mcore/model_forward_fused.py", line 120, in fused_forward_model output_orig: CausalLMOutputForPPO = model(**input_args) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/distributed/data_parallel_base.py", line 22, in forward return self.module(*inputs, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/transformer/module.py", line 490, in forward outputs = self.module(*inputs, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/verl/verl/models/mcore/model_forward_fused.py", line 178, in _fused_GPTModel_forward hidden_states = model.decoder( ^^^^^^^^^^^^^^ File "megatron/core/transformer/transformer_block.py", line 610, in __call__ return super().__call__(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/transformer/module.py", line 353, in __call__ return super().__call__(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/transformer/transformer_block.py", line 754, in forward hidden_states, context = layer( ^^^^^^ File "megatron/core/transformer/transformer_layer.py", line 1169, in __call__ return super().__call__(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/transformer/module.py", line 353, in __call__ return super().__call__(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/transformer/transformer_layer.py", line 491, in forward output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/transformer/transformer_layer.py", line 715, in _forward_mlp mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/transformer/moe/moe_layer.py", line 372, in forward outputs = custom_forward(hidden_states) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/transformer/moe/moe_layer.py", line 353, in custom_forward dispatched_input, probs = self.dispatch(hidden_states, probs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/transformer/moe/moe_layer.py", line 247, in dispatch return self.token_dispatcher.token_dispatch(hidden_states, probs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/transformer/moe/token_dispatcher.py", line 666, in token_dispatch global_input_tokens = all_to_all( ^^^^^^^^^^^ File "megatron/core/tensor_parallel/mappings.py", line 538, in all_to_all return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 581, in apply return super().apply(*args, **kwargs) # type: ignore[misc] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "megatron/core/tensor_parallel/mappings.py", line 444, in forward torch.distributed.all_to_all_single( File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4688, in all_to_all_single work = group.alltoall_base( ^^^^^^^^^^^^^^^^^^^^ RuntimeError: Split sizes doesn't match total dim 0 size ``` ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [X] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. <sub>✨ Presented to you with <a href="https://macaron.im/mindlab">Mind Lab</a> - A Lab for Experiential Intelligence.</sub> Signed-off-by: Hollow Man <hollowman@opensuse.org>
What does this PR do?
When router replay is enabled, and we have
moe_token_dispatcher_type = "alltoall", duplicate indices in top_indices can cause routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall and we shall derive it from the routing map instead in this case.To fix this error:
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.✨ Presented to you with Mind Lab - A Lab for Experiential Intelligence.