Skip to content

fix(trtllm_mha): clamp page_table to k_cache page range to prevent SWA crash#5

Open
pyc96 wants to merge 1 commit into
pyc/sota-gemma4-mtp-trtllm-mha-debugfrom
pyc/sota-gemma4-mtp-trtllm-swa-windowed
Open

fix(trtllm_mha): clamp page_table to k_cache page range to prevent SWA crash#5
pyc96 wants to merge 1 commit into
pyc/sota-gemma4-mtp-trtllm-mha-debugfrom
pyc/sota-gemma4-mtp-trtllm-swa-windowed

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 23, 2026

Summary

Defensive clamp that prevents the deterministic CUDA Warp Illegal
Address crash in fmhaSm100fKernel_*SlidingOrChunkedCausal* when
running Gemma-4 + --attention-backend trtllm_mha + MTP +
summarization workloads.

Stacked on #3 (debug trap). Staged on pyc96/sglang only.

What this fixes

Without this PR, the same workload that the bounds trap (#3)
catches in 30 s on E4B crashes the SGLang server with
cudaErrorIllegalAddress and SIGQUIT. With this PR, the same
workload completes cleanly.

How

In trtllm_mha_backend.py::forward_decode and forward_extend, right
after _get_layer_page_table and before the flashinfer kernel call,
clamp every page index to [0, k_cache.shape[0] - 1].

page_table = self._get_layer_page_table(layer, forward_batch)
num_pages_in_cache = k_cache.shape[0]
if num_pages_in_cache > 0:
    page_table = page_table.clamp(min=0, max=num_pages_in_cache - 1)

Three lines per call site, two call sites.

Benchmark — Gemma-4-E4B-IT, trtllm_mha, MTP, summarization 8 k/1 k × 80

metric before this PR after this PR
outcome CRASH ~30 s in COMPLETED
peak output token throughput n/a 4,032 tok/s
median TTFT n/a 1,437 ms
median TPOT n/a 12.16 ms
accept length n/a 1.61

Benchmark — Gemma-4-26B-A4B-IT, trtllm_mha, MTP, summarization 8 k/1 k × 80

metric triton + Patch 1+2 (best safe baseline) trtllm_mha + this PR Δ
output throughput 1,097 tok/s 1,832 tok/s +67 %
median TPOT 37.87 ms 24.97 ms −34 %
median TTFT 8,763 ms 2,887 ms −67 %
accept length 2.76 1.69 −39 % (see limitation below)

This unlocks the trtllm_mha attention backend for Gemma-4 MTP, which is
otherwise unusable.

Quality — MMLU @ 500 questions (Gemma-4-26B-A4B-IT, seed 0, temp 0)

server accuracy
Patch 2 baseline (triton + MTP) 0.706
trtllm_mha + this PR 0.718
vLLM nightly (for comparison) 0.710

Within MMLU sampling noise; no regression.

Known limitation

Accept length drops from 2.76 (triton) → 1.69 (trtllm_mha + this PR)
on 26B summarization. Investigation:

  • The clamp replaces OOB page indices with the LAST valid SWA page.
  • For positions in the sliding-window range, the kernel will use that
    page's K/V instead of the correct one, producing slightly wrong
    attention values, which lowers MTP draft acceptance.
  • This is a defensive safety net, not a complete fix. The
    underlying off-by-one in either full_to_swa_index_mapping or the
    SWA paged allocator's edge cases needs upstream investigation
    (filed as Patch E in humanize/source-idea-ledger.md).

For workloads where the lower acceptance is acceptable (the ~50 %
throughput improvement still significantly beats the triton baseline),
this fix is a net win.

Cost

One clamp() per kernel call. Few microseconds per forward. No
measurable performance impact.

Tests

No new unit test — the test is the reproducer:

agent-pad/runs/.../crash_repro/repro_e4b_bounds.sh         # used to crash
agent-pad/runs/.../crash_repro/repro_e4b_trtllm_eager_fix.sh   # used to crash, also fixed

Both reproduce no-crash behavior with this PR applied.


CI States

Latest PR Test (Base): ❌ Missing run-ci label -- add it to run CI tests.
Latest PR Test (Extra): ❌ Blocked -- run-ci is required first.

…A crash

Prevents the deterministic CUDA Warp Illegal Address crash in
'fmhaSm100fKernel_*SlidingOrChunkedCausal*' that triggers under
Gemma-4 + --attention-backend trtllm_mha + MTP + summarization
workloads at ~85% SWA pool utilization (see
crash_repro/TRIAGE_REPORT.md).

Root cause: the full_to_swa_index_mapping accumulates entries that
become invalid in certain MTP draft-token allocation patterns; after
//page_size, the resulting swa_page_table can contain values >=
num_swa_pages, which the trtllm SWA kernel TMA-prefetches and traps on.

Fix: clamp page_table values to [0, k_cache.shape[0] - 1] right
before the kernel call in both forward_decode and forward_extend.
Applies to BOTH the regular page_table and swa_page_table paths.

Verification on Gemma-4-E4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
  before this fix: CRASH at ~85% SWA fill, ~30 s into bench
  after this fix:  COMPLETED, output 4032 tok/s peak, no trap events

Verification on Gemma-4-26B-A4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
  before: CRASH (same kernel, same SWA fill trigger)
  after:  COMPLETED, output 1832 tok/s peak (vs Patch 1+2 triton
          1097 tok/s = +67%), TPOT 25 ms (vs triton 38 ms = -34%),
          TTFT 2.9 s (vs triton 8.8 s = -67%)

MMLU @ 500 questions on 26B with this fix: 0.718 (vs Patch 2 baseline
0.706, vLLM 0.710) -- within noise, no regression.

KNOWN LIMITATION: accept length drops vs triton backend (1.69 vs 2.76
on 26B summarization).  Clamped page indices that fall in the attention
window cause the kernel to read the LAST valid SWA page's K/V instead
of the correct one, producing slightly wrong attention values for
those positions.  The clamp is a defensive safety net, not a complete
fix; the underlying ownership of stale full_to_swa_index_mapping
entries needs upstream investigation (filed in
humanize/source-idea-ledger.md as Patch E).  For workloads where the
quality regression is acceptable (or workloads that don't hit the
near-pool-full edge), this fix unlocks the trtllm_mha attention
backend with MTP -- which is otherwise unusable.

Cost: one clamp() per kernel call (~few microseconds, no measurable
perf impact).

See crash_repro/TRIAGE_REPORT.md.

Co-authored-by: Claude
pyc96 added a commit that referenced this pull request May 28, 2026
Root-cause fix for the SWA-aware page_table OOB that crashed
trtllm_mha + MTP + hybrid-SWA models (Gemma-4 26B-A4B-IT, E4B-IT).

The TRTLLMHAAttnBackend caches use_sliding_window_kv_pool and
_swa_kv_pool at __init__ time from model_runner.token_to_kv_pool.
For the FROZEN_KV_MTP draft worker, the draft model_runner's pool is
NOT an SWAKVPool (the draft model is a small assistant); so those
SWA-aware attributes are set to (False, None) at init.

At forward time, frozen_kv_target_view / target_kv_pool_view
swap draft_attn_backend.token_to_kv_pool to the target's
SWAKVPool, but the cached SWA-aware attributes are NOT updated.
The backend then builds full-pool page_table values for layers
that the assistant remaps to SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers all
point at target physical layer 22 via the KV-shared owner map), and
the trtllm_mha sm_100a paged-attention kernel
(fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those
out-of-range page indices from the SWA k_cache (only 8657 pages on
E4B) and traps with Warp Illegal Address.

Definitive evidence captured by the Patch-E investigation:

  [Patch-E DEBUG] backend has use_sliding_window_kv_pool=False,
                  _swa_kv_pool is None? True,
                  layer_id=22, layer.sliding_window_size=512

The fix has two parts:

1. frozen_kv_mtp_utils.py: add _maybe_swap_swa_state /
   _restore_swa_state helpers and wire them into both
   frozen_kv_target_view and target_kv_pool_view so the
   backend's use_sliding_window_kv_pool and _swa_kv_pool
   attributes flip in lockstep with the token_to_kv_pool swap.
2. trtllm_mha_backend.py: add self.model_has_sliding_window
   computed from model_runner.sliding_window_size and use it in
   _alloc_swa_page_table so the SWA page_table buffer is
   eagerly allocated even when the backend's pool is non-SWA at
   init.  This is required for the FROZEN_KV_MTP cuda-graph capture
   path which binds the buffer at replay time.
3. frozen_kv_mtp_cuda_graph_runner.py: also swap SWA state during
   the cuda-graph capture wrapper (the manual swap there mirrors the
   context-manager pattern).

Results on Gemma-4 + trtllm_mha + MTP + summarization (random 8 k/1 k
× 80 prompts, max-concurrency=64 for E4B / unbounded for 26B):

  E4B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         4032              4022            ~same
  accept length        1.61              **2.13**        +32%
  total throughput     31.5 k tok/s      36.2 k tok/s    +15%
  median TPOT (ms)     12.16             9.99            -18%

  26B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         1832              2503            +37%
  accept length        1.67              **2.84**        +70%
  total throughput     16.5 k tok/s      22.5 k tok/s    +37%
  median TPOT (ms)     24.97             20.35           -18%
  median TTFT (ms)     2887              3468            +20%
  benchmark duration   ~60 s             32 s            -47%

26B beats the triton baseline (1097 tok/s, TPOT 37.87 ms, accept 2.76)
by +128%, -46%, +3% respectively.  MMLU @ 500 questions: 0.716 (vs
triton baseline 0.706, vLLM 0.710) -- within sampling noise.

26B chat 1000/1000: TTFT 510 ms (vs vLLM 880 ms), TPOT 8.72 ms (vs
vLLM 8.46 ms), accept 2.89 (vs vLLM 2.80).

This makes the defensive clamp from #5 unnecessary; that
PR can be reverted (or kept as a belt-and-suspenders safety net).

Co-authored-by: Claude
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant