Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 142 additions & 44 deletions python/sglang/srt/layers/gemma4_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,29 @@
a single kernel pass to reduce kernel launch overhead.
"""

from enum import Enum
from typing import Optional

import torch
import triton
import triton.language as tl


class ProjAndNormMode(Enum):
"""Projection + RMSNorm layout for a Gemma4 attention layer.

Q_ONLY KV-sharing layer; only Q is projected and normalised.
QK_ONLY attention_k_eq_v layer; Q and a shared K/V are projected,
the fused norm derives K and V from one K projection.
QKV_FULL Standard layer; Q, K, V are projected and normalised
independently.
"""

Q_ONLY = "q"
QK_ONLY = "qk"
QKV_FULL = "qkv"


@triton.jit
def _gemma_rmsnorm_residual_kernel(
X_ptr,
Expand Down Expand Up @@ -135,35 +151,47 @@ def _gemma_dual_rmsnorm_residual_kernel(
@triton.jit
def _gemma_qkv_rmsnorm_kernel(
Q_ptr,
K_ptr,
V_ptr,
K_in_ptr,
V_in_ptr,
K_out_ptr,
V_out_ptr,
Q_w_ptr,
K_w_ptr,
stride_q_m,
stride_k_m,
stride_v_m,
stride_kin_m,
stride_vin_m,
stride_kout_m,
stride_vout_m,
NUM_Q_HEADS: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
HEAD_DIM: tl.constexpr,
eps,
HAS_KV: tl.constexpr,
K_EQ_V: tl.constexpr,
BLOCK: tl.constexpr,
):
"""Per-token fused RMSNorm of Q (with q_w), K (with k_w), V (no scale).

Layout assumption: each tensor's last dim packs (num_heads, head_dim) contiguously
so per-head offset is `h * HEAD_DIM`. The token (M) stride is taken from
stride_*_m so the kernel works on strided views (e.g. slices of a larger
qkv buffer produced by `qkv.split`) without requiring `.contiguous()` copies.
V uses `weight=ones` semantics so the multiply-by-weight is omitted.
Three modes, selected via the ``HAS_KV`` / ``K_EQ_V`` constexpr toggles:

* **Q-only** (``HAS_KV=False``): normalises Q in-place from ``Q_ptr``.
K/V pointers are unused.
* **QKV** (``HAS_KV=True, K_EQ_V=False``): normalises Q, K, V in-place.
``K_in_ptr == K_out_ptr`` and ``V_in_ptr == V_out_ptr`` (the launcher
passes the same tensor for input and output).
* **K=V (a.k.a. ``attention_k_eq_v``)** (``HAS_KV=True, K_EQ_V=True``):
normalises Q in-place. ``K_in_ptr`` is the shared raw K/V projection;
``V_in_ptr`` is unused. ``K_out_ptr`` receives ``norm(KV) * k_weight``
and ``V_out_ptr`` receives ``norm(KV)``. One rrms per (token, head) is
shared between K and V.
"""
m = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < HEAD_DIM

qw = tl.load(Q_w_ptr + cols, mask=mask, other=0.0).to(tl.float32)

# Q heads
# Q heads — in-place
for h in tl.static_range(NUM_Q_HEADS):
off = m * stride_q_m + h * HEAD_DIM + cols
x = tl.load(Q_ptr + off, mask=mask, other=0.0).to(tl.float32)
Expand All @@ -174,21 +202,42 @@ def _gemma_qkv_rmsnorm_kernel(
if HAS_KV:
kw = tl.load(K_w_ptr + cols, mask=mask, other=0.0).to(tl.float32)

# K heads
for h in tl.static_range(NUM_KV_HEADS):
off = m * stride_k_m + h * HEAD_DIM + cols
x = tl.load(K_ptr + off, mask=mask, other=0.0).to(tl.float32)
rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps)
out = x * rrms * kw
tl.store(K_ptr + off, out.to(K_ptr.dtype.element_ty), mask=mask)

# V heads (no scaling: V-norm uses weight=ones)
for h in tl.static_range(NUM_KV_HEADS):
off = m * stride_v_m + h * HEAD_DIM + cols
x = tl.load(V_ptr + off, mask=mask, other=0.0).to(tl.float32)
rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps)
out = x * rrms
tl.store(V_ptr + off, out.to(V_ptr.dtype.element_ty), mask=mask)
if K_EQ_V:
# Shared KV input: one read of KV per head, two writes.
for h in tl.static_range(NUM_KV_HEADS):
in_off = m * stride_kin_m + h * HEAD_DIM + cols
x = tl.load(K_in_ptr + in_off, mask=mask, other=0.0).to(tl.float32)
rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps)
v_out = x * rrms
k_out = v_out * kw
k_off = m * stride_kout_m + h * HEAD_DIM + cols
v_off = m * stride_vout_m + h * HEAD_DIM + cols
tl.store(
K_out_ptr + k_off,
k_out.to(K_out_ptr.dtype.element_ty),
mask=mask,
)
tl.store(
V_out_ptr + v_off,
v_out.to(V_out_ptr.dtype.element_ty),
mask=mask,
)
else:
# Separate K and V inputs, normalised in-place.
for h in tl.static_range(NUM_KV_HEADS):
off = m * stride_kin_m + h * HEAD_DIM + cols
x = tl.load(K_in_ptr + off, mask=mask, other=0.0).to(tl.float32)
rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps)
out = x * rrms * kw
tl.store(K_in_ptr + off, out.to(K_in_ptr.dtype.element_ty), mask=mask)

# V heads (no scaling: V-norm uses weight=ones)
for h in tl.static_range(NUM_KV_HEADS):
off = m * stride_vin_m + h * HEAD_DIM + cols
x = tl.load(V_in_ptr + off, mask=mask, other=0.0).to(tl.float32)
rrms = tl.rsqrt(tl.sum(x * x, axis=0) / HEAD_DIM + eps)
out = x * rrms
tl.store(V_in_ptr + off, out.to(V_in_ptr.dtype.element_ty), mask=mask)


def gemma_qkv_rmsnorm(
Expand All @@ -201,49 +250,98 @@ def gemma_qkv_rmsnorm(
num_kv_heads: int,
head_dim: int,
eps: float = 1e-6,
) -> None:
"""In-place fused RMSNorm on Q, K, V for Gemma4 attention.

All three norms compute `x * rsqrt(mean(x^2) + eps)` independently per head.
Q is scaled by `q_weight`, K by `k_weight`, V by 1 (Gemma4's V-norm has
`with_scale=False`).

Inputs may be 2D `(M, num_heads * head_dim)` or strided views of a larger
buffer (such as q/k/v slices from `qkv.split`). The kernel uses the actual
`stride(0)` so no `.contiguous()` copy is required. Within a token, the
last dim must be contiguous so heads pack as `h * head_dim` offsets.

If k and v are both None (KV-shared layer), only Q is normalized.
*,
mode: ProjAndNormMode = ProjAndNormMode.QKV_FULL,
) -> Optional[tuple[torch.Tensor, torch.Tensor]]:
"""Fused per-head RMSNorm on Q, K, V (or any subset) for Gemma4.

Q is scaled by q_weight, K by k_weight, V by 1 (Gemma4 V-norm uses
with_scale=False). The caller picks the layout via ``mode``:

Q_ONLY k=None, v=None. Q normalised in place. Returns None.
QKV_FULL k and v non-None. Q, K, V normalised in place. Returns
None.
QK_ONLY k is the shared K/V projection, v=None. Q normalised in
place; fresh K and V tensors are allocated. Returns
(k_out, v_out).
"""
assert q.is_cuda
assert q.stride(-1) == 1, "Q's last dim must be contiguous"
assert q_weight.shape[-1] == head_dim
M = q.shape[0] if q.dim() >= 2 else 1
BLOCK = triton.next_power_of_2(head_dim)

has_kv = k is not None and v is not None
if has_kv:
# Resolve the mode + allocate outputs if needed.
if mode is ProjAndNormMode.QK_ONLY:
assert (
k is not None and v is None
), "QK_ONLY expects k=<shared KV input>, v=None"
assert k.is_cuda and k.stride(-1) == 1
assert k_weight is not None and k_weight.shape[-1] == head_dim
assert (
q.shape[0] == k.shape[0]
), f"M mismatch: q.shape[0]={q.shape[0]} vs kv.shape[0]={k.shape[0]}"
has_kv = True
k_eq_v = True
k_in = k
v_in = q # unused; just need a valid pointer for triton.
k_out = torch.empty_like(k)
v_out = torch.empty_like(k)
stride_kin_m = k.stride(0)
stride_vin_m = 0
stride_kout_m = k_out.stride(0)
stride_vout_m = v_out.stride(0)
elif mode is ProjAndNormMode.QKV_FULL:
assert k is not None and v is not None, "QKV_FULL expects non-None k and v"
assert k.is_cuda and v.is_cuda
assert k.stride(-1) == 1 and v.stride(-1) == 1
assert k_weight is not None and k_weight.shape[-1] == head_dim
has_kv = True
k_eq_v = False
k_in = k
v_in = v
# In-place: outputs == inputs.
k_out = k
v_out = v
stride_kin_m = k.stride(0)
stride_vin_m = v.stride(0)
stride_kout_m = k.stride(0)
stride_vout_m = v.stride(0)
else:
assert mode is ProjAndNormMode.Q_ONLY
assert k is None and v is None, "Q_ONLY requires both k and v to be None"
has_kv = False
k_eq_v = False
# Unused pointers; pass q for safety.
k_in = v_in = k_out = v_out = q
stride_kin_m = stride_vin_m = stride_kout_m = stride_vout_m = 0

_gemma_qkv_rmsnorm_kernel[(M,)](
q,
k if has_kv else q,
v if has_kv else q,
k_in,
v_in,
k_out,
v_out,
q_weight,
k_weight if has_kv else q_weight,
q.stride(0),
k.stride(0) if has_kv else 0,
v.stride(0) if has_kv else 0,
stride_kin_m,
stride_vin_m,
stride_kout_m,
stride_vout_m,
NUM_Q_HEADS=num_q_heads,
NUM_KV_HEADS=num_kv_heads if has_kv else 0,
HEAD_DIM=head_dim,
eps=eps,
HAS_KV=has_kv,
K_EQ_V=k_eq_v,
BLOCK=BLOCK,
)

if mode is ProjAndNormMode.QK_ONLY:
return k_out, v_out
return None


def gemma_dual_rmsnorm_residual_scalar(
x1: torch.Tensor,
Expand Down
Loading
Loading