Skip to content

[FEAT] [Perf] [Gemma4] Fused Gemma4 Routing Function Triton#39083

Merged
tjtanaa merged 13 commits into
vllm-project:mainfrom
EmbeddedLLM:gemma4router
Apr 19, 2026
Merged

[FEAT] [Perf] [Gemma4] Fused Gemma4 Routing Function Triton#39083
tjtanaa merged 13 commits into
vllm-project:mainfrom
EmbeddedLLM:gemma4router

Conversation

@tjtanaa
Copy link
Copy Markdown
Member

@tjtanaa tjtanaa commented Apr 6, 2026

Purpose

Improve the performance of Gemma4 by introducing triton fused routing function.

The custom routing function introduces many synchronizations point and read write to global memory.

Moreover, the custom routing function is not captured under torch compile.

Test Plan

Perform microbenchmark on the triton kernel versus torch

Perform end to end testing on A100, H100, MI300X, and B60 (TP4, FP8 online quant): Accuracy and performance

NOTE: The accuracy value is to provide reference baseline as to show that there are no notable accuracy deviation from baseline. It does not represent the full capability of the models.

Test Result

End-to-End Serving Benchmarks

Benchmark Configuration

  • Model: google/gemma-4-26B-A4B-it
  • Experts: 128, Top-K: 8
  • Token range: 1 to 262,144 tokens
  • Serving test: 160 requests, 16 max concurrency, 2048 input tokens, 512 output tokens per request
Benchmark commands and lmeval commands

Workload

Server:

#!/bin/bash

rm -rf /root/.cache/vllm

vllm serve google/gemma-4-26B-A4B-it \
-tp 1 \
--attention-backend TRITON_ATTN \
--no-enable-prefix-caching

Workload Command:

#!/bin/bash
vllm bench serve \
--backend vllm \
--host 127.0.0.1 \
--port 8000 \
--endpoint "/v1/completions" \
--model "google/gemma-4-26B-A4B-it" \
--dataset-name random \
--num-prompts 160 \
--max-concurrency 16 \
--random-input-len 2048 \
--random-output-len 512 \
--save-result

Lm_Eval Command:

#!/bin/bash

lm_eval \
--model local-completions \
--tasks gsm8k \
--model_args model=google/gemma-4-26B-A4B-it,base_url=http://127.0.0.1:8000/v1/completions,max_tokens=8192 \
--batch_size 100
GPU Total Throughput TPOT ITL Peak Output Throughput Accuracy (GSM8K)
A100 +8.6% (3,419 → 3,689 tok/s) -8.5% (20.98 → 19.25ms) -8.2% (20.98 → 19.56ms) +1.9% (864 → 880 tok/s avg) (0.4458 → 0.4428)
MI300X +21.4% (5,252 → 6,382 tok/s) -18.9% (14.07 → 11.41ms) -18.9% (14.07 → 11.41ms) +26.9% (1,248 → 1,584 tok/s) (0.4306 → 0.4291)
H100 +4.3% (5,813 → 6,060 tok/s) -4.4% (12.73 → 12.17ms) -4.4% (12.74 → 12.18ms) +4.7% (1,408 → 1,488 tok/s avg) (0.4594 → 0.4481)
B60 (4×TP, FP8) +32.0% (477 → 629 tok/s) -26.0% (157.1 → 116.2ms) -26.0% (157.1 → 116.2ms) +33.3% (144 → 192 tok/s avg) (0.4390 → 0.4511)

Microbenchmark Results

The Triton kernel shows substantial speedups over PyTorch implementations across all token counts and GPU architectures:

Speedup vs PyTorch Variants (Geometric Mean)

GPU vs Eager vs compile-default vs compile-reduce-overhead vs compile-max-autotune
A100 12.68× 4.50× 4.97× 4.45×
H100 10.58× 4.51× 5.02× 4.41×
MI300X 14.37× 6.00× 7.24× 5.81×
B60 6.33× 1.08× 1.07× 1.05×

