Skip to content

Add SwapAB Optimization for triton fused_moe_kernel on SM90.#15712

Merged
Fridge003 merged 4 commits intosgl-project:mainfrom
Insideyyy:fused_moe_swapAB
Jan 7, 2026
Merged

Add SwapAB Optimization for triton fused_moe_kernel on SM90.#15712
Fridge003 merged 4 commits intosgl-project:mainfrom
Insideyyy:fused_moe_swapAB

Conversation

@Insideyyy
Copy link
Copy Markdown
Contributor

@Insideyyy Insideyyy commented Dec 24, 2025

Motivation

In case of a small M dimension and using fp8_w8a8 on SM90, SwapAB brings significant benefit by transposing input A, B to make better use of WGMMA.

Modifications

SwapAB is enabled under all the following conditions:

  • use_fp8_w8a8
  • SM90 is supported
  • config["BLOCK_SIZE_M"] < 64 and config["BLOCK_SIZE_N"] >= 64

If SwapAB is enabled, a, b, a_scale, b_scale, accumulator will be transposed before tl.dot() and accumulator will be transposed back after k iterations.

Accuracy Tests

Before this PR:


$python3 benchmark/gsm8k/bench_sglang.py --num-questions 200 --parallel 8 --num-shots 8 --port 8188  
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:52<00:00,  3.82it/s]
Accuracy: 0.965
Invalid: 0.000
Latency: 52.392 s
Output throughput: 476.100 token/s

$python3 benchmark/gsm8k/bench_sglang.py --num-questions 200 --parallel 32 --num-shots 8 --port 8188
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:26<00:00,  7.46it/s]
Accuracy: 0.945
Invalid: 0.000
Latency: 26.948 s
Output throughput: 902.068 token/s

$python3 benchmark/gsm8k/bench_sglang.py --num-questions 200 --parallel 128 --num-shots 8 --port 8188  
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:13<00:00, 14.32it/s]
Accuracy: 0.965
Invalid: 0.000
Latency: 14.052 s
Output throughput: 1760.288 token/s

After this PR:

$python3 benchmark/gsm8k/bench_sglang.py --num-questions 200 --parallel 8 --num-shots 8 --port 8188 
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:48<00:00,  4.17it/s]
Accuracy: 0.970
Invalid: 0.000
Latency: 48.143 s
Output throughput: 517.128 token/s

$python3 benchmark/gsm8k/bench_sglang.py --num-questions 200 --parallel 32 --num-shots 8 --port 8188  
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:24<00:00,  8.21it/s]
Accuracy: 0.940
Invalid: 0.000
Latency: 24.458 s
Output throughput: 1012.738 token/s

$python3 benchmark/gsm8k/bench_sglang.py --num-questions 200 --parallel 128 --num-shots 8 --port 8188
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:12<00:00, 15.60it/s]
Accuracy: 0.950
Invalid: 0.000
Latency: 12.973 s
Output throughput: 1891.012 token/s

Benchmarking and Profiling

We tested GLM-4.6V-FP8 on H20-3e. MoE configs are tuned using benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py script.

Fused moe module

fused-moe-performance(ms):

batch_size Before this PR After this PR speedup
1 0.058496 0.051936 1.126309304
2 0.069696 0.061696 1.12966805
4 0.104 0.080864 1.286110012
8 0.141856 0.123984 1.144147632
32 0.247232 0.190592 1.297179315
128 0.270624 0.21136 1.280393641
256 0.343616 0.250176 1.373497058
512 0.358432 0.326848 1.096632074
1024 0.50912 0.49472 1.029107374
2048 0.840432 0.841968 0.998175703
4096 1.542496 1.544416 0.998756812

End to end

Setup

server:

MODEL=/root/GLM-4.6V-FP8
python3 -m sglang.launch_server \
        --model-path $MODEL \
        --host 127.0.0.1 \
        --port 8188 \
        --trust-remote-code \
        --mem-fraction-static 0.8 \
        --attention-backend flashinfer \
        --tp-size 4

client:

python3 -m sglang.bench_serving \
        --backend sglang \
        --host 127.0.0.1 \
        --port 8188 \
        --max-concurrency 8 \
        --dataset-name sharegpt \
        --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \
        --num-prompt 100

Performance

Before this PR:

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 8         
Successful requests:                     100       
Benchmark duration (s):                  44.23     
Total input tokens:                      33279     
Total input text tokens:                 33279     
Total input vision tokens:               0         
Total generated tokens:                  21392     
Total generated tokens (retokenized):    21366     
Request throughput (req/s):              2.26      
Input token throughput (tok/s):          752.34    
Output token throughput (tok/s):         483.61    
Peak output token throughput (tok/s):    648.00    
Peak concurrent requests:                14        
Total token throughput (tok/s):          1235.95   
Concurrency:                             7.55      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   3341.50   
Median E2E Latency (ms):                 2437.68   
---------------Time to First Token----------------
Mean TTFT (ms):                          124.87    
Median TTFT (ms):                        90.04     
P99 TTFT (ms):                           630.39    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.80     
Median TPOT (ms):                        14.51     
P99 TPOT (ms):                           20.78     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           15.11     
Median ITL (ms):                         12.51     
P95 ITL (ms):                            14.55     
P99 ITL (ms):                            76.87     
Max ITL (ms):                            734.65    
==================================================

After this PR:

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 8         
Successful requests:                     100       
Benchmark duration (s):                  38.13     
Total input tokens:                      33279     
Total input text tokens:                 33279     
Total input vision tokens:               0         
Total generated tokens:                  21392     
Total generated tokens (retokenized):    21385     
Request throughput (req/s):              2.62      
Input token throughput (tok/s):          872.76    
Output token throughput (tok/s):         561.01    
Peak output token throughput (tok/s):    720.00    
Peak concurrent requests:                15        
Total token throughput (tok/s):          1433.77   
Concurrency:                             7.53      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   2870.34   
Median E2E Latency (ms):                 2057.39   
---------------Time to First Token----------------
Mean TTFT (ms):                          95.47     
Median TTFT (ms):                        86.83     
P99 TTFT (ms):                           149.95    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          12.94     
Median TPOT (ms):                        12.81     
P99 TPOT (ms):                           18.61     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           13.03     
Median ITL (ms):                         11.16     
P95 ITL (ms):                            13.22     
P99 ITL (ms):                            74.15     
Max ITL (ms):                            80.46     
==================================================

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ClawSeven
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@Fridge003 Fridge003 merged commit ee4d228 into sgl-project:main Jan 7, 2026
317 of 378 checks passed
michaelzhang-ai added a commit to michaelzhang-ai/sglang that referenced this pull request Jan 7, 2026
@sunxxuns
Copy link
Copy Markdown
Collaborator

sunxxuns commented Jan 7, 2026

this pr didn't pass amd CI, and caused a failure. please avoid such cases for community sharing.

@Fridge003
Copy link
Copy Markdown
Collaborator

@Insideyyy
Copy link
Copy Markdown
Contributor Author

@Fridge003 Sorry for causing trouble. I'll make a fix.

Hi @Insideyyy seems this PR will break some AMD CIs https://github.com/sgl-project/sglang/actions/runs/20688902361/job/59402606297 https://github.com/sgl-project/sglang/actions/runs/20791025573/job/59713014224?pr=11349

So we reverted it temporarily. Can you make a fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants