Skip to content

perf(gemma4): default swa_full_tokens_ratio=0.15 for the 25:5 SWA:full split#2

Open
pyc96 wants to merge 6 commits into
pyc/sota-gemma4-mtp-fused-routingfrom
pyc/sota-gemma4-mtp-swa-ratio
Open

perf(gemma4): default swa_full_tokens_ratio=0.15 for the 25:5 SWA:full split#2
pyc96 wants to merge 6 commits into
pyc/sota-gemma4-mtp-fused-routingfrom
pyc/sota-gemma4-mtp-swa-ratio

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 22, 2026

Summary

Override the SGLang default swa_full_tokens_ratio (0.8) to 0.15 for
the Gemma-4 model family. Gemma-4 has a 25:5 SWA:full layer split, which
inverts the assumption that motivates the default ratio.

Stacked on top of #1 (fused router). Base = same branch.

Staged on pyc96/sglang only — not opening upstream to
sgl-project/sglang.

Motivation

Default-ratio pool layout on a 180 GB B200 (TP=1, bf16, MTP, 16k ctx):

pool tokens fits at 9k/req fits at 1024/req
full 593,956 ~65 reqs ← binding n/a
swa 475,164 n/a ~464 reqs

A typical 80-prompt summarization workload needs ~720k full-attention
tokens, but the full pool only holds ~594k. The scheduler partially
evicts in-flight KV and re-prefills later; visible in the serving log
as:

Prefill batch, ..., #cached-token: 1003, #new-token: 7010, ...

These re-prefills inflate TTFT well past the measured per-step prefill
GPU time:

metric SGLang (ratio=0.8) vLLM nightly
summarization median TTFT 10459 ms 8214 ms
summarization mean TTFT 15074 ms 8466 ms

Implementation

python/sglang/srt/server_args.py::_handle_model_specific_adjustments,
inside the existing Gemma4ForConditionalGeneration / Gemma4ForCausalLM
branch, after the attention-backend and MoE-runner adjustments:

if (
    self.swa_full_tokens_ratio
    == ServerArgs.swa_full_tokens_ratio
):
    self.swa_full_tokens_ratio = 0.15
    logger.info(...)

Same precedent as apply_deepseek_v4_defaults (sets ratio to 0.1 for
DSV4 for the same reason). User-passed --swa-full-tokens-ratio is
respected and not clobbered.

Resulting pool layout on the same B200:

pool tokens fits at 9k/req fits at 1024/req
full 2,138,243 ~237 reqs n/a
swa 320,736 n/a ~313 reqs

