diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 5195f8912ce7..cb61867ad0f2 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -133,6 +133,18 @@ def __init__( self._swa_kv_pool: Optional[SWAKVPool] = ( kv_pool if self.use_sliding_window_kv_pool else None ) + # The model has SWA semantics whenever ANY of its layers carries a + # sliding window size > 0. Use ``model_runner.sliding_window_size`` + # as the canonical signal: model_runner sets it from the model's + # ``get_attention_sliding_window_size`` or ``config.sliding_window_size``. + # We need this signal *separately* from the SWA-pool detection + # because the FROZEN_KV_MTP draft backend's pool starts non-SWA and + # gets swapped to the target's SWA pool at forward time; we must + # have allocated SWA-page-table buffers BEFORE that swap. + _model_sw = getattr(model_runner, "sliding_window_size", None) + self.model_has_sliding_window: bool = ( + _model_sw is not None and _model_sw > 0 + ) # Forward metadata self.forward_metadata: Optional[TRTLLMMHAMetadata] = None @@ -161,8 +173,20 @@ def _maybe_translate_swa( def _alloc_swa_page_table( self, max_bs: int, max_num_pages: int ) -> Optional[torch.Tensor]: - """Allocate a SWA page_table buffer, or return None for non-SWA models.""" - if not self.use_sliding_window_kv_pool: + """Allocate a SWA page_table buffer, or return None for non-SWA models. + + Note: we eagerly allocate when ``self.model_has_sliding_window`` is + true even if ``self.use_sliding_window_kv_pool`` is currently + ``False`` at init time. This is needed for the FROZEN_KV_MTP draft + backend: at init it has no SWA pool, but at forward time + ``target_kv_pool_view`` swaps in the target's SWA pool (see + ``sglang/srt/speculative/frozen_kv_mtp_utils.py``). Without the + pre-allocated buffer the draft backend would build full-pool + page_table values for SWA layers and crash the trtllm_mha + ``fmhaSm100fKernel_*SlidingOrChunkedCausal*`` kernel with + ``Warp Illegal Address``. + """ + if not self.use_sliding_window_kv_pool and not self.model_has_sliding_window: return None return torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device=self.device) diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py b/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py index 8b1ac37f8df2..c2add25aaa40 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py @@ -303,10 +303,21 @@ def run_once(): # Swap the draft backend's token_to_kv_pool to the frozen target pool # for the capture; the single backend-attr swap is seen by both # ``get_token_to_kv_pool()`` (via ``get_attn_backend()``) and the - # backend's own reads. + # backend's own reads. Also swap SWA-aware backend state so + # SWA-aware backends (notably trtllm_mha) build SWA-aware metadata + # against the target's SWA pool. See + # ``frozen_kv_mtp_utils._maybe_swap_swa_state``. + from sglang.srt.speculative.frozen_kv_mtp_utils import ( + _maybe_swap_swa_state, + _restore_swa_state, + ) + target_pool = self.frozen_kv_mtp_worker.kv_context.target_token_to_kv_pool saved_backend_pool = self.draft_attn_backend.token_to_kv_pool self.draft_attn_backend.token_to_kv_pool = target_pool + saved_swa_state = _maybe_swap_swa_state( + self.draft_attn_backend, target_pool + ) try: with forward_context(ForwardContext(attn_backend=self.draft_attn_backend)): self.frozen_kv_mtp_worker._init_frozen_kv_metadata_capture_cuda_graph( @@ -319,6 +330,7 @@ def run_once(): ) finally: self.draft_attn_backend.token_to_kv_pool = saved_backend_pool + _restore_swa_state(self.draft_attn_backend, saved_swa_state) set_global_graph_memory_pool(graph.pool()) return graph, out diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py index dbd63c2e444c..d2d7a6c17d59 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py @@ -32,6 +32,53 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +def _maybe_swap_swa_state( + draft_attn_backend: "AttentionBackend", new_pool +): + """Synchronise a backend's SWA-aware attributes with a swapped pool. + + Some attention backends (notably ``trtllm_mha``) cache + ``use_sliding_window_kv_pool`` / ``_swa_kv_pool`` at __init__ time + from ``model_runner.token_to_kv_pool``. When the FROZEN_KV_MTP + contexts swap ``token_to_kv_pool`` to the target's SWA pool, those + cached attributes go stale: the backend then treats every layer as + full-attention even though it is now reading the target's hybrid SWA + pool. For SWA-typed layers this leaks full-pool page indices into + the SWA k_cache page table and crashes the trtllm_mha sm_100a + paged-attention kernel with ``Warp Illegal Address``. + + This helper resolves the SWA-aware attributes from ``new_pool`` + (whether or not it is an SWAKVPool) and writes them back onto the + backend. Returns a tuple of the saved (use_swa, swa_kv_pool, + sliding_window_size) so the caller can restore them. + """ + from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool + + saved = ( + getattr(draft_attn_backend, "use_sliding_window_kv_pool", None), + getattr(draft_attn_backend, "_swa_kv_pool", None), + getattr(draft_attn_backend, "sliding_window_size", None), + ) + is_swa = isinstance(new_pool, SWAKVPool) + if hasattr(draft_attn_backend, "use_sliding_window_kv_pool"): + draft_attn_backend.use_sliding_window_kv_pool = is_swa + if hasattr(draft_attn_backend, "_swa_kv_pool"): + draft_attn_backend._swa_kv_pool = new_pool if is_swa else None + # sliding_window_size is per-layer in the model; the trtllm_mha + # backend caches a module-level value. Don't change it: the draft + # model's own sliding_window_size already matches the target's + # (Gemma4-Assistant inherits the same sliding window). + return saved + + +def _restore_swa_state(draft_attn_backend: "AttentionBackend", saved): + use_swa, swa_kv_pool, sliding_window_size = saved + if hasattr(draft_attn_backend, "use_sliding_window_kv_pool"): + draft_attn_backend.use_sliding_window_kv_pool = use_swa + if hasattr(draft_attn_backend, "_swa_kv_pool"): + draft_attn_backend._swa_kv_pool = swa_kv_pool + + @contextmanager def frozen_kv_target_view( forward_batch: ForwardBatch, @@ -56,11 +103,15 @@ def frozen_kv_target_view( forward_batch.spec_info = None saved_backend_pool = draft_attn_backend.token_to_kv_pool draft_attn_backend.token_to_kv_pool = kv_context.target_token_to_kv_pool + saved_swa_state = _maybe_swap_swa_state( + draft_attn_backend, kv_context.target_token_to_kv_pool + ) try: yield finally: forward_batch.spec_info = saved_spec_info draft_attn_backend.token_to_kv_pool = saved_backend_pool + _restore_swa_state(draft_attn_backend, saved_swa_state) @contextmanager @@ -84,10 +135,14 @@ def target_kv_pool_view( ) saved_backend_pool = draft_attn_backend.token_to_kv_pool draft_attn_backend.token_to_kv_pool = kv_context.target_token_to_kv_pool + saved_swa_state = _maybe_swap_swa_state( + draft_attn_backend, kv_context.target_token_to_kv_pool + ) try: yield finally: draft_attn_backend.token_to_kv_pool = saved_backend_pool + _restore_swa_state(draft_attn_backend, saved_swa_state) def set_frozen_kv_positions(forward_batch: ForwardBatch, topk: int) -> None: