Conversation
Co-authored-by: Dev Patel <dev.patel@berkeley.edu>
…nning with tp + ep for megatron
|
Current State: For small scale tests, routing replay seems to be working as shown above - tested only for TP=8 serving and TP=4 and EP=8 training. What's not working:
TODOs:
|
Solved! the issue was that in our test we were setting vLLM logprobs - mean: -2.564655
Megatron (replay) - mean: -9.231623
Megatron (no rep) - mean: -9.647593after setting vLLM logprobs - mean: -0.223607, std: 0.674102
Megatron (replay) - mean: -0.223626, std: 0.674850
Megatron (no rep) - mean: -0.224379, std: 0.677036
With replay - logprob diff mean: 0.006648, std: 0.021737
Without replay - logprob diff mean: 0.011115, std: 0.035957 |
verified that cherry picking the changes from #1300 to use the mp backend allow us to work around the compiled graph timeout
|
|
|
||
|
|
||
| @pytest.mark.megatron | ||
| @pytest.mark.skip(reason="Skipping router replay test for now due to size constraints") |
There was a problem hiding this comment.
skipping these for now - need to test some smaller models out - tracking as a follow up task in #815, will do ASAP along with PP + CP support
| return prompts | ||
|
|
||
|
|
||
| def _ensure_chat_template(tokenizer): |
There was a problem hiding this comment.
this is needed for "allenai/OLMoE-1B-7B-0924", which is used in the skyrl_gym_generator test - i plan to port over this model to the other router replay tests (and maybe megatron moe tests in general) since it's supported in Megatron-Bridge and is a 7B with 1B activated MoE that makes it nice for CI.
# Summary Extending #1273, this PR provides support for pipeline parallelism and context parallelism for R3. See #815 for tracking of future tasks to fully support routing replay in all settings. # Implementation **Pipeline Parallelism** For pipeline parallelism, we create a helper function `_get_current_pp_stage_layer_range(model_config)` which maps the current PP rank and its layers to the global layer offset across all the model layers so that we can use this offset to correctly select the corresponding replay instances from a `RouterReplay.global_router_replay_instances`. First, we get the number of pipeline stages from PP world size along with the total number of model layers. For models containing dense layers / unequal pipeline stages, megatron supports setting a customer number of layers for the first and last PP rank. Then, we capture these values from the model config and check to see if the remaining number of layers can be evenly distributed across the remaining PP ranks. Finally, we return the transformer-layer range owned by the current PP rank as s_p, n_p, where: - s_p is the global starting layer index for rank p - n_p is the number of transformer layers assigned to that stage For an even partition with L total layers and P pipeline stages: - next_n_pp_layers = L // P, start_index = next_n_pp_layers * pp_rank - the offset should thus span (next_n_pp_layers * pp_rank) : (next_n_pp_layers * (pp_rank+1) For uneven partitioning, if the first and/or last stages are assigned custom layer counts, we subtract those from $L$, split the remaining layers evenly among the remaining stages, and then shift the start index accordingly. This means we can support cases like Moonlight-16B models which have 27 layers, where we can pass `num_layers_in_first_pipeline_stage` as 13 for PP=2. **Context Parallelism** When using sample_packing, our megatron worker pre-processes and post-processes packed sequences. When CP is enabled, it is split into CP*2 chunks, so each effective GPU gets 2 CP chunks of half the size. See NVIDIA/TransformerEngine#1368. To account for this extra chunking, the `setup_per_microbatch_replay_forward` method is updated to so that the effective_seq_len accounts for cp_size * 2 (same as the alignment in preprocess_packed_seqs in megatron_utils.py) along with the seqlen_per_cp as seqlen_per_cp // 2. We then index the front and back halves of these CP chunks from the aligned indices across the CP ranks and then concatenate them. This ensures that the router replay indices see the correct tokens from this CP chunking for megatron. **Testing** You can test with CP and/or PP configs from the test_router_replay file. <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1327" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end --> --------- Signed-off-by: SumanthRH <sumanthrh99@gmail.com> Co-authored-by: Eric Tang <erictang000@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Charlie Ruan <charlieruan@berkeley.edu> Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com>









Overview
This PR adds support for Rollout Routing Replay (R3) from (See Paper).
See #815 for tracking of future tasks to fully support routing replay in all settings.
We add the following flags to enable R3:
cfg.generator.inference_engine.enable_return_routed_experts=Trueis a pass through argument to vLLM, which records expert router indices (returning a list of dimension(batch_size, seq_len, num_layers, top_k).We then pass this list
rollout_expert_indiceslist through to Megatron's nativeRouterReplayfeature (link).When
cfg.trainer.policy.megatron_config.moe_enable_routing_replayis set totrue, Megatron initializes an instance ofRouterReplayon each training worker rank.RouterReplay.set_replay_data(per_layer_data)can be used to set router decisions, andRouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)andRouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_BACKWARD)can be used to set the routing mode to be forward or backward.Results
GSM8K Training on


moonlight16b-a3bshows R3 improves training stability - this can be seen both in logprob diffs as well as inclip_ratio,grad_norm, andloss, which otherwise explode and collapse training.Supported Settings
Router Replay is supported for the following settings:
Generator Settings
use_conversation_multi_turn=Trueanduse_conversation_multi_turn=Falsebatched=Falseandbatched=Trueasync_engine=Trueandasync_engine=Falseretokenize_chat_historymode - i.e.self.use_conversation_multi_turn and self.custom_chat_templateself.generator_cfg.step_wise_trajectories- there are some question marks about how to support this when using step wise training and not strictly appending (what should the routing look like for per turn obs that the inference engine doesn't see? - do we need to disable routing overrides for those tokens?)Inference Engine Settings
_SKYRL_USE_NEW_INFERENCEis not supported - this will be added in a follow up PRcfg.generator.distributed_executor_backendmust be set tomp- hanging related to a Ray Compiled Graph issue occurs when using the defaultrayvLLM distributed executor backend. (see [Bug]: Generation hangs until RAY_CGRAPH_get_timeout (300s) with Ray compiled DAG executor vllm-project/vllm#36237 for details on the error that comes up)mpalso means that serving must be single node per engine, until we add support for using the mp backend with multi-node serving - progress tracked here: [vllm] Fully enable mp distributed executor backend in vLLM #1309Trainer Settings
Custom Generator support
Tests
Adds
test_router_replay.py, which includes:test_logprobs- integration test that runs a training batch through vllm, and through megatron with and without R3, to verify that logprob diffs are lower with routing replaytest_forward_backward- unit test forforward_backwardthat verifies that a training step can complete successfully when routing replay indices are passed inAdds
test_generator_multi_turn_gsm8k_router_replaytotest_skyrl_gym_generatorto verify that theSkyRLGymGeneratorplumbs through the router indices in an expected format.Rollout Routing Replay
Relevant resources:
vLLM PR: vllm-project/vllm#28284
Verl PR: verl-project/verl#4101
Mindlab blog: https://macaron.im/mindlab/research/router-replay-r3-why-it-failed-and-how-we-fixed-it
Megatron-LM API guide: https://github.com/NVIDIA/Megatron-LM/blob/main/docs/api-guide/router_replay.md