feat(gemma4 ARF): infrastructure - wrapper + Gemma3MLP skip_all_reduce + auto-enable (PR-A/2)#19
Draft
pyc96 wants to merge 1 commit into
Draft
feat(gemma4 ARF): infrastructure - wrapper + Gemma3MLP skip_all_reduce + auto-enable (PR-A/2)#19pyc96 wants to merge 1 commit into
pyc96 wants to merge 1 commit into
Conversation
…le (PR-A/2)
PR-A of a 2-PR stack that wires SGLang's existing
flashinfer_allreduce_residual_rmsnorm fusion into Gemma-4's dense post-FF
combine path. This PR adds the building blocks; PR-B wires them into
Gemma4DecoderLayer.forward.
Background: vLLM's fuse_allreduce_rms Inductor pass is technically enabled
for Gemma-4 at compile mode O2 but never matches Gemma-4's residual flow
(Gemma uses RMSNorm(x) + residual rather than the two-arg RMSNorm(x,
residual) form Llama uses). SGLang already exposes
flashinfer_allreduce_residual_rmsnorm as a direct-call Python op used by
Qwen3-MoE, DeepSeek-V3, GLM4-MoE etc. By calling it explicitly from the
Gemma-4 model code at the post-FF combine site, we get the fusion vLLM
nominally has but never actually delivers on Gemma-4.
Changes:
* python/sglang/srt/layers/gemma4_fused_ops.py:
New function gemma4_arf_rmsnorm_residual_scalar(x, weight, residual,
scalar, eps, use_attn_tp_group=True) that:
- Checks apply_flashinfer_allreduce_fusion(num_tokens) and calls
flashinfer_allreduce_residual_rmsnorm to fuse AR + residual_add +
RMSNorm into one TRT-LLM communication kernel.
- On success, applies the Gemma-4 layer_scalar tail as a one-launch
broadcast mul.
- On any fallback signal (predicate false, non-cuda input, flashinfer
returns (None, None) for batch>2048 / workspace-init-failed /
non-contiguous / FlashInfer unavailable), falls back to the explicit
tensor_model_parallel_all_reduce + gemma_rmsnorm_residual_scalar
sequence with bit-identical semantics to the pre-fusion path.
* python/sglang/srt/models/gemma3_causal.py:
Threads skip_all_reduce kwarg through Gemma3MLP.forward (= Gemma4MLP
via alias) so the caller can opt the down_proj into AR-skip mode.
Default False preserves current behavior for every other caller.
* python/sglang/srt/server_args.py:
Adds Gemma4ForCausalLM + Gemma4ForConditionalGeneration to the
flashinfer_allreduce_fusion auto-enable allow-list, gated on the same
preconditions as the existing 13 archs (SM90/100, TP>1, single-node,
not H20, no DP-attn, no MoE-A2A).
Server log on TP=2 H100 with default args now shows
'Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for
Gemma4ForCausalLM'
* test/registered/unit/layers/test_gemma4_arf_ops.py:
4 unit tests with FlashInfer + all-reduce fully mocked (runs on CPU):
- test_success_path_uses_flashinfer_and_applies_scalar: asserts
out == norm_out * scalar and that AR helper / fallback kernel are
NOT invoked.
- test_fallback_when_flashinfer_returns_none: asserts AR + fallback
kernel are invoked when flashinfer returns (None, None).
- test_predicate_off_uses_fallback_directly: asserts flashinfer is not
called when apply_flashinfer_allreduce_fusion returns False.
- test_non_cuda_input_takes_fallback: asserts the is_cuda gate short-
circuits to fallback for CPU tensors.
All 4 tests pass:
Ran 4 tests in 1.053s
OK
No runtime behavior change without PR-B (the model code still calls the
plain gemma_rmsnorm_residual_scalar; the new wrapper is unused).
The diff in server_args.py is ~325 lines but only 9 are mine -- the rest
is auto-format reflow of assert statements.
Stack base: pyc/sota-gemma4-31b-mm-disabled @ 3a3195b
Co-authored-by: Claude
This was referenced May 25, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
PR-A of a 2-PR stack that adds FlashInfer AllReduce + RMSNorm fusion infrastructure for Gemma-4. Does not change any runtime behavior on its own — the wrapper is added but not wired. PR-B (#TBD) wires it into the post-attention site in `Gemma4DecoderLayer.forward`.
Background
vLLM's `fuse_allreduce_rms` Inductor pass is technically enabled for Gemma-4 at compile mode O2 (per vllm/config/vllm.py:122-142) but never matches Gemma-4's residual flow because Gemma uses `RMSNorm(x) + residual` rather than the Llama-style two-arg `RMSNorm(x, residual)` the pattern looks for. SGLang already has the `flashinfer_allreduce_residual_rmsnorm` direct-call API (used by Qwen3-MoE, DeepSeek-V3, GLM4-MoE etc.). Adding it to Gemma-4 gives SGLang a fusion vLLM doesn't actually deliver on Gemma-4 today.
What's in PR-A
What this PR does NOT change
Tests
```
$ python test/registered/unit/layers/test_gemma4_arf_ops.py
....
Ran 4 tests in 1.053s
OK
```
Diff noise note
`server_args.py` shows 327 lines changed but only 9 are mine — the rest is auto-format reflow of unrelated `assert` statements. Real change: 4 new lines in the auto-enable list (`"Gemma4ForCausalLM"`, `"Gemma4ForConditionalGeneration"`) + 5-line comment.
Stack
Stack base: `pyc/sota-gemma4-31b-mm-disabled` @ `3a3195b30`
Plan: `.humanize/yoco-gemma4/refined-plan.md` (the ARF stack inherits the same review-plan structure)
CI States
Latest PR Test (Base): ❌ Missing
run-cilabel -- add it to run CI tests.Latest PR Test (Extra): ❌ Blocked --
run-ciis required first.