Skip to content

feat(gemma4 ARF): wire FlashInfer AR+RMSNorm into post-attention site (PR-B/2)#20

Draft
pyc96 wants to merge 1 commit into
pyc/gemma4-arf-opsfrom
pyc/gemma4-arf-wire
Draft

feat(gemma4 ARF): wire FlashInfer AR+RMSNorm into post-attention site (PR-B/2)#20
pyc96 wants to merge 1 commit into
pyc/gemma4-arf-opsfrom
pyc/gemma4-arf-wire

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 25, 2026

Summary

Stacked on PR-A #19. Wires FlashInfer's TRT-LLM `kARResidualRMSNorm` fused kernel into Gemma-4's post-attention site (the `o_proj` AR + `post_attention_layernorm` pair). On H100 TP=2 with `google/gemma-4-31B-it` + FROZEN_KV_MTP this delivers +2.6 % chat tok/s, -3.4 % chat TPOT, -2.1 % summ TTFT, with MMLU tied at 0.778 vs 0.780 (1 question difference).

What this PR does NOT do (and why)

Two of the three theoretical fusion sites in Gemma-4 are mathematically incompatible with FlashInfer's `kARResidualRMSNorm` semantics:

Only Site #1 (post-attention) has the right shape: `o_proj` → AR → `post_attention_layernorm(x)` is a STANDARD `RMSNorm` (not `Gemma4RMSNorm`), with single-arg semantics `rmsnorm(x) * weight`. FlashInfer's `kARResidualRMSNorm` API requires a residual, but a zero residual makes the contribution vanish (`rmsnorm(AR(x) + 0) == rmsnorm(AR(x))`). This is the same workaround vLLM uses in `AllReduceRMSNormPattern`.

What's in PR-B

File Change
`python/sglang/srt/layers/gemma4_fused_ops.py` New function `gemma4_arf_rmsnorm_only(x, norm_module, use_attn_tp_group=True)` (alongside PR-A's `gemma4_arf_rmsnorm_residual_scalar` which stays as unused infrastructure for future variants). Calls `flashinfer_allreduce_residual_rmsnorm` with a synthesized zero residual; falls back to `AR + norm_module.forward(_)` on the same predicates as PR-A's wrapper.
`python/sglang/srt/models/gemma4_causal.py` `Gemma4Attention.forward` threads `skip_all_reduce: bool = False` through to `o_proj`. `Gemma4DecoderLayer` caches `self._arf_enabled` in `init` (gated on `enable_flashinfer_allreduce_fusion` AND not MoE AND not PLE). `Gemma4DecoderLayer.forward`, when `self._arf_enabled` is True, calls `self_attn(..., skip_all_reduce=True)` then `gemma4_arf_rmsnorm_only(hidden_states, self.post_attention_layernorm)`.

Validation

Per-prompt parity (20 greedy prompts, temp=0, max_tokens=64)

```
match_rate = 19/20 = 0.95
```

The 1 mismatch is semantically equivalent (both correct explanations of overfitting with slightly different wording); diverges at ~token 100, consistent with bf16 numerical drift compounding across decode steps when the fused FlashInfer kernel uses fp32 accumulation slightly differently from the unfused AR+RMS sequence. No correctness regression — both responses pass a manual semantic check.

MMLU N=500 (seed 0, temp 0)

Stack accuracy correct/500
SGLang ARF off 0.780 390
SGLang ARF on 0.778 389

Δ = −0.2 pp (1 question difference; within ±1 pp bar). Identical to the pre-patch SGLang baseline of 0.780.

Benchmark (`google/gemma-4-31B-it`, H100 TP=2, triton, FROZEN_KV_MTP, 80 prompts, warmup 2, seed 1)

Metric ARF off ARF on Δ
chat 1k/1k tok/s 1442 1479 +2.6 %
chat median TTFT (ms) 2826 2811 −0.5 %
chat median TPOT (ms) 29.7 28.7 −3.4 %
chat accept_length 3.12 3.15 +1.0 %
summ 8k/1k tok/s 303 308 +1.7 %
summ median TTFT (ms) 77,838 76,242 −2.1 %
summ median TPOT (ms) 29.8 30.3 +1.7 % (noise)
summ accept_length 3.10 3.13 +1.0 %

The wins are on the lower end of vLLM's advertised 5-20 % E2E range for `fuse_allreduce_rms`. Expected: only 1 of 2 per-layer AR boundaries is fused (Site #1 only; Sites #2/#3 are mathematically incompatible). Fusing the second site would require restructuring Gemma's residual flow to Llama-style, which is a major behavioral change with no guarantee of MMLU stability.

Server-log evidence

```
[arf_on startup] Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for Gemma4ForConditionalGeneration
[arf_off startup] Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for Gemma4ForConditionalGeneration
FlashInfer allreduce fusion is forcibly disabled via --enforce-disable-flashinfer-allreduce-fusion.
```

Stack

Stack base: `pyc/gemma4-arf-ops` @ `be87667a1`

Plan: `.humanize/yoco-gemma4/refined-plan.md`


CI States

Latest PR Test (Base): ❌ Missing run-ci label -- add it to run CI tests.
Latest PR Test (Extra): ❌ Blocked -- run-ci is required first.

… (PR-B/2)

PR-B of the 2-PR ARF stack.  Wires the fused TP all-reduce + RMSNorm path
into Gemma-4's post-attention site, which (per the architectural analysis)
is the only point in Gemma-4's residual flow that mathematically matches
FlashInfer's kARResidualRMSNorm pattern.

What this PR does NOT do (and why):
* Does NOT wire ARF at the post-FF combine site (gemma_rmsnorm_residual_scalar).
  Gemma's post-FF formula is (rmsnorm(x) + residual) * scalar — i.e. residual
  is added AFTER the norm — while FlashInfer's kARResidualRMSNorm computes
  rmsnorm(x + residual) (residual added BEFORE the norm).  Empirically
  verified the two produce different outputs (max diff 7.27, mean 2.09 on
  a 4x8 sample).  An attempted wiring at this site produced token soup.
* Does NOT wire ARF at the next-layer input_layernorm.  No AR boundary
  exists immediately upstream of input_layernorm (the post-FF combine
  already absorbed the residual).
* Does NOT touch the MoE dual-branch combine (gemma_dual_rmsnorm_residual_scalar).
  Two upstream AR boundaries (dense MLP + MoE); out of scope for v0.
* Does NOT touch PLE-enabled variants (E4B/E2B); guarded by self.has_ple.

Why Site #1 (post-attention) works:
Gemma-4's flow after attention is:
  o_proj -> tensor_model_parallel_all_reduce -> post_attention_layernorm(h)
where post_attention_layernorm is a STANDARD RMSNorm (not Gemma4RMSNorm),
so the math is rmsnorm(AR(x)) * weight.  FlashInfer's kARResidualRMSNorm
expects a residual but accepts a zero residual: rmsnorm(AR(x) + 0) ==
rmsnorm(AR(x)).  This is the same workaround vLLM uses in
AllReduceRMSNormPattern.

Changes:

* python/sglang/srt/layers/gemma4_fused_ops.py:
  New function gemma4_arf_rmsnorm_only(x, norm_module, use_attn_tp_group=True)
  that:
  - Calls flashinfer_allreduce_residual_rmsnorm with a zero residual,
    discards the residual output, returns just the rmsnorm output.
  - Falls back to tensor_model_parallel_all_reduce(x) + norm_module.forward(_)
    when the predicate is False or flashinfer returns (None, None).
  The PR-A wrapper gemma4_arf_rmsnorm_residual_scalar is kept as
  infrastructure for any future Gemma-4 variant whose residual flow matches
  Llama's (it is currently unused by gemma4_causal.py).

* python/sglang/srt/models/gemma4_causal.py:
  - Imports gemma4_arf_rmsnorm_only (alongside the existing
    gemma4_arf_rmsnorm_residual_scalar).
  - Threads skip_all_reduce kwarg through Gemma4Attention.forward to the
    o_proj call (default False preserves current behavior).
  - At the post-attention site, when self._arf_enabled (set in __init__
    based on get_global_server_args().enable_flashinfer_allreduce_fusion
    and gated on not enable_moe_block and not has_ple):
      * self_attn is called with skip_all_reduce=True
      * gemma4_arf_rmsnorm_only(hidden_states, self.post_attention_layernorm)
        replaces self.post_attention_layernorm(hidden_states)

Validation (google/gemma-4-31B-it, H100 TP=2, triton, FROZEN_KV_MTP,
80 prompts, warmup 2, seed 1):

  Per-prompt parity (20 greedy prompts, temp=0):
    match_rate = 19/20 = 0.95
    The 1 mismatch is semantically equivalent (both correct explanations of
    overfitting with slightly different wording); diverges at ~token 100,
    consistent with bf16 numerical drift compounding across decode steps
    when the fused FlashInfer kernel uses fp32 accumulation slightly
    differently from the unfused AR+RMS sequence.

  MMLU N=500 (seed 0, temp 0):
    ARF off: 0.780 (390/500)  [exact baseline]
    ARF on : 0.778 (389/500)  delta = -0.2 pp  [within +/- 1 pp]

  Benchmark:
    Metric         | ARF off | ARF on   | Delta
    ---------------|--------:|---------:|------
    chat tok/s     |  1442   | **1479** | **+2.6%**
    chat med TTFT  |  2826   |  2811    | -0.5%
    chat med TPOT  |  29.7   | **28.7** | **-3.4%**
    summ tok/s     |   303   |   308    | +1.7%
    summ med TTFT  | 77838   | 76242    | -2.1%
    summ med TPOT  |  29.8   |  30.3    | +1.7% (noise)
    accept length  |  3.12   |  3.15    | +1.0%

  The wins are on the lower end of vLLM's advertised 5-20% E2E range
  for fuse_allreduce_rms.  Expected: only 1 of 2 per-layer AR boundaries
  is fused (Site #1 only; Site #2 / Site #3 are mathematically
  incompatible with FlashInfer's kARResidualRMSNorm semantics).

Stack base: pyc/gemma4-arf-ops @ be87667

Co-authored-by: Claude
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant