Skip to content

debug: trtllm_mha page_table bounds-check (SGLANG_TRTLLM_MHA_DEBUG=1)#3

Open
pyc96 wants to merge 2 commits into
pyc/sota-gemma4-mtp-swa-ratiofrom
pyc/sota-gemma4-mtp-trtllm-mha-debug
Open

debug: trtllm_mha page_table bounds-check (SGLANG_TRTLLM_MHA_DEBUG=1)#3
pyc96 wants to merge 2 commits into
pyc/sota-gemma4-mtp-swa-ratiofrom
pyc/sota-gemma4-mtp-trtllm-mha-debug

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 23, 2026

Summary

Adds an opt-in bounds-check that runs right before
flashinfer.decode.trtllm_batch_decode_with_kv_cache and turns the
otherwise-async CUDA error: an illegal memory access was encountered
from fmhaSm100fKernel_*SlidingOrChunkedCausal* into a deterministic
Python exception, plus dumps page_table + cache_seqlens_int32 for
post-mortem.

Stacked on top of #2 (SWA ratio). Base = same branch.

Staged on pyc96/sglang only — not opening upstream against
sgl-project/sglang.

Motivation

There is a real off-by-one in the trtllm_mha + hybrid-SWA + MTP path
that crashes the SGLang serving process under load. Reproducer:

agent-pad/runs/20260522_gemma4_26b_a4b_it_sota_humanize/crash_repro/repro_e4b_bounds.sh

(Same crash reproduces on the 26B-A4B-IT model too; E4B picks it up
faster — ~30 s instead of ~90 s.)

cuda-gdb on the coredump:

CUDA Exception: Warp Illegal Address
fmhaSm100fKernel_QkvBfloat16OBfloat16H256PagedKvSlidingOrChunkedCausalP64MultiCtasKvCgaVarSeqQ8Kv128StaticSwapsAbForGen

With this debug flag on, the trap fires before the kernel reads OOB:

[trtllm_mha DEBUG] OOB page_table @ layer 22:
    page_table.max=8657  num_pages_in_cache=8657
Dumped to /tmp/trtllm_mha_debug/page_table_oob_layer22_*.pt

Off-by-one root cause (best evidence, see triage report):
the SWA alloc_extend_kernel can emit token indices up to
(num_pages) * page_size + page_size - 1 = 554111, which after
dividing by page_size give page index 8657 — exactly one page
past the K-cache buffer's 8657-page capacity
(size + page_size = 554048 rows = pages [0, 8656] only).

The proper fix is structural (either drop the last page from the free
list and reduce available_size consistently, or bump the K-cache
buffer by one page). Both options need wider audit — see triage
report. This PR ships the trap+dump separately so the bug becomes
debuggable even before the structural fix lands.

Implementation

python/sglang/srt/layers/attention/trtllm_mha_backend.py::forward_decode:

  • When SGLANG_TRTLLM_MHA_DEBUG=1 and the current CUDA stream is
    not capturing a graph, compute page_table.max().item() /
    .min().item() and compare against k_cache.shape[0].
  • On OOB: torch.save the page table + cache_seqlens + shape info to
    $SGLANG_TRTLLM_MHA_DEBUG_DIR (default /tmp/trtllm_mha_debug),
    then logger.error() + raise RuntimeError(...).
  • torch.cuda.is_current_stream_capturing() guard ensures the
    .item() sync does not break cuda-graph capture/replay.

python/sglang/srt/environ.py: adds the SGLANG_TRTLLM_MHA_DEBUG
env var (default False).

Cost

mode overhead
flag off (default) none
flag on, eager decode one extra max().item() + .min().item() per decode forward (~10-50 µs depending on batch size)
flag on, cuda-graph replay none (guarded out)
flag on, cuda-graph capture none (guarded out)

Tests

No unit test — this is an opt-in diagnostic. Functional verification
on Gemma-4-E4B-IT (crash_repro/repro_e4b_bounds.sh) catches the
crash deterministically; sample dump committed under
agent-pad/runs/.../crash_repro/dumps_e4b/page_table_oob_layer22_*.pt.

Provenance

  • All changes original to this fork. No upstream code copied.
  • Co-authored-by: Claude on the commit.

CI States

Latest PR Test (Base): ❌ Missing run-ci label -- add it to run CI tests.
Latest PR Test (Extra): ❌ Blocked -- run-ci is required first.

Opt-in bounds-check before flashinfer trtllm_batch_decode_with_kv_cache
that traps OOB page indices and dumps page_table + cache_seqlens.
Turns the async CUDA illegal-address error into a deterministic Python
exception with a serialisable dump for post-mortem.

See crash_repro/TRIAGE_REPORT.md and crash_repro/repro_e4b_bounds.sh.

Co-authored-by: Claude
…rap)

Adds an opt-in trap inside SWATokenToKVPoolAllocator.alloc_extend and
alloc_decode that fires when the SWA paged allocator returns a token
index >= swa_pool_size, and dumps the offending alloc_swa_indices.

Same env var (SGLANG_TRTLLM_MHA_DEBUG=1) as the trtllm_mha bounds
check.  Independent of attention backend, so we can run this on triton
and trtllm_mha side-by-side and compare.

Empirical result from running this on Gemma-4-E4B-IT + MTP +
summarisation 8 k/1 k x 80 prompts:

  triton backend:     SWA usage reaches 1.00, ZERO trap fires, no crash
  trtllm_mha backend: SWA usage 0.83-0.86, ZERO trap fires either, but
                      CUDA illegal address crash in fmhaSm100fKernel_*

That is, the SWA allocator is NOT the source of the OOB.  Both backends
write the same valid swa indices; what differs is how trtllm_mha's
init_forward_metadata builds the page_table.  Specifically:

  metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]

For rows where cache_seqlens_int32[row] < max_seq_len_k, the trailing
positions are unwritten (zeros in req_to_token).  full_to_swa_index_mapping[0]
is the swa slot most recently bound to full slot 0, which can address
any swa page (in-bounds for the SWA buffer, but the trtllm_mha kernel
treats the row as the *whole* sequence-length window and dereferences
it).

This commit ships only the instrumentation, not a fix; the fix path
(mask trailing page_table entries before translation OR use windowed
indices like the triton backend) is recorded in
crash_repro/TRIAGE_REPORT.md.

Co-authored-by: Claude
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant