Skip to content

[5/n] DP Enhancement: Correct num_token_non_padded#9107

Merged
zhyncs merged 3 commits intomainfrom
cheng/dev/correct-num-non-padded-token
Aug 12, 2025
Merged

[5/n] DP Enhancement: Correct num_token_non_padded#9107
zhyncs merged 3 commits intomainfrom
cheng/dev/correct-num-non-padded-token

Conversation

@ch-wan
Copy link
Collaborator

@ch-wan ch-wan commented Aug 12, 2025

Motivation

num_token_non_padded is used for masking padded tokens in MoE compute. When dp<tp, the current implementation overestimates non padded tokens because its value is token size in the current DP rank but tokens are sharded across DP group under DeepEP. This PR fixes this issue.

Modifications

Accuracy Tests

gsm8k's accuracy is correct under different concurrency.

Benchmarking and Profiling

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --trust-remote-code --tp 8 --enable-dp-attention --dp 2 --moe-dense-tp-size 1 --moe-a2a-backend deepep --cuda-graph-max-bs 32 --max-running-requests 32 --speculative-algo EAGLE --speculative-draft lmsys/DeepSeek-V3-0324-NextN --speculative-num-steps 2 --speculative-eagle-topk 4 --speculative-num-draft-tokens 4 --disable-radix-cache --stream-output
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 32 --random-input 512 --random-output 32 --random-range-ratio 1 --host 127.0.0.1 --port 30000 --max-concurrency 1

Before this PR, the output throughput is 53.05 tok/s. With this PR, the output throughput is 57.07 tok/s.

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ch-wan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses an issue where num_token_non_padded was incorrectly calculated in MoE (Mixture of Experts) computations when data parallelism (dp) was less than tensor parallelism (tp). This overestimation led to incorrect masking of padded tokens. The fix involves adjusting the calculation of num_token_non_padded to account for tokens being sharded across the DeepEP DP group, ensuring accurate local token counts. This correction also yields a performance improvement, increasing throughput from 37.49 tok/s to 39.36 tok/s.

Highlights

  • Problem Identification: Corrects an overestimation of num_token_non_padded in MoE computations when dp < tp, which previously led to incorrect masking of padded tokens.
  • Corrected Calculation: Introduces a refined calculation for num_token_non_padded by considering attn_tp_size and attn_tp_rank to accurately determine num_local_token_non_padded for sharded tokens.
  • Performance Improvement: Benchmarking shows a throughput increase from 37.49 tok/s to 39.36 tok/s after applying the fix.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to correct the num_token_non_padded value for MoE computation when data parallelism is enabled. The changes involve adding logic to calculate the local number of non-padded tokens for each attention tensor-parallel rank. While the overall direction is correct, I've found a potential issue in the calculation logic that could lead to incorrect masking. My review includes a suggestion to fix this.

@ch-wan ch-wan enabled auto-merge (squash) August 12, 2025 18:50
@zhyncs zhyncs disabled auto-merge August 12, 2025 19:23
@zhyncs zhyncs merged commit 5f5b3b2 into main Aug 12, 2025
64 of 66 checks passed
@zhyncs zhyncs deleted the cheng/dev/correct-num-non-padded-token branch August 12, 2025 19:23
@miter6
Copy link
Contributor

miter6 commented Aug 13, 2025

thanks a lot.
I got this issue yesterday.

@miter6
Copy link
Contributor

miter6 commented Aug 13, 2025

Capture cuda graph bs [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64] for dp-attention=16.
Capture cuda graph bs [16, 32, 48, 64] for tp=16.
If running batchsize is below 16, it must padded to 16. Is this a bug??

@ch-wan
Copy link
Collaborator Author

ch-wan commented Aug 13, 2025

@MiterV1 This is expected when attn tp size is 16. We have this constraint to keep reduce-scatter and all-gather efficient (see #8279). In addition, when the running request is small, EP + DP attention is not an optimal approach because its workload is imbalanced.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants