Skip to content

feat(gemma4 ARF): infrastructure - wrapper + Gemma3MLP skip_all_reduce + auto-enable (PR-A/2)#19

Draft
pyc96 wants to merge 1 commit into
pyc/sota-gemma4-31b-mm-disabledfrom
pyc/gemma4-arf-ops
Draft

feat(gemma4 ARF): infrastructure - wrapper + Gemma3MLP skip_all_reduce + auto-enable (PR-A/2)#19
pyc96 wants to merge 1 commit into
pyc/sota-gemma4-31b-mm-disabledfrom
pyc/gemma4-arf-ops

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 25, 2026

Summary

PR-A of a 2-PR stack that adds FlashInfer AllReduce + RMSNorm fusion infrastructure for Gemma-4. Does not change any runtime behavior on its own — the wrapper is added but not wired. PR-B (#TBD) wires it into the post-attention site in `Gemma4DecoderLayer.forward`.

Background

vLLM's `fuse_allreduce_rms` Inductor pass is technically enabled for Gemma-4 at compile mode O2 (per vllm/config/vllm.py:122-142) but never matches Gemma-4's residual flow because Gemma uses `RMSNorm(x) + residual` rather than the Llama-style two-arg `RMSNorm(x, residual)` the pattern looks for. SGLang already has the `flashinfer_allreduce_residual_rmsnorm` direct-call API (used by Qwen3-MoE, DeepSeek-V3, GLM4-MoE etc.). Adding it to Gemma-4 gives SGLang a fusion vLLM doesn't actually deliver on Gemma-4 today.

What's in PR-A

File Change
`python/sglang/srt/layers/gemma4_fused_ops.py` New function `gemma4_arf_rmsnorm_residual_scalar(x, weight, residual, scalar, eps, use_attn_tp_group=True)` that wraps `flashinfer_allreduce_residual_rmsnorm` and applies a `* scalar` tail. Falls back to `tensor_model_parallel_all_reduce + gemma_rmsnorm_residual_scalar` on any predicate-False / non-CUDA / non-2D / FlashInfer-unavailable signal.
`python/sglang/srt/models/gemma3_causal.py` `Gemma3MLP.forward` (= `Gemma4MLP.forward` via alias) now accepts `skip_all_reduce: bool = False`, threaded through to `down_proj`. Default `False` preserves current behavior for every other caller.
`python/sglang/srt/server_args.py` Adds `Gemma4ForCausalLM` + `Gemma4ForConditionalGeneration` to the `flashinfer_allreduce_fusion` auto-enable allow-list. Same preconditions as the existing 13 archs (SM90/100, TP>1, single-node, not H20, no DP-attn, no MoE-A2A).
`test/registered/unit/layers/test_gemma4_arf_ops.py` 4 unit tests with FlashInfer + all-reduce fully mocked (runs on CPU): success path, fallback path, predicate-off path, non-cuda fallback. All pass.

What this PR does NOT change

  • `gemma4_causal.py` is unchanged — the new wrapper is not called yet.
  • No runtime behavior change for any deployment that doesn't pass `--enable-flashinfer-allreduce-fusion` AND wait for PR-B to wire the call site.
  • The auto-enable list change does mean a Gemma-4 server on TP=2 H100 will log `"Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for Gemma4ForCausalLM"` at startup, but without PR-B's wiring nothing actually fires.

Tests

```
$ python test/registered/unit/layers/test_gemma4_arf_ops.py
....
Ran 4 tests in 1.053s
OK
```

Test What it verifies
`test_success_path_uses_flashinfer_and_applies_scalar` wrapper returns `norm_out * scalar` from mocked FlashInfer; AR helper / fallback kernel NOT invoked
`test_fallback_when_flashinfer_returns_none` wrapper falls back to `AR + kernel` sequence when FlashInfer returns `(None, None)`
`test_predicate_off_uses_fallback_directly` wrapper takes fallback without calling FlashInfer when predicate is False
`test_non_cuda_input_takes_fallback` CPU tensors short-circuit through fallback (is_cuda gate)

Diff noise note

`server_args.py` shows 327 lines changed but only 9 are mine — the rest is auto-format reflow of unrelated `assert` statements. Real change: 4 new lines in the auto-enable list (`"Gemma4ForCausalLM"`, `"Gemma4ForConditionalGeneration"`) + 5-line comment.

Stack

Stack base: `pyc/sota-gemma4-31b-mm-disabled` @ `3a3195b30`

Plan: `.humanize/yoco-gemma4/refined-plan.md` (the ARF stack inherits the same review-plan structure)


CI States

Latest PR Test (Base): ❌ Missing run-ci label -- add it to run CI tests.
Latest PR Test (Extra): ❌ Blocked -- run-ci is required first.

…le (PR-A/2)

PR-A of a 2-PR stack that wires SGLang's existing
flashinfer_allreduce_residual_rmsnorm fusion into Gemma-4's dense post-FF
combine path. This PR adds the building blocks; PR-B wires them into
Gemma4DecoderLayer.forward.

Background: vLLM's fuse_allreduce_rms Inductor pass is technically enabled
for Gemma-4 at compile mode O2 but never matches Gemma-4's residual flow
(Gemma uses RMSNorm(x) + residual rather than the two-arg RMSNorm(x,
residual) form Llama uses). SGLang already exposes
flashinfer_allreduce_residual_rmsnorm as a direct-call Python op used by
Qwen3-MoE, DeepSeek-V3, GLM4-MoE etc. By calling it explicitly from the
Gemma-4 model code at the post-FF combine site, we get the fusion vLLM
nominally has but never actually delivers on Gemma-4.

Changes:

* python/sglang/srt/layers/gemma4_fused_ops.py:
  New function gemma4_arf_rmsnorm_residual_scalar(x, weight, residual,
  scalar, eps, use_attn_tp_group=True) that:
  - Checks apply_flashinfer_allreduce_fusion(num_tokens) and calls
    flashinfer_allreduce_residual_rmsnorm to fuse AR + residual_add +
    RMSNorm into one TRT-LLM communication kernel.
  - On success, applies the Gemma-4 layer_scalar tail as a one-launch
    broadcast mul.
  - On any fallback signal (predicate false, non-cuda input, flashinfer
    returns (None, None) for batch>2048 / workspace-init-failed /
    non-contiguous / FlashInfer unavailable), falls back to the explicit
    tensor_model_parallel_all_reduce + gemma_rmsnorm_residual_scalar
    sequence with bit-identical semantics to the pre-fusion path.

* python/sglang/srt/models/gemma3_causal.py:
  Threads skip_all_reduce kwarg through Gemma3MLP.forward (= Gemma4MLP
  via alias) so the caller can opt the down_proj into AR-skip mode.
  Default False preserves current behavior for every other caller.

* python/sglang/srt/server_args.py:
  Adds Gemma4ForCausalLM + Gemma4ForConditionalGeneration to the
  flashinfer_allreduce_fusion auto-enable allow-list, gated on the same
  preconditions as the existing 13 archs (SM90/100, TP>1, single-node,
  not H20, no DP-attn, no MoE-A2A).
  Server log on TP=2 H100 with default args now shows
    'Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for
     Gemma4ForCausalLM'

* test/registered/unit/layers/test_gemma4_arf_ops.py:
  4 unit tests with FlashInfer + all-reduce fully mocked (runs on CPU):
  - test_success_path_uses_flashinfer_and_applies_scalar: asserts
    out == norm_out * scalar and that AR helper / fallback kernel are
    NOT invoked.
  - test_fallback_when_flashinfer_returns_none: asserts AR + fallback
    kernel are invoked when flashinfer returns (None, None).
  - test_predicate_off_uses_fallback_directly: asserts flashinfer is not
    called when apply_flashinfer_allreduce_fusion returns False.
  - test_non_cuda_input_takes_fallback: asserts the is_cuda gate short-
    circuits to fallback for CPU tensors.

All 4 tests pass:
  Ran 4 tests in 1.053s
  OK

No runtime behavior change without PR-B (the model code still calls the
plain gemma_rmsnorm_residual_scalar; the new wrapper is unused).

The diff in server_args.py is ~325 lines but only 9 are mine -- the rest
is auto-format reflow of assert statements.

Stack base: pyc/sota-gemma4-31b-mm-disabled @ 3a3195b
Co-authored-by: Claude
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant