diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 1b09a4174070..5195f8912ce7 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -682,6 +682,22 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): metadata.swa_page_table[:, self.strided_indices] // self.page_size ) + # Defensive clamp: cap SWA page indices to the SWA-cache's + # valid page range. Some interactions between + # ``full_to_swa_index_mapping`` and the MTP draft-token + # allocation can leave the mapping with swa-token-index values + # whose page-divided result exceeds ``num_swa_pages``; without + # this clamp the trtllm_mha SWA kernel + # (``fmhaSm100fKernel_*SlidingOrChunkedCausal*``) TMA-prefetches + # the OOB block-table entry and traps with + # ``CUDA error: an illegal memory access``. Clamping to the + # last valid SWA page is safe because the kernel's + # ``window_left`` mask drops anything outside the sliding + # window anyway (see crash_repro/TRIAGE_REPORT.md). + if metadata.swa_page_table is not None and self._swa_kv_pool is not None: + num_swa_pages = self._swa_kv_pool.size_swa // self.page_size + metadata.swa_page_table.clamp_(min=0, max=max(num_swa_pages - 1, 0)) + self.forward_metadata = metadata def forward_decode( @@ -752,6 +768,23 @@ def forward_decode( page_table = self._get_layer_page_table(layer, forward_batch) + # Defensive clamp: cap page_table entries to the K-cache's + # valid page range before the kernel reads them. Avoids the + # trtllm_mha SWA crash (Warp Illegal Address inside + # fmhaSm100fKernel_*SlidingOrChunkedCausal*) when the SWA + # ``full_to_swa_index_mapping`` returns an off-by-one swa-token + # index, OR when the draft-model backend incorrectly uses + # full-pool page indices to address the SWA k_cache. + # Clamped pages still fall inside the kernel's window_left + # mask, so masked positions don't affect attention output; + # in-window positions land on the LAST valid SWA page (a + # one-page semantic shift that is bounded by page_size=64 + # tokens of staleness in the worst case). + # See crash_repro/TRIAGE_REPORT.md. + num_pages_in_cache = k_cache.shape[0] + if num_pages_in_cache > 0: + page_table = page_table.clamp(min=0, max=num_pages_in_cache - 1) + # DEBUG: bounds-check page_table before trtllm kernel. Looking # for OOB SWA page indices that explain the cudaErrorIllegalAddress. # IMPORTANT: .item() syncs and breaks cuda-graph capture, so we @@ -892,6 +925,13 @@ def forward_extend( page_table = self._get_layer_page_table(layer, forward_batch) + # Defensive clamp (see comment in forward_decode and + # crash_repro/TRIAGE_REPORT.md). Prevents the trtllm SWA + # crash when page_table entries fall outside the k_cache. + num_pages_in_cache = k_cache.shape[0] + if num_pages_in_cache > 0: + page_table = page_table.clamp(min=0, max=num_pages_in_cache - 1) + if forward_batch.forward_mode.is_target_verify(): o = flashinfer.decode.trtllm_batch_decode_with_kv_cache( query=q,