Enable AG/RS overlap with explicit process group passing#3249
Enable AG/RS overlap with explicit process group passing#3249jeffnvidia wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
|
Hi, @jeffnvidia could you resolve merge conflict, please? |
cfdfe8b to
f1aa8e4
Compare
Hi, yes, resolved |
|
/ok to test f1aa8e4 |
58c2d96 to
1757e42
Compare
shjwudp
left a comment
There was a problem hiding this comment.
Looks good to me for the Megatron-FSDP part.
megatron/training/arguments.py
Outdated
| group.add_argument('--fsdp-manual-registration', action='store_true', dest='fsdp_manual_registration', | ||
| default=False, help='Manually register the FSDP communication buffers to NCCL user buffer.' | ||
| 'This option is only effective when use-megatron-fsdp and use-nccl-ub is set.') | ||
| group.add_argument('--use-sharp', action='store_true', |
There was a problem hiding this comment.
These three arguments should be initialized from DistributedInitConfig.
For reference, see this Megatron-LM commit: adce147#diff-42fa19ec8893eabf951a5bb21edfc0dfe7c9c8949d5087ab133f5897ea0e3213R2203
|
not sure why I'm suddenly having a secrets detector error, I just deleted lines I myself added |
|
thanks for the approval, can people from @NVIDIA/mcore-oncall @NVIDIA/core-nemo review it ? Thanks a lot |
f59e385 to
98eecc7
Compare
hey @cspades, I edited the PR according to the commit, can you review it ? Thanks a lot |
98eecc7 to
00cbcf2
Compare
cspades
left a comment
There was a problem hiding this comment.
LGTM, thanks for jumping through some hoops on this one.
TODO: Benchmark and migrate to pg_collection.
megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py
Show resolved
Hide resolved
|
/ok to test f9b5b47 |
Signed-off-by: jeffnvidia <jmahou@nvidia.com>
Signed-off-by: jeffnvidia <jmahou@nvidia.com>
Signed-off-by: jeffnvidia <jmahou@nvidia.com>
f9b5b47 to
d1c1c62
Compare
|
/ok to test d1c1c62 |
|
/ok to test 56ecd09 |
| 'to overlap reduce-scatter and all-gather operations.') | ||
| help='Enable AG/RS overlap optimization by creating separate ' | ||
| 'all-gather communicators.') | ||
| group.add_argument('--megatron-fsdp-pg-collection', action='store_true', |
There was a problem hiding this comment.
Why are we exposing this knob supporting both behaviors? Is there something stopping us that I'm missing?
There was a problem hiding this comment.
+1, this seems like an implementation detail, not something the user should have to worry about.
There was a problem hiding this comment.
Context here should explain the status quo: #3249 (comment)
There was a problem hiding this comment.
I don't like this either.
|
/claude review |
| decoder_rank_gen = RankGenerator( | ||
| tp=tp_size, ep=1, dp=dp_size, pp=pp_size, cp=cp_size, order='tp-cp-ep-dp-pp', rank_offset=0 | ||
| ) |
There was a problem hiding this comment.
Bug: order and rank_offset are hardcoded here, but initialize_model_parallel() accepts them as parameters. If a user passes a different order (e.g., 'tp-dp-pp-cp-ep') or non-zero rank_offset, this RankGenerator will produce different rank lists than the actual DP groups, silently creating AG groups with wrong membership.
The same issue applies to the expert RankGenerator below (line 1423-1431).
A simpler and more robust approach would be to get the ranks directly from the already-created groups (which is what the tests in this PR already do):
# Regular DP AG group
dp_cp_group = get_data_parallel_group(with_context_parallel=True)
all_dp_cp_ranks = get_data_parallel_group_ranks(with_context_parallel=True)
dp_cp_ag_group = None
for ranks_with_cp in all_dp_cp_ranks: # or iterate all groups
group_with_cp_ag = create_group(
ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs or {}),
group_desc='DATA_PARALLEL_GROUP_WITH_CP_AG',
)
if rank in ranks_with_cp:
dp_cp_ag_group = group_with_cp_agOr even simpler — since create_group is a collective, you could collect the ranks from _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP (the global that stores all dp-cp rank lists). This avoids re-deriving ranks entirely and guarantees consistency with the initialized state.
There was a problem hiding this comment.
Overall the approach of moving from global state to explicit ProcessGroupCollection passing looks good. One bug flagged inline:
create_all_gather_groups() hardcodes order and rank_offset — the RankGenerator in this function uses order='tp-cp-ep-dp-pp' and rank_offset=0, but initialize_model_parallel() accepts these as configurable parameters. For non-default configurations, the AG groups will silently get wrong rank membership. The fix is straightforward: retrieve ranks from the already-initialized groups instead of re-deriving them.
|
TODO for me to benchmark this PR on PG collection backend, and fully deprecate the global parallel_state for Megatron-FSDP. cc @shjwudp |
What does this PR do ?
This PR enables all-gather (AG) / reduce-scatter (RS) communication overlap for both regular data parallelism and Expert Parallelism (MoE models) by migrating from global process group management to explicit argument passing via
ProcessGroupCollection.Motivation
Problem: PR (#2663) (merged 2026-01-27) implemented AG/RS overlap for regular DP using global state in
parallel_state.py, but:Solution: This PR refactors the implementation to use explicit process group passing while adding MoE support.
Key Changes
1. ProcessGroupCollection Extension (
process_groups_config.py)dp_cp_agfield for regular data parallel all-gather groupexpt_dp_agfield for expert data parallel all-gather group (NEW for MoE)None- users must create them explicitly (opt-in feature)2. FSDPDistributedIndex Refactor (
utils.py)parallel_stateglobalsfsdp_group_agandexpt_fsdp_group_agas explicit constructor parametersget_fsdp_group(..., independent_all_gather=True)returns appropriate AG group based on parameter type3. Explicit Group Extraction (
mcore_fsdp_adapter.py)pg_collectionusinggetattr()FSDPDistributedIndex4. Expert Parameter Support (
param_and_grad_buffer.py)group.is_expert_paramflag5. Cleanup of PR #2663 Global State
Removed:
_DATA_PARALLEL_GROUP_WITH_CP_AGglobal variablehas_separate_all_gather_group()functionindependent_all_gatherparameter fromget_data_parallel_group()create_all_gather_groupparameter frominitialize_model_parallel()--create-all-gather-groupCLI argumentBenefits
✅ Clean Architecture: No new globals in
parallel_state.py✅ Explicit Data Flow: Process groups passed as arguments, not accessed globally
✅ Expert Parallelism Support: AG/RS overlap now works for MoE models
✅ Testability: Easier to mock and test with dependency injection
✅ Backward Compatible: Opt-in feature (defaults to
None, same behavior as before)Migration Guide for Users
Before (PR #1 - no longer supported):