Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ class Envs:
# None = standard attention. See https://arxiv.org/abs/2512.12087
SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
# Debug flag: bounds-check trtllm_mha page_table before the kernel call.
# Catches OOB SWA page indices that otherwise surface as CUDA illegal
# address errors deep inside the attention kernel. Set to 1 to enable.
SGLANG_TRTLLM_MHA_DEBUG = EnvBool(False)
# TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion
# transport issue on GB200/GB300 platforms is fixed and verified resolved.
SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None)
Expand Down
84 changes: 82 additions & 2 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ def __init__(
self._swa_kv_pool: Optional[SWAKVPool] = (
kv_pool if self.use_sliding_window_kv_pool else None
)
# The model has SWA semantics whenever ANY of its layers carries a
# sliding window size > 0. Use ``model_runner.sliding_window_size``
# as the canonical signal: model_runner sets it from the model's
# ``get_attention_sliding_window_size`` or ``config.sliding_window_size``.
# We need this signal *separately* from the SWA-pool detection
# because the FROZEN_KV_MTP draft backend's pool starts non-SWA and
# gets swapped to the target's SWA pool at forward time; we must
# have allocated SWA-page-table buffers BEFORE that swap.
_model_sw = getattr(model_runner, "sliding_window_size", None)
self.model_has_sliding_window: bool = (
_model_sw is not None and _model_sw > 0
)

# Forward metadata
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
Expand Down Expand Up @@ -161,8 +173,20 @@ def _maybe_translate_swa(
def _alloc_swa_page_table(
self, max_bs: int, max_num_pages: int
) -> Optional[torch.Tensor]:
"""Allocate a SWA page_table buffer, or return None for non-SWA models."""
if not self.use_sliding_window_kv_pool:
"""Allocate a SWA page_table buffer, or return None for non-SWA models.

Note: we eagerly allocate when ``self.model_has_sliding_window`` is
true even if ``self.use_sliding_window_kv_pool`` is currently
``False`` at init time. This is needed for the FROZEN_KV_MTP draft
backend: at init it has no SWA pool, but at forward time
``target_kv_pool_view`` swaps in the target's SWA pool (see
``sglang/srt/speculative/frozen_kv_mtp_utils.py``). Without the
pre-allocated buffer the draft backend would build full-pool
page_table values for SWA layers and crash the trtllm_mha
``fmhaSm100fKernel_*SlidingOrChunkedCausal*`` kernel with
``Warp Illegal Address``.
"""
if not self.use_sliding_window_kv_pool and not self.model_has_sliding_window:
return None
return torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device=self.device)

Expand Down Expand Up @@ -752,6 +776,62 @@ def forward_decode(

page_table = self._get_layer_page_table(layer, forward_batch)

# DEBUG: bounds-check page_table before trtllm kernel. Looking
# for OOB SWA page indices that explain the cudaErrorIllegalAddress.
# IMPORTANT: .item() syncs and breaks cuda-graph capture, so we
# only do this when stream capture is not active.
if envs.SGLANG_TRTLLM_MHA_DEBUG.get() and (
not torch.cuda.is_current_stream_capturing()
):
import os

import torch as _t

cs = self.forward_metadata.cache_seqlens_int32
kc_shape = k_cache.shape # (num_pages, num_kv_heads, page_size, head_dim)
num_pages_in_cache = int(kc_shape[0])
# 1) max-value check
pt_max = int(page_table.max().item())
pt_min = int(page_table.min().item())
if pt_max >= num_pages_in_cache or pt_min < 0:
# Pre-emptively dump and abort before the kernel reads OOB.
dump_dir = os.environ.get(
"SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug"
)
os.makedirs(dump_dir, exist_ok=True)
ts = int(_t.cuda.current_stream().cuda_stream)
fn = (
f"{dump_dir}/page_table_oob_layer{layer.layer_id}_"
f"stream{ts}_{int(_t.cuda.device_count())}.pt"
)
_t.save(
{
"page_table": page_table.detach().cpu(),
"cache_seqlens_int32": cs.detach().cpu(),
"k_cache_shape": list(kc_shape),
"num_pages_in_cache": num_pages_in_cache,
"page_size": self.page_size,
"sliding_window": layer.sliding_window_size,
"layer_id": layer.layer_id,
"forward_mode": str(forward_batch.forward_mode),
"is_swa_layer": (
self._swa_kv_pool.layers_mapping[layer.layer_id][1]
if self.use_sliding_window_kv_pool
else False
),
},
fn,
)
msg = (
f"[trtllm_mha DEBUG] OOB page_table @ layer {layer.layer_id} "
f"({'SWA' if (self.use_sliding_window_kv_pool and self._swa_kv_pool.layers_mapping[layer.layer_id][1]) else 'FULL'}): "
f"page_table.max={pt_max} page_table.min={pt_min} "
f"num_pages_in_cache={num_pages_in_cache}. "
f"Dumped to {fn}"
)
logger.error(msg)
raise RuntimeError(msg)

# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
Expand Down
Loading
Loading