Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions python/sglang/srt/layers/gemma4_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,117 @@ def gemma_dual_rmsnorm_residual_scalar(
BLOCK_SIZE=BLOCK_SIZE,
)
return out


@triton.jit
def _gemma4_routing_kernel(
gating_ptr, # [T, E] router logits, any float dtype
per_expert_scale_ptr, # [E] per-expert scale (any float dtype)
topk_weights_ptr, # [T, K] fp32 out
topk_ids_ptr, # [T, K] int32 out
stride_g_t, # stride of gating in the token dim
E: tl.constexpr,
K: tl.constexpr,
BLOCK_E: tl.constexpr,
):
pid = tl.program_id(0)
offs_e = tl.arange(0, BLOCK_E)
valid = offs_e < E

logits = tl.load(
gating_ptr + pid * stride_g_t + offs_e,
mask=valid,
other=-float("inf"),
).to(tl.float32)

# Pack (sort_key, expert_id) into one int64 so a single signed-ascending
# tl.sort yields logits in descending float order. The key bijection is
# anti-monotone on the float value, and the <<32 shift moves its high bit
# into the int64 sign bit. Ties break by expert id ascending. Invalid
# lanes use a max key so they sort last.
MIN32 = -2147483648
logit_bits = logits.to(tl.int32, bitcast=True)
sign = logit_bits >> 31
key = tl.where(sign == 0, logit_bits ^ -1, logit_bits ^ MIN32)
key = tl.where(valid, key, 0x7FFFFFFF)
sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF
packed = (sk64 << 32) | offs_e.to(tl.int64)

sorted_p = tl.sort(packed, descending=False)
all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32)
all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32)

# Invert the key bijection to recover the original logit value.
sign_k = all_keys >> 31
all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32)
all_logits = all_bits.to(tl.float32, bitcast=True)

# softmax over the top-K logits; max sits at index 0 (sorted descending).
top_mask = offs_e < K
max_l = tl.max(tl.where(top_mask, all_logits, -float("inf")), axis=0)
raw_exp = tl.where(top_mask, tl.exp(all_logits - max_l), 0.0)

denom = tl.sum(raw_exp, axis=0)
denom = tl.where(denom > 0.0, denom, 1.0)
weights = raw_exp / denom

scales = tl.load(
per_expert_scale_ptr + all_ids.to(tl.int64),
mask=top_mask,
other=1.0,
).to(tl.float32)
weights = weights * scales

base_off = pid * K + offs_e
tl.store(topk_weights_ptr + base_off, weights, mask=top_mask)
tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask)


def gemma4_fused_routing(
gating_output: torch.Tensor,
per_expert_scale: torch.Tensor,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""One-launch Gemma4 router.

Args:
gating_output: [T, E] router logits in any floating dtype; will be
cast to fp32 inside the kernel.
per_expert_scale: [E] per-expert scale, any floating dtype.
topk: number of experts to keep per token.

Returns:
topk_weights: [T, topk] fp32 (matches SGLang TopK contract).
topk_ids: [T, topk] int32 (matches SGLang TopK contract).
"""
assert gating_output.dim() == 2, "expected [T, E] router logits"
assert per_expert_scale.dim() == 1
assert per_expert_scale.shape[0] == gating_output.shape[1]
T, E = gating_output.shape
assert topk <= E, f"topk ({topk}) must be <= E ({E})"
assert E <= 1024, f"gemma4_fused_routing only supports E<=1024, got E={E}"

gating_output = gating_output.contiguous()
per_expert_scale = per_expert_scale.contiguous()

BLOCK_E = triton.next_power_of_2(E)
topk_weights = torch.empty(
(T, topk), dtype=torch.float32, device=gating_output.device
)
topk_ids = torch.empty((T, topk), dtype=torch.int32, device=gating_output.device)

if T == 0:
return topk_weights, topk_ids

_gemma4_routing_kernel[(T,)](
gating_output,
per_expert_scale,
topk_weights,
topk_ids,
gating_output.stride(0),
E,
topk,
BLOCK_E,
num_warps=1,
)
return topk_weights, topk_ids
9 changes: 9 additions & 0 deletions python/sglang/srt/models/gemma4_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.gemma4_fused_ops import (
gemma4_fused_routing,
gemma_dual_rmsnorm_residual_scalar,
gemma_qkv_rmsnorm,
gemma_rmsnorm_residual_scalar,
Expand Down Expand Up @@ -220,6 +221,14 @@ def routing_function(
) -> tuple[torch.Tensor, torch.Tensor]:
# softmax(all)[topk] / sum(softmax(all)[topk]) = softmax(topk_logits),
# so we softmax only the top-k logits (fewer kernel launches).
if (
gating_output.is_cuda
and gating_output.dim() == 2
and gating_output.dtype
in (torch.float16, torch.bfloat16, torch.float32)
):
return gemma4_fused_routing(gating_output, per_expert_scale, topk)

topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1)

Expand Down
106 changes: 106 additions & 0 deletions test/registered/kernels/test_gemma4_fused_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Correctness tests for ``gemma4_fused_routing``.

Compares the Triton-fused routing kernel against the original SGLang
``Gemma4MoE.routing_function`` reference (softmax-of-topk * per_expert_scale).
Run with::

pytest test/registered/kernels/test_gemma4_fused_routing.py -v

Requires a CUDA-capable GPU; skips otherwise.
"""

from __future__ import annotations

import pytest
import torch

from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=60, stage="base-b", runner_config="1-gpu-small")

pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="gemma4_fused_routing is a CUDA-only Triton kernel",
)


@pytest.fixture(scope="module")
def fused_routing():
from sglang.srt.layers.gemma4_fused_ops import gemma4_fused_routing

return gemma4_fused_routing


def _reference(gating_output: torch.Tensor, per_expert_scale: torch.Tensor, topk: int):
"""The previous (now fallback) torch routing function from gemma4_causal.py."""
topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1)
topk_weights = topk_weights * per_expert_scale[topk_ids].to(topk_weights.dtype)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize("T", [1, 7, 64, 128, 1024])
@pytest.mark.parametrize("E,K", [(128, 8), (64, 4), (256, 8)])
def test_matches_reference(fused_routing, dtype, T, E, K):
torch.manual_seed(0)
g = torch.randn(T, E, dtype=dtype, device="cuda")
s = torch.rand(E, dtype=dtype, device="cuda") * 2.0

ref_w, ref_i = _reference(g, s, K)
out_w, out_i = fused_routing(g, s, K)

assert out_w.dtype == torch.float32
assert out_i.dtype == torch.int32
assert out_w.shape == (T, K)
assert out_i.shape == (T, K)

# The fused kernel does softmax in fp32 while the torch fallback uses the
# input dtype, so tolerances are set to roughly the input-dtype eps.
if dtype == torch.bfloat16:
atol, rtol = 5e-3, 5e-3
elif dtype == torch.float16:
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 1e-5, 1e-5

if (out_i != ref_i).any():
# Tie-break order may differ; require the same top-K set and weight sum.
ref_set = ref_i.sort(dim=-1).values
out_set = out_i.sort(dim=-1).values
assert torch.equal(
out_set, ref_set
), "fused routing picked a different top-K set than reference"
torch.testing.assert_close(
out_w.sum(dim=-1).to(torch.float32),
ref_w.sum(dim=-1).to(torch.float32),
atol=atol,
rtol=rtol,
)
else:
torch.testing.assert_close(out_w, ref_w, atol=atol, rtol=rtol)


def test_zero_tokens(fused_routing):
g = torch.empty(0, 128, dtype=torch.bfloat16, device="cuda")
s = torch.ones(128, dtype=torch.bfloat16, device="cuda")
w, i = fused_routing(g, s, 8)
assert w.shape == (0, 8) and i.shape == (0, 8)
assert w.dtype == torch.float32 and i.dtype == torch.int32


def test_scale_applied(fused_routing):
"""Weights must include per_expert_scale[topk_ids]."""
torch.manual_seed(1)
T, E, K = 4, 128, 8
g = torch.randn(T, E, dtype=torch.bfloat16, device="cuda")
s = torch.rand(E, dtype=torch.bfloat16, device="cuda") * 3.0

out_w, out_i = fused_routing(g, s, K)
ref_w, ref_i = _reference(g, s, K)
torch.testing.assert_close(out_w, ref_w, atol=5e-3, rtol=5e-3)
assert torch.equal(out_i, ref_i)


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__, "-v"]))
Loading