Skip to content

add all_gather process-group for overlapping in fsdp disributed training#2663

Merged
shjwudp merged 5 commits intoNVIDIA:mainfrom
jeffnvidia:separate_AG_RS_streams
Jan 27, 2026
Merged

add all_gather process-group for overlapping in fsdp disributed training#2663
shjwudp merged 5 commits intoNVIDIA:mainfrom
jeffnvidia:separate_AG_RS_streams

Conversation

@jeffnvidia
Copy link
Copy Markdown
Contributor

@jeffnvidia jeffnvidia commented Dec 15, 2025

What does this PR do ?

This PR intends to separate the all_gather process-group from reduce-scatter and the other operations. The goal is to have overlapping of these 2 collectives which when combined with SHARP have been proven to greatly increase performances (~15%)

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share discuss a design-doc with the team.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Dec 15, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@jeffnvidia jeffnvidia force-pushed the separate_AG_RS_streams branch from d924e10 to 5b61c81 Compare December 17, 2025 14:41
@jeffnvidia jeffnvidia marked this pull request as ready for review December 24, 2025 17:25
@jeffnvidia jeffnvidia requested review from a team as code owners December 24, 2025 17:25
@github-actions github-actions bot requested a review from Phlip79 December 24, 2025 17:25
@jeffnvidia
Copy link
Copy Markdown
Contributor Author

I am having this warning on the PR, any idea why ?

[MCORE][MultiGroupMemPoolAllocator] Failed to deregister mem pool from<torch.distributed.distributed_c10d.ProcessGroup object at 0x400228ee6170>(DATA_PARALLEL_GROUP_WITH_CP_AG) group!!

cursor said it's normal behavior but I'm suspicious

@jeffnvidia jeffnvidia force-pushed the separate_AG_RS_streams branch 4 times, most recently from c5d7a3f to 859859e Compare January 4, 2026 15:51
Copy link
Copy Markdown
Contributor

@shjwudp shjwudp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks.


all_gather_ops = []
if self.dist_index.get_fsdp_group(is_expert_parallel=False, all_gather=True) is not None:
# All-gather group when overlapping
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a more accurate description? For example:
"All-gather group used when overlapping all-gather and gradient reduction."

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no problem changed to your description

Copy link
Copy Markdown
Member

@youngeunkwon0405 youngeunkwon0405 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic-wise LGTM. I left some comments regarding variable names and comments.

# All-gather the module weights in each buffer shard into the allocated bucket.
# Now each rank will have a copy of this FSDP unit module's weights.
if self.buffer.dist_index.get_fsdp_group(is_expert_parallel=False, all_gather=True) is not None:
# All-gather group when overlapping
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please write a more specific comment that anyone without the context could understand what you are trying to do?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure let me know if it's good now

@jeffnvidia jeffnvidia force-pushed the separate_AG_RS_streams branch from 859859e to 453f26a Compare January 7, 2026 15:17
@shengf-nv
Copy link
Copy Markdown
Contributor

@jeffnvidia, thanks for creating this PR to add all_gather process group. I have two comments:

  1. Only PG for regular data parallelism has been split, not for ET data parallelism. Do you plan to add it in another PR?

  2. The PR decides which PG to use for all gather based on the return PG from self.buffer.dist_index.get_fsdp_group(is_expert_parallel=False, all_gather=True). What is your plan when ET data parallelism also exists? My suggestion is: instead of picking PG when all gather is called, we can set the appropriate data parallel process group for DataParallelBuffer of model_weight_buffer in function _init_each_parameter_group_buffers. Each parameter_group has a flag to indicate if it is a regular weight or a ET weight. We can use that flag to choose regular DP or ET DP. Then we do not need to make any change when all gather is being called.

Let me know if it makes sense to you.

Thanks
Sheng

@jeffnvidia
Copy link
Copy Markdown
Contributor Author

Hey Sheng thanks for your comments

  1. Only PG for regular data parallelism has been split, not for ET data parallelism. Do you plan to add it in another PR?

I can add it in a separate PR, I don't think it would require much work. The reason I didn't do it now was that we never tested it yet as far as I know. Once we move on to testing EP, I can create a new PR.

  1. The PR decides which PG to use for all gather based on the return PG from self.buffer.dist_index.get_fsdp_group(is_expert_parallel=False, all_gather=True). What is your plan when ET data parallelism also exists? My suggestion is: instead of picking PG when all gather is called, we can set the appropriate data parallel process group for DataParallelBuffer of model_weight_buffer in function _init_each_parameter_group_buffers. Each parameter_group has a flag to indicate if it is a regular weight or a ET weight. We can use that flag to choose regular DP or ET DP. Then we do not need to make any change when all gather is being called.

I've refactored per your recommendation:
PG selection now happens in _init_each_parameter_group_buffers() based on parameter group flags
All-gather call sites simplified to use the pre-configured group
Much cleaner and easier to extend, thanks !

let me know if it fits what you had in mind

@jeffnvidia jeffnvidia force-pushed the separate_AG_RS_streams branch from dff91d4 to d6dd2a6 Compare January 8, 2026 13:09
@youngeunkwon0405
Copy link
Copy Markdown
Member

I just realized that if EP is enabled, there will be a different DP group for MoE layers. I think if you run this with EP, it will not show the expected behavior. I think you have to modify the code for this, or at least insert an assertion about this.

@shjwudp, please correct me if I am wrong.

@jeffnvidia
Copy link
Copy Markdown
Contributor Author

I just realized that if EP is enabled, there will be a different DP group for MoE layers. I think if you run this with EP, it will not show the expected behavior. I think you have to modify the code for this, or at least insert an assertion about this.

@shjwudp, please correct me if I am wrong.

in the code, I am checking if expert parallelism is on and only doing the changes if not which is the scope of the POC

@shengf-nv
Copy link
Copy Markdown
Contributor

@jeffnvidia, thanks for the update. The updated code looks good to me.

It is your call if you want to have another PR to support ET. However, more recent models all have ET. I already have Llama4 setup, and plan to try overlapped AG+RS with Llama4 soon.

@youngeunkwon0405
Copy link
Copy Markdown
Member

I just realized that if EP is enabled, there will be a different DP group for MoE layers. I think if you run this with EP, it will not show the expected behavior. I think you have to modify the code for this, or at least insert an assertion about this.
@shjwudp, please correct me if I am wrong.

in the code, I am checking if expert parallelism is on and only doing the changes if not which is the scope of the POC

I don't see the code part that checks if expert_parallel and makes a decision. Can you point out which part of your code is doing that?

@jeffnvidia
Copy link
Copy Markdown
Contributor Author

jeffnvidia commented Jan 11, 2026

I just realized that if EP is enabled, there will be a different DP group for MoE layers. I think if you run this with EP, it will not show the expected behavior. I think you have to modify the code for this, or at least insert an assertion about this.
@shjwudp, please correct me if I am wrong.

in the code, I am checking if expert parallelism is on and only doing the changes if not which is the scope of the POC

I don't see the code part that checks if expert_parallel and makes a decision. Can you point out which part of your code is doing that?

It's line 1890 : https://github.com/NVIDIA/Megatron-LM/pull/2663/files#diff-da62f73a7a6a4ac7815ed316a147ba348a7915e35a2f4885ceaf1678e5f650fbR1890

if it is expert_parallelism, it wont even try to create the separate group

@jeffnvidia jeffnvidia force-pushed the separate_AG_RS_streams branch from b8dcd85 to 852aa57 Compare January 22, 2026 14:52
@jeffnvidia jeffnvidia force-pushed the separate_AG_RS_streams branch from 852aa57 to caf8de5 Compare January 22, 2026 14:58
@jeffnvidia
Copy link
Copy Markdown
Contributor Author

I rebased on main to fix the merging error. @youngeunkwon0405, could you re-run the CI ? Thanks a lot

Could you approve the PR @jaredcasper as we said and I'll start working immediately on a PR that puts everything in order. Thanks a lot

@youngeunkwon0405
Copy link
Copy Markdown
Member

/ok to test caf8de5

@jeffnvidia
Copy link
Copy Markdown
Contributor Author

/ok to test caf8de5

thanks Youngeun, we're having the same bug again of the CICD which seems to be random, I don't why and if it's a blocker or not

@youngeunkwon0405
Copy link
Copy Markdown
Member

Got it, I will reinitiate the failed runs.

@jeffnvidia
Copy link
Copy Markdown
Contributor Author

Hi guys, who do I need to get the approval from to finally merge this PR ?

@ericharper @jaredcasper @NVIDIA/core-adlr ?

@cspades
Copy link
Copy Markdown
Member

cspades commented Jan 25, 2026

Needs an approval from @NVIDIA/core-adlr or else it won't merge. It looks like concerns pertaining to MCore PG management have not been resolved.

@cspades
Copy link
Copy Markdown
Member

cspades commented Jan 25, 2026

M-FSDP is using pg_collection(mcore_fsdp_adapter.py#L193). However, as I understand it, the current MCore pretraining still initializes through the global parallel state (training.py#L1156), right?

And just to chime in here, FSDPDistributedIndex is a one-stop shop to hack in PG groups by argument, so really there shouldn't be any blocker for a fully-argument-based PG management system. And add the same args to fully_shard and you're good to go.

@shjwudp shjwudp added this pull request to the merge queue Jan 26, 2026
Merged via the queue into NVIDIA:main with commit 528cb2e Jan 27, 2026
75 of 83 checks passed
daiyaanarfeen pushed a commit to daiyaanarfeen/Megatron-LM that referenced this pull request Feb 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.