Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions python/sglang/srt/layers/gemma4_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,70 @@ def gemma4_arf_rmsnorm_residual_scalar(
return gemma_rmsnorm_residual_scalar(x_reduced, weight, residual, scalar, eps)


def gemma4_arf_rmsnorm_only(
x: torch.Tensor,
norm_module,
use_attn_tp_group: bool = True,
) -> torch.Tensor:
"""Fused TP all-reduce + single-arg RMSNorm for Gemma-4
``post_attention_layernorm``.

Numerically equivalent to::

x_reduced = tensor_model_parallel_all_reduce(x)
return norm_module.forward(x_reduced)

where ``norm_module`` is a standard SGLang ``RMSNorm`` whose math is
``rmsnorm(x) * weight``. This wrapper is the **correct fusion site**
for Gemma-4's residual flow because Gemma-4 places a single-arg
RMSNorm immediately after the attention all-reduce (before any
residual addition).

Why the zero-residual trick:
FlashInfer's TRT-LLM ``allreduce_fusion`` API only exposes the
``kARResidualRMSNorm`` pattern (no residual-less variant). vLLM's
``AllReduceRMSNormPattern`` solves this by synthesizing a
``torch.zeros_like(input)`` residual; the math
``rmsnorm(AR(x) + 0) == rmsnorm(AR(x))`` makes the residual
contribution vanish. We follow the same convention here.

Caller contract:
* Caller must pass ``skip_all_reduce=True`` to the upstream
``RowParallelLinear`` whose output is ``x``.
* ``x`` must be the still-TP-sharded post-attention projection.
* ``norm_module`` is the Gemma-4 layer's
``post_attention_layernorm`` (a ``RMSNorm`` instance — *not* a
``Gemma4RMSNorm``, because the latter's ``(weight + scale_shift)``
gamma is not currently expressible in FlashInfer's pattern).

Fallback: when FlashInfer is unavailable, batch too large, workspace
not ready, or the predicate is False, falls back to
``tensor_model_parallel_all_reduce(x) + norm_module.forward(_)`` with
bit-identical semantics to the pre-fusion path.
"""
from sglang.srt.distributed import tensor_model_parallel_all_reduce
from sglang.srt.layers.communicator import apply_flashinfer_allreduce_fusion
from sglang.srt.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)

if x.is_cuda and x.dim() == 2 and apply_flashinfer_allreduce_fusion(x.shape[0]):
zero_residual = torch.zeros_like(x)
norm_out, _residual_out = flashinfer_allreduce_residual_rmsnorm(
input_tensor=x,
residual=zero_residual,
weight=norm_module.weight.data,
eps=norm_module.variance_epsilon,
use_attn_tp_group=use_attn_tp_group,
)
if norm_out is not None:
return norm_out

# Fallback: identical to the pre-fusion code path.
x_reduced = tensor_model_parallel_all_reduce(x)
return norm_module.forward(x_reduced)


@triton.jit
def _gemma_dual_rmsnorm_residual_kernel(
X1_ptr,
Expand Down
52 changes: 47 additions & 5 deletions python/sglang/srt/models/gemma4_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.gemma4_fused_ops import (
gemma4_arf_rmsnorm_only,
gemma4_arf_rmsnorm_residual_scalar,
gemma4_fused_routing,
gemma_dual_rmsnorm_residual_scalar,
gemma_qkv_rmsnorm,
Expand Down Expand Up @@ -407,6 +409,7 @@ def forward(
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
skip_all_reduce: bool = False,
**kwargs,
):
qkv, _ = self.qkv_proj(hidden_states)
Expand Down Expand Up @@ -491,7 +494,14 @@ def forward(
)
if attn_output.dim() == 3:
attn_output = attn_output.flatten(-2, -1)
output, _ = self.o_proj(attn_output)
# ARF-fast-path: when the caller signals it will fuse the
# ``o_proj`` TP all-reduce with the downstream
# ``post_attention_layernorm`` via
# ``gemma4_arf_rmsnorm_only``, ``o_proj`` must NOT do its own
# all-reduce (otherwise the gradient is double-reduced). Safe
# default ``skip_all_reduce=False`` preserves current behavior
# for all non-ARF callers.
output, _ = self.o_proj(attn_output, skip_all_reduce=skip_all_reduce)

return output

Expand Down Expand Up @@ -621,6 +631,22 @@ def __init__(
self.has_ple = self.hidden_size_per_layer_input > 0
self.prefix = prefix

# FlashInfer AR+RMSNorm fusion opt-in (PR-B/2 of the Gemma-4 ARF
# stack). Cache the server-arg flag at __init__ time to avoid a
# per-step lookup; the actual runtime gate also checks
# ``apply_flashinfer_allreduce_fusion(num_tokens)`` inside
# ``gemma4_arf_rmsnorm_residual_scalar``. ARF is only wired into
# the dense (non-MoE, non-PLE) post-FF combine in v0.
try:
_server_args = get_global_server_args()
except Exception:
_server_args = None
self._arf_enabled = (
bool(getattr(_server_args, "enable_flashinfer_allreduce_fusion", False))
and not self.enable_moe_block
and not self.has_ple
)

def forward(
self,
positions: torch.Tensor,
Expand All @@ -644,12 +670,28 @@ def forward(

# Apply input layernorm
hidden_states = self.input_layernorm(hidden_states)
# ARF fast-path for the post-attention all-reduce + RMSNorm.
# When ``self._arf_enabled`` is True, ``self_attn`` skips its
# internal ``o_proj`` all-reduce and the downstream
# ``gemma4_arf_rmsnorm_only`` calls FlashInfer's TRT-LLM fused
# AR+RMSNorm kernel (with a zero residual to satisfy the
# FlashInfer API; the residual contribution vanishes
# mathematically). The wrapper falls back to plain
# AR + post_attention_layernorm if the runtime predicate is
# False (batch too large, FlashInfer unavailable, etc.).
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
skip_all_reduce=self._arf_enabled,
)
hidden_states = self.post_attention_layernorm(hidden_states)
if self._arf_enabled:
hidden_states = gemma4_arf_rmsnorm_only(
hidden_states,
self.post_attention_layernorm,
)
else:
hidden_states = self.post_attention_layernorm(hidden_states)

if self.enable_moe_block:
# Fuse: hidden_states + residual -> residual; pre_ff_norm(residual) -> hidden_states
Expand Down Expand Up @@ -943,9 +985,9 @@ def forward(
)
hidden_states = input_embeds
else:
assert (
pp_proxy_tensors is not None
), "pp_proxy_tensors is required on non-first PP ranks"
assert pp_proxy_tensors is not None, (
"pp_proxy_tensors is required on non-first PP ranks"
)
hidden_states = pp_proxy_tensors["hidden_states"]
# PLE inputs were computed on rank 0 and forwarded along the
# pipeline; non-PLE models simply omit the key.
Expand Down
Loading