Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions python/sglang/srt/layers/gemma4_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,76 @@ def gemma_rmsnorm_residual_scalar(
return out


def gemma4_arf_rmsnorm_residual_scalar(
x: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scalar: torch.Tensor,
eps: float = 1e-6,
use_attn_tp_group: bool = True,
) -> torch.Tensor:
"""Fused TP all-reduce + (rmsnorm(x) + residual) * scalar for Gemma-4
dense post-FF combine.

Numerically equivalent to::

x_reduced = tensor_model_parallel_all_reduce(x)
return gemma_rmsnorm_residual_scalar(x_reduced, weight, residual, scalar, eps)

but, when FlashInfer's fused AllReduce+RMSNorm pattern is applicable on
this step (Hopper/Blackwell, ``--enable-flashinfer-allreduce-fusion``,
batch <= ``FUSE_ALLREDUCE_MAX_BATCH_SIZE``, workspace healthy, etc.),
collapses the TP all-reduce and the residual-add+RMSNorm into a single
TRT-LLM communication kernel that overlaps the collective with the norm
math. The final ``* scalar`` tail runs as a one-launch broadcast mul
(cheap; vectorized point-wise op).

Caller contract:
* The caller is responsible for passing ``skip_all_reduce=True`` to
the upstream ``RowParallelLinear`` whose output is ``x`` so the
all-reduce is not double-counted.
* ``x`` must be the still-TP-sharded output of that ``down_proj``
(i.e. the value RowParallelLinear would have all-reduced).
* ``residual`` is the full pre-FF hidden state (already replicated).
* ``scalar`` is the Gemma-4 ``layer_scalar`` persistent buffer
(shape ``[1]``).
* ``use_attn_tp_group=True`` selects the attention-TP group's
FlashInfer workspace; for Gemma-4 (no DP-attn, no MoE-TP split)
this is the full TP group.

When the fused path is not applicable, falls back to the explicit
``tensor_model_parallel_all_reduce`` + ``gemma_rmsnorm_residual_scalar``
sequence with bit-identical semantics to the pre-fusion code path.
"""
# Lazy imports to avoid pulling in distributed/communicator at module
# load time (matches the convention used by other call sites of
# ``flashinfer_allreduce_residual_rmsnorm`` in SGLang).
from sglang.srt.distributed import tensor_model_parallel_all_reduce
from sglang.srt.layers.communicator import apply_flashinfer_allreduce_fusion
from sglang.srt.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_residual_rmsnorm,
)

if x.is_cuda and x.dim() == 2 and apply_flashinfer_allreduce_fusion(x.shape[0]):
norm_out, _residual_out = flashinfer_allreduce_residual_rmsnorm(
input_tensor=x,
residual=residual,
weight=weight,
eps=eps,
use_attn_tp_group=use_attn_tp_group,
)
if norm_out is not None:
# FlashInfer succeeded; apply the Gemma-4 layer_scalar tail.
# The mul is fused by the eager bf16 elementwise path; one
# extra launch on top of the fused AR+RMSNorm. ``scalar`` is
# shape ``[1]`` so broadcasting is free.
return norm_out * scalar

# Fallback: identical to the pre-fusion code path.
x_reduced = tensor_model_parallel_all_reduce(x)
return gemma_rmsnorm_residual_scalar(x_reduced, weight, residual, scalar, eps)


@triton.jit
def _gemma_dual_rmsnorm_residual_kernel(
X1_ptr,
Expand Down
12 changes: 10 additions & 2 deletions python/sglang/srt/models/gemma3_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,18 @@ def __init__(
self.act_fn = GeluAndMul()
self.prefix = prefix

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, skip_all_reduce: bool = False) -> torch.Tensor:
"""Forward pass.

When ``skip_all_reduce=True``, the ``RowParallelLinear.down_proj``
omits its TP all-reduce so the caller can fuse it into a downstream
operation (see ``gemma4_arf_rmsnorm_residual_scalar`` for the
Gemma-4 post-FF combine fusion). The default is to all-reduce
in-line for back-compat with every other caller.
"""
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
x, _ = self.down_proj(x, skip_all_reduce=skip_all_reduce)
return x


Expand Down
Loading
Loading