Integration consideration

I did consider to add it as a new custom routing function however the comment in

class RoutingMethodType(IntEnum):
deter me from adding it there as RoutingMethodType has to be kept insync with flashinfer enum definition, which does not have Gemma4 router enum.

So, the whole function is not explicitly added to the model definition file as it is a model specific routing function and it is pioneered by Gemma4 model.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

tjtanaa added 5 commits April 6, 2026 17:05
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 6, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 6, 2026
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused Triton routing kernel for the Gemma4 MoE model, optimizing the routing process on CUDA-compatible platforms. The changes include the implementation of the _gemma4_routing_kernel, a corresponding Python wrapper, and a comprehensive test suite to ensure parity with the PyTorch reference implementation. Review feedback recommends using the tl.LOG2E constant for better readability, increasing the default num_warps to 4 to improve GPU resource utilization, and updating the docstring to remove a copy-paste artifact.

Comment thread vllm/model_executor/models/gemma4.py
Comment thread vllm/model_executor/models/gemma4.py
Comment thread vllm/model_executor/models/gemma4.py
@mergify mergify Bot removed the needs-rebase label Apr 6, 2026
tjtanaa added 5 commits April 6, 2026 23:34
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@tjtanaa tjtanaa changed the title [Perf] [Gemma4] Fused Gemma4 Routing Function Triton [FEAT] [Perf] [Gemma4] Fused Gemma4 Routing Function Triton Apr 7, 2026
gating_output, topk, per_expert_scale
)

return gemma4_routing_function_torch(gating_output, topk, per_expert_scale)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you just add torch.compile to this function?

Copy link
Copy Markdown
Member Author

@tjtanaa tjtanaa Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try. However, in the microbenchmark, I found that the triton kernel is still faster than the torch compile implementations.

GPU vs Eager vs compile-default vs compile-reduce-overhead vs compile-max-autotune
A100 12.68× 4.50× 4.97× 4.45×
H100 10.58× 4.51× 5.02× 4.41×
MI300X 14.37× 6.00× 7.24× 5.81×
B60 6.33× 1.08× 1.07× 1.05×

Copy link
Copy Markdown
Member Author

@tjtanaa tjtanaa Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have run the test with two different modes on MI300X

compiled_routing_function_torch = torch.compile(routing_function_torch)

and

compiled_routing_function_torch = torch.compile(routing_function_torch, mode='max-autotune-no-cudagraphs')

Metric Torch Torch Compile (default) Torch Compile (max-autotune) Triton Kernel
Benchmark Duration (s) 77.99 70.49 69.46 64.18
Request Throughput (req/s) 2.05 2.27 2.30 2.49
Output Token Throughput (tok/s) 1,050 1,162 1,179 1,276
Total Token Throughput (tok/s) 5,252 5,811 5,897 6,382
Mean TTFT (ms) 603 589 594 583
Mean TPOT (ms) 14.07 12.63 12.43 11.41
Mean ITL (ms) 14.07 12.63 12.43 11.41
GSM8K Accuracy (flexible) 43.06% 43.37% 45.26% 42.91%

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It matches the microbenchmark speed up trend

mi300xbenchmark.log

Additional information, the following are the microbenchmark on a100 and h100

a100benchmarkkernel.log

H100benchmarkkernel.log

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg This is the benchmark perf for H100

compiled_routing_function_torch = torch.compile(routing_function_torch, mode='max-autotune-no-cudagraphs')

vs Triton Kernel

Summary Statistics (Averaged across 3 runs)

