Skip to content

[train][2/N] Support for Megatron PP + CP for R3#1327

Closed
devpatelio wants to merge 19 commits intomainfrom
r3-pp
Closed

[train][2/N] Support for Megatron PP + CP for R3#1327
devpatelio wants to merge 19 commits intomainfrom
r3-pp

Conversation

@devpatelio
Copy link
Copy Markdown
Collaborator

@devpatelio devpatelio commented Mar 16, 2026

Summary

Extending #1273, this PR provides support for pipeline parallelism and context parallelism for R3. See #815 for tracking of future tasks to fully support routing replay in all settings.

Implementation

Pipeline Parallelism
For pipeline parallelism, we create a helper function _get_current_pp_stage_layer_range(model_config) which maps the current PP rank and its layers to the global layer offset across all the model layers so that we can use this offset to correctly select the corresponding replay instances from a RouterReplay.global_router_replay_instances.

First, we get the number of pipeline stages from PP world size along with the total number of model layers. For models containing dense layers / unequal pipeline stages, megatron supports setting a customer number of layers for the first and last PP rank. Then, we capture these values from the model config and check to see if the remaining number of layers can be evenly distributed across the remaining PP ranks. Finally, we return the transformer-layer range owned by the current PP rank as s_p, n_p, where:

  • s_p is the global starting layer index for rank p
  • n_p is the number of transformer layers assigned to that stage

For an even partition with L total layers and P pipeline stages:

  • next_n_pp_layers = L // P, start_index = next_n_pp_layers * pp_rank
  • the offset should thus span (next_n_pp_layers * pp_rank) : (next_n_pp_layers * (pp_rank+1)

For uneven partitioning, if the first and/or last stages are assigned custom layer counts, we subtract those from $L$, split the remaining layers evenly among the remaining stages, and then shift the start index accordingly. This means we can support cases like Moonlight-16B models which have 27 layers, where we can pass num_layers_in_first_pipeline_stage as 13 for PP=2.

Context Parallelism

When using sample_packing, our megatron worker pre-processes and post-processes packed sequences. When CP is enabled, it is split into CP*2 chunks, so each effective GPU gets 2 CP chunks of half the size. See NVIDIA/TransformerEngine#1368. To account for this extra chunking, the setup_per_microbatch_replay_forward method is updated to so that the effective_seq_len accounts for cp_size * 2 (same as the alignment in preprocess_packed_seqs in megatron_utils.py) along with the seqlen_per_cp as seqlen_per_cp // 2. We then index the front and back halves of these CP chunks from the aligned indices across the CP ranks and then concatenate them. This ensures that the router replay indices see the correct tokens from this CP chunking for megatron.

Testing

You can test with CP and/or PP configs from the test_router_replay file.


Open with Devin

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive support for pipeline and context parallelism for R3 (Router Replay). The changes involve propagating "rollout_expert_indices" through the inference engine outputs, generator outputs, and training batches. A new utility file, "replay_utils.py", has been added to manage the complex logic of patching Megatron components and aligning these indices across various parallelism configurations (CP, TP, PP). The implementation appears robust, handling scenarios like dense-layer mismatches and uneven pipeline partitioning. Configuration options have been updated to enable this feature, and new tests have been added to validate its correctness and impact on log probabilities and training loss.

Comment on lines +337 to +343
# Log if enable_return_routed_experts is being passed
if "enable_return_routed_experts" in kwargs:
logger.info(
f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs"
)
else:
logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These DEBUG log messages are helpful during development but should ideally be removed or made configurable (e.g., tied to a verbose flag) before merging to production. Excessive logging can clutter output and potentially impact performance.

Suggested change
# Log if enable_return_routed_experts is being passed
if "enable_return_routed_experts" in kwargs:
logger.info(
f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs"
)
else:
logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs")
# Log if enable_return_routed_experts is being passed
# Consider making this logging configurable or removing it for production builds.
# if "enable_return_routed_experts" in kwargs:
# logger.info(
# f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs"
# )
# else:
# logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs")

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 potential issue.

View 6 additional findings in Devin Review.

Open in Devin Review

Comment on lines +171 to +172
if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None:
rollout_expert_indices = None # hack: assume uniform sampling params
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Wrong comparison operator (== 0 instead of > 0) prevents rollout_expert_indices from being set to None

In _postprocess_outputs, line 171 uses len(rollout_expert_indices) == 0 but should use len(rollout_expert_indices) > 0 (or truthiness, like the logprobs check on line 168). With == 0: (1) if the list is empty, rollout_expert_indices[0] raises an IndexError; (2) if the list is non-empty (the normal case), the condition is always False, so a list of all None values (e.g. [None, None, ...]) is never collapsed to None. This means when enable_return_routed_experts is disabled (the default), downstream code in inference_engine_client.py:169-171 sees a truthy list of Nones, sets add_rollout_expert_indices = True, and propagates [None, None, ...] instead of None through the pipeline.

Comparison with correct pattern on line 168

Line 168 (correct): if len(response_logprobs) and response_logprobs[0] is None:
Line 171 (broken): if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None:

Suggested change
if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None:
rollout_expert_indices = None # hack: assume uniform sampling params
if len(rollout_expert_indices) and rollout_expert_indices[0] is None:
rollout_expert_indices = None # hack: assume uniform sampling params
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@erictang000 erictang000 changed the title R3 pp [train][2/N] Support for Megatron PP + CP for R3 Mar 17, 2026
@erictang000
Copy link
Copy Markdown
Collaborator

closing in favor of newer PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants