Skip to content

perf(gemma4): single-launch fused router (topk + softmax + scale)#26502

Open
pyc96 wants to merge 6 commits into
sgl-project:mainfrom
pyc96:pyc/perf-gemma4-fused-router
Open

perf(gemma4): single-launch fused router (topk + softmax + scale)#26502
pyc96 wants to merge 6 commits into
sgl-project:mainfrom
pyc96:pyc/perf-gemma4-fused-router

Conversation

@pyc96
Copy link
Copy Markdown
Collaborator

@pyc96 pyc96 commented May 27, 2026

Motivation

Gemma4MoE.routing_function emits four per-layer GPU kernels for every MoE
forward pass (one decode step touches ~30 routed-MoE layers):

Kernel Decode share (Gemma-4-26B-A4B-IT + MTP, B200)
at::native::sbtopk::gatherTopK<bf16,uint,2,false> (torch.topk) 2.4%
at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...> (torch.topk tie-break) 2.0%
at::native::index_elementwise_kernel<bf16> (per_expert_scale[topk_ids]) 1.4%
MulFunctor<bf16> + softmax + fp32 cast folded
total ~5.8%

vLLM's _gemma4_routing_kernel
(apache-2.0) does the same logical work in a single Triton launch and brings
its share of decode GPU time to ~1.1%. This PR ports the same algorithm
to SGLang, rewritten from scratch in SGLang style.

Modifications

  • New file python/sglang/srt/layers/gemma4_fused_ops.py:
    • _gemma4_routing_kernel — one Triton program per token. Packs
      (bijective(logit_bits), expert_id) into int64; the bijection is
      anti-monotone on the float value and the <<32 shift puts the int32
      key's high bit into the int64 sign position, so signed ascending
      tl.sort yields the original logits in descending float order in one
      pass. Softmaxes in fp32, multiplies by per_expert_scale[topk_ids],
      writes (fp32 weights, int32 ids) matching the SGLang TopK output
      contract.
    • num_warps=1 for Gemma4's E = num_experts = 128 (fits in one warp).
    • Defensive E <= 1024 assert.
  • python/sglang/srt/models/gemma4_causal.py:
    • Gemma4MoE.routing_function calls the fused kernel on CUDA fp16/bf16/
      fp32 router logits; falls back to the original torch path otherwise.

Quantization compatibility

The router projection in Gemma4Router is constructed with
quant_config=None, so router logits are always emitted in the activation
dtype (bf16/fp16/fp32) regardless of expert quantization. The fast path's
dtype guard is therefore safe under any expert quantization that goes
through the standard topk dispatch.

MoE runner backend Gemma4 routing path Fused kernel called?
triton / triton_kernel (BF16, FP16) custom_routing_function via select_experts
FP8 (compressed-tensors, W8A8) custom_routing_function via select_experts
flashinfer_trtllm_routed (NVFP4 opt-in) custom_routing_function via select_experts
flashinfer_trtllm (NVFP4 default on SM100) BypassedTopKOutput → trtllm does routing internally ❌ (no-op)
flashinfer_mxfp4 (is_fp4_experts=False) BypassedTopKOutput → flashinfer mxfp4 kernel ❌ (no-op)

On the BYPASSED paths neither the fused kernel nor the previous torch
fallback runs, so this PR is a strict no-op there.

Accuracy Tests

Unit tests — test/srt/layers/test_gemma4_fused_routing.py (47 cases)

47 shape/dtype combinations against the previous torch routing function:
E ∈ {64, 128, 256}, K ∈ {4, 8}, T ∈ {0, 1, 7, 64, 128, 1024},
dtype ∈ {bf16, fp16, fp32}. Tolerances set to the input dtype eps; the
fused kernel does softmax in fp32 throughout while the torch fallback does
softmax in input dtype, so the fused path is the more accurate one for
bf16/fp16.

$ python -m pytest test/srt/layers/test_gemma4_fused_routing.py
47 passed in 27.75s

Equivalence probe (819,200 random tokens, E=128, K=8)

Metric Value
Top-K set disagreement (different K experts chosen) 0.00%
Top-K order disagreement (same K, different array order) 20.4%
Max id-aligned weight diff when sets agree 7.5e-3 (≤ bf16 eps)

The fused kernel and torch fallback always pick the exact same K experts;
only the position of each (id, weight) pair within topk_ids[t, :K] may
differ. Order doesn't affect the downstream MoE sum.

End-to-end MMLU (1× B200, BF16, 5-shot, temp=0)

Full MMLU: 57 subjects, 14,042 questions per variant. A/B via local
env-toggle in Gemma4MoE.routing_function that forces the torch fallback.
Each variant run twice from a clean server start.

Variant Run 1 Run 2 Mean
Fused router ON 0.6329 0.6335 0.6332
Torch fallback 0.6288 0.6280 0.6284

Gap = +0.5pp (fused vs torch).

McNemar's test on paired per-question outcomes (Run 1 of each, N=14,042):

Count
Fused-only correct 826
Torch-only correct 769
Both correct 8,061
Neither correct 4,386
McNemar χ² (continuity correction) 1.97
Critical χ² at p=0.05 (df=1) 3.84

Conclusion: no statistically significant difference between fused and torch
at p=0.05.

Speed Tests and Profiling

1× B200 (sm_100a), BF16, Gemma-4-26B-A4B-IT + Gemma4 MTP, triton
attention, --speculative-num-steps 3 --speculative-num-draft-tokens 4 --speculative-eagle-topk 1. Both scenarios use
python -m sglang.bench_serving --backend sglang-oai-chat --random-range-ratio 1.0 --num-prompts 80 --warmup-requests 2 --seed 1.

Workload Baseline This PR Δ
chat random 1000/1000 output tok/s 2729.30 2880.94 +5.6%
chat median TPOT (ms) 21.11 20.70 -1.9%
chat accept length 2.75 2.80 +1.8%
summarization random 8000/1000 output tok/s 1060.98 1108.42 +4.5%

torch.profiler share of the four pre-fusion kernels drops from ~5.8% to
~1.4% of decode GPU time, in line with vLLM's measured share.

Provenance

Checklist


CI States

Latest PR Test (Base): ❌ Run #26703258449
Latest PR Test (Extra): ❌ Run #26703258393

Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@pyc96 pyc96 marked this pull request as ready for review May 27, 2026 21:00
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@pyc96
Copy link
Copy Markdown
Collaborator Author

pyc96 commented May 27, 2026

/tag-and-rerun-ci

@kpham-sgl kpham-sgl self-assigned this May 28, 2026
@pyc96 pyc96 mentioned this pull request May 28, 2026
7 tasks
@pyc96 pyc96 changed the title perf(gemma4 MTP): single-launch fused router (topk + softmax + scale) perf(gemma): sin4gle-launch fused router (topk + softmax + scale) May 28, 2026
@pyc96 pyc96 changed the title perf(gemma): sin4gle-launch fused router (topk + softmax + scale) perf(gemma4): single-launch fused router (topk + softmax + scale) May 28, 2026
Comment thread python/sglang/srt/layers/gemma4_fused_ops.py Outdated
assert per_expert_scale.dim() == 1
assert per_expert_scale.shape[0] == gating_output.shape[1]
T, E = gating_output.shape
assert topk <= E
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a assert failed info?

Comment thread python/sglang/srt/models/gemma4_causal.py Outdated
Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean your code and comments

@pyc96 pyc96 force-pushed the pyc/perf-gemma4-fused-router branch from c85e55b to 0aa9869 Compare May 30, 2026 04:51
@pyc96 pyc96 force-pushed the pyc/perf-gemma4-fused-router branch from 0aa9869 to facc1e1 Compare May 30, 2026 04:53
@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented May 30, 2026

Heads-up: this new test won't actually run in CI as-is, and it's currently what's breaking the build.

test/srt/run_suite.py::_sanity_check_suites() globs every test_*.py under test/srt/ and asserts each one is registered in a suite (or in __not_in_ci__). The new file isn't registered, so build-test (per-commit-cpu-arm64) fails right at startup:

AssertionError: Some test files are not in test suite. If this is intentional, please add the following to `not_in_ci` section:
TestFile("layers/test_gemma4_fused_routing.py"),

That failure then fast-fail cascades into the b200 / h100 / h200 jobs, which is why so many checks are red.

All CUDA tests have been migrated to the test/registered/ registry system (the old suites dict in test/srt/run_suite.py no longer holds CUDA tests). To make this test get auto-discovered and actually run on a GPU runner, move it under test/registered/kernels/ (CUDA kernel correctness) and add a registration call at module top level, e.g.:

from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=60, stage="base-b", runner_config="1-gpu-small")

1-gpu-small runs on a single GPU, which fits this Triton routing kernel test. After that, the registry picks it up automatically — no manual suite editing needed.

(If you intentionally do not want it in CI, the alternative is to add TestFile("layers/test_gemma4_fused_routing.py") to the __not_in_ci__ section of test/srt/run_suite.py — but then it will never execute.)

@pyc96
Copy link
Copy Markdown
Collaborator Author

pyc96 commented May 31, 2026

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants