Skip to content

[Aiter][ROCm] RMSNormGated+GroupedQuantFP8 fusion#40710

Open
tpopp wants to merge 4 commits intovllm-project:mainfrom
tpopp:tpopp/gdn-rmsnorm-quant-fusion
Open

[Aiter][ROCm] RMSNormGated+GroupedQuantFP8 fusion#40710
tpopp wants to merge 4 commits intovllm-project:mainfrom
tpopp:tpopp/gdn-rmsnorm-quant-fusion

Conversation

@tpopp
Copy link
Copy Markdown
Contributor

@tpopp tpopp commented Apr 23, 2026

This PR adds a compilation fusion pass (AiterRMSNormGatedFp8GroupQuantPattern) that fuses the decomposed RMSNormGated + reshape + group FP8 quantization sequence into a single AITER Triton kernel call (fused_rms_gated_fp8_group_quant). This pattern appears in GatedDeltaNetAttention layers (e.g., Qwen3-Next) where each attention head's output goes through gated RMS normalization, is reshaped back to the full hidden dimension, and then group-quantized to FP8 before the output projection linear layer.

Results:
a 9us set of 2 kernels can be combined to 4.5us. In the case of Qwen3Next, this can be a 1-3% improvement depending on how small the workload is (concurrency 1 vs 128).

Motivation

In models using GatedDeltaNetAttention (such as Qwen3-Next-80B-A3B-Instruct-FP8), the output path of each attention block performs:

  1. RMSNormGated on per-head tensors (N*H, D) with a gating tensor
  2. Reshape to (N, H*D)
  3. Group FP8 quantization with GroupShape(1, 128)

These three operations decompose into many elementwise and reduction kernels when torch.compile lowers them. By matching this pattern in the FX graph and replacing it with a single fused Triton kernel from AITER, we eliminate multiple GPU kernel launches and intermediate memory traffic.

Changes
• Register rocm_aiter_fused_rms_gated_fp8_group_quant custom op wrapping aiter.ops.triton.quant.fused_rms_gated_fp8_group_quant
• Add rocm_aiter_ops.are_gdn_triton_kernels_available() — checks whether the required AITER Triton kernels (causal_conv1d_update_single_token, gated_delta_net) are importable, allowing graceful fallback on older AITER builds that lack the GDN kernels
• rocm_aiter_fusion.py: Add AiterRMSNormGatedFp8GroupQuantPattern that matches the decomposed norm→reshape→quant graph and replaces it with the fused op. Add _fold_consecutive_reshapes pre-processing pass (needed because make_fx faithfully
records chained reshapes that must be folded for the pattern to match). Dynamically infer num_heads/head_dim from GatedDeltaNetAttention layers via static_forward_context. Gate the pattern on are_gdn_triton_kernels_available()
• matcher_utils.py: Add MatcherRMSNormGated pattern tracer that traces RMSNormGated.forward_static for use in pm.register_replacement. Extend MatcherQuantFP8 to support Triton-based quant op matching
• layernorm.py: Extract RMSNormGated.forward_static as a @staticmethod so both forward_native and the matcher can share the same pure-PyTorch implementation. forward_native delegates to it
• test_fusion.py: Add unit tests (TestGatedModel) for the fusion pattern covering positive match cases (aiter quant, non-aiter quant, per-token dynamic) and negative cases (wrong group shape, per-tensor quant)

AITER Dependency

The fused Triton kernel (fused_rms_gated_fp8_group_quant) is provided by ROCm/aiter#2423 (​https://github.com/ROCm/aiter/pull/2423​) ("[Triton] optimized decode kernels for Qwen3-Next model"). The fusion pass is gated behind rocm_aiter_ops.are_gdn_triton_kernels_available(), so it is a no-op on AITER versions that do not include this PR.

Benchmark Results

Setup:
• Model: Qwen/Qwen3-Next-80B-A3B-Instruct-FP8, TP=1
• GPU: AMD MI355x (gfx950), single GPU
• Base image: vllm/vllm-openai-rocm:nightly (vLLM v0.19.2rc1) with AITER rebuilt from aiter:main + PR #2423
• Attention backend: ROCM_AITER_FA
• Compilation: cudagraph_mode=FULL_AND_PIECEWISE, custom_ops=["-rms_norm", "-silu_and_mul", "+quant_fp8"], pass_config={"fuse_norm_quant": true}
• Benchmark command: vllm bench serve --dataset_name random --random_input_len 1024 --random_output_len 1024 --max_concurrency 4 --num_prompts 32 --num_warmups 4 --seed 1 --temperature 0 --ignore_eos

Pattern matching verification:
• With fusion: RocmAiterRMSNormQuantFusionPass replaced 5 patterns (1+2+2 across repeated-layer subgraphs — the 4 additional matches are from AiterRMSNormGatedFp8GroupQuantPattern)
• Without fusion (pattern commented out): replaced 1 pattern (only the existing non-gated AiterRMSNormDynamicQuantPattern)

Throughput (ISL=1024, OSL=1024, concurrency=4):

┌─────────────────────────────────┬─────────────┬──────────┬───────┐
│ Metric │ With Fusion │ Baseline │ Delta │
├─────────────────────────────────┼─────────────┼──────────┼───────┤
│ Output token throughput (tok/s) │ 467.05 │ 456.52 │ +2.3% │
│ Total token throughput (tok/s) │ 934.11 │ 913.04 │ +2.3% │
│ Mean TPOT (ms) │ 8.44 │ 8.66 │ −2.5% │
│ P99 TPOT (ms) │ 8.67 │ 8.98 │ −3.5% │
│ Mean E2EL (ms) │ 8,769 │ 8,971 │ −2.3% │
└─────────────────────────────────┴─────────────┴──────────┴───────┘

Accuracy (lm_eval, gsm8k, 5-shot):

┌──────────────────┬────────────────┬────────────────┬─────────────────────────────┐
│ Filter │ With Fusion │ Baseline │ Delta │
├──────────────────┼────────────────┼────────────────┼─────────────────────────────┤
│ flexible-extract │ 0.8605 ±0.0095 │ 0.8506 ±0.0098 │ +0.0099 (within error bars) │
│ strict-match │ 0.8089 ±0.0108 │ 0.8097 ±0.0108 │ −0.0008 (within error bars) │
└──────────────────┴────────────────┴────────────────┴─────────────────────────────┘

Accuracy is statistically identical — the fusion is numerically safe.

Test plan

• [x] Unit tests: pytest tests/compile/passes/test_fusion.py -k "gated" — positive and negative pattern match cases
• [x] lm_eval --tasks gsm8k --num_fewshot 5 — accuracy unchanged vs. baseline
• [x] vllm bench serve — throughput improved ~2.3%, TPOT improved ~2.5%
• [x] Verified graceful no-op when AITER lacks GDN kernels (are_gdn_triton_kernels_available() == False)

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the rocm Related to AMD ROCm label Apr 23, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 23, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tpopp.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 23, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fusion support for RMSNormGated followed by FP8 group quantization on ROCm platforms using the aiter library. Key changes include the registration of a new fused custom operator, the implementation of a MatcherRMSNormGated class, and updates to the RocmAiterRMSNormQuantFusionPass to discover and fuse these patterns. Feedback focuses on critical safety issues regarding the global monkey-patching of the pattern matcher's type handling, which could lead to incorrect matches for other operators. Additionally, improvements were suggested to ensure the gated fusion pattern correctly supports both aiter and decomposed quantization variants and strictly validates the supported group size of 128 to prevent numerical errors.

Comment on lines +513 to +523
_orig_fx_to_pat = pm.fx_to_pattern

def _relaxed_fx_to_pattern(*a, **kw):
kw["ignore_types"] = (int, torch.SymInt)
return _orig_fx_to_pat(*a, **kw)

pm.fx_to_pattern = _relaxed_fx_to_pattern
try:
self.matched_count = self.patterns.apply(graph)
finally:
pm.fx_to_pattern = _orig_fx_to_pat
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Monkey-patching pm.fx_to_pattern to ignore all int and torch.SymInt types is extremely dangerous. This change affects all patterns registered in self.patterns, including those that rely on specific integer arguments for correctness (e.g., group_size=128 in AiterRMSFp8GroupQuantPattern). If a graph contains a quantization op with a different group size (e.g., 64), the matcher will incorrectly identify it as a match, leading to a replacement with a fused op that uses the wrong group size. This will cause silent numerical errors. A more targeted approach to handle SymInt in reshapes should be used instead of a global type ignore.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't found a better approach due to the shortcomings of the pytorch pattern matching based approach. This is becoming a common problem, especially when multiple reshapes exist.

Comment thread vllm/compilation/passes/fusion/rocm_aiter_fusion.py
Comment thread vllm/compilation/passes/fusion/rocm_aiter_fusion.py Outdated
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 5c39363 to 2c82404 Compare April 23, 2026 17:09
@mergify mergify Bot removed the needs-rebase label Apr 23, 2026
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 2c82404 to d4f1b17 Compare May 4, 2026 06:53
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch 3 times, most recently from 31da8cb to 7b6683e Compare May 4, 2026 14:08
@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 5, 2026

Some cleanup has been done and needs higher level feedback and a ready label to allow more complete testing.

@dllehr-amd
Copy link
Copy Markdown
Contributor

@gshtras Can you add the ready label for me?

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label May 5, 2026
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 5753895 to 3307453 Compare May 6, 2026 08:09
@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 6, 2026

Rebased to retrigger CI. Failures were existing failures.

@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 6, 2026

@tjtanaa seems to be the relevant CODEOWNER for this PR.

tpopp and others added 4 commits May 6, 2026 04:06
…y check

Register fused_rms_gated_fp8_group_quant custom op that wraps the
aiter Triton kernel for fused gated RMSNorm + FP8 group quantization.
Also add are_gdn_triton_kernels_available() to check whether the
required aiter Triton kernels (conv1d single-token, gated delta net)
are importable, allowing graceful fallback on older aiter versions.

Made-with: Cursor

Signed-off-by: Tres Popp <tres.popp@amd.com>
Implement pattern matching and replacement for decomposed RMSNormGated
followed by group FP8 quantization, fusing them into a single aiter
Triton kernel (fused_rms_gated_fp8_group_quant).

Key changes:
- Add AiterRMSNormGatedFp8GroupQuantPattern in rocm_aiter_fusion.py
  that matches the decomposed norm+reshape+quant graph and replaces it
  with the fused op
- Extend MatcherQuantFP8 and MatcherRMSNormGated in matcher_utils.py
  to support the gated norm pattern tracing
- Add forward_static to RMSNormGated for code sharing with the matcher
  and have forward_native delegate to it
- Simplify input_quant_fp8.py by extracting shared logic into
  forward_static
- Dynamically infer num_heads/head_dim from GatedDeltaNetAttention
  layers via static_forward_context
- Register per-token dynamic quant patterns for both aiter and
  non-aiter quant ops to handle +/- quant_fp8 configurations
- Gate the gated pattern on are_gdn_triton_kernels_available()
- Add unit tests for the fusion pattern (positive and negative cases)

Made-with: Cursor

Signed-off-by: Tres Popp <tres.popp@amd.com>
- Remove unused MatcherFusedAddRMSNorm and its dead imports (RMSNorm,
  RMS_ADD_OP)
- Move fold_consecutive_reshapes to vllm_inductor_pass.py next to the
  related _fx_view_to_reshape helper
- Add docstrings to new _aiter_ops methods
  (fused_rms_gated_fp8_group_quant impl and getter)
- Check fused_rms_gated_fp8_group_quant importability in
  are_gdn_triton_kernels_available
- Restore docstring on RMSNormGated.forward_native

Signed-off-by: Tres Popp <trespopp@gmail.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Iterate over use_triton for group quant patterns so both the CK and
triton backends are matched.  Use a set to deduplicate when quant_fp8
is disabled (forward_native is identical for both use_triton values).
Add a head_dim == 128 guard to AiterRMSNormGatedFp8GroupQuantPattern
since the fused kernel hardcodes group_size=head_dim.

Rename _fx_view_to_reshape to fx_view_to_reshape as it is not private.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
@tpopp tpopp force-pushed the tpopp/gdn-rmsnorm-quant-fusion branch from 3307453 to f7fa464 Compare May 6, 2026 09:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

3 participants