diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py index ad6f01d9875a..14a677e66c5c 100644 --- a/python/sglang/srt/layers/gemma4_fused_ops.py +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -283,3 +283,117 @@ def gemma_dual_rmsnorm_residual_scalar( BLOCK_SIZE=BLOCK_SIZE, ) return out + + +@triton.jit +def _gemma4_routing_kernel( + gating_ptr, # [T, E] router logits, any float dtype + per_expert_scale_ptr, # [E] per-expert scale (any float dtype) + topk_weights_ptr, # [T, K] fp32 out + topk_ids_ptr, # [T, K] int32 out + stride_g_t, # stride of gating in the token dim + E: tl.constexpr, + K: tl.constexpr, + BLOCK_E: tl.constexpr, +): + pid = tl.program_id(0) + offs_e = tl.arange(0, BLOCK_E) + valid = offs_e < E + + logits = tl.load( + gating_ptr + pid * stride_g_t + offs_e, + mask=valid, + other=-float("inf"), + ).to(tl.float32) + + # Pack (sort_key, expert_id) into one int64 so a single signed-ascending + # tl.sort yields logits in descending float order. The key bijection is + # anti-monotone on the float value, and the <<32 shift moves its high bit + # into the int64 sign bit. Ties break by expert id ascending. Invalid + # lanes use a max key so they sort last. + MIN32 = -2147483648 + logit_bits = logits.to(tl.int32, bitcast=True) + sign = logit_bits >> 31 + key = tl.where(sign == 0, logit_bits ^ -1, logit_bits ^ MIN32) + key = tl.where(valid, key, 0x7FFFFFFF) + sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF + packed = (sk64 << 32) | offs_e.to(tl.int64) + + sorted_p = tl.sort(packed, descending=False) + all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32) + all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32) + + # Invert the key bijection to recover the original logit value. + sign_k = all_keys >> 31 + all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32) + all_logits = all_bits.to(tl.float32, bitcast=True) + + # softmax over the top-K logits; max sits at index 0 (sorted descending). + top_mask = offs_e < K + max_l = tl.max(tl.where(top_mask, all_logits, -float("inf")), axis=0) + raw_exp = tl.where(top_mask, tl.exp(all_logits - max_l), 0.0) + + denom = tl.sum(raw_exp, axis=0) + denom = tl.where(denom > 0.0, denom, 1.0) + weights = raw_exp / denom + + scales = tl.load( + per_expert_scale_ptr + all_ids.to(tl.int64), + mask=top_mask, + other=1.0, + ).to(tl.float32) + weights = weights * scales + + base_off = pid * K + offs_e + tl.store(topk_weights_ptr + base_off, weights, mask=top_mask) + tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask) + + +def gemma4_fused_routing( + gating_output: torch.Tensor, + per_expert_scale: torch.Tensor, + topk: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """One-launch Gemma4 router. + + Args: + gating_output: [T, E] router logits in any floating dtype; will be + cast to fp32 inside the kernel. + per_expert_scale: [E] per-expert scale, any floating dtype. + topk: number of experts to keep per token. + + Returns: + topk_weights: [T, topk] fp32 (matches SGLang TopK contract). + topk_ids: [T, topk] int32 (matches SGLang TopK contract). + """ + assert gating_output.dim() == 2, "expected [T, E] router logits" + assert per_expert_scale.dim() == 1 + assert per_expert_scale.shape[0] == gating_output.shape[1] + T, E = gating_output.shape + assert topk <= E, f"topk ({topk}) must be <= E ({E})" + assert E <= 1024, f"gemma4_fused_routing only supports E<=1024, got E={E}" + + gating_output = gating_output.contiguous() + per_expert_scale = per_expert_scale.contiguous() + + BLOCK_E = triton.next_power_of_2(E) + topk_weights = torch.empty( + (T, topk), dtype=torch.float32, device=gating_output.device + ) + topk_ids = torch.empty((T, topk), dtype=torch.int32, device=gating_output.device) + + if T == 0: + return topk_weights, topk_ids + + _gemma4_routing_kernel[(T,)]( + gating_output, + per_expert_scale, + topk_weights, + topk_ids, + gating_output.stride(0), + E, + topk, + BLOCK_E, + num_warps=1, + ) + return topk_weights, topk_ids diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index c406f12a2b6c..29b0cfd4b56d 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -30,6 +30,7 @@ get_tensor_model_parallel_world_size, ) from sglang.srt.layers.gemma4_fused_ops import ( + gemma4_fused_routing, gemma_dual_rmsnorm_residual_scalar, gemma_qkv_rmsnorm, gemma_rmsnorm_residual_scalar, @@ -220,6 +221,14 @@ def routing_function( ) -> tuple[torch.Tensor, torch.Tensor]: # softmax(all)[topk] / sum(softmax(all)[topk]) = softmax(topk_logits), # so we softmax only the top-k logits (fewer kernel launches). + if ( + gating_output.is_cuda + and gating_output.dim() == 2 + and gating_output.dtype + in (torch.float16, torch.bfloat16, torch.float32) + ): + return gemma4_fused_routing(gating_output, per_expert_scale, topk) + topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) diff --git a/test/registered/kernels/test_gemma4_fused_routing.py b/test/registered/kernels/test_gemma4_fused_routing.py new file mode 100644 index 000000000000..487ab82002d2 --- /dev/null +++ b/test/registered/kernels/test_gemma4_fused_routing.py @@ -0,0 +1,106 @@ +"""Correctness tests for ``gemma4_fused_routing``. + +Compares the Triton-fused routing kernel against the original SGLang +``Gemma4MoE.routing_function`` reference (softmax-of-topk * per_expert_scale). +Run with:: + + pytest test/registered/kernels/test_gemma4_fused_routing.py -v + +Requires a CUDA-capable GPU; skips otherwise. +""" + +from __future__ import annotations + +import pytest +import torch + +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=60, stage="base-b", runner_config="1-gpu-small") + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="gemma4_fused_routing is a CUDA-only Triton kernel", +) + + +@pytest.fixture(scope="module") +def fused_routing(): + from sglang.srt.layers.gemma4_fused_ops import gemma4_fused_routing + + return gemma4_fused_routing + + +def _reference(gating_output: torch.Tensor, per_expert_scale: torch.Tensor, topk: int): + """The previous (now fallback) torch routing function from gemma4_causal.py.""" + topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) + topk_weights = topk_weights * per_expert_scale[topk_ids].to(topk_weights.dtype) + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("T", [1, 7, 64, 128, 1024]) +@pytest.mark.parametrize("E,K", [(128, 8), (64, 4), (256, 8)]) +def test_matches_reference(fused_routing, dtype, T, E, K): + torch.manual_seed(0) + g = torch.randn(T, E, dtype=dtype, device="cuda") + s = torch.rand(E, dtype=dtype, device="cuda") * 2.0 + + ref_w, ref_i = _reference(g, s, K) + out_w, out_i = fused_routing(g, s, K) + + assert out_w.dtype == torch.float32 + assert out_i.dtype == torch.int32 + assert out_w.shape == (T, K) + assert out_i.shape == (T, K) + + # The fused kernel does softmax in fp32 while the torch fallback uses the + # input dtype, so tolerances are set to roughly the input-dtype eps. + if dtype == torch.bfloat16: + atol, rtol = 5e-3, 5e-3 + elif dtype == torch.float16: + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 1e-5, 1e-5 + + if (out_i != ref_i).any(): + # Tie-break order may differ; require the same top-K set and weight sum. + ref_set = ref_i.sort(dim=-1).values + out_set = out_i.sort(dim=-1).values + assert torch.equal( + out_set, ref_set + ), "fused routing picked a different top-K set than reference" + torch.testing.assert_close( + out_w.sum(dim=-1).to(torch.float32), + ref_w.sum(dim=-1).to(torch.float32), + atol=atol, + rtol=rtol, + ) + else: + torch.testing.assert_close(out_w, ref_w, atol=atol, rtol=rtol) + + +def test_zero_tokens(fused_routing): + g = torch.empty(0, 128, dtype=torch.bfloat16, device="cuda") + s = torch.ones(128, dtype=torch.bfloat16, device="cuda") + w, i = fused_routing(g, s, 8) + assert w.shape == (0, 8) and i.shape == (0, 8) + assert w.dtype == torch.float32 and i.dtype == torch.int32 + + +def test_scale_applied(fused_routing): + """Weights must include per_expert_scale[topk_ids].""" + torch.manual_seed(1) + T, E, K = 4, 128, 8 + g = torch.randn(T, E, dtype=torch.bfloat16, device="cuda") + s = torch.rand(E, dtype=torch.bfloat16, device="cuda") * 3.0 + + out_w, out_i = fused_routing(g, s, K) + ref_w, ref_i = _reference(g, s, K) + torch.testing.assert_close(out_w, ref_w, atol=5e-3, rtol=5e-3) + assert torch.equal(out_i, ref_i) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v"]))