Skip to content

Opt fused triton moe: add tma for down proj kernel#10567

Merged
huangtingwei9988 merged 8 commits intosgl-project:mainfrom
antgroup:xyf/tune_moe
Oct 28, 2025
Merged

Opt fused triton moe: add tma for down proj kernel#10567
huangtingwei9988 merged 8 commits intosgl-project:mainfrom
antgroup:xyf/tune_moe

Conversation

@xu-yfei
Copy link
Contributor

@xu-yfei xu-yfei commented Sep 17, 2025

Motivation

In H20(96GB) TP8 prefill, during performance analysis, we observed that the latency of the second MOE in each layer (i.e., the downsampled Fused Triton MOE) was comparable to that of the first MOE (upsampled Fused Triton MOE), even though its weight data volume and computational cost were only half of the first MOE. This latency performance was unreasonable.

For the second MOE (downsampling MOE), we use TMA to encapsulate input A and weight B, and perform independent tuning and configuration loading for it (separate from the first MOE). As a result, the computation utilization of the second MOE has increased from 45.20% to 81.12%, and the average latency for 100 samples in an 8K tokens scenario decreased by 0.995 ms (from 2.430 ms to 1.435 ms).

image

Detail

During performance analysis, we observed that the latency of the second MOE in each layer (i.e., the downsampling Fused Triton MOE) was comparable to that of the first MOE (upsampled Fused Triton MOE), even though its weight data volume and computational cost were only half of the first MOE. This latency performance was unreasonable.
To optimize its performance, we implemented the following measures:

  • Optimized the acquisition and calculation of b_scale: When BLOCK_SIZE_N is less than or equal to the group_n(block quantization parameter), only one b_scale element needs to be read each time instead of BLOCK_SIZE_N elements. In this case, a_scale (input quantization scaling factor) can be multiplied by b_scale first, and then multiplied with the dot product result, thereby reducing computational overhead.

  • Refactored the second MOE based on TMA: Changed the access pattern of input A from discrete fetching to a continuous organization, adjusting the first dimension from (num_tokens * top_k) to (num_blocks * block_size_m), and encapsulated the access processes for input A and weight B using TMA.

  • Optimized configuration using real-world expert distribution: Since expert workloads are typically unevenly distributed, the original tuning configuration based on randomly generated experts did not reflect real-world conditions. We instead used topk_ids generated during actual inference as input to more accurately match real workloads.

  • Tuned the optimal configurations for both MOEs, independently loading different configurations for each.

Tuning Process:

  1. Edit the code file srt/models/deepseek_v2.py in the Python site package and add the logic for saving topk_ids:
# import get_tensor_model_parallel_rank
# DeepseekV2MoE::forward_normal
if hidden_states.shape[0] == 16384 and get_tensor_model_parallel_rank() == 0:
    topk_ids_dir = xxxx
    if not hasattr(self, "save_idx"):
        self.save_idx = 0
    if self.save_idx <= 1:
        torch.save(topk_output.topk_ids, f"{topk_ids_dir}/topk_idx_layer{self.layer_id}_idx{self.save_idx}.pt")
    self.save_idx += 1
  1. Set the chunked prefix size to 16384 and send a request with a longer context length to the server;
  2. Stop the server and perform tuning. This will generate two files: one for upsampling and the other for downsampling (ending with _down);
model_path=/home/deepseek-ai__DeepSeek-R1
topk_ids_dir=xxxxx

python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
    --model $model_path \
    --tp-size 8 \
    --dtype fp8_w8a8 \
    --topk-ids-dir ${topk_ids_dir} \
    --tune
  1. Replace the configurations in the directory srt/layers/moe/fused_moe_triton/configs/triton_3_4_0.

benchmark one config:

model_path=/home/deepseek-ai__DeepSeek-R1
topk_ids_dir=xxxxx
cfg=$1
bs=$2

python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py \
    --model $model_path \
    --tp-size 8 \
    --dtype fp8_w8a8 \
    --topk-ids-dir ${topk_ids_dir} \
    --configs ${cfg} \
    --batch-size ${bs}

Modifications

Accuracy Tests

# gsm8k
Accuracy: 0.951
Invalid: 0.000
Latency: 241.414 s
Output throughput: 531.216 token/s

