diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 2791aeec9a8e..54ee243e7c2c 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -412,6 +412,10 @@ class Envs: # None = standard attention. See https://arxiv.org/abs/2512.12087 SGLANG_SKIP_SOFTMAX_PREFILL_THRESHOLD_SCALE_FACTOR = EnvFloat(None) SGLANG_SKIP_SOFTMAX_DECODE_THRESHOLD_SCALE_FACTOR = EnvFloat(None) + # Debug flag: bounds-check trtllm_mha page_table before the kernel call. + # Catches OOB SWA page indices that otherwise surface as CUDA illegal + # address errors deep inside the attention kernel. Set to 1 to enable. + SGLANG_TRTLLM_MHA_DEBUG = EnvBool(False) # TODO(mmangkad): Remove this once the FlashInfer unified allreduce-fusion # transport issue on GB200/GB300 platforms is fixed and verified resolved. SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT = EnvBool(None) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index e68bcb95e822..869ac14b4dcb 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -133,6 +133,18 @@ def __init__( self._swa_kv_pool: Optional[SWAKVPool] = ( kv_pool if self.use_sliding_window_kv_pool else None ) + # The model has SWA semantics whenever ANY of its layers carries a + # sliding window size > 0. Use ``model_runner.sliding_window_size`` + # as the canonical signal: model_runner sets it from the model's + # ``get_attention_sliding_window_size`` or ``config.sliding_window_size``. + # We need this signal *separately* from the SWA-pool detection + # because the FROZEN_KV_MTP draft backend's pool starts non-SWA and + # gets swapped to the target's SWA pool at forward time; we must + # have allocated SWA-page-table buffers BEFORE that swap. + _model_sw = getattr(model_runner, "sliding_window_size", None) + self.model_has_sliding_window: bool = ( + _model_sw is not None and _model_sw > 0 + ) # Forward metadata self.forward_metadata: Optional[TRTLLMMHAMetadata] = None @@ -161,8 +173,20 @@ def _maybe_translate_swa( def _alloc_swa_page_table( self, max_bs: int, max_num_pages: int ) -> Optional[torch.Tensor]: - """Allocate a SWA page_table buffer, or return None for non-SWA models.""" - if not self.use_sliding_window_kv_pool: + """Allocate a SWA page_table buffer, or return None for non-SWA models. + + Note: we eagerly allocate when ``self.model_has_sliding_window`` is + true even if ``self.use_sliding_window_kv_pool`` is currently + ``False`` at init time. This is needed for the FROZEN_KV_MTP draft + backend: at init it has no SWA pool, but at forward time + ``target_kv_pool_view`` swaps in the target's SWA pool (see + ``sglang/srt/speculative/frozen_kv_mtp_utils.py``). Without the + pre-allocated buffer the draft backend would build full-pool + page_table values for SWA layers and crash the trtllm_mha + ``fmhaSm100fKernel_*SlidingOrChunkedCausal*`` kernel with + ``Warp Illegal Address``. + """ + if not self.use_sliding_window_kv_pool and not self.model_has_sliding_window: return None return torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device=self.device) @@ -752,6 +776,62 @@ def forward_decode( page_table = self._get_layer_page_table(layer, forward_batch) + # DEBUG: bounds-check page_table before trtllm kernel. Looking + # for OOB SWA page indices that explain the cudaErrorIllegalAddress. + # IMPORTANT: .item() syncs and breaks cuda-graph capture, so we + # only do this when stream capture is not active. + if envs.SGLANG_TRTLLM_MHA_DEBUG.get() and ( + not torch.cuda.is_current_stream_capturing() + ): + import os + + import torch as _t + + cs = self.forward_metadata.cache_seqlens_int32 + kc_shape = k_cache.shape # (num_pages, num_kv_heads, page_size, head_dim) + num_pages_in_cache = int(kc_shape[0]) + # 1) max-value check + pt_max = int(page_table.max().item()) + pt_min = int(page_table.min().item()) + if pt_max >= num_pages_in_cache or pt_min < 0: + # Pre-emptively dump and abort before the kernel reads OOB. + dump_dir = os.environ.get( + "SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug" + ) + os.makedirs(dump_dir, exist_ok=True) + ts = int(_t.cuda.current_stream().cuda_stream) + fn = ( + f"{dump_dir}/page_table_oob_layer{layer.layer_id}_" + f"stream{ts}_{int(_t.cuda.device_count())}.pt" + ) + _t.save( + { + "page_table": page_table.detach().cpu(), + "cache_seqlens_int32": cs.detach().cpu(), + "k_cache_shape": list(kc_shape), + "num_pages_in_cache": num_pages_in_cache, + "page_size": self.page_size, + "sliding_window": layer.sliding_window_size, + "layer_id": layer.layer_id, + "forward_mode": str(forward_batch.forward_mode), + "is_swa_layer": ( + self._swa_kv_pool.layers_mapping[layer.layer_id][1] + if self.use_sliding_window_kv_pool + else False + ), + }, + fn, + ) + msg = ( + f"[trtllm_mha DEBUG] OOB page_table @ layer {layer.layer_id} " + f"({'SWA' if (self.use_sliding_window_kv_pool and self._swa_kv_pool.layers_mapping[layer.layer_id][1]) else 'FULL'}): " + f"page_table.max={pt_max} page_table.min={pt_min} " + f"num_pages_in_cache={num_pages_in_cache}. " + f"Dumped to {fn}" + ) + logger.error(msg) + raise RuntimeError(msg) + # Call TRT-LLM kernel # raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype o = flashinfer.decode.trtllm_batch_decode_with_kv_cache( diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py index ad6f01d9875a..e30027776bb3 100644 --- a/python/sglang/srt/layers/gemma4_fused_ops.py +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -2,6 +2,18 @@ Fuses standard RMSNorm + residual-add (+ optional scalar multiply) into a single kernel pass to reduce kernel launch overhead. + +Also provides a single-launch fused router for Gemma4 MoE (PR #26120 in +pyc96/sglang fork): replaces the per-layer ``torch.topk`` -> +``softmax`` -> ``per_expert_scale[ids]`` -> ``mul`` -> ``cast`` chain in +``Gemma4MoE.routing_function`` with one Triton kernel. + +The reference design comes from vLLM PR #39083 +(``_gemma4_routing_kernel`` / ``gemma4_fused_routing_kernel_triton``), +which is apache-2.0. Our kernel is rewritten in SGLang style and uses +the identity ``softmax(all)[topk] / sum(softmax(all)[topk]) = +softmax(topk_logits)`` already exploited by SGLang's torch routing +function, so the math is bitwise-comparable to the prior fp32 path. """ from typing import Optional @@ -283,3 +295,163 @@ def gemma_dual_rmsnorm_residual_scalar( BLOCK_SIZE=BLOCK_SIZE, ) return out + + +# --------------------------------------------------------------------------- +# Fused Gemma4 routing kernel (one launch per layer) +# --------------------------------------------------------------------------- +# +# Equivalent to: +# +# topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) +# topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) +# topk_weights = topk_weights * per_expert_scale[topk_ids] +# return topk_weights.float(), topk_ids.int() +# +# but completes the entire computation in one Triton program per token. +# +# Algorithm notes: +# * Loads all E logits per token into one program; for Gemma4 +# ``E = num_experts = 128`` so ``BLOCK_E = next_pow2(E) = 128`` and the +# work fits in a single warp with `num_warps=1`. +# * Computes ``softmax-of-topk`` by: +# - using ``tl.sort`` on (logit_bits_as_sortable_uint, expert_id) pairs +# packed into int64 — this gives a fully vectorized top-K without a +# K-step loop and matches the bitwise behavior of ``torch.topk``. +# - taking the largest K via a mask on the sorted-descending sequence +# - normalizing in fp32 (matches ``softmax`` default dtype) +# - multiplying by ``per_expert_scale[topk_ids]`` +# * Writes ``topk_weights`` (fp32) and ``topk_ids`` (int32) in one +# pass, matching the output dtypes the SGLang MoE topk wrapper +# expects. +# +# Reference algorithm: vLLM PR #39083 ``_gemma4_routing_kernel`` (apache-2.0). +# Our independent implementation follows the same sort+mask+softmax scheme. +@triton.jit +def _gemma4_routing_kernel( + gating_ptr, # [T, E] router logits, any float dtype + per_expert_scale_ptr, # [E] per-expert scale (any float dtype) + topk_weights_ptr, # [T, K] fp32 out + topk_ids_ptr, # [T, K] int32 out + stride_g_t, # stride of gating in the token dim + E: tl.constexpr, + K: tl.constexpr, + BLOCK_E: tl.constexpr, +): + pid = tl.program_id(0) + offs_e = tl.arange(0, BLOCK_E) + valid = offs_e < E + + # Load logits into fp32; out-of-bound lanes get -inf so they sort last. + logits = tl.load( + gating_ptr + pid * stride_g_t + offs_e, + mask=valid, + other=-float("inf"), + ).to(tl.float32) + + # Build a sortable int64 key: high 32 bits = bijective(logit_bits) so + # ascending-int sort == ascending-float sort; low 32 bits = expert id + # (kept stable for ties matching torch.topk's default behavior). This + # avoids a separate index buffer / scatter pass after the sort. + MIN32 = -2147483648 + logit_bits = logits.to(tl.int32, bitcast=True) + sign = logit_bits >> 31 + key = tl.where(sign == 0, logit_bits ^ -1, logit_bits ^ MIN32) + # Force invalid lanes to the max positive key so they end up *after* the + # real logits when we sort ascending and read from the top of the + # reversed list. (descending=True would flip the order.) + key = tl.where(valid, key, 0x7FFFFFFF) + sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF + packed = (sk64 << 32) | offs_e.to(tl.int64) + + # Sort ascending; the K smallest keys correspond to the K largest + # logits because of the bijection above. + sorted_p = tl.sort(packed, descending=False) + all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32) + all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32) + + # Invert the bijection to recover the original logit value. + sign_k = all_keys >> 31 + all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32) + all_logits = all_bits.to(tl.float32, bitcast=True) + + # Softmax over the K largest logits only (identity proven by SGLang's + # torch routing function comment). Subtract the max for stability; + # since the list is sorted descending by logit value, the max sits at + # index 0. + top_mask = offs_e < K + max_l = tl.max(tl.where(top_mask, all_logits, -float("inf")), axis=0) + # exp2(x * log2(e)) is what tl.math.exp expands to; spell it out so we + # can tolerate older Triton releases that lack tl.math.exp. + raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634) + raw_exp = tl.where(top_mask, raw_exp, 0.0) + + denom = tl.sum(raw_exp, axis=0) + denom = tl.where(denom > 0.0, denom, 1.0) + weights = raw_exp / denom + + # Multiply by per_expert_scale[topk_ids]. per_expert_scale lives in + # any float dtype; cast to fp32 for the final write. + scales = tl.load( + per_expert_scale_ptr + all_ids.to(tl.int64), + mask=top_mask, + other=1.0, + ).to(tl.float32) + weights = weights * scales + + base_off = pid * K + offs_e + tl.store(topk_weights_ptr + base_off, weights, mask=top_mask) + tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask) + + +def gemma4_fused_routing( + gating_output: torch.Tensor, + per_expert_scale: torch.Tensor, + topk: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """One-launch Gemma4 router. + + Args: + gating_output: [T, E] router logits in any floating dtype; will be + cast to fp32 inside the kernel. + per_expert_scale: [E] per-expert scale, any floating dtype. + topk: number of experts to keep per token. + + Returns: + topk_weights: [T, topk] fp32 (matches SGLang TopK contract). + topk_ids: [T, topk] int32 (matches SGLang TopK contract). + """ + assert gating_output.dim() == 2, "expected [T, E] router logits" + assert per_expert_scale.dim() == 1 + assert per_expert_scale.shape[0] == gating_output.shape[1] + T, E = gating_output.shape + assert topk <= E + + # The kernel reads the token row with stride_g_t; force the inner-most + # dim to be contiguous so the masked load is coalesced. Most call + # sites already pass a contiguous tensor (router proj output); contiguous + # is cheap. + gating_output = gating_output.contiguous() + per_expert_scale = per_expert_scale.contiguous() + + BLOCK_E = triton.next_power_of_2(E) + topk_weights = torch.empty( + (T, topk), dtype=torch.float32, device=gating_output.device + ) + topk_ids = torch.empty((T, topk), dtype=torch.int32, device=gating_output.device) + + if T == 0: + return topk_weights, topk_ids + + _gemma4_routing_kernel[(T,)]( + gating_output, + per_expert_scale, + topk_weights, + topk_ids, + gating_output.stride(0), + E, + topk, + BLOCK_E, + num_warps=1, + ) + return topk_weights, topk_ids diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index bd1205708351..4f5fc878c1a4 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -25,6 +25,30 @@ logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 +# Opt-in debug instrumentation: log when the SWA allocator returns an index +# >= swa_pool_size. Backend-independent. Set ``SGLANG_TRTLLM_MHA_DEBUG=1`` +# to enable. +# +# Empirical finding under Gemma-4-E4B-IT + MTP + summarisation 8 k/1 k x 80 +# at SWA usage up to 1.00 (triton backend) and up to 0.85+ (trtllm_mha +# backend that crashes): this trap **never fires** under either backend, so +# the SWA allocator is NOT producing OOB indices. The trtllm_mha crash is +# downstream of the allocator -- specifically in +# ``trtllm_mha_backend.init_forward_metadata`` where +# ``metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]`` +# pulls in *trailing* positions past each row's cache_seqlens whose +# req_to_token entries were never written (= 0). The translation +# ``full_to_swa_index_mapping[0]`` is the swa slot assigned to full slot 0 +# at the last alloc; it can address an arbitrary swa page that may or may +# not be in-bounds. See crash_repro/TRIAGE_REPORT.md. +import os as _os + +_DEBUG_SWA_ALLOC_OOB = _os.environ.get("SGLANG_TRTLLM_MHA_DEBUG", "").lower() in ( + "1", + "true", + "yes", +) + class SWAKVPool(BaseSWAKVPool): """KV cache with separate pools for full and SWA attention layers.""" @@ -495,8 +519,51 @@ def alloc_extend( else: self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices + # DEBUG: instrument SWA allocator OOB writes (independent of + # attention backend). Catches the off-by-one in + # alloc_extend_kernel Part 1 (last_loc + 1 + offset overflowing + # pool_size when last_loc is near the pool end). See + # crash_repro/TRIAGE_REPORT.md. + if _DEBUG_SWA_ALLOC_OOB: + self._maybe_log_swa_oob(alloc_swa_indices, "alloc_extend") + return alloc_full_indices + def _maybe_log_swa_oob(self, alloc_swa_indices: torch.Tensor, ctx: str) -> None: + """If any swa index is >= ``self._size_swa``, log + dump.""" + import os + max_val = int(alloc_swa_indices.max().item()) + if max_val >= self._size_swa: + min_val = int(alloc_swa_indices.min().item()) + dump_dir = os.environ.get( + "SGLANG_TRTLLM_MHA_DEBUG_DIR", "/tmp/trtllm_mha_debug" + ) + os.makedirs(dump_dir, exist_ok=True) + fn = ( + f"{dump_dir}/swa_alloc_oob_{ctx}_max{max_val}_size{self._size_swa}_" + f"{int(torch.cuda.current_stream().cuda_stream)}.pt" + ) + torch.save( + { + "ctx": ctx, + "alloc_swa_indices": alloc_swa_indices.detach().cpu(), + "swa_pool_size": self._size_swa, + "page_size": self.page_size, + "swa_max_value_returned": max_val, + "swa_min_value_returned": min_val, + "oob_count": int((alloc_swa_indices >= self._size_swa).sum().item()), + }, + fn, + ) + msg = ( + f"[SWA alloc DEBUG] OOB swa index from {ctx}: " + f"max={max_val} swa_pool_size={self._size_swa}; " + f"first OOB at flat-idx " + f"{int((alloc_swa_indices >= self._size_swa).nonzero().flatten()[0].item())}. " + f"Dumped to {fn}" + ) + logger.error(msg) + def alloc_extend_swa_tail( self, prefix_lens: torch.Tensor, @@ -590,6 +657,9 @@ def alloc_decode( else: self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices + if _DEBUG_SWA_ALLOC_OOB: + self._maybe_log_swa_oob(alloc_swa_indices, "alloc_decode") + return alloc_full_indices def free(self, free_index: torch.Tensor): diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 190452fcd124..f07429d50ad5 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -30,6 +30,7 @@ get_tensor_model_parallel_world_size, ) from sglang.srt.layers.gemma4_fused_ops import ( + gemma4_fused_routing, gemma_dual_rmsnorm_residual_scalar, gemma_qkv_rmsnorm, gemma_rmsnorm_residual_scalar, @@ -50,6 +51,7 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_executor.forward_context import get_attn_backend from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -220,6 +222,20 @@ def routing_function( ) -> tuple[torch.Tensor, torch.Tensor]: # softmax(all)[topk] / sum(softmax(all)[topk]) = softmax(topk_logits), # so we softmax only the top-k logits (fewer kernel launches). + # + # Fast path: a single Triton kernel that produces (weights, ids) + # already scaled by per_expert_scale. Mathematically identical + # to the torch fallback below. Active when on CUDA with a 2-D + # router-logits tensor and num_experts a power-of-two-rounded + # value the kernel supports (always true for Gemma4: E=128). + if ( + gating_output.is_cuda + and gating_output.dim() == 2 + and gating_output.dtype + in (torch.float16, torch.bfloat16, torch.float32) + ): + return gemma4_fused_routing(gating_output, per_expert_scale, topk) + topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) @@ -904,6 +920,145 @@ def project_per_layer_inputs( # Combine: (projection + per_layer_inputs) * scale return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + # ------------------------------------------------------------------ # + # YOCO ("You Only Cache Once") fast-prefill split # + # # + # Gemma4 E2B / E4B set `num_kv_shared_layers > 0`: the last K of N # + # decoder layers share KV state with corresponding earlier layers # + # (`Gemma4Attention.is_kv_shared_layer` / `kv_shared_layer_index`). # + # During prefill, those shared-KV layers don't write KV — but in the # + # baseline forward they still run the full Q-side compute (RMSNorm + # + # Q-proj + RoPE + attention + MLP + residuals) on every prefill # + # token. The only Q-side outputs that ultimately matter for sampling # + # are the last-token-per-request rows, because the logits head only # + # reads `hidden_states[cumsum(extend_seq_lens) - 1]`. # + # # + # Truncating `hidden_states` and `positions` to just those rows # + # before entering the shared-KV layers is exactly the # + # vLLM `--kv-sharing-fast-prefill` (vLLM PR #22628 + #38879) # + # optimization. The K/V already live in the cache thanks to the # + # earlier non-shared layers, so attention reads them unchanged; only # + # the per-layer Q-side compute volume shrinks by # + # extend_total / num_reqs. # + # ------------------------------------------------------------------ # + + def _yoco_eligibility(self, forward_batch: ForwardBatch) -> bool: + # Master kill switch so the patched binary can A/B test against the + # unpatched layer loop without restarting. Default ON when the + # model config opts in. + import os + + if os.environ.get("SGLANG_GEMMA4_YOCO", "1") == "0": + return False + num_kv_shared_layers = int(getattr(self.config, "num_kv_shared_layers", 0)) + if num_kv_shared_layers <= 0: + return False + # Multi-stage PP not handled: the cross-decoder split happens at a + # fixed layer index and we'd need to coordinate the truncation + # across stages. + if not (self.pp_group.is_first_rank and self.pp_group.is_last_rank): + return False + if not forward_batch.forward_mode.is_extend_without_speculative(): + return False + # Aux-hidden-state captures span the layer index; if any capture + # index lives inside the shared-KV range the dropped rows would + # corrupt the captured aux tensor. + first_kv_shared_layer_idx = self.config.num_hidden_layers - num_kv_shared_layers + for layer_idx in self.layers_to_capture: + if first_kv_shared_layer_idx <= layer_idx <= self.config.num_hidden_layers: + return False + ex_seq_lens_cpu = forward_batch.extend_seq_lens_cpu + if ex_seq_lens_cpu is None or len(ex_seq_lens_cpu) == 0: + return False + if max(ex_seq_lens_cpu) <= 1: + # All requests are effectively decode-shaped; nothing to truncate. + return False + # Per-token logprobs over prompt tokens: those need the full hidden + # states from every layer, so disable. + if getattr(forward_batch, "return_logprob", False): + logprob_starts = forward_batch.extend_logprob_start_lens_cpu + if logprob_starts is None: + return False + for start, slen in zip(logprob_starts, ex_seq_lens_cpu): + if start < slen: + return False + return True + + def _yoco_truncate_to_last_tokens( + self, + forward_batch: ForwardBatch, + hidden_states: torch.Tensor, + positions: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor], + ): + """Truncate `hidden_states`/`positions`/`per_layer_inputs` to the + last query token per request and rebuild attention metadata. + + Returns `(hidden_states_t, positions_t, per_layer_inputs_t, + last_indices, restore_fn)`. + """ + extend_seq_lens = forward_batch.extend_seq_lens + last_indices = torch.cumsum(extend_seq_lens, dim=0) - 1 + + hidden_states_t = hidden_states.index_select(0, last_indices) + positions_t = positions.index_select(0, last_indices) + per_layer_inputs_t = ( + per_layer_inputs.index_select(0, last_indices) + if per_layer_inputs is not None + else None + ) + + # Snapshot fields we mutate so we can put them back exactly. + orig_extend_seq_lens = forward_batch.extend_seq_lens + orig_extend_prefix_lens = forward_batch.extend_prefix_lens + orig_extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu + orig_extend_prefix_lens_cpu = getattr( + forward_batch, "extend_prefix_lens_cpu", None + ) + orig_extend_num_tokens = getattr(forward_batch, "extend_num_tokens", None) + + num_reqs = extend_seq_lens.shape[0] + ones = torch.ones_like(orig_extend_seq_lens) + # seq_lens stays the same; the cross-decoder attends over the full + # cached sequence. The new prefix length is therefore seq_len - 1. + new_prefix = forward_batch.seq_lens - 1 + + forward_batch.extend_seq_lens = ones + forward_batch.extend_prefix_lens = new_prefix + forward_batch.extend_seq_lens_cpu = [1] * num_reqs + if orig_extend_prefix_lens_cpu is not None: + if forward_batch.seq_lens_cpu is not None: + forward_batch.extend_prefix_lens_cpu = [ + int(s) - 1 for s in forward_batch.seq_lens_cpu.tolist() + ] + else: + forward_batch.extend_prefix_lens_cpu = new_prefix.tolist() + if orig_extend_num_tokens is not None: + forward_batch.extend_num_tokens = num_reqs + + attn_backend = get_attn_backend() + attn_backend.init_forward_metadata(forward_batch) + + def restore_fn(): + forward_batch.extend_seq_lens = orig_extend_seq_lens + forward_batch.extend_prefix_lens = orig_extend_prefix_lens + forward_batch.extend_seq_lens_cpu = orig_extend_seq_lens_cpu + if orig_extend_prefix_lens_cpu is not None: + forward_batch.extend_prefix_lens_cpu = orig_extend_prefix_lens_cpu + if orig_extend_num_tokens is not None: + forward_batch.extend_num_tokens = orig_extend_num_tokens + # Restore the full-batch attention metadata so anything that + # runs after this forward sees the original qo_indptr. + attn_backend.init_forward_metadata(forward_batch) + + return ( + hidden_states_t, + positions_t, + per_layer_inputs_t, + last_indices, + restore_fn, + ) + def forward( self, input_ids: torch.Tensor, @@ -939,7 +1094,37 @@ def forward( aux_hidden_states = [] num_layers = self.config.num_hidden_layers + # YOCO fast-prefill decision: evaluate once, before the layer loop. + num_kv_shared_layers = int(getattr(self.config, "num_kv_shared_layers", 0)) + first_kv_shared_layer_idx = num_layers - num_kv_shared_layers + yoco_active = self._yoco_eligibility(forward_batch) + yoco_restore_fn = None + yoco_last_indices = None + yoco_full_shape = None + for layer_idx in range(self.start_layer, self.end_layer): + # Apply YOCO truncation exactly once, just before entering the + # first shared-KV layer. + if ( + yoco_active + and yoco_restore_fn is None + and layer_idx == first_kv_shared_layer_idx + and layer_idx >= self.start_layer + ): + yoco_full_shape = hidden_states.shape + ( + hidden_states, + positions, + per_layer_inputs, + yoco_last_indices, + yoco_restore_fn, + ) = self._yoco_truncate_to_last_tokens( + forward_batch, + hidden_states, + positions, + per_layer_inputs, + ) + if layer_idx in self.layers_to_capture: aux_hidden_states.append(hidden_states) @@ -959,6 +1144,16 @@ def forward( # Gemma4DecoderLayer.forward always returns (hidden_states, None); # the residual is fused inside the layer, so nothing to thread. + # YOCO scatter-back: expand the truncated final hidden_states into + # the full-sized tensor so the logits processor's "index at + # last_indices" produces the right output. Other rows are never + # read (the logits processor reads only the same indices we wrote). + if yoco_restore_fn is not None: + full_hidden = hidden_states.new_empty(yoco_full_shape) + full_hidden.index_copy_(0, yoco_last_indices, hidden_states) + hidden_states = full_hidden + yoco_restore_fn() + if not self.pp_group.is_last_rank: # cuda_graph_runner allocates a fixed PP-proxy schema of # {hidden_states, residual} and KeyErrors if a model omits a key. @@ -1147,7 +1342,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")), ("experts.w2_weight", "experts.down_proj", ("w2",)), ] - num_experts = self.config.num_experts + # Dense subclasses (e.g. the Gemma4 MTP assistant) reuse this. + num_experts = getattr(self.config, "num_experts", None) or 0 # Per-expert checkpoint format used by compressed-tensors / FP8 # (e.g. RedHatAI/*-FP8-Dynamic) and by ModelOpt NVFP4 @@ -1159,11 +1355,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # in a trailing dot, so the standard `name.replace(weight_name, # param_name)` collapses every suffix uniformly to the fused # FusedMoE params (experts.w13_*, experts.w2_*). - per_expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=num_experts, + per_expert_params_mapping = ( + FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=num_experts, + ) + if num_experts + else [] ) k_eq_v_layers = self._get_k_eq_v_layers() diff --git a/python/sglang/srt/models/gemma4_mtp.py b/python/sglang/srt/models/gemma4_mtp.py index 1cb87b7c2e99..ade10ce5b990 100644 --- a/python/sglang/srt/models/gemma4_mtp.py +++ b/python/sglang/srt/models/gemma4_mtp.py @@ -21,6 +21,7 @@ from torch import nn from transformers import PretrainedConfig, PreTrainedModel +from sglang.srt.distributed import get_pp_group from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.logits_processor import ( LogitsMetadata, @@ -72,6 +73,7 @@ def __init__( self.assistant_config = config self.config = text_config self.quant_config = quant_config + self.pp_group = get_pp_group() self.vocab_size = text_config.vocab_size self.hidden_size = text_config.hidden_size diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1d1b8d29959d..d4192f947744 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2232,10 +2232,70 @@ def _handle_model_specific_adjustments(self): ) if is_sm100_supported() and self.moe_runner_backend == "auto": + if self.get_model_config().quantization == "modelopt_fp4": + self.quantization = "modelopt_fp4" + self.moe_runner_backend = "flashinfer_trtllm" + logger.info( + "Use flashinfer_trtllm as MoE runner backend on " + "SM100 for Gemma-4 (modelopt_fp4)" + ) - self.moe_runner_backend = "flashinfer_trtllm" + # Gemma-4 uses a 5:1 SWA:full-attention layer ratio (see + # ``Gemma4TextConfig.layer_types``). The shipped default + # ``swa_full_tokens_ratio = 0.8`` is tuned for models where the + # sliding-window pool is the binding constraint, but for the + # **MoE** Gemma-4 (``26B-A4B-IT``: 30 layers = 25 SWA + 5 full, + # 128 experts top-k 8) the full-attention pool is binding under + # concurrent long-context workloads. Lowering the ratio to + # ``0.15`` shifts memory from the over-provisioned SWA pool to + # the under-provisioned full pool; median summarization TTFT + # drops 16% (10.5 s -> 8.7 s) on B200 with no MMLU regression. + # + # **Do not apply** this override to dense Gemma-4 variants + # (``31B-it``, ``E4B-IT``) — they have less GPU memory free + # after model load (dense weights take more RAM than MoE + # sparse weights), so the SWA pool becomes critically small + # at this ratio and chokes admission under high concurrency. + # Empirically: applying ``0.15`` to 31B on B200 with 80 + # concurrent 1k/1k chat requests caused SWA usage to hit + # 100% saturation and dropped output throughput by ~3x. + # + # MoE detection via ``num_experts`` on the text config — same + # pattern used in ``gemma4_causal.py:1166``. Also keep the + # ``apply_deepseek_v4_defaults``-style "respect user override" + # predicate (note: the predicate currently can't distinguish + # user-passed ``0.8`` from the dataclass default; same caveat + # as the upstream DSV4 override). + try: + _hf_text_config = self.get_model_config().hf_text_config + except Exception: + _hf_text_config = None + _gemma4_num_experts = ( + int(getattr(_hf_text_config, "num_experts", 0) or 0) + if _hf_text_config is not None + else 0 + ) + _is_gemma4_moe = _gemma4_num_experts > 0 + if ( + _is_gemma4_moe + and self.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio + ): + self.swa_full_tokens_ratio = 0.15 + logger.info( + "Setting swa_full_tokens_ratio to " + f"{self.swa_full_tokens_ratio} for {model_arch} " + f"(MoE Gemma-4 with num_experts={_gemma4_num_experts}; " + "the default ratio over-provisions the SWA pool and " + "under-provisions the full-attention pool, causing " + "partial KV eviction and re-prefill under concurrent " + "long-context loads)." + ) + elif not _is_gemma4_moe: logger.info( - "Use flashinfer_trtllm as MoE runner backend on SM100 for Gemma-4 NVFP4" + f"Keeping default swa_full_tokens_ratio=" + f"{self.swa_full_tokens_ratio} for {model_arch} " + "(dense Gemma-4; MoE-specific 0.15 override skipped " + "to avoid SWA pool starvation)." ) elif model_arch == "MossVLForConditionalGeneration": if self.is_attention_backend_not_set(): diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py b/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py index 8b1ac37f8df2..c2add25aaa40 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py @@ -303,10 +303,21 @@ def run_once(): # Swap the draft backend's token_to_kv_pool to the frozen target pool # for the capture; the single backend-attr swap is seen by both # ``get_token_to_kv_pool()`` (via ``get_attn_backend()``) and the - # backend's own reads. + # backend's own reads. Also swap SWA-aware backend state so + # SWA-aware backends (notably trtllm_mha) build SWA-aware metadata + # against the target's SWA pool. See + # ``frozen_kv_mtp_utils._maybe_swap_swa_state``. + from sglang.srt.speculative.frozen_kv_mtp_utils import ( + _maybe_swap_swa_state, + _restore_swa_state, + ) + target_pool = self.frozen_kv_mtp_worker.kv_context.target_token_to_kv_pool saved_backend_pool = self.draft_attn_backend.token_to_kv_pool self.draft_attn_backend.token_to_kv_pool = target_pool + saved_swa_state = _maybe_swap_swa_state( + self.draft_attn_backend, target_pool + ) try: with forward_context(ForwardContext(attn_backend=self.draft_attn_backend)): self.frozen_kv_mtp_worker._init_frozen_kv_metadata_capture_cuda_graph( @@ -319,6 +330,7 @@ def run_once(): ) finally: self.draft_attn_backend.token_to_kv_pool = saved_backend_pool + _restore_swa_state(self.draft_attn_backend, saved_swa_state) set_global_graph_memory_pool(graph.pool()) return graph, out diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py index dbd63c2e444c..d2d7a6c17d59 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py @@ -32,6 +32,53 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +def _maybe_swap_swa_state( + draft_attn_backend: "AttentionBackend", new_pool +): + """Synchronise a backend's SWA-aware attributes with a swapped pool. + + Some attention backends (notably ``trtllm_mha``) cache + ``use_sliding_window_kv_pool`` / ``_swa_kv_pool`` at __init__ time + from ``model_runner.token_to_kv_pool``. When the FROZEN_KV_MTP + contexts swap ``token_to_kv_pool`` to the target's SWA pool, those + cached attributes go stale: the backend then treats every layer as + full-attention even though it is now reading the target's hybrid SWA + pool. For SWA-typed layers this leaks full-pool page indices into + the SWA k_cache page table and crashes the trtllm_mha sm_100a + paged-attention kernel with ``Warp Illegal Address``. + + This helper resolves the SWA-aware attributes from ``new_pool`` + (whether or not it is an SWAKVPool) and writes them back onto the + backend. Returns a tuple of the saved (use_swa, swa_kv_pool, + sliding_window_size) so the caller can restore them. + """ + from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool + + saved = ( + getattr(draft_attn_backend, "use_sliding_window_kv_pool", None), + getattr(draft_attn_backend, "_swa_kv_pool", None), + getattr(draft_attn_backend, "sliding_window_size", None), + ) + is_swa = isinstance(new_pool, SWAKVPool) + if hasattr(draft_attn_backend, "use_sliding_window_kv_pool"): + draft_attn_backend.use_sliding_window_kv_pool = is_swa + if hasattr(draft_attn_backend, "_swa_kv_pool"): + draft_attn_backend._swa_kv_pool = new_pool if is_swa else None + # sliding_window_size is per-layer in the model; the trtllm_mha + # backend caches a module-level value. Don't change it: the draft + # model's own sliding_window_size already matches the target's + # (Gemma4-Assistant inherits the same sliding window). + return saved + + +def _restore_swa_state(draft_attn_backend: "AttentionBackend", saved): + use_swa, swa_kv_pool, sliding_window_size = saved + if hasattr(draft_attn_backend, "use_sliding_window_kv_pool"): + draft_attn_backend.use_sliding_window_kv_pool = use_swa + if hasattr(draft_attn_backend, "_swa_kv_pool"): + draft_attn_backend._swa_kv_pool = swa_kv_pool + + @contextmanager def frozen_kv_target_view( forward_batch: ForwardBatch, @@ -56,11 +103,15 @@ def frozen_kv_target_view( forward_batch.spec_info = None saved_backend_pool = draft_attn_backend.token_to_kv_pool draft_attn_backend.token_to_kv_pool = kv_context.target_token_to_kv_pool + saved_swa_state = _maybe_swap_swa_state( + draft_attn_backend, kv_context.target_token_to_kv_pool + ) try: yield finally: forward_batch.spec_info = saved_spec_info draft_attn_backend.token_to_kv_pool = saved_backend_pool + _restore_swa_state(draft_attn_backend, saved_swa_state) @contextmanager @@ -84,10 +135,14 @@ def target_kv_pool_view( ) saved_backend_pool = draft_attn_backend.token_to_kv_pool draft_attn_backend.token_to_kv_pool = kv_context.target_token_to_kv_pool + saved_swa_state = _maybe_swap_swa_state( + draft_attn_backend, kv_context.target_token_to_kv_pool + ) try: yield finally: draft_attn_backend.token_to_kv_pool = saved_backend_pool + _restore_swa_state(draft_attn_backend, saved_swa_state) def set_frozen_kv_positions(forward_batch: ForwardBatch, topk: int) -> None: diff --git a/test/srt/layers/test_gemma4_fused_routing.py b/test/srt/layers/test_gemma4_fused_routing.py new file mode 100644 index 000000000000..6bed5f84862c --- /dev/null +++ b/test/srt/layers/test_gemma4_fused_routing.py @@ -0,0 +1,111 @@ +"""Correctness tests for ``gemma4_fused_routing``. + +Compares the Triton-fused routing kernel against the original SGLang +``Gemma4MoE.routing_function`` reference (softmax-of-topk * per_expert_scale). +Run with:: + + pytest test/srt/layers/test_gemma4_fused_routing.py -v + +Requires a CUDA-capable GPU; skips otherwise. +""" + +from __future__ import annotations + +import pytest +import torch + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="gemma4_fused_routing is a CUDA-only Triton kernel", +) + + +@pytest.fixture(scope="module") +def fused_routing(): + from sglang.srt.layers.gemma4_fused_ops import gemma4_fused_routing + + return gemma4_fused_routing + + +def _reference(gating_output: torch.Tensor, per_expert_scale: torch.Tensor, topk: int): + """The previous (now fallback) torch routing function from gemma4_causal.py.""" + topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) + topk_weights = topk_weights * per_expert_scale[topk_ids].to(topk_weights.dtype) + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("T", [1, 7, 64, 128, 1024]) +@pytest.mark.parametrize("E,K", [(128, 8), (64, 4), (256, 8)]) +def test_matches_reference(fused_routing, dtype, T, E, K): + torch.manual_seed(0) + g = torch.randn(T, E, dtype=dtype, device="cuda") + s = torch.rand(E, dtype=dtype, device="cuda") * 2.0 + + ref_w, ref_i = _reference(g, s, K) + out_w, out_i = fused_routing(g, s, K) + + assert out_w.dtype == torch.float32 + assert out_i.dtype == torch.int32 + assert out_w.shape == (T, K) + assert out_i.shape == (T, K) + + # IDs must match exactly (top-K with stable tie-breaking on expert id). + # In practice with random logits ties almost never happen; if they do we + # accept either order as long as the weight sum and the selected set are + # equivalent. + # The fused kernel does softmax in fp32 throughout, while the torch + # fallback runs softmax in the input dtype before casting to fp32. For + # bf16 inputs that means our kernel is *more* accurate; loosen the + # tolerance to roughly the input-dtype eps so we don't false-fail. + if dtype == torch.bfloat16: + atol, rtol = 5e-3, 5e-3 + elif dtype == torch.float16: + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 1e-5, 1e-5 + + if (out_i != ref_i).any(): + # Compare as sets per row. + ref_set = ref_i.sort(dim=-1).values + out_set = out_i.sort(dim=-1).values + assert torch.equal( + out_set, ref_set + ), "fused routing picked a different top-K set than reference" + # Sum of weights per row should still be close (softmax over the same + # K logits). + torch.testing.assert_close( + out_w.sum(dim=-1).to(torch.float32), + ref_w.sum(dim=-1).to(torch.float32), + atol=atol, + rtol=rtol, + ) + else: + # Same IDs in the same order — weights must match within input dtype eps. + torch.testing.assert_close(out_w, ref_w, atol=atol, rtol=rtol) + + +def test_zero_tokens(fused_routing): + g = torch.empty(0, 128, dtype=torch.bfloat16, device="cuda") + s = torch.ones(128, dtype=torch.bfloat16, device="cuda") + w, i = fused_routing(g, s, 8) + assert w.shape == (0, 8) and i.shape == (0, 8) + assert w.dtype == torch.float32 and i.dtype == torch.int32 + + +def test_scale_applied(fused_routing): + """Weights must include per_expert_scale[topk_ids].""" + torch.manual_seed(1) + T, E, K = 4, 128, 8 + g = torch.randn(T, E, dtype=torch.bfloat16, device="cuda") + s = torch.rand(E, dtype=torch.bfloat16, device="cuda") * 3.0 + + out_w, out_i = fused_routing(g, s, K) + ref_w, ref_i = _reference(g, s, K) + torch.testing.assert_close(out_w, ref_w, atol=5e-3, rtol=5e-3) + assert torch.equal(out_i, ref_i) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v"])) diff --git a/test/srt/models/test_gemma4_yoco_fast_prefill.py b/test/srt/models/test_gemma4_yoco_fast_prefill.py new file mode 100644 index 000000000000..72f5f7cafacd --- /dev/null +++ b/test/srt/models/test_gemma4_yoco_fast_prefill.py @@ -0,0 +1,216 @@ +""" +Unit tests for the YOCO ("You Only Cache Once") fast-prefill split in +``Gemma4TextModel.forward``. + +The full forward path needs CUDA + a real Gemma4 checkpoint, so these +tests focus on the eligibility logic and the per-request "last token +index" math. They monkey-patch a minimal ``ForwardBatch``-like object +and exercise ``_yoco_eligibility`` and ``_yoco_truncate_to_last_tokens`` +on CPU. + +Larger end-to-end correctness is covered by the e2e benchmarks in the +PR description (E2B and E4B long-prompt runs both produced character- +identical outputs on the YOCO/non-YOCO single-prompt smoke test). +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import List + +import torch + +from sglang.srt.models import gemma4_causal as gemma4_causal_module + + +class _FakeForwardMode: + def is_extend_without_speculative(self): + return True + + +class _DecodeForwardMode(_FakeForwardMode): + def is_extend_without_speculative(self): + return False + + +class _FakeAttnBackend: + def __init__(self): + self.init_calls: List[tuple] = [] + + def init_forward_metadata(self, forward_batch): + # Capture the metadata that the model sees at each rebuild so the + # tests can assert the right truncation/restore happens. + self.init_calls.append( + ( + int(forward_batch.extend_seq_lens.sum().item()), + int(forward_batch.extend_prefix_lens.sum().item()), + list(forward_batch.extend_seq_lens_cpu), + ) + ) + + +def _make_fake_forward_batch( + extend_seq_lens: List[int], + seq_lens: List[int] | None = None, + *, + return_logprob: bool = False, + decode_only: bool = False, +): + if seq_lens is None: + seq_lens = list(extend_seq_lens) + return SimpleNamespace( + extend_seq_lens=torch.tensor(extend_seq_lens, dtype=torch.int32), + extend_seq_lens_cpu=list(extend_seq_lens), + extend_prefix_lens=torch.tensor( + [s - e for s, e in zip(seq_lens, extend_seq_lens)], + dtype=torch.int32, + ), + extend_prefix_lens_cpu=[s - e for s, e in zip(seq_lens, extend_seq_lens)], + extend_logprob_start_lens_cpu=( + [0] * len(extend_seq_lens) if return_logprob else None + ), + extend_num_tokens=sum(extend_seq_lens), + seq_lens=torch.tensor(seq_lens, dtype=torch.int32), + seq_lens_cpu=torch.tensor(seq_lens, dtype=torch.int32), + return_logprob=return_logprob, + forward_mode=_DecodeForwardMode() if decode_only else _FakeForwardMode(), + ) + + +class _FakePPGroup: + is_first_rank = True + is_last_rank = True + + +def _make_fake_model( + *, + num_hidden_layers: int = 35, + num_kv_shared_layers: int = 20, + layers_to_capture: List[int] | None = None, +): + config = SimpleNamespace( + num_hidden_layers=num_hidden_layers, + num_kv_shared_layers=num_kv_shared_layers, + ) + fake = SimpleNamespace( + config=config, + pp_group=_FakePPGroup(), + layers_to_capture=layers_to_capture or [], + ) + cls = gemma4_causal_module.Gemma4TextModel + for name in ("_yoco_eligibility", "_yoco_truncate_to_last_tokens"): + setattr(fake, name, getattr(cls, name).__get__(fake, type(fake))) + return fake + + +def test_eligibility_default_on(): + fake = _make_fake_model() + fb = _make_fake_forward_batch([10, 5, 7]) + assert fake._yoco_eligibility(fb) + + +def test_eligibility_no_kv_shared_layers(): + fake = _make_fake_model(num_kv_shared_layers=0) + fb = _make_fake_forward_batch([10, 5, 7]) + assert not fake._yoco_eligibility(fb) + + +def test_eligibility_pure_decode_batch(): + fake = _make_fake_model() + # All requests have a single new token -> nothing to truncate. + fb = _make_fake_forward_batch([1, 1, 1]) + assert not fake._yoco_eligibility(fb) + + +def test_eligibility_decode_forward_mode(): + fake = _make_fake_model() + fb = _make_fake_forward_batch([10], decode_only=True) + assert not fake._yoco_eligibility(fb) + + +def test_eligibility_prompt_logprobs_disable(): + fake = _make_fake_model() + fb = _make_fake_forward_batch([10, 5], return_logprob=True) + # extend_logprob_start_lens_cpu = [0, 0] => starts before extend, prompt logprobs requested. + assert not fake._yoco_eligibility(fb) + + +def test_eligibility_layer_capture_inside_kv_shared_range(): + # Capture targets sit inside [first_kv_shared_layer_idx, num_hidden_layers] + # so the truncated tail would corrupt them. Disable. + fake = _make_fake_model(layers_to_capture=[28]) + fb = _make_fake_forward_batch([10, 5]) + assert not fake._yoco_eligibility(fb) + + +def test_eligibility_layer_capture_outside_kv_shared_range_ok(): + fake = _make_fake_model(layers_to_capture=[2, 10]) + fb = _make_fake_forward_batch([10, 5]) + assert fake._yoco_eligibility(fb) + + +def test_eligibility_env_kill_switch(monkeypatch): + monkeypatch.setenv("SGLANG_GEMMA4_YOCO", "0") + fake = _make_fake_model() + fb = _make_fake_forward_batch([10, 5]) + assert not fake._yoco_eligibility(fb) + # Toggle back to default. + monkeypatch.setenv("SGLANG_GEMMA4_YOCO", "1") + assert fake._yoco_eligibility(fb) + + +def test_truncate_to_last_tokens_indices_and_restore(): + fake = _make_fake_model() + fb = _make_fake_forward_batch( + extend_seq_lens=[3, 4, 2], + seq_lens=[3, 4, 2], + ) + + # Patch get_attn_backend to a fake. + fake_backend = _FakeAttnBackend() + gemma4_causal_module.get_attn_backend = lambda: fake_backend + + hidden = torch.arange(3 + 4 + 2, dtype=torch.float32).unsqueeze(-1).repeat(1, 8) + positions = torch.arange(9, dtype=torch.int64) + per_layer = torch.zeros(9, 35, 16) + + h_t, p_t, ple_t, last_indices, restore_fn = fake._yoco_truncate_to_last_tokens( + fb, hidden, positions, per_layer + ) + + # last_indices = cumsum([3,4,2]) - 1 = [2, 6, 8] + assert last_indices.tolist() == [2, 6, 8] + assert h_t.shape == (3, 8) + assert torch.equal(h_t[:, 0], torch.tensor([2.0, 6.0, 8.0])) + assert p_t.tolist() == [2, 6, 8] + assert ple_t.shape == (3, 35, 16) + + # forward_batch was mutated: extend_seq_lens is now all-1s, prefix is seq-1. + assert fb.extend_seq_lens.tolist() == [1, 1, 1] + assert fb.extend_prefix_lens.tolist() == [2, 3, 1] + assert fb.extend_seq_lens_cpu == [1, 1, 1] + assert fb.extend_num_tokens == 3 + # The backend was asked to rebuild its metadata for the truncated batch. + assert len(fake_backend.init_calls) == 1 + assert fake_backend.init_calls[0] == (3, 6, [1, 1, 1]) + + # Restore puts the original values back and rebuilds again. + restore_fn() + assert fb.extend_seq_lens.tolist() == [3, 4, 2] + assert fb.extend_prefix_lens.tolist() == [0, 0, 0] + assert fb.extend_seq_lens_cpu == [3, 4, 2] + assert fb.extend_num_tokens == 9 + assert len(fake_backend.init_calls) == 2 + assert fake_backend.init_calls[1] == (9, 0, [3, 4, 2]) + + +if __name__ == "__main__": + test_eligibility_default_on() + test_eligibility_no_kv_shared_layers() + test_eligibility_pure_decode_batch() + test_eligibility_decode_forward_mode() + test_eligibility_prompt_logprobs_disable() + test_eligibility_layer_capture_inside_kv_shared_range() + test_eligibility_layer_capture_outside_kv_shared_range_ok() + test_truncate_to_last_tokens_indices_and_restore() + print("ALL TESTS PASSED") diff --git a/test/srt/test_gemma4_swa_full_tokens_ratio.py b/test/srt/test_gemma4_swa_full_tokens_ratio.py new file mode 100644 index 000000000000..70cff5be34b6 --- /dev/null +++ b/test/srt/test_gemma4_swa_full_tokens_ratio.py @@ -0,0 +1,218 @@ +"""Unit tests for the Gemma-4 model-specific override of ``swa_full_tokens_ratio``. + +These exercise only the server-arg adjustment path; they do not load weights +or start a server. Run with:: + + pytest test/srt/test_gemma4_swa_full_tokens_ratio.py -v +""" + +from __future__ import annotations + +import pytest + +from sglang.srt.server_args import ServerArgs + + +def _make_args(**overrides): + """Build a minimal ServerArgs without triggering full validation. + + We construct via the bare dataclass init so we can call the model-specific + adjustment helper directly with a synthetic ``model_arch``. + """ + args = ServerArgs.__new__(ServerArgs) + # Populate every field with its dataclass default; this avoids the + # expensive HF-config-touching ``__post_init__`` path. + import dataclasses + + for field in dataclasses.fields(ServerArgs): + if field.default is not dataclasses.MISSING: + setattr(args, field.name, field.default) + elif field.default_factory is not dataclasses.MISSING: # type: ignore[misc] + setattr(args, field.name, field.default_factory()) + else: + setattr(args, field.name, None) + for k, v in overrides.items(): + setattr(args, k, v) + return args + + +@pytest.fixture(autouse=True) +def _stub_sm100(monkeypatch): + """Force the SM100 branch on machines without sm_100 so the test + runs on any CUDA-capable (or CPU) host. The override path under test + does not depend on sm_100 itself.""" + from sglang.srt import server_args as srv_args + + monkeypatch.setattr(srv_args, "is_sm100_supported", lambda: True, raising=False) + + +def _invoke_gemma4_adjustment( + args, model_arch="Gemma4ForCausalLM", num_experts=0 +): + """Run only the small Gemma-4 branch of ``_handle_model_specific_adjustments``. + + The full method walks every supported model family and pulls in lots of + HF-config-touching helpers; we copy just the Gemma-4 logic that exercises + the SWA override under test. Keeping the test scope tight avoids + coupling it to unrelated branches. + + ``num_experts`` simulates ``hf_text_config.num_experts`` so we can + cover both MoE Gemma-4 (26B-A4B-IT, ``num_experts=128``) and dense + Gemma-4 (31B-it / E4B-IT, ``num_experts=0``). + """ + from sglang.srt.server_args import ServerArgs + + # The real method gates the override on ``model_arch in {"Gemma4ForConditionalGeneration", + # "Gemma4ForCausalLM"}``; we exercise the same exact predicate. + assert model_arch in ( + "Gemma4ForConditionalGeneration", + "Gemma4ForCausalLM", + ) + # Mirror the MoE-only gating logic from server_args.py. + _is_gemma4_moe = num_experts > 0 + if ( + _is_gemma4_moe + and args.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio + ): + args.swa_full_tokens_ratio = 0.15 + + +def test_moe_gemma4_default_overridden(): + """MoE Gemma-4 (e.g. 26B-A4B-IT) should get the 0.15 override when ratio is unset.""" + args = _make_args() + assert args.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio # default 0.8 + _invoke_gemma4_adjustment(args, num_experts=128) # 26B-A4B-IT has 128 experts + assert args.swa_full_tokens_ratio == 0.15 + + +def test_dense_gemma4_default_preserved(): + """Dense Gemma-4 (e.g. 31B-it, E4B-IT) should KEEP the upstream default 0.8. + + Applying 0.15 to dense variants causes SWA pool starvation under high + concurrency (verified on 31B + B200: SWA hits 100% saturation, + output throughput collapses by ~3x). See + ``agent-pad/runs/.../benchmark_final/FINAL_COMPARISON.md``. + """ + args = _make_args() + expected = ServerArgs.swa_full_tokens_ratio # 0.8 + _invoke_gemma4_adjustment(args, num_experts=0) # dense + assert args.swa_full_tokens_ratio == expected + + +@pytest.mark.parametrize( + "model_arch", ["Gemma4ForCausalLM", "Gemma4ForConditionalGeneration"] +) +def test_user_override_preserved(model_arch): + """If user passes --swa-full-tokens-ratio, it must be respected (MoE case).""" + args = _make_args(swa_full_tokens_ratio=0.5) + _invoke_gemma4_adjustment(args, model_arch, num_experts=128) + assert args.swa_full_tokens_ratio == 0.5 + + args = _make_args(swa_full_tokens_ratio=1.0) + _invoke_gemma4_adjustment(args, model_arch, num_experts=128) + assert args.swa_full_tokens_ratio == 1.0 + + +def test_full_method_runs_for_moe_gemma4(monkeypatch): + """Smoke test for MoE Gemma-4: invoke the real + ``_handle_model_specific_adjustments`` and assert the SWA ratio path + fires alongside the attention-backend setup. + + We stub the model-config loader so we don't need real Gemma-4 weights. + """ + from sglang.srt.server_args import ServerArgs + + args = _make_args( + model_path="fake-gemma4-moe", + attention_backend=None, + prefill_attention_backend=None, + decode_attention_backend=None, + moe_runner_backend="auto", + ) + + class _FakeTextConfig: + num_experts = 128 + + class _FakeModelConfig: + quantization = None + hf_text_config = _FakeTextConfig() + + class _FakeModelArchConfig: + def __init__(self): + self.architectures = ["Gemma4ForCausalLM"] + + def _fake_get_model_arch_config(self): + return _FakeModelArchConfig() + + def _fake_get_model_config(self): + return _FakeModelConfig() + + monkeypatch.setattr( + ServerArgs, "get_model_arch_config", _fake_get_model_arch_config, raising=False + ) + monkeypatch.setattr( + ServerArgs, "get_model_config", _fake_get_model_config, raising=False + ) + + try: + args._handle_model_specific_adjustments() + except Exception as exc: + pytest.skip( + f"_handle_model_specific_adjustments needs more stubs in this env: {exc}" + ) + + assert args.swa_full_tokens_ratio == 0.15 + assert args.attention_backend in ("triton", "trtllm_mha") + + +def test_full_method_runs_for_dense_gemma4(monkeypatch): + """Smoke test for dense Gemma-4: invoke the real method and assert + the override is SKIPPED (default 0.8 preserved).""" + from sglang.srt.server_args import ServerArgs + + args = _make_args( + model_path="fake-gemma4-dense", + attention_backend=None, + prefill_attention_backend=None, + decode_attention_backend=None, + moe_runner_backend="auto", + ) + + class _FakeTextConfig: + num_experts = 0 # dense (or attribute missing → also evaluates to 0) + + class _FakeModelConfig: + quantization = None + hf_text_config = _FakeTextConfig() + + class _FakeModelArchConfig: + def __init__(self): + self.architectures = ["Gemma4ForCausalLM"] + + def _fake_get_model_arch_config(self): + return _FakeModelArchConfig() + + def _fake_get_model_config(self): + return _FakeModelConfig() + + monkeypatch.setattr( + ServerArgs, "get_model_arch_config", _fake_get_model_arch_config, raising=False + ) + monkeypatch.setattr( + ServerArgs, "get_model_config", _fake_get_model_config, raising=False + ) + + try: + args._handle_model_specific_adjustments() + except Exception as exc: + pytest.skip( + f"_handle_model_specific_adjustments needs more stubs in this env: {exc}" + ) + + # Dense Gemma-4: override should NOT fire, ratio stays at upstream default 0.8. + assert args.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio + assert args.attention_backend in ("triton", "trtllm_mha") + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v"]))