perf(gemma4): default swa_full_tokens_ratio=0.15 for the 25:5 SWA:full split#2
Open
pyc96 wants to merge 6 commits into
Open
perf(gemma4): default swa_full_tokens_ratio=0.15 for the 25:5 SWA:full split#2pyc96 wants to merge 6 commits into
pyc96 wants to merge 6 commits into
Conversation
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:
torch.topk -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
+ at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
softmax -> at::native::cunn_SoftMaxForward<4,float,...>
per_expert_scale[] -> at::native::index_elementwise_kernel<bf16,...>
topk_weights * ... -> at::native::elementwise_kernel<MulFunctor<bf16>>
cast to fp32 -> at::native::elementwise_kernel<copy>
torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels. vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.
This commit ports the algorithm to SGLang:
* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
python/sglang/srt/layers/gemma4_fused_ops.py. One Triton program per
token loads all E logits, packs (bijective(logit_bits), expert_id) into
int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
ids) in (fp32, int32). num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
bf16/fp32 inputs and falls back to the torch path otherwise. Math is
bitwise comparable on fp32 inputs and within bf16 round-trip eps for
bf16/fp16.
Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):
workload baseline this patch delta
chat random 1000/1000 2729.30 tok/s 2880.94 tok/s +5.6%
summariz. random 8000/1000 1060.98 tok/s 1108.42 tok/s +4.5%
chat median TPOT (ms) 21.11 20.70 -1.9%
chat accept length 2.75 2.80 +1.8%
MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.
Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.
Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.
Co-authored-by: Claude
…l split Gemma-4 textual layers are a 25:5 SWA:full split (see `Gemma4TextConfig.layer_types`). SGLang's default `swa_full_tokens_ratio=0.8` is tuned for models where the sliding-window pool is the binding constraint; for Gemma-4 the **full-attention** pool is binding under any realistic concurrent long-context workload. On a 180 GB B200 with TP=1, bf16, MTP (assistant draft model), 16 k context, the default pool layout solves to: full_layer_tokens = 593_956 <-- fits ~65 concurrent 9k-token requests swa_layer_tokens = 475_164 <-- fits ~464 concurrent 1024-token windows A typical 80-prompt summarization workload (8 k input + 1 k output = 9 k tokens / request) needs ~720 k full-attention tokens. Because the full pool is too small, the scheduler partially evicts the KV of in-flight requests and re-prefills them later, visible in the serving log as: Prefill batch, ..., #cached-token: 1003, #new-token: 7010, ... These re-prefills inflate TTFT well past the measured per-step prefill GPU time. Setting `swa_full_tokens_ratio = 0.15` (matching the precedent in `apply_deepseek_v4_defaults`) shifts memory from the over-provisioned SWA pool to the under-provisioned full pool: full_layer_tokens = 2_138_243 <-- fits ~237 concurrent 9k-token reqs swa_layer_tokens = 320_736 <-- still ~313 1024-token windows Real-model results on the same B200 (host venv SGLang, baseline = PR #1 on pyc96/sglang head = sota-loop-base + fused router): workload Patch 1 this patch delta chat random 1000/1000 2881 tok/s 2913 tok/s +1.1 % summariz. random 8000/1000 median TTFT (ms) 10459 8763 **-16.2 %** output tok/s 1108 1097 -1.0 % median TPOT (ms) 44.6 37.9 -15.0 % Median summarization TTFT now matches vLLM nightly (8763 ms vs vLLM 8916 ms, within run-to-run noise). MMLU @ 500 random questions (seed 0, temp 0): SGLang 0.706 vs vLLM 0.710 -- within MMLU sampling noise; no regression. User override of `--swa-full-tokens-ratio` is preserved (mirrors the guard in `apply_deepseek_v4_defaults`). Tests: test/srt/test_gemma4_swa_full_tokens_ratio.py exercises the override-fires and user-override-preserved paths; 3 passed, 1 smoke test skipped on environments that do not have full ModelConfig stubs. Co-authored-by: Claude
This was referenced May 23, 2026
a4444a7 to
5f717d2
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Override the SGLang default
swa_full_tokens_ratio(0.8) to0.15forthe Gemma-4 model family. Gemma-4 has a 25:5 SWA:full layer split, which
inverts the assumption that motivates the default ratio.
Stacked on top of #1 (fused router). Base = same branch.
Staged on
pyc96/sglangonly — not opening upstream tosgl-project/sglang.Motivation
Default-ratio pool layout on a 180 GB B200 (TP=1, bf16, MTP, 16k ctx):
A typical 80-prompt summarization workload needs ~720k full-attention
tokens, but the full pool only holds ~594k. The scheduler partially
evicts in-flight KV and re-prefills later; visible in the serving log
as:
These re-prefills inflate TTFT well past the measured per-step prefill
GPU time:
Implementation
python/sglang/srt/server_args.py::_handle_model_specific_adjustments,inside the existing
Gemma4ForConditionalGeneration/Gemma4ForCausalLMbranch, after the attention-backend and MoE-runner adjustments:
Same precedent as
apply_deepseek_v4_defaults(sets ratio to0.1forDSV4 for the same reason). User-passed
--swa-full-tokens-ratioisrespected and not clobbered.
Resulting pool layout on the same B200:
Benchmark (1× B200, host venv, baseline = #1 head)
Both scenarios use
python -m sglang.bench_serving --backend sglang-oai-chat --random-range-ratio 1.0 --num-prompts 80 --warmup-requests 2 --seed 1.Summarization (random 8000/1000, 80 prompts)
Median summarization TTFT now matches vLLM nightly within run-to-run
noise (8763 vs 8916 ms). Median TPOT also improves 15% because the
scheduler can now interleave more prefills with decode.
Chat (random 1000/1000, 80 prompts)
Chat workload doesn't stress the full pool, so the change is neutral
there (good — no regression).
Quality (MMLU, 500 random questions, seed 0, temp 0)
Within MMLU sampling noise. No regression.
Tests
test/srt/test_gemma4_swa_full_tokens_ratio.py:Gemma4ForCausalLMandGemma4ForConditionalGeneration): passesneeds more env stubs)
Existing tests (PR #1 routing tests) still pass:
python -m pytest test/srt/layers/test_gemma4_fused_routing.py # 47 passedProvenance
apply_deepseek_v4_defaultsinpython/sglang/srt/arg_groups/deepseek_v4_hook.pysets the same arg to0.1for the same reason (oversized SWA, undersized full). Thischange follows the same pattern inside the existing Gemma-4 branch.
0.15chosen by solving the cell-size equation for theworkload's
max(input + output) = 9000token requirement under theavailable 121 GB of post-static KV memory on a B200. Anywhere in
[0.10, 0.20]would work for this workload;0.15keeps a smallmargin without leaving too much SWA headroom unused.
CI States
Latest PR Test (Base): ❌ Missing
run-cilabel -- add it to run CI tests.Latest PR Test (Extra): ❌ Blocked --
run-ciis required first.