Skip to content

Conversation

@elfiegg
Copy link

@elfiegg elfiegg commented Dec 4, 2025

Initial version to integrate DeepEP to torchtitan
Currently:

  1. Only tested on Deepseek 16B models
  2. Perf remains similar as basline for 16B models
    (Edited from: Perf is CPU bottlenecked and lags behind baseline by ~41% (~7.9% MFU vs. ~11.4% MFU))
  3. Haven't tested with DualPipeV pipeline scheduling yet

Shout out to my colleagues @gekurian @syed-ahmed @aazzolini for internal supports!

@meta-cla
Copy link

meta-cla bot commented Dec 4, 2025

Hi @elfiegg!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@elfiegg elfiegg force-pushed the loss_bug branch 6 times, most recently from f33b3d9 to a5875e5 Compare December 4, 2025 05:11
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing!

Perf is CPU bottlenecked and lags behind baseline by ~41% (~7.9% MFU vs. ~11.4% MFU)

May I ask what is the baseline?

Btw, I think to fully utilize the power of DeepEP, we also need to have node-limited routing, which the current torchtitan DSv3 model doesn't have.

@shuhuayu let's add it? we can refer to HF or deepseek original impl.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of making it an experiment (which restricts it to a special version of deepseek_v3), I think we should integrate it directly in core.
We can have a factory method (e.g. build_moe) which takes a string (e.g. "deep_ep") to dispatch to this version of MoE.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that's a great idea! - once I confirm this works for larger models and improves perf

Regarding integrating directly to main - do we need to manage DeepEP dependency at all or we leave it to the users to install?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I prefer

we leave it to the users to install

instead of bundling it by default. We can explicitly mention this in try-catch when we do the import.

from torch.distributed.tensor import distribute_module


class DeepEPExpertParallel(ParallelStyle):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere? I'm guessing that this is not running e2e with torchtitan train.py which is still WIP.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also handle the token communication in here -- the design would look cleaner and more aligned with existing interface.

This is our implementation and our deepep interface is similar to Nvidia's version.

https://github.com/AMD-AGI/torchtitan-amd/blob/59fe226a46d76a553d4e591411a464390475be02/torchtitan/distributed/expert_parallel.py#L441

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr! I think we should support node-limited routing to make multi-node setup faster.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the node-limited routing here: #2111. Perhaps it helps make deepep faster in multi-node setups.

@elfiegg
Copy link
Author

elfiegg commented Dec 4, 2025

Perf is CPU bottlenecked and lags behind baseline by ~41% (~7.9% MFU vs. ~11.4% MFU)

May I ask what is the baseline?

Actually I have rerun this last night and the perf caught up - the lagging perf was gone once I enabled FSDP for MoE layer (which I disabled for debugging purpose). Running below command, I got 13% MFU for both baseline and DeepEP version

torchrun \
    --nnodes=$NUM_NODES \
    --nproc_per_node=$NPROC_PER_NODE \
    --rdzv_id=deepseek_16b_multinode \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
    -m torchtitan.train \
   --parallelism.expert_parallel_degree 16 \
    --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Baseline I referred to the config here: ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml
And DeepEP version is to override --model.name=deepep.deepseek_v3

@yuankaichen-amd
Copy link

Thanks for posting the work!

We had a successful and performant DeepEP integration at: AMD-AGI@59fe226

We borrowed some design from Megatron-LM and we can use it here too.

I don't see big differences between our DeepEP interface and yours. Let's work together on this. Feel free to reach out to me or Tianyu for future discussion and collaboration.

from torch.distributed.tensor import distribute_module


class DeepEPExpertParallel(ParallelStyle):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also handle the token communication in here -- the design would look cleaner and more aligned with existing interface.

This is our implementation and our deepep interface is similar to Nvidia's version.

https://github.com/AMD-AGI/torchtitan-amd/blob/59fe226a46d76a553d4e591411a464390475be02/torchtitan/distributed/expert_parallel.py#L441


if self.score_before_experts:
recv_x_sorted = (recv_x_sorted.to(torch.float32) * valid_weights_sorted.unsqueeze(-1)).to(recv_x_sorted.dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code before experts.forward should go to DeepEPExpertParallel as input_fn, same for the token unpermute after experts as output_fn

Also consider using something like _indices_to_multihot_kernel (https://github.com/NVIDIA/Megatron-LM/blob/f5344166732f45bb0dd825dc875288ea97b15b47/megatron/core/fusions/fused_indices_converter.py#L32C5-L32C32) to preprocess received DeepEP data.

You are using a lot of index-selecting here which I suspect would incur significant CPU overhead (and lock/wait among CPU threads)

@elfiegg
Copy link
Author

elfiegg commented Dec 5, 2025

Thanks all for the valuable advice! - I'm currently occupied by a deadline but I will take a closer look and join the discussion tomorrow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants