Opt fused triton moe: add tma for down proj kernel#10567
Opt fused triton moe: add tma for down proj kernel#10567huangtingwei9988 merged 8 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @xu-yfei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request delivers a critical performance upgrade for the Fused Triton Mixture of Experts (MoE) architecture, focusing on the upsampled second MoE layer. By strategically implementing Tensor Memory Accelerator (TMA) for efficient data handling and refining the configuration tuning with real-world data, the changes drastically improve computational utilization and reduce inference latency. This ensures that the second MoE's performance aligns with its computational cost, leading to a more efficient overall model execution.
Highlights
- Performance Optimization for Second MoE: The pull request significantly optimizes the performance of the second Mixture of Experts (MoE) layer, specifically the upsampled Fused Triton MoE, which previously exhibited disproportionately high latency.
- TMA Integration: Tensor Memory Accelerator (TMA) has been integrated for input A and weight B access in the second MoE, enhancing data movement efficiency.
- Improved Computational Utilization: The computational utilization of the second MoE has increased from 45.20% to 81.12% due to these optimizations.
- Latency Reduction: Average latency for 100 samples in an 8K tokens scenario decreased by 0.995 ms (from 2.430 ms to 1.435 ms), demonstrating a substantial speedup.
- Refined Tuning Process: The tuning process now uses real-world expert distribution (topk_ids from actual inference) and independently tunes and loads optimal configurations for both downsampling and upsampling MoEs.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
|
@BBuf Could you have a review? |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant performance optimization for the fused MoE up-kernel by leveraging TMA (Tensor Memory Access), resulting in a substantial latency reduction. The changes are well-structured, including a new benchmark script for separate tuning of MoE kernels, updated Triton kernel configurations, and modifications to the MoE implementation to integrate the new TMA kernel. While the performance gains are impressive, I've identified a couple of critical correctness issues in the new Triton kernels and some areas for improvement in the benchmark script to enhance accuracy and maintainability.
| if BLOCK_SIZE_N > group_n: | ||
| accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] | ||
| else: | ||
| accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale[None, :]) |
There was a problem hiding this comment.
When BLOCK_SIZE_N <= group_n, b_scale is loaded as a scalar for each K-group. Applying [None, :] indexing to a scalar is invalid in Triton. It seems it should be just b_scale, which would broadcast correctly.
| if BLOCK_SIZE_N > group_n: | |
| accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] | |
| else: | |
| accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale[None, :]) | |
| if BLOCK_SIZE_N > group_n: | |
| accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] | |
| else: | |
| accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale) |
| if BLOCK_SIZE_N > group_n: | ||
| accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] | ||
| else: | ||
| accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale[None, :]) |
There was a problem hiding this comment.
Similar to a previous comment, when BLOCK_SIZE_N <= group_n, b_scale is loaded as a scalar. Applying [None, :] indexing to a scalar is invalid in Triton. It should probably be just b_scale, which would broadcast correctly.
| if BLOCK_SIZE_N > group_n: | |
| accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] | |
| else: | |
| accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale[None, :]) | |
| if BLOCK_SIZE_N > group_n: | |
| accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] | |
| else: | |
| accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale) |
|
|
||
| def prepare(i: int): | ||
| input_gating = gating_output[i] | ||
| topk_ids = torch.load(f"{topk_ids_dir}/topk_ids_layer{i%58+3}_idx{i//58}.pt") |
| for i in range(num_iters): | ||
| prepare(i) | ||
| torch.cuda.synchronize() | ||
| up_moe_use_tma = False | ||
| t0, t1 = run() | ||
| torch.cuda.synchronize() | ||
| up_moe_use_tma = True | ||
| t0_x, t_tma = run() | ||
| torch.cuda.synchronize() | ||
| latencies.append(t0) | ||
| latencies1.append(t1) | ||
| latencies_tma.append(t_tma) |
There was a problem hiding this comment.
The latency of the first kernel (invoke_fused_moe_kernel) when up_moe_use_tma is true (t0_x at line 278) is discarded. The tuning for this kernel is always based on t0 (from up_moe_use_tma = False), which corresponds to c_sorted=False. However, at runtime, if TMA is used for the up-kernel, the first kernel will be called with c_sorted=True. This means the configuration used for the first kernel might be suboptimal as it was tuned for c_sorted=False.
To fix this, you should also record t0_x and use it for tuning the first kernel when TMA is chosen for the up-kernel.
| self.time_cost_all = None # kernel0, kernel1 without tma, kernel1 with tma | ||
|
|
||
| def update(self, config, time_cost, time_cost_all): | ||
| if time_cost < self.time_cost: |
There was a problem hiding this comment.
Using < might cause the tuning to be sensitive to the order of configurations in the search space. If two configurations have the same performance, only the first one encountered will be kept. Using <= is more robust and makes the choice of the best config more stable, preferring later configurations in case of a tie.
| if time_cost < self.time_cost: | |
| if time_cost <= self.time_cost: |
|
Great work! I noticed that the Qwen3-MoE model also uses the same MoE operator. Would this modification bring benefits to Qwen3-MoE as well? |
What is the maximum token length and TP size in your test? We found that when the token length is large (e.g., 8K), the performance of the second MOE is similar to that of the first one. A possible reason is that after TP partitioning, the K dimension of the second MOE is relatively small (2048/tp8 = 256), while the N dimension is quite large (7168). |
Thanks for your reply. I was testing the Qwen3-30B-A3B model with an input length of only 1024 and tp-size=1, which might be the reason for the issue. Next, I will test the Qwen3-235B model with tp-size=4 to see if there will be any differences. |
aeafb07 to
94ce683
Compare
|
Are there any possibility to merge the kernel into the origin fused moe kernel? |
|
@xu-yfei Why is the TMA optimization applied only to the down_proj? Can the up_proj achieve similar acceleration effects as well? |
The weights of up_proj can use TMA, but my local verification shows that the improvement is not significant (1~2% For down_proj, I have modified the distribution of input A: when calculating the output of up_proj, values are written in blocks corresponding to experts, which enables down_proj to fetch values using TMA. |
I'm currently attempting to merge. After merging, the performance drops significantly, though the precision remains normal. I'm investigating the cause, which may be due to branches leading to the performance degradation. After removing some branches, the performance has recovered considerably. |
c721896 to
aaa6ee3
Compare
@Alcanderian Done. The performance degradation in the previously merged code stems from two factors:
|
|
if a_desc is not None: @xu-yfei hello,in fused_moe_kernel: |
@dongyibo Currently, only the fp8_w8a8 scenario has been verified, while other scenarios remain unverified. |
|
@zhyncs Ready to merge. Currently, only H20 has the new tuning configuration. This feature is enabled only when the relevant configuration exists. The failed AMD test cases should be unrelated to this feature. |
|
Hi! may I ask if you have measured the time consumption of fused_moe_kernel after applying TMA to input A? I found that it takes more time than without using TMA optimization. |
What is your device? How many tokens do you use, and what's the configuration? |
My device is H200, testing num_tokens from 1-16384, and the model config come from Qwen3-30B-A3B. However, what is confusing is that although the kernel takes longer, using TMA can still show throughput improvement in end-to-end inference. |
Which script did you use to verify the kernel performance? The optimal configuration when using TMA is actually different from that when not using TMA, and the performance of the optimal TMA configuration can be much better. In H20(96GB) use
|
|
Hi, @xu-yfei , thank you for sharing this great work. When I want to generate topk_ids based on real data, is one query enough? Or I need to calculate statistic top-k in my dataset? |
|
@xu-yfei I have one question, if topk_ids are tightly coupled with specific scenarios for better perf, how can we save a kernel configuration for different scenarios within a single kernel config file? thanks. |
|
@xu-yfei from the tuned config file, it seems that USE_TMA only applied for very large batch size (e.g. >= 1024), so for the long query with normal batch size (such as 4, 8) could it benefit from this optimization ? |
The current solution uses topk_ids data with uneven expert distribution, capturing only one long input and selecting 100 topk_ids entries from 58 layers. While a dataset-based approach might be more reasonable, the volume of data required for tuning is excessively large.
Based on the current implementation, the tuning results indicate that TMA yields no benefits in this scenario. |
|
@xu-yfei thanks for your detailed explanation. |
|
Hello, may I ask why the following is required: |
In the TMA implementation, BLOCK_SIZE_M is related to the data layout of the first MoE's output and the second MoE's input. Within the same contiguous block of length BLOCK_SIZE_M, the experts are identical and can be loaded via TMA. Strictly speaking, it suffices that the BLOCK_SIZE_M of the first MoE is a multiple of that of the second MoE: |
Thank you for the explanation; I understand now. |
Motivation
In H20(96GB) TP8 prefill, during performance analysis, we observed that the latency of the second MOE in each layer (i.e., the downsampled Fused Triton MOE) was comparable to that of the first MOE (upsampled Fused Triton MOE), even though its weight data volume and computational cost were only half of the first MOE. This latency performance was unreasonable.
For the second MOE (downsampling MOE), we use TMA to encapsulate input A and weight B, and perform independent tuning and configuration loading for it (separate from the first MOE). As a result, the computation utilization of the second MOE has increased from 45.20% to 81.12%, and the average latency for 100 samples in an 8K tokens scenario decreased by 0.995 ms (from 2.430 ms to 1.435 ms).
Detail
During performance analysis, we observed that the latency of the second MOE in each layer (i.e., the downsampling Fused Triton MOE) was comparable to that of the first MOE (upsampled Fused Triton MOE), even though its weight data volume and computational cost were only half of the first MOE. This latency performance was unreasonable.
To optimize its performance, we implemented the following measures:
Optimized the acquisition and calculation of b_scale: When BLOCK_SIZE_N is less than or equal to the group_n(block quantization parameter), only one b_scale element needs to be read each time instead of BLOCK_SIZE_N elements. In this case, a_scale (input quantization scaling factor) can be multiplied by b_scale first, and then multiplied with the dot product result, thereby reducing computational overhead.
Refactored the second MOE based on TMA: Changed the access pattern of input A from discrete fetching to a continuous organization, adjusting the first dimension from (num_tokens * top_k) to (num_blocks * block_size_m), and encapsulated the access processes for input A and weight B using TMA.
Optimized configuration using real-world expert distribution: Since expert workloads are typically unevenly distributed, the original tuning configuration based on randomly generated experts did not reflect real-world conditions. We instead used topk_ids generated during actual inference as input to more accurately match real workloads.
Tuned the optimal configurations for both MOEs, independently loading different configurations for each.
Tuning Process:
benchmark one config:
Modifications
Accuracy Tests
Benchmarking and Profiling
16372 chunked size, sending one request with a context length of 16372×3 + 502:
Before PR
After PR
TTFT of a single request with different lengths
Checklist