Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ class Envs:
# None = standard attention. See https://arxiv.org/abs/2512.12087
SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
# Debug flag: bounds-check trtllm_mha page_table before the kernel call.
# Catches OOB SWA page indices that otherwise surface as CUDA illegal
# address errors deep inside the attention kernel. Set to 1 to enable.
SGLANG_TRTLLM_MHA_DEBUG = EnvBool(False)
# TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion
# transport issue on GB200/GB300 platforms is fixed and verified resolved.
SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None)
Expand Down
56 changes: 56 additions & 0 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,62 @@ def forward_decode(

page_table = self._get_layer_page_table(layer, forward_batch)

# DEBUG: bounds-check page_table before trtllm kernel. Looking
# for OOB SWA page indices that explain the cudaErrorIllegalAddress.
# IMPORTANT: .item() syncs and breaks cuda-graph capture, so we
# only do this when stream capture is not active.
if envs.SGLANG_TRTLLM_MHA_DEBUG.get() and (
not torch.cuda.is_current_stream_capturing()
):
import os

import torch as _t

cs = self.forward_metadata.cache_seqlens_int32
kc_shape = k_cache.shape # (num_pages, num_kv_heads, page_size, head_dim)
num_pages_in_cache = int(kc_shape[0])
# 1) max-value check
pt_max = int(page_table.max().item())
pt_min = int(page_table.min().item())
if pt_max >= num_pages_in_cache or pt_min < 0:
# Pre-emptively dump and abort before the kernel reads OOB.
dump_dir = os.environ.get(
"SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug"
)
os.makedirs(dump_dir, exist_ok=True)
ts = int(_t.cuda.current_stream().cuda_stream)
fn = (
f"{dump_dir}/page_table_oob_layer{layer.layer_id}_"
f"stream{ts}_{int(_t.cuda.device_count())}.pt"
)
_t.save(
{
"page_table": page_table.detach().cpu(),
"cache_seqlens_int32": cs.detach().cpu(),
"k_cache_shape": list(kc_shape),
"num_pages_in_cache": num_pages_in_cache,
"page_size": self.page_size,
"sliding_window": layer.sliding_window_size,
"layer_id": layer.layer_id,
"forward_mode": str(forward_batch.forward_mode),
"is_swa_layer": (
self._swa_kv_pool.layers_mapping[layer.layer_id][1]
if self.use_sliding_window_kv_pool
else False
),
},
fn,
)
msg = (
f"[trtllm_mha DEBUG] OOB page_table @ layer {layer.layer_id} "
f"({'SWA' if (self.use_sliding_window_kv_pool and self._swa_kv_pool.layers_mapping[layer.layer_id][1]) else 'FULL'}): "
f"page_table.max={pt_max} page_table.min={pt_min} "
f"num_pages_in_cache={num_pages_in_cache}. "
f"Dumped to {fn}"
)
logger.error(msg)
raise RuntimeError(msg)

# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
Expand Down
70 changes: 70 additions & 0 deletions python/sglang/srt/mem_cache/swa_memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024

# Opt-in debug instrumentation: log when the SWA allocator returns an index
# >= swa_pool_size. Backend-independent. Set ``SGLANG_TRTLLM_MHA_DEBUG=1``
# to enable.
#
# Empirical finding under Gemma-4-E4B-IT + MTP + summarisation 8 k/1 k x 80
# at SWA usage up to 1.00 (triton backend) and up to 0.85+ (trtllm_mha
# backend that crashes): this trap **never fires** under either backend, so
# the SWA allocator is NOT producing OOB indices. The trtllm_mha crash is
# downstream of the allocator -- specifically in
# ``trtllm_mha_backend.init_forward_metadata`` where
# ``metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]``
# pulls in *trailing* positions past each row's cache_seqlens whose
# req_to_token entries were never written (= 0). The translation
# ``full_to_swa_index_mapping[0]`` is the swa slot assigned to full slot 0
# at the last alloc; it can address an arbitrary swa page that may or may
# not be in-bounds. See crash_repro/TRIAGE_REPORT.md.
import os as _os

_DEBUG_SWA_ALLOC_OOB = _os.environ.get("SGLANG_TRTLLM_MHA_DEBUG", "").lower() in (
"1",
"true",
"yes",
)


class SWAKVPool(BaseSWAKVPool):
"""KV cache with separate pools for full and SWA attention layers."""
Expand Down Expand Up @@ -495,8 +519,51 @@ def alloc_extend(
else:
self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices

# DEBUG: instrument SWA allocator OOB writes (independent of
# attention backend). Catches the off-by-one in
# alloc_extend_kernel Part 1 (last_loc + 1 + offset overflowing
# pool_size when last_loc is near the pool end). See
# crash_repro/TRIAGE_REPORT.md.
if _DEBUG_SWA_ALLOC_OOB:
self._maybe_log_swa_oob(alloc_swa_indices, "alloc_extend")

return alloc_full_indices

def _maybe_log_swa_oob(self, alloc_swa_indices: torch.Tensor, ctx: str) -> None:
"""If any swa index is >= ``self._size_swa``, log + dump."""
import os
max_val = int(alloc_swa_indices.max().item())
if max_val >= self._size_swa:
min_val = int(alloc_swa_indices.min().item())
dump_dir = os.environ.get(
"SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug"
)
os.makedirs(dump_dir, exist_ok=True)
fn = (
f"{dump_dir}/swa_alloc_oob_{ctx}_max{max_val}_size{self._size_swa}_"
f"{int(torch.cuda.current_stream().cuda_stream)}.pt"
)
torch.save(
{
"ctx": ctx,
"alloc_swa_indices": alloc_swa_indices.detach().cpu(),
"swa_pool_size": self._size_swa,
"page_size": self.page_size,
"swa_max_value_returned": max_val,
"swa_min_value_returned": min_val,
"oob_count": int((alloc_swa_indices >= self._size_swa).sum().item()),
},
fn,
)
msg = (
f"[SWA alloc DEBUG] OOB swa index from {ctx}: "
f"max={max_val} swa_pool_size={self._size_swa}; "
f"first OOB at flat-idx "
f"{int((alloc_swa_indices >= self._size_swa).nonzero().flatten()[0].item())}. "
f"Dumped to {fn}"
)
logger.error(msg)

def alloc_extend_swa_tail(
self,
prefix_lens: torch.Tensor,
Expand Down Expand Up @@ -590,6 +657,9 @@ def alloc_decode(
else:
self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices

if _DEBUG_SWA_ALLOC_OOB:
self._maybe_log_swa_oob(alloc_swa_indices, "alloc_decode")

return alloc_full_indices

def free(self, free_index: torch.Tensor):
Expand Down
Loading