Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
e07a7ac
Fix two assistant-MTP regressions surfaced by frozen-KV E4B smoke test
pyc96 May 22, 2026
2c94273
Merge branch 'main' into pyc/fix/gemma4-assistant-mtp-regressions
pyc96 May 22, 2026
2a516ce
Fix Gemma-4 BF16 MoE backend auto-select on SM100
pyc96 May 22, 2026
155cc4a
Merge branch 'main' into pyc/fix/gemma4-assistant-mtp-regressions
pyc96 May 22, 2026
0ea98c6
perf(gemma4 MTP): single-launch fused router (topk + softmax + scale)
pyc96 May 22, 2026
b12237d
perf(gemma4): default swa_full_tokens_ratio=0.15 for the 25:5 SWA:ful…
pyc96 May 22, 2026
7e925d8
debug: trtllm_mha page_table bounds-check (SGLANG_TRTLLM_MHA_DEBUG=1)
pyc96 May 23, 2026
aa45f66
debug: SWA allocator OOB instrumentation (companion to bounds-check t…
pyc96 May 23, 2026
5547e41
fix(trtllm_mha): clamp page_table to k_cache page range to prevent SW…
pyc96 May 23, 2026
a0a8f1e
fix(trtllm_mha + FROZEN_KV_MTP): swap SWA-aware state with target pool
pyc96 May 23, 2026
3a60af0
Revert "fix(trtllm_mha): clamp page_table to k_cache page range"
pyc96 May 23, 2026
b0e87f3
fix(gemma4): only apply swa_full_tokens_ratio=0.15 to MoE variants
pyc96 May 23, 2026
f6513a4
perf(gemma4): close triton-attn TPOT gap (fused PLE tail + piecewise …
May 24, 2026
232415c
perf(gemma4): port vLLM Inductor's triple-rmsnorm fusion (post-attn p…
May 25, 2026
563ac65
perf(gemma4 MM): batch vision encoder and embed_vision calls
May 23, 2026
a0225a1
perf(gemma4): YOCO fast-prefill for E2B/E4B (port of vllm#22628 + #38…
May 24, 2026
88faaff
fix(gemma4): FROZEN_KV_MTP zero-accept crash + scheduler merge/filter…
May 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion python/sglang/srt/arg_groups/speculative_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,24 @@ def _resolve_speculative_algorithm_alias(

if speculative_algorithm == "NEXTN" or speculative_algorithm == "EAGLE":
if is_gemma4_draft:
# Opt-out: set SGLANG_GEMMA4_FORCE_EAGLE=1 to keep NEXTN/EAGLE
# on the upstream EAGLE worker (and skip the FROZEN_KV_MTP
# promotion). Useful for A/B testing when FROZEN_KV_MTP's
# FrozenKVMTPWorker overhead exceeds its spec-decode gain on
# a given workload (see runs/20260525_mtp_comparison/).
import os

if os.environ.get("SGLANG_GEMMA4_FORCE_EAGLE", "0") == "1":
logger.info(
"SGLANG_GEMMA4_FORCE_EAGLE=1: keeping "
f"--speculative-algorithm {speculative_algorithm} on the "
"upstream EAGLE worker (skipping FROZEN_KV_MTP promotion)."
)
return "EAGLE"
logger.info(
"Detected Gemma4AssistantForCausalLM draft; "
f"promoting --speculative-algorithm {speculative_algorithm} to FROZEN_KV_MTP."
f"promoting --speculative-algorithm {speculative_algorithm} to FROZEN_KV_MTP. "
"Set SGLANG_GEMMA4_FORCE_EAGLE=1 to opt out."
)
return "FROZEN_KV_MTP"
return "EAGLE"
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ class Envs:
# None = standard attention. See https://arxiv.org/abs/2512.12087
SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
# Debug flag: bounds-check trtllm_mha page_table before the kernel call.
# Catches OOB SWA page indices that otherwise surface as CUDA illegal
# address errors deep inside the attention kernel. Set to 1 to enable.
SGLANG_TRTLLM_MHA_DEBUG = EnvBool(False)
# TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion
# transport issue on GB200/GB300 platforms is fixed and verified resolved.
SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None)
Expand Down
84 changes: 82 additions & 2 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ def __init__(
self._swa_kv_pool: Optional[SWAKVPool] = (
kv_pool if self.use_sliding_window_kv_pool else None
)
# The model has SWA semantics whenever ANY of its layers carries a
# sliding window size > 0. Use ``model_runner.sliding_window_size``
# as the canonical signal: model_runner sets it from the model's
# ``get_attention_sliding_window_size`` or ``config.sliding_window_size``.
# We need this signal *separately* from the SWA-pool detection
# because the FROZEN_KV_MTP draft backend's pool starts non-SWA and
# gets swapped to the target's SWA pool at forward time; we must
# have allocated SWA-page-table buffers BEFORE that swap.
_model_sw = getattr(model_runner, "sliding_window_size", None)
self.model_has_sliding_window: bool = (
_model_sw is not None and _model_sw > 0
)

# Forward metadata
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
Expand Down Expand Up @@ -161,8 +173,20 @@ def _maybe_translate_swa(
def _alloc_swa_page_table(
self, max_bs: int, max_num_pages: int
) -> Optional[torch.Tensor]:
"""Allocate a SWA page_table buffer, or return None for non-SWA models."""
if not self.use_sliding_window_kv_pool:
"""Allocate a SWA page_table buffer, or return None for non-SWA models.

Note: we eagerly allocate when ``self.model_has_sliding_window`` is
true even if ``self.use_sliding_window_kv_pool`` is currently
``False`` at init time. This is needed for the FROZEN_KV_MTP draft
backend: at init it has no SWA pool, but at forward time
``target_kv_pool_view`` swaps in the target's SWA pool (see
``sglang/srt/speculative/frozen_kv_mtp_utils.py``). Without the
pre-allocated buffer the draft backend would build full-pool
page_table values for SWA layers and crash the trtllm_mha
``fmhaSm100fKernel_*SlidingOrChunkedCausal*`` kernel with
``Warp Illegal Address``.
"""
if not self.use_sliding_window_kv_pool and not self.model_has_sliding_window:
return None
return torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device=self.device)

Expand Down Expand Up @@ -752,6 +776,62 @@ def forward_decode(

page_table = self._get_layer_page_table(layer, forward_batch)

# DEBUG: bounds-check page_table before trtllm kernel. Looking
# for OOB SWA page indices that explain the cudaErrorIllegalAddress.
# IMPORTANT: .item() syncs and breaks cuda-graph capture, so we
# only do this when stream capture is not active.
if envs.SGLANG_TRTLLM_MHA_DEBUG.get() and (
not torch.cuda.is_current_stream_capturing()
):
import os

import torch as _t

cs = self.forward_metadata.cache_seqlens_int32
kc_shape = k_cache.shape # (num_pages, num_kv_heads, page_size, head_dim)
num_pages_in_cache = int(kc_shape[0])
# 1) max-value check
pt_max = int(page_table.max().item())
pt_min = int(page_table.min().item())
if pt_max >= num_pages_in_cache or pt_min < 0:
# Pre-emptively dump and abort before the kernel reads OOB.
dump_dir = os.environ.get(
"SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug"
)
os.makedirs(dump_dir, exist_ok=True)
ts = int(_t.cuda.current_stream().cuda_stream)
fn = (
f"{dump_dir}/page_table_oob_layer{layer.layer_id}_"
f"stream{ts}_{int(_t.cuda.device_count())}.pt"
)
_t.save(
{
"page_table": page_table.detach().cpu(),
"cache_seqlens_int32": cs.detach().cpu(),
"k_cache_shape": list(kc_shape),
"num_pages_in_cache": num_pages_in_cache,
"page_size": self.page_size,
"sliding_window": layer.sliding_window_size,
"layer_id": layer.layer_id,
"forward_mode": str(forward_batch.forward_mode),
"is_swa_layer": (
self._swa_kv_pool.layers_mapping[layer.layer_id][1]
if self.use_sliding_window_kv_pool
else False
),
},
fn,
)
msg = (
f"[trtllm_mha DEBUG] OOB page_table @ layer {layer.layer_id} "
f"({'SWA' if (self.use_sliding_window_kv_pool and self._swa_kv_pool.layers_mapping[layer.layer_id][1]) else 'FULL'}): "
f"page_table.max={pt_max} page_table.min={pt_min} "
f"num_pages_in_cache={num_pages_in_cache}. "
f"Dumped to {fn}"
)
logger.error(msg)
raise RuntimeError(msg)

# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
Expand Down
Loading
Loading