Skip to content

[feat]: enhance fused MOE kernel with TMA support#10854

Open
liusy58 wants to merge 1 commit intosgl-project:mainfrom
liusy58:tma
Open

[feat]: enhance fused MOE kernel with TMA support#10854
liusy58 wants to merge 1 commit intosgl-project:mainfrom
liusy58:tma

Conversation

@liusy58
Copy link
Collaborator

@liusy58 liusy58 commented Sep 24, 2025

Motivation

This commit introduces TMA (Tensor Memory Accelerator) support for the fused MoE (Mixture of Experts) kernel to improve memory access performance on modern GPU architectures.

Modifications

Accuracy Tests

Benchmarking and Profiling

This optimization delivers measurable performance improvements in production scenarios:

  • 5% TTFT reduction for DeepSeek-R1 deployment
  • Configuration: 4000 input tokens, 1500 output tokens, TP=8

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @liusy58, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Tensor Memory Accelerator (TMA) capabilities into the fused Mixture of Experts (MoE) kernel. This enhancement is specifically designed to optimize memory access patterns on modern GPU architectures, leading to improved overall performance. The changes enable conditional use of TMA for loading B tensors and adjust related scaling factor computations, resulting in measurable speedups, such as a 5% TTFT reduction for DeepSeek-R1.

Highlights

  • TMA Support: Introduces Tensor Memory Accelerator (TMA) support for the fused Mixture of Experts (MoE) kernel.
  • Performance Improvement: Aims to enhance memory access performance on modern GPU architectures by leveraging TMA.
  • Benchmarked Gains: Achieves a 5% reduction in Time-to-First-Token (TTFT) for DeepSeek-R1 deployment under specific configurations (4000 input tokens, 1500 output tokens, TP=8).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces TMA (Tensor Memory Accelerator) support to the fused MoE kernel, aiming to enhance memory access performance. The changes are mostly confined to the Triton kernel implementation, adding conditional logic to use TMA for loading weights and their scales. The implementation looks solid, but I've identified a few areas for improvement, mainly concerning code duplication and redundancy, which can be refactored for better readability and maintainability.

Comment on lines +505 to +521
if even_Ks:
expert_offset = off_experts.to(tl.int32)
n_offset = (pid_n * BLOCK_SIZE_N).to(tl.int32)
k_offset = k_start.to(tl.int32)

b = b_desc.load([expert_offset, n_offset, k_offset])
b = b.reshape([BLOCK_SIZE_N, BLOCK_SIZE_K])
else:
expert_offset = off_experts.to(tl.int32)
n_offset = (pid_n * BLOCK_SIZE_N).to(tl.int32)
k_offset = k_start.to(tl.int32)

b = b_desc.load([expert_offset, n_offset, k_offset])
b = b.reshape([BLOCK_SIZE_N, BLOCK_SIZE_K])
k_mask = (k_start + tl.arange(0, BLOCK_SIZE_K)) < K
b = tl.where(k_mask[None, :], b, 0.0)
b = b.T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's significant code duplication in the if even_Ks: and else: blocks when USE_TMA_B is true. You can refactor this to remove redundancy and improve readability by calculating the offsets once and then conditionally applying the mask.

            expert_offset = off_experts.to(tl.int32)
            n_offset = (pid_n * BLOCK_SIZE_N).to(tl.int32)
            k_offset = k_start.to(tl.int32)

            b = b_desc.load([expert_offset, n_offset, k_offset])
            b = b.reshape([BLOCK_SIZE_N, BLOCK_SIZE_K])
            if not even_Ks:
                k_mask = (k_start + tl.arange(0, BLOCK_SIZE_K)) < K
                b = tl.where(k_mask[None, :], b, 0.0)
            b = b.T

b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn+ offs_ks * stride_bsk)
b_scale = tl.load(b_scale_ptrs)
else:
k_start = k * BLOCK_SIZE_K
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable k_start is calculated at the beginning of the loop (line 491). This recalculation on line 542 is redundant and can be removed to improve clarity.

Comment on lines +544 to +546
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the TMA path, consider using other=1.0 for loading a_scale. While it may not affect the correctness here since a is zeroed out for masked tokens, using 1.0 as a default for a scale factor is more semantically correct and robust.

                    a_scale = tl.load(
                        a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=1.0
                    )

@xu-yfei
Copy link
Contributor

xu-yfei commented Sep 24, 2025

The solution is similar to that in PR #10567. In addition to modifying B with TMA, we also applied TMA modifications to input A of the second MOE. We found that the performance improvement after TMA modification of input A is more significant. Moreover, the optimal configuration of TMA will also change.

@Alcanderian
Copy link
Collaborator

Alcanderian commented Sep 27, 2025

Let's compare the performance with #10567 @liusy58

@Jacki1223
Copy link
Contributor

The solution is similar to that in PR #10567. In addition to modifying B with TMA, we also applied TMA modifications to input A of the second MOE. We found that the performance improvement after TMA modification of input A is more significant. Moreover, the optimal configuration of TMA will also change.

Thank you for sharing this insight. I have a follow-up question: since you observed a bigger gain from applying TMA to input A alone, why not simply apply TMA to both A and B? In our own experiments we consistently see an additional, non-negligible boost when TMA is enabled on both MOE inputs, and the overhead is minimal.

@xu-yfei
Copy link
Contributor

xu-yfei commented Oct 10, 2025

The solution is similar to that in PR #10567. In addition to modifying B with TMA, we also applied TMA modifications to input A of the second MOE. We found that the performance improvement after TMA modification of input A is more significant. Moreover, the optimal configuration of TMA will also change.

Thank you for sharing this insight. I have a follow-up question: since you observed a bigger gain from applying TMA to input A alone, why not simply apply TMA to both A and B? In our own experiments we consistently see an additional, non-negligible boost when TMA is enabled on both MOE inputs, and the overhead is minimal.

Yes, we have applied TMA to both inputs A and B of the second MoE (down proj), while no TMA was used for the first one (gateup proj). The performance improvement we observed for the first MoE is relatively small (approximately 1%).

@dongyibo
Copy link

May I ask if there are any merger plans? @liusy58

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants