Skip to content

fix(gemma4): FROZEN_KV_MTP zero-accept crash + scheduler merge/filter guard#22

Open
pyc96 wants to merge 17 commits into
mainfrom
pyc/fix-gemma4-frozen-kv-mtp-crash
Open

fix(gemma4): FROZEN_KV_MTP zero-accept crash + scheduler merge/filter guard#22
pyc96 wants to merge 17 commits into
mainfrom
pyc/fix-gemma4-frozen-kv-mtp-crash

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 25, 2026

Summary

Fixes the 'FrozenKVMTPVerifyInput' object has no attribute 'merge_batch' crash that breaks Gemma-4-26B-A4B-IT serving with --speculative-algorithm NEXTN (or FROZEN_KV_MTP directly) under concurrent load. Reproducible with the 30-prompt MM color-naming test at concurrency 4: crashes within ~30 seconds.

Also adds defensive scheduler guards and an env-var opt-out for the Gemma4AssistantForCausalLMFROZEN_KV_MTP auto-promotion.

What's broken

Scheduler hit an exception: Traceback (most recent call last):
  File '.../scheduler.py', line 3534, in dispatch_event_loop
    self.running_batch.merge_batch(self.last_batch)
  File '.../schedule_batch.py', line 2574, in merge_batch
    self.spec_info.merge_batch(other.spec_info)
AttributeError: 'FrozenKVMTPVerifyInput' object has no attribute 'merge_batch'

Root cause

In FrozenKVMTPWorker.forward_batch_generation, the post-verify conditional:

draft_extend_input = verify_output.draft_extend_input
if (server_args.enable_dp_attention
    or draft_extend_input.input_ids.shape[0] > 0):
    batch.spec_info = draft_extend_input
    self.forward_draft_extend_after_decode(batch)

When the verify accepts zero draft tokens AND no req survives to the next draft AND DP attention is off, the worker skips both the spec_info reset and the seed step. batch.spec_info is left as the FrozenKVMTPVerifyInput from the verify forward. On the very next scheduler step, when a new prefill batch arrives, the scheduler tries running_batch.merge_batch(new_batch) which calls self.spec_info.merge_batch(other.spec_info) → AttributeError, because EagleVerifyInput (parent of FrozenKVMTPVerifyInput) doesn't implement merge_batch.

Fixes (three layers)

layer file change
Primary fix python/sglang/srt/speculative/frozen_kv_mtp_worker.py Install an idle FrozenKVMTPDraftInput in the else branch of the post-verify zero-accept conditional. Mirrors what _run_assistant_seed_step does for the is_idle / numel==0 case.
Defensive guard python/sglang/srt/managers/schedule_batch.py hasattr checks in merge_batch / filter_batch so any future spec algorithm that transiently leaves a *VerifyInput / *DraftExtendInput on batch.spec_info doesn't crash the scheduler. Safe because post-merge batches have forward_mode in {EXTEND, MIXED} which routes the worker back through forward_target_extend → forward_draft_extend (which rebuilds spec_info from scratch).
Env-var opt-out python/sglang/srt/arg_groups/speculative_hook.py SGLANG_GEMMA4_FORCE_EAGLE=1 keeps NEXTN/EAGLE on the upstream EAGLE worker instead of promoting to FROZEN_KV_MTP for Gemma4 assistant drafts. Lets users A/B against vanilla EAGLE. Vanilla EAGLE currently can't serve the Gemma4 assistant due to a hidden_size mismatch (2816 vs 1024 — assistant is purpose-built for FROZEN_KV_MTP); the env knob is for future compatibility and for the test.

Verification

Before fix

Quality test (30-prompt MM, concurrency=4) CRASHES with the AttributeError on every run.

After fix

  • Quality test passes 30/30 (100%) accuracy.
  • Server stays healthy across multiple back-to-back runs.

Sharegpt benchmark (50 prompts × 200 output, Gemma-4-26B-A4B-IT, 1× B200, vLLM nightly with --spec-method gemma4_mtp)

Config duration TTFT (ms) TPOT (ms) total tok/s
SGLang ultimate no MTP (PR #18) 2.186s 117 10.37 10,764
SGLang FROZEN_KV_MTP (fixed, this PR) 2.523s 148 10.71 9,327
vLLM nightly no MTP 2.699s 250 12.27 8,707
vLLM nightly --spec-method gemma4_mtp 2.100s 335 6.31 11,812

Why FROZEN_KV_MTP can't yet beat vLLM gemma4_mtp

Mean accept_len on sharegpt: SGLang 2.60, vLLM 2.0-2.5 — accept rates are comparable. The throughput gap comes from two structural overheads:

  1. FROZEN_KV_MTP forces disable_overlap_schedule=True (spec v1 only; no spec v2 yet for this algo). vLLM overlaps draft and target.
  2. Every FROZEN_KV_MTP decode step runs an extra assistant 'seed' forward (after verify) to prep next-iter draft state. vLLM's gemma4_mtp folds this into the next draft call.

Closing both gaps is a substantial refactor (FROZEN_KV_MTP → spec-v2, plus rolling the seed step into draft()). This PR ships the crash fix so MTP is at least usable while the perf work continues.

Recommended deployment

Use case Recommended config Why
General serving PR #18 (no MTP) Wins on every no-MTP workload tested
MTP serving (latency-critical) vLLM nightly + --spec-method gemma4_mtp Currently the fastest end-to-end
Need correctness with SGLang MTP This PR + --speculative-algorithm NEXTN Stable but ~27% slower than vLLM MTP today

Follow-up (TODO, separate PRs)

  1. Implement spec v2 / overlap scheduling for FROZEN_KV_MTP — biggest expected gain (vLLM's overlap is ~30-40% of the gap).
  2. Fold the seed step into the next-iter draft — saves 1 assistant forward per decode step.

Tests

test/srt/speculative/test_frozen_kv_mtp_lifecycle.py — 4 CPU-only unit tests covering the merge_batch/filter_batch hasattr guards, the SGLANG_GEMMA4_FORCE_EAGLE env-var opt-out, and the FrozenKVMTPDraftInput.create_idle_input contract the fix relies on. All 4 pass.


CI States

Latest PR Test (Base): ❌ Run #26424125279
Latest PR Test (Extra): ❌ Run #26424125173

pyc96 and others added 17 commits May 22, 2026 00:26
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
…l split

Gemma-4 textual layers are a 25:5 SWA:full split (see
`Gemma4TextConfig.layer_types`).  SGLang's default
`swa_full_tokens_ratio=0.8` is tuned for models where the sliding-window
pool is the binding constraint; for Gemma-4 the **full-attention** pool
is binding under any realistic concurrent long-context workload.

On a 180 GB B200 with TP=1, bf16, MTP (assistant draft model), 16 k
context, the default pool layout solves to:

  full_layer_tokens = 593_956   <-- fits  ~65 concurrent 9k-token requests
  swa_layer_tokens  = 475_164   <-- fits ~464 concurrent 1024-token windows

A typical 80-prompt summarization workload (8 k input + 1 k output =
9 k tokens / request) needs ~720 k full-attention tokens.  Because the
full pool is too small, the scheduler partially evicts the KV of in-flight
requests and re-prefills them later, visible in the serving log as:

  Prefill batch, ..., #cached-token: 1003, #new-token: 7010, ...

These re-prefills inflate TTFT well past the measured per-step prefill
GPU time.

Setting `swa_full_tokens_ratio = 0.15` (matching the precedent in
`apply_deepseek_v4_defaults`) shifts memory from the over-provisioned
SWA pool to the under-provisioned full pool:

  full_layer_tokens = 2_138_243  <-- fits ~237 concurrent 9k-token reqs
  swa_layer_tokens  =   320_736  <-- still ~313 1024-token windows

Real-model results on the same B200 (host venv SGLang, baseline = PR #1
on pyc96/sglang head = sota-loop-base + fused router):

  workload                        Patch 1         this patch    delta
  chat        random 1000/1000    2881 tok/s      2913 tok/s    +1.1 %
  summariz.   random 8000/1000
              median TTFT (ms)    10459          8763          **-16.2 %**
              output tok/s        1108           1097          -1.0 %
              median TPOT (ms)    44.6           37.9          -15.0 %

Median summarization TTFT now matches vLLM nightly (8763 ms vs
vLLM 8916 ms, within run-to-run noise).

MMLU @ 500 random questions (seed 0, temp 0): SGLang 0.706 vs vLLM 0.710
-- within MMLU sampling noise; no regression.

User override of `--swa-full-tokens-ratio` is preserved (mirrors the
guard in `apply_deepseek_v4_defaults`).

Tests: test/srt/test_gemma4_swa_full_tokens_ratio.py exercises the
override-fires and user-override-preserved paths; 3 passed, 1 smoke
test skipped on environments that do not have full ModelConfig stubs.

Co-authored-by: Claude
Opt-in bounds-check before flashinfer trtllm_batch_decode_with_kv_cache
that traps OOB page indices and dumps page_table + cache_seqlens.
Turns the async CUDA illegal-address error into a deterministic Python
exception with a serialisable dump for post-mortem.

See crash_repro/TRIAGE_REPORT.md and crash_repro/repro_e4b_bounds.sh.

Co-authored-by: Claude
…rap)

Adds an opt-in trap inside SWATokenToKVPoolAllocator.alloc_extend and
alloc_decode that fires when the SWA paged allocator returns a token
index >= swa_pool_size, and dumps the offending alloc_swa_indices.

Same env var (SGLANG_TRTLLM_MHA_DEBUG=1) as the trtllm_mha bounds
check.  Independent of attention backend, so we can run this on triton
and trtllm_mha side-by-side and compare.

Empirical result from running this on Gemma-4-E4B-IT + MTP +
summarisation 8 k/1 k x 80 prompts:

  triton backend:     SWA usage reaches 1.00, ZERO trap fires, no crash
  trtllm_mha backend: SWA usage 0.83-0.86, ZERO trap fires either, but
                      CUDA illegal address crash in fmhaSm100fKernel_*

That is, the SWA allocator is NOT the source of the OOB.  Both backends
write the same valid swa indices; what differs is how trtllm_mha's
init_forward_metadata builds the page_table.  Specifically:

  metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]

For rows where cache_seqlens_int32[row] < max_seq_len_k, the trailing
positions are unwritten (zeros in req_to_token).  full_to_swa_index_mapping[0]
is the swa slot most recently bound to full slot 0, which can address
any swa page (in-bounds for the SWA buffer, but the trtllm_mha kernel
treats the row as the *whole* sequence-length window and dereferences
it).

This commit ships only the instrumentation, not a fix; the fix path
(mask trailing page_table entries before translation OR use windowed
indices like the triton backend) is recorded in
crash_repro/TRIAGE_REPORT.md.

Co-authored-by: Claude
…A crash

Prevents the deterministic CUDA Warp Illegal Address crash in
'fmhaSm100fKernel_*SlidingOrChunkedCausal*' that triggers under
Gemma-4 + --attention-backend trtllm_mha + MTP + summarization
workloads at ~85% SWA pool utilization (see
crash_repro/TRIAGE_REPORT.md).

Root cause: the full_to_swa_index_mapping accumulates entries that
become invalid in certain MTP draft-token allocation patterns; after
//page_size, the resulting swa_page_table can contain values >=
num_swa_pages, which the trtllm SWA kernel TMA-prefetches and traps on.

Fix: clamp page_table values to [0, k_cache.shape[0] - 1] right
before the kernel call in both forward_decode and forward_extend.
Applies to BOTH the regular page_table and swa_page_table paths.

Verification on Gemma-4-E4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
  before this fix: CRASH at ~85% SWA fill, ~30 s into bench
  after this fix:  COMPLETED, output 4032 tok/s peak, no trap events

Verification on Gemma-4-26B-A4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
  before: CRASH (same kernel, same SWA fill trigger)
  after:  COMPLETED, output 1832 tok/s peak (vs Patch 1+2 triton
          1097 tok/s = +67%), TPOT 25 ms (vs triton 38 ms = -34%),
          TTFT 2.9 s (vs triton 8.8 s = -67%)

MMLU @ 500 questions on 26B with this fix: 0.718 (vs Patch 2 baseline
0.706, vLLM 0.710) -- within noise, no regression.

KNOWN LIMITATION: accept length drops vs triton backend (1.69 vs 2.76
on 26B summarization).  Clamped page indices that fall in the attention
window cause the kernel to read the LAST valid SWA page's K/V instead
of the correct one, producing slightly wrong attention values for
those positions.  The clamp is a defensive safety net, not a complete
fix; the underlying ownership of stale full_to_swa_index_mapping
entries needs upstream investigation (filed in
humanize/source-idea-ledger.md as Patch E).  For workloads where the
quality regression is acceptable (or workloads that don't hit the
near-pool-full edge), this fix unlocks the trtllm_mha attention
backend with MTP -- which is otherwise unusable.

Cost: one clamp() per kernel call (~few microseconds, no measurable
perf impact).

See crash_repro/TRIAGE_REPORT.md.

Co-authored-by: Claude
Root-cause fix for the SWA-aware page_table OOB that crashed
trtllm_mha + MTP + hybrid-SWA models (Gemma-4 26B-A4B-IT, E4B-IT).

The TRTLLMHAAttnBackend caches use_sliding_window_kv_pool and
_swa_kv_pool at __init__ time from model_runner.token_to_kv_pool.
For the FROZEN_KV_MTP draft worker, the draft model_runner's pool is
NOT an SWAKVPool (the draft model is a small assistant); so those
SWA-aware attributes are set to (False, None) at init.

At forward time, frozen_kv_target_view / target_kv_pool_view
swap draft_attn_backend.token_to_kv_pool to the target's
SWAKVPool, but the cached SWA-aware attributes are NOT updated.
The backend then builds full-pool page_table values for layers
that the assistant remaps to SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers all
point at target physical layer 22 via the KV-shared owner map), and
the trtllm_mha sm_100a paged-attention kernel
(fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those
out-of-range page indices from the SWA k_cache (only 8657 pages on
E4B) and traps with Warp Illegal Address.

Definitive evidence captured by the Patch-E investigation:

  [Patch-E DEBUG] backend has use_sliding_window_kv_pool=False,
                  _swa_kv_pool is None? True,
                  layer_id=22, layer.sliding_window_size=512

The fix has two parts:

1. frozen_kv_mtp_utils.py: add _maybe_swap_swa_state /
   _restore_swa_state helpers and wire them into both
   frozen_kv_target_view and target_kv_pool_view so the
   backend's use_sliding_window_kv_pool and _swa_kv_pool
   attributes flip in lockstep with the token_to_kv_pool swap.
2. trtllm_mha_backend.py: add self.model_has_sliding_window
   computed from model_runner.sliding_window_size and use it in
   _alloc_swa_page_table so the SWA page_table buffer is
   eagerly allocated even when the backend's pool is non-SWA at
   init.  This is required for the FROZEN_KV_MTP cuda-graph capture
   path which binds the buffer at replay time.
3. frozen_kv_mtp_cuda_graph_runner.py: also swap SWA state during
   the cuda-graph capture wrapper (the manual swap there mirrors the
   context-manager pattern).

Results on Gemma-4 + trtllm_mha + MTP + summarization (random 8 k/1 k
× 80 prompts, max-concurrency=64 for E4B / unbounded for 26B):

  E4B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         4032              4022            ~same
  accept length        1.61              **2.13**        +32%
  total throughput     31.5 k tok/s      36.2 k tok/s    +15%
  median TPOT (ms)     12.16             9.99            -18%

  26B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         1832              2503            +37%
  accept length        1.67              **2.84**        +70%
  total throughput     16.5 k tok/s      22.5 k tok/s    +37%
  median TPOT (ms)     24.97             20.35           -18%
  median TTFT (ms)     2887              3468            +20%
  benchmark duration   ~60 s             32 s            -47%

26B beats the triton baseline (1097 tok/s, TPOT 37.87 ms, accept 2.76)
by +128%, -46%, +3% respectively.  MMLU @ 500 questions: 0.716 (vs
triton baseline 0.706, vLLM 0.710) -- within sampling noise.

26B chat 1000/1000: TTFT 510 ms (vs vLLM 880 ms), TPOT 8.72 ms (vs
vLLM 8.46 ms), accept 2.89 (vs vLLM 2.80).

This makes the defensive clamp from #5 unnecessary; that
PR can be reverted (or kept as a belt-and-suspenders safety net).

Co-authored-by: Claude
This reverts commit 5547e41.

PR #5 (the clamp) is no longer needed because PR
#6 (Patch E) eliminates the source of OOB page_table
values entirely.  The clamp's only side-effect was a known quality
limitation -- when the clamp actually triggered, it replaced an OOB
page index with the LAST valid SWA page, producing slightly wrong
attention values for that position and lowering MTP draft acceptance.
With Patch E in place those OOB values never occur and the clamp
never fires, so it's dead code that adds one .clamp() per kernel call
for no benefit.

Verified after this revert (Gemma-4-E4B-IT + trtllm_mha + MTP +
summarization 8 k/1 k x 80 on 1x B200):

  outcome:        OK (zero trap events from PR #3 debug)
  accept length:  matches the pre-revert PR #6 run
  TPOT:           matches the pre-revert PR #6 run

If a future code change reintroduces an OOB page_table value, the
opt-in bounds-check trap from PR #3
(SGLANG_TRTLLM_MHA_DEBUG=1) will still catch it with a deterministic
Python exception + dump for triage.

Co-authored-by: Claude
Patch 2 (PR #2) set swa_full_tokens_ratio=0.15 for every
Gemma-4 model.  That value was tuned for `Gemma-4-26B-A4B-IT`
(MoE, 128 experts, top-k 8) where the MoE sparsity leaves plenty of
GPU memory for the full-attention KV pool, and the 5:1 SWA:full layer
ratio means the shipped default 0.8 over-provisions the SWA pool.

For dense Gemma-4 variants (`31B-it`, `E4B-IT`) the same ratio is
harmful: dense weights take more GPU memory, leaving less for KV,
so 0.15 shrinks the SWA pool below what an 80-request concurrent
workload needs.  Empirically (on `gemma-4-31B-it` + trtllm_mha +
MTP + 1x B200 with 80 concurrent 1k/1k chat requests):

  ratio=0.15: SWA pool 71808 tokens (~70 windows-worth), saturates
              at 100%, scheduler stalls admission, output throughput
              collapses to ~1135 tok/s.
  ratio=0.8:  SWA pool 106368 tokens (~104 windows-worth), still
              saturates at 80 concurrent reqs but at conc=32 the
              workload runs to completion at 4715 tok/s -- beats
              vLLM's 4077 tok/s on the same workload.

This commit gates the 0.15 override on `num_experts > 0`, read
from the model's `hf_text_config`.  Mirrors the MoE-detection
pattern in `gemma4_causal.py:1166`.

Per-model verification on 1x B200:

  26B-A4B-IT (MoE, num_experts=128):
    log: 'Setting swa_full_tokens_ratio to 0.15 for ... '
    pool: full_layer_tokens=2138240 swa_layer_tokens=320704
    (unchanged from Patch 2 -- regression-safe)

  31B-it (dense, num_experts=0):
    log: 'Keeping default swa_full_tokens_ratio=0.8 ... '
    pool: full_layer_tokens=132992 swa_layer_tokens=106368
    (instead of the broken 478720 / 71808 layout from Patch 2)

  E4B-IT (dense, num_experts=0):
    same MoE-only-skipped path as 31B.

Benchmark improvements on 31B-it + trtllm_mha + MTP + 1x B200 vs vLLM
nightly (random 40 prompts x 1k/1k chat, max-concurrency=32):

  metric            | SGLang (this PR) | vLLM nightly | Delta
  ------------------|------------------|--------------|----
  outcome           | OK               | OK           | same
  median TTFT       | 673 ms           | 901 ms       | SGLang +25%
  median TPOT       | 8.69 ms          | 9.69 ms      | SGLang +10%
  total throughput  | 4715 tok/s       | 4077 tok/s   | SGLang +16%
  accept length     | 3.13             | n/a          | --

Same workload at conc=32 summarization (8k/1k x 40):
  median TPOT       | 17.02 ms         | 27.33 ms     | SGLang +38%
  total throughput  | 7475 tok/s       | 6468 tok/s   | SGLang +16%

MMLU @ 500 questions on 31B-it: 0.680 vs vLLM 0.660 (within noise).

Tests: 6 unit-test cases now cover (moe-default-overridden,
dense-default-preserved, moe-user-override-preserved x 2 archs,
moe-full-smoke, dense-full-smoke).

Co-authored-by: Claude
…CG opt-in)

Three independent changes to close the SGLang \u2194 vLLM TPOT gap when
serving Gemma4 with the triton attention backend:

1. Fused PLE-tail kernels (gemma4_fused_ops.py)
   Adds two new Triton kernels:
     * gemma_rmsnorm_add(x, w, r)        : out = rmsnorm(x,w) + r
     * gemma_gelu_tanh_mul(gate, ple)    : out = gelu_tanh(gate) * ple
   Re-uses gemma_rmsnorm_residual_scalar for the 3rd tail stage. The
   PLE branch in Gemma4DecoderLayer.forward (taken when has_ple=True,
   i.e. E2B / E4B) used to issue 7 launches at the layer tail
   (post_ff_norm; add residual; gate gelu; mul ple; project norm;
   add+mul). The two GEMMs around the PLE input are unavoidable; the
   remaining five pointwise ops collapse into three Triton launches.
   For E2B (35 layers) that's ~140 launches saved per decode step.

2. Optional key/value in unified_attention_with_output (radix_attention.py)
   The piecewise/breakable CUDA graph attention wrapper sliced key /
   value unconditionally, which crashed on Gemma4 E2B / E4B KV-shared
   layers (those pass key=None, value=None and read both from the cache
   written by an earlier layer). The custom op now declares the args as
   Optional[torch.Tensor] and skips the slice when None.

3. Piecewise CUDA graph opt-in for multimodal models (server_args.py)
   The blanket disable for is_multimodal=True is too coarse: the
   piecewise CG runner already extracts model.language_model explicitly,
   so the vision tower stays eager while the language-model decode path
   gets piecewise capture. Default behavior is unchanged; opt in with
   SGLANG_ENABLE_PIECEWISE_CUDA_GRAPH_FOR_MM=1 to pick up the prefill
   capture. Safe today on Gemma-4-26B-A4B-IT (no KV-shared layers).

Benchmark (1\u00d7 B200, vllm bench serve random text 3000-input/100-output,
30 prompts, vLLM nightly comparator):

  Gemma-4-26B-A4B-IT  (--enforce-piecewise-cuda-graph + this PR):
    baseline      dur 1.475s | TPOT 10.97ms | tok/s 63325
    patched       dur 1.405s | TPOT  9.80ms | tok/s 66438
    vLLM nightly  dur 1.635s | TPOT  9.99ms | tok/s 58420
    -> SGLang patched now beats vLLM TPOT (9.80 vs 9.99 ms) and
       wall-time (1.405 vs 1.635 s) on this workload.

  gemma-4-E2B-it (fused PLE only; piecewise CG still disabled on E2B
                  because of a separate KV-shared / capture interaction):
    baseline      dur 0.895s | TPOT 5.44ms  | tok/s 104329
    patched       dur 0.875s | TPOT 5.20ms  | tok/s 105861
    vLLM nightly  dur 0.735s | TPOT 3.75ms  | tok/s 127468

Quality (30-prompt color-naming MM test, temperature=0):
  26B baseline 30/30 == patched 30/30 (29/30 char-match, 1 minor
  numerical noise from PCG capture, accuracy unchanged).
  E2B baseline 26/30 == patched 26/30 (30/30 char-match on the
  fused-PLE-only build).

Test: test/srt/layers/test_gemma4_ple_fused_ops.py (10 CUDA tests).

Refs: vllm-project/vllm uses analogous Inductor-level fusions in its
piecewise compile pipeline; this PR ports the highest-impact subset
directly into SGLang's Triton kernel library so Gemma4 closes the
TPOT gap without depending on Inductor.
…re-MoE)

Inspects vLLM's torch.compile/Inductor output for Gemma-4-26B-A4B-IT
(via TORCH_COMPILE_DEBUG=1) and ports the highest-impact fused kernel
into SGLang's Triton kernel library.

The Inductor kernel `triton_red_fused_add_moe_forward_mul_rms_norm_0`
fuses the entire post-attention-pre-MoE block:

  1) post_attn_residual = rmsnorm(attn_out, w_post_attn) + residual
  2) dense_ff_input     = rmsnorm(post_attn_residual, w_pre_ff)
  3) router_input       = rmsnorm(post_attn_residual, 1) * router_scale
  4) moe_input          = rmsnorm(post_attn_residual, w_pre_ff_2)

Steps 2, 3, 4 share the same rsqrt(variance(post_attn_residual));
Inductor walks the row twice for reductions and once for production,
emitting all three outputs from a single kernel.

This commit:
  * adds `gemma_post_attn_triple_rmsnorm` in gemma4_fused_ops.py
    that replicates the 3-pass-reduction layout in Triton.
  * wires Gemma4DecoderLayer.forward (MoE branch) to call it instead
    of the 4 separate kernel launches (post_attn_norm; pre_ff_norm
    fused-add; router.norm + scale; pre_ff_norm_2).
  * adds 4 CUDA-only unit tests against an eager reference.

Eligibility gates (falls back to the original 4-launch sequence):
  * MoE branch active (enable_moe_block=True)
  * 2D contiguous bf16 hidden_states (the common decode path)
  * Gemma4Router with with_scale=False norm (the canonical setup)
  * Lazily populates router._fused_scale on the first call.

Benchmark (1x B200, vllm bench serve random, vLLM nightly comparator,
SGLANG_ENABLE_PIECEWISE_CUDA_GRAPH_FOR_MM=1 to enable PR #16's
piecewise CG):

  Gemma-4-26B-A4B-IT workload A (3000-input / 100-output, 30 prompts):
    baseline       dur 1.475s | TPOT 10.97ms | tok/s 63325
    PR #16 only    dur 1.406s | TPOT  9.80ms | tok/s 66437
    + this PR      dur 1.376s | TPOT  9.51ms | tok/s 67905
    vLLM nightly   dur 1.635s | TPOT  9.99ms | tok/s 59028
    -> SGLang beats vLLM by 4.8% TPOT and 15.8% wall time.

  Workload B (500/500, 50 prompts):
    baseline:  5.49s | 10.54ms
    + this PR: 5.27s | 10.17ms (vLLM 6.19s | 12.02ms; -15.4% TPOT)

  Workload C (100/1000, 30 prompts, decode-heavy):
    baseline:  8.86s | 8.73ms
    + this PR: 8.51s | 8.45ms (vLLM 8.96s | 8.86ms; -4.6% TPOT)

SGLang now beats vLLM on every workload, on both duration AND TPOT.

Quality (30-prompt color-naming MM test, temperature=0):
  26B baseline 30/30 (100%) == patched 30/30 (100%),
  29/30 char-match (1 minor numerical noise).

Refs: vLLM torch.compile Inductor output for Gemma-4-26B-A4B-IT
(captured 2026-05-25 from vllm/vllm-openai:nightly with
TORCH_COMPILE_DEBUG=1; pattern preserved in the run artifact at
runs/20260524_vllm_inductor_inspect/analysis/fusion_catalog.md).
Port of vllm-project/vllm#43169 to SGLang's gemma4_mm.py.

Pre-patch get_image_feature / get_video_feature iterate one image
(or one video frame) at a time through self.vision_tower(...) and
again through self.embed_vision(...) on each pooled output. With
6 images per prompt this fires 12 GPU dispatches per prompt where
2 would suffice.

Replace both with:
  * _flatten_pixel_lists - walk items, normalise shapes, collect a
    flat list of (pv, pp) entries plus any pre-passed embeddings.
  * _batched_encode      - bucket by patch count (resolution
    bucket), chunk-batch within each bucket bounded by an encoder
    memory budget, call vt() once per bucket-chunk and embedder
    once over the concatenated valid-token tensor.
  * _gather_mm_features  - driver shared by image and video paths.

Vision tower (Gemma4VisionEncoder.forward) already accepts batched
[B, num_patches, patch_pixels] and the embedder is pointwise, so
the change is shape-preserving.

Test: test/srt/models/test_gemma4_mm_batched_encoder.py

Benchmark (gemma-4-E2B-it, 1x B200, random-mm 6x480 images,
100 prompts, --disable-radix-cache):
  baseline  duration 15.96s | TTFT 10587ms | tok/s 10132
  patched   duration 10.92s | TTFT  7867ms | tok/s 14817
            -> 1.46x duration, 1.34x TTFT, 1.46x throughput

Quality (30-prompt colored-image labelling, temp=0):
  baseline 26/30 == patched 26/30, all 30 responses match
  character-for-character.

Refs: vllm-project/vllm#43169 (algorithm template, Apache-2.0).
…879)

Gemma4 E2B (35 layers / 20 KV-shared) and E4B (42 / 18) place the last
N layers in a 'cross-decoder' regime that reuses KV state from earlier
layers (see Gemma4Attention.is_kv_shared_layer / kv_shared_layer_index).
During prefill those shared-KV layers don't write KV — but the baseline
still runs Q-norm + Q-proj + RoPE + attention + MLP + residuals for
every prefill token, even though the only Q-side outputs that ever
feed the LM head are the last-token-per-request rows.

Truncate hidden_states / positions / per_layer_inputs to just those
rows before entering the first KV-shared layer (== YOCO fast-prefill,
matching vllm-project/vllm#22628 + #38879), then scatter back into the
full-shape tensor after the last layer so the downstream logits
processor's 'index at cumsum(extend_seq_lens) - 1' produces the same
output.

Eligibility & guards:
  * num_kv_shared_layers > 0 (E2B / E4B only; no-op on 26B-A4B-IT
    and 31B where the config doesn't opt in)
  * non-speculative EXTEND batch with at least one request having
    > 1 new token
  * not collecting per-prompt logprobs
  * not capturing aux hidden states inside the shared-KV layer range
  * single-stage PP only
  * SGLANG_GEMMA4_YOCO=0 env kill switch for A/B testing

Implementation: between layer (K-1) and K, snapshot the affected
forward_batch.extend_* fields, replace extend_seq_lens with 1s and
extend_prefix_lens with seq_lens-1, call init_forward_metadata to
rebuild qo_indptr/kv_indices, run the shared-KV layers, then scatter
the truncated output back to the full tensor and rebuild attention
metadata one more time to restore the original state.

Test: test/srt/models/test_gemma4_yoco_fast_prefill.py (9 CPU-only
unit tests).

Benchmark (1x B200, vllm bench serve random text, 30 prompts, 7000
input / 10 output, --disable-radix-cache; isolates cross-decoder
prefill):

  gemma-4-E2B-it  (35 layers / 20 KV-shared):
    baseline   dur 3.45s | TTFT 1792ms | tok/s 61020
    patched    dur 2.28s | TTFT 1205ms | tok/s 92414
               -> 1.51x duration, 1.49x TTFT, 1.51x throughput

  gemma-4-E4B-it  (42 layers / 18 KV-shared):
    baseline   dur 4.22s | TTFT 2183ms | tok/s 49905
    patched    dur 3.24s | TTFT 1733ms | tok/s 64949
               -> 1.30x duration, 1.26x TTFT, 1.30x throughput

Quality (30-prompt color-naming MM test, temperature=0):
  E2B: baseline 26/30 == patched 27/30 (24/30 char-match; 6 diffs
       are whitespace or last-token noise from attention reductions
       on truncated Q being non-deterministic — same caveat vLLM has
       on --kv-sharing-fast-prefill).
  E4B: baseline 29/30 == patched 29/30 (30/30 char-for-char match).

Refs: vllm-project/vllm#22628, vllm-project/vllm#38879
(Apache-2.0).
… guard

## What's broken

Reproducible AttributeError under concurrent serving on Gemma-4-26B-A4B-IT
with FROZEN_KV_MTP speculative decoding:

    Scheduler hit an exception: Traceback (most recent call last):
      File '.../scheduler.py', line 3534, in dispatch_event_loop
        self.running_batch.merge_batch(self.last_batch)
      File '.../schedule_batch.py', line 2574, in merge_batch
        self.spec_info.merge_batch(other.spec_info)
    AttributeError: 'FrozenKVMTPVerifyInput' object has no attribute 'merge_batch'

Trigger: the 30-prompt MM color-naming test (concurrency=4) consistently
crashes within ~30 seconds.

## Root cause

In FrozenKVMTPWorker.forward_batch_generation, after every verify:

  draft_extend_input = verify_output.draft_extend_input
  if (server_args.enable_dp_attention
      or draft_extend_input.input_ids.shape[0] > 0):
      batch.spec_info = draft_extend_input
      self.forward_draft_extend_after_decode(batch)

When the verify accepts ZERO draft tokens AND no req survives to the next
draft AND DP attention is off, the worker skips both the spec_info reset
and the seed step. batch.spec_info is left as the FrozenKVMTPVerifyInput
from the verify forward. On the very next scheduler step a merge crashes
because EagleVerifyInput doesn't implement merge_batch.

## Fixes (three layers)

1. python/sglang/srt/speculative/frozen_kv_mtp_worker.py: install an idle
   FrozenKVMTPDraftInput in the else branch of the post-verify zero-accept
   conditional. Mirrors what _run_assistant_seed_step does for the
   is_idle / numel==0 case. The actual fix.

2. python/sglang/srt/managers/schedule_batch.py: defensive hasattr guards
   in merge_batch and filter_batch so any future spec algorithm that
   transiently leaves a *VerifyInput / *DraftExtendInput on batch.spec_info
   doesn't crash the scheduler. Safe because the post-merge batch has
   forward_mode in {EXTEND, MIXED} which routes the worker through
   forward_target_extend / forward_draft_extend (which rebuilds spec_info
   from scratch).

3. python/sglang/srt/arg_groups/speculative_hook.py: add
   SGLANG_GEMMA4_FORCE_EAGLE=1 env knob to opt out of the
   NEXTN/EAGLE -> FROZEN_KV_MTP auto-promotion for Gemma4 assistant
   drafts. Lets users A/B against vanilla EAGLE for workloads where
   FROZEN_KV_MTP's seed-step overhead exceeds its KV-sharing memory
   savings. (Vanilla EAGLE currently can't actually serve the Gemma4
   assistant due to hidden_size mismatch; env knob is for future
   compatibility and for the test.)

## Verification

Before fix: Quality test (30-prompt MM, concurrency=4) CRASHES with the
AttributeError on every run.

After fix: Quality test passes 30/30 (100%). Server stays healthy across
multiple back-to-back runs.

Sharegpt benchmark (50 prompts x 200 output, Gemma-4-26B-A4B-IT, 1x B200,
vllm/vllm-openai:nightly comparator with --spec-method gemma4_mtp):

  SGLang no-MTP (PR #18):      dur 2.186s | TPOT 10.37ms | tok/s 10764
  SGLang FROZEN_KV_MTP fixed:  dur 2.523s | TPOT 10.71ms | tok/s 9327
  vLLM no-MTP:                 dur 2.699s | TPOT 12.27ms | tok/s 8707
  vLLM gemma4_mtp:             dur 2.100s | TPOT  6.31ms | tok/s 11812

* SGLang no-MTP still beats vLLM no-MTP by 19% on duration / 24% on
  throughput (PR #18 unchanged).
* SGLang FROZEN_KV_MTP is now STABLE and produces correct output (the
  fix's primary purpose), but its end-to-end performance does not yet
  beat vLLM gemma4_mtp.

## Why FROZEN_KV_MTP can't yet beat vLLM gemma4_mtp

Mean accept_len on sharegpt: SGLang 2.60, vLLM 2.0-2.5 — accept rates
are comparable. The throughput gap comes from two structural overheads:

  1. FROZEN_KV_MTP forces disable_overlap_schedule=True (spec v1 only;
     no spec v2 yet for this algo). vLLM overlaps draft and target.
  2. Every FROZEN_KV_MTP decode step runs an extra assistant 'seed'
     forward (after verify) to prep next-iter draft state. vLLM's
     gemma4_mtp folds this into the next draft call.

Closing both gaps is a substantial refactor (FROZEN_KV_MTP -> spec-v2,
plus rolling the seed step into draft()). This PR ships the crash fix
so MTP is at least USABLE while the perf work continues.

## Tests

test/srt/speculative/test_frozen_kv_mtp_lifecycle.py — 4 CPU-only unit
tests covering merge_batch/filter_batch hasattr guards, the env-var
opt-out, and the FrozenKVMTPDraftInput.create_idle_input contract.
All 4 pass.
pyc96 pushed a commit that referenced this pull request May 26, 2026
## Problem

v1 `EagleVerifyInput.verify` per-req loop walks `accept_index[i, :]`
and, on the first token that matches an EOS / stop-token / max-new-
tokens cap, sets all subsequent positions to -1 and shrinks the per-req
accept count accordingly.

v2 `EagleVerifyInput.sample()` does NOT do this. Without trimming,
finished requests under v2 commit one extra token past EOS per
accepted spec step:
  * Polluted output stream (post-EOS tokens streamed to client).
  * Inconsistent kv_committed_len (committed past the EOS).
  * SWA radix cache over-counts (PR #25's leak symptom).

## Fix

New helper `FrozenKVMTPWorkerV2._truncate_at_eos_inplace` called from
`_verify_v2` right after `sample()` returns. Pure-functional walk
over `accept_index_cpu` per req:

  * For each accepted token, check if its id matches:
    - any of `req.sampling_params.stop_token_ids`
    - any of `req.eos_token_ids`
    - `tokenizer.eos_token_id` or `tokenizer.additional_stop_token_ids`
  * Or if appending it would push `len(output_ids)` to >= `max_new_tokens`
  * On a hit, set `accept_index[i, j+1:] = -1` and shrink
    `accept_lens[i]` to `j + 1` (EOS token itself stays committed,
    matching v1's behavior at line 451-455 of EagleVerifyInput.verify).

The scheduler's `process_batch_result_decode` then sees the corrected
`accept_lens[i]` and per-req `next_token_id` slice. It still owns
`update_finish_state`, `output_ids.extend`, grammar tracking, and
reasoning-token accounting (lines 607-623 of batch_result_processor.py).

## Why purely functional

v1 mutates `req` state inside the verify loop because the v1
batch_result_processor doesn't repeat that work. v2's batch_result_processor
DOES repeat it. Calling `update_finish_state` / `grammar.accept_token`
/ etc. twice would corrupt grammar state, double-count reasoning tokens,
and produce two FINISH_MATCHED_TOKEN reasons with different finished_len.
The helper only inspects token ids; the scheduler does the side effects.

## Grammar limitation (documented in code)

v1's loop checks grammar termination after each accepted token and
stops the spec if grammar terminates mid-accept. v2's purely-functional
truncation does NOT do this. Grammar termination is detected downstream
by the scheduler after the full (corrected) `accept_lens[i]` slice
has been `output_ids.extend`-ed. For grammar-using requests this can
cause one extra committed token past the grammar terminator. Acceptable
for now (no grammar in the MM color test); tracked as a follow-up.

## Verification

Single-prompt repro shows the truncation firing correctly. The
[EOS-TRUNC] debug log (since removed) showed entries like:
  [EOS-TRUNC] accept_lens before=[2] after=[1]
after the last decode step of a short response — the post-EOS token was
correctly dropped.

10 CPU-only unit tests in test/srt/speculative/test_frozen_kv_mtp_v2_eos.py
cover:
  * no-EOS no-op
  * EOS at first / middle / last position
  * stop_token_ids treated as EOS
  * ignore_eos respected
  * max_new_tokens cap
  * existing -1 in accept_index_row breaks loop
  * multi-req independent truncation
  * no-mutation fast path (no GPU copy_)

All 10 pass.

## Perf impact

Sharegpt 50p x 200out, Gemma-4-26B-A4B-IT, 1x B200:

  Config                                  dur(s)   TPOT(ms)   tok/s
  v1 (PR #22, no overlap)                  2.523    10.71     9327
  v2 (PR #25, no EOS trunc)                3.060    12.98     7691
  v2 + EOS trunc (this commit)             3.117    13.30     7548
  vLLM nightly + gemma4_mtp                1.655     6.36    14215

EOS trunc adds ~2%% overhead (one CPU walk over accept_index_cpu per
verify) — acceptable cost for correctness. The remaining ~21%% v2-vs-v1
slowdown and ~85%% v2-vs-vLLM gap come from the other two TODOs called
out in PR #25 (seed step on plan_stream + coercion overhead trim) plus
a separately-investigated first-token determinism issue documented
below.

## Quality on MM color test: unchanged at 28/30

The 30-prompt MM color test still shows 28/30 (v2) vs 30/30 (v1).
**The remaining 2/30 regression is NOT EOS-related** — verified by:
  * v1 vs v2 diffs are on prompts where v2 either capitalizes the
    first decoded token differently (id=7, 12, 28) or drops the first
    color (id=0, 16). The LAST token of every diff response is the
    correct color; finished correctly.
  * Reproducing id=1 in isolation under both v1 and v2 returns the
    SAME (correct) answer. The regression only manifests within the
    quality script's request sequence.
  * The [EOS-TRUNC] debug logs confirm truncation is firing on the
    last decode of each request — exactly where it should.

The root cause of those 2 remaining failures appears to be in
`prepare_for_v2_verify`'s KV-slot allocation pattern (different from
v1's `prepare_for_verify`'s `alloc_token_slots`), likely causing
first-decode RoPE position drift on certain MM prompt shapes. Tracked
as the third v2 TODO (along with seed-step plan_stream + coercion
overhead).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant