feat(gemma4 ARF): wire FlashInfer AR+RMSNorm into post-attention site (PR-B/2)#20
Draft
pyc96 wants to merge 1 commit into
Draft
feat(gemma4 ARF): wire FlashInfer AR+RMSNorm into post-attention site (PR-B/2)#20pyc96 wants to merge 1 commit into
pyc96 wants to merge 1 commit into
Conversation
… (PR-B/2) PR-B of the 2-PR ARF stack. Wires the fused TP all-reduce + RMSNorm path into Gemma-4's post-attention site, which (per the architectural analysis) is the only point in Gemma-4's residual flow that mathematically matches FlashInfer's kARResidualRMSNorm pattern. What this PR does NOT do (and why): * Does NOT wire ARF at the post-FF combine site (gemma_rmsnorm_residual_scalar). Gemma's post-FF formula is (rmsnorm(x) + residual) * scalar — i.e. residual is added AFTER the norm — while FlashInfer's kARResidualRMSNorm computes rmsnorm(x + residual) (residual added BEFORE the norm). Empirically verified the two produce different outputs (max diff 7.27, mean 2.09 on a 4x8 sample). An attempted wiring at this site produced token soup. * Does NOT wire ARF at the next-layer input_layernorm. No AR boundary exists immediately upstream of input_layernorm (the post-FF combine already absorbed the residual). * Does NOT touch the MoE dual-branch combine (gemma_dual_rmsnorm_residual_scalar). Two upstream AR boundaries (dense MLP + MoE); out of scope for v0. * Does NOT touch PLE-enabled variants (E4B/E2B); guarded by self.has_ple. Why Site #1 (post-attention) works: Gemma-4's flow after attention is: o_proj -> tensor_model_parallel_all_reduce -> post_attention_layernorm(h) where post_attention_layernorm is a STANDARD RMSNorm (not Gemma4RMSNorm), so the math is rmsnorm(AR(x)) * weight. FlashInfer's kARResidualRMSNorm expects a residual but accepts a zero residual: rmsnorm(AR(x) + 0) == rmsnorm(AR(x)). This is the same workaround vLLM uses in AllReduceRMSNormPattern. Changes: * python/sglang/srt/layers/gemma4_fused_ops.py: New function gemma4_arf_rmsnorm_only(x, norm_module, use_attn_tp_group=True) that: - Calls flashinfer_allreduce_residual_rmsnorm with a zero residual, discards the residual output, returns just the rmsnorm output. - Falls back to tensor_model_parallel_all_reduce(x) + norm_module.forward(_) when the predicate is False or flashinfer returns (None, None). The PR-A wrapper gemma4_arf_rmsnorm_residual_scalar is kept as infrastructure for any future Gemma-4 variant whose residual flow matches Llama's (it is currently unused by gemma4_causal.py). * python/sglang/srt/models/gemma4_causal.py: - Imports gemma4_arf_rmsnorm_only (alongside the existing gemma4_arf_rmsnorm_residual_scalar). - Threads skip_all_reduce kwarg through Gemma4Attention.forward to the o_proj call (default False preserves current behavior). - At the post-attention site, when self._arf_enabled (set in __init__ based on get_global_server_args().enable_flashinfer_allreduce_fusion and gated on not enable_moe_block and not has_ple): * self_attn is called with skip_all_reduce=True * gemma4_arf_rmsnorm_only(hidden_states, self.post_attention_layernorm) replaces self.post_attention_layernorm(hidden_states) Validation (google/gemma-4-31B-it, H100 TP=2, triton, FROZEN_KV_MTP, 80 prompts, warmup 2, seed 1): Per-prompt parity (20 greedy prompts, temp=0): match_rate = 19/20 = 0.95 The 1 mismatch is semantically equivalent (both correct explanations of overfitting with slightly different wording); diverges at ~token 100, consistent with bf16 numerical drift compounding across decode steps when the fused FlashInfer kernel uses fp32 accumulation slightly differently from the unfused AR+RMS sequence. MMLU N=500 (seed 0, temp 0): ARF off: 0.780 (390/500) [exact baseline] ARF on : 0.778 (389/500) delta = -0.2 pp [within +/- 1 pp] Benchmark: Metric | ARF off | ARF on | Delta ---------------|--------:|---------:|------ chat tok/s | 1442 | **1479** | **+2.6%** chat med TTFT | 2826 | 2811 | -0.5% chat med TPOT | 29.7 | **28.7** | **-3.4%** summ tok/s | 303 | 308 | +1.7% summ med TTFT | 77838 | 76242 | -2.1% summ med TPOT | 29.8 | 30.3 | +1.7% (noise) accept length | 3.12 | 3.15 | +1.0% The wins are on the lower end of vLLM's advertised 5-20% E2E range for fuse_allreduce_rms. Expected: only 1 of 2 per-layer AR boundaries is fused (Site #1 only; Site #2 / Site #3 are mathematically incompatible with FlashInfer's kARResidualRMSNorm semantics). Stack base: pyc/gemma4-arf-ops @ be87667 Co-authored-by: Claude
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stacked on PR-A #19. Wires FlashInfer's TRT-LLM `kARResidualRMSNorm` fused kernel into Gemma-4's post-attention site (the `o_proj` AR + `post_attention_layernorm` pair). On H100 TP=2 with `google/gemma-4-31B-it` + FROZEN_KV_MTP this delivers +2.6 % chat tok/s, -3.4 % chat TPOT, -2.1 % summ TTFT, with MMLU tied at 0.778 vs 0.780 (1 question difference).
What this PR does NOT do (and why)
Two of the three theoretical fusion sites in Gemma-4 are mathematically incompatible with FlashInfer's `kARResidualRMSNorm` semantics:
Only Site #1 (post-attention) has the right shape: `o_proj` → AR → `post_attention_layernorm(x)` is a STANDARD `RMSNorm` (not `Gemma4RMSNorm`), with single-arg semantics `rmsnorm(x) * weight`. FlashInfer's `kARResidualRMSNorm` API requires a residual, but a zero residual makes the contribution vanish (`rmsnorm(AR(x) + 0) == rmsnorm(AR(x))`). This is the same workaround vLLM uses in `AllReduceRMSNormPattern`.
What's in PR-B
Validation
Per-prompt parity (20 greedy prompts, temp=0, max_tokens=64)
```
match_rate = 19/20 = 0.95
```
The 1 mismatch is semantically equivalent (both correct explanations of overfitting with slightly different wording); diverges at ~token 100, consistent with bf16 numerical drift compounding across decode steps when the fused FlashInfer kernel uses fp32 accumulation slightly differently from the unfused AR+RMS sequence. No correctness regression — both responses pass a manual semantic check.
MMLU N=500 (seed 0, temp 0)
Δ = −0.2 pp (1 question difference; within ±1 pp bar). Identical to the pre-patch SGLang baseline of 0.780.
Benchmark (`google/gemma-4-31B-it`, H100 TP=2, triton, FROZEN_KV_MTP, 80 prompts, warmup 2, seed 1)
The wins are on the lower end of vLLM's advertised 5-20 % E2E range for `fuse_allreduce_rms`. Expected: only 1 of 2 per-layer AR boundaries is fused (Site #1 only; Sites #2/#3 are mathematically incompatible). Fusing the second site would require restructuring Gemma's residual flow to Llama-style, which is a major behavioral change with no guarantee of MMLU stability.
Server-log evidence
```
[arf_on startup] Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for Gemma4ForConditionalGeneration
[arf_off startup] Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for Gemma4ForConditionalGeneration
FlashInfer allreduce fusion is forcibly disabled via --enforce-disable-flashinfer-allreduce-fusion.
```
Stack
Stack base: `pyc/gemma4-arf-ops` @ `be87667a1`
Plan: `.humanize/yoco-gemma4/refined-plan.md`
CI States
Latest PR Test (Base): ❌ Missing
run-cilabel -- add it to run CI tests.Latest PR Test (Extra): ❌ Blocked --
run-ciis required first.