Skip to content

[4/n] DP Enhancement: Optimize communication when dp < tp by using all_gather_into_tensor and reduce_scatter_tensor#8279

Closed
ch-wan wants to merge 4 commits intogh/ch-wam/4/basefrom
gh/ch-wam/4/head
Closed

[4/n] DP Enhancement: Optimize communication when dp < tp by using all_gather_into_tensor and reduce_scatter_tensor#8279
ch-wan wants to merge 4 commits intogh/ch-wam/4/basefrom
gh/ch-wam/4/head

Conversation

@ch-wan
Copy link
Collaborator

@ch-wan ch-wan commented Jul 23, 2025

Stack from ghstack (oldest at bottom):

Motivation and Modifications

#8278 padded the token size to a multiple of attn_tp_size. As a result, each DP rank's hidden states can be evenly scattered across its TP group. This enables the use of reduce_scatter_tensor and all_gather_into_tensor to optimize communication efficiency.

Benchmark

TP MoE

Benchmark command:

# launch
python3 -m sglang.launch_server --model-path /dev/shm/DeepSeek-V3-0324 --tp-size 8 --enable-dp-attention --dp-size 4 --enable-deepep-moe --trust-remote-code  --chunked-prefill-size 4096 --cuda-graph-max-bs 256 --max-running-requests 1024 --enable-dp-lm-head --moe-dense-tp-size 1 --mem-fraction-static 0.75 --disable-radix-cache
# benchmark
python3 -m sglang.bench_one_batch_server --model /dev/shm/DeepSeek-V3-0324 --base-url http://localhost:30000/ --batch-size 4096 --input-len 128 --output-len 32

Output throughput:

Main #8278 #8279
4154.71 4163.29 4261.29

DeepEP MoE

Benchmark command:

# launch
python3 -m sglang.launch_server --model-path /dev/shm/DeepSeek-V3-0324 --tp-size 8 --enable-dp-attention --dp-size 4 --enable-deepep-moe --trust-remote-code  --chunked-prefill-size 4096 --cuda-graph-max-bs 128 --max-running-requests 512 --enable-deepep-moe --enable-dp-lm-head --moe-dense-tp-size 1 --mem-fraction-static 0.75 --disable-radix-cache 
# benchmark
python3 -m sglang.bench_one_batch_server --model /dev/shm/DeepSeek-V3-0324 --base-url http://localhost:30000 --batch-size 4096 --input-len 128 --output-len 32

Output throughput:

Main #8278 #8279
2405.84 2397.82 2443.36

[ghstack-poisoned]
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ch-wan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing inter-process communication within the distributed training framework, particularly for scenarios where data parallelism is less than tensor parallelism. It achieves this by adopting new, more efficient tensor-based collective communication primitives, which streamline data transfer and reduce overhead during operations like all-gather and reduce-scatter, leading to improved performance.

Highlights

  • Communication Optimization: Replaced list-based all_gather and reduce_scatter calls with new all_gather_into_tensor and reduce_scatter_tensor primitives. These new primitives operate directly on tensors, potentially offering more efficient data movement in distributed settings, especially for dp < tp configurations.
  • New Collective Communication APIs: Introduced attn_tp_reduce_scatter_tensor and attn_tp_all_gather_into_tensor functions in dp_attention.py. These functions provide a more direct and potentially optimized interface for tensor-parallel collective operations.
  • Conditional Logits Gathering: Implemented conditional logic in logits_processor.py to leverage the new attn_tp_all_gather_into_tensor for gathering logits. This optimization is applied when the model's vocabulary size is evenly divisible by the tensor parallelism size, ensuring efficient communication for common scenarios.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes communication for dp < tp scenarios by replacing all_gather and reduce_scatter with their more efficient _tensor counterparts (all_gather_into_tensor and reduce_scatter_tensor). The changes are well-implemented, including a fallback mechanism in logits_processor.py for cases where the vocab size is not divisible by the tensor parallel size. I have one suggestion to improve code clarity in communicator.py.

@ch-wan ch-wan changed the title Optimize communication for dp < tp by using all_gather_into_tensor and reduce_scatter_tensor [4/n] DP Enhancement: Optimize communication when dp < tp by using all_gather_into_tensor and reduce_scatter_tensor Jul 23, 2025
@ch-wan ch-wan mentioned this pull request Jul 23, 2025
6 tasks
@whybeyoung
Copy link
Collaborator

LGTM

[ghstack-poisoned]
[ghstack-poisoned]
ch-wan added a commit that referenced this pull request Jul 23, 2025
…` and `reduce_scatter_tensor`

ghstack-source-id: bef9f4c
Pull-Request: #8279
ch-wan added a commit that referenced this pull request Jul 25, 2025
…` and `reduce_scatter_tensor`

ghstack-source-id: bef9f4c
Pull-Request: #8279
[ghstack-poisoned]
@ch-wan ch-wan closed this Jul 25, 2025
@ch-wan ch-wan deleted the gh/ch-wam/4/head branch July 25, 2025 04:50
@miter6
Copy link
Contributor

miter6 commented Aug 13, 2025

Both TP MoE Benchmark command and DeepEP MoE Benchmark command use --enable-deepep-moe option.
Is there a typo or bug??

@ch-wan
Copy link
Collaborator Author

ch-wan commented Aug 14, 2025

@MiterV1 It should be a typo 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants