[feat]: enhance fused MOE kernel with TMA support#10854
[feat]: enhance fused MOE kernel with TMA support#10854liusy58 wants to merge 1 commit intosgl-project:mainfrom
Conversation
Summary of ChangesHello @liusy58, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates Tensor Memory Accelerator (TMA) capabilities into the fused Mixture of Experts (MoE) kernel. This enhancement is specifically designed to optimize memory access patterns on modern GPU architectures, leading to improved overall performance. The changes enable conditional use of TMA for loading B tensors and adjust related scaling factor computations, resulting in measurable speedups, such as a 5% TTFT reduction for DeepSeek-R1. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces TMA (Tensor Memory Accelerator) support to the fused MoE kernel, aiming to enhance memory access performance. The changes are mostly confined to the Triton kernel implementation, adding conditional logic to use TMA for loading weights and their scales. The implementation looks solid, but I've identified a few areas for improvement, mainly concerning code duplication and redundancy, which can be refactored for better readability and maintainability.
| if even_Ks: | ||
| expert_offset = off_experts.to(tl.int32) | ||
| n_offset = (pid_n * BLOCK_SIZE_N).to(tl.int32) | ||
| k_offset = k_start.to(tl.int32) | ||
|
|
||
| b = b_desc.load([expert_offset, n_offset, k_offset]) | ||
| b = b.reshape([BLOCK_SIZE_N, BLOCK_SIZE_K]) | ||
| else: | ||
| expert_offset = off_experts.to(tl.int32) | ||
| n_offset = (pid_n * BLOCK_SIZE_N).to(tl.int32) | ||
| k_offset = k_start.to(tl.int32) | ||
|
|
||
| b = b_desc.load([expert_offset, n_offset, k_offset]) | ||
| b = b.reshape([BLOCK_SIZE_N, BLOCK_SIZE_K]) | ||
| k_mask = (k_start + tl.arange(0, BLOCK_SIZE_K)) < K | ||
| b = tl.where(k_mask[None, :], b, 0.0) | ||
| b = b.T |
There was a problem hiding this comment.
There's significant code duplication in the if even_Ks: and else: blocks when USE_TMA_B is true. You can refactor this to remove redundancy and improve readability by calculating the offsets once and then conditionally applying the mask.
expert_offset = off_experts.to(tl.int32)
n_offset = (pid_n * BLOCK_SIZE_N).to(tl.int32)
k_offset = k_start.to(tl.int32)
b = b_desc.load([expert_offset, n_offset, k_offset])
b = b.reshape([BLOCK_SIZE_N, BLOCK_SIZE_K])
if not even_Ks:
k_mask = (k_start + tl.arange(0, BLOCK_SIZE_K)) < K
b = tl.where(k_mask[None, :], b, 0.0)
b = b.T| b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn+ offs_ks * stride_bsk) | ||
| b_scale = tl.load(b_scale_ptrs) | ||
| else: | ||
| k_start = k * BLOCK_SIZE_K |
| a_scale = tl.load( | ||
| a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 | ||
| ) |
There was a problem hiding this comment.
For consistency with the TMA path, consider using other=1.0 for loading a_scale. While it may not affect the correctness here since a is zeroed out for masked tokens, using 1.0 as a default for a scale factor is more semantically correct and robust.
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=1.0
)|
The solution is similar to that in PR #10567. In addition to modifying B with TMA, we also applied TMA modifications to input A of the second MOE. We found that the performance improvement after TMA modification of input A is more significant. Moreover, the optimal configuration of TMA will also change. |
Thank you for sharing this insight. I have a follow-up question: since you observed a bigger gain from applying TMA to input A alone, why not simply apply TMA to both A and B? In our own experiments we consistently see an additional, non-negligible boost when TMA is enabled on both MOE inputs, and the overhead is minimal. |
Yes, we have applied TMA to both inputs A and B of the second MoE (down proj), while no TMA was used for the first one (gateup proj). The performance improvement we observed for the first MoE is relatively small (approximately 1%). |
|
May I ask if there are any merger plans? @liusy58 |
Motivation
This commit introduces TMA (Tensor Memory Accelerator) support for the fused MoE (Mixture of Experts) kernel to improve memory access performance on modern GPU architectures.
Modifications
Accuracy Tests
Benchmarking and Profiling
This optimization delivers measurable performance improvements in production scenarios:
Checklist