perf(gemma4 MTP): single-launch fused router (topk + softmax + scale)#1
Open
pyc96 wants to merge 1 commit into
Open
perf(gemma4 MTP): single-launch fused router (topk + softmax + scale)#1pyc96 wants to merge 1 commit into
pyc96 wants to merge 1 commit into
Conversation
This was referenced May 22, 2026
Draft
0ea98c6 to
a4444a7
Compare
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:
torch.topk -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
+ at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
softmax -> at::native::cunn_SoftMaxForward<4,float,...>
per_expert_scale[] -> at::native::index_elementwise_kernel<bf16,...>
topk_weights * ... -> at::native::elementwise_kernel<MulFunctor<bf16>>
cast to fp32 -> at::native::elementwise_kernel<copy>
torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels. vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.
This commit ports the algorithm to SGLang:
* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
python/sglang/srt/layers/gemma4_fused_ops.py. One Triton program per
token loads all E logits, packs (bijective(logit_bits), expert_id) into
int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
ids) in (fp32, int32). num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
bf16/fp32 inputs and falls back to the torch path otherwise. Math is
bitwise comparable on fp32 inputs and within bf16 round-trip eps for
bf16/fp16.
Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):
workload baseline this patch delta
chat random 1000/1000 2729.30 tok/s 2880.94 tok/s +5.6%
summariz. random 8000/1000 1060.98 tok/s 1108.42 tok/s +4.5%
chat median TPOT (ms) 21.11 20.70 -1.9%
chat accept length 2.75 2.80 +1.8%
MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.
Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.
Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.
Co-authored-by: Claude
a4444a7 to
5f717d2
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Single-launch Triton fused router for Gemma4 MoE replaces the per-layer
four-kernel
torch.topk -> softmax -> per_expert_scale[ids] -> mul -> castchain in
Gemma4MoE.routing_function.Staged on
pyc96/sglangonly — not opening upstream tosgl-project/sglang.Motivation
torch.profiler triage of
google/Gemma-4-26B-A4B-IT+ Gemma4 MTP on 1× B200(sm_100a, bf16,
--attention-backend triton, MTP--speculative-num-steps 3 --speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed~5.8% of decode GPU time to the split routing kernels:
at::native::sbtopk::gatherTopK<bf16,uint,2,false>at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>at::native::cunn_SoftMaxForward<4,float,...>at::native::index_elementwise_kernel<bf16>forper_expert_scale[ids]MulFunctor<bf16>fortopk_weights * scalevLLM ships an equivalent single-launch kernel (
_gemma4_routing_kernel,PR vllm-project/vllm#39083) that does the same logical work in ~1.1% of vLLM's
decode GPU time. This PR ports the algorithm to SGLang (rewritten from
scratch in SGLang style; algorithm is Apache-2.0).
Implementation
python/sglang/srt/layers/gemma4_fused_ops.py— new_gemma4_routing_kernelgemma4_fused_routing.(bijective(logit_bits), expert_id)into int64, runs onetl.sort,masks to the K largest, softmaxes in fp32, multiplies by
per_expert_scale[topk_ids], writes(weights, ids)in(fp32, int32).num_warps=1— for Gemma4E = num_experts = 128everything fits in asingle warp.
python/sglang/srt/models/gemma4_causal.py—Gemma4MoE.routing_functioncalls the fused kernel on CUDA fp16/bf16/fp32 inputs and falls back to
the original torch path otherwise.
Benchmark (1× B200, branch baseline = PR sgl-project#26026 head)
Both scenarios use
python -m sglang.bench_serving --backend sglang-oai-chat --random-range-ratio 1.0 --num-prompts 80 --warmup-requests 2 --seed 1.1000/1000output tok/s8000/1000output tok/sSame workload on
vllm/vllm-openai:nightly(gemma4_mtp, spec-tokens=3) forreference: chat 5309 tok/s, summ 2330 tok/s.
Quality (full MMLU, 57 subjects, N=14,042 questions per variant)
A/B via local
SGLANG_GEMMA4_DISABLE_FUSED_ROUTERenv-toggle inGemma4MoE.routing_function. Each variant run twice from a clean server start.McNemar χ² (continuity correction) = 1.97 (critical 3.84 at p=0.05). No
statistically significant difference between fused and torch at full-MMLU N.
Noise-floor decomposition:
Per-subject distribution: fused > torch on 28/57 subjects, torch > fused on
27/57, tie on 2. Bidirectional drift, comparable magnitudes — characteristic
of cascading bf16 round-off through ~30 routed-MoE layers, not systematic bias.
Equivalence probe (819,200 random tokens): top-K set disagreement 0.00%,
order disagreement 20.4%, max id-aligned weight diff when sets agree = 7.5e-3
(within bf16 eps). The fused kernel always picks the exact same K experts as
the torch fallback; only the (id, weight) tuple positions within
topk_ids[t, :K]may differ, which is a no-op for the downstream MoE sum.Tests
test/srt/layers/test_gemma4_fused_routing.pyexercises 47 shape/dtypecombinations (E ∈ {64, 128, 256}, K ∈ {4, 8}, T ∈ {0, 1, 7, 64, 128, 1024},
dtype ∈ {bf16, fp16, fp32}) against the previous torch routing function.
Tolerances are set to the input dtype eps; the fused kernel actually does
softmax in fp32 throughout while the torch fallback does softmax in input
dtype, so the fused path is the more accurate one.
Provenance
source copy.
Co-authored-by: Claudeon the commit.CI States
Latest PR Test (Base): ❌ Run #26531141706
Latest PR Test (Extra): ❌ Run #26531139307