Metric Torch Compile Triton Winner
Throughput
Request throughput (req/s) 2.43 2.47 ✅ Triton (+1.6%)
Output token throughput (tok/s) 1244.44 1266.30 ✅ Triton (+1.8%)
Total token throughput (tok/s) 6222.20 6331.49 ✅ Triton (+1.8%)
Peak output token throughput (tok/s) 1552.00 1584.00 ✅ Triton (+2.1%)
Latency
Mean TTFT (ms) 760.64 741.30 ✅ Triton (-2.5%)
Median TTFT (ms) 592.70 584.92 ✅ Triton (-1.3%)
P99 TTFT (ms) 3230.87 3174.35 ✅ Triton (-1.7%)
Mean TPOT (ms) 11.41 11.24 ✅ Triton (-1.5%)
Median TPOT (ms) 11.34 11.17 ✅ Triton (-1.5%)
P99 TPOT (ms) 12.57 12.64 ✅ Torch Compile (+0.6%)
Mean ITL (ms) 11.42 11.26 ✅ Triton (-1.4%)
Median ITL (ms) 10.82 10.63 ✅ Triton (-1.8%)
Duration
Benchmark duration (s) 65.98 64.90 ✅ Triton (-1.6%)

Comment thread tests/kernels/moe/test_gemma4router.py Outdated
return w.gather(1, order), ids.gather(1, order)


# Gemma4 Moe Model has context length of 250K
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Gemma4 Moe Model has context length of 250K
# Gemma4 MoE Model has context length of 250K

expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
topk_weights = topk_weights * expert_scales
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
if current_platform.is_cuda_alike() or current_platform.is_xpu():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use something like custom op to dispatch?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Op itself is a very model specific op, and all in-tree models are benefited from the same triton kernel. So that's why this PR is shaped it this way.

I will follow up with a vLLM IR Op as CustomOp is going to be replaced by vLLM IR Op.

tri_ws, tri_is = sort_by_id(tri_w, tri_ids)

ids_match = (ref_is == tri_is).all().item()
weights_match = torch.allclose(ref_ws, tri_ws, atol=1e-2, rtol=1e-2)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the atol here too large?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the routing function especially in bf16 has larger numerical divergence.

I am currently following the threshold used in existing unit tests of the fused routing methods:

topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bfloat16

T= 1 E= 128 K=8 bfloat16 ids=True max_Δweight=6.34e-04
T= 2 E= 128 K=8 bfloat16 ids=True max_Δweight=6.34e-04
T= 2048 E= 128 K=8 bfloat16 ids=True max_Δweight=3.40e-03
T=250000 E= 128 K=8 bfloat16 ids=True max_Δweight=3.89e-03

float16

T= 1 E= 128 K=8 float16 ids=True max_Δweight=6.76e-05
T= 2 E= 128 K=8 float16 ids=True max_Δweight=6.76e-05
T= 2048 E= 128 K=8 float16 ids=True max_Δweight=2.81e-04
T=250000 E= 128 K=8 float16 ids=True max_Δweight=7.93e-04

float32

T= 1 E= 128 K=8 float32 ids=True max_Δweight=7.45e-09
T= 2 E= 128 K=8 float32 ids=True max_Δweight=1.49e-08
T= 2048 E= 128 K=8 float32 ids=True max_Δweight=5.96e-08
T=250000 E= 128 K=8 float32 ids=True max_Δweight=1.19e-07

tjtanaa added 2 commits April 19, 2026 08:12
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@tjtanaa tjtanaa enabled auto-merge (squash) April 19, 2026 08:24
@github-actions github-actions Bot added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 19, 2026
@tjtanaa tjtanaa merged commit 45232a4 into vllm-project:main Apr 19, 2026
65 checks passed
@Naist4869
Copy link
Copy Markdown

https://paste.sh/FS6-P1Ld#3EhFXYWTjxoxW0WcQ1QtMKDr

Triton not support Non-causal attention

bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request Apr 20, 2026
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Apr 23, 2026
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
…ject#39083)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
Rukhaiya2004 pushed a commit to Rukhaiya2004/vllm that referenced this pull request May 23, 2026
pyc96 added a commit to pyc96/sglang that referenced this pull request May 27, 2026
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
pyc96 added a commit to pyc96/sglang that referenced this pull request May 27, 2026
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants