diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py index ad6f01d9875a..037846feafa0 100644 --- a/python/sglang/srt/layers/gemma4_fused_ops.py +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -4,6 +4,7 @@ a single kernel pass to reduce kernel launch overhead. """ +from enum import Enum from typing import Optional import torch @@ -11,6 +12,21 @@ import triton.language as tl +class ProjAndNormMode(Enum): + """Projection + RMSNorm layout for a Gemma4 attention layer. + + Q_ONLY KV-sharing layer; only Q is projected and normalised. + QK_ONLY attention_k_eq_v layer; Q and a shared K/V are projected, + the fused norm derives K and V from one K projection. + QKV_FULL Standard layer; Q, K, V are projected and normalised + independently. + """ + + Q_ONLY = "q" + QK_ONLY = "qk" + QKV_FULL = "qkv" + + @triton.jit def _gemma_rmsnorm_residual_kernel( X_ptr, @@ -135,27 +151,39 @@ def _gemma_dual_rmsnorm_residual_kernel( @triton.jit def _gemma_qkv_rmsnorm_kernel( Q_ptr, - K_ptr, - V_ptr, + K_in_ptr, + V_in_ptr, + K_out_ptr, + V_out_ptr, Q_w_ptr, K_w_ptr, stride_q_m, - stride_k_m, - stride_v_m, + stride_kin_m, + stride_vin_m, + stride_kout_m, + stride_vout_m, NUM_Q_HEADS: tl.constexpr, NUM_KV_HEADS: tl.constexpr, HEAD_DIM: tl.constexpr, eps, HAS_KV: tl.constexpr, + K_EQ_V: tl.constexpr, BLOCK: tl.constexpr, ): """Per-token fused RMSNorm of Q (with q_w), K (with k_w), V (no scale). - Layout assumption: each tensor's last dim packs (num_heads, head_dim) contiguously - so per-head offset is `h * HEAD_DIM`. The token (M) stride is taken from - stride_*_m so the kernel works on strided views (e.g. slices of a larger - qkv buffer produced by `qkv.split`) without requiring `.contiguous()` copies. - V uses `weight=ones` semantics so the multiply-by-weight is omitted. + Three modes, selected via the ``HAS_KV`` / ``K_EQ_V`` constexpr toggles: + + * **Q-only** (``HAS_KV=False``): normalises Q in-place from ``Q_ptr``. + K/V pointers are unused. + * **QKV** (``HAS_KV=True, K_EQ_V=False``): normalises Q, K, V in-place. + ``K_in_ptr == K_out_ptr`` and ``V_in_ptr == V_out_ptr`` (the launcher + passes the same tensor for input and output). + * **K=V (a.k.a. ``attention_k_eq_v``)** (``HAS_KV=True, K_EQ_V=True``): + normalises Q in-place. ``K_in_ptr`` is the shared raw K/V projection; + ``V_in_ptr`` is unused. ``K_out_ptr`` receives ``norm(KV) * k_weight`` + and ``V_out_ptr`` receives ``norm(KV)``. One rrms per (token, head) is + shared between K and V. """ m = tl.program_id(0) cols = tl.arange(0, BLOCK) @@ -163,7 +191,7 @@ def _gemma_qkv_rmsnorm_kernel( qw = tl.load(Q_w_ptr + cols, mask=mask, other=0.0).to(tl.float32) - # Q heads + # Q heads — in-place for h in tl.static_range(NUM_Q_HEADS): off = m * stride_q_m + h * HEAD_DIM + cols x = tl.load(Q_ptr + off, mask=mask, other=0.0).to(tl.float32) @@ -174,21 +202,42 @@ def _gemma_qkv_rmsnorm_kernel( if HAS_KV: kw = tl.load(K_w_ptr + cols, mask=mask, other=0.0).to(tl.float32) - # K heads - for h in tl.static_range(NUM_KV_HEADS): - off = m * stride_k_m + h * HEAD_DIM + cols - x = tl.load(K_ptr + off, mask=mask, other=0.0).to(tl.float32) - rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps) - out = x * rrms * kw - tl.store(K_ptr + off, out.to(K_ptr.dtype.element_ty), mask=mask) - - # V heads (no scaling: V-norm uses weight=ones) - for h in tl.static_range(NUM_KV_HEADS): - off = m * stride_v_m + h * HEAD_DIM + cols - x = tl.load(V_ptr + off, mask=mask, other=0.0).to(tl.float32) - rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps) - out = x * rrms - tl.store(V_ptr + off, out.to(V_ptr.dtype.element_ty), mask=mask) + if K_EQ_V: + # Shared KV input: one read of KV per head, two writes. + for h in tl.static_range(NUM_KV_HEADS): + in_off = m * stride_kin_m + h * HEAD_DIM + cols + x = tl.load(K_in_ptr + in_off, mask=mask, other=0.0).to(tl.float32) + rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps) + v_out = x * rrms + k_out = v_out * kw + k_off = m * stride_kout_m + h * HEAD_DIM + cols + v_off = m * stride_vout_m + h * HEAD_DIM + cols + tl.store( + K_out_ptr + k_off, + k_out.to(K_out_ptr.dtype.element_ty), + mask=mask, + ) + tl.store( + V_out_ptr + v_off, + v_out.to(V_out_ptr.dtype.element_ty), + mask=mask, + ) + else: + # Separate K and V inputs, normalised in-place. + for h in tl.static_range(NUM_KV_HEADS): + off = m * stride_kin_m + h * HEAD_DIM + cols + x = tl.load(K_in_ptr + off, mask=mask, other=0.0).to(tl.float32) + rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps) + out = x * rrms * kw + tl.store(K_in_ptr + off, out.to(K_in_ptr.dtype.element_ty), mask=mask) + + # V heads (no scaling: V-norm uses weight=ones) + for h in tl.static_range(NUM_KV_HEADS): + off = m * stride_vin_m + h * HEAD_DIM + cols + x = tl.load(V_in_ptr + off, mask=mask, other=0.0).to(tl.float32) + rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps) + out = x * rrms + tl.store(V_in_ptr + off, out.to(V_in_ptr.dtype.element_ty), mask=mask) def gemma_qkv_rmsnorm( @@ -201,19 +250,20 @@ def gemma_qkv_rmsnorm( num_kv_heads: int, head_dim: int, eps: float = 1e-6, -) -> None: - """In-place fused RMSNorm on Q, K, V for Gemma4 attention. - - All three norms compute `x * rsqrt(mean(x^2) + eps)` independently per head. - Q is scaled by `q_weight`, K by `k_weight`, V by 1 (Gemma4's V-norm has - `with_scale=False`). - - Inputs may be 2D `(M, num_heads * head_dim)` or strided views of a larger - buffer (such as q/k/v slices from `qkv.split`). The kernel uses the actual - `stride(0)` so no `.contiguous()` copy is required. Within a token, the - last dim must be contiguous so heads pack as `h * head_dim` offsets. - - If k and v are both None (KV-shared layer), only Q is normalized. + *, + mode: ProjAndNormMode = ProjAndNormMode.QKV_FULL, +) -> Optional[tuple[torch.Tensor, torch.Tensor]]: + """Fused per-head RMSNorm on Q, K, V (or any subset) for Gemma4. + + Q is scaled by q_weight, K by k_weight, V by 1 (Gemma4 V-norm uses + with_scale=False). The caller picks the layout via ``mode``: + + Q_ONLY k=None, v=None. Q normalised in place. Returns None. + QKV_FULL k and v non-None. Q, K, V normalised in place. Returns + None. + QK_ONLY k is the shared K/V projection, v=None. Q normalised in + place; fresh K and V tensors are allocated. Returns + (k_out, v_out). """ assert q.is_cuda assert q.stride(-1) == 1, "Q's last dim must be contiguous" @@ -221,29 +271,77 @@ def gemma_qkv_rmsnorm( M = q.shape[0] if q.dim() >= 2 else 1 BLOCK = triton.next_power_of_2(head_dim) - has_kv = k is not None and v is not None - if has_kv: + # Resolve the mode + allocate outputs if needed. + if mode is ProjAndNormMode.QK_ONLY: + assert ( + k is not None and v is None + ), "QK_ONLY expects k=, v=None" + assert k.is_cuda and k.stride(-1) == 1 + assert k_weight is not None and k_weight.shape[-1] == head_dim + assert ( + q.shape[0] == k.shape[0] + ), f"M mismatch: q.shape[0]={q.shape[0]} vs kv.shape[0]={k.shape[0]}" + has_kv = True + k_eq_v = True + k_in = k + v_in = q # unused; just need a valid pointer for triton. + k_out = torch.empty_like(k) + v_out = torch.empty_like(k) + stride_kin_m = k.stride(0) + stride_vin_m = 0 + stride_kout_m = k_out.stride(0) + stride_vout_m = v_out.stride(0) + elif mode is ProjAndNormMode.QKV_FULL: + assert k is not None and v is not None, "QKV_FULL expects non-None k and v" assert k.is_cuda and v.is_cuda assert k.stride(-1) == 1 and v.stride(-1) == 1 assert k_weight is not None and k_weight.shape[-1] == head_dim + has_kv = True + k_eq_v = False + k_in = k + v_in = v + # In-place: outputs == inputs. + k_out = k + v_out = v + stride_kin_m = k.stride(0) + stride_vin_m = v.stride(0) + stride_kout_m = k.stride(0) + stride_vout_m = v.stride(0) + else: + assert mode is ProjAndNormMode.Q_ONLY + assert k is None and v is None, "Q_ONLY requires both k and v to be None" + has_kv = False + k_eq_v = False + # Unused pointers; pass q for safety. + k_in = v_in = k_out = v_out = q + stride_kin_m = stride_vin_m = stride_kout_m = stride_vout_m = 0 _gemma_qkv_rmsnorm_kernel[(M,)]( q, - k if has_kv else q, - v if has_kv else q, + k_in, + v_in, + k_out, + v_out, q_weight, k_weight if has_kv else q_weight, q.stride(0), - k.stride(0) if has_kv else 0, - v.stride(0) if has_kv else 0, + stride_kin_m, + stride_vin_m, + stride_kout_m, + stride_vout_m, NUM_Q_HEADS=num_q_heads, NUM_KV_HEADS=num_kv_heads if has_kv else 0, HEAD_DIM=head_dim, eps=eps, HAS_KV=has_kv, + K_EQ_V=k_eq_v, BLOCK=BLOCK, ) + if mode is ProjAndNormMode.QK_ONLY: + return k_out, v_out + return None + def gemma_dual_rmsnorm_residual_scalar( x1: torch.Tensor, diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index c406f12a2b6c..c0951ba32b78 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -30,12 +30,15 @@ get_tensor_model_parallel_world_size, ) from sglang.srt.layers.gemma4_fused_ops import ( + ProjAndNormMode, gemma_dual_rmsnorm_residual_scalar, gemma_qkv_rmsnorm, gemma_rmsnorm_residual_scalar, ) from sglang.srt.layers.layernorm import Gemma4RMSNorm, RMSNorm from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, @@ -124,6 +127,32 @@ def pp_filter_load_weight( return False +def get_k_eq_v_layers(text_config) -> Set[int]: + """Return set of layer indices where attention_k_eq_v applies (full-attention layers).""" + if not getattr(text_config, "attention_k_eq_v", False): + return set() + return {i for i, lt in enumerate(text_config.layer_types) if lt == "full_attention"} + + +def _resolve_proj_mode( + layer_id: int, + *, + k_eq_v_layers: Set[int], + kv_shared_layers: Set[int], +) -> ProjAndNormMode: + """Pick the projection + norm layout for one decoder layer. + + Q-only wins over QK-only when a layer appears in both sets (KV + sharing borrows another layer's K/V cache, so the K projection + would be dead weight). + """ + if layer_id in kv_shared_layers: + return ProjAndNormMode.Q_ONLY + if layer_id in k_eq_v_layers: + return ProjAndNormMode.QK_ONLY + return ProjAndNormMode.QKV_FULL + + class Gemma4Router(nn.Module): """Router for Gemma4 MoE that preprocesses input before projection. @@ -269,6 +298,7 @@ def __init__( max_position_embeddings: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + proj_mode: ProjAndNormMode = ProjAndNormMode.QKV_FULL, ) -> None: super().__init__() @@ -305,15 +335,48 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=config.attention_bias, - quant_config=quant_config, - prefix=add_prefix("qkv_proj", prefix), - ) + # Single source of truth for this layer's projection + norm + # layout. See ProjAndNormMode. + self.proj_mode = proj_mode + + # Build exactly the projection layer this layout needs. + self.q_proj = None + self.qk_proj = None + self.qkv_proj = None + if proj_mode is ProjAndNormMode.Q_ONLY: + # KV-sharing: K/V come from another layer's cache. + self.q_proj = ColumnParallelLinear( + hidden_size, + self.total_num_heads * self.head_dim, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("q_proj", prefix), + ) + elif proj_mode is ProjAndNormMode.QK_ONLY: + # attention_k_eq_v: Q (shard 0) and K (shard 1) merged into + # one Linear; V is derived from K at runtime via the fused + # QK_ONLY norm kernel. + self.qk_proj = MergedColumnParallelLinear( + hidden_size, + [ + self.total_num_heads * self.head_dim, # Q + self.total_num_kv_heads * self.head_dim, # K + ], + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("qk_proj", prefix), + ) + else: + assert proj_mode is ProjAndNormMode.QKV_FULL + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, @@ -322,17 +385,21 @@ def __init__( prefix=add_prefix("o_proj", prefix), ) - self.q_norm = Gemma4RMSNorm( - self.head_dim, - eps=config.rms_norm_eps, - ) - self.k_norm = Gemma4RMSNorm( - self.head_dim, - eps=config.rms_norm_eps, - ) - self.v_norm = Gemma4RMSNorm( - self.head_dim, eps=config.rms_norm_eps, scale_shift=0.0, with_scale=False - ) + # Norms: Q-only layers have only q_norm in the checkpoint, so + # don't allocate k_norm / v_norm (they'd silently stay zero-init + # and waste memory). + self.q_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + if proj_mode is ProjAndNormMode.Q_ONLY: + self.k_norm = None + self.v_norm = None + else: + self.k_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma4RMSNorm( + self.head_dim, + eps=config.rms_norm_eps, + scale_shift=0.0, + with_scale=False, + ) if layer_type in config.rope_parameters: rope_parameters = dict(config.rope_parameters[layer_type]) @@ -342,12 +409,12 @@ def __init__( rope_theta=10000.0, ) - # KV sharing logic + # KV sharing: proj_mode == Q_ONLY is the single source of truth. + # The caller (e.g. assistant MTP) may force this even when the + # config wouldn't. num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0) first_kv_shared_layer_idx = config.num_hidden_layers - num_kv_shared_layers - self.is_kv_shared_layer = ( - layer_id >= first_kv_shared_layer_idx and num_kv_shared_layers > 0 - ) + self.is_kv_shared_layer = proj_mode is ProjAndNormMode.Q_ONLY self.kv_shared_layer_index = None if num_kv_shared_layers > 0 and self.layer_id >= first_kv_shared_layer_idx: @@ -394,75 +461,15 @@ def forward( forward_batch: ForwardBatch, **kwargs, ): - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - # Fused Q/K/V RMSNorm: replaces three separate norm kernels with one. - # Preconditions for the fused path: tensors on CUDA, q_norm/k_norm use - # the standard norm*weight (scale_shift==0) and v_norm has weight=ones - # (with_scale=False) — the canonical Gemma4 attention configuration. - is_kv_shared = ( - self.is_kv_shared_layer and self.kv_shared_layer_index is not None - ) - can_fuse_qkv_norm = ( - q.is_cuda - and self.q_norm.scale_shift == 0.0 - and self.k_norm.scale_shift == 0.0 - and not self.v_norm.with_scale - ) - if can_fuse_qkv_norm: - if is_kv_shared: - gemma_qkv_rmsnorm( - q, - None, - None, - self.q_norm.weight.data, - None, - num_q_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - eps=self.q_norm.eps, - ) - k = None - v = None - else: - gemma_qkv_rmsnorm( - q, - k, - v, - self.q_norm.weight.data, - self.k_norm.weight.data, - num_q_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - eps=self.q_norm.eps, - ) - # Match the original norm path's output shapes: q stays 2D, - # k/v become 3D so the subsequent `.flatten(-2, -1)` works. - # Use reshape (not view) since k/v are strided slice views of - # the qkv buffer and may not satisfy view's contiguity rules. - k = k.reshape(-1, self.num_kv_heads, self.head_dim) - v = v.reshape(-1, self.num_kv_heads, self.head_dim) - else: - q = q.unflatten(-1, (self.num_heads, self.head_dim)) - q = self.q_norm(q) - q = q.flatten(-2, -1) - if is_kv_shared: - k = None - v = None - else: - k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) - k = self.k_norm(k) - v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) - v = self.v_norm(v) + q, k, v = self._project_and_norm(hidden_states) - # Apply rotary embedding + # Apply rotary embedding. K is None for Q-only layers (KV-shared); + # rotary needs a key input so we pass a zero stand-in. if k is not None: k = k.flatten(-2, -1) q, k = self.rotary_emb(positions, q, k) k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) else: - # Rotary embedding requires a key input; use zeros since KV is shared from another layer dummy_k = torch.zeros_like(q[:, : self.kv_size]) q, _ = self.rotary_emb(positions, q, dummy_k) @@ -477,9 +484,119 @@ def forward( if attn_output.dim() == 3: attn_output = attn_output.flatten(-2, -1) output, _ = self.o_proj(attn_output) - return output + def _can_use_fused_norm(self, q: torch.Tensor) -> bool: + """The fused kernel assumes the canonical Gemma4 norm settings: + Q/K use the standard ``x * rrms * w`` form and V has unit scale. + Anything else falls back to per-tensor Gemma4RMSNorm modules.""" + return ( + q.is_cuda + and self.q_norm.scale_shift == 0.0 + and (self.k_norm is None or self.k_norm.scale_shift == 0.0) + and (self.v_norm is None or not self.v_norm.with_scale) + ) + + def _project_and_norm(self, hidden_states: torch.Tensor): + """Return (q, k, v) ready for rotary + attention.""" + if self.proj_mode is ProjAndNormMode.Q_ONLY: + return self._project_and_norm_q_only(hidden_states) + if self.proj_mode is ProjAndNormMode.QK_ONLY: + return self._project_and_norm_qk(hidden_states) + return self._project_and_norm_qkv(hidden_states) + + def _project_and_norm_q_only(self, hidden_states: torch.Tensor): + """Q only project. + + KV is read from another layer's cache; only Q is + projected and normalised. K/V are returned as ``None`` so the + downstream rotary code uses a dummy and the attention call + reads KV from cache. + """ + q, _ = self.q_proj(hidden_states) + if self._can_use_fused_norm(q): + gemma_qkv_rmsnorm( + q, + None, + None, + self.q_norm.weight.data, + None, + num_q_heads=self.num_heads, + num_kv_heads=0, + head_dim=self.head_dim, + eps=self.q_norm.eps, + mode=ProjAndNormMode.Q_ONLY, + ) + else: + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + return q, None, None + + def _project_and_norm_qk(self, hidden_states: torch.Tensor): + """Q + K projection (V derived from K). + + Uses the K_EQ_V mode of the fused norm kernel: K and V are allocated out-of-place from a single shared K input read. + """ + qk, _ = self.qk_proj(hidden_states) + q, k = qk.split([self.q_size, self.kv_size], dim=-1) + if self._can_use_fused_norm(q): + # K is a strided slice of the qk buffer; the kernel respects + # stride(0) so no .contiguous() copy is needed. + k_out, v_out = gemma_qkv_rmsnorm( + q, + k, + None, + self.q_norm.weight.data, + self.k_norm.weight.data, + num_q_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + eps=self.q_norm.eps, + mode=ProjAndNormMode.QK_ONLY, + ) + k = k_out.reshape(-1, self.num_kv_heads, self.head_dim) + v = v_out.reshape(-1, self.num_kv_heads, self.head_dim) + else: + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(k) + k = self.k_norm(k) + return q, k, v + + def _project_and_norm_qkv(self, hidden_states: torch.Tensor): + """Standard QKV projection with three independent shards.""" + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self._can_use_fused_norm(q): + gemma_qkv_rmsnorm( + q, + k, + v, + self.q_norm.weight.data, + self.k_norm.weight.data, + num_q_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + eps=self.q_norm.eps, + mode=ProjAndNormMode.QKV_FULL, + ) + # Use reshape (not view) since k/v are strided slice views of + # the qkv buffer and may not satisfy view's contiguity rules. + k = k.reshape(-1, self.num_kv_heads, self.head_dim) + v = v.reshape(-1, self.num_kv_heads, self.head_dim) + else: + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(v) + return q, k, v + class Gemma4DecoderLayer(nn.Module): def __init__( @@ -488,6 +605,7 @@ def __init__( config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + proj_mode: ProjAndNormMode = ProjAndNormMode.QKV_FULL, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -512,6 +630,7 @@ def __init__( head_dim=head_dim, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), + proj_mode=proj_mode, ) first_kv_shared_layer_idx = config.num_hidden_layers - getattr( @@ -723,6 +842,7 @@ def __init__( config: Gemma4TextConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + kv_shared_layer_indices: Optional[Set[int]] = None, ) -> None: super().__init__(config=config) self.config = config @@ -810,6 +930,21 @@ def __init__( self.per_layer_input_scale = None self.per_layer_projection_scale = None + # Resolve per-layer projection / norm layout. The caller may + # override the config-derived KV-sharing set via + # kv_shared_layer_indices (MTP assistant passes every layer to + # read from the target's cache). + k_eq_v_layers = get_k_eq_v_layers(config) + if kv_shared_layer_indices is None: + n_shared = getattr(config, "num_kv_shared_layers", 0) + kv_shared_layer_indices = ( + set( + range(config.num_hidden_layers - n_shared, config.num_hidden_layers) + ) + if n_shared > 0 + else set() + ) + self.layers, self.start_layer, self.end_layer = make_layers( config.num_hidden_layers, lambda idx, prefix: Gemma4DecoderLayer( @@ -817,6 +952,11 @@ def __init__( config=config, quant_config=quant_config, prefix=prefix, + proj_mode=_resolve_proj_mode( + idx, + k_eq_v_layers=k_eq_v_layers, + kv_shared_layers=kv_shared_layer_indices, + ), ), pp_rank=self.pp_group.rank_in_group, pp_size=self.pp_group.world_size, @@ -1123,13 +1263,9 @@ def forward( input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states ) - def _get_k_eq_v_layers(self) -> set: + def _get_k_eq_v_layers(self) -> Set[int]: """Return set of layer indices where attention_k_eq_v applies (full-attention layers).""" - if not getattr(self.config, "attention_k_eq_v", False): - return set() - return { - i for i, lt in enumerate(self.config.layer_types) if lt == "full_attention" - } + return get_k_eq_v_layers(self.config) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -1141,6 +1277,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] + # k_eq_v layers use MergedColumnParallelLinear (qk_proj) instead + # of QKVParallelLinear (qkv_proj). Map checkpoint q_proj / k_proj + # to integer shard ids 0 and 1 respectively. V is derived from K + # at runtime via the K_EQ_V fused norm, so there is no v shard + # and no v_proj weight in the checkpoint for these layers. + k_eq_v_stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qk_proj", "q_proj", 0), + ("qk_proj", "k_proj", 1), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + fused_expert_params_mapping = [ # (param_name, ckpt_weight_name, shard_ids) # gate_up_proj is fused [E, 2*I, H] — chunk into w1 (gate) + w3 (up) @@ -1210,16 +1359,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ): continue - # attention_k_eq_v: full-attention layers have no v_proj in the - # checkpoint (K and V share weights). When we see a k_proj weight - # for one of these layers, load it into both the "k" and "v" shards - # of the fused QKV so the forward produces v_raw == k_raw. - should_dup_k_to_v = ( - ".k_proj." in name - and k_eq_v_layers - and (m := re.search(r"layers\.(\d+)\.", name)) is not None - and int(m.group(1)) in k_eq_v_layers - ) + # Determine whether this weight belongs to a k_eq_v layer. + is_k_eq_v_layer = False + if k_eq_v_layers: + m = re.search(r"layers\.(\d+)\.", name) + if m is not None: + is_k_eq_v_layer = int(m.group(1)) in k_eq_v_layers # MoE expert weights checked first (gate_up_proj contains "up_proj" # which would false-match the stacked dense MLP mapping). @@ -1274,7 +1419,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_params.add(name) break else: - for param_name, weight_name, shard_id in stacked_params_mapping: + # 3) Stacked dense projection weights. k_eq_v layers + # pack only Q+K into qk_proj (V is derived at + # runtime from K via the K_EQ_V fused norm), so + # they need a different mapping than the standard + # qkv_proj layers. + mapping = ( + k_eq_v_stacked_params_mapping + if is_k_eq_v_layer + else stacked_params_mapping + ) + for param_name, weight_name, shard_id in mapping: name = orig_name if weight_name not in name: continue @@ -1284,8 +1439,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - if should_dup_k_to_v: - weight_loader(param, loaded_weight, "v") loaded_params.add(name) break else: diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index cafc31f20ce8..d9a2153ac53a 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -58,7 +58,11 @@ maybe_remap_kv_scale_name, ) from sglang.srt.models.gemma4_audio import Gemma4AudioEncoder -from sglang.srt.models.gemma4_causal import Gemma4TextModel, pp_filter_load_weight +from sglang.srt.models.gemma4_causal import ( + Gemma4TextModel, + get_k_eq_v_layers, + pp_filter_load_weight, +) from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor @@ -684,6 +688,17 @@ def tie_weights(self, recompute_mapping=False): (".gate_up_proj", ".gate_proj", 0), ] + # k_eq_v layers use MergedColumnParallelLinear (qk_proj) instead of + # QKVParallelLinear (qkv_proj). Map checkpoint q_proj / k_proj to + # integer shard ids 0 and 1 respectively. + k_eq_v_stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qk_proj", ".q_proj", 0), + (".qk_proj", ".k_proj", 1), + (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), + ] + # Regex for fused QKV in vision/audio towers. # Vision: *.self_attn.{q,k,v}_proj.* Audio: *.attn.{q,k,v}_proj.* _RE_TOWER_QKV = re.compile( @@ -797,14 +812,9 @@ def _remap_tower_name(name: str, params_dict: dict) -> str: return name - def _get_k_eq_v_layers(self) -> set: + def _get_k_eq_v_layers(self) -> Set[int]: """Return set of layer indices where attention_k_eq_v applies (full-attention layers).""" - text_config = self.config.text_config - if not getattr(text_config, "attention_k_eq_v", False): - return set() - return { - i for i, lt in enumerate(text_config.layer_types) if lt == "full_attention" - } + return get_k_eq_v_layers(self.config.text_config) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): k_eq_v_layers = self._get_k_eq_v_layers() @@ -899,17 +909,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if "vision_tower." in name or "audio_tower." in name: name = self._remap_tower_name(name, params_dict) - # attention_k_eq_v: full-attention layers have no v_proj in the - # checkpoint (K and V share weights). When we see a k_proj weight - # for one of these layers, load it into both the "k" and "v" shards - # of the fused QKV so the forward produces v_raw == k_raw. - should_dup_k_to_v = ( - ".k_proj." in name - and k_eq_v_layers - and "language_model." in name - and (m := re.search(r"layers\.(\d+)\.", name)) is not None - and int(m.group(1)) in k_eq_v_layers - ) + # Determine whether this weight belongs to a k_eq_v layer. + is_k_eq_v_layer = False + if k_eq_v_layers and "language_model." in name: + m_layer = re.search(r"layers\.(\d+)\.", name) + if m_layer is not None: + is_k_eq_v_layer = int(m_layer.group(1)) in k_eq_v_layers # MoE expert weights checked first (gate_up_proj contains "up_proj" # which would false-match the stacked dense MLP mapping). @@ -964,11 +969,21 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_params.add(name) break else: + # 3) Stacked dense projection weights. k_eq_v layers + # pack only Q+K into qk_proj (V is derived at + # runtime from K via the K_EQ_V fused norm), so + # they need a different mapping than the standard + # qkv_proj layers. + mapping = ( + self.k_eq_v_stacked_params_mapping + if is_k_eq_v_layer + else self.stacked_params_mapping + ) for ( param_name, weight_name, shard_id, - ) in self.stacked_params_mapping: + ) in mapping: name = orig_name if weight_name not in name: continue @@ -978,8 +993,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - if should_dup_k_to_v: - weight_loader(param, loaded_weight, "v") loaded_params.add(name) break else: diff --git a/python/sglang/srt/models/gemma4_mtp.py b/python/sglang/srt/models/gemma4_mtp.py index ade10ce5b990..ccf1b98b8825 100644 --- a/python/sglang/srt/models/gemma4_mtp.py +++ b/python/sglang/srt/models/gemma4_mtp.py @@ -15,7 +15,7 @@ import copy import logging -from typing import Dict, Iterable, Optional, Tuple +from typing import Dict, Iterable, Optional, Set, Tuple import torch from torch import nn @@ -73,6 +73,8 @@ def __init__( self.assistant_config = config self.config = text_config self.quant_config = quant_config + n_layers = text_config.num_hidden_layers + self._assistant_kv_shared_layers: Set[int] = set(range(n_layers)) self.pp_group = get_pp_group() self.vocab_size = text_config.vocab_size @@ -98,6 +100,7 @@ def __init__( config=text_config, quant_config=quant_config, prefix=add_prefix("model", prefix), + kv_shared_layer_indices=self._assistant_kv_shared_layers, ) self.post_projection = ReplicatedLinear( self.hidden_size,