Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions python/sglang/srt/layers/gemma4_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

Fuses standard RMSNorm + residual-add (+ optional scalar multiply) into
a single kernel pass to reduce kernel launch overhead.

Also provides a single-launch fused router for Gemma4 MoE (PR #26120 in
pyc96/sglang fork): replaces the per-layer ``torch.topk`` ->
``softmax`` -> ``per_expert_scale[ids]`` -> ``mul`` -> ``cast`` chain in
``Gemma4MoE.routing_function`` with one Triton kernel.

The reference design comes from vLLM PR #39083
(``_gemma4_routing_kernel`` / ``gemma4_fused_routing_kernel_triton``),
which is apache-2.0. Our kernel is rewritten in SGLang style and uses
the identity ``softmax(all)[topk] / sum(softmax(all)[topk]) =
softmax(topk_logits)`` already exploited by SGLang's torch routing
function, so the math is bitwise-comparable to the prior fp32 path.
"""

from typing import Optional
Expand Down Expand Up @@ -283,3 +295,186 @@ def gemma_dual_rmsnorm_residual_scalar(
BLOCK_SIZE=BLOCK_SIZE,
)
return out


# ---------------------------------------------------------------------------
# Fused Gemma4 routing kernel (one launch per layer)
# ---------------------------------------------------------------------------
#
# Equivalent to:
#
# topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
# topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1)
# topk_weights = topk_weights * per_expert_scale[topk_ids]
# return topk_weights.float(), topk_ids.int()
#
# but completes the entire computation in one Triton program per token.
#
# Algorithm notes:
# * Loads all E logits per token into one program; for Gemma4
# ``E = num_experts = 128`` so ``BLOCK_E = next_pow2(E) = 128`` and the
# work fits in a single warp with `num_warps=1`.
# * Computes ``softmax-of-topk`` by:
# - building a bijection (logit_bits -> int32 key) that is anti-monotone
# on the float value, then packing ``(key, expert_id)`` into int64.
# After the ``<<32`` shift, the int32 key's high bit lands in bit 63
# of the int64, so Triton's signed ascending ``tl.sort`` yields the
# logits in *descending* float order without a K-step loop or a
# separate index scatter.
# - taking the largest K via a mask on positions 0..K-1 of the sorted
# output
# - normalizing in fp32 (matches ``softmax`` default dtype)
# - multiplying by ``per_expert_scale[topk_ids]``
# * Writes ``topk_weights`` (fp32) and ``topk_ids`` (int32) in one
# pass, matching the output dtypes the SGLang MoE topk wrapper
# expects.
# * Compatibility with quantized MoE backends: this fast path runs whenever
# the model calls ``Gemma4MoE.routing_function`` via
# ``select_experts``. That covers unquantized BF16/FP16, FP8/W8A8 (where
# the standard topk path is used), and the ``flashinfer_trtllm_routed``
# NVFP4 path. The default ``flashinfer_trtllm`` NVFP4 backend uses a
# BypassedTopKOutput and does routing inside the trtllm kernel, so this
# function is neither called nor needed there.
#
# Reference algorithm: vLLM PR #39083 ``_gemma4_routing_kernel`` (apache-2.0).
# Our independent implementation follows the same sort+mask+softmax scheme.
@triton.jit
def _gemma4_routing_kernel(
gating_ptr, # [T, E] router logits, any float dtype
per_expert_scale_ptr, # [E] per-expert scale (any float dtype)
topk_weights_ptr, # [T, K] fp32 out
topk_ids_ptr, # [T, K] int32 out
stride_g_t, # stride of gating in the token dim
E: tl.constexpr,
K: tl.constexpr,
BLOCK_E: tl.constexpr,
):
pid = tl.program_id(0)
offs_e = tl.arange(0, BLOCK_E)
valid = offs_e < E

# Load logits into fp32; out-of-bound lanes get -inf so they sort last.
logits = tl.load(
gating_ptr + pid * stride_g_t + offs_e,
mask=valid,
other=-float("inf"),
).to(tl.float32)

# Build a sortable int64 key: high 32 bits = bijective(logit_bits) +
# expert id in the low 32 bits. The bijection is anti-monotone on the
# float value, and the ``<<32`` shift below moves the int32 key's high
# bit into the int64 sign bit, so ``tl.sort(..., descending=False)``
# (which is *signed* int64 ascending) yields the original logits in
# *descending* float order. Ties are broken by expert id ascending
# (lower id wins), which is a stable choice but not guaranteed to
# match ``torch.topk``'s tie-break (torch.topk's order is
# implementation-defined). Random fp inputs effectively never collide,
# so the test compares as sets when IDs differ.
MIN32 = -2147483648
logit_bits = logits.to(tl.int32, bitcast=True)
sign = logit_bits >> 31
key = tl.where(sign == 0, logit_bits ^ -1, logit_bits ^ MIN32)
# Force invalid lanes to the max positive key so they end up *after* the
# real logits when we sort ascending: positive int32 key -> positive
# int64 packed -> sorts after the bit-63-set packed values that carry
# real logits.
key = tl.where(valid, key, 0x7FFFFFFF)
sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF
packed = (sk64 << 32) | offs_e.to(tl.int64)

# Signed ascending int64 sort. Real positive logits become negative
# int64 (bit 63 set) and sort first; negative logits become positive
# int64 and sort after; invalid lanes (key=0x7fffffff) sort last.
sorted_p = tl.sort(packed, descending=False)
all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32)
all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32)

# Invert the bijection to recover the original logit value.
sign_k = all_keys >> 31
all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32)
all_logits = all_bits.to(tl.float32, bitcast=True)

# Softmax over the K largest logits only (identity proven by SGLang's
# torch routing function comment). Subtract the max for stability;
# since the list is sorted descending by logit value, the max sits at
# index 0.
top_mask = offs_e < K
max_l = tl.max(tl.where(top_mask, all_logits, -float("inf")), axis=0)
# exp2(x * log2(e)) is what tl.math.exp expands to; spell it out so we
# can tolerate older Triton releases that lack tl.math.exp.
raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634)
raw_exp = tl.where(top_mask, raw_exp, 0.0)

denom = tl.sum(raw_exp, axis=0)
denom = tl.where(denom > 0.0, denom, 1.0)
weights = raw_exp / denom

# Multiply by per_expert_scale[topk_ids]. per_expert_scale lives in
# any float dtype; cast to fp32 for the final write.
scales = tl.load(
per_expert_scale_ptr + all_ids.to(tl.int64),
mask=top_mask,
other=1.0,
).to(tl.float32)
weights = weights * scales

base_off = pid * K + offs_e
tl.store(topk_weights_ptr + base_off, weights, mask=top_mask)
tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask)


def gemma4_fused_routing(
gating_output: torch.Tensor,
per_expert_scale: torch.Tensor,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""One-launch Gemma4 router.

Args:
gating_output: [T, E] router logits in any floating dtype; will be
cast to fp32 inside the kernel.
per_expert_scale: [E] per-expert scale, any floating dtype.
topk: number of experts to keep per token.

Returns:
topk_weights: [T, topk] fp32 (matches SGLang TopK contract).
topk_ids: [T, topk] int32 (matches SGLang TopK contract).
"""
assert gating_output.dim() == 2, "expected [T, E] router logits"
assert per_expert_scale.dim() == 1
assert per_expert_scale.shape[0] == gating_output.shape[1]
T, E = gating_output.shape
assert topk <= E
# Guard against pathological E that would blow up the compiler / register
# budget. Gemma4 ships with E=128; even hypothetical 4x variants stay
# well under this cap.
assert E <= 1024, f"gemma4_fused_routing only supports E<=1024, got E={E}"

# The kernel reads the token row with stride_g_t; force the inner-most
# dim to be contiguous so the masked load is coalesced. Most call
# sites already pass a contiguous tensor (router proj output); contiguous
# is cheap.
gating_output = gating_output.contiguous()
per_expert_scale = per_expert_scale.contiguous()

BLOCK_E = triton.next_power_of_2(E)
topk_weights = torch.empty(
(T, topk), dtype=torch.float32, device=gating_output.device
)
topk_ids = torch.empty((T, topk), dtype=torch.int32, device=gating_output.device)

if T == 0:
return topk_weights, topk_ids

_gemma4_routing_kernel[(T,)](
gating_output,
per_expert_scale,
topk_weights,
topk_ids,
gating_output.stride(0),
E,
topk,
BLOCK_E,
num_warps=1,
)
return topk_weights, topk_ids
15 changes: 15 additions & 0 deletions python/sglang/srt/models/gemma4_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.gemma4_fused_ops import (
gemma4_fused_routing,
gemma_dual_rmsnorm_residual_scalar,
gemma_qkv_rmsnorm,
gemma_rmsnorm_residual_scalar,
Expand Down Expand Up @@ -220,6 +221,20 @@ def routing_function(
) -> tuple[torch.Tensor, torch.Tensor]:
# softmax(all)[topk] / sum(softmax(all)[topk]) = softmax(topk_logits),
# so we softmax only the top-k logits (fewer kernel launches).
#
# Fast path: a single Triton kernel that produces (weights, ids)
# already scaled by per_expert_scale. Mathematically identical
# to the torch fallback below. Active when on CUDA with a 2-D
# router-logits tensor and num_experts a power-of-two-rounded
# value the kernel supports (always true for Gemma4: E=128).
if (
gating_output.is_cuda
and gating_output.dim() == 2
and gating_output.dtype
in (torch.float16, torch.bfloat16, torch.float32)
):
return gemma4_fused_routing(gating_output, per_expert_scale, topk)

topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1)

Expand Down
111 changes: 111 additions & 0 deletions test/srt/layers/test_gemma4_fused_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Correctness tests for ``gemma4_fused_routing``.

Compares the Triton-fused routing kernel against the original SGLang
``Gemma4MoE.routing_function`` reference (softmax-of-topk * per_expert_scale).
Run with::

pytest test/srt/layers/test_gemma4_fused_routing.py -v

Requires a CUDA-capable GPU; skips otherwise.
"""

from __future__ import annotations

import pytest
import torch

pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="gemma4_fused_routing is a CUDA-only Triton kernel",
)


@pytest.fixture(scope="module")
def fused_routing():
from sglang.srt.layers.gemma4_fused_ops import gemma4_fused_routing

return gemma4_fused_routing


def _reference(gating_output: torch.Tensor, per_expert_scale: torch.Tensor, topk: int):
"""The previous (now fallback) torch routing function from gemma4_causal.py."""
topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1)
topk_weights = topk_weights * per_expert_scale[topk_ids].to(topk_weights.dtype)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize("T", [1, 7, 64, 128, 1024])
@pytest.mark.parametrize("E,K", [(128, 8), (64, 4), (256, 8)])
def test_matches_reference(fused_routing, dtype, T, E, K):
torch.manual_seed(0)
g = torch.randn(T, E, dtype=dtype, device="cuda")
s = torch.rand(E, dtype=dtype, device="cuda") * 2.0

ref_w, ref_i = _reference(g, s, K)
out_w, out_i = fused_routing(g, s, K)

assert out_w.dtype == torch.float32
assert out_i.dtype == torch.int32
assert out_w.shape == (T, K)
assert out_i.shape == (T, K)

# IDs must match exactly (top-K with stable tie-breaking on expert id).
# In practice with random logits ties almost never happen; if they do we
# accept either order as long as the weight sum and the selected set are
# equivalent.
# The fused kernel does softmax in fp32 throughout, while the torch
# fallback runs softmax in the input dtype before casting to fp32. For
# bf16 inputs that means our kernel is *more* accurate; loosen the
# tolerance to roughly the input-dtype eps so we don't false-fail.
if dtype == torch.bfloat16:
atol, rtol = 5e-3, 5e-3
elif dtype == torch.float16:
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 1e-5, 1e-5

if (out_i != ref_i).any():
# Compare as sets per row.
ref_set = ref_i.sort(dim=-1).values
out_set = out_i.sort(dim=-1).values
assert torch.equal(
out_set, ref_set
), "fused routing picked a different top-K set than reference"
# Sum of weights per row should still be close (softmax over the same
# K logits).
torch.testing.assert_close(
out_w.sum(dim=-1).to(torch.float32),
ref_w.sum(dim=-1).to(torch.float32),
atol=atol,
rtol=rtol,
)
else:
# Same IDs in the same order — weights must match within input dtype eps.
torch.testing.assert_close(out_w, ref_w, atol=atol, rtol=rtol)


def test_zero_tokens(fused_routing):
g = torch.empty(0, 128, dtype=torch.bfloat16, device="cuda")
s = torch.ones(128, dtype=torch.bfloat16, device="cuda")
w, i = fused_routing(g, s, 8)
assert w.shape == (0, 8) and i.shape == (0, 8)
assert w.dtype == torch.float32 and i.dtype == torch.int32


def test_scale_applied(fused_routing):
"""Weights must include per_expert_scale[topk_ids]."""
torch.manual_seed(1)
T, E, K = 4, 128, 8
g = torch.randn(T, E, dtype=torch.bfloat16, device="cuda")
s = torch.rand(E, dtype=torch.bfloat16, device="cuda") * 3.0

out_w, out_i = fused_routing(g, s, K)
ref_w, ref_i = _reference(g, s, K)
torch.testing.assert_close(out_w, ref_w, atol=5e-3, rtol=5e-3)
assert torch.equal(out_i, ref_i)


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__, "-v"]))
Loading