diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index cb61867ad0f2..869ac14b4dcb 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -706,22 +706,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): metadata.swa_page_table[:, self.strided_indices] // self.page_size ) - # Defensive clamp: cap SWA page indices to the SWA-cache's - # valid page range. Some interactions between - # ``full_to_swa_index_mapping`` and the MTP draft-token - # allocation can leave the mapping with swa-token-index values - # whose page-divided result exceeds ``num_swa_pages``; without - # this clamp the trtllm_mha SWA kernel - # (``fmhaSm100fKernel_*SlidingOrChunkedCausal*``) TMA-prefetches - # the OOB block-table entry and traps with - # ``CUDA error: an illegal memory access``. Clamping to the - # last valid SWA page is safe because the kernel's - # ``window_left`` mask drops anything outside the sliding - # window anyway (see crash_repro/TRIAGE_REPORT.md). - if metadata.swa_page_table is not None and self._swa_kv_pool is not None: - num_swa_pages = self._swa_kv_pool.size_swa // self.page_size - metadata.swa_page_table.clamp_(min=0, max=max(num_swa_pages - 1, 0)) - self.forward_metadata = metadata def forward_decode( @@ -792,23 +776,6 @@ def forward_decode( page_table = self._get_layer_page_table(layer, forward_batch) - # Defensive clamp: cap page_table entries to the K-cache's - # valid page range before the kernel reads them. Avoids the - # trtllm_mha SWA crash (Warp Illegal Address inside - # fmhaSm100fKernel_*SlidingOrChunkedCausal*) when the SWA - # ``full_to_swa_index_mapping`` returns an off-by-one swa-token - # index, OR when the draft-model backend incorrectly uses - # full-pool page indices to address the SWA k_cache. - # Clamped pages still fall inside the kernel's window_left - # mask, so masked positions don't affect attention output; - # in-window positions land on the LAST valid SWA page (a - # one-page semantic shift that is bounded by page_size=64 - # tokens of staleness in the worst case). - # See crash_repro/TRIAGE_REPORT.md. - num_pages_in_cache = k_cache.shape[0] - if num_pages_in_cache > 0: - page_table = page_table.clamp(min=0, max=num_pages_in_cache - 1) - # DEBUG: bounds-check page_table before trtllm kernel. Looking # for OOB SWA page indices that explain the cudaErrorIllegalAddress. # IMPORTANT: .item() syncs and breaks cuda-graph capture, so we @@ -949,13 +916,6 @@ def forward_extend( page_table = self._get_layer_page_table(layer, forward_batch) - # Defensive clamp (see comment in forward_decode and - # crash_repro/TRIAGE_REPORT.md). Prevents the trtllm SWA - # crash when page_table entries fall outside the k_cache. - num_pages_in_cache = k_cache.shape[0] - if num_pages_in_cache > 0: - page_table = page_table.clamp(min=0, max=num_pages_in_cache - 1) - if forward_batch.forward_mode.is_target_verify(): o = flashinfer.decode.trtllm_batch_decode_with_kv_cache( query=q,