-
Notifications
You must be signed in to change notification settings - Fork 624
Integrate DeepEP to experimental torchtitan #2107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Hi @elfiegg! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
f33b3d9 to
a5875e5
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing!
Perf is CPU bottlenecked and lags behind baseline by ~41% (~7.9% MFU vs. ~11.4% MFU)
May I ask what is the baseline?
Btw, I think to fully utilize the power of DeepEP, we also need to have node-limited routing, which the current torchtitan DSv3 model doesn't have.
@shuhuayu let's add it? we can refer to HF or deepseek original impl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of making it an experiment (which restricts it to a special version of deepseek_v3), I think we should integrate it directly in core.
We can have a factory method (e.g. build_moe) which takes a string (e.g. "deep_ep") to dispatch to this version of MoE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure that's a great idea! - once I confirm this works for larger models and improves perf
Regarding integrating directly to main - do we need to manage DeepEP dependency at all or we leave it to the users to install?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I prefer
we leave it to the users to install
instead of bundling it by default. We can explicitly mention this in try-catch when we do the import.
| from torch.distributed.tensor import distribute_module | ||
|
|
||
|
|
||
| class DeepEPExpertParallel(ParallelStyle): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used anywhere? I'm guessing that this is not running e2e with torchtitan train.py which is still WIP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can also handle the token communication in here -- the design would look cleaner and more aligned with existing interface.
This is our implementation and our deepep interface is similar to Nvidia's version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pr! I think we should support node-limited routing to make multi-node setup faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the node-limited routing here: #2111. Perhaps it helps make deepep faster in multi-node setups.
Actually I have rerun this last night and the perf caught up - the lagging perf was gone once I enabled FSDP for MoE layer (which I disabled for debugging purpose). Running below command, I got 13% MFU for both baseline and DeepEP version Baseline I referred to the config here: |
|
Thanks for posting the work! We had a successful and performant DeepEP integration at: AMD-AGI@59fe226 We borrowed some design from Megatron-LM and we can use it here too. I don't see big differences between our DeepEP interface and yours. Let's work together on this. Feel free to reach out to me or Tianyu for future discussion and collaboration. |
| from torch.distributed.tensor import distribute_module | ||
|
|
||
|
|
||
| class DeepEPExpertParallel(ParallelStyle): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can also handle the token communication in here -- the design would look cleaner and more aligned with existing interface.
This is our implementation and our deepep interface is similar to Nvidia's version.
|
|
||
| if self.score_before_experts: | ||
| recv_x_sorted = (recv_x_sorted.to(torch.float32) * valid_weights_sorted.unsqueeze(-1)).to(recv_x_sorted.dtype) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the code before experts.forward should go to DeepEPExpertParallel as input_fn, same for the token unpermute after experts as output_fn
Also consider using something like _indices_to_multihot_kernel (https://github.com/NVIDIA/Megatron-LM/blob/f5344166732f45bb0dd825dc875288ea97b15b47/megatron/core/fusions/fused_indices_converter.py#L32C5-L32C32) to preprocess received DeepEP data.
You are using a lot of index-selecting here which I suspect would incur significant CPU overhead (and lock/wait among CPU threads)
|
Thanks all for the valuable advice! - I'm currently occupied by a deadline but I will take a closer look and join the discussion tomorrow |
Initial version to integrate DeepEP to torchtitan
Currently:
(Edited from: Perf is CPU bottlenecked and lags behind baseline by ~41% (~7.9% MFU vs. ~11.4% MFU))
Shout out to my colleagues @gekurian @syed-ahmed @aazzolini for internal supports!