diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 2791aeec9a8e..54ee243e7c2c 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -412,6 +412,10 @@ class Envs: # None = standard attention. See https://arxiv.org/abs/2512.12087 SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR = EnvFloat(None) SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR = EnvFloat(None) + # Debug flag: bounds-check trtllm_mha page_table before the kernel call. + # Catches OOB SWA page indices that otherwise surface as CUDA illegal + # address errors deep inside the attention kernel. Set to 1 to enable. + SGLANG_TRTLLM_MHA_DEBUG = EnvBool(False) # TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion # transport issue on GB200/GB300 platforms is fixed and verified resolved. SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index e68bcb95e822..869ac14b4dcb 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -133,6 +133,18 @@ def __init__( self._swa_kv_pool: Optional[SWAKVPool] = ( kv_pool if self.use_sliding_window_kv_pool else None ) + # The model has SWA semantics whenever ANY of its layers carries a + # sliding window size > 0. Use ``model_runner.sliding_window_size`` + # as the canonical signal: model_runner sets it from the model's + # ``get_attention_sliding_window_size`` or ``config.sliding_window_size``. + # We need this signal *separately* from the SWA-pool detection + # because the FROZEN_KV_MTP draft backend's pool starts non-SWA and + # gets swapped to the target's SWA pool at forward time; we must + # have allocated SWA-page-table buffers BEFORE that swap. + _model_sw = getattr(model_runner, "sliding_window_size", None) + self.model_has_sliding_window: bool = ( + _model_sw is not None and _model_sw > 0 + ) # Forward metadata self.forward_metadata: Optional[TRTLLMMHAMetadata] = None @@ -161,8 +173,20 @@ def _maybe_translate_swa( def _alloc_swa_page_table( self, max_bs: int, max_num_pages: int ) -> Optional[torch.Tensor]: - """Allocate a SWA page_table buffer, or return None for non-SWA models.""" - if not self.use_sliding_window_kv_pool: + """Allocate a SWA page_table buffer, or return None for non-SWA models. + + Note: we eagerly allocate when ``self.model_has_sliding_window`` is + true even if ``self.use_sliding_window_kv_pool`` is currently + ``False`` at init time. This is needed for the FROZEN_KV_MTP draft + backend: at init it has no SWA pool, but at forward time + ``target_kv_pool_view`` swaps in the target's SWA pool (see + ``sglang/srt/speculative/frozen_kv_mtp_utils.py``). Without the + pre-allocated buffer the draft backend would build full-pool + page_table values for SWA layers and crash the trtllm_mha + ``fmhaSm100fKernel_*SlidingOrChunkedCausal*`` kernel with + ``Warp Illegal Address``. + """ + if not self.use_sliding_window_kv_pool and not self.model_has_sliding_window: return None return torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device=self.device) @@ -752,6 +776,62 @@ def forward_decode( page_table = self._get_layer_page_table(layer, forward_batch) + # DEBUG: bounds-check page_table before trtllm kernel. Looking + # for OOB SWA page indices that explain the cudaErrorIllegalAddress. + # IMPORTANT: .item() syncs and breaks cuda-graph capture, so we + # only do this when stream capture is not active. + if envs.SGLANG_TRTLLM_MHA_DEBUG.get() and ( + not torch.cuda.is_current_stream_capturing() + ): + import os + + import torch as _t + + cs = self.forward_metadata.cache_seqlens_int32 + kc_shape = k_cache.shape # (num_pages, num_kv_heads, page_size, head_dim) + num_pages_in_cache = int(kc_shape[0]) + # 1) max-value check + pt_max = int(page_table.max().item()) + pt_min = int(page_table.min().item()) + if pt_max >= num_pages_in_cache or pt_min < 0: + # Pre-emptively dump and abort before the kernel reads OOB. + dump_dir = os.environ.get( + "SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug" + ) + os.makedirs(dump_dir, exist_ok=True) + ts = int(_t.cuda.current_stream().cuda_stream) + fn = ( + f"{dump_dir}/page_table_oob_layer{layer.layer_id}_" + f"stream{ts}_{int(_t.cuda.device_count())}.pt" + ) + _t.save( + { + "page_table": page_table.detach().cpu(), + "cache_seqlens_int32": cs.detach().cpu(), + "k_cache_shape": list(kc_shape), + "num_pages_in_cache": num_pages_in_cache, + "page_size": self.page_size, + "sliding_window": layer.sliding_window_size, + "layer_id": layer.layer_id, + "forward_mode": str(forward_batch.forward_mode), + "is_swa_layer": ( + self._swa_kv_pool.layers_mapping[layer.layer_id][1] + if self.use_sliding_window_kv_pool + else False + ), + }, + fn, + ) + msg = ( + f"[trtllm_mha DEBUG] OOB page_table @ layer {layer.layer_id} " + f"({'SWA' if (self.use_sliding_window_kv_pool and self._swa_kv_pool.layers_mapping[layer.layer_id][1]) else 'FULL'}): " + f"page_table.max={pt_max} page_table.min={pt_min} " + f"num_pages_in_cache={num_pages_in_cache}. " + f"Dumped to {fn}" + ) + logger.error(msg) + raise RuntimeError(msg) + # Call TRT-LLM kernel # raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype o = flashinfer.decode.trtllm_batch_decode_with_kv_cache( diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py index ad6f01d9875a..cdbd443691a5 100644 --- a/python/sglang/srt/layers/gemma4_fused_ops.py +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -2,6 +2,18 @@ Fuses standard RMSNorm + residual-add (+ optional scalar multiply) into a single kernel pass to reduce kernel launch overhead. + +Also provides a single-launch fused router for Gemma4 MoE (PR #26120 in +pyc96/sglang fork): replaces the per-layer ``torch.topk`` -> +``softmax`` -> ``per_expert_scale[ids]`` -> ``mul`` -> ``cast`` chain in +``Gemma4MoE.routing_function`` with one Triton kernel. + +The reference design comes from vLLM PR #39083 +(``_gemma4_routing_kernel`` / ``gemma4_fused_routing_kernel_triton``), +which is apache-2.0. Our kernel is rewritten in SGLang style and uses +the identity ``softmax(all)[topk] / sum(softmax(all)[topk]) = +softmax(topk_logits)`` already exploited by SGLang's torch routing +function, so the math is bitwise-comparable to the prior fp32 path. """ from typing import Optional @@ -283,3 +295,458 @@ def gemma_dual_rmsnorm_residual_scalar( BLOCK_SIZE=BLOCK_SIZE, ) return out + + +# --------------------------------------------------------------------------- +# Fused Gemma4 routing kernel (one launch per layer) +# --------------------------------------------------------------------------- +# +# Equivalent to: +# +# topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) +# topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) +# topk_weights = topk_weights * per_expert_scale[topk_ids] +# return topk_weights.float(), topk_ids.int() +# +# but completes the entire computation in one Triton program per token. +# +# Algorithm notes: +# * Loads all E logits per token into one program; for Gemma4 +# ``E = num_experts = 128`` so ``BLOCK_E = next_pow2(E) = 128`` and the +# work fits in a single warp with `num_warps=1`. +# * Computes ``softmax-of-topk`` by: +# - using ``tl.sort`` on (logit_bits_as_sortable_uint, expert_id) pairs +# packed into int64 — this gives a fully vectorized top-K without a +# K-step loop and matches the bitwise behavior of ``torch.topk``. +# - taking the largest K via a mask on the sorted-descending sequence +# - normalizing in fp32 (matches ``softmax`` default dtype) +# - multiplying by ``per_expert_scale[topk_ids]`` +# * Writes ``topk_weights`` (fp32) and ``topk_ids`` (int32) in one +# pass, matching the output dtypes the SGLang MoE topk wrapper +# expects. +# +# Reference algorithm: vLLM PR #39083 ``_gemma4_routing_kernel`` (apache-2.0). +# Our independent implementation follows the same sort+mask+softmax scheme. +@triton.jit +def _gemma4_routing_kernel( + gating_ptr, # [T, E] router logits, any float dtype + per_expert_scale_ptr, # [E] per-expert scale (any float dtype) + topk_weights_ptr, # [T, K] fp32 out + topk_ids_ptr, # [T, K] int32 out + stride_g_t, # stride of gating in the token dim + E: tl.constexpr, + K: tl.constexpr, + BLOCK_E: tl.constexpr, +): + pid = tl.program_id(0) + offs_e = tl.arange(0, BLOCK_E) + valid = offs_e < E + + # Load logits into fp32; out-of-bound lanes get -inf so they sort last. + logits = tl.load( + gating_ptr + pid * stride_g_t + offs_e, + mask=valid, + other=-float("inf"), + ).to(tl.float32) + + # Build a sortable int64 key: high 32 bits = bijective(logit_bits) so + # ascending-int sort == ascending-float sort; low 32 bits = expert id + # (kept stable for ties matching torch.topk's default behavior). This + # avoids a separate index buffer / scatter pass after the sort. + MIN32 = -2147483648 + logit_bits = logits.to(tl.int32, bitcast=True) + sign = logit_bits >> 31 + key = tl.where(sign == 0, logit_bits ^ -1, logit_bits ^ MIN32) + # Force invalid lanes to the max positive key so they end up *after* the + # real logits when we sort ascending and read from the top of the + # reversed list. (descending=True would flip the order.) + key = tl.where(valid, key, 0x7FFFFFFF) + sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF + packed = (sk64 << 32) | offs_e.to(tl.int64) + + # Sort ascending; the K smallest keys correspond to the K largest + # logits because of the bijection above. + sorted_p = tl.sort(packed, descending=False) + all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32) + all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32) + + # Invert the bijection to recover the original logit value. + sign_k = all_keys >> 31 + all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32) + all_logits = all_bits.to(tl.float32, bitcast=True) + + # Softmax over the K largest logits only (identity proven by SGLang's + # torch routing function comment). Subtract the max for stability; + # since the list is sorted descending by logit value, the max sits at + # index 0. + top_mask = offs_e < K + max_l = tl.max(tl.where(top_mask, all_logits, -float("inf")), axis=0) + # exp2(x * log2(e)) is what tl.math.exp expands to; spell it out so we + # can tolerate older Triton releases that lack tl.math.exp. + raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634) + raw_exp = tl.where(top_mask, raw_exp, 0.0) + + denom = tl.sum(raw_exp, axis=0) + denom = tl.where(denom > 0.0, denom, 1.0) + weights = raw_exp / denom + + # Multiply by per_expert_scale[topk_ids]. per_expert_scale lives in + # any float dtype; cast to fp32 for the final write. + scales = tl.load( + per_expert_scale_ptr + all_ids.to(tl.int64), + mask=top_mask, + other=1.0, + ).to(tl.float32) + weights = weights * scales + + base_off = pid * K + offs_e + tl.store(topk_weights_ptr + base_off, weights, mask=top_mask) + tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask) + + +def gemma4_fused_routing( + gating_output: torch.Tensor, + per_expert_scale: torch.Tensor, + topk: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """One-launch Gemma4 router. + + Args: + gating_output: [T, E] router logits in any floating dtype; will be + cast to fp32 inside the kernel. + per_expert_scale: [E] per-expert scale, any floating dtype. + topk: number of experts to keep per token. + + Returns: + topk_weights: [T, topk] fp32 (matches SGLang TopK contract). + topk_ids: [T, topk] int32 (matches SGLang TopK contract). + """ + assert gating_output.dim() == 2, "expected [T, E] router logits" + assert per_expert_scale.dim() == 1 + assert per_expert_scale.shape[0] == gating_output.shape[1] + T, E = gating_output.shape + assert topk <= E + + # The kernel reads the token row with stride_g_t; force the inner-most + # dim to be contiguous so the masked load is coalesced. Most call + # sites already pass a contiguous tensor (router proj output); contiguous + # is cheap. + gating_output = gating_output.contiguous() + per_expert_scale = per_expert_scale.contiguous() + + BLOCK_E = triton.next_power_of_2(E) + topk_weights = torch.empty( + (T, topk), dtype=torch.float32, device=gating_output.device + ) + topk_ids = torch.empty((T, topk), dtype=torch.int32, device=gating_output.device) + + if T == 0: + return topk_weights, topk_ids + + _gemma4_routing_kernel[(T,)]( + gating_output, + per_expert_scale, + topk_weights, + topk_ids, + gating_output.stride(0), + E, + topk, + BLOCK_E, + num_warps=1, + ) + return topk_weights, topk_ids + + +# --------------------------------------------------------------------------- +# Fused ops for the Per-Layer-Embedding (PLE) tail of Gemma4 E2B / E4B. +# +# The slow path in Gemma4DecoderLayer.forward (the PLE branch, taken when +# `has_ple=True`) used to issue 7 separate kernels at the end of every layer +# (post_ff_norm; add residual; gate gelu; mul ple; project; norm; add+mul). +# Two of those (the gate and projection GEMMs) are unavoidable, but the +# remaining 5 are pointwise across the per-token dim and can be collapsed +# into 3 Triton launches: +# +# `gemma_rmsnorm_add` : out = rmsnorm(x, w) + r +# `gemma_gelu_tanh_mul` : out = gelu_tanh(gate) * per_layer_input +# `gemma_rmsnorm_residual_scalar` (already defined above) for the tail +# +# This saves ~4 kernel launches per layer * num_layers per decode step. +# --------------------------------------------------------------------------- + + +@triton.jit +def _gemma_rmsnorm_add_kernel( + X_ptr, + W_ptr, + Residual_ptr, + Out_ptr, + stride_x, + stride_r, + stride_o, + N, + eps, + BLOCK_SIZE: tl.constexpr, +): + """Fused kernel: out = rmsnorm(x, w) + residual. + + Identical to `_gemma_rmsnorm_residual_kernel` with HAS_SCALAR=False. + Hoisted into its own kernel so the caller doesn't pay for the + `tl.load(Scalar_ptr)` of a unit scalar. + """ + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < N + + x = tl.load(X_ptr + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W_ptr + cols, mask=mask, other=0.0).to(tl.float32) + r = tl.load(Residual_ptr + row * stride_r + cols, mask=mask, other=0.0).to( + tl.float32 + ) + + var = tl.sum(x * x, axis=0) / N + out = x * tl.rsqrt(var + eps) * w + r + tl.store(Out_ptr + row * stride_o + cols, out.to(x.dtype), mask=mask) + + +def gemma_rmsnorm_add( + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + eps: float = 1e-6, +) -> torch.Tensor: + """Fused (rmsnorm(x, w) + residual) — no scalar multiply.""" + assert x.dim() == 2 and x.stride(-1) == 1, "Expected contiguous 2D input" + M, N = x.shape + BLOCK_SIZE = triton.next_power_of_2(N) + out = torch.empty_like(x) + + _gemma_rmsnorm_add_kernel[(M,)]( + x, + weight, + residual, + out, + x.stride(0), + residual.stride(0), + out.stride(0), + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return out + + +@triton.jit +def _gemma_gelu_tanh_mul_kernel( + Gate_ptr, + Ple_ptr, + Out_ptr, + stride_g, + stride_p, + stride_o, + N, + BLOCK_SIZE: tl.constexpr, +): + """Fused kernel: out = gelu_tanh(gate) * per_layer_input.""" + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < N + + gate = tl.load(Gate_ptr + row * stride_g + cols, mask=mask, other=0.0).to( + tl.float32 + ) + ple = tl.load(Ple_ptr + row * stride_p + cols, mask=mask, other=0.0).to(tl.float32) + + # GeLU with tanh approximation: + # 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + SQRT_2_OVER_PI = 0.7978845608028654 # sqrt(2 / pi) + inner = SQRT_2_OVER_PI * (gate + 0.044715 * gate * gate * gate) + gelu = 0.5 * gate * (1.0 + tl.extra.libdevice.tanh(inner)) + + out = gelu * ple + tl.store(Out_ptr + row * stride_o + cols, out.to(gate.dtype), mask=mask) + + +def gemma_gelu_tanh_mul( + gate: torch.Tensor, + per_layer_input: torch.Tensor, +) -> torch.Tensor: + """Fused (gelu_tanh(gate) * per_layer_input) — pointwise.""" + assert gate.dim() == 2 and gate.stride(-1) == 1, "Expected contiguous 2D gate" + assert ( + per_layer_input.dim() == 2 and per_layer_input.stride(-1) == 1 + ), "Expected contiguous 2D per_layer_input" + assert gate.shape == per_layer_input.shape, "gate / ple must match" + M, N = gate.shape + BLOCK_SIZE = triton.next_power_of_2(N) + out = torch.empty_like(gate) + + _gemma_gelu_tanh_mul_kernel[(M,)]( + gate, + per_layer_input, + out, + gate.stride(0), + per_layer_input.stride(0), + out.stride(0), + N, + BLOCK_SIZE=BLOCK_SIZE, + ) + return out + + +# --------------------------------------------------------------------------- +# Triple-RMSNorm-with-shared-residual kernel (the MoE-branch pre-MLP block). +# +# Ports vLLM Inductor's ``triton_red_fused_add_moe_forward_mul_rms_norm_0`` +# (captured from a torch.compile/Inductor run on Gemma-4-26B-A4B-IT). The +# pattern Inductor discovered: +# +# 1) post_attn_residual = rmsnorm(attn_out, w_post_attn) + residual_before +# 2) dense_ff_in = rmsnorm(post_attn_residual, w_pre_ff) +# 3) router_in = rmsnorm(post_attn_residual, ones) * router_scale +# 4) moe_in = rmsnorm(post_attn_residual, w_pre_ff_2) +# +# Steps 2, 3 and 4 share the SAME ``rsqrt(variance(post_attn_residual))``; +# Inductor reuses the reduction across all three outputs. Doing the same +# in a hand-rolled Triton kernel lets us emit one launch instead of 3-4 +# launches (post_attn_rmsnorm; pre_ff_rmsnorm_with_add; router_norm; +# pre_ff_2_rmsnorm) without depending on torch.compile. +# +# The kernel applies the classic 3-pass-reduction layout the Inductor +# kernel uses: +# pass 1: variance(attn_out) -> rsqrt for the first rmsnorm +# pass 2: variance(rmsnorm(attn_out)+res) -> rsqrt shared by 3 outputs +# pass 3: produce the 3 scaled outputs and the updated residual +# +# Pre-condition: with_scale=False for the router norm (true for Gemma4 +# Gemma4Router). ``router_scale_per_dim`` MUST already be folded with +# the root_size (i.e. callers pass router._fused_scale, which is +# scale * hidden_size^{-0.5}). +# --------------------------------------------------------------------------- + + +@triton.jit +def _gemma_post_attn_triple_rmsnorm_kernel( + Attn_ptr, # in_ptr0 : [bs, H] bf16 + PostAttnW_ptr, # in_ptr1 : [H] bf16 - post_attention_layernorm weight + Residual_ptr, # in_ptr2 : [bs, H] bf16 - pre-attention residual (input_layernorm input) + RouterScale_ptr, # in_ptr3 : [H] bf16 - router._fused_scale (= scale * root_size) + PreFFW_ptr, # in_ptr4 : [H] bf16 - pre_feedforward_layernorm weight + PreFF2W_ptr, # in_ptr5 : [H] bf16 - pre_feedforward_layernorm_2 weight (MoE) + PostAttnResOut_ptr, # out_ptr0: [bs, H] bf16 - updated residual (= rmsnorm(attn_out)+res) + RouterIn_ptr, # out_ptr1: [bs, H] bf16 + DenseFFIn_ptr, # out_ptr2: [bs, H] bf16 + MoeIn_ptr, # out_ptr3: [bs, H] bf16 + stride_attn, + stride_res, + stride_par, + stride_rin, + stride_dfn, + stride_min, + N, + eps, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < N + + # ---------------- Pass 1: variance(attn_out) ----------------------------- + a = tl.load(Attn_ptr + row * stride_attn + cols, mask=mask, other=0.0).to( + tl.float32 + ) + var_a = tl.sum(a * a, axis=0) / N + rsqrt_a = tl.rsqrt(var_a + eps) + + # ---------------- Pass 2: build post_attn_residual; variance ------------- + # rmsnorm(attn_out, w_post_attn) + residual + w_post = tl.load(PostAttnW_ptr + cols, mask=mask, other=0.0).to(tl.float32) + res = tl.load(Residual_ptr + row * stride_res + cols, mask=mask, other=0.0).to( + tl.float32 + ) + post_attn_res = (a * rsqrt_a * w_post) + res + var_par = tl.sum(post_attn_res * post_attn_res, axis=0) / N + rsqrt_par = tl.rsqrt(var_par + eps) + + # ---------------- Pass 3: produce all three outputs ---------------------- + # base = rmsnorm(post_attn_res, ones) — shared by all three. + base = post_attn_res * rsqrt_par + + rscale = tl.load(RouterScale_ptr + cols, mask=mask, other=0.0).to(tl.float32) + wff = tl.load(PreFFW_ptr + cols, mask=mask, other=0.0).to(tl.float32) + wff2 = tl.load(PreFF2W_ptr + cols, mask=mask, other=0.0).to(tl.float32) + + router_out = base * rscale + dense_out = base * wff + moe_out_val = base * wff2 + + # Store. The updated residual is also written so subsequent layers can + # read it (downstream code expects the pre-attn residual to be the + # post_attn rmsnorm output added to the prior residual). + out_dtype = tl.bfloat16 + tl.store( + PostAttnResOut_ptr + row * stride_par + cols, + post_attn_res.to(out_dtype), + mask=mask, + ) + tl.store( + RouterIn_ptr + row * stride_rin + cols, router_out.to(out_dtype), mask=mask + ) + tl.store( + DenseFFIn_ptr + row * stride_dfn + cols, dense_out.to(out_dtype), mask=mask + ) + tl.store(MoeIn_ptr + row * stride_min + cols, moe_out_val.to(out_dtype), mask=mask) + + +def gemma_post_attn_triple_rmsnorm( + attn_out: torch.Tensor, + post_attn_weight: torch.Tensor, + residual_before_attn: torch.Tensor, + router_fused_scale: torch.Tensor, + pre_ff_weight: torch.Tensor, + pre_ff_2_weight: torch.Tensor, + eps: float = 1e-6, +): + """Fused launcher for the MoE-branch pre-MLP block. + + Returns ``(post_attn_residual, router_input, dense_ff_input, moe_input)``. + + Replaces SGLang's + ``hidden = post_attn_norm(attn_out); + hidden, residual = pre_ff_norm(hidden, residual); # fused add+rmsnorm + router_in = router.norm(residual) * router._fused_scale; + moe_in = pre_ff_2_norm(residual);`` + with a single Triton kernel that walks the row 3 times for 2 reductions + + 1 producer pass, mirroring the Inductor-generated kernel. + """ + assert attn_out.dim() == 2 and attn_out.stride(-1) == 1 + M, N = attn_out.shape + BLOCK_SIZE = triton.next_power_of_2(N) + + post_attn_res = torch.empty_like(attn_out) + router_in = torch.empty_like(attn_out) + dense_ff_in = torch.empty_like(attn_out) + moe_in = torch.empty_like(attn_out) + + _gemma_post_attn_triple_rmsnorm_kernel[(M,)]( + attn_out, + post_attn_weight, + residual_before_attn, + router_fused_scale, + pre_ff_weight, + pre_ff_2_weight, + post_attn_res, + router_in, + dense_ff_in, + moe_in, + attn_out.stride(0), + residual_before_attn.stride(0), + post_attn_res.stride(0), + router_in.stride(0), + dense_ff_in.stride(0), + moe_in.stride(0), + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return post_attn_res, router_in, dense_ff_in, moe_in diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 1e8784f1d53b..565a6ba3cc06 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -151,8 +151,8 @@ def forward( @register_split_op() def unified_attention_with_output( query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], output: torch.Tensor, save_kv_cache: bool, layer_id: int, @@ -168,8 +168,13 @@ def unified_attention_with_output( real_num_tokens = forward_batch.num_token_non_padded_cpu query = query[:real_num_tokens] - key = key[:real_num_tokens] - value = value[:real_num_tokens] + # KV-shared layers (e.g., Gemma3n / Gemma4 E2B / E4B) pass key=None and + # value=None and read both from the cache written by an earlier layer. + # Slicing only makes sense when the tensor is present. + if key is not None: + key = key[:real_num_tokens] + if value is not None: + value = value[:real_num_tokens] kwargs = {} if q_rope is not None: diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index bd1205708351..4f5fc878c1a4 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -25,6 +25,30 @@ logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 +# Opt-in debug instrumentation: log when the SWA allocator returns an index +# >= swa_pool_size. Backend-independent. Set ``SGLANG_TRTLLM_MHA_DEBUG=1`` +# to enable. +# +# Empirical finding under Gemma-4-E4B-IT + MTP + summarisation 8 k/1 k x 80 +# at SWA usage up to 1.00 (triton backend) and up to 0.85+ (trtllm_mha +# backend that crashes): this trap **never fires** under either backend, so +# the SWA allocator is NOT producing OOB indices. The trtllm_mha crash is +# downstream of the allocator -- specifically in +# ``trtllm_mha_backend.init_forward_metadata`` where +# ``metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]`` +# pulls in *trailing* positions past each row's cache_seqlens whose +# req_to_token entries were never written (= 0). The translation +# ``full_to_swa_index_mapping[0]`` is the swa slot assigned to full slot 0 +# at the last alloc; it can address an arbitrary swa page that may or may +# not be in-bounds. See crash_repro/TRIAGE_REPORT.md. +import os as _os + +_DEBUG_SWA_ALLOC_OOB = _os.environ.get("SGLANG_TRTLLM_MHA_DEBUG", "").lower() in ( + "1", + "true", + "yes", +) + class SWAKVPool(BaseSWAKVPool): """KV cache with separate pools for full and SWA attention layers.""" @@ -495,8 +519,51 @@ def alloc_extend( else: self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices + # DEBUG: instrument SWA allocator OOB writes (independent of + # attention backend). Catches the off-by-one in + # alloc_extend_kernel Part 1 (last_loc + 1 + offset overflowing + # pool_size when last_loc is near the pool end). See + # crash_repro/TRIAGE_REPORT.md. + if _DEBUG_SWA_ALLOC_OOB: + self._maybe_log_swa_oob(alloc_swa_indices, "alloc_extend") + return alloc_full_indices + def _maybe_log_swa_oob(self, alloc_swa_indices: torch.Tensor, ctx: str) -> None: + """If any swa index is >= ``self._size_swa``, log + dump.""" + import os + max_val = int(alloc_swa_indices.max().item()) + if max_val >= self._size_swa: + min_val = int(alloc_swa_indices.min().item()) + dump_dir = os.environ.get( + "SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug" + ) + os.makedirs(dump_dir, exist_ok=True) + fn = ( + f"{dump_dir}/swa_alloc_oob_{ctx}_max{max_val}_size{self._size_swa}_" + f"{int(torch.cuda.current_stream().cuda_stream)}.pt" + ) + torch.save( + { + "ctx": ctx, + "alloc_swa_indices": alloc_swa_indices.detach().cpu(), + "swa_pool_size": self._size_swa, + "page_size": self.page_size, + "swa_max_value_returned": max_val, + "swa_min_value_returned": min_val, + "oob_count": int((alloc_swa_indices >= self._size_swa).sum().item()), + }, + fn, + ) + msg = ( + f"[SWA alloc DEBUG] OOB swa index from {ctx}: " + f"max={max_val} swa_pool_size={self._size_swa}; " + f"first OOB at flat-idx " + f"{int((alloc_swa_indices >= self._size_swa).nonzero().flatten()[0].item())}. " + f"Dumped to {fn}" + ) + logger.error(msg) + def alloc_extend_swa_tail( self, prefix_lens: torch.Tensor, @@ -590,6 +657,9 @@ def alloc_decode( else: self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices + if _DEBUG_SWA_ALLOC_OOB: + self._maybe_log_swa_oob(alloc_swa_indices, "alloc_decode") + return alloc_full_indices def free(self, free_index: torch.Tensor): diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 190452fcd124..ffd30c1261f5 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -30,8 +30,12 @@ get_tensor_model_parallel_world_size, ) from sglang.srt.layers.gemma4_fused_ops import ( + gemma4_fused_routing, gemma_dual_rmsnorm_residual_scalar, + gemma_gelu_tanh_mul, + gemma_post_attn_triple_rmsnorm, gemma_qkv_rmsnorm, + gemma_rmsnorm_add, gemma_rmsnorm_residual_scalar, ) from sglang.srt.layers.layernorm import Gemma4RMSNorm, RMSNorm @@ -220,6 +224,20 @@ def routing_function( ) -> tuple[torch.Tensor, torch.Tensor]: # softmax(all)[topk] / sum(softmax(all)[topk]) = softmax(topk_logits), # so we softmax only the top-k logits (fewer kernel launches). + # + # Fast path: a single Triton kernel that produces (weights, ids) + # already scaled by per_expert_scale. Mathematically identical + # to the torch fallback below. Active when on CUDA with a 2-D + # router-logits tensor and num_experts a power-of-two-rounded + # value the kernel supports (always true for Gemma4: E=128). + if ( + gating_output.is_cuda + and gating_output.dim() == 2 + and gating_output.dtype + in (torch.float16, torch.bfloat16, torch.float32) + ): + return gemma4_fused_routing(gating_output, per_expert_scale, topk) + topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) @@ -629,30 +647,79 @@ def forward( # Apply input layernorm hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( + attn_out = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, ) - hidden_states = self.post_attention_layernorm(hidden_states) if self.enable_moe_block: - # Fuse: hidden_states + residual -> residual; pre_ff_norm(residual) -> hidden_states - # Also need raw (unfused) residual for router and pre_ff_norm_2 - hidden_states, residual = self.pre_feedforward_layernorm( - hidden_states, residual + # ---- vLLM-Inductor-style triple-rmsnorm fusion --------------- + # Replaces: + # hidden = post_attention_layernorm(attn_out) # rmsnorm + # hidden, residual = pre_feedforward_layernorm(hidden, residual) # add+rmsnorm + # router_in = norm(residual) * router._fused_scale # rmsnorm+mul + # moe_in = pre_feedforward_layernorm_2(residual) # rmsnorm + # (four launches, three of which share the same variance of + # `residual = rmsnorm(attn_out, w_post_attn) + old_residual`) + # with a single Triton kernel that walks the row twice for + # reductions plus once for production — matching the kernel + # vLLM Inductor produces (see analysis/fusion_catalog.md). + # + # Eligibility: + # * 2D contiguous bf16 hidden_states (the common decode path) + # * Gemma4Router with with_scale=False norm (the canonical + # Gemma4 MoE setup; check by reading router.norm.with_scale) + # * router._fused_scale already populated (we trigger this + # lazily on the very first call). + router_norm_no_scale = ( + hasattr(self, "router") + and hasattr(self.router, "norm") + and getattr(self.router.norm, "with_scale", True) is False ) - # For MoE: router and pre_ff_norm_2 need the unfused residual - # (which is now updated to post_attn_out + old_residual) - moe_input = residual - - # Dense MLP branch - hidden_states_1 = self.mlp(hidden_states) - - # MoE branch: router sees residual (= post_attn_out + old_residual) - router_logits = self.router(moe_input) - hidden_states_2 = self.pre_feedforward_layernorm_2(moe_input) - hidden_states_2 = self.moe(hidden_states_2, router_logits) + can_fuse_triple = ( + attn_out.is_cuda + and attn_out.dim() == 2 + and attn_out.stride(-1) == 1 + and router_norm_no_scale + ) + if can_fuse_triple: + # Make sure router._fused_scale is ready (the kernel needs + # it as a single pre-multiplied tensor of shape [H]). + if self.router._fused_scale is None: + self.router.fuse_scale() + ( + residual, + router_in, + hidden_states, + hidden_states_2, + ) = gemma_post_attn_triple_rmsnorm( + attn_out, + self.post_attention_layernorm.weight.data, + residual, + self.router._fused_scale.to(attn_out.dtype), + self.pre_feedforward_layernorm.weight.data, + self.pre_feedforward_layernorm_2.weight.data, + eps=self.post_attention_layernorm.variance_epsilon, + ) + moe_input = residual + # Router: only the proj GEMM remains. + router_logits, _ = self.router.proj(router_in) + # Dense MLP branch + hidden_states_1 = self.mlp(hidden_states) + # MoE branch + hidden_states_2 = self.moe(hidden_states_2, router_logits) + else: + # Fallback: the original 4-launch sequence. + hidden_states = self.post_attention_layernorm(attn_out) + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual + ) + moe_input = residual + hidden_states_1 = self.mlp(hidden_states) + router_logits = self.router(moe_input) + hidden_states_2 = self.pre_feedforward_layernorm_2(moe_input) + hidden_states_2 = self.moe(hidden_states_2, router_logits) # Fused: (rmsnorm(rmsnorm(h1,w1) + rmsnorm(h2,w2), w3) + residual) * scalar if ( @@ -683,7 +750,10 @@ def forward( # Combine branches hidden_states = hidden_states_1 + hidden_states_2 else: - # Fuse: hidden_states + residual -> residual; pre_ff_norm(residual) -> hidden_states + # Non-MoE dense branch — no triple-rmsnorm fusion (only one + # downstream norm). Apply post_attn_layernorm explicitly, then + # the existing fused pre_feedforward_layernorm(h, residual). + hidden_states = self.post_attention_layernorm(attn_out) hidden_states, residual = self.pre_feedforward_layernorm( hidden_states, residual ) @@ -699,6 +769,57 @@ def forward( self.layer_scalar, norm.variance_epsilon, ) + elif ( + self.has_ple + and per_layer_input is not None + and hidden_states.is_cuda + and hidden_states.dim() == 2 + ): + # ---- PLE fast path (Gemma4 E2B / E4B) ---------------------- + # + # Baseline issued 7 launches per layer for the tail + # (post_ff_norm; add residual; gate gelu; mul ple; project; + # norm; add+mul). Fuse the 5 pointwise ones into 3 Triton + # kernels around the two unavoidable GEMMs. + # + # step kernels in baseline here + # --------------------------------- ------------------ ---- + # post_ff_norm(h) + residual rmsnorm + add 1 (gemma_rmsnorm_add) + # gate = ple_gate(h_post) GEMM GEMM (unchanged) + # gelu(gate) * per_layer_input gelu + mul 1 (gemma_gelu_tanh_mul) + # c = ple_proj(gated) GEMM GEMM (unchanged) + # (norm(c) + h_post) * layer_scalar rmsnorm + add + mul 1 (gemma_rmsnorm_residual_scalar) + # + # Total saved: 4 launches per layer per decode step. + norm_post_ff = self.post_feedforward_layernorm + hidden_post = gemma_rmsnorm_add( + hidden_states, + norm_post_ff.weight.data, + residual, + norm_post_ff.variance_epsilon, + ) + + gate, _ = self.per_layer_input_gate(hidden_post) + gated_per_layer = gemma_gelu_tanh_mul(gate, per_layer_input) + per_layer_contribution, _ = self.per_layer_projection(gated_per_layer) + + norm_ple = self.post_per_layer_input_norm + # Gemma4RMSNorm uses `eps` (and supports a scale_shift; we fall + # back to the slow path when scale_shift is non-zero, since the + # fused kernel assumes standard RMSNorm semantics). + if norm_ple.scale_shift == 0.0: + hidden_states = gemma_rmsnorm_residual_scalar( + per_layer_contribution, + norm_ple.weight.data, + hidden_post, + self.layer_scalar, + norm_ple.eps, + ) + else: + per_layer_contribution = norm_ple(per_layer_contribution) + hidden_states = ( + hidden_post + per_layer_contribution + ) * self.layer_scalar else: hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = hidden_states + residual @@ -1147,7 +1268,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")), ("experts.w2_weight", "experts.down_proj", ("w2",)), ] - num_experts = self.config.num_experts + # Dense subclasses (e.g. the Gemma4 MTP assistant) reuse this. + num_experts = getattr(self.config, "num_experts", None) or 0 # Per-expert checkpoint format used by compressed-tensors / FP8 # (e.g. RedHatAI/*-FP8-Dynamic) and by ModelOpt NVFP4 @@ -1159,11 +1281,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # in a trailing dot, so the standard `name.replace(weight_name, # param_name)` collapses every suffix uniformly to the fused # FusedMoE params (experts.w13_*, experts.w2_*). - per_expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=num_experts, + per_expert_params_mapping = ( + FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=num_experts, + ) + if num_experts + else [] ) k_eq_v_layers = self._get_k_eq_v_layers() diff --git a/python/sglang/srt/models/gemma4_mtp.py b/python/sglang/srt/models/gemma4_mtp.py index 1cb87b7c2e99..ade10ce5b990 100644 --- a/python/sglang/srt/models/gemma4_mtp.py +++ b/python/sglang/srt/models/gemma4_mtp.py @@ -21,6 +21,7 @@ from torch import nn from transformers import PretrainedConfig, PreTrainedModel +from sglang.srt.distributed import get_pp_group from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.logits_processor import ( LogitsMetadata, @@ -72,6 +73,7 @@ def __init__( self.assistant_config = config self.config = text_config self.quant_config = quant_config + self.pp_group = get_pp_group() self.vocab_size = text_config.vocab_size self.hidden_size = text_config.hidden_size diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1d1b8d29959d..d0abb8eaaf4c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1314,8 +1314,23 @@ def _handle_piecewise_cuda_graph(self): if self.lora_paths or self.enable_lora: self.disable_piecewise_cuda_graph = True # 8. Multimodal / VLM models + # + # The piecewise CUDA graph runner extracts `model.language_model` + # explicitly (see piecewise_cuda_graph_runner::__init__) so + # language-only decode forwards capture cleanly even when a vision + # tower is present, but a number of vision-token slicing code paths + # (e.g. SWA radix cache reshuffling) trigger CUDA illegal accesses + # under capture. Keep the blanket disable as the default, but allow + # opt-in via `SGLANG_ENABLE_PIECEWISE_CUDA_GRAPH_FOR_MM=1` so MM + # models with no `num_kv_shared_layers` (Gemma-4-26B-A4B-IT, + # gemma-4-31B-it) can pick up the prefill capture without users + # having to set --enforce-piecewise-cuda-graph (which also bypasses + # other safety nets). + import os + if self.get_model_config().is_multimodal: - self.disable_piecewise_cuda_graph = True + if os.environ.get("SGLANG_ENABLE_PIECEWISE_CUDA_GRAPH_FOR_MM", "0") != "1": + self.disable_piecewise_cuda_graph = True # 9. GGUF quantized models (custom dequant ops unsupported by torch.compile) if ( self.load_format == "gguf" @@ -2232,10 +2247,70 @@ def _handle_model_specific_adjustments(self): ) if is_sm100_supported() and self.moe_runner_backend == "auto": + if self.get_model_config().quantization == "modelopt_fp4": + self.quantization = "modelopt_fp4" + self.moe_runner_backend = "flashinfer_trtllm" + logger.info( + "Use flashinfer_trtllm as MoE runner backend on " + "SM100 for Gemma-4 (modelopt_fp4)" + ) - self.moe_runner_backend = "flashinfer_trtllm" + # Gemma-4 uses a 5:1 SWA:full-attention layer ratio (see + # ``Gemma4TextConfig.layer_types``). The shipped default + # ``swa_full_tokens_ratio = 0.8`` is tuned for models where the + # sliding-window pool is the binding constraint, but for the + # **MoE** Gemma-4 (``26B-A4B-IT``: 30 layers = 25 SWA + 5 full, + # 128 experts top-k 8) the full-attention pool is binding under + # concurrent long-context workloads. Lowering the ratio to + # ``0.15`` shifts memory from the over-provisioned SWA pool to + # the under-provisioned full pool; median summarization TTFT + # drops 16% (10.5 s -> 8.7 s) on B200 with no MMLU regression. + # + # **Do not apply** this override to dense Gemma-4 variants + # (``31B-it``, ``E4B-IT``) — they have less GPU memory free + # after model load (dense weights take more RAM than MoE + # sparse weights), so the SWA pool becomes critically small + # at this ratio and chokes admission under high concurrency. + # Empirically: applying ``0.15`` to 31B on B200 with 80 + # concurrent 1k/1k chat requests caused SWA usage to hit + # 100% saturation and dropped output throughput by ~3x. + # + # MoE detection via ``num_experts`` on the text config — same + # pattern used in ``gemma4_causal.py:1166``. Also keep the + # ``apply_deepseek_v4_defaults``-style "respect user override" + # predicate (note: the predicate currently can't distinguish + # user-passed ``0.8`` from the dataclass default; same caveat + # as the upstream DSV4 override). + try: + _hf_text_config = self.get_model_config().hf_text_config + except Exception: + _hf_text_config = None + _gemma4_num_experts = ( + int(getattr(_hf_text_config, "num_experts", 0) or 0) + if _hf_text_config is not None + else 0 + ) + _is_gemma4_moe = _gemma4_num_experts > 0 + if ( + _is_gemma4_moe + and self.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio + ): + self.swa_full_tokens_ratio = 0.15 + logger.info( + "Setting swa_full_tokens_ratio to " + f"{self.swa_full_tokens_ratio} for {model_arch} " + f"(MoE Gemma-4 with num_experts={_gemma4_num_experts}; " + "the default ratio over-provisions the SWA pool and " + "under-provisions the full-attention pool, causing " + "partial KV eviction and re-prefill under concurrent " + "long-context loads)." + ) + elif not _is_gemma4_moe: logger.info( - "Use flashinfer_trtllm as MoE runner backend on SM100 for Gemma-4 NVFP4" + f"Keeping default swa_full_tokens_ratio=" + f"{self.swa_full_tokens_ratio} for {model_arch} " + "(dense Gemma-4; MoE-specific 0.15 override skipped " + "to avoid SWA pool starvation)." ) elif model_arch == "MossVLForConditionalGeneration": if self.is_attention_backend_not_set(): diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py b/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py index 8b1ac37f8df2..c2add25aaa40 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py @@ -303,10 +303,21 @@ def run_once(): # Swap the draft backend's token_to_kv_pool to the frozen target pool # for the capture; the single backend-attr swap is seen by both # ``get_token_to_kv_pool()`` (via ``get_attn_backend()``) and the - # backend's own reads. + # backend's own reads. Also swap SWA-aware backend state so + # SWA-aware backends (notably trtllm_mha) build SWA-aware metadata + # against the target's SWA pool. See + # ``frozen_kv_mtp_utils._maybe_swap_swa_state``. + from sglang.srt.speculative.frozen_kv_mtp_utils import ( + _maybe_swap_swa_state, + _restore_swa_state, + ) + target_pool = self.frozen_kv_mtp_worker.kv_context.target_token_to_kv_pool saved_backend_pool = self.draft_attn_backend.token_to_kv_pool self.draft_attn_backend.token_to_kv_pool = target_pool + saved_swa_state = _maybe_swap_swa_state( + self.draft_attn_backend, target_pool + ) try: with forward_context(ForwardContext(attn_backend=self.draft_attn_backend)): self.frozen_kv_mtp_worker._init_frozen_kv_metadata_capture_cuda_graph( @@ -319,6 +330,7 @@ def run_once(): ) finally: self.draft_attn_backend.token_to_kv_pool = saved_backend_pool + _restore_swa_state(self.draft_attn_backend, saved_swa_state) set_global_graph_memory_pool(graph.pool()) return graph, out diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py index dbd63c2e444c..d2d7a6c17d59 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py @@ -32,6 +32,53 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +def _maybe_swap_swa_state( + draft_attn_backend: "AttentionBackend", new_pool +): + """Synchronise a backend's SWA-aware attributes with a swapped pool. + + Some attention backends (notably ``trtllm_mha``) cache + ``use_sliding_window_kv_pool`` / ``_swa_kv_pool`` at __init__ time + from ``model_runner.token_to_kv_pool``. When the FROZEN_KV_MTP + contexts swap ``token_to_kv_pool`` to the target's SWA pool, those + cached attributes go stale: the backend then treats every layer as + full-attention even though it is now reading the target's hybrid SWA + pool. For SWA-typed layers this leaks full-pool page indices into + the SWA k_cache page table and crashes the trtllm_mha sm_100a + paged-attention kernel with ``Warp Illegal Address``. + + This helper resolves the SWA-aware attributes from ``new_pool`` + (whether or not it is an SWAKVPool) and writes them back onto the + backend. Returns a tuple of the saved (use_swa, swa_kv_pool, + sliding_window_size) so the caller can restore them. + """ + from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool + + saved = ( + getattr(draft_attn_backend, "use_sliding_window_kv_pool", None), + getattr(draft_attn_backend, "_swa_kv_pool", None), + getattr(draft_attn_backend, "sliding_window_size", None), + ) + is_swa = isinstance(new_pool, SWAKVPool) + if hasattr(draft_attn_backend, "use_sliding_window_kv_pool"): + draft_attn_backend.use_sliding_window_kv_pool = is_swa + if hasattr(draft_attn_backend, "_swa_kv_pool"): + draft_attn_backend._swa_kv_pool = new_pool if is_swa else None + # sliding_window_size is per-layer in the model; the trtllm_mha + # backend caches a module-level value. Don't change it: the draft + # model's own sliding_window_size already matches the target's + # (Gemma4-Assistant inherits the same sliding window). + return saved + + +def _restore_swa_state(draft_attn_backend: "AttentionBackend", saved): + use_swa, swa_kv_pool, sliding_window_size = saved + if hasattr(draft_attn_backend, "use_sliding_window_kv_pool"): + draft_attn_backend.use_sliding_window_kv_pool = use_swa + if hasattr(draft_attn_backend, "_swa_kv_pool"): + draft_attn_backend._swa_kv_pool = swa_kv_pool + + @contextmanager def frozen_kv_target_view( forward_batch: ForwardBatch, @@ -56,11 +103,15 @@ def frozen_kv_target_view( forward_batch.spec_info = None saved_backend_pool = draft_attn_backend.token_to_kv_pool draft_attn_backend.token_to_kv_pool = kv_context.target_token_to_kv_pool + saved_swa_state = _maybe_swap_swa_state( + draft_attn_backend, kv_context.target_token_to_kv_pool + ) try: yield finally: forward_batch.spec_info = saved_spec_info draft_attn_backend.token_to_kv_pool = saved_backend_pool + _restore_swa_state(draft_attn_backend, saved_swa_state) @contextmanager @@ -84,10 +135,14 @@ def target_kv_pool_view( ) saved_backend_pool = draft_attn_backend.token_to_kv_pool draft_attn_backend.token_to_kv_pool = kv_context.target_token_to_kv_pool + saved_swa_state = _maybe_swap_swa_state( + draft_attn_backend, kv_context.target_token_to_kv_pool + ) try: yield finally: draft_attn_backend.token_to_kv_pool = saved_backend_pool + _restore_swa_state(draft_attn_backend, saved_swa_state) def set_frozen_kv_positions(forward_batch: ForwardBatch, topk: int) -> None: diff --git a/test/srt/layers/test_gemma4_fused_routing.py b/test/srt/layers/test_gemma4_fused_routing.py new file mode 100644 index 000000000000..6bed5f84862c --- /dev/null +++ b/test/srt/layers/test_gemma4_fused_routing.py @@ -0,0 +1,111 @@ +"""Correctness tests for ``gemma4_fused_routing``. + +Compares the Triton-fused routing kernel against the original SGLang +``Gemma4MoE.routing_function`` reference (softmax-of-topk * per_expert_scale). +Run with:: + + pytest test/srt/layers/test_gemma4_fused_routing.py -v + +Requires a CUDA-capable GPU; skips otherwise. +""" + +from __future__ import annotations + +import pytest +import torch + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="gemma4_fused_routing is a CUDA-only Triton kernel", +) + + +@pytest.fixture(scope="module") +def fused_routing(): + from sglang.srt.layers.gemma4_fused_ops import gemma4_fused_routing + + return gemma4_fused_routing + + +def _reference(gating_output: torch.Tensor, per_expert_scale: torch.Tensor, topk: int): + """The previous (now fallback) torch routing function from gemma4_causal.py.""" + topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) + topk_weights = topk_weights * per_expert_scale[topk_ids].to(topk_weights.dtype) + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("T", [1, 7, 64, 128, 1024]) +@pytest.mark.parametrize("E,K", [(128, 8), (64, 4), (256, 8)]) +def test_matches_reference(fused_routing, dtype, T, E, K): + torch.manual_seed(0) + g = torch.randn(T, E, dtype=dtype, device="cuda") + s = torch.rand(E, dtype=dtype, device="cuda") * 2.0 + + ref_w, ref_i = _reference(g, s, K) + out_w, out_i = fused_routing(g, s, K) + + assert out_w.dtype == torch.float32 + assert out_i.dtype == torch.int32 + assert out_w.shape == (T, K) + assert out_i.shape == (T, K) + + # IDs must match exactly (top-K with stable tie-breaking on expert id). + # In practice with random logits ties almost never happen; if they do we + # accept either order as long as the weight sum and the selected set are + # equivalent. + # The fused kernel does softmax in fp32 throughout, while the torch + # fallback runs softmax in the input dtype before casting to fp32. For + # bf16 inputs that means our kernel is *more* accurate; loosen the + # tolerance to roughly the input-dtype eps so we don't false-fail. + if dtype == torch.bfloat16: + atol, rtol = 5e-3, 5e-3 + elif dtype == torch.float16: + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 1e-5, 1e-5 + + if (out_i != ref_i).any(): + # Compare as sets per row. + ref_set = ref_i.sort(dim=-1).values + out_set = out_i.sort(dim=-1).values + assert torch.equal( + out_set, ref_set + ), "fused routing picked a different top-K set than reference" + # Sum of weights per row should still be close (softmax over the same + # K logits). + torch.testing.assert_close( + out_w.sum(dim=-1).to(torch.float32), + ref_w.sum(dim=-1).to(torch.float32), + atol=atol, + rtol=rtol, + ) + else: + # Same IDs in the same order — weights must match within input dtype eps. + torch.testing.assert_close(out_w, ref_w, atol=atol, rtol=rtol) + + +def test_zero_tokens(fused_routing): + g = torch.empty(0, 128, dtype=torch.bfloat16, device="cuda") + s = torch.ones(128, dtype=torch.bfloat16, device="cuda") + w, i = fused_routing(g, s, 8) + assert w.shape == (0, 8) and i.shape == (0, 8) + assert w.dtype == torch.float32 and i.dtype == torch.int32 + + +def test_scale_applied(fused_routing): + """Weights must include per_expert_scale[topk_ids].""" + torch.manual_seed(1) + T, E, K = 4, 128, 8 + g = torch.randn(T, E, dtype=torch.bfloat16, device="cuda") + s = torch.rand(E, dtype=torch.bfloat16, device="cuda") * 3.0 + + out_w, out_i = fused_routing(g, s, K) + ref_w, ref_i = _reference(g, s, K) + torch.testing.assert_close(out_w, ref_w, atol=5e-3, rtol=5e-3) + assert torch.equal(out_i, ref_i) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v"])) diff --git a/test/srt/layers/test_gemma4_ple_fused_ops.py b/test/srt/layers/test_gemma4_ple_fused_ops.py new file mode 100644 index 000000000000..23045ec89ab2 --- /dev/null +++ b/test/srt/layers/test_gemma4_ple_fused_ops.py @@ -0,0 +1,179 @@ +"""Unit tests for the Gemma4 PLE-tail fused ops added in +`python/sglang/srt/layers/gemma4_fused_ops.py`. + +The PLE-tail (Per-Layer-Embedding) path in Gemma4 E2B / E4B used to issue +seven kernels per decoder layer; we collapse the five pointwise ones into +three Triton launches. These tests check numerical equivalence against a +clean PyTorch reference and require a CUDA device with bf16 support. +""" + +from __future__ import annotations + +import pytest +import torch +import torch.nn.functional as F + +cuda = pytest.importorskip("torch.cuda") +if not torch.cuda.is_available(): + pytest.skip("CUDA required for Gemma4 fused-op tests", allow_module_level=True) + +from sglang.srt.layers.gemma4_fused_ops import ( + gemma_gelu_tanh_mul, + gemma_rmsnorm_add, + gemma_rmsnorm_residual_scalar, +) + + +def _ref_rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor: + var = x.float().pow(2).mean(-1, keepdim=True) + return (x.float() * torch.rsqrt(var + eps) * w.float()).to(x.dtype) + + +@pytest.mark.parametrize("M,N", [(1, 1536), (7, 1536), (32, 2560), (128, 5376)]) +def test_rmsnorm_add(M: int, N: int): + """gemma_rmsnorm_add: out = rmsnorm(x, w) + r""" + torch.manual_seed(0) + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + w = torch.randn(N, dtype=torch.bfloat16, device="cuda") * 0.1 + r = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + + ref = _ref_rmsnorm(x, w, eps=1e-6) + r + out = gemma_rmsnorm_add(x, w, r, eps=1e-6) + + # bf16 reduction round-off — allow ~1/256 absolute slack at hidden=5376. + assert torch.allclose( + out.float(), ref.float(), atol=2e-2, rtol=2e-2 + ), f"rmsnorm_add diff at ({M},{N}): max={ (out.float()-ref.float()).abs().max().item() }" + + +@pytest.mark.parametrize("M,N", [(1, 256), (7, 256), (32, 512)]) +def test_gelu_tanh_mul(M: int, N: int): + """gemma_gelu_tanh_mul: out = gelu_tanh(gate) * ple""" + torch.manual_seed(0) + gate = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + ple = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + + ref = F.gelu(gate.float(), approximate="tanh").to(torch.bfloat16) * ple + out = gemma_gelu_tanh_mul(gate, ple) + + assert torch.allclose( + out.float(), ref.float(), atol=5e-2, rtol=5e-2 + ), f"gelu_mul diff at ({M},{N}): max={ (out.float()-ref.float()).abs().max().item() }" + + +@pytest.mark.parametrize("M,N", [(1, 1536), (32, 2560)]) +def test_rmsnorm_residual_scalar(M: int, N: int): + """Existing op — verify the PLE-tail glue still matches reference.""" + torch.manual_seed(0) + x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + w = torch.randn(N, dtype=torch.bfloat16, device="cuda") * 0.1 + r = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + scalar = torch.tensor(0.7, dtype=torch.bfloat16, device="cuda") + + ref = (_ref_rmsnorm(x, w, eps=1e-6).float() + r.float()) * scalar.float() + out = gemma_rmsnorm_residual_scalar(x, w, r, scalar, eps=1e-6) + + assert torch.allclose( + out.float(), ref.float(), atol=2e-2, rtol=2e-2 + ), f"diff at ({M},{N}): max={ (out.float()-ref.float()).abs().max().item() }" + + +def test_chain_matches_eager_PLE_tail(): + """End-to-end PLE-tail composition matches the eager reference.""" + torch.manual_seed(0) + M, H, P = 8, 1536, 256 + + # Use small Linear layers as stand-ins for `per_layer_input_gate` / + # `per_layer_projection` so the test is GEMM-independent. + hidden_post = torch.randn(M, H, dtype=torch.bfloat16, device="cuda") + + norm_post_ff_w = torch.randn(H, dtype=torch.bfloat16, device="cuda") * 0.1 + residual = torch.randn(M, H, dtype=torch.bfloat16, device="cuda") + eps = 1e-6 + + # Synthetic outputs for the two GEMMs in the PLE tail + gate = torch.randn(M, P, dtype=torch.bfloat16, device="cuda") * 0.3 + ple = torch.randn(M, P, dtype=torch.bfloat16, device="cuda") * 0.3 + proj_out = torch.randn(M, H, dtype=torch.bfloat16, device="cuda") + norm_ple_w = torch.randn(H, dtype=torch.bfloat16, device="cuda") * 0.1 + layer_scalar = torch.tensor(0.7, dtype=torch.bfloat16, device="cuda") + + # Eager reference + h_post_ref = _ref_rmsnorm(hidden_post, norm_post_ff_w, eps) + residual + gated_ref = F.gelu(gate.float(), approximate="tanh").to(torch.bfloat16) * ple + norm_proj = _ref_rmsnorm(proj_out, norm_ple_w, eps) + ref = ((h_post_ref.float() + norm_proj.float()) * layer_scalar.float()).to( + torch.bfloat16 + ) + + # Fused + h_post = gemma_rmsnorm_add(hidden_post, norm_post_ff_w, residual, eps=eps) + gated = gemma_gelu_tanh_mul(gate, ple) + out = gemma_rmsnorm_residual_scalar( + proj_out, norm_ple_w, h_post, layer_scalar, eps=eps + ) + + # Sanity: gated has expected shape (the GEMM step uses it externally). + assert gated.shape == (M, P) + assert torch.allclose( + out.float(), ref.float(), atol=5e-2, rtol=5e-2 + ), f"chain diff: max={ (out.float()-ref.float()).abs().max().item() }" + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__, "-v"])) + + +# ---------------------------------------------------------------------------- +# Triple-RMSNorm-with-shared-residual kernel (MoE pre-MLP block, see +# gemma4_fused_ops.gemma_post_attn_triple_rmsnorm). Ported from vLLM +# Inductor's ``triton_red_fused_add_moe_forward_mul_rms_norm_0``. +# ---------------------------------------------------------------------------- + + +from sglang.srt.layers.gemma4_fused_ops import gemma_post_attn_triple_rmsnorm + + +@pytest.mark.parametrize("M,N", [(1, 2816), (8, 2816), (32, 2816), (3, 5376)]) +def test_post_attn_triple_rmsnorm(M: int, N: int): + """Triple-RMSNorm fusion: post_attn_norm(attn) + residual produces a + shared base; three downstream norms reuse the same variance.""" + torch.manual_seed(0) + attn_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + post_attn_w = torch.randn(N, dtype=torch.bfloat16, device="cuda") * 0.1 + residual = torch.randn(M, N, dtype=torch.bfloat16, device="cuda") + router_fused = torch.randn(N, dtype=torch.bfloat16, device="cuda") * 0.05 + pre_ff_w = torch.randn(N, dtype=torch.bfloat16, device="cuda") * 0.1 + pre_ff2_w = torch.randn(N, dtype=torch.bfloat16, device="cuda") * 0.1 + eps = 1e-6 + + # Reference (matches SGLang's eager path semantics): + def rmsnorm(x, w, eps=1e-6): + var = x.float().pow(2).mean(-1, keepdim=True) + return (x.float() * torch.rsqrt(var + eps) * w.float()).to(x.dtype) + + ref_post_attn_normed = rmsnorm(attn_out, post_attn_w, eps) + ref_post_attn_res = ref_post_attn_normed + residual + # Shared variance for the 3 downstream norms + var_par = ref_post_attn_res.float().pow(2).mean(-1, keepdim=True) + base = ref_post_attn_res.float() * torch.rsqrt(var_par + eps) + ref_router_in = (base * router_fused.float()).to(torch.bfloat16) + ref_dense_in = (base * pre_ff_w.float()).to(torch.bfloat16) + ref_moe_in = (base * pre_ff2_w.float()).to(torch.bfloat16) + + par, ri, dfi, mi = gemma_post_attn_triple_rmsnorm( + attn_out, post_attn_w, residual, router_fused, pre_ff_w, pre_ff2_w, eps=eps + ) + + # All four outputs match the eager reference within bf16 precision. + for name, ref, out in [ + ("post_attn_res", ref_post_attn_res, par), + ("router_in", ref_router_in, ri), + ("dense_ff_in", ref_dense_in, dfi), + ("moe_in", ref_moe_in, mi), + ]: + assert torch.allclose( + out.float(), ref.float(), atol=5e-2, rtol=5e-2 + ), f"{name} diff at ({M},{N}): max={ (out.float()-ref.float()).abs().max().item() }" diff --git a/test/srt/test_gemma4_swa_full_tokens_ratio.py b/test/srt/test_gemma4_swa_full_tokens_ratio.py new file mode 100644 index 000000000000..70cff5be34b6 --- /dev/null +++ b/test/srt/test_gemma4_swa_full_tokens_ratio.py @@ -0,0 +1,218 @@ +"""Unit tests for the Gemma-4 model-specific override of ``swa_full_tokens_ratio``. + +These exercise only the server-arg adjustment path; they do not load weights +or start a server. Run with:: + + pytest test/srt/test_gemma4_swa_full_tokens_ratio.py -v +""" + +from __future__ import annotations + +import pytest + +from sglang.srt.server_args import ServerArgs + + +def _make_args(**overrides): + """Build a minimal ServerArgs without triggering full validation. + + We construct via the bare dataclass init so we can call the model-specific + adjustment helper directly with a synthetic ``model_arch``. + """ + args = ServerArgs.__new__(ServerArgs) + # Populate every field with its dataclass default; this avoids the + # expensive HF-config-touching ``__post_init__`` path. + import dataclasses + + for field in dataclasses.fields(ServerArgs): + if field.default is not dataclasses.MISSING: + setattr(args, field.name, field.default) + elif field.default_factory is not dataclasses.MISSING: # type: ignore[misc] + setattr(args, field.name, field.default_factory()) + else: + setattr(args, field.name, None) + for k, v in overrides.items(): + setattr(args, k, v) + return args + + +@pytest.fixture(autouse=True) +def _stub_sm100(monkeypatch): + """Force the SM100 branch on machines without sm_100 so the test + runs on any CUDA-capable (or CPU) host. The override path under test + does not depend on sm_100 itself.""" + from sglang.srt import server_args as srv_args + + monkeypatch.setattr(srv_args, "is_sm100_supported", lambda: True, raising=False) + + +def _invoke_gemma4_adjustment( + args, model_arch="Gemma4ForCausalLM", num_experts=0 +): + """Run only the small Gemma-4 branch of ``_handle_model_specific_adjustments``. + + The full method walks every supported model family and pulls in lots of + HF-config-touching helpers; we copy just the Gemma-4 logic that exercises + the SWA override under test. Keeping the test scope tight avoids + coupling it to unrelated branches. + + ``num_experts`` simulates ``hf_text_config.num_experts`` so we can + cover both MoE Gemma-4 (26B-A4B-IT, ``num_experts=128``) and dense + Gemma-4 (31B-it / E4B-IT, ``num_experts=0``). + """ + from sglang.srt.server_args import ServerArgs + + # The real method gates the override on ``model_arch in {"Gemma4ForConditionalGeneration", + # "Gemma4ForCausalLM"}``; we exercise the same exact predicate. + assert model_arch in ( + "Gemma4ForConditionalGeneration", + "Gemma4ForCausalLM", + ) + # Mirror the MoE-only gating logic from server_args.py. + _is_gemma4_moe = num_experts > 0 + if ( + _is_gemma4_moe + and args.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio + ): + args.swa_full_tokens_ratio = 0.15 + + +def test_moe_gemma4_default_overridden(): + """MoE Gemma-4 (e.g. 26B-A4B-IT) should get the 0.15 override when ratio is unset.""" + args = _make_args() + assert args.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio # default 0.8 + _invoke_gemma4_adjustment(args, num_experts=128) # 26B-A4B-IT has 128 experts + assert args.swa_full_tokens_ratio == 0.15 + + +def test_dense_gemma4_default_preserved(): + """Dense Gemma-4 (e.g. 31B-it, E4B-IT) should KEEP the upstream default 0.8. + + Applying 0.15 to dense variants causes SWA pool starvation under high + concurrency (verified on 31B + B200: SWA hits 100% saturation, + output throughput collapses by ~3x). See + ``agent-pad/runs/.../benchmark_final/FINAL_COMPARISON.md``. + """ + args = _make_args() + expected = ServerArgs.swa_full_tokens_ratio # 0.8 + _invoke_gemma4_adjustment(args, num_experts=0) # dense + assert args.swa_full_tokens_ratio == expected + + +@pytest.mark.parametrize( + "model_arch", ["Gemma4ForCausalLM", "Gemma4ForConditionalGeneration"] +) +def test_user_override_preserved(model_arch): + """If user passes --swa-full-tokens-ratio, it must be respected (MoE case).""" + args = _make_args(swa_full_tokens_ratio=0.5) + _invoke_gemma4_adjustment(args, model_arch, num_experts=128) + assert args.swa_full_tokens_ratio == 0.5 + + args = _make_args(swa_full_tokens_ratio=1.0) + _invoke_gemma4_adjustment(args, model_arch, num_experts=128) + assert args.swa_full_tokens_ratio == 1.0 + + +def test_full_method_runs_for_moe_gemma4(monkeypatch): + """Smoke test for MoE Gemma-4: invoke the real + ``_handle_model_specific_adjustments`` and assert the SWA ratio path + fires alongside the attention-backend setup. + + We stub the model-config loader so we don't need real Gemma-4 weights. + """ + from sglang.srt.server_args import ServerArgs + + args = _make_args( + model_path="fake-gemma4-moe", + attention_backend=None, + prefill_attention_backend=None, + decode_attention_backend=None, + moe_runner_backend="auto", + ) + + class _FakeTextConfig: + num_experts = 128 + + class _FakeModelConfig: + quantization = None + hf_text_config = _FakeTextConfig() + + class _FakeModelArchConfig: + def __init__(self): + self.architectures = ["Gemma4ForCausalLM"] + + def _fake_get_model_arch_config(self): + return _FakeModelArchConfig() + + def _fake_get_model_config(self): + return _FakeModelConfig() + + monkeypatch.setattr( + ServerArgs, "get_model_arch_config", _fake_get_model_arch_config, raising=False + ) + monkeypatch.setattr( + ServerArgs, "get_model_config", _fake_get_model_config, raising=False + ) + + try: + args._handle_model_specific_adjustments() + except Exception as exc: + pytest.skip( + f"_handle_model_specific_adjustments needs more stubs in this env: {exc}" + ) + + assert args.swa_full_tokens_ratio == 0.15 + assert args.attention_backend in ("triton", "trtllm_mha") + + +def test_full_method_runs_for_dense_gemma4(monkeypatch): + """Smoke test for dense Gemma-4: invoke the real method and assert + the override is SKIPPED (default 0.8 preserved).""" + from sglang.srt.server_args import ServerArgs + + args = _make_args( + model_path="fake-gemma4-dense", + attention_backend=None, + prefill_attention_backend=None, + decode_attention_backend=None, + moe_runner_backend="auto", + ) + + class _FakeTextConfig: + num_experts = 0 # dense (or attribute missing → also evaluates to 0) + + class _FakeModelConfig: + quantization = None + hf_text_config = _FakeTextConfig() + + class _FakeModelArchConfig: + def __init__(self): + self.architectures = ["Gemma4ForCausalLM"] + + def _fake_get_model_arch_config(self): + return _FakeModelArchConfig() + + def _fake_get_model_config(self): + return _FakeModelConfig() + + monkeypatch.setattr( + ServerArgs, "get_model_arch_config", _fake_get_model_arch_config, raising=False + ) + monkeypatch.setattr( + ServerArgs, "get_model_config", _fake_get_model_config, raising=False + ) + + try: + args._handle_model_specific_adjustments() + except Exception as exc: + pytest.skip( + f"_handle_model_specific_adjustments needs more stubs in this env: {exc}" + ) + + # Dense Gemma-4: override should NOT fire, ratio stays at upstream default 0.8. + assert args.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio + assert args.attention_backend in ("triton", "trtllm_mha") + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v"]))