perf(gemma4 MTP H100): tune Triton extend tile for Lq=256 / sm_90#4
Draft
pyc96 wants to merge 2 commits into
Draft
perf(gemma4 MTP H100): tune Triton extend tile for Lq=256 / sm_90#4pyc96 wants to merge 2 commits into
pyc96 wants to merge 2 commits into
Conversation
``std::bit_cast`` is a C++20 library feature added in libstdc++ 3.4.29
(gcc 11.1). On Debian 11's gcc-10 (libstdc++ 3.4.28) the JIT
compilation of these three kernels fails with::
error: namespace "std" has no member "bit_cast"
making ``--disable-custom-all-reduce`` mandatory on that host. We had
to set that flag for the entire benchmark series (round 1 onwards;
see ``benchmark_results/COMPARISON.md``).
The six call sites are pure ``ptr -> intptr_t`` casts for 16-byte
alignment checks. ``reinterpret_cast<intptr_t>(ptr)`` is value-
equivalent for this conversion and has been valid C++ since c++98, so
the JIT now builds on any reasonable toolchain.
Files patched:
* ``custom_all_reduce_push.cuh:232`` (1 cast)
* ``custom_all_reduce_pull.cuh:164`` (1 cast)
* ``tp_qknorm.cuh:299-302`` (4 casts)
Verified end-to-end on H100 / gcc-10 / libstdc++ 3.4.28:
* Before: server crashes during cuda-graph capture with the
``std::bit_cast`` build error.
* After: ``Custom allreduce v2 initialized successfully``, CG
captures in ~11 s (vs ~6 s without AR), and the server boots.
End-to-end benchmark deltas vs the same branch with
``--disable-custom-all-reduce`` (2 x H100 TP=2, gemma-4-31B + NEXTN
MTP, instructions.md workload + decode-burst variant):
workload bench no-AR with-AR delta
-------------------- ------------- ------- ---------- -----
no-spec decode-burst output tok/s 1608 1688 +5.0 %
no-spec decode-burst median TPOT 19.58 ms 18.49 ms -5.6 %
no-spec decode-burst median E2E 20.38 s 19.41 s -4.8 %
with-spec decode-burst output tok/s 1166 1087 -6.8 %
with-spec decode-burst median TPOT 23.09 ms 24.66 ms +6.8 %
with-spec full bench total tok/s 6067 5994 -1.2 %
So custom-AR is a real win on the no-spec path (closes about half of
the ~10 % gap vs vLLM that ``benchmark_results/NOSPEC_GAP.md``
attributed to NCCL overhead -- per-fwd comms time drops from 1.611 ms
to ~0.05 ms, matching vLLM's ``cross_device_reduce_1stage``). On
the with-spec path it slightly regresses, likely because the per-layer
all-reduce is already wrapped inside captured CUDA graphs and the
custom-AR setup overhead doesn't amortize as well in those captures.
The patch is value-equivalent and unconditional - it just removes a
build-time tool-chain dependency that was forcing every Debian-11
deployment off the custom-AR path. Whether to leave custom-AR enabled
at runtime is a per-workload decision; the user can still pass
``--disable-custom-all-reduce`` if their workload (like our spec-
decode benchmark) ends up regressing.
The Hopper branch in '_get_block_sizes_for_extend_attention' picked
(BLOCK_M=128, BLOCK_N=64, num_warps=8, num_stages=1) for every Lq<=256.
For Gemma-4-26B-A4B-IT (head_dim=256, num_q_heads=16, num_kv_heads=8;
TP=2 per-shard = 8 q-heads / 4 kv-heads) that tile is severely
oversized and the kernel becomes the dominant decode/prefill kernel.
Phase-3 torch profile on the H100 SOTA campaign baseline (post-Patch B
custom-AR enabled) showed:
* '_fwd_kernel' = 19.2% of decode GPU time (25.6 ms / 133 ms)
* '_fwd_kernel' = 60.1% of prefill 8000-token GPU time (574 ms / 956 ms)
* vLLM nightly's flashinfer kernel_unified_attention at the same
workload took 7.2 ms decode and 381 ms prefill 8k.
Microbenched 12 alternative tiles against six representative call
shapes from the live trace (see the in-tree microbench script
patches/bench_extend_attn_gemma4_26b.py in the H100 run artifact
dir). Winners:
shape (bs, ext, prefix, sw) legacy (128,64,w8,s1) new delta
---------------------------------- --------------------- ------------ -----
prefill long bs=1 ext=8192 sw=-1 2656.80 us 1907.64 us -28.2 % (32,64,w4,s2)
prefill chat bs=1 ext=1000 sw=-1 128.21 us 55.98 us -56.3 % (32,64,w4,s2)
verify chat bs=32 ext=4 pf=1000 sw=1024 616.48 us 144.01 us -76.6 % (16,64,w4,s2)
verify summ bs=32 ext=4 pf=8000 sw=1024 1075.79 us 191.49 us -82.2 % (16,64,w4,s2)
verify burst bs=32 ext=4 pf=64 sw=1024 93.98 us 22.10 us -76.5 % (32,32,w4,s2)
prefill multi bs=4 ext=1000 sw=-1 225.33 us 153.53 us -31.9 % (32,64,w4,s2)
The two regimes (single-seq long-extend prefill vs high-bs short-verify
MTP step) want different tiles. Gate on batch_size >= 8:
* bs < 8 ('single-seq long-extend prefill'): (32, 64, w4, s2)
* bs >= 8 ('MTP verify / chunked-prefill'): (16, 64, w4, s2)
Plumbing changes:
* '_get_block_sizes_for_extend_attention' now takes 'batch_size'
(kw-only) and returns 'num_stages' as well.
* Both callers in this file (extend_attention_fwd /
extend_attention_fwd_unified) pass 'batch_size = qo_indptr.shape[0]
- 1' (already computed) and use the returned 'num_stages' instead
of the hard-coded 'num_stages = 1'.
Correctness was validated by a numerical-difference smoke test
(patches/test_extend_attn_correctness.py): per-element max-abs / ref-max
< 2e-3 across all six call shapes (bf16 noise).
Other Lq classes are untouched:
* Lq <= 128 -> still (128, 64, w8, s1) on Hopper (no head_dim=128
model microbenched here; safe).
* Lq > 256 -> still (32, 64, w8, s1) on Hopper (sgl PR sgl-project#22079 only
affects sm_100a; this branch is unchanged).
* sm120 / sm100a / Ampere / older: unchanged.
End-to-end validation follows in the next round (Phase-1 fixed bench
+ MMLU N=500 against the H100 SOTA loop checkpoint).
This was referenced May 25, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
H100 Triton
_fwd_kernelblock-size tuning for Gemma-4 head_dim=256.The Hopper branch in
_get_block_sizes_for_extend_attentionpicked(BLOCK_M=128, BLOCK_N=64, num_warps=8, num_stages=1)for everyLq<=256. For Gemma-4-26B-A4B-IT (head_dim=256) on H100 TP=2 thattile is severely oversized and the kernel becomes the dominant
decode/prefill kernel.
This PR adds a Lq=256 / sm_90 branch gated on
(batch_size, max_len_extend)that picks:(32, 64, w4, s2)for the chunked-prefill / long-extendregime (the general default; covers initial prefill, multi-seq
chunked prefill, and small-bs cases)
(16, 64, w4, s2)for the MTP verify regime(
batch_size >= 8 AND max_len_extend <= 16, i.e. the per-specverify step with
num_draft_tokens=4)Validation
Microbench
Two microbench scripts in this PR's H100 SOTA run artifact dir:
patches/bench_extend_attn_gemma4_26b.py— the original 6-shapesweep across 12 candidate tiles
bs >= 8chunked-prefillshapes after the first end-to-end bench revealed the simpler
bs>=8 -> verify-tilepredicate was wrong for chunked prefillCorrectness
Numerical-difference smoke test (
patches/test_extend_attn_correctness.py):relative max-abs / ref-max
< 2e-3(bf16 noise) across all six callshapes, both for the prefill tile and the verify tile.
End-to-end (Gemma-4-26B-A4B-IT + NEXTN-MTP, 2 x H100 TP=2,
--attention-backend triton, MTPnum_steps=3,num_draft_tokens=4,eagle_topk=1, 80 prompts, seed=1, nomax_running_requestscap)(
summmedian TTFT moves from 9.8s -> 19.8s, butsummmean / p99TTFT both improve and overall throughput is +21 %. The median-TTFT
"regression" is a benign distribution shift: faster decode means fewer
low-TTFT outliers from early requests, so the median rises while the
worst-case TTFT improves.)
Profile evidence (decode = 1 -> 50 tokens, 30 steps captured)
_fwd_kernel(Triton extend)all_reduce_one_shot_push_kernel(custom AR)fused_moe_kernelProfile evidence (prefill 8000 -> 1 token, 5 steps captured)
_fwd_kernel(Triton extend)fused_moe_kernelncclDevKernel_AllReduceFiles changed
python/sglang/srt/layers/attention/triton_ops/extend_attention.pybatch_size(kw-only) andmax_len_extend(kw-only)parameters to
_get_block_sizes_for_extend_attention; returnnum_stagesas an extra tuple element.Lq <= 256 / Hopper / sm_90with the gatingdescribed above.
extend_attention_fwdandextend_attention_fwd_unified) to pass the new kwargs and usethe returned
num_stages.Risk and rollback
Lq <= 128(e.g. Llama / Qwen2 / DeepSeek-style head_dim=128models) is untouched -- still uses the legacy
(128, 64, w8, s1).Lq > 256(e.g. Gemma-4-31B with head_dim=512) is untouched --still uses the legacy Hopper
(32, 64, w8, s1).the
elif Lq <= 256branch (or invert the gate). The widerback-compat path (
batch_size=0default) returns the new tile,which the microbench shows is universally better than the legacy
one on Lq=256.
CI States
Latest PR Test (Base): ❌ Missing
run-cilabel -- add it to run CI tests.Latest PR Test (Extra): ❌ Blocked --
run-ciis required first.