debug: trtllm_mha page_table bounds-check (SGLANG_TRTLLM_MHA_DEBUG=1)#3
Open
pyc96 wants to merge 2 commits into
Open
debug: trtllm_mha page_table bounds-check (SGLANG_TRTLLM_MHA_DEBUG=1)#3pyc96 wants to merge 2 commits into
pyc96 wants to merge 2 commits into
Conversation
Opt-in bounds-check before flashinfer trtllm_batch_decode_with_kv_cache that traps OOB page indices and dumps page_table + cache_seqlens. Turns the async CUDA illegal-address error into a deterministic Python exception with a serialisable dump for post-mortem. See crash_repro/TRIAGE_REPORT.md and crash_repro/repro_e4b_bounds.sh. Co-authored-by: Claude
…rap)
Adds an opt-in trap inside SWATokenToKVPoolAllocator.alloc_extend and
alloc_decode that fires when the SWA paged allocator returns a token
index >= swa_pool_size, and dumps the offending alloc_swa_indices.
Same env var (SGLANG_TRTLLM_MHA_DEBUG=1) as the trtllm_mha bounds
check. Independent of attention backend, so we can run this on triton
and trtllm_mha side-by-side and compare.
Empirical result from running this on Gemma-4-E4B-IT + MTP +
summarisation 8 k/1 k x 80 prompts:
triton backend: SWA usage reaches 1.00, ZERO trap fires, no crash
trtllm_mha backend: SWA usage 0.83-0.86, ZERO trap fires either, but
CUDA illegal address crash in fmhaSm100fKernel_*
That is, the SWA allocator is NOT the source of the OOB. Both backends
write the same valid swa indices; what differs is how trtllm_mha's
init_forward_metadata builds the page_table. Specifically:
metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]
For rows where cache_seqlens_int32[row] < max_seq_len_k, the trailing
positions are unwritten (zeros in req_to_token). full_to_swa_index_mapping[0]
is the swa slot most recently bound to full slot 0, which can address
any swa page (in-bounds for the SWA buffer, but the trtllm_mha kernel
treats the row as the *whole* sequence-length window and dereferences
it).
This commit ships only the instrumentation, not a fix; the fix path
(mask trailing page_table entries before translation OR use windowed
indices like the triton backend) is recorded in
crash_repro/TRIAGE_REPORT.md.
Co-authored-by: Claude
This was referenced May 23, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds an opt-in bounds-check that runs right before
flashinfer.decode.trtllm_batch_decode_with_kv_cacheand turns theotherwise-async
CUDA error: an illegal memory access was encounteredfrom
fmhaSm100fKernel_*SlidingOrChunkedCausal*into a deterministicPython exception, plus dumps
page_table+cache_seqlens_int32forpost-mortem.
Stacked on top of #2 (SWA ratio). Base = same branch.
Staged on
pyc96/sglangonly — not opening upstream againstsgl-project/sglang.Motivation
There is a real off-by-one in the trtllm_mha + hybrid-SWA + MTP path
that crashes the SGLang serving process under load. Reproducer:
(Same crash reproduces on the 26B-A4B-IT model too; E4B picks it up
faster — ~30 s instead of ~90 s.)
cuda-gdb on the coredump:
With this debug flag on, the trap fires before the kernel reads OOB:
Off-by-one root cause (best evidence, see triage report):
the SWA
alloc_extend_kernelcan emit token indices up to(num_pages) * page_size + page_size - 1 = 554111, which afterdividing by
page_sizegive page index8657— exactly one pagepast the K-cache buffer's
8657-page capacity(
size + page_size = 554048rows = pages[0, 8656]only).The proper fix is structural (either drop the last page from the free
list and reduce
available_sizeconsistently, or bump the K-cachebuffer by one page). Both options need wider audit — see triage
report. This PR ships the trap+dump separately so the bug becomes
debuggable even before the structural fix lands.
Implementation
python/sglang/srt/layers/attention/trtllm_mha_backend.py::forward_decode:SGLANG_TRTLLM_MHA_DEBUG=1and the current CUDA stream isnot capturing a graph, compute
page_table.max().item()/.min().item()and compare againstk_cache.shape[0].torch.savethe page table + cache_seqlens + shape info to$SGLANG_TRTLLM_MHA_DEBUG_DIR(default/tmp/trtllm_mha_debug),then
logger.error()+raise RuntimeError(...).torch.cuda.is_current_stream_capturing()guard ensures the.item()sync does not break cuda-graph capture/replay.python/sglang/srt/environ.py: adds theSGLANG_TRTLLM_MHA_DEBUGenv var (default
False).Cost
max().item()+.min().item()per decode forward (~10-50 µs depending on batch size)Tests
No unit test — this is an opt-in diagnostic. Functional verification
on Gemma-4-E4B-IT (
crash_repro/repro_e4b_bounds.sh) catches thecrash deterministically; sample dump committed under
agent-pad/runs/.../crash_repro/dumps_e4b/page_table_oob_layer22_*.pt.Provenance
Co-authored-by: Claudeon the commit.CI States
Latest PR Test (Base): ❌ Missing
run-cilabel -- add it to run CI tests.Latest PR Test (Extra): ❌ Blocked --
run-ciis required first.