fix(trtllm_mha + FROZEN_KV_MTP): swap SWA-aware state with target pool#6
Open
pyc96 wants to merge 1 commit into
Open
Conversation
Root-cause fix for the SWA-aware page_table OOB that crashed
trtllm_mha + MTP + hybrid-SWA models (Gemma-4 26B-A4B-IT, E4B-IT).
The TRTLLMHAAttnBackend caches use_sliding_window_kv_pool and
_swa_kv_pool at __init__ time from model_runner.token_to_kv_pool.
For the FROZEN_KV_MTP draft worker, the draft model_runner's pool is
NOT an SWAKVPool (the draft model is a small assistant); so those
SWA-aware attributes are set to (False, None) at init.
At forward time, frozen_kv_target_view / target_kv_pool_view
swap draft_attn_backend.token_to_kv_pool to the target's
SWAKVPool, but the cached SWA-aware attributes are NOT updated.
The backend then builds full-pool page_table values for layers
that the assistant remaps to SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers all
point at target physical layer 22 via the KV-shared owner map), and
the trtllm_mha sm_100a paged-attention kernel
(fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those
out-of-range page indices from the SWA k_cache (only 8657 pages on
E4B) and traps with Warp Illegal Address.
Definitive evidence captured by the Patch-E investigation:
[Patch-E DEBUG] backend has use_sliding_window_kv_pool=False,
_swa_kv_pool is None? True,
layer_id=22, layer.sliding_window_size=512
The fix has two parts:
1. frozen_kv_mtp_utils.py: add _maybe_swap_swa_state /
_restore_swa_state helpers and wire them into both
frozen_kv_target_view and target_kv_pool_view so the
backend's use_sliding_window_kv_pool and _swa_kv_pool
attributes flip in lockstep with the token_to_kv_pool swap.
2. trtllm_mha_backend.py: add self.model_has_sliding_window
computed from model_runner.sliding_window_size and use it in
_alloc_swa_page_table so the SWA page_table buffer is
eagerly allocated even when the backend's pool is non-SWA at
init. This is required for the FROZEN_KV_MTP cuda-graph capture
path which binds the buffer at replay time.
3. frozen_kv_mtp_cuda_graph_runner.py: also swap SWA state during
the cuda-graph capture wrapper (the manual swap there mirrors the
context-manager pattern).
Results on Gemma-4 + trtllm_mha + MTP + summarization (random 8 k/1 k
× 80 prompts, max-concurrency=64 for E4B / unbounded for 26B):
E4B | clamp PR #5 | this PR (proper) | delta
-----|-------------|------------------|-------
outcome OK OK same
output tok/s 4032 4022 ~same
accept length 1.61 **2.13** +32%
total throughput 31.5 k tok/s 36.2 k tok/s +15%
median TPOT (ms) 12.16 9.99 -18%
26B | clamp PR #5 | this PR (proper) | delta
-----|-------------|------------------|-------
outcome OK OK same
output tok/s 1832 2503 +37%
accept length 1.67 **2.84** +70%
total throughput 16.5 k tok/s 22.5 k tok/s +37%
median TPOT (ms) 24.97 20.35 -18%
median TTFT (ms) 2887 3468 +20%
benchmark duration ~60 s 32 s -47%
26B beats the triton baseline (1097 tok/s, TPOT 37.87 ms, accept 2.76)
by +128%, -46%, +3% respectively. MMLU @ 500 questions: 0.716 (vs
triton baseline 0.706, vLLM 0.710) -- within sampling noise.
26B chat 1000/1000: TTFT 510 ms (vs vLLM 880 ms), TPOT 8.72 ms (vs
vLLM 8.46 ms), accept 2.89 (vs vLLM 2.80).
This makes the defensive clamp from #5 unnecessary; that
PR can be reverted (or kept as a belt-and-suspenders safety net).
Co-authored-by: Claude
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Root-cause fix for the trtllm_mha + FROZEN_KV_MTP + hybrid-SWA crash.
Supersedes the defensive clamp from #5 with a proper fix
that recovers full MTP draft acceptance.
Stacked on #5 (the clamp). Staged on
pyc96/sglangonly.Root cause
TRTLLMHAAttnBackendcachesuse_sliding_window_kv_pooland_swa_kv_poolat__init__frommodel_runner.token_to_kv_pool. Forthe FROZEN_KV_MTP draft worker, the draft model_runner's pool is NOT
a
SWAKVPool(the draft model is a small assistant), so thoseattributes are set to
(False, None).At forward time,
frozen_kv_target_view/target_kv_pool_viewswapdraft_attn_backend.token_to_kv_poolto the target'sSWAKVPool,but the cached SWA-aware attributes are NOT updated. The backend then
builds full-pool
page_tablevalues for assistant layers that areremapped to target SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers allpoint at target physical layer 22 via the KV-shared owner map). The
trtllm_mha sm_100a paged-attention kernel
(
fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those out-of-rangepage indices from the SWA k_cache (only 8657 pages on E4B) and traps
with
Warp Illegal Address.Definitive evidence from the Patch-E instrumentation:
The fix (3 parts)
frozen_kv_mtp_utils.py— add_maybe_swap_swa_state/_restore_swa_statehelpers and wire them into bothfrozen_kv_target_viewandtarget_kv_pool_view. SWA-awareattributes flip in lockstep with the
token_to_kv_poolswap.trtllm_mha_backend.py— addself.model_has_sliding_windowfrom
model_runner.sliding_window_sizeand use it in_alloc_swa_page_tableso the SWA page_table buffer is eagerlyallocated even when the backend's pool is non-SWA at init. Needed
for the FROZEN_KV_MTP cuda-graph path which binds the pre-
allocated buffer at replay time.
frozen_kv_mtp_cuda_graph_runner.py— apply the same SWA-stateswap inside the cuda-graph capture wrapper (which has its own
manual swap, not via the context manager).
Benchmark: Gemma-4-E4B-IT, trtllm_mha, MTP, summ 8 k/1 k × 80
Benchmark: Gemma-4-26B-A4B-IT, trtllm_mha, MTP, summ 8 k/1 k × 80
Benchmark: Gemma-4-26B-A4B-IT chat 1000/1000
SGLang now beats vLLM on TTFT, accept length, and throughput; matches on TPOT.
Quality
MMLU @ 500 random questions on 26B (seed 0, temp 0):
Within sampling noise; no regression.
Recommendation re. PR #5
The defensive clamp in PR #5 becomes unnecessary with this fix. Two
options for the merge sequence:
(or open a follow-up PR that drops it). The clamp's quality cost
(lower accept rate when it actually triggers) is no longer worth
the safety net it provides because the OOB simply doesn't occur
anymore with the proper SWA-state swap.
~free in perf, just
clamp_per layer).Tests
No new unit tests. Verification is the same crash reproducer that
used to crash now completes:
Both bounds-check traps from PR #3 fire ZERO times with this PR
applied (confirmed in server logs).
CI States
Latest PR Test (Base): ❌ Missing
run-cilabel -- add it to run CI tests.Latest PR Test (Extra): ❌ Blocked --
run-ciis required first.