Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions python/sglang/srt/layers/gemma4_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

Fuses standard RMSNorm + residual-add (+ optional scalar multiply) into
a single kernel pass to reduce kernel launch overhead.

Also provides a single-launch fused router for Gemma4 MoE (PR #26120 in
pyc96/sglang fork): replaces the per-layer ``torch.topk`` ->
``softmax`` -> ``per_expert_scale[ids]`` -> ``mul`` -> ``cast`` chain in
``Gemma4MoE.routing_function`` with one Triton kernel.

The reference design comes from vLLM PR #39083
(``_gemma4_routing_kernel`` / ``gemma4_fused_routing_kernel_triton``),
which is apache-2.0. Our kernel is rewritten in SGLang style and uses
the identity ``softmax(all)[topk] / sum(softmax(all)[topk]) =
softmax(topk_logits)`` already exploited by SGLang's torch routing
function, so the math is bitwise-comparable to the prior fp32 path.
"""

from typing import Optional
Expand Down Expand Up @@ -283,3 +295,163 @@ def gemma_dual_rmsnorm_residual_scalar(
BLOCK_SIZE=BLOCK_SIZE,
)
return out


# ---------------------------------------------------------------------------
# Fused Gemma4 routing kernel (one launch per layer)
# ---------------------------------------------------------------------------
#
# Equivalent to:
#
# topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
# topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1)
# topk_weights = topk_weights * per_expert_scale[topk_ids]
# return topk_weights.float(), topk_ids.int()
#
# but completes the entire computation in one Triton program per token.
#
# Algorithm notes:
# * Loads all E logits per token into one program; for Gemma4
# ``E = num_experts = 128`` so ``BLOCK_E = next_pow2(E) = 128`` and the
# work fits in a single warp with `num_warps=1`.
# * Computes ``softmax-of-topk`` by:
# - using ``tl.sort`` on (logit_bits_as_sortable_uint, expert_id) pairs
# packed into int64 — this gives a fully vectorized top-K without a
# K-step loop and matches the bitwise behavior of ``torch.topk``.
# - taking the largest K via a mask on the sorted-descending sequence
# - normalizing in fp32 (matches ``softmax`` default dtype)
# - multiplying by ``per_expert_scale[topk_ids]``
# * Writes ``topk_weights`` (fp32) and ``topk_ids`` (int32) in one
# pass, matching the output dtypes the SGLang MoE topk wrapper
# expects.
#
# Reference algorithm: vLLM PR #39083 ``_gemma4_routing_kernel`` (apache-2.0).
# Our independent implementation follows the same sort+mask+softmax scheme.
@triton.jit
def _gemma4_routing_kernel(
gating_ptr, # [T, E] router logits, any float dtype
per_expert_scale_ptr, # [E] per-expert scale (any float dtype)
topk_weights_ptr, # [T, K] fp32 out
topk_ids_ptr, # [T, K] int32 out
stride_g_t, # stride of gating in the token dim
E: tl.constexpr,
K: tl.constexpr,
BLOCK_E: tl.constexpr,
):
pid = tl.program_id(0)
offs_e = tl.arange(0, BLOCK_E)
valid = offs_e < E

# Load logits into fp32; out-of-bound lanes get -inf so they sort last.
logits = tl.load(
gating_ptr + pid * stride_g_t + offs_e,
mask=valid,
other=-float("inf"),
).to(tl.float32)

# Build a sortable int64 key: high 32 bits = bijective(logit_bits) so
# ascending-int sort == ascending-float sort; low 32 bits = expert id
# (kept stable for ties matching torch.topk's default behavior). This
# avoids a separate index buffer / scatter pass after the sort.
MIN32 = -2147483648
logit_bits = logits.to(tl.int32, bitcast=True)
sign = logit_bits >> 31
key = tl.where(sign == 0, logit_bits ^ -1, logit_bits ^ MIN32)
# Force invalid lanes to the max positive key so they end up *after* the
# real logits when we sort ascending and read from the top of the
# reversed list. (descending=True would flip the order.)
key = tl.where(valid, key, 0x7FFFFFFF)
sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF
packed = (sk64 << 32) | offs_e.to(tl.int64)

# Sort ascending; the K smallest keys correspond to the K largest
# logits because of the bijection above.
sorted_p = tl.sort(packed, descending=False)
all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32)
all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32)

# Invert the bijection to recover the original logit value.
sign_k = all_keys >> 31
all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32)
all_logits = all_bits.to(tl.float32, bitcast=True)

# Softmax over the K largest logits only (identity proven by SGLang's
# torch routing function comment). Subtract the max for stability;
# since the list is sorted descending by logit value, the max sits at
# index 0.
top_mask = offs_e < K
max_l = tl.max(tl.where(top_mask, all_logits, -float("inf")), axis=0)
# exp2(x * log2(e)) is what tl.math.exp expands to; spell it out so we
# can tolerate older Triton releases that lack tl.math.exp.
raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634)
raw_exp = tl.where(top_mask, raw_exp, 0.0)

denom = tl.sum(raw_exp, axis=0)
denom = tl.where(denom > 0.0, denom, 1.0)
weights = raw_exp / denom

# Multiply by per_expert_scale[topk_ids]. per_expert_scale lives in
# any float dtype; cast to fp32 for the final write.
scales = tl.load(
per_expert_scale_ptr + all_ids.to(tl.int64),
mask=top_mask,
other=1.0,
).to(tl.float32)
weights = weights * scales

base_off = pid * K + offs_e
tl.store(topk_weights_ptr + base_off, weights, mask=top_mask)
tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask)


def gemma4_fused_routing(
gating_output: torch.Tensor,
per_expert_scale: torch.Tensor,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""One-launch Gemma4 router.

Args:
gating_output: [T, E] router logits in any floating dtype; will be
cast to fp32 inside the kernel.
per_expert_scale: [E] per-expert scale, any floating dtype.
topk: number of experts to keep per token.

Returns:
topk_weights: [T, topk] fp32 (matches SGLang TopK contract).
topk_ids: [T, topk] int32 (matches SGLang TopK contract).
"""
assert gating_output.dim() == 2, "expected [T, E] router logits"
assert per_expert_scale.dim() == 1
assert per_expert_scale.shape[0] == gating_output.shape[1]
T, E = gating_output.shape
assert topk <= E

# The kernel reads the token row with stride_g_t; force the inner-most
# dim to be contiguous so the masked load is coalesced. Most call
# sites already pass a contiguous tensor (router proj output); contiguous
# is cheap.
gating_output = gating_output.contiguous()
per_expert_scale = per_expert_scale.contiguous()

BLOCK_E = triton.next_power_of_2(E)
topk_weights = torch.empty(
(T, topk), dtype=torch.float32, device=gating_output.device
)
topk_ids = torch.empty((T, topk), dtype=torch.int32, device=gating_output.device)

if T == 0:
return topk_weights, topk_ids

_gemma4_routing_kernel[(T,)](
gating_output,
per_expert_scale,
topk_weights,
topk_ids,
gating_output.stride(0),
E,
topk,
BLOCK_E,
num_warps=1,
)
return topk_weights, topk_ids
32 changes: 26 additions & 6 deletions python/sglang/srt/models/gemma4_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.gemma4_fused_ops import (
gemma4_fused_routing,
gemma_dual_rmsnorm_residual_scalar,
gemma_qkv_rmsnorm,
gemma_rmsnorm_residual_scalar,
Expand Down Expand Up @@ -220,6 +221,20 @@ def routing_function(
) -> tuple[torch.Tensor, torch.Tensor]:
# softmax(all)[topk] / sum(softmax(all)[topk]) = softmax(topk_logits),
# so we softmax only the top-k logits (fewer kernel launches).
#
# Fast path: a single Triton kernel that produces (weights, ids)
# already scaled by per_expert_scale. Mathematically identical
# to the torch fallback below. Active when on CUDA with a 2-D
# router-logits tensor and num_experts a power-of-two-rounded
# value the kernel supports (always true for Gemma4: E=128).
if (
gating_output.is_cuda
and gating_output.dim() == 2
and gating_output.dtype
in (torch.float16, torch.bfloat16, torch.float32)
):
return gemma4_fused_routing(gating_output, per_expert_scale, topk)

topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1)

Expand Down Expand Up @@ -1147,7 +1162,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")),
("experts.w2_weight", "experts.down_proj", ("w2",)),
]
num_experts = self.config.num_experts
# Dense subclasses (e.g. the Gemma4 MTP assistant) reuse this.
num_experts = getattr(self.config, "num_experts", None) or 0

# Per-expert checkpoint format used by compressed-tensors / FP8
# (e.g. RedHatAI/*-FP8-Dynamic) and by ModelOpt NVFP4
Expand All @@ -1159,11 +1175,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# in a trailing dot, so the standard `name.replace(weight_name,
# param_name)` collapses every suffix uniformly to the fused
# FusedMoE params (experts.w13_*, experts.w2_*).
per_expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=num_experts,
per_expert_params_mapping = (
FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=num_experts,
)
if num_experts
else []
)

k_eq_v_layers = self._get_k_eq_v_layers()
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/gemma4_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel

from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import (
LogitsMetadata,
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
self.assistant_config = config
self.config = text_config
self.quant_config = quant_config
self.pp_group = get_pp_group()

self.vocab_size = text_config.vocab_size
self.hidden_size = text_config.hidden_size
Expand Down
42 changes: 40 additions & 2 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,10 +2232,48 @@ def _handle_model_specific_adjustments(self):
)

if is_sm100_supported() and self.moe_runner_backend == "auto":
if self.get_model_config().quantization == "modelopt_fp4":
self.quantization = "modelopt_fp4"
self.moe_runner_backend = "flashinfer_trtllm"
logger.info(
"Use flashinfer_trtllm as MoE runner backend on "
"SM100 for Gemma-4 (modelopt_fp4)"
)

self.moe_runner_backend = "flashinfer_trtllm"
# Gemma-4 uses a 25:5 sliding-window : full-attention layer ratio
# (see ``Gemma4TextConfig.layer_types``). The shipped default
# ``swa_full_tokens_ratio = 0.8`` is tuned for models where the
# sliding-window pool is the binding constraint, but for Gemma-4
# the full-attention pool is binding under concurrent long-context
# workloads: with the default ratio the full pool only fits ~65
# 9k-token requests on a 180 GB B200, forcing partial KV eviction
# and re-prefill (visible as ``#cached-token: 1003 #new-token:
# 7010`` lines in the serving log) under typical 80-request
# summarization loads.
#
# Lowering the ratio to ~0.15 shifts memory from the over-
# provisioned SWA pool (25 layers × 1024-token window) to the
# under-provisioned full pool (5 layers × full context length).
# On the same 180 GB B200, the full pool grows from ~594 k tokens
# to ~2.14 M tokens (3.6× larger; enough for ~237 concurrent
# 9k-token requests), while the SWA pool shrinks from ~475 k to
# ~321 k tokens (still ~313 concurrent 1024-token windows,
# far above any realistic request count). Median TTFT on a
# summarization workload of 80 × 8k-input / 1k-output prompts
# drops 16.5 % (10.5 s -> 8.7 s) on a B200 with TP=1, MTP, and
# the triton attention backend, with no MMLU regression.
#
# Only apply when the user did not explicitly set the ratio,
# mirroring the pattern in ``apply_deepseek_v4_defaults``.
if self.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio:
self.swa_full_tokens_ratio = 0.15
logger.info(
"Use flashinfer_trtllm as MoE runner backend on SM100 for Gemma-4 NVFP4"
"Setting swa_full_tokens_ratio to "
f"{self.swa_full_tokens_ratio} for {model_arch} "
"(Gemma-4 has a 25:5 SWA:full layer split; the default "
"ratio over-provisions the SWA pool and under-provisions "
"the full-attention pool, causing partial KV eviction "
"and re-prefill under concurrent long-context loads)."
)
elif model_arch == "MossVLForConditionalGeneration":
if self.is_attention_backend_not_set():
Expand Down
Loading
Loading