Accuracy: 0.953
Invalid: 0.000
Latency: 241.295 s
Output throughput: 531.077 token/s
# mmlu
subject: abstract_algebra, #q:100, acc: 0.740
subject: anatomy, #q:135, acc: 0.852
subject: astronomy, #q:152, acc: 0.941
subject: business_ethics, #q:100, acc: 0.880
subject: clinical_knowledge, #q:265, acc: 0.928
subject: college_biology, #q:144, acc: 0.979
subject: college_chemistry, #q:100, acc: 0.630
subject: college_computer_science, #q:100, acc: 0.830
subject: college_mathematics, #q:100, acc: 0.770
subject: college_medicine, #q:173, acc: 0.879
subject: college_physics, #q:102, acc: 0.843
subject: computer_security, #q:100, acc: 0.880
subject: conceptual_physics, #q:235, acc: 0.932
subject: econometrics, #q:114, acc: 0.737
subject: electrical_engineering, #q:145, acc: 0.876
subject: elementary_mathematics, #q:378, acc: 0.937
subject: formal_logic, #q:126, acc: 0.802
subject: global_facts, #q:100, acc: 0.670
subject: high_school_biology, #q:310, acc: 0.955
subject: high_school_chemistry, #q:203, acc: 0.852
subject: high_school_computer_science, #q:100, acc: 0.950
subject: high_school_european_history, #q:165, acc: 0.891
subject: high_school_geography, #q:198, acc: 0.960
subject: high_school_government_and_politics, #q:193, acc: 0.984
subject: high_school_macroeconomics, #q:390, acc: 0.918
subject: high_school_mathematics, #q:270, acc: 0.763
subject: high_school_microeconomics, #q:238, acc: 0.971
subject: high_school_physics, #q:151, acc: 0.841
subject: high_school_psychology, #q:545, acc: 0.972
subject: high_school_statistics, #q:216, acc: 0.861
subject: high_school_us_history, #q:204, acc: 0.956
subject: high_school_world_history, #q:237, acc: 0.945
subject: human_aging, #q:223, acc: 0.848
subject: human_sexuality, #q:131, acc: 0.931
subject: international_law, #q:121, acc: 0.950
subject: jurisprudence, #q:108, acc: 0.907
subject: logical_fallacies, #q:163, acc: 0.920
subject: machine_learning, #q:112, acc: 0.804
subject: management, #q:103, acc: 0.922
subject: marketing, #q:234, acc: 0.949
subject: medical_genetics, #q:100, acc: 0.950
subject: miscellaneous, #q:783, acc: 0.950
subject: moral_disputes, #q:346, acc: 0.882
subject: moral_scenarios, #q:895, acc: 0.774
subject: nutrition, #q:306, acc: 0.915
subject: philosophy, #q:311, acc: 0.897
subject: prehistory, #q:324, acc: 0.941
subject: professional_accounting, #q:282, acc: 0.865
subject: professional_law, #q:1534, acc: 0.698
subject: professional_medicine, #q:272, acc: 0.960
subject: professional_psychology, #q:612, acc: 0.913
subject: public_relations, #q:110, acc: 0.845
subject: security_studies, #q:245, acc: 0.890
subject: sociology, #q:201, acc: 0.965
subject: us_foreign_policy, #q:100, acc: 0.930
subject: virology, #q:166, acc: 0.590
subject: world_religions, #q:171, acc: 0.936
Total latency: 778.113
Average accuracy: 0.871

Benchmarking and Profiling

16372 chunked size, sending one request with a context length of 16372×3 + 502:
Before PR

Name Wall duration Avg Wall duration Occurrences
fused_moe_kernel 1.525s 4.383ms 348
image

After PR

Name Wall duration Avg Wall duration Occurrences
fused_moe_kernel 706.7ms 4.061ms 174
fused_moe_down_tma_kernel 430.0ms 2.471ms 174
image

TTFT of a single request with different lengths

export SGL_ENABLE_JIT_DEEPGEMM=1
export TORCHINDUCTOR_CACHE_DIR=/home/admin/inductor_root_cache
export SGLANG_TORCH_PROFILER_DIR=/home/admin/torch_profiler
export SGL_CHUNKED_PREFIX_CACHE_USE_TUNED=1
model_path=/home/deepseek-ai__DeepSeek-R1

nohup python3 -m sglang.launch_server --model-path $model_path \
--host 0.0.0.0 --port 8000 --trust-remote-code \
--enable-cache-report --quantization fp8 --log-level info \
--max-running-requests 16 \
--mem-fraction-static 0.92 --chunked-prefill-size 16372 \
--context-length 65535 --chat-template /home/r1.jinja \
--attention-backend flashinfer \
--tp-size 8 --enable-metrics --cuda-graph-max-bs 16 \
--disable-radix-cache
for((i=1;i<=16;i++)); do
input_len=$((i*512))
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input $input_len --random-output 1 --request-rate 1000 --num-prompt 100 --random-range-ratio 1 --max-concurrency 1  --port 8000 --dataset-path /home/ShareGPT_V3_unfiltered_cleaned_split.json
done
Input Lens Before PR - Request Throughput (req/s) Before PR - Mean TTFT After PR - Request Throughput (req/s) After PR - Mean TTFT TTFT Reduction
512 8.6 115.78 8.82 112.86 2.52%
1024 7.19 138.63 7.8 127.31 8.17%
1536 5.59 178.45 6.1 163.26 8.51%
2048 4.55 219.31 4.96 200.5 8.58%
2560 3.91 255.08 4.3 232.13 9.00%
3072 3.32 300.59 3.64 274.11 8.81%
3584 2.89 341.83 3.17 314.86 7.89%
4096 2.54 393.37 2.78 358.95 8.75%
4608 2.29 436.66 2.51 398.52 8.73%
5120 2.05 487.74 2.24 446.47 8.46%
5632 1.87 530.16 2.03 491.93 7.21%
6144 1.71 585.58 1.86 537.3 8.24%
6656 1.57 635.68 1.71 583.92 8.14%
7168 1.45 690.08 1.58 634.02 8.12%
7680 1.35 727.69 1.46 670.21 7.90%
8192 1.25 801.73 1.35 738.52 7.88%

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @xu-yfei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request delivers a critical performance upgrade for the Fused Triton Mixture of Experts (MoE) architecture, focusing on the upsampled second MoE layer. By strategically implementing Tensor Memory Accelerator (TMA) for efficient data handling and refining the configuration tuning with real-world data, the changes drastically improve computational utilization and reduce inference latency. This ensures that the second MoE's performance aligns with its computational cost, leading to a more efficient overall model execution.

Highlights

  • Performance Optimization for Second MoE: The pull request significantly optimizes the performance of the second Mixture of Experts (MoE) layer, specifically the upsampled Fused Triton MoE, which previously exhibited disproportionately high latency.
  • TMA Integration: Tensor Memory Accelerator (TMA) has been integrated for input A and weight B access in the second MoE, enhancing data movement efficiency.
  • Improved Computational Utilization: The computational utilization of the second MoE has increased from 45.20% to 81.12% due to these optimizations.
  • Latency Reduction: Average latency for 100 samples in an 8K tokens scenario decreased by 0.995 ms (from 2.430 ms to 1.435 ms), demonstrating a substantial speedup.
  • Refined Tuning Process: The tuning process now uses real-world expert distribution (topk_ids from actual inference) and independently tunes and loads optimal configurations for both downsampling and upsampling MoEs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Sep 17, 2025

@BBuf Could you have a review?

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for the fused MoE up-kernel by leveraging TMA (Tensor Memory Access), resulting in a substantial latency reduction. The changes are well-structured, including a new benchmark script for separate tuning of MoE kernels, updated Triton kernel configurations, and modifications to the MoE implementation to integrate the new TMA kernel. While the performance gains are impressive, I've identified a couple of critical correctness issues in the new Triton kernels and some areas for improvement in the benchmark script to enhance accuracy and maintainability.

Comment on lines +516 to +519
if BLOCK_SIZE_N > group_n:
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale[None, :])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

When BLOCK_SIZE_N <= group_n, b_scale is loaded as a scalar for each K-group. Applying [None, :] indexing to a scalar is invalid in Triton. It seems it should be just b_scale, which would broadcast correctly.

Suggested change
if BLOCK_SIZE_N > group_n:
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale[None, :])
if BLOCK_SIZE_N > group_n:
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)

Comment on lines +869 to +872
if BLOCK_SIZE_N > group_n:
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale[None, :])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to a previous comment, when BLOCK_SIZE_N <= group_n, b_scale is loaded as a scalar. Applying [None, :] indexing to a scalar is invalid in Triton. It should probably be just b_scale, which would broadcast correctly.

Suggested change
if BLOCK_SIZE_N > group_n:
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale[None, :])
if BLOCK_SIZE_N > group_n:
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)


def prepare(i: int):
input_gating = gating_output[i]
topk_ids = torch.load(f"{topk_ids_dir}/topk_ids_layer{i%58+3}_idx{i//58}.pt")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The numbers 58 and 3 are magic numbers. It would be better to define them as constants with explanatory names (e.g., NUM_MOE_LAYERS and FIRST_MOE_LAYER_IDX) at the top of the file. This improves readability and maintainability.

Comment on lines +271 to +282
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
up_moe_use_tma = False
t0, t1 = run()
torch.cuda.synchronize()
up_moe_use_tma = True
t0_x, t_tma = run()
torch.cuda.synchronize()
latencies.append(t0)
latencies1.append(t1)
latencies_tma.append(t_tma)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The latency of the first kernel (invoke_fused_moe_kernel) when up_moe_use_tma is true (t0_x at line 278) is discarded. The tuning for this kernel is always based on t0 (from up_moe_use_tma = False), which corresponds to c_sorted=False. However, at runtime, if TMA is used for the up-kernel, the first kernel will be called with c_sorted=True. This means the configuration used for the first kernel might be suboptimal as it was tuned for c_sorted=False.

To fix this, you should also record t0_x and use it for tuning the first kernel when TMA is chosen for the up-kernel.

self.time_cost_all = None # kernel0, kernel1 without tma, kernel1 with tma

def update(self, config, time_cost, time_cost_all):
if time_cost < self.time_cost:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using < might cause the tuning to be sensitive to the order of configurations in the search space. If two configurations have the same performance, only the first one encountered will be kept. Using <= is more robust and makes the choice of the best config more stable, preferring later configurations in case of a tie.

Suggested change
if time_cost < self.time_cost:
if time_cost <= self.time_cost:

@soyail
Copy link

soyail commented Sep 18, 2025

Great work! I noticed that the Qwen3-MoE model also uses the same MoE operator. Would this modification bring benefits to Qwen3-MoE as well?
Additionally, when profiling Qwen3-MoE, I found that the execution time of the first fused_moe_kernel and the second fused_moe_kernel scales proportionally with their computation size, which seems inconsistent with the phenomenon you described in DeepSeek-R1. I am very curious about the reason behind this discrepancy and would be happy to discuss further.

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Sep 18, 2025

Great work! I noticed that the Qwen3-MoE model also uses the same MoE operator. Would this modification bring benefits to Qwen3-MoE as well? Additionally, when profiling Qwen3-MoE, I found that the execution time of the first fused_moe_kernel and the second fused_moe_kernel scales proportionally with their computation size, which seems inconsistent with the phenomenon you described in DeepSeek-R1. I am very curious about the reason behind this discrepancy and would be happy to discuss further.

What is the maximum token length and TP size in your test? We found that when the token length is large (e.g., 8K), the performance of the second MOE is similar to that of the first one. A possible reason is that after TP partitioning, the K dimension of the second MOE is relatively small (2048/tp8 = 256), while the N dimension is quite large (7168).

@soyail
Copy link

soyail commented Sep 18, 2025

Great work! I noticed that the Qwen3-MoE model also uses the same MoE operator. Would this modification bring benefits to Qwen3-MoE as well? Additionally, when profiling Qwen3-MoE, I found that the execution time of the first fused_moe_kernel and the second fused_moe_kernel scales proportionally with their computation size, which seems inconsistent with the phenomenon you described in DeepSeek-R1. I am very curious about the reason behind this discrepancy and would be happy to discuss further.

What is the maximum token length and TP size in your test? We found that when the token length is large (e.g., 8K), the performance of the second MOE is similar to that of the first one. A possible reason is that after TP partitioning, the K dimension of the second MOE is relatively small (2048/tp8 = 256), while the N dimension is quite large (7168).

Thanks for your reply. I was testing the Qwen3-30B-A3B model with an input length of only 1024 and tp-size=1, which might be the reason for the issue. Next, I will test the Qwen3-235B model with tp-size=4 to see if there will be any differences.

@xu-yfei xu-yfei changed the title Fused triton moe opt: add tma for fused moe up kernel Fused triton moe opt: add tma for fused moe down kernel Sep 24, 2025
@xu-yfei xu-yfei changed the title Fused triton moe opt: add tma for fused moe down kernel Opt fused triton moe: add tma for down proj kernel Sep 25, 2025
@Alcanderian
Copy link
Collaborator

Are there any possibility to merge the kernel into the origin fused moe kernel?

@Zhiy-Zhang
Copy link
Contributor

@xu-yfei Why is the TMA optimization applied only to the down_proj? Can the up_proj achieve similar acceleration effects as well?

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Sep 28, 2025

@xu-yfei Why is the TMA optimization applied only to the down_proj? Can the up_proj achieve similar acceleration effects as well?

The weights of up_proj can use TMA, but my local verification shows that the improvement is not significant (1~2% 10%<). As input A is hidden_states, its values are fetched discretely based on experts during computation, so TMA may not be applicable to input A currently.

For down_proj, I have modified the distribution of input A: when calculating the output of up_proj, values are written in blocks corresponding to experts, which enables down_proj to fetch values using TMA.

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Sep 29, 2025

Are there any possibility to merge the kernel into the origin fused moe kernel?

I'm currently attempting to merge. After merging, the performance drops significantly, though the precision remains normal. I'm investigating the cause, which may be due to branches leading to the performance degradation. After removing some branches, the performance has recovered considerably.

@zhyncs
Copy link
Collaborator

zhyncs commented Sep 29, 2025

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Oct 9, 2025

Are there any possibility to merge the kernel into the origin fused moe kernel?

@Alcanderian Done. The performance degradation in the previously merged code stems from two factors:

  1. The uncertainty introduced by if off_experts == -1 (used only in the EP scenario), which has been skipped in the TP scenario via the filter_expert input parameter;

  2. The loop for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)), which has been modified to for k_start in range(0, K, BLOCK_SIZE_K):.

@dongyibo
Copy link

if a_desc is not None:
assert use_fp8_w8a8 and group_n > 0 and group_k > 0
start_offs_m = pid_m * BLOCK_SIZE_M

@xu-yfei hello,in fused_moe_kernel:
Is it necessary to have fp8_w8a8 enabled in order to use a_desc?

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Oct 13, 2025

if a_desc is not None: assert use_fp8_w8a8 and group_n > 0 and group_k > 0 start_offs_m = pid_m * BLOCK_SIZE_M

@xu-yfei hello,in fused_moe_kernel: Is it necessary to have fp8_w8a8 enabled in order to use a_desc?

@dongyibo Currently, only the fp8_w8a8 scenario has been verified, while other scenarios remain unverified.

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Oct 13, 2025

@zhyncs Ready to merge. Currently, only H20 has the new tuning configuration. This feature is enabled only when the relevant configuration exists. The failed AMD test cases should be unrelated to this feature.

@Jacki1223
Copy link
Contributor

Hi! may I ask if you have measured the time consumption of fused_moe_kernel after applying TMA to input A? I found that it takes more time than without using TMA optimization.

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Oct 21, 2025

Hi! may I ask if you have measured the time consumption of fused_moe_kernel after applying TMA to input A? I found that it takes more time than without using TMA optimization.

What is your device? How many tokens do you use, and what's the configuration?

@Jacki1223
Copy link
Contributor

Hi! may I ask if you have measured the time consumption of fused_moe_kernel after applying TMA to input A? I found that it takes more time than without using TMA optimization.

What is your device? How many tokens do you use, and what's the configuration?

My device is H200, testing num_tokens from 1-16384, and the model config come from Qwen3-30B-A3B. However, what is confusing is that although the kernel takes longer, using TMA can still show throughput improvement in end-to-end inference.

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Oct 21, 2025

Hi! may I ask if you have measured the time consumption of fused_moe_kernel after applying TMA to input A? I found that it takes more time than without using TMA optimization.

What is your device? How many tokens do you use, and what's the configuration?

My device is H200, testing num_tokens from 1-16384, and the model config come from Qwen3-30B-A3B. However, what is confusing is that although the kernel takes longer, using TMA can still show throughput improvement in end-to-end inference.

Which script did you use to verify the kernel performance? The optimal configuration when using TMA is actually different from that when not using TMA, and the performance of the optimal TMA configuration can be much better.

In H20(96GB) use benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py

configs tokens gateup proj (us) gateup proj with TMA (us) down proj (us) down proj with TMA (us)
"64 128 128 1 4 3" 8192 2311 2290 2164 1456
"64 128 128 32 4 3" 8192 2311 2292 2204 1391
"64 128 128 16 4 3" 8192 2309 2289 2172 1377
model_path=xxx/deepseek-ai__DeepSeek-R1

cfg=$1

#python tuning_fused_moe_triton.py \
python tuning_fused_moe_triton_sep.py \
    --model $model_path \
    --tp-size 8 \
    --seed 128 \
    --dtype fp8_w8a8 \
    --batch-size $2 \
    --topk-ids-dir $topk_ids_dir/ \
    --configs ${cfg} \

@huangtingwei9988 huangtingwei9988 merged commit d2b8c41 into sgl-project:main Oct 28, 2025
127 of 143 checks passed
@lmu97
Copy link

lmu97 commented Nov 3, 2025

Hi, @xu-yfei , thank you for sharing this great work. When I want to generate topk_ids based on real data, is one query enough? Or I need to calculate statistic top-k in my dataset?

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Nov 3, 2025

Hi, @xu-yfei , thank you for sharing this great work. When I want to generate topk_ids based on real data, is one query enough? Or I need to calculate statistic top-k in my dataset?

@lmu97 My approach is simply to use a single long query.

@thincal
Copy link

thincal commented Nov 23, 2025

@xu-yfei I have one question, if topk_ids are tightly coupled with specific scenarios for better perf, how can we save a kernel configuration for different scenarios within a single kernel config file? thanks.

@thincal
Copy link

thincal commented Nov 24, 2025

@xu-yfei from the tuned config file, it seems that USE_TMA only applied for very large batch size (e.g. >= 1024), so for the long query with normal batch size (such as 4, 8) could it benefit from this optimization ?

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Nov 24, 2025

@xu-yfei I have one question, if topk_ids are tightly coupled with specific scenarios for better perf, how can we save a kernel configuration for different scenarios within a single kernel config file? thanks.

The current solution uses topk_ids data with uneven expert distribution, capturing only one long input and selecting 100 topk_ids entries from 58 layers. While a dataset-based approach might be more reasonable, the volume of data required for tuning is excessively large.

@xu-yfei from the tuned config file, it seems that USE_TMA only applied for very large batch size (e.g. >= 1024), so for the long query with normal batch size (such as 4, 8) could it benefit from this optimization ?

Based on the current implementation, the tuning results indicate that TMA yields no benefits in this scenario.

@thincal
Copy link

thincal commented Nov 24, 2025

@xu-yfei thanks for your detailed explanation.

@dongyibo
Copy link

Hello, may I ask why the following is required:
config["BLOCK_SIZE_M"] == down_config["BLOCK_SIZE_M"]
Is this necessary? @xu-yfei

@xu-yfei
Copy link
Contributor Author

xu-yfei commented Nov 25, 2025

Hello, may I ask why the following is required: config["BLOCK_SIZE_M"] == down_config["BLOCK_SIZE_M"] Is this necessary? @xu-yfei

In the TMA implementation, BLOCK_SIZE_M is related to the data layout of the first MoE's output and the second MoE's input. Within the same contiguous block of length BLOCK_SIZE_M, the experts are identical and can be loaded via TMA. Strictly speaking, it suffices that the BLOCK_SIZE_M of the first MoE is a multiple of that of the second MoE: config["BLOCK_SIZE_M"] % down_config["BLOCK_SIZE_M"] == 0.

@dongyibo
Copy link

Hello, may I ask why the following is required: config["BLOCK_SIZE_M"] == down_config["BLOCK_SIZE_M"] Is this necessary? @xu-yfei

In the TMA implementation, BLOCK_SIZE_M is related to the data layout of the first MoE's output and the second MoE's input. Within the same contiguous block of length BLOCK_SIZE_M, the experts are identical and can be loaded via TMA. Strictly speaking, it suffices that the BLOCK_SIZE_M of the first MoE is a multiple of that of the second MoE: config["BLOCK_SIZE_M"] % down_config["BLOCK_SIZE_M"] == 0.


Thank you for the explanation; I understand now.
I also found that the moe_align_block_size function requires BLOCK_SIZE_M, so up and down must use the same BLOCK_SIZE_M due to this constraint. @xu-yfei

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.