From 87aac49c427eea195135bbd436f3a3f9b4955ffc Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Mon, 25 May 2026 05:36:28 +0000 Subject: [PATCH] feat(gemma4 ARF): wire FlashInfer AR+RMSNorm into post-attention site (PR-B/2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR-B of the 2-PR ARF stack. Wires the fused TP all-reduce + RMSNorm path into Gemma-4's post-attention site, which (per the architectural analysis) is the only point in Gemma-4's residual flow that mathematically matches FlashInfer's kARResidualRMSNorm pattern. What this PR does NOT do (and why): * Does NOT wire ARF at the post-FF combine site (gemma_rmsnorm_residual_scalar). Gemma's post-FF formula is (rmsnorm(x) + residual) * scalar — i.e. residual is added AFTER the norm — while FlashInfer's kARResidualRMSNorm computes rmsnorm(x + residual) (residual added BEFORE the norm). Empirically verified the two produce different outputs (max diff 7.27, mean 2.09 on a 4x8 sample). An attempted wiring at this site produced token soup. * Does NOT wire ARF at the next-layer input_layernorm. No AR boundary exists immediately upstream of input_layernorm (the post-FF combine already absorbed the residual). * Does NOT touch the MoE dual-branch combine (gemma_dual_rmsnorm_residual_scalar). Two upstream AR boundaries (dense MLP + MoE); out of scope for v0. * Does NOT touch PLE-enabled variants (E4B/E2B); guarded by self.has_ple. Why Site #1 (post-attention) works: Gemma-4's flow after attention is: o_proj -> tensor_model_parallel_all_reduce -> post_attention_layernorm(h) where post_attention_layernorm is a STANDARD RMSNorm (not Gemma4RMSNorm), so the math is rmsnorm(AR(x)) * weight. FlashInfer's kARResidualRMSNorm expects a residual but accepts a zero residual: rmsnorm(AR(x) + 0) == rmsnorm(AR(x)). This is the same workaround vLLM uses in AllReduceRMSNormPattern. Changes: * python/sglang/srt/layers/gemma4_fused_ops.py: New function gemma4_arf_rmsnorm_only(x, norm_module, use_attn_tp_group=True) that: - Calls flashinfer_allreduce_residual_rmsnorm with a zero residual, discards the residual output, returns just the rmsnorm output. - Falls back to tensor_model_parallel_all_reduce(x) + norm_module.forward(_) when the predicate is False or flashinfer returns (None, None). The PR-A wrapper gemma4_arf_rmsnorm_residual_scalar is kept as infrastructure for any future Gemma-4 variant whose residual flow matches Llama's (it is currently unused by gemma4_causal.py). * python/sglang/srt/models/gemma4_causal.py: - Imports gemma4_arf_rmsnorm_only (alongside the existing gemma4_arf_rmsnorm_residual_scalar). - Threads skip_all_reduce kwarg through Gemma4Attention.forward to the o_proj call (default False preserves current behavior). - At the post-attention site, when self._arf_enabled (set in __init__ based on get_global_server_args().enable_flashinfer_allreduce_fusion and gated on not enable_moe_block and not has_ple): * self_attn is called with skip_all_reduce=True * gemma4_arf_rmsnorm_only(hidden_states, self.post_attention_layernorm) replaces self.post_attention_layernorm(hidden_states) Validation (google/gemma-4-31B-it, H100 TP=2, triton, FROZEN_KV_MTP, 80 prompts, warmup 2, seed 1): Per-prompt parity (20 greedy prompts, temp=0): match_rate = 19/20 = 0.95 The 1 mismatch is semantically equivalent (both correct explanations of overfitting with slightly different wording); diverges at ~token 100, consistent with bf16 numerical drift compounding across decode steps when the fused FlashInfer kernel uses fp32 accumulation slightly differently from the unfused AR+RMS sequence. MMLU N=500 (seed 0, temp 0): ARF off: 0.780 (390/500) [exact baseline] ARF on : 0.778 (389/500) delta = -0.2 pp [within +/- 1 pp] Benchmark: Metric | ARF off | ARF on | Delta ---------------|--------:|---------:|------ chat tok/s | 1442 | **1479** | **+2.6%** chat med TTFT | 2826 | 2811 | -0.5% chat med TPOT | 29.7 | **28.7** | **-3.4%** summ tok/s | 303 | 308 | +1.7% summ med TTFT | 77838 | 76242 | -2.1% summ med TPOT | 29.8 | 30.3 | +1.7% (noise) accept length | 3.12 | 3.15 | +1.0% The wins are on the lower end of vLLM's advertised 5-20% E2E range for fuse_allreduce_rms. Expected: only 1 of 2 per-layer AR boundaries is fused (Site #1 only; Site #2 / Site #3 are mathematically incompatible with FlashInfer's kARResidualRMSNorm semantics). Stack base: pyc/gemma4-arf-ops @ be87667a1 Co-authored-by: Claude --- python/sglang/srt/layers/gemma4_fused_ops.py | 64 ++++++++++++++++++++ python/sglang/srt/models/gemma4_causal.py | 52 ++++++++++++++-- 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py index 127c42b29573..4ef2c00f0e87 100644 --- a/python/sglang/srt/layers/gemma4_fused_ops.py +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -163,6 +163,70 @@ def gemma4_arf_rmsnorm_residual_scalar( return gemma_rmsnorm_residual_scalar(x_reduced, weight, residual, scalar, eps) +def gemma4_arf_rmsnorm_only( + x: torch.Tensor, + norm_module, + use_attn_tp_group: bool = True, +) -> torch.Tensor: + """Fused TP all-reduce + single-arg RMSNorm for Gemma-4 + ``post_attention_layernorm``. + + Numerically equivalent to:: + + x_reduced = tensor_model_parallel_all_reduce(x) + return norm_module.forward(x_reduced) + + where ``norm_module`` is a standard SGLang ``RMSNorm`` whose math is + ``rmsnorm(x) * weight``. This wrapper is the **correct fusion site** + for Gemma-4's residual flow because Gemma-4 places a single-arg + RMSNorm immediately after the attention all-reduce (before any + residual addition). + + Why the zero-residual trick: + FlashInfer's TRT-LLM ``allreduce_fusion`` API only exposes the + ``kARResidualRMSNorm`` pattern (no residual-less variant). vLLM's + ``AllReduceRMSNormPattern`` solves this by synthesizing a + ``torch.zeros_like(input)`` residual; the math + ``rmsnorm(AR(x) + 0) == rmsnorm(AR(x))`` makes the residual + contribution vanish. We follow the same convention here. + + Caller contract: + * Caller must pass ``skip_all_reduce=True`` to the upstream + ``RowParallelLinear`` whose output is ``x``. + * ``x`` must be the still-TP-sharded post-attention projection. + * ``norm_module`` is the Gemma-4 layer's + ``post_attention_layernorm`` (a ``RMSNorm`` instance — *not* a + ``Gemma4RMSNorm``, because the latter's ``(weight + scale_shift)`` + gamma is not currently expressible in FlashInfer's pattern). + + Fallback: when FlashInfer is unavailable, batch too large, workspace + not ready, or the predicate is False, falls back to + ``tensor_model_parallel_all_reduce(x) + norm_module.forward(_)`` with + bit-identical semantics to the pre-fusion path. + """ + from sglang.srt.distributed import tensor_model_parallel_all_reduce + from sglang.srt.layers.communicator import apply_flashinfer_allreduce_fusion + from sglang.srt.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + if x.is_cuda and x.dim() == 2 and apply_flashinfer_allreduce_fusion(x.shape[0]): + zero_residual = torch.zeros_like(x) + norm_out, _residual_out = flashinfer_allreduce_residual_rmsnorm( + input_tensor=x, + residual=zero_residual, + weight=norm_module.weight.data, + eps=norm_module.variance_epsilon, + use_attn_tp_group=use_attn_tp_group, + ) + if norm_out is not None: + return norm_out + + # Fallback: identical to the pre-fusion code path. + x_reduced = tensor_model_parallel_all_reduce(x) + return norm_module.forward(x_reduced) + + @triton.jit def _gemma_dual_rmsnorm_residual_kernel( X1_ptr, diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index a943730cc893..4d292caad7dc 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -30,6 +30,8 @@ get_tensor_model_parallel_world_size, ) from sglang.srt.layers.gemma4_fused_ops import ( + gemma4_arf_rmsnorm_only, + gemma4_arf_rmsnorm_residual_scalar, gemma4_fused_routing, gemma_dual_rmsnorm_residual_scalar, gemma_qkv_rmsnorm, @@ -407,6 +409,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, + skip_all_reduce: bool = False, **kwargs, ): qkv, _ = self.qkv_proj(hidden_states) @@ -491,7 +494,14 @@ def forward( ) if attn_output.dim() == 3: attn_output = attn_output.flatten(-2, -1) - output, _ = self.o_proj(attn_output) + # ARF-fast-path: when the caller signals it will fuse the + # ``o_proj`` TP all-reduce with the downstream + # ``post_attention_layernorm`` via + # ``gemma4_arf_rmsnorm_only``, ``o_proj`` must NOT do its own + # all-reduce (otherwise the gradient is double-reduced). Safe + # default ``skip_all_reduce=False`` preserves current behavior + # for all non-ARF callers. + output, _ = self.o_proj(attn_output, skip_all_reduce=skip_all_reduce) return output @@ -621,6 +631,22 @@ def __init__( self.has_ple = self.hidden_size_per_layer_input > 0 self.prefix = prefix + # FlashInfer AR+RMSNorm fusion opt-in (PR-B/2 of the Gemma-4 ARF + # stack). Cache the server-arg flag at __init__ time to avoid a + # per-step lookup; the actual runtime gate also checks + # ``apply_flashinfer_allreduce_fusion(num_tokens)`` inside + # ``gemma4_arf_rmsnorm_residual_scalar``. ARF is only wired into + # the dense (non-MoE, non-PLE) post-FF combine in v0. + try: + _server_args = get_global_server_args() + except Exception: + _server_args = None + self._arf_enabled = ( + bool(getattr(_server_args, "enable_flashinfer_allreduce_fusion", False)) + and not self.enable_moe_block + and not self.has_ple + ) + def forward( self, positions: torch.Tensor, @@ -644,12 +670,28 @@ def forward( # Apply input layernorm hidden_states = self.input_layernorm(hidden_states) + # ARF fast-path for the post-attention all-reduce + RMSNorm. + # When ``self._arf_enabled`` is True, ``self_attn`` skips its + # internal ``o_proj`` all-reduce and the downstream + # ``gemma4_arf_rmsnorm_only`` calls FlashInfer's TRT-LLM fused + # AR+RMSNorm kernel (with a zero residual to satisfy the + # FlashInfer API; the residual contribution vanishes + # mathematically). The wrapper falls back to plain + # AR + post_attention_layernorm if the runtime predicate is + # False (batch too large, FlashInfer unavailable, etc.). hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, + skip_all_reduce=self._arf_enabled, ) - hidden_states = self.post_attention_layernorm(hidden_states) + if self._arf_enabled: + hidden_states = gemma4_arf_rmsnorm_only( + hidden_states, + self.post_attention_layernorm, + ) + else: + hidden_states = self.post_attention_layernorm(hidden_states) if self.enable_moe_block: # Fuse: hidden_states + residual -> residual; pre_ff_norm(residual) -> hidden_states @@ -943,9 +985,9 @@ def forward( ) hidden_states = input_embeds else: - assert ( - pp_proxy_tensors is not None - ), "pp_proxy_tensors is required on non-first PP ranks" + assert pp_proxy_tensors is not None, ( + "pp_proxy_tensors is required on non-first PP ranks" + ) hidden_states = pp_proxy_tensors["hidden_states"] # PLE inputs were computed on rank 0 and forwarded along the # pipeline; non-PLE models simply omit the key.