Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ class Envs:
# None = standard attention. See https://arxiv.org/abs/2512.12087
SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
# Debug flag: bounds-check trtllm_mha page_table before the kernel call.
# Catches OOB SWA page indices that otherwise surface as CUDA illegal
# address errors deep inside the attention kernel. Set to 1 to enable.
SGLANG_TRTLLM_MHA_DEBUG = EnvBool(False)
# TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion
# transport issue on GB200/GB300 platforms is fixed and verified resolved.
SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None)
Expand Down
84 changes: 82 additions & 2 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ def __init__(
self._swa_kv_pool: Optional[SWAKVPool] = (
kv_pool if self.use_sliding_window_kv_pool else None
)
# The model has SWA semantics whenever ANY of its layers carries a
# sliding window size > 0. Use ``model_runner.sliding_window_size``
# as the canonical signal: model_runner sets it from the model's
# ``get_attention_sliding_window_size`` or ``config.sliding_window_size``.
# We need this signal *separately* from the SWA-pool detection
# because the FROZEN_KV_MTP draft backend's pool starts non-SWA and
# gets swapped to the target's SWA pool at forward time; we must
# have allocated SWA-page-table buffers BEFORE that swap.
_model_sw = getattr(model_runner, "sliding_window_size", None)
self.model_has_sliding_window: bool = (
_model_sw is not None and _model_sw > 0
)

# Forward metadata
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
Expand Down Expand Up @@ -161,8 +173,20 @@ def _maybe_translate_swa(
def _alloc_swa_page_table(
self, max_bs: int, max_num_pages: int
) -> Optional[torch.Tensor]:
"""Allocate a SWA page_table buffer, or return None for non-SWA models."""
if not self.use_sliding_window_kv_pool:
"""Allocate a SWA page_table buffer, or return None for non-SWA models.

Note: we eagerly allocate when ``self.model_has_sliding_window`` is
true even if ``self.use_sliding_window_kv_pool`` is currently
``False`` at init time. This is needed for the FROZEN_KV_MTP draft
backend: at init it has no SWA pool, but at forward time
``target_kv_pool_view`` swaps in the target's SWA pool (see
``sglang/srt/speculative/frozen_kv_mtp_utils.py``). Without the
pre-allocated buffer the draft backend would build full-pool
page_table values for SWA layers and crash the trtllm_mha
``fmhaSm100fKernel_*SlidingOrChunkedCausal*`` kernel with
``Warp Illegal Address``.
"""
if not self.use_sliding_window_kv_pool and not self.model_has_sliding_window:
return None
return torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device=self.device)

Expand Down Expand Up @@ -752,6 +776,62 @@ def forward_decode(

page_table = self._get_layer_page_table(layer, forward_batch)

# DEBUG: bounds-check page_table before trtllm kernel. Looking
# for OOB SWA page indices that explain the cudaErrorIllegalAddress.
# IMPORTANT: .item() syncs and breaks cuda-graph capture, so we
# only do this when stream capture is not active.
if envs.SGLANG_TRTLLM_MHA_DEBUG.get() and (
not torch.cuda.is_current_stream_capturing()
):
import os

import torch as _t

cs = self.forward_metadata.cache_seqlens_int32
kc_shape = k_cache.shape # (num_pages, num_kv_heads, page_size, head_dim)
num_pages_in_cache = int(kc_shape[0])
# 1) max-value check
pt_max = int(page_table.max().item())
pt_min = int(page_table.min().item())
if pt_max >= num_pages_in_cache or pt_min < 0:
# Pre-emptively dump and abort before the kernel reads OOB.
dump_dir = os.environ.get(
"SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug"
)
os.makedirs(dump_dir, exist_ok=True)
ts = int(_t.cuda.current_stream().cuda_stream)
fn = (
f"{dump_dir}/page_table_oob_layer{layer.layer_id}_"
f"stream{ts}_{int(_t.cuda.device_count())}.pt"
)
_t.save(
{
"page_table": page_table.detach().cpu(),
"cache_seqlens_int32": cs.detach().cpu(),
"k_cache_shape": list(kc_shape),
"num_pages_in_cache": num_pages_in_cache,
"page_size": self.page_size,
"sliding_window": layer.sliding_window_size,
"layer_id": layer.layer_id,
"forward_mode": str(forward_batch.forward_mode),
"is_swa_layer": (
self._swa_kv_pool.layers_mapping[layer.layer_id][1]
if self.use_sliding_window_kv_pool
else False
),
},
fn,
)
msg = (
f"[trtllm_mha DEBUG] OOB page_table @ layer {layer.layer_id} "
f"({'SWA' if (self.use_sliding_window_kv_pool and self._swa_kv_pool.layers_mapping[layer.layer_id][1]) else 'FULL'}): "
f"page_table.max={pt_max} page_table.min={pt_min} "
f"num_pages_in_cache={num_pages_in_cache}. "
f"Dumped to {fn}"
)
logger.error(msg)
raise RuntimeError(msg)

# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
Expand Down
172 changes: 172 additions & 0 deletions python/sglang/srt/layers/gemma4_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

Fuses standard RMSNorm + residual-add (+ optional scalar multiply) into
a single kernel pass to reduce kernel launch overhead.

Also provides a single-launch fused router for Gemma4 MoE (PR #26120 in
pyc96/sglang fork): replaces the per-layer ``torch.topk`` ->
``softmax`` -> ``per_expert_scale[ids]`` -> ``mul`` -> ``cast`` chain in
``Gemma4MoE.routing_function`` with one Triton kernel.

The reference design comes from vLLM PR #39083
(``_gemma4_routing_kernel`` / ``gemma4_fused_routing_kernel_triton``),
which is apache-2.0. Our kernel is rewritten in SGLang style and uses
the identity ``softmax(all)[topk] / sum(softmax(all)[topk]) =
softmax(topk_logits)`` already exploited by SGLang's torch routing
function, so the math is bitwise-comparable to the prior fp32 path.
"""

from typing import Optional
Expand Down Expand Up @@ -283,3 +295,163 @@ def gemma_dual_rmsnorm_residual_scalar(
BLOCK_SIZE=BLOCK_SIZE,
)
return out


# ---------------------------------------------------------------------------
# Fused Gemma4 routing kernel (one launch per layer)
# ---------------------------------------------------------------------------
#
# Equivalent to:
#
# topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
# topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1)
# topk_weights = topk_weights * per_expert_scale[topk_ids]
# return topk_weights.float(), topk_ids.int()
#
# but completes the entire computation in one Triton program per token.
#
# Algorithm notes:
# * Loads all E logits per token into one program; for Gemma4
# ``E = num_experts = 128`` so ``BLOCK_E = next_pow2(E) = 128`` and the
# work fits in a single warp with `num_warps=1`.
# * Computes ``softmax-of-topk`` by:
# - using ``tl.sort`` on (logit_bits_as_sortable_uint, expert_id) pairs
# packed into int64 — this gives a fully vectorized top-K without a
# K-step loop and matches the bitwise behavior of ``torch.topk``.
# - taking the largest K via a mask on the sorted-descending sequence
# - normalizing in fp32 (matches ``softmax`` default dtype)
# - multiplying by ``per_expert_scale[topk_ids]``
# * Writes ``topk_weights`` (fp32) and ``topk_ids`` (int32) in one
# pass, matching the output dtypes the SGLang MoE topk wrapper
# expects.
#
# Reference algorithm: vLLM PR #39083 ``_gemma4_routing_kernel`` (apache-2.0).
# Our independent implementation follows the same sort+mask+softmax scheme.
@triton.jit
def _gemma4_routing_kernel(
gating_ptr, # [T, E] router logits, any float dtype
per_expert_scale_ptr, # [E] per-expert scale (any float dtype)
topk_weights_ptr, # [T, K] fp32 out
topk_ids_ptr, # [T, K] int32 out
stride_g_t, # stride of gating in the token dim
E: tl.constexpr,
K: tl.constexpr,
BLOCK_E: tl.constexpr,
):
pid = tl.program_id(0)
offs_e = tl.arange(0, BLOCK_E)
valid = offs_e < E

# Load logits into fp32; out-of-bound lanes get -inf so they sort last.
logits = tl.load(
gating_ptr + pid * stride_g_t + offs_e,
mask=valid,
other=-float("inf"),
).to(tl.float32)

# Build a sortable int64 key: high 32 bits = bijective(logit_bits) so
# ascending-int sort == ascending-float sort; low 32 bits = expert id
# (kept stable for ties matching torch.topk's default behavior). This
# avoids a separate index buffer / scatter pass after the sort.
MIN32 = -2147483648
logit_bits = logits.to(tl.int32, bitcast=True)
sign = logit_bits >> 31
key = tl.where(sign == 0, logit_bits ^ -1, logit_bits ^ MIN32)
# Force invalid lanes to the max positive key so they end up *after* the
# real logits when we sort ascending and read from the top of the
# reversed list. (descending=True would flip the order.)
key = tl.where(valid, key, 0x7FFFFFFF)
sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF
packed = (sk64 << 32) | offs_e.to(tl.int64)

# Sort ascending; the K smallest keys correspond to the K largest
# logits because of the bijection above.
sorted_p = tl.sort(packed, descending=False)
all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32)
all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32)

# Invert the bijection to recover the original logit value.
sign_k = all_keys >> 31
all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32)
all_logits = all_bits.to(tl.float32, bitcast=True)

# Softmax over the K largest logits only (identity proven by SGLang's
# torch routing function comment). Subtract the max for stability;
# since the list is sorted descending by logit value, the max sits at
# index 0.
top_mask = offs_e < K
max_l = tl.max(tl.where(top_mask, all_logits, -float("inf")), axis=0)
# exp2(x * log2(e)) is what tl.math.exp expands to; spell it out so we
# can tolerate older Triton releases that lack tl.math.exp.
raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634)
raw_exp = tl.where(top_mask, raw_exp, 0.0)

denom = tl.sum(raw_exp, axis=0)
denom = tl.where(denom > 0.0, denom, 1.0)
weights = raw_exp / denom

# Multiply by per_expert_scale[topk_ids]. per_expert_scale lives in
# any float dtype; cast to fp32 for the final write.
scales = tl.load(
per_expert_scale_ptr + all_ids.to(tl.int64),
mask=top_mask,
other=1.0,
).to(tl.float32)
weights = weights * scales

base_off = pid * K + offs_e
tl.store(topk_weights_ptr + base_off, weights, mask=top_mask)
tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask)


def gemma4_fused_routing(
gating_output: torch.Tensor,
per_expert_scale: torch.Tensor,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""One-launch Gemma4 router.

Args:
gating_output: [T, E] router logits in any floating dtype; will be
cast to fp32 inside the kernel.
per_expert_scale: [E] per-expert scale, any floating dtype.
topk: number of experts to keep per token.

Returns:
topk_weights: [T, topk] fp32 (matches SGLang TopK contract).
topk_ids: [T, topk] int32 (matches SGLang TopK contract).
"""
assert gating_output.dim() == 2, "expected [T, E] router logits"
assert per_expert_scale.dim() == 1
assert per_expert_scale.shape[0] == gating_output.shape[1]
T, E = gating_output.shape
assert topk <= E

# The kernel reads the token row with stride_g_t; force the inner-most
# dim to be contiguous so the masked load is coalesced. Most call
# sites already pass a contiguous tensor (router proj output); contiguous
# is cheap.
gating_output = gating_output.contiguous()
per_expert_scale = per_expert_scale.contiguous()

BLOCK_E = triton.next_power_of_2(E)
topk_weights = torch.empty(
(T, topk), dtype=torch.float32, device=gating_output.device
)
topk_ids = torch.empty((T, topk), dtype=torch.int32, device=gating_output.device)

if T == 0:
return topk_weights, topk_ids

_gemma4_routing_kernel[(T,)](
gating_output,
per_expert_scale,
topk_weights,
topk_ids,
gating_output.stride(0),
E,
topk,
BLOCK_E,
num_warps=1,
)
return topk_weights, topk_ids
70 changes: 70 additions & 0 deletions python/sglang/srt/mem_cache/swa_memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024

# Opt-in debug instrumentation: log when the SWA allocator returns an index
# >= swa_pool_size. Backend-independent. Set ``SGLANG_TRTLLM_MHA_DEBUG=1``
# to enable.
#
# Empirical finding under Gemma-4-E4B-IT + MTP + summarisation 8 k/1 k x 80
# at SWA usage up to 1.00 (triton backend) and up to 0.85+ (trtllm_mha
# backend that crashes): this trap **never fires** under either backend, so
# the SWA allocator is NOT producing OOB indices. The trtllm_mha crash is
# downstream of the allocator -- specifically in
# ``trtllm_mha_backend.init_forward_metadata`` where
# ``metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]``
# pulls in *trailing* positions past each row's cache_seqlens whose
# req_to_token entries were never written (= 0). The translation
# ``full_to_swa_index_mapping[0]`` is the swa slot assigned to full slot 0
# at the last alloc; it can address an arbitrary swa page that may or may
# not be in-bounds. See crash_repro/TRIAGE_REPORT.md.
import os as _os

_DEBUG_SWA_ALLOC_OOB = _os.environ.get("SGLANG_TRTLLM_MHA_DEBUG", "").lower() in (
"1",
"true",
"yes",
)


class SWAKVPool(BaseSWAKVPool):
"""KV cache with separate pools for full and SWA attention layers."""
Expand Down Expand Up @@ -495,8 +519,51 @@ def alloc_extend(
else:
self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices

# DEBUG: instrument SWA allocator OOB writes (independent of
# attention backend). Catches the off-by-one in
# alloc_extend_kernel Part 1 (last_loc + 1 + offset overflowing
# pool_size when last_loc is near the pool end). See
# crash_repro/TRIAGE_REPORT.md.
if _DEBUG_SWA_ALLOC_OOB:
self._maybe_log_swa_oob(alloc_swa_indices, "alloc_extend")

return alloc_full_indices

def _maybe_log_swa_oob(self, alloc_swa_indices: torch.Tensor, ctx: str) -> None:
"""If any swa index is >= ``self._size_swa``, log + dump."""
import os
max_val = int(alloc_swa_indices.max().item())
if max_val >= self._size_swa:
min_val = int(alloc_swa_indices.min().item())
dump_dir = os.environ.get(
"SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug"
)
os.makedirs(dump_dir, exist_ok=True)
fn = (
f"{dump_dir}/swa_alloc_oob_{ctx}_max{max_val}_size{self._size_swa}_"
f"{int(torch.cuda.current_stream().cuda_stream)}.pt"
)
torch.save(
{
"ctx": ctx,
"alloc_swa_indices": alloc_swa_indices.detach().cpu(),
"swa_pool_size": self._size_swa,
"page_size": self.page_size,
"swa_max_value_returned": max_val,
"swa_min_value_returned": min_val,
"oob_count": int((alloc_swa_indices >= self._size_swa).sum().item()),
},
fn,
)
msg = (
f"[SWA alloc DEBUG] OOB swa index from {ctx}: "
f"max={max_val} swa_pool_size={self._size_swa}; "
f"first OOB at flat-idx "
f"{int((alloc_swa_indices >= self._size_swa).nonzero().flatten()[0].item())}. "
f"Dumped to {fn}"
)
logger.error(msg)

def alloc_extend_swa_tail(
self,
prefix_lens: torch.Tensor,
Expand Down Expand Up @@ -590,6 +657,9 @@ def alloc_decode(
else:
self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices

if _DEBUG_SWA_ALLOC_OOB:
self._maybe_log_swa_oob(alloc_swa_indices, "alloc_decode")

return alloc_full_indices

def free(self, free_index: torch.Tensor):
Expand Down
Loading
Loading