Benchmark (1× B200, host venv, baseline = #1 head)

Both scenarios use python -m sglang.bench_serving --backend sglang-oai-chat --random-range-ratio 1.0 --num-prompts 80 --warmup-requests 2 --seed 1.

Summarization (random 8000/1000, 80 prompts)

metric PR #1 (ratio=0.8) this PR (ratio=0.15) Δ vs PR #1 vLLM nightly
median TTFT (ms) 10459 8763 −16.2% 8916 (solo rerun)
output tok/s 1108 1097 −1.0% 2371
median TPOT (ms) 44.6 37.9 −15.0% 18.9
mean TTFT (ms) 15383 16416 +6.7% 8841
accept length 2.73 2.76 +1.1% n/a

Median summarization TTFT now matches vLLM nightly within run-to-run
noise (8763 vs 8916 ms).
Median TPOT also improves 15% because the
scheduler can now interleave more prefills with decode.

Chat (random 1000/1000, 80 prompts)

metric PR #1 this PR Δ
output tok/s 2881 2913 +1.1%
median TPOT (ms) 20.70 20.49 −1.0%
median TTFT (ms) 591 597 +1.0% (noise)
accept length 2.80 2.78 −0.7%

Chat workload doesn't stress the full pool, so the change is neutral
there (good — no regression).

Quality (MMLU, 500 random questions, seed 0, temp 0)

Server Accuracy Parsed
PR #1 baseline 0.708 396/500
This PR 0.706 394/500
vLLM nightly 0.710 392/500

Within MMLU sampling noise. No regression.

Tests

test/srt/test_gemma4_swa_full_tokens_ratio.py:

  • override-fires-on-default: passes
  • user-override-preserved (both Gemma4ForCausalLM and
    Gemma4ForConditionalGeneration): passes
  • full-method smoke test: passes (skips gracefully where ModelConfig
    needs more env stubs)
PATH=.venv/bin:$PATH python -m pytest test/srt/test_gemma4_swa_full_tokens_ratio.py
# 3 passed, 1 skipped

Existing tests (PR #1 routing tests) still pass:

python -m pytest test/srt/layers/test_gemma4_fused_routing.py
# 47 passed

Provenance

  • Precedent: apply_deepseek_v4_defaults in
    python/sglang/srt/arg_groups/deepseek_v4_hook.py sets the same arg to
    0.1 for the same reason (oversized SWA, undersized full). This
    change follows the same pattern inside the existing Gemma-4 branch.
  • Ratio of 0.15 chosen by solving the cell-size equation for the
    workload's max(input + output) = 9000 token requirement under the
    available 121 GB of post-static KV memory on a B200. Anywhere in
    [0.10, 0.20] would work for this workload; 0.15 keeps a small
    margin without leaving too much SWA headroom unused.

CI States

Latest PR Test (Base): ❌ Missing run-ci label -- add it to run CI tests.
Latest PR Test (Extra): ❌ Blocked -- run-ci is required first.

pyc96 added 6 commits May 22, 2026 00:26
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
…l split

Gemma-4 textual layers are a 25:5 SWA:full split (see
`Gemma4TextConfig.layer_types`).  SGLang's default
`swa_full_tokens_ratio=0.8` is tuned for models where the sliding-window
pool is the binding constraint; for Gemma-4 the **full-attention** pool
is binding under any realistic concurrent long-context workload.

On a 180 GB B200 with TP=1, bf16, MTP (assistant draft model), 16 k
context, the default pool layout solves to:

  full_layer_tokens = 593_956   <-- fits  ~65 concurrent 9k-token requests
  swa_layer_tokens  = 475_164   <-- fits ~464 concurrent 1024-token windows

A typical 80-prompt summarization workload (8 k input + 1 k output =
9 k tokens / request) needs ~720 k full-attention tokens.  Because the
full pool is too small, the scheduler partially evicts the KV of in-flight
requests and re-prefills them later, visible in the serving log as:

  Prefill batch, ..., #cached-token: 1003, #new-token: 7010, ...

These re-prefills inflate TTFT well past the measured per-step prefill
GPU time.

Setting `swa_full_tokens_ratio = 0.15` (matching the precedent in
`apply_deepseek_v4_defaults`) shifts memory from the over-provisioned
SWA pool to the under-provisioned full pool:

  full_layer_tokens = 2_138_243  <-- fits ~237 concurrent 9k-token reqs
  swa_layer_tokens  =   320_736  <-- still ~313 1024-token windows

Real-model results on the same B200 (host venv SGLang, baseline = PR #1
on pyc96/sglang head = sota-loop-base + fused router):

  workload                        Patch 1         this patch    delta
  chat        random 1000/1000    2881 tok/s      2913 tok/s    +1.1 %
  summariz.   random 8000/1000
              median TTFT (ms)    10459          8763          **-16.2 %**
              output tok/s        1108           1097          -1.0 %
              median TPOT (ms)    44.6           37.9          -15.0 %

Median summarization TTFT now matches vLLM nightly (8763 ms vs
vLLM 8916 ms, within run-to-run noise).

MMLU @ 500 random questions (seed 0, temp 0): SGLang 0.706 vs vLLM 0.710
-- within MMLU sampling noise; no regression.

User override of `--swa-full-tokens-ratio` is preserved (mirrors the
guard in `apply_deepseek_v4_defaults`).

Tests: test/srt/test_gemma4_swa_full_tokens_ratio.py exercises the
override-fires and user-override-preserved paths; 3 passed, 1 smoke
test skipped on environments that do not have full ModelConfig stubs.

Co-authored-by: Claude
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant