[FEAT] [Perf] [Gemma4] Fused Gemma4 Routing Function Triton#39083
Conversation
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a fused Triton routing kernel for the Gemma4 MoE model, optimizing the routing process on CUDA-compatible platforms. The changes include the implementation of the _gemma4_routing_kernel, a corresponding Python wrapper, and a comprehensive test suite to ensure parity with the PyTorch reference implementation. Review feedback recommends using the tl.LOG2E constant for better readability, increasing the default num_warps to 4 to improve GPU resource utilization, and updating the docstring to remove a copy-paste artifact.
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
| gating_output, topk, per_expert_scale | ||
| ) | ||
|
|
||
| return gemma4_routing_function_torch(gating_output, topk, per_expert_scale) |
There was a problem hiding this comment.
What happens if you just add torch.compile to this function?
There was a problem hiding this comment.
Let me try. However, in the microbenchmark, I found that the triton kernel is still faster than the torch compile implementations.
| GPU | vs Eager | vs compile-default | vs compile-reduce-overhead | vs compile-max-autotune |
|---|---|---|---|---|
| A100 | 12.68× | 4.50× | 4.97× | 4.45× |
| H100 | 10.58× | 4.51× | 5.02× | 4.41× |
| MI300X | 14.37× | 6.00× | 7.24× | 5.81× |
| B60 | 6.33× | 1.08× | 1.07× | 1.05× |
There was a problem hiding this comment.
I have run the test with two different modes on MI300X
compiled_routing_function_torch = torch.compile(routing_function_torch)
and
compiled_routing_function_torch = torch.compile(routing_function_torch, mode='max-autotune-no-cudagraphs')
| Metric | Torch | Torch Compile (default) | Torch Compile (max-autotune) | Triton Kernel |
|---|---|---|---|---|
| Benchmark Duration (s) | 77.99 | 70.49 | 69.46 | 64.18 |
| Request Throughput (req/s) | 2.05 | 2.27 | 2.30 | 2.49 |
| Output Token Throughput (tok/s) | 1,050 | 1,162 | 1,179 | 1,276 |
| Total Token Throughput (tok/s) | 5,252 | 5,811 | 5,897 | 6,382 |
| Mean TTFT (ms) | 603 | 589 | 594 | 583 |
| Mean TPOT (ms) | 14.07 | 12.63 | 12.43 | 11.41 |
| Mean ITL (ms) | 14.07 | 12.63 | 12.43 | 11.41 |
| GSM8K Accuracy (flexible) | 43.06% | 43.37% | 45.26% | 42.91% |
There was a problem hiding this comment.
It matches the microbenchmark speed up trend
Additional information, the following are the microbenchmark on a100 and h100
There was a problem hiding this comment.
@ProExpertProg This is the benchmark perf for H100
compiled_routing_function_torch = torch.compile(routing_function_torch, mode='max-autotune-no-cudagraphs')
vs Triton Kernel
Summary Statistics (Averaged across 3 runs)
| Metric | Torch Compile | Triton | Winner |
|---|---|---|---|
| Throughput | |||
| Request throughput (req/s) | 2.43 | 2.47 | ✅ Triton (+1.6%) |
| Output token throughput (tok/s) | 1244.44 | 1266.30 | ✅ Triton (+1.8%) |
| Total token throughput (tok/s) | 6222.20 | 6331.49 | ✅ Triton (+1.8%) |
| Peak output token throughput (tok/s) | 1552.00 | 1584.00 | ✅ Triton (+2.1%) |
| Latency | |||
| Mean TTFT (ms) | 760.64 | 741.30 | ✅ Triton (-2.5%) |
| Median TTFT (ms) | 592.70 | 584.92 | ✅ Triton (-1.3%) |
| P99 TTFT (ms) | 3230.87 | 3174.35 | ✅ Triton (-1.7%) |
| Mean TPOT (ms) | 11.41 | 11.24 | ✅ Triton (-1.5%) |
| Median TPOT (ms) | 11.34 | 11.17 | ✅ Triton (-1.5%) |
| P99 TPOT (ms) | 12.57 | 12.64 | ✅ Torch Compile (+0.6%) |
| Mean ITL (ms) | 11.42 | 11.26 | ✅ Triton (-1.4%) |
| Median ITL (ms) | 10.82 | 10.63 | ✅ Triton (-1.8%) |
| Duration | |||
| Benchmark duration (s) | 65.98 | 64.90 | ✅ Triton (-1.6%) |
| return w.gather(1, order), ids.gather(1, order) | ||
|
|
||
|
|
||
| # Gemma4 Moe Model has context length of 250K |
There was a problem hiding this comment.
| # Gemma4 Moe Model has context length of 250K | |
| # Gemma4 MoE Model has context length of 250K |
| expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype) | ||
| topk_weights = topk_weights * expert_scales | ||
| return topk_weights.to(torch.float32), topk_ids.to(torch.int32) | ||
| if current_platform.is_cuda_alike() or current_platform.is_xpu(): |
There was a problem hiding this comment.
can we use something like custom op to dispatch?
There was a problem hiding this comment.
The Op itself is a very model specific op, and all in-tree models are benefited from the same triton kernel. So that's why this PR is shaped it this way.
I will follow up with a vLLM IR Op as CustomOp is going to be replaced by vLLM IR Op.
| tri_ws, tri_is = sort_by_id(tri_w, tri_ids) | ||
|
|
||
| ids_match = (ref_is == tri_is).all().item() | ||
| weights_match = torch.allclose(ref_ws, tri_ws, atol=1e-2, rtol=1e-2) |
There was a problem hiding this comment.
It seems that the routing function especially in bf16 has larger numerical divergence.
I am currently following the threshold used in existing unit tests of the fused routing methods:
vllm/tests/kernels/moe/test_fused_topk.py
Line 135 in 4353c9c
There was a problem hiding this comment.
bfloat16
T= 1 E= 128 K=8 bfloat16 ids=True max_Δweight=6.34e-04
T= 2 E= 128 K=8 bfloat16 ids=True max_Δweight=6.34e-04
T= 2048 E= 128 K=8 bfloat16 ids=True max_Δweight=3.40e-03
T=250000 E= 128 K=8 bfloat16 ids=True max_Δweight=3.89e-03
float16
T= 1 E= 128 K=8 float16 ids=True max_Δweight=6.76e-05
T= 2 E= 128 K=8 float16 ids=True max_Δweight=6.76e-05
T= 2048 E= 128 K=8 float16 ids=True max_Δweight=2.81e-04
T=250000 E= 128 K=8 float16 ids=True max_Δweight=7.93e-04
float32
T= 1 E= 128 K=8 float32 ids=True max_Δweight=7.45e-09
T= 2 E= 128 K=8 float32 ids=True max_Δweight=1.49e-08
T= 2048 E= 128 K=8 float32 ids=True max_Δweight=5.96e-08
T=250000 E= 128 K=8 float32 ids=True max_Δweight=1.19e-07
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
|
https://paste.sh/FS6-P1Ld#3EhFXYWTjxoxW0WcQ1QtMKDr Triton not support Non-causal attention |
…ject#39083) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
…ject#39083) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
…ject#39083) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
…ject#39083) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
…ject#39083) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
…ject#39083) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
…ject#39083) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
…ject#39083) Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:
torch.topk -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
+ at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
softmax -> at::native::cunn_SoftMaxForward<4,float,...>
per_expert_scale[] -> at::native::index_elementwise_kernel<bf16,...>
topk_weights * ... -> at::native::elementwise_kernel<MulFunctor<bf16>>
cast to fp32 -> at::native::elementwise_kernel<copy>
torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels. vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.
This commit ports the algorithm to SGLang:
* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
python/sglang/srt/layers/gemma4_fused_ops.py. One Triton program per
token loads all E logits, packs (bijective(logit_bits), expert_id) into
int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
ids) in (fp32, int32). num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
bf16/fp32 inputs and falls back to the torch path otherwise. Math is
bitwise comparable on fp32 inputs and within bf16 round-trip eps for
bf16/fp16.
Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):
workload baseline this patch delta
chat random 1000/1000 2729.30 tok/s 2880.94 tok/s +5.6%
summariz. random 8000/1000 1060.98 tok/s 1108.42 tok/s +4.5%
chat median TPOT (ms) 21.11 20.70 -1.9%
chat accept length 2.75 2.80 +1.8%
MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.
Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.
Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.
Co-authored-by: Claude
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:
torch.topk -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
+ at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
softmax -> at::native::cunn_SoftMaxForward<4,float,...>
per_expert_scale[] -> at::native::index_elementwise_kernel<bf16,...>
topk_weights * ... -> at::native::elementwise_kernel<MulFunctor<bf16>>
cast to fp32 -> at::native::elementwise_kernel<copy>
torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels. vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.
This commit ports the algorithm to SGLang:
* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
python/sglang/srt/layers/gemma4_fused_ops.py. One Triton program per
token loads all E logits, packs (bijective(logit_bits), expert_id) into
int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
ids) in (fp32, int32). num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
bf16/fp32 inputs and falls back to the torch path otherwise. Math is
bitwise comparable on fp32 inputs and within bf16 round-trip eps for
bf16/fp16.
Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):
workload baseline this patch delta
chat random 1000/1000 2729.30 tok/s 2880.94 tok/s +5.6%
summariz. random 8000/1000 1060.98 tok/s 1108.42 tok/s +4.5%
chat median TPOT (ms) 21.11 20.70 -1.9%
chat accept length 2.75 2.80 +1.8%
MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.
Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.
Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.
Co-authored-by: Claude
Purpose
Improve the performance of Gemma4 by introducing triton fused routing function.
The custom routing function introduces many synchronizations point and read write to global memory.
Moreover, the custom routing function is not captured under torch compile.
Test Plan
Perform microbenchmark on the triton kernel versus torch
Perform end to end testing on A100, H100, MI300X, and B60 (TP4, FP8 online quant): Accuracy and performance
NOTE: The accuracy value is to provide reference baseline as to show that there are no notable accuracy deviation from baseline. It does not represent the full capability of the models.
Test Result
End-to-End Serving Benchmarks
Benchmark Configuration
Benchmark commands and lmeval commands
Workload
Server:
Workload Command:
Lm_Eval Command:
Microbenchmark Results
The Triton kernel shows substantial speedups over PyTorch implementations across all token counts and GPU architectures:
Speedup vs PyTorch Variants (Geometric Mean)
Integration consideration
I did consider to add it as a new custom routing function however the comment in
vllm/vllm/model_executor/layers/fused_moe/config.py
Line 102 in 47e6050
So, the whole function is not explicitly added to the model definition file as it is a model specific routing function and it is pioneered by Gemma4 model.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.