diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 2791aeec9a8e..54ee243e7c2c 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -412,6 +412,10 @@ class Envs: # None = standard attention. See https://arxiv.org/abs/2512.12087 SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR = EnvFloat(None) SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR = EnvFloat(None) + # Debug flag: bounds-check trtllm_mha page_table before the kernel call. + # Catches OOB SWA page indices that otherwise surface as CUDA illegal + # address errors deep inside the attention kernel. Set to 1 to enable. + SGLANG_TRTLLM_MHA_DEBUG = EnvBool(False) # TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion # transport issue on GB200/GB300 platforms is fixed and verified resolved. SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index e68bcb95e822..1b09a4174070 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -752,6 +752,62 @@ def forward_decode( page_table = self._get_layer_page_table(layer, forward_batch) + # DEBUG: bounds-check page_table before trtllm kernel. Looking + # for OOB SWA page indices that explain the cudaErrorIllegalAddress. + # IMPORTANT: .item() syncs and breaks cuda-graph capture, so we + # only do this when stream capture is not active. + if envs.SGLANG_TRTLLM_MHA_DEBUG.get() and ( + not torch.cuda.is_current_stream_capturing() + ): + import os + + import torch as _t + + cs = self.forward_metadata.cache_seqlens_int32 + kc_shape = k_cache.shape # (num_pages, num_kv_heads, page_size, head_dim) + num_pages_in_cache = int(kc_shape[0]) + # 1) max-value check + pt_max = int(page_table.max().item()) + pt_min = int(page_table.min().item()) + if pt_max >= num_pages_in_cache or pt_min < 0: + # Pre-emptively dump and abort before the kernel reads OOB. + dump_dir = os.environ.get( + "SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug" + ) + os.makedirs(dump_dir, exist_ok=True) + ts = int(_t.cuda.current_stream().cuda_stream) + fn = ( + f"{dump_dir}/page_table_oob_layer{layer.layer_id}_" + f"stream{ts}_{int(_t.cuda.device_count())}.pt" + ) + _t.save( + { + "page_table": page_table.detach().cpu(), + "cache_seqlens_int32": cs.detach().cpu(), + "k_cache_shape": list(kc_shape), + "num_pages_in_cache": num_pages_in_cache, + "page_size": self.page_size, + "sliding_window": layer.sliding_window_size, + "layer_id": layer.layer_id, + "forward_mode": str(forward_batch.forward_mode), + "is_swa_layer": ( + self._swa_kv_pool.layers_mapping[layer.layer_id][1] + if self.use_sliding_window_kv_pool + else False + ), + }, + fn, + ) + msg = ( + f"[trtllm_mha DEBUG] OOB page_table @ layer {layer.layer_id} " + f"({'SWA' if (self.use_sliding_window_kv_pool and self._swa_kv_pool.layers_mapping[layer.layer_id][1]) else 'FULL'}): " + f"page_table.max={pt_max} page_table.min={pt_min} " + f"num_pages_in_cache={num_pages_in_cache}. " + f"Dumped to {fn}" + ) + logger.error(msg) + raise RuntimeError(msg) + # Call TRT-LLM kernel # raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype o = flashinfer.decode.trtllm_batch_decode_with_kv_cache( diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index bd1205708351..4f5fc878c1a4 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -25,6 +25,30 @@ logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 +# Opt-in debug instrumentation: log when the SWA allocator returns an index +# >= swa_pool_size. Backend-independent. Set ``SGLANG_TRTLLM_MHA_DEBUG=1`` +# to enable. +# +# Empirical finding under Gemma-4-E4B-IT + MTP + summarisation 8 k/1 k x 80 +# at SWA usage up to 1.00 (triton backend) and up to 0.85+ (trtllm_mha +# backend that crashes): this trap **never fires** under either backend, so +# the SWA allocator is NOT producing OOB indices. The trtllm_mha crash is +# downstream of the allocator -- specifically in +# ``trtllm_mha_backend.init_forward_metadata`` where +# ``metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]`` +# pulls in *trailing* positions past each row's cache_seqlens whose +# req_to_token entries were never written (= 0). The translation +# ``full_to_swa_index_mapping[0]`` is the swa slot assigned to full slot 0 +# at the last alloc; it can address an arbitrary swa page that may or may +# not be in-bounds. See crash_repro/TRIAGE_REPORT.md. +import os as _os + +_DEBUG_SWA_ALLOC_OOB = _os.environ.get("SGLANG_TRTLLM_MHA_DEBUG", "").lower() in ( + "1", + "true", + "yes", +) + class SWAKVPool(BaseSWAKVPool): """KV cache with separate pools for full and SWA attention layers.""" @@ -495,8 +519,51 @@ def alloc_extend( else: self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices + # DEBUG: instrument SWA allocator OOB writes (independent of + # attention backend). Catches the off-by-one in + # alloc_extend_kernel Part 1 (last_loc + 1 + offset overflowing + # pool_size when last_loc is near the pool end). See + # crash_repro/TRIAGE_REPORT.md. + if _DEBUG_SWA_ALLOC_OOB: + self._maybe_log_swa_oob(alloc_swa_indices, "alloc_extend") + return alloc_full_indices + def _maybe_log_swa_oob(self, alloc_swa_indices: torch.Tensor, ctx: str) -> None: + """If any swa index is >= ``self._size_swa``, log + dump.""" + import os + max_val = int(alloc_swa_indices.max().item()) + if max_val >= self._size_swa: + min_val = int(alloc_swa_indices.min().item()) + dump_dir = os.environ.get( + "SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug" + ) + os.makedirs(dump_dir, exist_ok=True) + fn = ( + f"{dump_dir}/swa_alloc_oob_{ctx}_max{max_val}_size{self._size_swa}_" + f"{int(torch.cuda.current_stream().cuda_stream)}.pt" + ) + torch.save( + { + "ctx": ctx, + "alloc_swa_indices": alloc_swa_indices.detach().cpu(), + "swa_pool_size": self._size_swa, + "page_size": self.page_size, + "swa_max_value_returned": max_val, + "swa_min_value_returned": min_val, + "oob_count": int((alloc_swa_indices >= self._size_swa).sum().item()), + }, + fn, + ) + msg = ( + f"[SWA alloc DEBUG] OOB swa index from {ctx}: " + f"max={max_val} swa_pool_size={self._size_swa}; " + f"first OOB at flat-idx " + f"{int((alloc_swa_indices >= self._size_swa).nonzero().flatten()[0].item())}. " + f"Dumped to {fn}" + ) + logger.error(msg) + def alloc_extend_swa_tail( self, prefix_lens: torch.Tensor, @@ -590,6 +657,9 @@ def alloc_decode( else: self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices + if _DEBUG_SWA_ALLOC_OOB: + self._maybe_log_swa_oob(alloc_swa_indices, "alloc_decode") + return alloc_full_indices def free(self, free_index: torch.Tensor):