Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ def __init__(
self._swa_kv_pool: Optional[SWAKVPool] = (
kv_pool if self.use_sliding_window_kv_pool else None
)
# The model has SWA semantics whenever ANY of its layers carries a
# sliding window size > 0. Use ``model_runner.sliding_window_size``
# as the canonical signal: model_runner sets it from the model's
# ``get_attention_sliding_window_size`` or ``config.sliding_window_size``.
# We need this signal *separately* from the SWA-pool detection
# because the FROZEN_KV_MTP draft backend's pool starts non-SWA and
# gets swapped to the target's SWA pool at forward time; we must
# have allocated SWA-page-table buffers BEFORE that swap.
_model_sw = getattr(model_runner, "sliding_window_size", None)
self.model_has_sliding_window: bool = (
_model_sw is not None and _model_sw > 0
)

# Forward metadata
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
Expand Down Expand Up @@ -161,8 +173,20 @@ def _maybe_translate_swa(
def _alloc_swa_page_table(
self, max_bs: int, max_num_pages: int
) -> Optional[torch.Tensor]:
"""Allocate a SWA page_table buffer, or return None for non-SWA models."""
if not self.use_sliding_window_kv_pool:
"""Allocate a SWA page_table buffer, or return None for non-SWA models.

Note: we eagerly allocate when ``self.model_has_sliding_window`` is
true even if ``self.use_sliding_window_kv_pool`` is currently
``False`` at init time. This is needed for the FROZEN_KV_MTP draft
backend: at init it has no SWA pool, but at forward time
``target_kv_pool_view`` swaps in the target's SWA pool (see
``sglang/srt/speculative/frozen_kv_mtp_utils.py``). Without the
pre-allocated buffer the draft backend would build full-pool
page_table values for SWA layers and crash the trtllm_mha
``fmhaSm100fKernel_*SlidingOrChunkedCausal*`` kernel with
``Warp Illegal Address``.
"""
if not self.use_sliding_window_kv_pool and not self.model_has_sliding_window:
return None
return torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device=self.device)

Expand Down
14 changes: 13 additions & 1 deletion python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,21 @@ def run_once():
# Swap the draft backend's token_to_kv_pool to the frozen target pool
# for the capture; the single backend-attr swap is seen by both
# ``get_token_to_kv_pool()`` (via ``get_attn_backend()``) and the
# backend's own reads.
# backend's own reads. Also swap SWA-aware backend state so
# SWA-aware backends (notably trtllm_mha) build SWA-aware metadata
# against the target's SWA pool. See
# ``frozen_kv_mtp_utils._maybe_swap_swa_state``.
from sglang.srt.speculative.frozen_kv_mtp_utils import (
_maybe_swap_swa_state,
_restore_swa_state,
)

target_pool = self.frozen_kv_mtp_worker.kv_context.target_token_to_kv_pool
saved_backend_pool = self.draft_attn_backend.token_to_kv_pool
self.draft_attn_backend.token_to_kv_pool = target_pool
saved_swa_state = _maybe_swap_swa_state(
self.draft_attn_backend, target_pool
)
try:
with forward_context(ForwardContext(attn_backend=self.draft_attn_backend)):
self.frozen_kv_mtp_worker._init_frozen_kv_metadata_capture_cuda_graph(
Expand All @@ -319,6 +330,7 @@ def run_once():
)
finally:
self.draft_attn_backend.token_to_kv_pool = saved_backend_pool
_restore_swa_state(self.draft_attn_backend, saved_swa_state)
set_global_graph_memory_pool(graph.pool())
return graph, out

Expand Down
55 changes: 55 additions & 0 deletions python/sglang/srt/speculative/frozen_kv_mtp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,53 @@
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend


def _maybe_swap_swa_state(
draft_attn_backend: "AttentionBackend", new_pool
):
"""Synchronise a backend's SWA-aware attributes with a swapped pool.

Some attention backends (notably ``trtllm_mha``) cache
``use_sliding_window_kv_pool`` / ``_swa_kv_pool`` at __init__ time
from ``model_runner.token_to_kv_pool``. When the FROZEN_KV_MTP
contexts swap ``token_to_kv_pool`` to the target's SWA pool, those
cached attributes go stale: the backend then treats every layer as
full-attention even though it is now reading the target's hybrid SWA
pool. For SWA-typed layers this leaks full-pool page indices into
the SWA k_cache page table and crashes the trtllm_mha sm_100a
paged-attention kernel with ``Warp Illegal Address``.

This helper resolves the SWA-aware attributes from ``new_pool``
(whether or not it is an SWAKVPool) and writes them back onto the
backend. Returns a tuple of the saved (use_swa, swa_kv_pool,
sliding_window_size) so the caller can restore them.
"""
from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool

saved = (
getattr(draft_attn_backend, "use_sliding_window_kv_pool", None),
getattr(draft_attn_backend, "_swa_kv_pool", None),
getattr(draft_attn_backend, "sliding_window_size", None),
)
is_swa = isinstance(new_pool, SWAKVPool)
if hasattr(draft_attn_backend, "use_sliding_window_kv_pool"):
draft_attn_backend.use_sliding_window_kv_pool = is_swa
if hasattr(draft_attn_backend, "_swa_kv_pool"):
draft_attn_backend._swa_kv_pool = new_pool if is_swa else None
# sliding_window_size is per-layer in the model; the trtllm_mha
# backend caches a module-level value. Don't change it: the draft
# model's own sliding_window_size already matches the target's
# (Gemma4-Assistant inherits the same sliding window).
return saved


def _restore_swa_state(draft_attn_backend: "AttentionBackend", saved):
use_swa, swa_kv_pool, sliding_window_size = saved
if hasattr(draft_attn_backend, "use_sliding_window_kv_pool"):
draft_attn_backend.use_sliding_window_kv_pool = use_swa
if hasattr(draft_attn_backend, "_swa_kv_pool"):
draft_attn_backend._swa_kv_pool = swa_kv_pool


@contextmanager
def frozen_kv_target_view(
forward_batch: ForwardBatch,
Expand All @@ -56,11 +103,15 @@ def frozen_kv_target_view(
forward_batch.spec_info = None
saved_backend_pool = draft_attn_backend.token_to_kv_pool
draft_attn_backend.token_to_kv_pool = kv_context.target_token_to_kv_pool
saved_swa_state = _maybe_swap_swa_state(
draft_attn_backend, kv_context.target_token_to_kv_pool
)
try:
yield
finally:
forward_batch.spec_info = saved_spec_info
draft_attn_backend.token_to_kv_pool = saved_backend_pool
_restore_swa_state(draft_attn_backend, saved_swa_state)


@contextmanager
Expand All @@ -84,10 +135,14 @@ def target_kv_pool_view(
)
saved_backend_pool = draft_attn_backend.token_to_kv_pool
draft_attn_backend.token_to_kv_pool = kv_context.target_token_to_kv_pool
saved_swa_state = _maybe_swap_swa_state(
draft_attn_backend, kv_context.target_token_to_kv_pool
)
try:
yield
finally:
draft_attn_backend.token_to_kv_pool = saved_backend_pool
_restore_swa_state(draft_attn_backend, saved_swa_state)


def set_frozen_kv_positions(forward_batch: ForwardBatch, topk: int) -> None:
Expand Down
Loading