Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e07a7ac
Fix two assistant-MTP regressions surfaced by frozen-KV E4B smoke test
pyc96 May 22, 2026
2c94273
Merge branch 'main' into pyc/fix/gemma4-assistant-mtp-regressions
pyc96 May 22, 2026
2a516ce
Fix Gemma-4 BF16 MoE backend auto-select on SM100
pyc96 May 22, 2026
155cc4a
Merge branch 'main' into pyc/fix/gemma4-assistant-mtp-regressions
pyc96 May 22, 2026
0ea98c6
perf(gemma4 MTP): single-launch fused router (topk + softmax + scale)
pyc96 May 22, 2026
b12237d
perf(gemma4): default swa_full_tokens_ratio=0.15 for the 25:5 SWA:ful…
pyc96 May 22, 2026
7e925d8
debug: trtllm_mha page_table bounds-check (SGLANG_TRTLLM_MHA_DEBUG=1)
pyc96 May 23, 2026
aa45f66
debug: SWA allocator OOB instrumentation (companion to bounds-check t…
pyc96 May 23, 2026
5547e41
fix(trtllm_mha): clamp page_table to k_cache page range to prevent SW…
pyc96 May 23, 2026
a0a8f1e
fix(trtllm_mha + FROZEN_KV_MTP): swap SWA-aware state with target pool
pyc96 May 23, 2026
3a60af0
Revert "fix(trtllm_mha): clamp page_table to k_cache page range"
pyc96 May 23, 2026
b0e87f3
fix(gemma4): only apply swa_full_tokens_ratio=0.15 to MoE variants
pyc96 May 23, 2026
f6513a4
perf(gemma4): close triton-attn TPOT gap (fused PLE tail + piecewise …
May 24, 2026
232415c
perf(gemma4): port vLLM Inductor's triple-rmsnorm fusion (post-attn p…
May 25, 2026
563ac65
perf(gemma4 MM): batch vision encoder and embed_vision calls
May 23, 2026
a0225a1
perf(gemma4): YOCO fast-prefill for E2B/E4B (port of vllm#22628 + #38…
May 24, 2026
38e66f0
JIT custom_all_reduce/tp_qknorm: use reinterpret_cast not std::bit_cast
pyc96 May 21, 2026
192bdea
perf(gemma4 MTP H100): tune Triton extend tile for Lq=256 / sm_90
pyc96 May 23, 2026
7ac3895
perf(gemma4): add Gemma4ForConditionalGeneration to mm_disabled_models
pyc96 May 24, 2026
0c98fb5
perf(gemma4 31b): cap chunked_prefill_size=4096 + mem_fraction floor=…
pyc96 May 25, 2026
c7aaac2
feat(gemma4 ARF): add wrapper + Gemma3MLP skip_all_reduce + auto-enab…
pyc96 May 25, 2026
9094a0a
feat(gemma4 ARF): wire FlashInfer AR+RMSNorm into post-attention site…
pyc96 May 25, 2026
f5c8815
fix(gemma4 PCG): explicit Gemma-4 arch gate (mm_disabled_models remov…
pyc96 May 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ struct CustomAllReducePull : public CustomAllReduceBase {
RuntimeCheck(shot == 1 || shot == 2, "Invalid shot count: ", shot);
RuntimeCheck(device.device_type == kDLCUDA, "Only CUDA device is supported");
RuntimeCheck(is_type<DType>(input.dtype()), "Input dtype mismatch");
RuntimeCheck(std::bit_cast<intptr_t>(input_ptr) % 16 == 0, "Input pointer is not properly aligned");
// ``reinterpret_cast`` rather than ``std::bit_cast`` so the JIT
// builds on libstdc++ < 11 (gcc 10 ships in Debian 11). The cast
// is value-equivalent for pointer-to-integer.
RuntimeCheck(reinterpret_cast<intptr_t>(input_ptr) % 16 == 0, "Input pointer is not properly aligned");
RuntimeCheck(m_pull_ctrl.has_value(), "Controller is not initialized");
RuntimeCheck(static_cast<int64_t>(num_items) == num_items_int64, "Number of items exceeds 4G limit");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ struct CustomAllReducePush : public CustomAllReduceBase {
RuntimeCheck(m_num_gpu == kNumGPU, "Number of GPUs mismatch");
RuntimeCheck(device.device_type == kDLCUDA, "Only CUDA device is supported");
RuntimeCheck(is_type<DType>(input.dtype()), "Input dtype mismatch");
RuntimeCheck(std::bit_cast<intptr_t>(input_ptr) % 16 == 0, "Input pointer is not properly aligned");
// ``reinterpret_cast`` rather than ``std::bit_cast`` so the JIT
// builds on libstdc++ < 11 (gcc 10 ships in Debian 11). The cast
// is value-equivalent for pointer-to-integer.
RuntimeCheck(reinterpret_cast<intptr_t>(input_ptr) % 16 == 0, "Input pointer is not properly aligned");
RuntimeCheck(m_push_ctrl.has_value(), "Controller is not initialized");
RuntimeCheck(shot == 1, "Push all-reduce only supports 1-shot, got: ", shot);
RuntimeCheck(static_cast<int64_t>(num_items) == num_items_int64, "Number of items exceeds 4G limit");
Expand Down
11 changes: 7 additions & 4 deletions python/sglang/jit_kernel/csrc/distributed/tp_qknorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,13 @@ struct FusedParallelQKNormAcrossHead : public CustomAllReduceBase {
const auto needed_buffer_bytes = static_cast<int64_t>(num_tokens) * 2 * sizeof(float);
RuntimeCheck(m_num_gpu == kNumGPU, "Number of GPUs mismatch");
RuntimeCheck(m_push_ctrl.has_value(), "Controller is not initialized");
RuntimeCheck(std::bit_cast<intptr_t>(params.q_ptr) % 16 == 0, "q pointer is not properly aligned");
RuntimeCheck(std::bit_cast<intptr_t>(params.k_ptr) % 16 == 0, "k pointer is not properly aligned");
RuntimeCheck(std::bit_cast<intptr_t>(params.q_weight) % 16 == 0, "q_weight pointer is not properly aligned");
RuntimeCheck(std::bit_cast<intptr_t>(params.k_weight) % 16 == 0, "k_weight pointer is not properly aligned");
// ``reinterpret_cast`` rather than ``std::bit_cast`` so the JIT
// builds on libstdc++ < 11 (gcc 10 ships in Debian 11). The cast
// is value-equivalent for pointer-to-integer.
RuntimeCheck(reinterpret_cast<intptr_t>(params.q_ptr) % 16 == 0, "q pointer is not properly aligned");
RuntimeCheck(reinterpret_cast<intptr_t>(params.k_ptr) % 16 == 0, "k pointer is not properly aligned");
RuntimeCheck(reinterpret_cast<intptr_t>(params.q_weight) % 16 == 0, "q_weight pointer is not properly aligned");
RuntimeCheck(reinterpret_cast<intptr_t>(params.k_weight) % 16 == 0, "k_weight pointer is not properly aligned");
RuntimeCheck(needed_buffer_bytes <= m_push_buffer_bytes, "Push buffer is too small");

LaunchKernel(num_blocks, num_threads, device) //
Expand Down
5 changes: 2 additions & 3 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def __init__(
if enable_multimodal is None:
mm_disabled_models = [
"Gemma3ForConditionalGeneration",
"Gemma4ForConditionalGeneration",
"Llama4ForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
Expand Down Expand Up @@ -914,7 +915,6 @@ def _parse_quant_hf_config(self):
if not is_local:
# Conditional import based on SGLANG_USE_MODELSCOPE environment variable
if envs.SGLANG_USE_MODELSCOPE.get():

from modelscope import HubApi, model_file_download

hf_api = HubApi()
Expand Down Expand Up @@ -1649,8 +1649,7 @@ def compute_mla_mscale_scaling(rope_scaling: dict, base_scaling: float) -> float
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
if "factor" not in rope_scaling:
logger.warning(
"rope_scaling missing 'factor', defaulting to 1.0. "
"Check model accuracy.",
"rope_scaling missing 'factor', defaulting to 1.0. Check model accuracy.",
)
scaling_factor = rope_scaling.get("factor", 1.0)
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ class Envs:
# None = standard attention. See https://arxiv.org/abs/2512.12087
SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR = EnvFloat(None)
# Debug flag: bounds-check trtllm_mha page_table before the kernel call.
# Catches OOB SWA page indices that otherwise surface as CUDA illegal
# address errors deep inside the attention kernel. Set to 1 to enable.
SGLANG_TRTLLM_MHA_DEBUG = EnvBool(False)
# TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion
# transport issue on GB200/GB300 platforms is fixed and verified resolved.
SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None)
Expand Down
98 changes: 82 additions & 16 deletions python/sglang/srt/layers/attention/triton_ops/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,34 @@
_is_hip = is_hip()


def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):
def _get_block_sizes_for_extend_attention(
Lq: int,
Lv: int,
*,
batch_size: int = 0,
max_len_extend: int = 0,
):
"""
Get block sizes and configuration for extend attention kernels.

Args:
Lq: Query head dimension
Lv: Value head dimension
batch_size: Number of sequences in the batch (kw-only). Used by the
H100 (sm_90, Lq<=256) heuristic to pick a smaller tile for
high-bs spec-decode verify shapes where the default (128, 64, w8)
wastes work per program. ``0`` (default) is treated as "unknown"
and preserves the legacy tile.
max_len_extend: Maximum extend length per sequence in the batch
(kw-only). Used together with batch_size to distinguish
high-bs *verify* shapes (small max_len_extend, e.g. 4 for
num_draft_tokens=4) from high-bs *chunked prefill* shapes
(larger max_len_extend). ``0`` (default) is treated as
"unknown" and falls back to the long-extend tile.

Returns:
tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps)
tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps,
num_stages)
"""
# Determine BLOCK_DMODEL and BLOCK_DPE based on head dimension
if Lq == 576:
Expand All @@ -59,6 +77,8 @@ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):

BLOCK_DV = triton.next_power_of_2(Lv)

num_stages = 1

# Determine BLOCK_M, BLOCK_N, and num_warps based on hardware
if _is_hip:
BLOCK_M, BLOCK_N = (64, 64)
Expand All @@ -82,8 +102,48 @@ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):
BLOCK_M, BLOCK_N = (16, 64)
elif _is_cuda and CUDA_CAPABILITY[0] >= 9:
# Hopper architecture (H100, etc.)
if Lq <= 256:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 64)
elif Lq <= 256:
# H100 / sm_90, head_dim == 256 (e.g. Gemma-4-26B-A4B-IT,
# which uses head_dim=256). The legacy (128, 64, w8, s1)
# tile is severely oversized for both the long-extend
# initial-prefill shape (bs=1, ext=8k) and the high-bs
# MTP verify shape (bs=32, ext=4, prefix>=1k) — see
# the microbench in the H100 SOTA run artifact dir
# ``patches/bench_extend_attn_gemma4_26b.py`` (and the
# ``patches/extend_attn_microbench_*.log`` artifacts).
# Microbench winners on bf16, num_q_heads=8, num_kv_heads=4:
# prefill long ext=8192 bs=1 2657us -> 1908us -28% (32,64,w4,s2)
# prefill chat ext=1000 bs=1 128us -> 56us -56% (32,64,w4,s2)
# verify chat ext=4 pf=1000 bs=32 616us -> 144us -77% (16,64,w4,s2)
# verify summ ext=4 pf=8000 bs=32 1076us-> 191us -82% (16,64,w4,s2)
# verify burst ext=4 pf=64 bs=32 94us -> 22us -77% (32,32,w4,s2)
# chunked-prefill ext=512 bs=8 136us -> 92us -32% (32,64,w4,s2)
# chunked-prefill ext=1024 bs=16 752us -> 559us -26% (32,64,w4,s2)
# The (16, 64, w4, s2) tile that dominates the high-bs
# *verify* path (max_len_extend = num_draft_tokens, very
# small) regresses the high-bs *chunked-prefill* path
# (max_len_extend = chunked_prefill_size_per_seq, larger)
# by ~30 %. Gate on BOTH batch_size and max_len_extend
# so chunked prefill keeps (32, 64, w4, s2).
if batch_size >= 8 and 0 < max_len_extend <= 16:
BLOCK_M, BLOCK_N = (16, 64)
num_warps = 4
num_stages = 2
else:
BLOCK_M, BLOCK_N = (32, 64)
num_warps = 4
num_stages = 2
return (
BLOCK_DMODEL,
BLOCK_DPE,
BLOCK_DV,
BLOCK_M,
BLOCK_N,
num_warps,
num_stages,
)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
Expand All @@ -109,7 +169,7 @@ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):

num_warps = 4 if Lq <= 64 else 8

return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps
return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps, num_stages


@triton.jit
Expand Down Expand Up @@ -591,23 +651,26 @@ def extend_attention_fwd(
v_extend.shape[-1],
)

# Get block sizes and configuration
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
_get_block_sizes_for_extend_attention(Lq, Lv)
)

sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]

# Get block sizes and configuration. Pass batch_size + max_len_extend so
# the H100 Lq<=256 heuristic can pick the spec-decode-verify tile
# (only when extend is tiny) vs the chunked-prefill / long-extend tile.
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps, num_stages = (
_get_block_sizes_for_extend_attention(
Lq, Lv, batch_size=batch_size, max_len_extend=max_len_extend
)
)

USE_CUSTOM_MASK = custom_mask is not None
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask

HAS_SINK = sinks is not None

grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1

extra_kargs = {}
if _is_hip:
Expand Down Expand Up @@ -1001,15 +1064,19 @@ def extend_attention_fwd_unified(
"""
Lq, Lv = q.shape[-1], v_buffer.shape[-1]

# Get block sizes and configuration
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
_get_block_sizes_for_extend_attention(Lq, Lv)
)

sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = qo_indptr.shape[0] - 1, q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[1]

# Get block sizes and configuration. Pass batch_size + max_len_extend so
# the H100 Lq<=256 heuristic can pick the spec-decode-verify tile
# (only when extend is tiny) vs the chunked-prefill / long-extend tile.
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps, num_stages = (
_get_block_sizes_for_extend_attention(
Lq, Lv, batch_size=batch_size, max_len_extend=max_len_extend
)
)

USE_CUSTOM_MASK = custom_mask is not None
HAS_SINK = sinks is not None

Expand All @@ -1020,7 +1087,6 @@ def extend_attention_fwd_unified(
window_start_pos = torch.zeros(batch_size, dtype=torch.int32, device=q.device)

grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1

extra_kargs = {}
if _is_hip:
Expand Down
84 changes: 82 additions & 2 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ def __init__(
self._swa_kv_pool: Optional[SWAKVPool] = (
kv_pool if self.use_sliding_window_kv_pool else None
)
# The model has SWA semantics whenever ANY of its layers carries a
# sliding window size > 0. Use ``model_runner.sliding_window_size``
# as the canonical signal: model_runner sets it from the model's
# ``get_attention_sliding_window_size`` or ``config.sliding_window_size``.
# We need this signal *separately* from the SWA-pool detection
# because the FROZEN_KV_MTP draft backend's pool starts non-SWA and
# gets swapped to the target's SWA pool at forward time; we must
# have allocated SWA-page-table buffers BEFORE that swap.
_model_sw = getattr(model_runner, "sliding_window_size", None)
self.model_has_sliding_window: bool = (
_model_sw is not None and _model_sw > 0
)

# Forward metadata
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
Expand Down Expand Up @@ -161,8 +173,20 @@ def _maybe_translate_swa(
def _alloc_swa_page_table(
self, max_bs: int, max_num_pages: int
) -> Optional[torch.Tensor]:
"""Allocate a SWA page_table buffer, or return None for non-SWA models."""
if not self.use_sliding_window_kv_pool:
"""Allocate a SWA page_table buffer, or return None for non-SWA models.

Note: we eagerly allocate when ``self.model_has_sliding_window`` is
true even if ``self.use_sliding_window_kv_pool`` is currently
``False`` at init time. This is needed for the FROZEN_KV_MTP draft
backend: at init it has no SWA pool, but at forward time
``target_kv_pool_view`` swaps in the target's SWA pool (see
``sglang/srt/speculative/frozen_kv_mtp_utils.py``). Without the
pre-allocated buffer the draft backend would build full-pool
page_table values for SWA layers and crash the trtllm_mha
``fmhaSm100fKernel_*SlidingOrChunkedCausal*`` kernel with
``Warp Illegal Address``.
"""
if not self.use_sliding_window_kv_pool and not self.model_has_sliding_window:
return None
return torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device=self.device)

Expand Down Expand Up @@ -752,6 +776,62 @@ def forward_decode(

page_table = self._get_layer_page_table(layer, forward_batch)

# DEBUG: bounds-check page_table before trtllm kernel. Looking
# for OOB SWA page indices that explain the cudaErrorIllegalAddress.
# IMPORTANT: .item() syncs and breaks cuda-graph capture, so we
# only do this when stream capture is not active.
if envs.SGLANG_TRTLLM_MHA_DEBUG.get() and (
not torch.cuda.is_current_stream_capturing()
):
import os

import torch as _t

cs = self.forward_metadata.cache_seqlens_int32
kc_shape = k_cache.shape # (num_pages, num_kv_heads, page_size, head_dim)
num_pages_in_cache = int(kc_shape[0])
# 1) max-value check
pt_max = int(page_table.max().item())
pt_min = int(page_table.min().item())
if pt_max >= num_pages_in_cache or pt_min < 0:
# Pre-emptively dump and abort before the kernel reads OOB.
dump_dir = os.environ.get(
"SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug"
)
os.makedirs(dump_dir, exist_ok=True)
ts = int(_t.cuda.current_stream().cuda_stream)
fn = (
f"{dump_dir}/page_table_oob_layer{layer.layer_id}_"
f"stream{ts}_{int(_t.cuda.device_count())}.pt"
)
_t.save(
{
"page_table": page_table.detach().cpu(),
"cache_seqlens_int32": cs.detach().cpu(),
"k_cache_shape": list(kc_shape),
"num_pages_in_cache": num_pages_in_cache,
"page_size": self.page_size,
"sliding_window": layer.sliding_window_size,
"layer_id": layer.layer_id,
"forward_mode": str(forward_batch.forward_mode),
"is_swa_layer": (
self._swa_kv_pool.layers_mapping[layer.layer_id][1]
if self.use_sliding_window_kv_pool
else False
),
},
fn,
)
msg = (
f"[trtllm_mha DEBUG] OOB page_table @ layer {layer.layer_id} "
f"({'SWA' if (self.use_sliding_window_kv_pool and self._swa_kv_pool.layers_mapping[layer.layer_id][1]) else 'FULL'}): "
f"page_table.max={pt_max} page_table.min={pt_min} "
f"num_pages_in_cache={num_pages_in_cache}. "
f"Dumped to {fn}"
)
logger.error(msg)
raise RuntimeError(msg)

# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
Expand Down
Loading
Loading