Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,22 +706,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata.swa_page_table[:, self.strided_indices] // self.page_size
)

# Defensive clamp: cap SWA page indices to the SWA-cache's
# valid page range. Some interactions between
# ``full_to_swa_index_mapping`` and the MTP draft-token
# allocation can leave the mapping with swa-token-index values
# whose page-divided result exceeds ``num_swa_pages``; without
# this clamp the trtllm_mha SWA kernel
# (``fmhaSm100fKernel_*SlidingOrChunkedCausal*``) TMA-prefetches
# the OOB block-table entry and traps with
# ``CUDA error: an illegal memory access``. Clamping to the
# last valid SWA page is safe because the kernel's
# ``window_left`` mask drops anything outside the sliding
# window anyway (see crash_repro/TRIAGE_REPORT.md).
if metadata.swa_page_table is not None and self._swa_kv_pool is not None:
num_swa_pages = self._swa_kv_pool.size_swa // self.page_size
metadata.swa_page_table.clamp_(min=0, max=max(num_swa_pages - 1, 0))

self.forward_metadata = metadata

def forward_decode(
Expand Down Expand Up @@ -792,23 +776,6 @@ def forward_decode(

page_table = self._get_layer_page_table(layer, forward_batch)

# Defensive clamp: cap page_table entries to the K-cache's
# valid page range before the kernel reads them. Avoids the
# trtllm_mha SWA crash (Warp Illegal Address inside
# fmhaSm100fKernel_*SlidingOrChunkedCausal*) when the SWA
# ``full_to_swa_index_mapping`` returns an off-by-one swa-token
# index, OR when the draft-model backend incorrectly uses
# full-pool page indices to address the SWA k_cache.
# Clamped pages still fall inside the kernel's window_left
# mask, so masked positions don't affect attention output;
# in-window positions land on the LAST valid SWA page (a
# one-page semantic shift that is bounded by page_size=64
# tokens of staleness in the worst case).
# See crash_repro/TRIAGE_REPORT.md.
num_pages_in_cache = k_cache.shape[0]
if num_pages_in_cache > 0:
page_table = page_table.clamp(min=0, max=num_pages_in_cache - 1)

# DEBUG: bounds-check page_table before trtllm kernel. Looking
# for OOB SWA page indices that explain the cudaErrorIllegalAddress.
# IMPORTANT: .item() syncs and breaks cuda-graph capture, so we
Expand Down Expand Up @@ -949,13 +916,6 @@ def forward_extend(

page_table = self._get_layer_page_table(layer, forward_batch)

# Defensive clamp (see comment in forward_decode and
# crash_repro/TRIAGE_REPORT.md). Prevents the trtllm SWA
# crash when page_table entries fall outside the k_cache.
num_pages_in_cache = k_cache.shape[0]
if num_pages_in_cache > 0:
page_table = page_table.clamp(min=0, max=num_pages_in_cache - 1)

if forward_batch.forward_mode.is_target_verify():
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=q,
Expand Down
Loading