Skip to content

perf(gemma4 MTP): single-launch fused router (topk + softmax + scale)#1

Open
pyc96 wants to merge 1 commit into
mainfrom
pyc/sota-gemma4-mtp-fused-routing
Open

perf(gemma4 MTP): single-launch fused router (topk + softmax + scale)#1
pyc96 wants to merge 1 commit into
mainfrom
pyc/sota-gemma4-mtp-fused-routing

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 22, 2026

Summary

Single-launch Triton fused router for Gemma4 MoE replaces the per-layer
four-kernel torch.topk -> softmax -> per_expert_scale[ids] -> mul -> cast
chain in Gemma4MoE.routing_function.

Staged on pyc96/sglang only — not opening upstream to sgl-project/sglang.

Motivation

torch.profiler triage of google/Gemma-4-26B-A4B-IT + Gemma4 MTP on 1× B200
(sm_100a, bf16, --attention-backend triton, MTP --speculative-num-steps 3 --speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to the split routing kernels:

Kernel Decode share
at::native::sbtopk::gatherTopK<bf16,uint,2,false> 2.4%
at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...> 2.0%
at::native::cunn_SoftMaxForward<4,float,...> (folded)
at::native::index_elementwise_kernel<bf16> for per_expert_scale[ids] 1.4%
MulFunctor<bf16> for topk_weights * scale small

vLLM ships an equivalent single-launch kernel (_gemma4_routing_kernel,
PR vllm-project/vllm#39083) that does the same logical work in ~1.1% of vLLM's
decode GPU time. This PR ports the algorithm to SGLang (rewritten from
scratch in SGLang style; algorithm is Apache-2.0).

Implementation

  • python/sglang/srt/layers/gemma4_fused_ops.py — new _gemma4_routing_kernel
    • gemma4_fused_routing.
    • One Triton program per token.
    • Packs (bijective(logit_bits), expert_id) into int64, runs one tl.sort,
      masks to the K largest, softmaxes in fp32, multiplies by
      per_expert_scale[topk_ids], writes (weights, ids) in (fp32, int32).
    • num_warps=1 — for Gemma4 E = num_experts = 128 everything fits in a
      single warp.
  • python/sglang/srt/models/gemma4_causal.pyGemma4MoE.routing_function
    calls the fused kernel on CUDA fp16/bf16/fp32 inputs and falls back to
    the original torch path otherwise.

Benchmark (1× B200, branch baseline = PR sgl-project#26026 head)

Both scenarios use python -m sglang.bench_serving --backend sglang-oai-chat --random-range-ratio 1.0 --num-prompts 80 --warmup-requests 2 --seed 1.

Workload Baseline This PR Δ
chat random 1000/1000 output tok/s 2729.30 2880.94 +5.6%
chat median TPOT (ms) 21.11 20.70 −1.9%
chat accept length 2.75 2.80 +1.8%
summarization random 8000/1000 output tok/s 1060.98 1108.42 +4.5%

Same workload on vllm/vllm-openai:nightly (gemma4_mtp, spec-tokens=3) for
reference: chat 5309 tok/s, summ 2330 tok/s.

Quality (full MMLU, 57 subjects, N=14,042 questions per variant)

A/B via local SGLANG_GEMMA4_DISABLE_FUSED_ROUTER env-toggle in
Gemma4MoE.routing_function. Each variant run twice from a clean server start.

Variant Run 1 Run 2 Mean
Fused router ON (this PR) 0.6329 0.6335 0.6332
Torch fallback (baseline) 0.6288 0.6280 0.6284

McNemar χ² (continuity correction) = 1.97 (critical 3.84 at p=0.05). No
statistically significant difference between fused and torch at full-MMLU N.

Noise-floor decomposition:

Per-question divergence Aggregate variation
Same-variant (re-run, new random_seed) ~2.0% 0.07pp
Cross-variant (fused vs torch) ~20.0% 0.50pp

Per-subject distribution: fused > torch on 28/57 subjects, torch > fused on
27/57, tie on 2. Bidirectional drift, comparable magnitudes — characteristic
of cascading bf16 round-off through ~30 routed-MoE layers, not systematic bias.

Equivalence probe (819,200 random tokens): top-K set disagreement 0.00%,
order disagreement 20.4%, max id-aligned weight diff when sets agree = 7.5e-3
(within bf16 eps). The fused kernel always picks the exact same K experts as
the torch fallback; only the (id, weight) tuple positions within
topk_ids[t, :K] may differ, which is a no-op for the downstream MoE sum.

Tests

test/srt/layers/test_gemma4_fused_routing.py exercises 47 shape/dtype
combinations (E ∈ {64, 128, 256}, K ∈ {4, 8}, T ∈ {0, 1, 7, 64, 128, 1024},
dtype ∈ {bf16, fp16, fp32}) against the previous torch routing function.
Tolerances are set to the input dtype eps; the fused kernel actually does
softmax in fp32 throughout while the torch fallback does softmax in input
dtype, so the fused path is the more accurate one.

PATH=.venv/bin:$PATH python -m pytest test/srt/layers/test_gemma4_fused_routing.py
# 47 passed

Provenance


CI States

Latest PR Test (Base): ❌ Run #26531141706
Latest PR Test (Extra): ❌ Run #26531139307

Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
@pyc96 pyc96 force-pushed the pyc/sota-gemma4-mtp-fused-routing branch from a4444a7 to 5f717d2 Compare May 27, 2026 18:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant