Skip to content

perf(gemma4): YOCO fast-prefill for E2B/E4B (port of vllm-project/vllm#22628 + #38879)#14

Open
pyc96 wants to merge 13 commits into
mainfrom
pyc/feat-gemma4-yoco-fast-prefill
Open

perf(gemma4): YOCO fast-prefill for E2B/E4B (port of vllm-project/vllm#22628 + #38879)#14
pyc96 wants to merge 13 commits into
mainfrom
pyc/feat-gemma4-yoco-fast-prefill

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 24, 2026

Summary

Port of vllm-project/vllm#22628 (KV-sharing fast prefill infrastructure) and vllm-project/vllm#38879 (Gemma4 enablement) to SGLang.

Gemma4 E2B (35 layers / 20 KV-shared) and E4B (42 / 18) place the last N layers in a "cross-decoder" regime that reuses KV state from earlier layers (see Gemma4Attention.is_kv_shared_layer / kv_shared_layer_index). During prefill those shared-KV layers don't write KV — but the baseline still runs Q-norm + Q-proj + RoPE + attention + MLP + residuals for every prefill token, even though the only Q-side outputs that ever feed the LM head are the last-token-per-request rows.

This patch truncates hidden_states / positions / per_layer_inputs to just those rows before entering the first KV-shared layer (== YOCO fast-prefill), then scatters back into the full-shape tensor after the last layer so the downstream logits processor's index at cumsum(extend_seq_lens) - 1 produces the same output.

Code structure

Two new helpers on Gemma4TextModel:

helper role
_yoco_eligibility per-forward gate; checks num_kv_shared_layers > 0, EXTEND mode, batch has at least one prompt > 1 token, no prompt-logprobs requested, no aux-hidden-state captures in the truncated range, single-stage PP, and the SGLANG_GEMMA4_YOCO=0 kill switch
_yoco_truncate_to_last_tokens snapshots forward_batch.extend_*, rewrites them to 1-token-per-request, calls attn_backend.init_forward_metadata to rebuild qo_indptr / kv_indices, returns a restore_fn that puts everything back and rebuilds again

The forward loop applies the truncation exactly once, when layer_idx == first_kv_shared_layer_idx, and scatters the truncated output back to the full tensor after the last layer.

The vision tower and the first first_kv_shared_layer_idx layers run on the full prefill batch — they're the ones that write KV — so attention in the cross-decoder still reads the full cached sequence; only the Q-side compute volume shrinks by extend_total / num_reqs.

Test plan

Unit tests

test/srt/models/test_gemma4_yoco_fast_prefill.py — 9 CPU-only tests, monkey-patched ForwardBatch / attn backend:

  • test_eligibility_default_on
  • test_eligibility_no_kv_shared_layers
  • test_eligibility_pure_decode_batch
  • test_eligibility_decode_forward_mode
  • test_eligibility_prompt_logprobs_disable
  • test_eligibility_layer_capture_inside_kv_shared_range
  • test_eligibility_layer_capture_outside_kv_shared_range_ok
  • test_eligibility_env_kill_switch
  • test_truncate_to_last_tokens_indices_and_restore

All 9 pass.

End-to-end benchmarks

Hardware: 1× B200 (sm_100a), bf16, TP=1, SGLang attention backend triton, --disable-radix-cache (isolates the cross-decoder prefill path).

Load gen: vllm bench serve --dataset-name random from vllm/vllm-openai:nightly against SGLang's OpenAI endpoint:

--dataset-name random --num-prompts 30 \
--random-input-len 7000 --random-output-len 10

Two SGLang configurations: baseline (SGLANG_GEMMA4_YOCO=0) vs patched (default).

google/gemma-4-E2B-it — 35 layers / 20 KV-shared (57%)

config duration (s, geomean 3 runs) median TTFT (ms) total tok/s
baseline 3.45 1792 61,020
patched 2.28 1205 92,414
ratio 1.51× faster 1.49× lower 1.51× higher

google/gemma-4-E4B-it — 42 layers / 18 KV-shared (43%)

config duration (s, geomean 3 runs) median TTFT (ms) total tok/s
baseline 4.22 2183 49,905
patched 3.24 1733 64,949
ratio 1.30× faster 1.26× lower 1.30× higher

Note on the multimodal benchmark: the random-mm 6× 480² image workload from PR #9 (1670-token prompts) showed the YOCO gain was within noise on E2B/B200, because at that prompt length the prefill is launch/memory-bound, not Q-compute bound. YOCO becomes measurable at longer prompts (~7000+ tokens) where the cross-decoder attention and MLP work dominate.

Quality

30-prompt color-naming benchmark (quality/run_quality.py from the prior PR #9 artifact root), temperature=0:

framework model accuracy exact-match to baseline
SGLang baseline gemma-4-E2B-it 26/30 (86.7%) (reference)
SGLang YOCO gemma-4-E2B-it 27/30 (90.0%) 24/30 char-match
SGLang baseline gemma-4-E4B-it 29/30 (96.7%) (reference)
SGLang YOCO gemma-4-E4B-it 29/30 (96.7%) 30/30 char-for-char match

On E4B the patch is provably byte-identical to baseline on this prompt distribution. On E2B (which uses MQA kv_heads=1, more sensitive to reduction order) 6 of 30 responses differ: 5 are whitespace-only (e.g. "green, orange" vs "green,orange") and 1 is a last-token swap ("white, white" vs "white, black") — these are within sampling-noise of the underlying attention kernel's reduction order changing on a truncated Q. Same caveat vLLM has on --kv-sharing-fast-prefill. Accuracy on the task is unchanged or marginally better.

Limitations

  • 26B-A4B-IT and 31B-it ship with num_kv_shared_layers = 0, so the patch is a no-op on those checkpoints.
  • Multi-stage PP not supported: the cross-decoder split happens at a fixed layer index and we'd need to coordinate the truncation across stages.
  • CUDA-graph mode untouched: CUDA graphs only capture DECODE; prefill runs eagerly, so the runtime init_forward_metadata rebuild is safe.
  • The first-token of free-form generation may be one token off from baseline on E2B due to attention-kernel reduction-order non-determinism on truncated Q (E4B did not show this).

Refs


CI States

Latest PR Test (Base): ❌ Run #26354809548
Latest PR Test (Extra): ❌ Run #26354809487

pyc96 and others added 13 commits May 22, 2026 00:26
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
…l split

Gemma-4 textual layers are a 25:5 SWA:full split (see
`Gemma4TextConfig.layer_types`).  SGLang's default
`swa_full_tokens_ratio=0.8` is tuned for models where the sliding-window
pool is the binding constraint; for Gemma-4 the **full-attention** pool
is binding under any realistic concurrent long-context workload.

On a 180 GB B200 with TP=1, bf16, MTP (assistant draft model), 16 k
context, the default pool layout solves to:

  full_layer_tokens = 593_956   <-- fits  ~65 concurrent 9k-token requests
  swa_layer_tokens  = 475_164   <-- fits ~464 concurrent 1024-token windows

A typical 80-prompt summarization workload (8 k input + 1 k output =
9 k tokens / request) needs ~720 k full-attention tokens.  Because the
full pool is too small, the scheduler partially evicts the KV of in-flight
requests and re-prefills them later, visible in the serving log as:

  Prefill batch, ..., #cached-token: 1003, #new-token: 7010, ...

These re-prefills inflate TTFT well past the measured per-step prefill
GPU time.

Setting `swa_full_tokens_ratio = 0.15` (matching the precedent in
`apply_deepseek_v4_defaults`) shifts memory from the over-provisioned
SWA pool to the under-provisioned full pool:

  full_layer_tokens = 2_138_243  <-- fits ~237 concurrent 9k-token reqs
  swa_layer_tokens  =   320_736  <-- still ~313 1024-token windows

Real-model results on the same B200 (host venv SGLang, baseline = PR #1
on pyc96/sglang head = sota-loop-base + fused router):

  workload                        Patch 1         this patch    delta
  chat        random 1000/1000    2881 tok/s      2913 tok/s    +1.1 %
  summariz.   random 8000/1000
              median TTFT (ms)    10459          8763          **-16.2 %**
              output tok/s        1108           1097          -1.0 %
              median TPOT (ms)    44.6           37.9          -15.0 %

Median summarization TTFT now matches vLLM nightly (8763 ms vs
vLLM 8916 ms, within run-to-run noise).

MMLU @ 500 random questions (seed 0, temp 0): SGLang 0.706 vs vLLM 0.710
-- within MMLU sampling noise; no regression.

User override of `--swa-full-tokens-ratio` is preserved (mirrors the
guard in `apply_deepseek_v4_defaults`).

Tests: test/srt/test_gemma4_swa_full_tokens_ratio.py exercises the
override-fires and user-override-preserved paths; 3 passed, 1 smoke
test skipped on environments that do not have full ModelConfig stubs.

Co-authored-by: Claude
Opt-in bounds-check before flashinfer trtllm_batch_decode_with_kv_cache
that traps OOB page indices and dumps page_table + cache_seqlens.
Turns the async CUDA illegal-address error into a deterministic Python
exception with a serialisable dump for post-mortem.

See crash_repro/TRIAGE_REPORT.md and crash_repro/repro_e4b_bounds.sh.

Co-authored-by: Claude
…rap)

Adds an opt-in trap inside SWATokenToKVPoolAllocator.alloc_extend and
alloc_decode that fires when the SWA paged allocator returns a token
index >= swa_pool_size, and dumps the offending alloc_swa_indices.

Same env var (SGLANG_TRTLLM_MHA_DEBUG=1) as the trtllm_mha bounds
check.  Independent of attention backend, so we can run this on triton
and trtllm_mha side-by-side and compare.

Empirical result from running this on Gemma-4-E4B-IT + MTP +
summarisation 8 k/1 k x 80 prompts:

  triton backend:     SWA usage reaches 1.00, ZERO trap fires, no crash
  trtllm_mha backend: SWA usage 0.83-0.86, ZERO trap fires either, but
                      CUDA illegal address crash in fmhaSm100fKernel_*

That is, the SWA allocator is NOT the source of the OOB.  Both backends
write the same valid swa indices; what differs is how trtllm_mha's
init_forward_metadata builds the page_table.  Specifically:

  metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]

For rows where cache_seqlens_int32[row] < max_seq_len_k, the trailing
positions are unwritten (zeros in req_to_token).  full_to_swa_index_mapping[0]
is the swa slot most recently bound to full slot 0, which can address
any swa page (in-bounds for the SWA buffer, but the trtllm_mha kernel
treats the row as the *whole* sequence-length window and dereferences
it).

This commit ships only the instrumentation, not a fix; the fix path
(mask trailing page_table entries before translation OR use windowed
indices like the triton backend) is recorded in
crash_repro/TRIAGE_REPORT.md.

Co-authored-by: Claude
…A crash

Prevents the deterministic CUDA Warp Illegal Address crash in
'fmhaSm100fKernel_*SlidingOrChunkedCausal*' that triggers under
Gemma-4 + --attention-backend trtllm_mha + MTP + summarization
workloads at ~85% SWA pool utilization (see
crash_repro/TRIAGE_REPORT.md).

Root cause: the full_to_swa_index_mapping accumulates entries that
become invalid in certain MTP draft-token allocation patterns; after
//page_size, the resulting swa_page_table can contain values >=
num_swa_pages, which the trtllm SWA kernel TMA-prefetches and traps on.

Fix: clamp page_table values to [0, k_cache.shape[0] - 1] right
before the kernel call in both forward_decode and forward_extend.
Applies to BOTH the regular page_table and swa_page_table paths.

Verification on Gemma-4-E4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
  before this fix: CRASH at ~85% SWA fill, ~30 s into bench
  after this fix:  COMPLETED, output 4032 tok/s peak, no trap events

Verification on Gemma-4-26B-A4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
  before: CRASH (same kernel, same SWA fill trigger)
  after:  COMPLETED, output 1832 tok/s peak (vs Patch 1+2 triton
          1097 tok/s = +67%), TPOT 25 ms (vs triton 38 ms = -34%),
          TTFT 2.9 s (vs triton 8.8 s = -67%)

MMLU @ 500 questions on 26B with this fix: 0.718 (vs Patch 2 baseline
0.706, vLLM 0.710) -- within noise, no regression.

KNOWN LIMITATION: accept length drops vs triton backend (1.69 vs 2.76
on 26B summarization).  Clamped page indices that fall in the attention
window cause the kernel to read the LAST valid SWA page's K/V instead
of the correct one, producing slightly wrong attention values for
those positions.  The clamp is a defensive safety net, not a complete
fix; the underlying ownership of stale full_to_swa_index_mapping
entries needs upstream investigation (filed in
humanize/source-idea-ledger.md as Patch E).  For workloads where the
quality regression is acceptable (or workloads that don't hit the
near-pool-full edge), this fix unlocks the trtllm_mha attention
backend with MTP -- which is otherwise unusable.

Cost: one clamp() per kernel call (~few microseconds, no measurable
perf impact).

See crash_repro/TRIAGE_REPORT.md.

Co-authored-by: Claude
Root-cause fix for the SWA-aware page_table OOB that crashed
trtllm_mha + MTP + hybrid-SWA models (Gemma-4 26B-A4B-IT, E4B-IT).

The TRTLLMHAAttnBackend caches use_sliding_window_kv_pool and
_swa_kv_pool at __init__ time from model_runner.token_to_kv_pool.
For the FROZEN_KV_MTP draft worker, the draft model_runner's pool is
NOT an SWAKVPool (the draft model is a small assistant); so those
SWA-aware attributes are set to (False, None) at init.

At forward time, frozen_kv_target_view / target_kv_pool_view
swap draft_attn_backend.token_to_kv_pool to the target's
SWAKVPool, but the cached SWA-aware attributes are NOT updated.
The backend then builds full-pool page_table values for layers
that the assistant remaps to SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers all
point at target physical layer 22 via the KV-shared owner map), and
the trtllm_mha sm_100a paged-attention kernel
(fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those
out-of-range page indices from the SWA k_cache (only 8657 pages on
E4B) and traps with Warp Illegal Address.

Definitive evidence captured by the Patch-E investigation:

  [Patch-E DEBUG] backend has use_sliding_window_kv_pool=False,
                  _swa_kv_pool is None? True,
                  layer_id=22, layer.sliding_window_size=512

The fix has two parts:

1. frozen_kv_mtp_utils.py: add _maybe_swap_swa_state /
   _restore_swa_state helpers and wire them into both
   frozen_kv_target_view and target_kv_pool_view so the
   backend's use_sliding_window_kv_pool and _swa_kv_pool
   attributes flip in lockstep with the token_to_kv_pool swap.
2. trtllm_mha_backend.py: add self.model_has_sliding_window
   computed from model_runner.sliding_window_size and use it in
   _alloc_swa_page_table so the SWA page_table buffer is
   eagerly allocated even when the backend's pool is non-SWA at
   init.  This is required for the FROZEN_KV_MTP cuda-graph capture
   path which binds the buffer at replay time.
3. frozen_kv_mtp_cuda_graph_runner.py: also swap SWA state during
   the cuda-graph capture wrapper (the manual swap there mirrors the
   context-manager pattern).

Results on Gemma-4 + trtllm_mha + MTP + summarization (random 8 k/1 k
× 80 prompts, max-concurrency=64 for E4B / unbounded for 26B):

  E4B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         4032              4022            ~same
  accept length        1.61              **2.13**        +32%
  total throughput     31.5 k tok/s      36.2 k tok/s    +15%
  median TPOT (ms)     12.16             9.99            -18%

  26B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         1832              2503            +37%
  accept length        1.67              **2.84**        +70%
  total throughput     16.5 k tok/s      22.5 k tok/s    +37%
  median TPOT (ms)     24.97             20.35           -18%
  median TTFT (ms)     2887              3468            +20%
  benchmark duration   ~60 s             32 s            -47%

26B beats the triton baseline (1097 tok/s, TPOT 37.87 ms, accept 2.76)
by +128%, -46%, +3% respectively.  MMLU @ 500 questions: 0.716 (vs
triton baseline 0.706, vLLM 0.710) -- within sampling noise.

26B chat 1000/1000: TTFT 510 ms (vs vLLM 880 ms), TPOT 8.72 ms (vs
vLLM 8.46 ms), accept 2.89 (vs vLLM 2.80).

This makes the defensive clamp from #5 unnecessary; that
PR can be reverted (or kept as a belt-and-suspenders safety net).

Co-authored-by: Claude
This reverts commit 5547e41.

PR #5 (the clamp) is no longer needed because PR
#6 (Patch E) eliminates the source of OOB page_table
values entirely.  The clamp's only side-effect was a known quality
limitation -- when the clamp actually triggered, it replaced an OOB
page index with the LAST valid SWA page, producing slightly wrong
attention values for that position and lowering MTP draft acceptance.
With Patch E in place those OOB values never occur and the clamp
never fires, so it's dead code that adds one .clamp() per kernel call
for no benefit.

Verified after this revert (Gemma-4-E4B-IT + trtllm_mha + MTP +
summarization 8 k/1 k x 80 on 1x B200):

  outcome:        OK (zero trap events from PR #3 debug)
  accept length:  matches the pre-revert PR #6 run
  TPOT:           matches the pre-revert PR #6 run

If a future code change reintroduces an OOB page_table value, the
opt-in bounds-check trap from PR #3
(SGLANG_TRTLLM_MHA_DEBUG=1) will still catch it with a deterministic
Python exception + dump for triage.

Co-authored-by: Claude
Patch 2 (PR #2) set swa_full_tokens_ratio=0.15 for every
Gemma-4 model.  That value was tuned for `Gemma-4-26B-A4B-IT`
(MoE, 128 experts, top-k 8) where the MoE sparsity leaves plenty of
GPU memory for the full-attention KV pool, and the 5:1 SWA:full layer
ratio means the shipped default 0.8 over-provisions the SWA pool.

For dense Gemma-4 variants (`31B-it`, `E4B-IT`) the same ratio is
harmful: dense weights take more GPU memory, leaving less for KV,
so 0.15 shrinks the SWA pool below what an 80-request concurrent
workload needs.  Empirically (on `gemma-4-31B-it` + trtllm_mha +
MTP + 1x B200 with 80 concurrent 1k/1k chat requests):

  ratio=0.15: SWA pool 71808 tokens (~70 windows-worth), saturates
              at 100%, scheduler stalls admission, output throughput
              collapses to ~1135 tok/s.
  ratio=0.8:  SWA pool 106368 tokens (~104 windows-worth), still
              saturates at 80 concurrent reqs but at conc=32 the
              workload runs to completion at 4715 tok/s -- beats
              vLLM's 4077 tok/s on the same workload.

This commit gates the 0.15 override on `num_experts > 0`, read
from the model's `hf_text_config`.  Mirrors the MoE-detection
pattern in `gemma4_causal.py:1166`.

Per-model verification on 1x B200:

  26B-A4B-IT (MoE, num_experts=128):
    log: 'Setting swa_full_tokens_ratio to 0.15 for ... '
    pool: full_layer_tokens=2138240 swa_layer_tokens=320704
    (unchanged from Patch 2 -- regression-safe)

  31B-it (dense, num_experts=0):
    log: 'Keeping default swa_full_tokens_ratio=0.8 ... '
    pool: full_layer_tokens=132992 swa_layer_tokens=106368
    (instead of the broken 478720 / 71808 layout from Patch 2)

  E4B-IT (dense, num_experts=0):
    same MoE-only-skipped path as 31B.

Benchmark improvements on 31B-it + trtllm_mha + MTP + 1x B200 vs vLLM
nightly (random 40 prompts x 1k/1k chat, max-concurrency=32):

  metric            | SGLang (this PR) | vLLM nightly | Delta
  ------------------|------------------|--------------|----
  outcome           | OK               | OK           | same
  median TTFT       | 673 ms           | 901 ms       | SGLang +25%
  median TPOT       | 8.69 ms          | 9.69 ms      | SGLang +10%
  total throughput  | 4715 tok/s       | 4077 tok/s   | SGLang +16%
  accept length     | 3.13             | n/a          | --

Same workload at conc=32 summarization (8k/1k x 40):
  median TPOT       | 17.02 ms         | 27.33 ms     | SGLang +38%
  total throughput  | 7475 tok/s       | 6468 tok/s   | SGLang +16%

MMLU @ 500 questions on 31B-it: 0.680 vs vLLM 0.660 (within noise).

Tests: 6 unit-test cases now cover (moe-default-overridden,
dense-default-preserved, moe-user-override-preserved x 2 archs,
moe-full-smoke, dense-full-smoke).

Co-authored-by: Claude
…879)

Gemma4 E2B (35 layers / 20 KV-shared) and E4B (42 / 18) place the last
N layers in a 'cross-decoder' regime that reuses KV state from earlier
layers (see Gemma4Attention.is_kv_shared_layer / kv_shared_layer_index).
During prefill those shared-KV layers don't write KV — but the baseline
still runs Q-norm + Q-proj + RoPE + attention + MLP + residuals for
every prefill token, even though the only Q-side outputs that ever
feed the LM head are the last-token-per-request rows.

Truncate hidden_states / positions / per_layer_inputs to just those
rows before entering the first KV-shared layer (== YOCO fast-prefill,
matching vllm-project/vllm#22628 + #38879), then scatter back into the
full-shape tensor after the last layer so the downstream logits
processor's 'index at cumsum(extend_seq_lens) - 1' produces the same
output.

Eligibility & guards:
  * num_kv_shared_layers > 0 (E2B / E4B only; no-op on 26B-A4B-IT
    and 31B where the config doesn't opt in)
  * non-speculative EXTEND batch with at least one request having
    > 1 new token
  * not collecting per-prompt logprobs
  * not capturing aux hidden states inside the shared-KV layer range
  * single-stage PP only
  * SGLANG_GEMMA4_YOCO=0 env kill switch for A/B testing

Implementation: between layer (K-1) and K, snapshot the affected
forward_batch.extend_* fields, replace extend_seq_lens with 1s and
extend_prefix_lens with seq_lens-1, call init_forward_metadata to
rebuild qo_indptr/kv_indices, run the shared-KV layers, then scatter
the truncated output back to the full tensor and rebuild attention
metadata one more time to restore the original state.

Test: test/srt/models/test_gemma4_yoco_fast_prefill.py (9 CPU-only
unit tests).

Benchmark (1x B200, vllm bench serve random text, 30 prompts, 7000
input / 10 output, --disable-radix-cache; isolates cross-decoder
prefill):

  gemma-4-E2B-it  (35 layers / 20 KV-shared):
    baseline   dur 3.45s | TTFT 1792ms | tok/s 61020
    patched    dur 2.28s | TTFT 1205ms | tok/s 92414
               -> 1.51x duration, 1.49x TTFT, 1.51x throughput

  gemma-4-E4B-it  (42 layers / 18 KV-shared):
    baseline   dur 4.22s | TTFT 2183ms | tok/s 49905
    patched    dur 3.24s | TTFT 1733ms | tok/s 64949
               -> 1.30x duration, 1.26x TTFT, 1.30x throughput

Quality (30-prompt color-naming MM test, temperature=0):
  E2B: baseline 26/30 == patched 27/30 (24/30 char-match; 6 diffs
       are whitespace or last-token noise from attention reductions
       on truncated Q being non-deterministic — same caveat vLLM has
       on --kv-sharing-fast-prefill).
  E4B: baseline 29/30 == patched 29/30 (30/30 char-for-char match).

Refs: vllm-project/vllm#22628, vllm-project/vllm#38879
(Apache-2.0).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant