Skip to content

[DeepSeek V3.1/V3.2] Optimize fused moe configs for H20 & H20-3E based on swapab#17133

Merged
Fridge003 merged 2 commits intosgl-project:mainfrom
antgroup:xyf/tune_moe
Jan 16, 2026
Merged

[DeepSeek V3.1/V3.2] Optimize fused moe configs for H20 & H20-3E based on swapab#17133
Fridge003 merged 2 commits intosgl-project:mainfrom
antgroup:xyf/tune_moe

Conversation

@xu-yfei
Copy link
Copy Markdown
Contributor

@xu-yfei xu-yfei commented Jan 15, 2026

Motivation

  1. Performance tuning based on the code after fused moe swapab [Rework] Add SwapAB Optimization for triton fused_moe_kernel on SM90. #16723. The optimal configuration of fused MoE changes when swapab is taken into consideration.

  2. Optimize the tuning script tuning_fused_moe_triton_sep.py: CUDA Graph is used to encapsulate the kernel to avoid inaccurate performance evaluation in small-token scenarios. In addition, a total of 100 sample data are divided into 10 iterations with 10 data executed per iteration, replacing the previous approach of 10 samples total with 1 data executed per iteration and repeated 10 times.

  3. Tuned the DeepSeek V3.1/V3.2 TP8 scenarios on H20 and H20-3E devices.

Modifications

Accuracy Tests

Benchmarking and Profiling

For token counts ranging from 1 to 256, compare the performance before and after optimization using TPOT.

bs=(1 4 8 16 24 32 48 64 96 128 256)
for i in "${bs[@]}"
do
python3 -m sglang.bench_serving --backend sglang --dataset-name random \
--random-input 16 --random-output 1024 --request-rate 1000 --num-prompt ${i} \
--random-range-ratio 1 --max-concurrency 1024 --port 8000 \
--dataset-path /home/ShareGPT_V3_unfiltered_cleaned_split.json
done
ds32 tp 8 in h20-3e ds31 tp 8 in h20
bs before PR after PR before PR after PR
1 13.09 12.3 9.84 9.08
4 15.64 14.71 12.8 11.43
8 19.89 17.84 16.34 14.46
16 24.85 23.58 22.62 19.95
24 28.25 26.78 24.24 23.76
32 33.05 29.65 28.47 25.98
48 35.82 33.96 31.59 29.52
64 38.36 36.66 34.12 32.74
96 49.15 46.54 39.51 38.02
128 51.64 49.4 41.99 40.97
256 65.85 62.58 52.96 52.4

For token counts ranging from 512 to 8192, compare the performance before and after optimization using TTFT.

tokens=(512 1024 1536 2048 3072 4096 8192)
for i in "${tokens[@]}"
do
python3 -m sglang.bench_serving --backend sglang --dataset-name random \
--random-input $i --random-output 1 --request-rate 1000 --num-prompt 128 \
--random-range-ratio 1 --max-concurrency 1 --port 8000 \
--dataset-path /home/ShareGPT_V3_unfiltered_cleaned_split.json
done
ds32 tp 8 in h20-3e ds31 tp 8 in h20
bs before PR after PR before PR after PR
512 129.29 132.34 132.41 132.59
1024 137.82 140.48 136.55 136.85
1536 162.58 144.94 142.97 135.47
2048 215.83 199.33 170.72 168.73
3072 714.93 695.50 231.10 230.16
4096 968.13 940.98 302.88 302.52
8192 1937.76 1888.15 629.13 628.69

It can be observed that the TPOT for token counts ranging from 1 to 256 has decreased, and the TTFT of H20-3E has also seen a certain reduction.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @xu-yfei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves the performance of Mixture-of-Experts (MoE) models, specifically DeepSeek V3.1/V3.2, on H20 and H20-3E hardware. It achieves this by fine-tuning MoE configurations in light of recent architectural changes and by upgrading the underlying tuning methodology. The updated tuning script now leverages CUDA Graphs for more precise measurements and employs a more comprehensive sampling approach, leading to more reliable and efficient MoE operations. Performance benchmarks provided indicate notable improvements in both TPOT for small token counts and TTFT for H20-3E.

Highlights

  • MoE Configuration Optimization: Optimized fused Mixture-of-Experts (MoE) configurations for DeepSeek V3.1/V3.2 TP8 on H20 and H20-3E devices, building upon recent swapab integration to enhance performance.
  • Tuning Script Enhancement: Enhanced the MoE tuning script (tuning_fused_moe_triton_sep.py) by incorporating CUDA Graphs for more accurate performance evaluation, particularly in small-token scenarios, and revised the sampling strategy to use 100 samples over 10 iterations for improved robustness.
  • New Device Configurations: Introduced new configuration files for H20 and H20-3E devices to reflect the newly optimized MoE settings, ensuring better performance on these specific hardware platforms.
  • Kernel Indexing Correction: Corrected an indexing issue in fused_moe_triton_kernels.py related to C.stride (from C.stride(1), C.stride(2) to C.stride(-2), C.stride(-1)) for improved kernel execution logic.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Jan 15, 2026

@BBuf @Fridge003 Could you please review this PR?

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly enhances the MoE kernel tuning script by incorporating CUDA graphs for more precise benchmarking and adopting a more resilient strategy for identifying optimal configurations. It also introduces new, optimized configurations for DeepSeek V3.1/V3.2 on H20 and H20-3E hardware, which, as the benchmarks indicate, should yield performance improvements. The changes are generally well-executed, but I have pointed out a critical issue in the tuning script that requires attention to guarantee the accuracy of the benchmark outcomes. Additionally, I've offered a suggestion to enhance code readability.

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Jan 15, 2026

/tag-and-rerun-ci

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Jan 16, 2026

/rerun-failed-ci

@Fridge003 Fridge003 merged commit 82a1b64 into sgl-project:main Jan 16, 2026
252 of 285 checks passed
@dongyibo
Copy link
Copy Markdown

@xu-yfei Hello, I noticed that in the down config file, USE_TMA is set to true for all batch sizes, which differs from my previous understanding. I thought TMA would only be enabled for larger batch sizes. Could you please explain why this is the case? Thank you

@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Feb 2, 2026

@xu-yfei Hello, I noticed that in the down config file, USE_TMA is set to true for all batch sizes, which differs from my previous understanding. I thought TMA would only be enabled for larger batch sizes. Could you please explain why this is the case? Thank you

@dongyibo Wrap the MoE operator calls with CUDA Graph for more accurate performance evaluation. The tuning result is that USE_TMA for all down proj layers is true.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants