Skip to content

perf(gemma4 MTP H100): tune Triton extend tile for Lq=256 / sm_90#4

Draft
pyc96 wants to merge 2 commits into
pyc/sota-gemma4-mtp-swa-ratiofrom
pyc/sota-gemma4-mtp-h100-extend-tile
Draft

perf(gemma4 MTP H100): tune Triton extend tile for Lq=256 / sm_90#4
pyc96 wants to merge 2 commits into
pyc/sota-gemma4-mtp-swa-ratiofrom
pyc/sota-gemma4-mtp-h100-extend-tile

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 23, 2026

Summary

H100 Triton _fwd_kernel block-size tuning for Gemma-4 head_dim=256.

The Hopper branch in _get_block_sizes_for_extend_attention picked
(BLOCK_M=128, BLOCK_N=64, num_warps=8, num_stages=1) for every
Lq<=256. For Gemma-4-26B-A4B-IT (head_dim=256) on H100 TP=2 that
tile is severely oversized and the kernel becomes the dominant
decode/prefill kernel.

This PR adds a Lq=256 / sm_90 branch gated on (batch_size, max_len_extend) that picks:

  • (32, 64, w4, s2) for the chunked-prefill / long-extend
    regime
    (the general default; covers initial prefill, multi-seq
    chunked prefill, and small-bs cases)
  • (16, 64, w4, s2) for the MTP verify regime
    (batch_size >= 8 AND max_len_extend <= 16, i.e. the per-spec
    verify step with num_draft_tokens=4)

Validation

Microbench

Two microbench scripts in this PR's H100 SOTA run artifact dir:

  • patches/bench_extend_attn_gemma4_26b.py — the original 6-shape
    sweep across 12 candidate tiles
  • extended bench (in the run log) — added bs >= 8 chunked-prefill
    shapes after the first end-to-end bench revealed the simpler
    bs>=8 -> verify-tile predicate was wrong for chunked prefill
Shape (bs, ext, prefix, sw) Legacy (128,64,w8,s1) New Δ Tile picked
prefill long bs=1 ext=8192 sw=-1 2656.80 us 1907.64 us -28% (32,64,w4,s2) prefill
prefill chat bs=1 ext=1000 sw=-1 128.21 us 55.98 us -56% (32,64,w4,s2) prefill
verify chat bs=32 ext=4 pf=1000 616.48 us 144.01 us -77% (16,64,w4,s2) verify
verify summ bs=32 ext=4 pf=8000 1075.79 us 191.49 us -82% (16,64,w4,s2) verify
verify burst bs=32 ext=4 pf=64 93.98 us 22.10 us -77% (16,64,w4,s2) verify
chunked-prefill bs=8 ext=512 135.98 us 92.37 us -32% (32,64,w4,s2) prefill
chunked-prefill bs=16 ext=1024 751.80 us 559.18 us -26% (32,64,w4,s2) prefill
chunked-prefill bs=32 ext=512 460.41 us 344.36 us -25% (32,64,w4,s2) prefill

Correctness

Numerical-difference smoke test (patches/test_extend_attn_correctness.py):
relative max-abs / ref-max < 2e-3 (bf16 noise) across all six call
shapes, both for the prefill tile and the verify tile.

End-to-end (Gemma-4-26B-A4B-IT + NEXTN-MTP, 2 x H100 TP=2,

--attention-backend triton, MTP num_steps=3,
num_draft_tokens=4, eagle_topk=1, 80 prompts, seed=1, no
max_running_requests cap)

Scenario Metric SGLang baseline (custom-AR off) SGLang Patch B (custom-AR) SGLang Patch C (this PR) vLLM nightly Δ vs Patch B
chat 1000/1000 output tok/s 1572 1726 2452 5022 +42 %
chat 1000/1000 median TPOT (ms) 23.11 21.20 12.11 9.16 -43 %
chat 1000/1000 median TTFT (ms) 1735 865 717 1014 -17 % (better than vLLM!)
chat 1000/1000 accept length 2.83 2.82 2.70 2.78 -4 %
summ 8000/1000 output tok/s 818 893 1081 2843 +21 %
summ 8000/1000 median TPOT (ms) 44.81 42.65 25.61 16.66 -40 %
summ 8000/1000 mean TTFT (ms) 20047 18973 17430 7691 -8 %
summ 8000/1000 p99 TTFT (ms) 49976 45277 37364 14944 -17 %
summ 8000/1000 median E2E (ms) 65645 66612 48131 24394 -28 %
summ 8000/1000 duration (s) 97.82 89.58 74.01 28.14 -17 %
summ 8000/1000 accept length 2.82 2.85 2.70 2.78 -5 %
MMLU N=500 accuracy 0.674 0.666 0.666 0.666 tied vLLM

(summ median TTFT moves from 9.8s -> 19.8s, but summ mean / p99
TTFT both improve and overall throughput is +21 %. The median-TTFT
"regression" is a benign distribution shift: faster decode means fewer
low-TTFT outliers from early requests, so the median rises while the
worst-case TTFT improves.)

Profile evidence (decode = 1 -> 50 tokens, 30 steps captured)

Kernel Patch B Patch C Δ
_fwd_kernel (Triton extend) 25.6 ms (19.2 %) 11.2 ms (9.3 %) -56 %
all_reduce_one_shot_push_kernel (custom AR) 15.5 ms (11.6 %) 17.6 ms (14.7 %) +14 %
fused_moe_kernel 16.6 ms (12.4 %) 16.2 ms (13.5 %) -2 %

Profile evidence (prefill 8000 -> 1 token, 5 steps captured)

Kernel Patch B Patch C Δ
_fwd_kernel (Triton extend) 574.6 ms (60.1 %) 435.3 ms (50.7 %) -24 %
fused_moe_kernel 113.2 ms (11.8 %) 129.2 ms (15.0 %) +14 %
ncclDevKernel_AllReduce 83.5 ms (8.7 %) 84.4 ms (9.8 %) +1 %

Files changed

  • python/sglang/srt/layers/attention/triton_ops/extend_attention.py
    • Add batch_size (kw-only) and max_len_extend (kw-only)
      parameters to _get_block_sizes_for_extend_attention; return
      num_stages as an extra tuple element.
    • Insert new branch Lq <= 256 / Hopper / sm_90 with the gating
      described above.
    • Update both callers (extend_attention_fwd and
      extend_attention_fwd_unified) to pass the new kwargs and use
      the returned num_stages.

Risk and rollback

  • Lq <= 128 (e.g. Llama / Qwen2 / DeepSeek-style head_dim=128
    models) is untouched -- still uses the legacy (128, 64, w8, s1).
  • Lq > 256 (e.g. Gemma-4-31B with head_dim=512) is untouched --
    still uses the legacy Hopper (32, 64, w8, s1).
  • sm_120 / sm_100a / Ampere / older arches: untouched.
  • If the new tile regresses an unseen shape, the rollback is to remove
    the elif Lq <= 256 branch (or invert the gate). The wider
    back-compat path (batch_size=0 default) returns the new tile,
    which the microbench shows is universally better than the legacy
    one on Lq=256.

CI States

Latest PR Test (Base): ❌ Missing run-ci label -- add it to run CI tests.
Latest PR Test (Extra): ❌ Blocked -- run-ci is required first.

pyc96 added 2 commits May 22, 2026 23:57
``std::bit_cast`` is a C++20 library feature added in libstdc++ 3.4.29
(gcc 11.1). On Debian 11's gcc-10 (libstdc++ 3.4.28) the JIT
compilation of these three kernels fails with::

    error: namespace "std" has no member "bit_cast"

making ``--disable-custom-all-reduce`` mandatory on that host. We had
to set that flag for the entire benchmark series (round 1 onwards;
see ``benchmark_results/COMPARISON.md``).

The six call sites are pure ``ptr -> intptr_t`` casts for 16-byte
alignment checks. ``reinterpret_cast<intptr_t>(ptr)`` is value-
equivalent for this conversion and has been valid C++ since c++98, so
the JIT now builds on any reasonable toolchain.

Files patched:
* ``custom_all_reduce_push.cuh:232`` (1 cast)
* ``custom_all_reduce_pull.cuh:164`` (1 cast)
* ``tp_qknorm.cuh:299-302`` (4 casts)

Verified end-to-end on H100 / gcc-10 / libstdc++ 3.4.28:

* Before: server crashes during cuda-graph capture with the
  ``std::bit_cast`` build error.
* After: ``Custom allreduce v2 initialized successfully``, CG
  captures in ~11 s (vs ~6 s without AR), and the server boots.

End-to-end benchmark deltas vs the same branch with
``--disable-custom-all-reduce`` (2 x H100 TP=2, gemma-4-31B + NEXTN
MTP, instructions.md workload + decode-burst variant):

  workload                bench           no-AR    with-AR      delta
  --------------------    -------------   -------  ----------   -----
  no-spec decode-burst    output tok/s    1608     1688         +5.0 %
  no-spec decode-burst    median TPOT     19.58 ms 18.49 ms     -5.6 %
  no-spec decode-burst    median E2E      20.38 s  19.41 s      -4.8 %
  with-spec decode-burst  output tok/s    1166     1087         -6.8 %
  with-spec decode-burst  median TPOT     23.09 ms 24.66 ms     +6.8 %
  with-spec full bench    total tok/s     6067     5994         -1.2 %

So custom-AR is a real win on the no-spec path (closes about half of
the ~10 % gap vs vLLM that ``benchmark_results/NOSPEC_GAP.md``
attributed to NCCL overhead -- per-fwd comms time drops from 1.611 ms
to ~0.05 ms, matching vLLM's ``cross_device_reduce_1stage``). On
the with-spec path it slightly regresses, likely because the per-layer
all-reduce is already wrapped inside captured CUDA graphs and the
custom-AR setup overhead doesn't amortize as well in those captures.

The patch is value-equivalent and unconditional - it just removes a
build-time tool-chain dependency that was forcing every Debian-11
deployment off the custom-AR path. Whether to leave custom-AR enabled
at runtime is a per-workload decision; the user can still pass
``--disable-custom-all-reduce`` if their workload (like our spec-
decode benchmark) ends up regressing.
The Hopper branch in '_get_block_sizes_for_extend_attention' picked
(BLOCK_M=128, BLOCK_N=64, num_warps=8, num_stages=1) for every Lq<=256.
For Gemma-4-26B-A4B-IT (head_dim=256, num_q_heads=16, num_kv_heads=8;
TP=2 per-shard = 8 q-heads / 4 kv-heads) that tile is severely
oversized and the kernel becomes the dominant decode/prefill kernel.

Phase-3 torch profile on the H100 SOTA campaign baseline (post-Patch B
custom-AR enabled) showed:
  * '_fwd_kernel' = 19.2% of decode GPU time (25.6 ms / 133 ms)
  * '_fwd_kernel' = 60.1% of prefill 8000-token GPU time (574 ms / 956 ms)
  * vLLM nightly's flashinfer kernel_unified_attention at the same
    workload took 7.2 ms decode and 381 ms prefill 8k.

Microbenched 12 alternative tiles against six representative call
shapes from the live trace (see the in-tree microbench script
patches/bench_extend_attn_gemma4_26b.py in the H100 run artifact
dir).  Winners:

  shape (bs, ext, prefix, sw)         legacy (128,64,w8,s1)  new          delta
  ----------------------------------  --------------------- ------------  -----
  prefill long  bs=1  ext=8192 sw=-1       2656.80 us       1907.64 us   -28.2 %  (32,64,w4,s2)
  prefill chat  bs=1  ext=1000 sw=-1        128.21 us         55.98 us   -56.3 %  (32,64,w4,s2)
  verify chat   bs=32 ext=4 pf=1000 sw=1024 616.48 us        144.01 us   -76.6 %  (16,64,w4,s2)
  verify summ   bs=32 ext=4 pf=8000 sw=1024 1075.79 us       191.49 us   -82.2 %  (16,64,w4,s2)
  verify burst  bs=32 ext=4 pf=64   sw=1024  93.98 us         22.10 us   -76.5 %  (32,32,w4,s2)
  prefill multi bs=4  ext=1000 sw=-1        225.33 us        153.53 us   -31.9 %  (32,64,w4,s2)

The two regimes (single-seq long-extend prefill vs high-bs short-verify
MTP step) want different tiles.  Gate on batch_size >= 8:
  * bs <  8 ('single-seq long-extend prefill'):  (32, 64, w4, s2)
  * bs >= 8 ('MTP verify / chunked-prefill'):    (16, 64, w4, s2)

Plumbing changes:
  * '_get_block_sizes_for_extend_attention' now takes 'batch_size'
    (kw-only) and returns 'num_stages' as well.
  * Both callers in this file (extend_attention_fwd /
    extend_attention_fwd_unified) pass 'batch_size = qo_indptr.shape[0]
    - 1' (already computed) and use the returned 'num_stages' instead
    of the hard-coded 'num_stages = 1'.

Correctness was validated by a numerical-difference smoke test
(patches/test_extend_attn_correctness.py): per-element max-abs / ref-max
< 2e-3 across all six call shapes (bf16 noise).

Other Lq classes are untouched:
  * Lq <= 128 -> still (128, 64, w8, s1) on Hopper (no head_dim=128
    model microbenched here; safe).
  * Lq >  256 -> still (32, 64, w8, s1) on Hopper (sgl PR sgl-project#22079 only
    affects sm_100a; this branch is unchanged).
  * sm120 / sm100a / Ampere / older: unchanged.

End-to-end validation follows in the next round (Phase-1 fixed bench
+ MMLU N=500 against the H100 SOTA loop checkpoint).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant