Skip to content

fix(trtllm_mha + FROZEN_KV_MTP): swap SWA-aware state with target pool#6

Open
pyc96 wants to merge 1 commit into
pyc/sota-gemma4-mtp-trtllm-swa-windowedfrom
pyc/sota-gemma4-mtp-patch-e-rootcause
Open

fix(trtllm_mha + FROZEN_KV_MTP): swap SWA-aware state with target pool#6
pyc96 wants to merge 1 commit into
pyc/sota-gemma4-mtp-trtllm-swa-windowedfrom
pyc/sota-gemma4-mtp-patch-e-rootcause

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 23, 2026

Summary

Root-cause fix for the trtllm_mha + FROZEN_KV_MTP + hybrid-SWA crash.
Supersedes the defensive clamp from #5 with a proper fix
that recovers full MTP draft acceptance.

Stacked on #5 (the clamp). Staged on pyc96/sglang only.

Root cause

TRTLLMHAAttnBackend caches use_sliding_window_kv_pool and
_swa_kv_pool at __init__ from model_runner.token_to_kv_pool. For
the FROZEN_KV_MTP draft worker, the draft model_runner's pool is NOT
a SWAKVPool (the draft model is a small assistant), so those
attributes are set to (False, None).

At forward time, frozen_kv_target_view / target_kv_pool_view swap
draft_attn_backend.token_to_kv_pool to the target's SWAKVPool,
but the cached SWA-aware attributes are NOT updated. The backend then
builds full-pool page_table values for assistant layers that are
remapped to target SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers all
point at target physical layer 22 via the KV-shared owner map). The
trtllm_mha sm_100a paged-attention kernel
(fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those out-of-range
page indices from the SWA k_cache (only 8657 pages on E4B) and traps
with Warp Illegal Address.

Definitive evidence from the Patch-E instrumentation:

[Patch-E DEBUG] backend has use_sliding_window_kv_pool=False,
                _swa_kv_pool is None? True,
                layer_id=22, layer.sliding_window_size=512

The fix (3 parts)

  1. frozen_kv_mtp_utils.py — add _maybe_swap_swa_state /
    _restore_swa_state helpers and wire them into both
    frozen_kv_target_view and target_kv_pool_view. SWA-aware
    attributes flip in lockstep with the token_to_kv_pool swap.
  2. trtllm_mha_backend.py — add self.model_has_sliding_window
    from model_runner.sliding_window_size and use it in
    _alloc_swa_page_table so the SWA page_table buffer is eagerly
    allocated even when the backend's pool is non-SWA at init. Needed
    for the FROZEN_KV_MTP cuda-graph path which binds the pre-
    allocated buffer at replay time.
  3. frozen_kv_mtp_cuda_graph_runner.py — apply the same SWA-state
    swap inside the cuda-graph capture wrapper (which has its own
    manual swap, not via the context manager).

Benchmark: Gemma-4-E4B-IT, trtllm_mha, MTP, summ 8 k/1 k × 80

metric clamp PR #5 this PR Δ
outcome OK OK same
output tok/s 4032 (peak) 4022 ~same
accept length 1.61 2.13 +32%
total throughput 31.5 k tok/s 36.2 k tok/s +15%
median TPOT (ms) 12.16 9.99 −18%
benchmark duration n/a 19.89 s very fast

Benchmark: Gemma-4-26B-A4B-IT, trtllm_mha, MTP, summ 8 k/1 k × 80

metric triton + Patch 2 clamp PR #5 this PR Δ vs triton Δ vs clamp
outcome OK OK OK same same
output tok/s 1097 1832 2503 +128% +37%
accept length 2.76 1.67 2.84 +3% +70%
total throughput 9849 16486 22523 +129% +37%
median TPOT (ms) 37.87 24.97 20.35 −46% −18%
median TTFT (ms) 8763 2887 3468 −60% +20%
duration 75.4 s ~60 s 31.97 s −58% −47%

Benchmark: Gemma-4-26B-A4B-IT chat 1000/1000

metric vLLM nightly this PR
output tok/s 5310 9180
median TTFT 880 ms 510 ms
median TPOT 8.46 ms 8.72 ms
accept length 2.80 2.89

SGLang now beats vLLM on TTFT, accept length, and throughput; matches on TPOT.

Quality

MMLU @ 500 random questions on 26B (seed 0, temp 0):

Within sampling noise; no regression.

Recommendation re. PR #5

The defensive clamp in PR #5 becomes unnecessary with this fix. Two
options for the merge sequence:

  • Recommended: merge this PR, then revert the clamp from PR fix(trtllm_mha): clamp page_table to k_cache page range to prevent SWA crash #5
    (or open a follow-up PR that drops it). The clamp's quality cost
    (lower accept rate when it actually triggers) is no longer worth
    the safety net it provides because the OOB simply doesn't occur
    anymore with the proper SWA-state swap.
  • Alternative: keep both as belt-and-suspenders (the clamp is
    ~free in perf, just clamp_ per layer).

Tests

No new unit tests. Verification is the same crash reproducer that
used to crash now completes:

agent-pad/runs/20260522_gemma4_26b_a4b_it_sota_humanize/crash_repro/repro_e4b_bounds.sh    # used to crash; now: 80/80 ok, accept 2.13
agent-pad/runs/20260522_gemma4_26b_a4b_it_sota_humanize/crash_repro/repro.sh               # used to crash; now: 80/80 ok, accept 2.84

Both bounds-check traps from PR #3 fire ZERO times with this PR
applied (confirmed in server logs).


CI States

Latest PR Test (Base): ❌ Missing run-ci label -- add it to run CI tests.
Latest PR Test (Extra): ❌ Blocked -- run-ci is required first.

Root-cause fix for the SWA-aware page_table OOB that crashed
trtllm_mha + MTP + hybrid-SWA models (Gemma-4 26B-A4B-IT, E4B-IT).

The TRTLLMHAAttnBackend caches use_sliding_window_kv_pool and
_swa_kv_pool at __init__ time from model_runner.token_to_kv_pool.
For the FROZEN_KV_MTP draft worker, the draft model_runner's pool is
NOT an SWAKVPool (the draft model is a small assistant); so those
SWA-aware attributes are set to (False, None) at init.

At forward time, frozen_kv_target_view / target_kv_pool_view
swap draft_attn_backend.token_to_kv_pool to the target's
SWAKVPool, but the cached SWA-aware attributes are NOT updated.
The backend then builds full-pool page_table values for layers
that the assistant remaps to SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers all
point at target physical layer 22 via the KV-shared owner map), and
the trtllm_mha sm_100a paged-attention kernel
(fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those
out-of-range page indices from the SWA k_cache (only 8657 pages on
E4B) and traps with Warp Illegal Address.

Definitive evidence captured by the Patch-E investigation:

  [Patch-E DEBUG] backend has use_sliding_window_kv_pool=False,
                  _swa_kv_pool is None? True,
                  layer_id=22, layer.sliding_window_size=512

The fix has two parts:

1. frozen_kv_mtp_utils.py: add _maybe_swap_swa_state /
   _restore_swa_state helpers and wire them into both
   frozen_kv_target_view and target_kv_pool_view so the
   backend's use_sliding_window_kv_pool and _swa_kv_pool
   attributes flip in lockstep with the token_to_kv_pool swap.
2. trtllm_mha_backend.py: add self.model_has_sliding_window
   computed from model_runner.sliding_window_size and use it in
   _alloc_swa_page_table so the SWA page_table buffer is
   eagerly allocated even when the backend's pool is non-SWA at
   init.  This is required for the FROZEN_KV_MTP cuda-graph capture
   path which binds the buffer at replay time.
3. frozen_kv_mtp_cuda_graph_runner.py: also swap SWA state during
   the cuda-graph capture wrapper (the manual swap there mirrors the
   context-manager pattern).

Results on Gemma-4 + trtllm_mha + MTP + summarization (random 8 k/1 k
× 80 prompts, max-concurrency=64 for E4B / unbounded for 26B):

  E4B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         4032              4022            ~same
  accept length        1.61              **2.13**        +32%
  total throughput     31.5 k tok/s      36.2 k tok/s    +15%
  median TPOT (ms)     12.16             9.99            -18%

  26B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         1832              2503            +37%
  accept length        1.67              **2.84**        +70%
  total throughput     16.5 k tok/s      22.5 k tok/s    +37%
  median TPOT (ms)     24.97             20.35           -18%
  median TTFT (ms)     2887              3468            +20%
  benchmark duration   ~60 s             32 s            -47%

26B beats the triton baseline (1097 tok/s, TPOT 37.87 ms, accept 2.76)
by +128%, -46%, +3% respectively.  MMLU @ 500 questions: 0.716 (vs
triton baseline 0.706, vLLM 0.710) -- within sampling noise.

26B chat 1000/1000: TTFT 510 ms (vs vLLM 880 ms), TPOT 8.72 ms (vs
vLLM 8.46 ms), accept 2.89 (vs vLLM 2.80).

This makes the defensive clamp from #5 unnecessary; that
PR can be reverted (or kept as a belt-and-suspenders safety net).

Co-authored-by: Claude
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant