Skip to content

fix(gemma4): only apply swa_full_tokens_ratio=0.15 to MoE variants#8

Open
pyc96 wants to merge 1 commit into
pyc/sota-gemma4-mtp-revert-clampfrom
pyc/sota-gemma4-mtp-swa-ratio-moe-only
Open

fix(gemma4): only apply swa_full_tokens_ratio=0.15 to MoE variants#8
pyc96 wants to merge 1 commit into
pyc/sota-gemma4-mtp-revert-clampfrom
pyc/sota-gemma4-mtp-swa-ratio-moe-only

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 23, 2026

Summary

Refine Patch 2 (#2) so the swa_full_tokens_ratio=0.15
override only applies to MoE Gemma-4 variants
(Gemma-4-26B-A4B-IT), not to dense variants (gemma-4-31B-it,
Gemma-4-E4B-IT).

Stacked on #7 (clamp revert). Staged on pyc96/sglang
only.

Motivation

Patch 2's 0.15 ratio was tuned for the MoE 26B-A4B-IT: MoE sparse
weights leave plenty of GPU memory for KV, so growing the full-attn
pool 3.6x (and shrinking the over-provisioned SWA pool) materially
improves long-context summarization TTFT.

For dense 31B and E4B, the same ratio is harmful:

  • 31B-it (62 GB dense weights) + ratio 0.15 → SWA pool = 71808 tokens
    (~70 windows-worth). Concurrent 80-prompt chat → SWA usage 1.00 →
    scheduler stalls admission → output throughput collapses to 1135 tok/s.

This commit gates the override on num_experts > 0 (read from
hf_text_config), the standard SGLang MoE-detection pattern (mirrors
gemma4_causal.py:1166).

Per-model behavior (verified on 1x B200)

model num_experts branch resulting ratio resulting pool layout
Gemma-4-26B-A4B-IT (MoE) 128 override fires 0.15 full=2138240 swa=320704 (unchanged from Patch 2)
gemma-4-31B-it (dense) 0 override skipped 0.8 (upstream default) full=132992 swa=106368
Gemma-4-E4B-IT (dense) 0 override skipped 0.8 (same skip path)

Log lines:

# 26B-A4B-IT:
Setting swa_full_tokens_ratio to 0.15 for Gemma4ForConditionalGeneration
(MoE Gemma-4 with num_experts=128; the default ratio over-provisions the
SWA pool and under-provisions the full-attention pool, ...).

# 31B-it:
Keeping default swa_full_tokens_ratio=0.8 for Gemma4ForConditionalGeneration
(dense Gemma-4; MoE-specific 0.15 override skipped to avoid SWA pool
starvation).

Benchmark: 31B-it + trtllm_mha + MTP + 1x B200 vs vLLM nightly

random 40 prompts, max-concurrency=32, seed 1:

chat 1k/1k

metric SGLang (this PR) vLLM nightly delta
outcome OK 40/40 OK 40/40 same
median TTFT 673 ms 901 ms SGLang +25%
median TPOT 8.69 ms 9.69 ms SGLang +10%
P99 TPOT 11.10 ms 13.16 ms SGLang +16%
total throughput 4715 tok/s 4077 tok/s SGLang +16%
accept length 3.13 n/a --

summarization 8k/1k

metric SGLang (this PR) vLLM nightly delta
outcome OK 40/40 OK 40/40 same
median TPOT 17.02 ms 27.33 ms SGLang +38%
P99 TPOT 27.13 ms 45.17 ms SGLang +40%
total throughput 7475 tok/s 6468 tok/s SGLang +16%
accept length 3.14 n/a --
median TTFT 13613 ms 10162 ms vLLM +25% (tradeoff)

Quality (MMLU @ N=500, seed=0, temp=0)

accuracy
SGLang (this PR, 31B-it) 0.680
vLLM nightly (31B-it) 0.660

Within sampling noise.

Regression test for 26B-A4B-IT

The 26B path is unchanged. Verified by relaunching 26B and confirming:

  • log: Setting swa_full_tokens_ratio to 0.15 for ... (MoE Gemma-4 with num_experts=128; ...)
  • pool: full_layer_tokens=2138240 swa_layer_tokens=320704 (same as Patch 2)

Tests

test/srt/test_gemma4_swa_full_tokens_ratio.py — 4 cases pass, 2
skipped (full smoke tests skip when env lacks model-config stubs,
same as before):

test_moe_gemma4_default_overridden                       PASSED
test_dense_gemma4_default_preserved                      PASSED
test_user_override_preserved[Gemma4ForCausalLM]          PASSED
test_user_override_preserved[Gemma4ForConditionalGeneration] PASSED
test_full_method_runs_for_moe_gemma4                     SKIPPED (env)
test_full_method_runs_for_dense_gemma4                   SKIPPED (env)

Known limitation (out-of-scope)

The predicate if self.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio
still cannot distinguish "user passed 0.8 explicitly" from "user
didn't pass the flag" (same caveat as the upstream
apply_deepseek_v4_defaults pattern). Fixing this would require a
sentinel default (None) and resolution in __post_init__ — wider
surface area than this PR's MoE gating fix. For MoE Gemma-4, a user
passing 0.8 will still get the override; if they want the upstream
default explicitly, they can pass any other value like 0.81.


CI States

Latest PR Test (Base): ❌ Missing run-ci label -- add it to run CI tests.
Latest PR Test (Extra): ❌ Blocked -- run-ci is required first.

Patch 2 (PR #2) set swa_full_tokens_ratio=0.15 for every
Gemma-4 model.  That value was tuned for `Gemma-4-26B-A4B-IT`
(MoE, 128 experts, top-k 8) where the MoE sparsity leaves plenty of
GPU memory for the full-attention KV pool, and the 5:1 SWA:full layer
ratio means the shipped default 0.8 over-provisions the SWA pool.

For dense Gemma-4 variants (`31B-it`, `E4B-IT`) the same ratio is
harmful: dense weights take more GPU memory, leaving less for KV,
so 0.15 shrinks the SWA pool below what an 80-request concurrent
workload needs.  Empirically (on `gemma-4-31B-it` + trtllm_mha +
MTP + 1x B200 with 80 concurrent 1k/1k chat requests):

  ratio=0.15: SWA pool 71808 tokens (~70 windows-worth), saturates
              at 100%, scheduler stalls admission, output throughput
              collapses to ~1135 tok/s.
  ratio=0.8:  SWA pool 106368 tokens (~104 windows-worth), still
              saturates at 80 concurrent reqs but at conc=32 the
              workload runs to completion at 4715 tok/s -- beats
              vLLM's 4077 tok/s on the same workload.

This commit gates the 0.15 override on `num_experts > 0`, read
from the model's `hf_text_config`.  Mirrors the MoE-detection
pattern in `gemma4_causal.py:1166`.

Per-model verification on 1x B200:

  26B-A4B-IT (MoE, num_experts=128):
    log: 'Setting swa_full_tokens_ratio to 0.15 for ... '
    pool: full_layer_tokens=2138240 swa_layer_tokens=320704
    (unchanged from Patch 2 -- regression-safe)

  31B-it (dense, num_experts=0):
    log: 'Keeping default swa_full_tokens_ratio=0.8 ... '
    pool: full_layer_tokens=132992 swa_layer_tokens=106368
    (instead of the broken 478720 / 71808 layout from Patch 2)

  E4B-IT (dense, num_experts=0):
    same MoE-only-skipped path as 31B.

Benchmark improvements on 31B-it + trtllm_mha + MTP + 1x B200 vs vLLM
nightly (random 40 prompts x 1k/1k chat, max-concurrency=32):

  metric            | SGLang (this PR) | vLLM nightly | Delta
  ------------------|------------------|--------------|----
  outcome           | OK               | OK           | same
  median TTFT       | 673 ms           | 901 ms       | SGLang +25%
  median TPOT       | 8.69 ms          | 9.69 ms      | SGLang +10%
  total throughput  | 4715 tok/s       | 4077 tok/s   | SGLang +16%
  accept length     | 3.13             | n/a          | --

Same workload at conc=32 summarization (8k/1k x 40):
  median TPOT       | 17.02 ms         | 27.33 ms     | SGLang +38%
  total throughput  | 7475 tok/s       | 6468 tok/s   | SGLang +16%

MMLU @ 500 questions on 31B-it: 0.680 vs vLLM 0.660 (within noise).

Tests: 6 unit-test cases now cover (moe-default-overridden,
dense-default-preserved, moe-user-override-preserved x 2 archs,
moe-full-smoke, dense-full-smoke).

Co-authored-by: Claude
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant