[train][2/N] Support for Megatron PP + CP for R3#1327
[train][2/N] Support for Megatron PP + CP for R3#1327devpatelio wants to merge 19 commits intomainfrom
Conversation
Co-authored-by: Dev Patel <dev.patel@berkeley.edu>
…nning with tp + ep for megatron
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive support for pipeline and context parallelism for R3 (Router Replay). The changes involve propagating "rollout_expert_indices" through the inference engine outputs, generator outputs, and training batches. A new utility file, "replay_utils.py", has been added to manage the complex logic of patching Megatron components and aligning these indices across various parallelism configurations (CP, TP, PP). The implementation appears robust, handling scenarios like dense-layer mismatches and uneven pipeline partitioning. Configuration options have been updated to enable this feature, and new tests have been added to validate its correctness and impact on log probabilities and training loss.
| # Log if enable_return_routed_experts is being passed | ||
| if "enable_return_routed_experts" in kwargs: | ||
| logger.info( | ||
| f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs" | ||
| ) | ||
| else: | ||
| logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") |
There was a problem hiding this comment.
These DEBUG log messages are helpful during development but should ideally be removed or made configurable (e.g., tied to a verbose flag) before merging to production. Excessive logging can clutter output and potentially impact performance.
| # Log if enable_return_routed_experts is being passed | |
| if "enable_return_routed_experts" in kwargs: | |
| logger.info( | |
| f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs" | |
| ) | |
| else: | |
| logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") | |
| # Log if enable_return_routed_experts is being passed | |
| # Consider making this logging configurable or removing it for production builds. | |
| # if "enable_return_routed_experts" in kwargs: | |
| # logger.info( | |
| # f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs" | |
| # ) | |
| # else: | |
| # logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") |
| if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None: | ||
| rollout_expert_indices = None # hack: assume uniform sampling params |
There was a problem hiding this comment.
🔴 Wrong comparison operator (== 0 instead of > 0) prevents rollout_expert_indices from being set to None
In _postprocess_outputs, line 171 uses len(rollout_expert_indices) == 0 but should use len(rollout_expert_indices) > 0 (or truthiness, like the logprobs check on line 168). With == 0: (1) if the list is empty, rollout_expert_indices[0] raises an IndexError; (2) if the list is non-empty (the normal case), the condition is always False, so a list of all None values (e.g. [None, None, ...]) is never collapsed to None. This means when enable_return_routed_experts is disabled (the default), downstream code in inference_engine_client.py:169-171 sees a truthy list of Nones, sets add_rollout_expert_indices = True, and propagates [None, None, ...] instead of None through the pipeline.
Comparison with correct pattern on line 168
Line 168 (correct): if len(response_logprobs) and response_logprobs[0] is None:
Line 171 (broken): if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None:
| if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None: | |
| rollout_expert_indices = None # hack: assume uniform sampling params | |
| if len(rollout_expert_indices) and rollout_expert_indices[0] is None: | |
| rollout_expert_indices = None # hack: assume uniform sampling params |
Was this helpful? React with 👍 or 👎 to provide feedback.
|
closing in favor of newer PR |
Summary
Extending #1273, this PR provides support for pipeline parallelism and context parallelism for R3. See #815 for tracking of future tasks to fully support routing replay in all settings.
Implementation
Pipeline Parallelism
For pipeline parallelism, we create a helper function
_get_current_pp_stage_layer_range(model_config)which maps the current PP rank and its layers to the global layer offset across all the model layers so that we can use this offset to correctly select the corresponding replay instances from aRouterReplay.global_router_replay_instances.First, we get the number of pipeline stages from PP world size along with the total number of model layers. For models containing dense layers / unequal pipeline stages, megatron supports setting a customer number of layers for the first and last PP rank. Then, we capture these values from the model config and check to see if the remaining number of layers can be evenly distributed across the remaining PP ranks. Finally, we return the transformer-layer range owned by the current PP rank as s_p, n_p, where:
For an even partition with L total layers and P pipeline stages:
For uneven partitioning, if the first and/or last stages are assigned custom layer counts, we subtract those from$L$ , split the remaining layers evenly among the remaining stages, and then shift the start index accordingly. This means we can support cases like Moonlight-16B models which have 27 layers, where we can pass
num_layers_in_first_pipeline_stageas 13 for PP=2.Context Parallelism
When using sample_packing, our megatron worker pre-processes and post-processes packed sequences. When CP is enabled, it is split into CP*2 chunks, so each effective GPU gets 2 CP chunks of half the size. See NVIDIA/TransformerEngine#1368. To account for this extra chunking, the
setup_per_microbatch_replay_forwardmethod is updated to so that the effective_seq_len accounts for cp_size * 2 (same as the alignment in preprocess_packed_seqs in megatron_utils.py) along with the seqlen_per_cp as seqlen_per_cp // 2. We then index the front and back halves of these CP chunks from the aligned indices across the CP ranks and then concatenate them. This ensures that the router replay indices see the correct tokens from this CP chunking for megatron.Testing
You can test with CP and/or PP configs from the test_router_replay file.