Skip to content

Conversation

@siddharth9820
Copy link

@siddharth9820 siddharth9820 commented Jul 6, 2022

Note: This is in conjunction with this PR on the deepspeed repo.

This PR adds tensor parallelism for non-experts. This combined with ZeRO-2 allows us to scale to roughly 2x larger base models than ZeRO-2. When tensor parallelism is enabled only for non-experts, there are duplicate tokens at each gate. It is important to drop the duplicates before they reach the experts, otherwise we run into convergence issues. In megatron/mpu/mappings.py, we provide new autograd functions that drop these tokens and gather them, namely _DropTokens and _AllGatherFromModelParallelRegion.
In the current implementation, we drop tokens right before the AlltoAll and gather them right after the AlltoAll. These calls are done in the Deepspeed codebase.

Update: This PR now supports tensor parallelism for experts as well. This can be enabled by passing the --enable-expert-tensor-parallelism argument.

@siddharth9820
Copy link
Author

siddharth9820 commented Jul 6, 2022

Comparing loss curves with no tensor parallelism

image

@siddharth9820 siddharth9820 changed the title Tensor parallelism for Non-Experts Tensor parallelism for Mixture of Experts Jul 20, 2022
@siddharth9820
Copy link
Author

siddharth9820 commented Jul 20, 2022

Here are the loss curves when tensor parallelism is used for both the experts and non experts.
image

Note: these were buggy and were fixed subsequently

@siddharth9820 siddharth9820 requested a review from awan-10 July 26, 2022 18:57
@siddharth9820 siddharth9820 merged commit 222d899 into main Aug 1, 2022
@siddharth9820 siddharth9820 deleted the moe-tensor-parallelism branch August 1, 2022 00:42
hyoo pushed a commit to hyoo/Megatron-DeepSpeed that referenced this pull request Apr 21, 2023
saforem2 added a commit to saforem2/Megatron-DeepSpeed that referenced this pull request Nov 15, 2024
[merge]: into `microsoft-main` $\leftarrow$ from `hzheng-data-fix`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants