perf(gemma4): single-launch fused router (topk + softmax + scale)#26502
perf(gemma4): single-launch fused router (topk + softmax + scale)#26502pyc96 wants to merge 6 commits into
Conversation
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:
torch.topk -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
+ at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
softmax -> at::native::cunn_SoftMaxForward<4,float,...>
per_expert_scale[] -> at::native::index_elementwise_kernel<bf16,...>
topk_weights * ... -> at::native::elementwise_kernel<MulFunctor<bf16>>
cast to fp32 -> at::native::elementwise_kernel<copy>
torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels. vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.
This commit ports the algorithm to SGLang:
* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
python/sglang/srt/layers/gemma4_fused_ops.py. One Triton program per
token loads all E logits, packs (bijective(logit_bits), expert_id) into
int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
ids) in (fp32, int32). num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
bf16/fp32 inputs and falls back to the torch path otherwise. Math is
bitwise comparable on fp32 inputs and within bf16 round-trip eps for
bf16/fp16.
Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):
workload baseline this patch delta
chat random 1000/1000 2729.30 tok/s 2880.94 tok/s +5.6%
summariz. random 8000/1000 1060.98 tok/s 1108.42 tok/s +4.5%
chat median TPOT (ms) 21.11 20.70 -1.9%
chat accept length 2.75 2.80 +1.8%
MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.
Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.
Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.
Co-authored-by: Claude
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/tag-and-rerun-ci |
| assert per_expert_scale.dim() == 1 | ||
| assert per_expert_scale.shape[0] == gating_output.shape[1] | ||
| T, E = gating_output.shape | ||
| assert topk <= E |
There was a problem hiding this comment.
Can you add a assert failed info?
BBuf
left a comment
There was a problem hiding this comment.
Please clean your code and comments
c85e55b to
0aa9869
Compare
0aa9869 to
facc1e1
Compare
|
Heads-up: this new test won't actually run in CI as-is, and it's currently what's breaking the build.
That failure then fast-fail cascades into the b200 / h100 / h200 jobs, which is why so many checks are red. All CUDA tests have been migrated to the from sglang.test.ci.ci_register import register_cuda_ci
register_cuda_ci(est_time=60, stage="base-b", runner_config="1-gpu-small")
(If you intentionally do not want it in CI, the alternative is to add |
|
/rerun-failed-ci |
Motivation
Gemma4MoE.routing_functionemits four per-layer GPU kernels for every MoEforward pass (one decode step touches ~30 routed-MoE layers):
at::native::sbtopk::gatherTopK<bf16,uint,2,false>(torch.topk)at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>(torch.topk tie-break)at::native::index_elementwise_kernel<bf16>(per_expert_scale[topk_ids])MulFunctor<bf16>+ softmax + fp32 castvLLM's
_gemma4_routing_kernel(apache-2.0) does the same logical work in a single Triton launch and brings
its share of decode GPU time to ~1.1%. This PR ports the same algorithm
to SGLang, rewritten from scratch in SGLang style.
Modifications
python/sglang/srt/layers/gemma4_fused_ops.py:_gemma4_routing_kernel— one Triton program per token. Packs(bijective(logit_bits), expert_id)into int64; the bijection isanti-monotone on the float value and the
<<32shift puts the int32key's high bit into the int64 sign position, so signed ascending
tl.sortyields the original logits in descending float order in onepass. Softmaxes in fp32, multiplies by
per_expert_scale[topk_ids],writes
(fp32 weights, int32 ids)matching the SGLang TopK outputcontract.
num_warps=1for Gemma4'sE = num_experts = 128(fits in one warp).E <= 1024assert.python/sglang/srt/models/gemma4_causal.py:Gemma4MoE.routing_functioncalls the fused kernel on CUDA fp16/bf16/fp32 router logits; falls back to the original torch path otherwise.
Quantization compatibility
The router projection in
Gemma4Routeris constructed withquant_config=None, so router logits are always emitted in the activationdtype (bf16/fp16/fp32) regardless of expert quantization. The fast path's
dtype guard is therefore safe under any expert quantization that goes
through the standard topk dispatch.
triton/triton_kernel(BF16, FP16)custom_routing_functionviaselect_expertscustom_routing_functionviaselect_expertsflashinfer_trtllm_routed(NVFP4 opt-in)custom_routing_functionviaselect_expertsflashinfer_trtllm(NVFP4 default on SM100)BypassedTopKOutput→ trtllm does routing internallyflashinfer_mxfp4(is_fp4_experts=False)BypassedTopKOutput→ flashinfer mxfp4 kernelOn the BYPASSED paths neither the fused kernel nor the previous torch
fallback runs, so this PR is a strict no-op there.
Accuracy Tests
Unit tests —
test/srt/layers/test_gemma4_fused_routing.py(47 cases)47 shape/dtype combinations against the previous torch routing function:
E ∈ {64, 128, 256},K ∈ {4, 8},T ∈ {0, 1, 7, 64, 128, 1024},dtype ∈ {bf16, fp16, fp32}. Tolerances set to the input dtype eps; thefused kernel does softmax in fp32 throughout while the torch fallback does
softmax in input dtype, so the fused path is the more accurate one for
bf16/fp16.
$ python -m pytest test/srt/layers/test_gemma4_fused_routing.py 47 passed in 27.75sEquivalence probe (819,200 random tokens, E=128, K=8)
The fused kernel and torch fallback always pick the exact same K experts;
only the position of each
(id, weight)pair withintopk_ids[t, :K]maydiffer. Order doesn't affect the downstream MoE sum.
End-to-end MMLU (1× B200, BF16, 5-shot, temp=0)
Full MMLU: 57 subjects, 14,042 questions per variant. A/B via local
env-toggle in
Gemma4MoE.routing_functionthat forces the torch fallback.Each variant run twice from a clean server start.
Gap = +0.5pp (fused vs torch).
McNemar's test on paired per-question outcomes (Run 1 of each, N=14,042):
Conclusion: no statistically significant difference between fused and torch
at p=0.05.
Speed Tests and Profiling
1× B200 (sm_100a), BF16,
Gemma-4-26B-A4B-IT+ Gemma4 MTP, tritonattention,
--speculative-num-steps 3 --speculative-num-draft-tokens 4 --speculative-eagle-topk 1. Both scenarios usepython -m sglang.bench_serving --backend sglang-oai-chat --random-range-ratio 1.0 --num-prompts 80 --warmup-requests 2 --seed 1.torch.profiler share of the four pre-fusion kernels drops from ~5.8% to
~1.4% of decode GPU time, in line with vLLM's measured share.
Provenance
source copy.
Checklist
CI States
Latest PR Test (Base): ❌ Run #26703258449
Latest PR Test (Extra): ❌ Run #26703258393