perf(gemma4 MM): batch vision encoder and embed_vision calls (port of vllm-project/vllm#43169)#9
Open
pyc96 wants to merge 13 commits into
Open
perf(gemma4 MM): batch vision encoder and embed_vision calls (port of vllm-project/vllm#43169)#9pyc96 wants to merge 13 commits into
pyc96 wants to merge 13 commits into
Conversation
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:
torch.topk -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
+ at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
softmax -> at::native::cunn_SoftMaxForward<4,float,...>
per_expert_scale[] -> at::native::index_elementwise_kernel<bf16,...>
topk_weights * ... -> at::native::elementwise_kernel<MulFunctor<bf16>>
cast to fp32 -> at::native::elementwise_kernel<copy>
torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels. vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.
This commit ports the algorithm to SGLang:
* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
python/sglang/srt/layers/gemma4_fused_ops.py. One Triton program per
token loads all E logits, packs (bijective(logit_bits), expert_id) into
int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
ids) in (fp32, int32). num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
bf16/fp32 inputs and falls back to the torch path otherwise. Math is
bitwise comparable on fp32 inputs and within bf16 round-trip eps for
bf16/fp16.
Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):
workload baseline this patch delta
chat random 1000/1000 2729.30 tok/s 2880.94 tok/s +5.6%
summariz. random 8000/1000 1060.98 tok/s 1108.42 tok/s +4.5%
chat median TPOT (ms) 21.11 20.70 -1.9%
chat accept length 2.75 2.80 +1.8%
MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.
Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.
Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.
Co-authored-by: Claude
…l split Gemma-4 textual layers are a 25:5 SWA:full split (see `Gemma4TextConfig.layer_types`). SGLang's default `swa_full_tokens_ratio=0.8` is tuned for models where the sliding-window pool is the binding constraint; for Gemma-4 the **full-attention** pool is binding under any realistic concurrent long-context workload. On a 180 GB B200 with TP=1, bf16, MTP (assistant draft model), 16 k context, the default pool layout solves to: full_layer_tokens = 593_956 <-- fits ~65 concurrent 9k-token requests swa_layer_tokens = 475_164 <-- fits ~464 concurrent 1024-token windows A typical 80-prompt summarization workload (8 k input + 1 k output = 9 k tokens / request) needs ~720 k full-attention tokens. Because the full pool is too small, the scheduler partially evicts the KV of in-flight requests and re-prefills them later, visible in the serving log as: Prefill batch, ..., #cached-token: 1003, #new-token: 7010, ... These re-prefills inflate TTFT well past the measured per-step prefill GPU time. Setting `swa_full_tokens_ratio = 0.15` (matching the precedent in `apply_deepseek_v4_defaults`) shifts memory from the over-provisioned SWA pool to the under-provisioned full pool: full_layer_tokens = 2_138_243 <-- fits ~237 concurrent 9k-token reqs swa_layer_tokens = 320_736 <-- still ~313 1024-token windows Real-model results on the same B200 (host venv SGLang, baseline = PR #1 on pyc96/sglang head = sota-loop-base + fused router): workload Patch 1 this patch delta chat random 1000/1000 2881 tok/s 2913 tok/s +1.1 % summariz. random 8000/1000 median TTFT (ms) 10459 8763 **-16.2 %** output tok/s 1108 1097 -1.0 % median TPOT (ms) 44.6 37.9 -15.0 % Median summarization TTFT now matches vLLM nightly (8763 ms vs vLLM 8916 ms, within run-to-run noise). MMLU @ 500 random questions (seed 0, temp 0): SGLang 0.706 vs vLLM 0.710 -- within MMLU sampling noise; no regression. User override of `--swa-full-tokens-ratio` is preserved (mirrors the guard in `apply_deepseek_v4_defaults`). Tests: test/srt/test_gemma4_swa_full_tokens_ratio.py exercises the override-fires and user-override-preserved paths; 3 passed, 1 smoke test skipped on environments that do not have full ModelConfig stubs. Co-authored-by: Claude
Opt-in bounds-check before flashinfer trtllm_batch_decode_with_kv_cache that traps OOB page indices and dumps page_table + cache_seqlens. Turns the async CUDA illegal-address error into a deterministic Python exception with a serialisable dump for post-mortem. See crash_repro/TRIAGE_REPORT.md and crash_repro/repro_e4b_bounds.sh. Co-authored-by: Claude
…rap)
Adds an opt-in trap inside SWATokenToKVPoolAllocator.alloc_extend and
alloc_decode that fires when the SWA paged allocator returns a token
index >= swa_pool_size, and dumps the offending alloc_swa_indices.
Same env var (SGLANG_TRTLLM_MHA_DEBUG=1) as the trtllm_mha bounds
check. Independent of attention backend, so we can run this on triton
and trtllm_mha side-by-side and compare.
Empirical result from running this on Gemma-4-E4B-IT + MTP +
summarisation 8 k/1 k x 80 prompts:
triton backend: SWA usage reaches 1.00, ZERO trap fires, no crash
trtllm_mha backend: SWA usage 0.83-0.86, ZERO trap fires either, but
CUDA illegal address crash in fmhaSm100fKernel_*
That is, the SWA allocator is NOT the source of the OOB. Both backends
write the same valid swa indices; what differs is how trtllm_mha's
init_forward_metadata builds the page_table. Specifically:
metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]
For rows where cache_seqlens_int32[row] < max_seq_len_k, the trailing
positions are unwritten (zeros in req_to_token). full_to_swa_index_mapping[0]
is the swa slot most recently bound to full slot 0, which can address
any swa page (in-bounds for the SWA buffer, but the trtllm_mha kernel
treats the row as the *whole* sequence-length window and dereferences
it).
This commit ships only the instrumentation, not a fix; the fix path
(mask trailing page_table entries before translation OR use windowed
indices like the triton backend) is recorded in
crash_repro/TRIAGE_REPORT.md.
Co-authored-by: Claude
…A crash
Prevents the deterministic CUDA Warp Illegal Address crash in
'fmhaSm100fKernel_*SlidingOrChunkedCausal*' that triggers under
Gemma-4 + --attention-backend trtllm_mha + MTP + summarization
workloads at ~85% SWA pool utilization (see
crash_repro/TRIAGE_REPORT.md).
Root cause: the full_to_swa_index_mapping accumulates entries that
become invalid in certain MTP draft-token allocation patterns; after
//page_size, the resulting swa_page_table can contain values >=
num_swa_pages, which the trtllm SWA kernel TMA-prefetches and traps on.
Fix: clamp page_table values to [0, k_cache.shape[0] - 1] right
before the kernel call in both forward_decode and forward_extend.
Applies to BOTH the regular page_table and swa_page_table paths.
Verification on Gemma-4-E4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
before this fix: CRASH at ~85% SWA fill, ~30 s into bench
after this fix: COMPLETED, output 4032 tok/s peak, no trap events
Verification on Gemma-4-26B-A4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
before: CRASH (same kernel, same SWA fill trigger)
after: COMPLETED, output 1832 tok/s peak (vs Patch 1+2 triton
1097 tok/s = +67%), TPOT 25 ms (vs triton 38 ms = -34%),
TTFT 2.9 s (vs triton 8.8 s = -67%)
MMLU @ 500 questions on 26B with this fix: 0.718 (vs Patch 2 baseline
0.706, vLLM 0.710) -- within noise, no regression.
KNOWN LIMITATION: accept length drops vs triton backend (1.69 vs 2.76
on 26B summarization). Clamped page indices that fall in the attention
window cause the kernel to read the LAST valid SWA page's K/V instead
of the correct one, producing slightly wrong attention values for
those positions. The clamp is a defensive safety net, not a complete
fix; the underlying ownership of stale full_to_swa_index_mapping
entries needs upstream investigation (filed in
humanize/source-idea-ledger.md as Patch E). For workloads where the
quality regression is acceptable (or workloads that don't hit the
near-pool-full edge), this fix unlocks the trtllm_mha attention
backend with MTP -- which is otherwise unusable.
Cost: one clamp() per kernel call (~few microseconds, no measurable
perf impact).
See crash_repro/TRIAGE_REPORT.md.
Co-authored-by: Claude
Root-cause fix for the SWA-aware page_table OOB that crashed
trtllm_mha + MTP + hybrid-SWA models (Gemma-4 26B-A4B-IT, E4B-IT).
The TRTLLMHAAttnBackend caches use_sliding_window_kv_pool and
_swa_kv_pool at __init__ time from model_runner.token_to_kv_pool.
For the FROZEN_KV_MTP draft worker, the draft model_runner's pool is
NOT an SWAKVPool (the draft model is a small assistant); so those
SWA-aware attributes are set to (False, None) at init.
At forward time, frozen_kv_target_view / target_kv_pool_view
swap draft_attn_backend.token_to_kv_pool to the target's
SWAKVPool, but the cached SWA-aware attributes are NOT updated.
The backend then builds full-pool page_table values for layers
that the assistant remaps to SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers all
point at target physical layer 22 via the KV-shared owner map), and
the trtllm_mha sm_100a paged-attention kernel
(fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those
out-of-range page indices from the SWA k_cache (only 8657 pages on
E4B) and traps with Warp Illegal Address.
Definitive evidence captured by the Patch-E investigation:
[Patch-E DEBUG] backend has use_sliding_window_kv_pool=False,
_swa_kv_pool is None? True,
layer_id=22, layer.sliding_window_size=512
The fix has two parts:
1. frozen_kv_mtp_utils.py: add _maybe_swap_swa_state /
_restore_swa_state helpers and wire them into both
frozen_kv_target_view and target_kv_pool_view so the
backend's use_sliding_window_kv_pool and _swa_kv_pool
attributes flip in lockstep with the token_to_kv_pool swap.
2. trtllm_mha_backend.py: add self.model_has_sliding_window
computed from model_runner.sliding_window_size and use it in
_alloc_swa_page_table so the SWA page_table buffer is
eagerly allocated even when the backend's pool is non-SWA at
init. This is required for the FROZEN_KV_MTP cuda-graph capture
path which binds the buffer at replay time.
3. frozen_kv_mtp_cuda_graph_runner.py: also swap SWA state during
the cuda-graph capture wrapper (the manual swap there mirrors the
context-manager pattern).
Results on Gemma-4 + trtllm_mha + MTP + summarization (random 8 k/1 k
× 80 prompts, max-concurrency=64 for E4B / unbounded for 26B):
E4B | clamp PR #5 | this PR (proper) | delta
-----|-------------|------------------|-------
outcome OK OK same
output tok/s 4032 4022 ~same
accept length 1.61 **2.13** +32%
total throughput 31.5 k tok/s 36.2 k tok/s +15%
median TPOT (ms) 12.16 9.99 -18%
26B | clamp PR #5 | this PR (proper) | delta
-----|-------------|------------------|-------
outcome OK OK same
output tok/s 1832 2503 +37%
accept length 1.67 **2.84** +70%
total throughput 16.5 k tok/s 22.5 k tok/s +37%
median TPOT (ms) 24.97 20.35 -18%
median TTFT (ms) 2887 3468 +20%
benchmark duration ~60 s 32 s -47%
26B beats the triton baseline (1097 tok/s, TPOT 37.87 ms, accept 2.76)
by +128%, -46%, +3% respectively. MMLU @ 500 questions: 0.716 (vs
triton baseline 0.706, vLLM 0.710) -- within sampling noise.
26B chat 1000/1000: TTFT 510 ms (vs vLLM 880 ms), TPOT 8.72 ms (vs
vLLM 8.46 ms), accept 2.89 (vs vLLM 2.80).
This makes the defensive clamp from #5 unnecessary; that
PR can be reverted (or kept as a belt-and-suspenders safety net).
Co-authored-by: Claude
This reverts commit 5547e41. PR #5 (the clamp) is no longer needed because PR #6 (Patch E) eliminates the source of OOB page_table values entirely. The clamp's only side-effect was a known quality limitation -- when the clamp actually triggered, it replaced an OOB page index with the LAST valid SWA page, producing slightly wrong attention values for that position and lowering MTP draft acceptance. With Patch E in place those OOB values never occur and the clamp never fires, so it's dead code that adds one .clamp() per kernel call for no benefit. Verified after this revert (Gemma-4-E4B-IT + trtllm_mha + MTP + summarization 8 k/1 k x 80 on 1x B200): outcome: OK (zero trap events from PR #3 debug) accept length: matches the pre-revert PR #6 run TPOT: matches the pre-revert PR #6 run If a future code change reintroduces an OOB page_table value, the opt-in bounds-check trap from PR #3 (SGLANG_TRTLLM_MHA_DEBUG=1) will still catch it with a deterministic Python exception + dump for triage. Co-authored-by: Claude
Patch 2 (PR #2) set swa_full_tokens_ratio=0.15 for every Gemma-4 model. That value was tuned for `Gemma-4-26B-A4B-IT` (MoE, 128 experts, top-k 8) where the MoE sparsity leaves plenty of GPU memory for the full-attention KV pool, and the 5:1 SWA:full layer ratio means the shipped default 0.8 over-provisions the SWA pool. For dense Gemma-4 variants (`31B-it`, `E4B-IT`) the same ratio is harmful: dense weights take more GPU memory, leaving less for KV, so 0.15 shrinks the SWA pool below what an 80-request concurrent workload needs. Empirically (on `gemma-4-31B-it` + trtllm_mha + MTP + 1x B200 with 80 concurrent 1k/1k chat requests): ratio=0.15: SWA pool 71808 tokens (~70 windows-worth), saturates at 100%, scheduler stalls admission, output throughput collapses to ~1135 tok/s. ratio=0.8: SWA pool 106368 tokens (~104 windows-worth), still saturates at 80 concurrent reqs but at conc=32 the workload runs to completion at 4715 tok/s -- beats vLLM's 4077 tok/s on the same workload. This commit gates the 0.15 override on `num_experts > 0`, read from the model's `hf_text_config`. Mirrors the MoE-detection pattern in `gemma4_causal.py:1166`. Per-model verification on 1x B200: 26B-A4B-IT (MoE, num_experts=128): log: 'Setting swa_full_tokens_ratio to 0.15 for ... ' pool: full_layer_tokens=2138240 swa_layer_tokens=320704 (unchanged from Patch 2 -- regression-safe) 31B-it (dense, num_experts=0): log: 'Keeping default swa_full_tokens_ratio=0.8 ... ' pool: full_layer_tokens=132992 swa_layer_tokens=106368 (instead of the broken 478720 / 71808 layout from Patch 2) E4B-IT (dense, num_experts=0): same MoE-only-skipped path as 31B. Benchmark improvements on 31B-it + trtllm_mha + MTP + 1x B200 vs vLLM nightly (random 40 prompts x 1k/1k chat, max-concurrency=32): metric | SGLang (this PR) | vLLM nightly | Delta ------------------|------------------|--------------|---- outcome | OK | OK | same median TTFT | 673 ms | 901 ms | SGLang +25% median TPOT | 8.69 ms | 9.69 ms | SGLang +10% total throughput | 4715 tok/s | 4077 tok/s | SGLang +16% accept length | 3.13 | n/a | -- Same workload at conc=32 summarization (8k/1k x 40): median TPOT | 17.02 ms | 27.33 ms | SGLang +38% total throughput | 7475 tok/s | 6468 tok/s | SGLang +16% MMLU @ 500 questions on 31B-it: 0.680 vs vLLM 0.660 (within noise). Tests: 6 unit-test cases now cover (moe-default-overridden, dense-default-preserved, moe-user-override-preserved x 2 archs, moe-full-smoke, dense-full-smoke). Co-authored-by: Claude
Port of vllm-project/vllm#43169 to SGLang's gemma4_mm.py. Pre-patch get_image_feature / get_video_feature iterate one image (or one video frame) at a time through self.vision_tower(...) and again through self.embed_vision(...) on each pooled output. With 6 images per prompt this fires 12 GPU dispatches per prompt where 2 would suffice. Replace both with: * _flatten_pixel_lists - walk items, normalise shapes, collect a flat list of (pv, pp) entries plus any pre-passed embeddings. * _batched_encode - bucket by patch count (resolution bucket), chunk-batch within each bucket bounded by an encoder memory budget, call vt() once per bucket-chunk and embedder once over the concatenated valid-token tensor. * _gather_mm_features - driver shared by image and video paths. Vision tower (Gemma4VisionEncoder.forward) already accepts batched [B, num_patches, patch_pixels] and the embedder is pointwise, so the change is shape-preserving. Test: test/srt/models/test_gemma4_mm_batched_encoder.py Benchmark (gemma-4-E2B-it, 1x B200, random-mm 6x480 images, 100 prompts, --disable-radix-cache): baseline duration 15.96s | TTFT 10587ms | tok/s 10132 patched duration 10.92s | TTFT 7867ms | tok/s 14817 -> 1.46x duration, 1.34x TTFT, 1.46x throughput Quality (30-prompt colored-image labelling, temp=0): baseline 26/30 == patched 26/30, all 30 responses match character-for-character. Refs: vllm-project/vllm#43169 (algorithm template, Apache-2.0).
This was referenced May 24, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Port of vllm-project/vllm#43169 to SGLang. Removes the per-image (and per-video-frame) Python loop inside
Gemma4ForConditionalGeneration.get_image_feature/get_video_featurethat fires onevision_tower(...)call and oneembed_vision(...)call per image.The SGLang vision tower (
Gemma4VisionEncoder.forward) already accepts batched[B, num_patches, patch_pixels]input, and the embedder is pointwise (RMSNorm + Linear), so the change is shape-preserving (outputs match the pre-patch path).Code structure
Three helpers replace the two duplicated loops:
_flatten_pixel_lists(pv, pp)entries plus any caller-passed-through embeddings_batched_encodevt()once per bucket-chunk, call the embedder once over the concatenated valid tokens_gather_mm_featuresget_image_featureandget_video_featureTwo new attributes on
Gemma4ForConditionalGeneration:_encoder_budget_bytes— lazy-initialised to 5% of total device memory on first encode call._encoder_bytes_per_patch— populated at end ofload_weightsfromvision_config.hidden_size * 2 * num_hidden_layers(matches the vLLM PR formula).Test plan
Unit tests
test/srt/models/test_gemma4_mm_batched_encoder.py(4 CPU-only tests, monkey-patched fake vision tower):test_single_resolution_single_call— 6 same-shape images → exactly 1 encoder forward + 1 embedder forward.test_mixed_resolution_bucketing— small + big images → 2 encoder forwards (one per resolution bucket), still 1 embedder forward.test_chunking_when_max_batch_set— forced low budget → bucket of 6 splits into 3 encoder forwards, 1 embedder forward.test_empty_returns_empty_tensor— items=[] →(0, hidden)tensor.All 4 pass.
End-to-end benchmarks
Hardware: 1× B200 (sm_100a), bf16, TP=1, SGLang attention backend
triton.Load gen:
vllm bench serve --dataset-name random-mmfrom thevllm/vllm-openai:nightlydocker image hitting SGLang's OpenAI endpoint:Model 1:
google/gemma-4-E2B-it(vision tower 768×16)With radix cache enabled (default): parity within noise on this distribution because
random-mmcycles RNG-seeded images so cache hit > 60% on both frameworks.Model 2:
google/Gemma-4-26B-A4B-IT(vision tower 1152×27)Cache-off (encoder runs every request):
Cache-on (default):
On 26B the cache-on TTFT improvement is real (−16.2%) because the heavier vision tower has uncached fragments that benefit from batching even when the prefix cache hits.
Quality
30-prompt color-naming benchmark (
runs/.../quality/run_quality.pyin the artifact root), temperature=0, seed=0:The patch is provably output-identical to baseline on both model sizes.
Run artifacts
runs/20260523_gemma_mm_batched_encoder/in the gemma-mtp workspace.runs/20260523_gemma_26b_mm_batched_encoder/in the gemma-mtp workspace.Refs
python/sglang/srt/models/gemma4_mm.py.