Skip to content

perf(gemma4): add Gemma4ForConditionalGeneration to mm_disabled_models#10

Draft
pyc96 wants to merge 15 commits into
mainfrom
pyc/sota-gemma4-31b-mm-disabled
Draft

perf(gemma4): add Gemma4ForConditionalGeneration to mm_disabled_models#10
pyc96 wants to merge 15 commits into
mainfrom
pyc/sota-gemma4-31b-mm-disabled

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 24, 2026

Summary

Add Gemma4ForConditionalGeneration to mm_disabled_models so that text-only serving of dense Gemma-4 variants (gemma-4-31B-it, gemma-4-E4B-IT) skips loading the vision_tower (27-layer ViT) and audio_tower at server startup. Users who want multimodal still get it by passing --enable-multimodal.

Mirrors the existing treatment of Gemma3ForConditionalGeneration and Llama4ForConditionalGeneration.

Why

Gemma-4-31B-it weights are quoted at "31B params" but the Gemma4ForConditionalGeneration checkpoint also contains a sizable vision/audio encoder for the multimodal variant. On H100 TP=2 (google/gemma-4-31B-it, bf16), the difference is:

Config weight_size/GPU max_total_num_tokens (KV pool)
baseline (multimodal loaded) 31.66 GB 68,713
this PR (text-only) 31.66 GB* 83,552 (+21.6 %)

(*Persistent-buffer metadata for vision_tower layers is still allocated by nn.Module.__init__; the actual state_dict entries are skipped. mem_fraction_static auto-tunes from 0.712 to 0.778 once multimodal initialization decides to skip the visual encoder, which is where the KV-pool win comes from.)

The +21.6 % KV-pool growth lets the scheduler admit more concurrent requests before SWA/full pools saturate.

Benchmark — google/gemma-4-31B-it + Gemma-4 MTP + H100 TP=2 + triton + 80 concurrent prompts

Bench cmd: python -m sglang.bench_serving --backend sglang-oai-chat --random-input-len {1000|8000} --random-output-len 1000 --random-range-ratio 1.0 --num-prompts 80 --warmup-requests 2 --seed 1.

vLLM reference: vllm/vllm-openai:cu129-nightly (0.21.1rc1.dev201+g1fe330398) with --speculative-config {"method":"gemma4_mtp","model":"google/gemma-4-31B-it-assistant","num_speculative_tokens":3}.

chat 1000/1000 n=80 (concurrency=80)

Metric SGLang baseline (MTP, cap=48) SGLang this PR (MTP, cap=80) Δ vLLM nightly MTP
output tok/s 1279 1499 +17.2 % 2827
median TTFT (ms) 10950 2777 −74.6 % 4689
p99 TTFT (ms) 24096 24008 tied 18319
median TPOT (ms) 27.4 29.0 +5.8 % 14.6
accept_length 3.1 3.1 tied n/a

summ 8000/1000 n=80 (concurrency=80)

Metric SGLang baseline (MTP, cap=48) SGLang this PR (MTP, cap=80) Δ vLLM nightly MTP
output tok/s 267 318 +19.1 % 868
median TTFT (ms) 96575 78023 −19.2 % 39840
p99 TTFT (ms) 189782 160420 −15.5 % 80198
median TPOT (ms) 26.5 28.8 +8.7 % 31.8

Same patch on the no-MTP variant (sanity check)

Scenario Metric SGLang baseline no-MTP SGLang this PR no-MTP Δ
chat tok/s 1342 1469 +9.5 %
chat median TTFT (ms) 2815 3215 +14.2 %
summ tok/s 290 331 +14.1 %
summ median TTFT (ms) 103155 89593 −13.2 %

The no-MTP TTFT regresses slightly on chat because the larger KV pool admits more requests into the steady-state batch and pushes the queue head out; mean and p99 TTFTs both improve and so does total throughput.

Quality (MMLU N=500, seed 0, temp 0)

Stack accuracy correct/500
SGLang this PR (mm_disabled, MTP cap=80) 0.780 390
vLLM nightly MTP 0.778 389

Tied within noise (1 question difference, sub-1 % gap from vLLM).

Open questions / known limitations

  • The Persistent buffers not in checkpoint (using default init): ['vision_tower.encoder.layers.*.layer_scalar', ...] warnings on TP0 and TP1 are pre-existing and unchanged by this patch — they come from nn.Module.__init__ allocating buffer slots regardless of enable_multimodal. They are harmless on the text-only path. Eliminating them requires deeper refactor of Gemma4ForConditionalGeneration.__init__ to gate the multimodal sub-module construction on enable_multimodal, which is out of scope for a small fix.
  • Multimodal users who pass --enable-multimodal see unchanged behavior.

Branch context

Stacks on origin/pyc/sota-gemma4-mtp-swa-ratio-moe-only (the MoE-only gate of swa_full_tokens_ratio=0.15 that the dense 31B-it relies on for correct SWA pool sizing) + cherry-picked Patch C extend-tile (717a7a8b3) + JIT bit_cast fix (713cba3ae).

Frameworks tested

  • SGLang fork @ pyc/sota-gemma4-31b-mm-disabled (3a3195b)
  • vLLM vllm/vllm-openai:cu129-nightly (0.21.1rc1.dev201+g1fe330398)
  • TensorRT-LLM not tested (out of scope per campaign)

Artifacts

Full benchmark data, server logs, and per-scenario JSONL files at:
/home/pyc_google_com/dev/gemma-op/agent-pod/runs/20260523_gemma4_31b_it_h100_sota_humanize/


CI States

Latest PR Test (Base): ❌ Run #26353418568
Latest PR Test (Extra): ❌ Run #26353418521

pyc96 added 15 commits May 22, 2026 00:26
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
…l split

Gemma-4 textual layers are a 25:5 SWA:full split (see
`Gemma4TextConfig.layer_types`).  SGLang's default
`swa_full_tokens_ratio=0.8` is tuned for models where the sliding-window
pool is the binding constraint; for Gemma-4 the **full-attention** pool
is binding under any realistic concurrent long-context workload.

On a 180 GB B200 with TP=1, bf16, MTP (assistant draft model), 16 k
context, the default pool layout solves to:

  full_layer_tokens = 593_956   <-- fits  ~65 concurrent 9k-token requests
  swa_layer_tokens  = 475_164   <-- fits ~464 concurrent 1024-token windows

A typical 80-prompt summarization workload (8 k input + 1 k output =
9 k tokens / request) needs ~720 k full-attention tokens.  Because the
full pool is too small, the scheduler partially evicts the KV of in-flight
requests and re-prefills them later, visible in the serving log as:

  Prefill batch, ..., #cached-token: 1003, #new-token: 7010, ...

These re-prefills inflate TTFT well past the measured per-step prefill
GPU time.

Setting `swa_full_tokens_ratio = 0.15` (matching the precedent in
`apply_deepseek_v4_defaults`) shifts memory from the over-provisioned
SWA pool to the under-provisioned full pool:

  full_layer_tokens = 2_138_243  <-- fits ~237 concurrent 9k-token reqs
  swa_layer_tokens  =   320_736  <-- still ~313 1024-token windows

Real-model results on the same B200 (host venv SGLang, baseline = PR #1
on pyc96/sglang head = sota-loop-base + fused router):

  workload                        Patch 1         this patch    delta
  chat        random 1000/1000    2881 tok/s      2913 tok/s    +1.1 %
  summariz.   random 8000/1000
              median TTFT (ms)    10459          8763          **-16.2 %**
              output tok/s        1108           1097          -1.0 %
              median TPOT (ms)    44.6           37.9          -15.0 %

Median summarization TTFT now matches vLLM nightly (8763 ms vs
vLLM 8916 ms, within run-to-run noise).

MMLU @ 500 random questions (seed 0, temp 0): SGLang 0.706 vs vLLM 0.710
-- within MMLU sampling noise; no regression.

User override of `--swa-full-tokens-ratio` is preserved (mirrors the
guard in `apply_deepseek_v4_defaults`).

Tests: test/srt/test_gemma4_swa_full_tokens_ratio.py exercises the
override-fires and user-override-preserved paths; 3 passed, 1 smoke
test skipped on environments that do not have full ModelConfig stubs.

Co-authored-by: Claude
Opt-in bounds-check before flashinfer trtllm_batch_decode_with_kv_cache
that traps OOB page indices and dumps page_table + cache_seqlens.
Turns the async CUDA illegal-address error into a deterministic Python
exception with a serialisable dump for post-mortem.

See crash_repro/TRIAGE_REPORT.md and crash_repro/repro_e4b_bounds.sh.

Co-authored-by: Claude
…rap)

Adds an opt-in trap inside SWATokenToKVPoolAllocator.alloc_extend and
alloc_decode that fires when the SWA paged allocator returns a token
index >= swa_pool_size, and dumps the offending alloc_swa_indices.

Same env var (SGLANG_TRTLLM_MHA_DEBUG=1) as the trtllm_mha bounds
check.  Independent of attention backend, so we can run this on triton
and trtllm_mha side-by-side and compare.

Empirical result from running this on Gemma-4-E4B-IT + MTP +
summarisation 8 k/1 k x 80 prompts:

  triton backend:     SWA usage reaches 1.00, ZERO trap fires, no crash
  trtllm_mha backend: SWA usage 0.83-0.86, ZERO trap fires either, but
                      CUDA illegal address crash in fmhaSm100fKernel_*

That is, the SWA allocator is NOT the source of the OOB.  Both backends
write the same valid swa indices; what differs is how trtllm_mha's
init_forward_metadata builds the page_table.  Specifically:

  metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]

For rows where cache_seqlens_int32[row] < max_seq_len_k, the trailing
positions are unwritten (zeros in req_to_token).  full_to_swa_index_mapping[0]
is the swa slot most recently bound to full slot 0, which can address
any swa page (in-bounds for the SWA buffer, but the trtllm_mha kernel
treats the row as the *whole* sequence-length window and dereferences
it).

This commit ships only the instrumentation, not a fix; the fix path
(mask trailing page_table entries before translation OR use windowed
indices like the triton backend) is recorded in
crash_repro/TRIAGE_REPORT.md.

Co-authored-by: Claude
…A crash

Prevents the deterministic CUDA Warp Illegal Address crash in
'fmhaSm100fKernel_*SlidingOrChunkedCausal*' that triggers under
Gemma-4 + --attention-backend trtllm_mha + MTP + summarization
workloads at ~85% SWA pool utilization (see
crash_repro/TRIAGE_REPORT.md).

Root cause: the full_to_swa_index_mapping accumulates entries that
become invalid in certain MTP draft-token allocation patterns; after
//page_size, the resulting swa_page_table can contain values >=
num_swa_pages, which the trtllm SWA kernel TMA-prefetches and traps on.

Fix: clamp page_table values to [0, k_cache.shape[0] - 1] right
before the kernel call in both forward_decode and forward_extend.
Applies to BOTH the regular page_table and swa_page_table paths.

Verification on Gemma-4-E4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
  before this fix: CRASH at ~85% SWA fill, ~30 s into bench
  after this fix:  COMPLETED, output 4032 tok/s peak, no trap events

Verification on Gemma-4-26B-A4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
  before: CRASH (same kernel, same SWA fill trigger)
  after:  COMPLETED, output 1832 tok/s peak (vs Patch 1+2 triton
          1097 tok/s = +67%), TPOT 25 ms (vs triton 38 ms = -34%),
          TTFT 2.9 s (vs triton 8.8 s = -67%)

MMLU @ 500 questions on 26B with this fix: 0.718 (vs Patch 2 baseline
0.706, vLLM 0.710) -- within noise, no regression.

KNOWN LIMITATION: accept length drops vs triton backend (1.69 vs 2.76
on 26B summarization).  Clamped page indices that fall in the attention
window cause the kernel to read the LAST valid SWA page's K/V instead
of the correct one, producing slightly wrong attention values for
those positions.  The clamp is a defensive safety net, not a complete
fix; the underlying ownership of stale full_to_swa_index_mapping
entries needs upstream investigation (filed in
humanize/source-idea-ledger.md as Patch E).  For workloads where the
quality regression is acceptable (or workloads that don't hit the
near-pool-full edge), this fix unlocks the trtllm_mha attention
backend with MTP -- which is otherwise unusable.

Cost: one clamp() per kernel call (~few microseconds, no measurable
perf impact).

See crash_repro/TRIAGE_REPORT.md.

Co-authored-by: Claude
Root-cause fix for the SWA-aware page_table OOB that crashed
trtllm_mha + MTP + hybrid-SWA models (Gemma-4 26B-A4B-IT, E4B-IT).

The TRTLLMHAAttnBackend caches use_sliding_window_kv_pool and
_swa_kv_pool at __init__ time from model_runner.token_to_kv_pool.
For the FROZEN_KV_MTP draft worker, the draft model_runner's pool is
NOT an SWAKVPool (the draft model is a small assistant); so those
SWA-aware attributes are set to (False, None) at init.

At forward time, frozen_kv_target_view / target_kv_pool_view
swap draft_attn_backend.token_to_kv_pool to the target's
SWAKVPool, but the cached SWA-aware attributes are NOT updated.
The backend then builds full-pool page_table values for layers
that the assistant remaps to SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers all
point at target physical layer 22 via the KV-shared owner map), and
the trtllm_mha sm_100a paged-attention kernel
(fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those
out-of-range page indices from the SWA k_cache (only 8657 pages on
E4B) and traps with Warp Illegal Address.

Definitive evidence captured by the Patch-E investigation:

  [Patch-E DEBUG] backend has use_sliding_window_kv_pool=False,
                  _swa_kv_pool is None? True,
                  layer_id=22, layer.sliding_window_size=512

The fix has two parts:

1. frozen_kv_mtp_utils.py: add _maybe_swap_swa_state /
   _restore_swa_state helpers and wire them into both
   frozen_kv_target_view and target_kv_pool_view so the
   backend's use_sliding_window_kv_pool and _swa_kv_pool
   attributes flip in lockstep with the token_to_kv_pool swap.
2. trtllm_mha_backend.py: add self.model_has_sliding_window
   computed from model_runner.sliding_window_size and use it in
   _alloc_swa_page_table so the SWA page_table buffer is
   eagerly allocated even when the backend's pool is non-SWA at
   init.  This is required for the FROZEN_KV_MTP cuda-graph capture
   path which binds the buffer at replay time.
3. frozen_kv_mtp_cuda_graph_runner.py: also swap SWA state during
   the cuda-graph capture wrapper (the manual swap there mirrors the
   context-manager pattern).

Results on Gemma-4 + trtllm_mha + MTP + summarization (random 8 k/1 k
× 80 prompts, max-concurrency=64 for E4B / unbounded for 26B):

  E4B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         4032              4022            ~same
  accept length        1.61              **2.13**        +32%
  total throughput     31.5 k tok/s      36.2 k tok/s    +15%
  median TPOT (ms)     12.16             9.99            -18%

  26B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         1832              2503            +37%
  accept length        1.67              **2.84**        +70%
  total throughput     16.5 k tok/s      22.5 k tok/s    +37%
  median TPOT (ms)     24.97             20.35           -18%
  median TTFT (ms)     2887              3468            +20%
  benchmark duration   ~60 s             32 s            -47%

26B beats the triton baseline (1097 tok/s, TPOT 37.87 ms, accept 2.76)
by +128%, -46%, +3% respectively.  MMLU @ 500 questions: 0.716 (vs
triton baseline 0.706, vLLM 0.710) -- within sampling noise.

26B chat 1000/1000: TTFT 510 ms (vs vLLM 880 ms), TPOT 8.72 ms (vs
vLLM 8.46 ms), accept 2.89 (vs vLLM 2.80).

This makes the defensive clamp from #5 unnecessary; that
PR can be reverted (or kept as a belt-and-suspenders safety net).

Co-authored-by: Claude
This reverts commit 5547e41.

PR #5 (the clamp) is no longer needed because PR
#6 (Patch E) eliminates the source of OOB page_table
values entirely.  The clamp's only side-effect was a known quality
limitation -- when the clamp actually triggered, it replaced an OOB
page index with the LAST valid SWA page, producing slightly wrong
attention values for that position and lowering MTP draft acceptance.
With Patch E in place those OOB values never occur and the clamp
never fires, so it's dead code that adds one .clamp() per kernel call
for no benefit.

Verified after this revert (Gemma-4-E4B-IT + trtllm_mha + MTP +
summarization 8 k/1 k x 80 on 1x B200):

  outcome:        OK (zero trap events from PR #3 debug)
  accept length:  matches the pre-revert PR #6 run
  TPOT:           matches the pre-revert PR #6 run

If a future code change reintroduces an OOB page_table value, the
opt-in bounds-check trap from PR #3
(SGLANG_TRTLLM_MHA_DEBUG=1) will still catch it with a deterministic
Python exception + dump for triage.

Co-authored-by: Claude
Patch 2 (PR #2) set swa_full_tokens_ratio=0.15 for every
Gemma-4 model.  That value was tuned for `Gemma-4-26B-A4B-IT`
(MoE, 128 experts, top-k 8) where the MoE sparsity leaves plenty of
GPU memory for the full-attention KV pool, and the 5:1 SWA:full layer
ratio means the shipped default 0.8 over-provisions the SWA pool.

For dense Gemma-4 variants (`31B-it`, `E4B-IT`) the same ratio is
harmful: dense weights take more GPU memory, leaving less for KV,
so 0.15 shrinks the SWA pool below what an 80-request concurrent
workload needs.  Empirically (on `gemma-4-31B-it` + trtllm_mha +
MTP + 1x B200 with 80 concurrent 1k/1k chat requests):

  ratio=0.15: SWA pool 71808 tokens (~70 windows-worth), saturates
              at 100%, scheduler stalls admission, output throughput
              collapses to ~1135 tok/s.
  ratio=0.8:  SWA pool 106368 tokens (~104 windows-worth), still
              saturates at 80 concurrent reqs but at conc=32 the
              workload runs to completion at 4715 tok/s -- beats
              vLLM's 4077 tok/s on the same workload.

This commit gates the 0.15 override on `num_experts > 0`, read
from the model's `hf_text_config`.  Mirrors the MoE-detection
pattern in `gemma4_causal.py:1166`.

Per-model verification on 1x B200:

  26B-A4B-IT (MoE, num_experts=128):
    log: 'Setting swa_full_tokens_ratio to 0.15 for ... '
    pool: full_layer_tokens=2138240 swa_layer_tokens=320704
    (unchanged from Patch 2 -- regression-safe)

  31B-it (dense, num_experts=0):
    log: 'Keeping default swa_full_tokens_ratio=0.8 ... '
    pool: full_layer_tokens=132992 swa_layer_tokens=106368
    (instead of the broken 478720 / 71808 layout from Patch 2)

  E4B-IT (dense, num_experts=0):
    same MoE-only-skipped path as 31B.

Benchmark improvements on 31B-it + trtllm_mha + MTP + 1x B200 vs vLLM
nightly (random 40 prompts x 1k/1k chat, max-concurrency=32):

  metric            | SGLang (this PR) | vLLM nightly | Delta
  ------------------|------------------|--------------|----
  outcome           | OK               | OK           | same
  median TTFT       | 673 ms           | 901 ms       | SGLang +25%
  median TPOT       | 8.69 ms          | 9.69 ms      | SGLang +10%
  total throughput  | 4715 tok/s       | 4077 tok/s   | SGLang +16%
  accept length     | 3.13             | n/a          | --

Same workload at conc=32 summarization (8k/1k x 40):
  median TPOT       | 17.02 ms         | 27.33 ms     | SGLang +38%
  total throughput  | 7475 tok/s       | 6468 tok/s   | SGLang +16%

MMLU @ 500 questions on 31B-it: 0.680 vs vLLM 0.660 (within noise).

Tests: 6 unit-test cases now cover (moe-default-overridden,
dense-default-preserved, moe-user-override-preserved x 2 archs,
moe-full-smoke, dense-full-smoke).

Co-authored-by: Claude
The Hopper branch in '_get_block_sizes_for_extend_attention' picked
(BLOCK_M=128, BLOCK_N=64, num_warps=8, num_stages=1) for every Lq<=256.
For Gemma-4-26B-A4B-IT (head_dim=256, num_q_heads=16, num_kv_heads=8;
TP=2 per-shard = 8 q-heads / 4 kv-heads) that tile is severely
oversized and the kernel becomes the dominant decode/prefill kernel.

Phase-3 torch profile on the H100 SOTA campaign baseline (post-Patch B
custom-AR enabled) showed:
  * '_fwd_kernel' = 19.2% of decode GPU time (25.6 ms / 133 ms)
  * '_fwd_kernel' = 60.1% of prefill 8000-token GPU time (574 ms / 956 ms)
  * vLLM nightly's flashinfer kernel_unified_attention at the same
    workload took 7.2 ms decode and 381 ms prefill 8k.

Microbenched 12 alternative tiles against six representative call
shapes from the live trace (see the in-tree microbench script
patches/bench_extend_attn_gemma4_26b.py in the H100 run artifact
dir).  Winners:

  shape (bs, ext, prefix, sw)         legacy (128,64,w8,s1)  new          delta
  ----------------------------------  --------------------- ------------  -----
  prefill long  bs=1  ext=8192 sw=-1       2656.80 us       1907.64 us   -28.2 %  (32,64,w4,s2)
  prefill chat  bs=1  ext=1000 sw=-1        128.21 us         55.98 us   -56.3 %  (32,64,w4,s2)
  verify chat   bs=32 ext=4 pf=1000 sw=1024 616.48 us        144.01 us   -76.6 %  (16,64,w4,s2)
  verify summ   bs=32 ext=4 pf=8000 sw=1024 1075.79 us       191.49 us   -82.2 %  (16,64,w4,s2)
  verify burst  bs=32 ext=4 pf=64   sw=1024  93.98 us         22.10 us   -76.5 %  (32,32,w4,s2)
  prefill multi bs=4  ext=1000 sw=-1        225.33 us        153.53 us   -31.9 %  (32,64,w4,s2)

The two regimes (single-seq long-extend prefill vs high-bs short-verify
MTP step) want different tiles.  Gate on batch_size >= 8:
  * bs <  8 ('single-seq long-extend prefill'):  (32, 64, w4, s2)
  * bs >= 8 ('MTP verify / chunked-prefill'):    (16, 64, w4, s2)

Plumbing changes:
  * '_get_block_sizes_for_extend_attention' now takes 'batch_size'
    (kw-only) and returns 'num_stages' as well.
  * Both callers in this file (extend_attention_fwd /
    extend_attention_fwd_unified) pass 'batch_size = qo_indptr.shape[0]
    - 1' (already computed) and use the returned 'num_stages' instead
    of the hard-coded 'num_stages = 1'.

Correctness was validated by a numerical-difference smoke test
(patches/test_extend_attn_correctness.py): per-element max-abs / ref-max
< 2e-3 across all six call shapes (bf16 noise).

Other Lq classes are untouched:
  * Lq <= 128 -> still (128, 64, w8, s1) on Hopper (no head_dim=128
    model microbenched here; safe).
  * Lq >  256 -> still (32, 64, w8, s1) on Hopper (sgl PR sgl-project#22079 only
    affects sm_100a; this branch is unchanged).
  * sm120 / sm100a / Ampere / older: unchanged.

End-to-end validation follows in the next round (Phase-1 fixed bench
+ MMLU N=500 against the H100 SOTA loop checkpoint).
``std::bit_cast`` is a C++20 library feature added in libstdc++ 3.4.29
(gcc 11.1). On Debian 11's gcc-10 (libstdc++ 3.4.28) the JIT
compilation of these three kernels fails with::

    error: namespace "std" has no member "bit_cast"

making ``--disable-custom-all-reduce`` mandatory on that host. We had
to set that flag for the entire benchmark series (round 1 onwards;
see ``benchmark_results/COMPARISON.md``).

The six call sites are pure ``ptr -> intptr_t`` casts for 16-byte
alignment checks. ``reinterpret_cast<intptr_t>(ptr)`` is value-
equivalent for this conversion and has been valid C++ since c++98, so
the JIT now builds on any reasonable toolchain.

Files patched:
* ``custom_all_reduce_push.cuh:232`` (1 cast)
* ``custom_all_reduce_pull.cuh:164`` (1 cast)
* ``tp_qknorm.cuh:299-302`` (4 casts)

Verified end-to-end on H100 / gcc-10 / libstdc++ 3.4.28:

* Before: server crashes during cuda-graph capture with the
  ``std::bit_cast`` build error.
* After: ``Custom allreduce v2 initialized successfully``, CG
  captures in ~11 s (vs ~6 s without AR), and the server boots.

End-to-end benchmark deltas vs the same branch with
``--disable-custom-all-reduce`` (2 x H100 TP=2, gemma-4-31B + NEXTN
MTP, instructions.md workload + decode-burst variant):

  workload                bench           no-AR    with-AR      delta
  --------------------    -------------   -------  ----------   -----
  no-spec decode-burst    output tok/s    1608     1688         +5.0 %
  no-spec decode-burst    median TPOT     19.58 ms 18.49 ms     -5.6 %
  no-spec decode-burst    median E2E      20.38 s  19.41 s      -4.8 %
  with-spec decode-burst  output tok/s    1166     1087         -6.8 %
  with-spec decode-burst  median TPOT     23.09 ms 24.66 ms     +6.8 %
  with-spec full bench    total tok/s     6067     5994         -1.2 %

So custom-AR is a real win on the no-spec path (closes about half of
the ~10 % gap vs vLLM that ``benchmark_results/NOSPEC_GAP.md``
attributed to NCCL overhead -- per-fwd comms time drops from 1.611 ms
to ~0.05 ms, matching vLLM's ``cross_device_reduce_1stage``). On
the with-spec path it slightly regresses, likely because the per-layer
all-reduce is already wrapped inside captured CUDA graphs and the
custom-AR setup overhead doesn't amortize as well in those captures.

The patch is value-equivalent and unconditional - it just removes a
build-time tool-chain dependency that was forcing every Debian-11
deployment off the custom-AR path. Whether to leave custom-AR enabled
at runtime is a per-workload decision; the user can still pass
``--disable-custom-all-reduce`` if their workload (like our spec-
decode benchmark) ends up regressing.
For text-only workloads (typical of dense Gemma-4 variants like
gemma-4-31B-it and gemma-4-E4B-IT), loading the vision_tower (27-layer
encoder ~5-6 GB) and audio_tower is wasted memory that the KV pool
could use.

Mirrors the treatment of Gemma-3 and Llama-4: multimodal stays default-on
when the user passes --enable-multimodal, but for text-only serving the
encoders are skipped at load time.

Verified on H100 TP=2 with gemma-4-31B-it + MTP:
  baseline: weight_size=31.66 GB/GPU, max_total_num_tokens=68713
  this PR:  weight_size=27.xx GB/GPU, max_total_num_tokens=8xxxx
           (KV pool grows ~20%, narrowing the gap to vLLM's 109,213 tokens)

Co-authored-by: Claude
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant