[Aiter][ROCm] RMSNormGated+GroupedQuantFP8 fusion#40710
[Aiter][ROCm] RMSNormGated+GroupedQuantFP8 fusion#40710tpopp wants to merge 4 commits intovllm-project:mainfrom
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces fusion support for RMSNormGated followed by FP8 group quantization on ROCm platforms using the aiter library. Key changes include the registration of a new fused custom operator, the implementation of a MatcherRMSNormGated class, and updates to the RocmAiterRMSNormQuantFusionPass to discover and fuse these patterns. Feedback focuses on critical safety issues regarding the global monkey-patching of the pattern matcher's type handling, which could lead to incorrect matches for other operators. Additionally, improvements were suggested to ensure the gated fusion pattern correctly supports both aiter and decomposed quantization variants and strictly validates the supported group size of 128 to prevent numerical errors.
| _orig_fx_to_pat = pm.fx_to_pattern | ||
|
|
||
| def _relaxed_fx_to_pattern(*a, **kw): | ||
| kw["ignore_types"] = (int, torch.SymInt) | ||
| return _orig_fx_to_pat(*a, **kw) | ||
|
|
||
| pm.fx_to_pattern = _relaxed_fx_to_pattern | ||
| try: | ||
| self.matched_count = self.patterns.apply(graph) | ||
| finally: | ||
| pm.fx_to_pattern = _orig_fx_to_pat |
There was a problem hiding this comment.
Monkey-patching pm.fx_to_pattern to ignore all int and torch.SymInt types is extremely dangerous. This change affects all patterns registered in self.patterns, including those that rely on specific integer arguments for correctness (e.g., group_size=128 in AiterRMSFp8GroupQuantPattern). If a graph contains a quantization op with a different group size (e.g., 64), the matcher will incorrectly identify it as a match, leading to a replacement with a fused op that uses the wrong group size. This will cause silent numerical errors. A more targeted approach to handle SymInt in reshapes should be used instead of a global type ignore.
There was a problem hiding this comment.
I haven't found a better approach due to the shortcomings of the pytorch pattern matching based approach. This is becoming a common problem, especially when multiple reshapes exist.
5c39363 to
2c82404
Compare
2c82404 to
d4f1b17
Compare
31da8cb to
7b6683e
Compare
|
Some cleanup has been done and needs higher level feedback and a ready label to allow more complete testing. |
|
@gshtras Can you add the |
5753895 to
3307453
Compare
|
Rebased to retrigger CI. Failures were existing failures. |
|
@tjtanaa seems to be the relevant CODEOWNER for this PR. |
…y check Register fused_rms_gated_fp8_group_quant custom op that wraps the aiter Triton kernel for fused gated RMSNorm + FP8 group quantization. Also add are_gdn_triton_kernels_available() to check whether the required aiter Triton kernels (conv1d single-token, gated delta net) are importable, allowing graceful fallback on older aiter versions. Made-with: Cursor Signed-off-by: Tres Popp <tres.popp@amd.com>
Implement pattern matching and replacement for decomposed RMSNormGated followed by group FP8 quantization, fusing them into a single aiter Triton kernel (fused_rms_gated_fp8_group_quant). Key changes: - Add AiterRMSNormGatedFp8GroupQuantPattern in rocm_aiter_fusion.py that matches the decomposed norm+reshape+quant graph and replaces it with the fused op - Extend MatcherQuantFP8 and MatcherRMSNormGated in matcher_utils.py to support the gated norm pattern tracing - Add forward_static to RMSNormGated for code sharing with the matcher and have forward_native delegate to it - Simplify input_quant_fp8.py by extracting shared logic into forward_static - Dynamically infer num_heads/head_dim from GatedDeltaNetAttention layers via static_forward_context - Register per-token dynamic quant patterns for both aiter and non-aiter quant ops to handle +/- quant_fp8 configurations - Gate the gated pattern on are_gdn_triton_kernels_available() - Add unit tests for the fusion pattern (positive and negative cases) Made-with: Cursor Signed-off-by: Tres Popp <tres.popp@amd.com>
- Remove unused MatcherFusedAddRMSNorm and its dead imports (RMSNorm, RMS_ADD_OP) - Move fold_consecutive_reshapes to vllm_inductor_pass.py next to the related _fx_view_to_reshape helper - Add docstrings to new _aiter_ops methods (fused_rms_gated_fp8_group_quant impl and getter) - Check fused_rms_gated_fp8_group_quant importability in are_gdn_triton_kernels_available - Restore docstring on RMSNormGated.forward_native Signed-off-by: Tres Popp <trespopp@gmail.com> Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Tres Popp <tres.popp@amd.com>
Iterate over use_triton for group quant patterns so both the CK and triton backends are matched. Use a set to deduplicate when quant_fp8 is disabled (forward_native is identical for both use_triton values). Add a head_dim == 128 guard to AiterRMSNormGatedFp8GroupQuantPattern since the fused kernel hardcodes group_size=head_dim. Rename _fx_view_to_reshape to fx_view_to_reshape as it is not private. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Tres Popp <tres.popp@amd.com>
3307453 to
f7fa464
Compare
This PR adds a compilation fusion pass (AiterRMSNormGatedFp8GroupQuantPattern) that fuses the decomposed RMSNormGated + reshape + group FP8 quantization sequence into a single AITER Triton kernel call (fused_rms_gated_fp8_group_quant). This pattern appears in GatedDeltaNetAttention layers (e.g., Qwen3-Next) where each attention head's output goes through gated RMS normalization, is reshaped back to the full hidden dimension, and then group-quantized to FP8 before the output projection linear layer.
Results:
a 9us set of 2 kernels can be combined to 4.5us. In the case of Qwen3Next, this can be a 1-3% improvement depending on how small the workload is (concurrency 1 vs 128).
Motivation
In models using GatedDeltaNetAttention (such as Qwen3-Next-80B-A3B-Instruct-FP8), the output path of each attention block performs:
These three operations decompose into many elementwise and reduction kernels when torch.compile lowers them. By matching this pattern in the FX graph and replacing it with a single fused Triton kernel from AITER, we eliminate multiple GPU kernel launches and intermediate memory traffic.
Changes
• Register rocm_aiter_fused_rms_gated_fp8_group_quant custom op wrapping aiter.ops.triton.quant.fused_rms_gated_fp8_group_quant
• Add rocm_aiter_ops.are_gdn_triton_kernels_available() — checks whether the required AITER Triton kernels (causal_conv1d_update_single_token, gated_delta_net) are importable, allowing graceful fallback on older AITER builds that lack the GDN kernels
• rocm_aiter_fusion.py: Add AiterRMSNormGatedFp8GroupQuantPattern that matches the decomposed norm→reshape→quant graph and replaces it with the fused op. Add _fold_consecutive_reshapes pre-processing pass (needed because make_fx faithfully
records chained reshapes that must be folded for the pattern to match). Dynamically infer num_heads/head_dim from GatedDeltaNetAttention layers via static_forward_context. Gate the pattern on are_gdn_triton_kernels_available()
• matcher_utils.py: Add MatcherRMSNormGated pattern tracer that traces RMSNormGated.forward_static for use in pm.register_replacement. Extend MatcherQuantFP8 to support Triton-based quant op matching
• layernorm.py: Extract RMSNormGated.forward_static as a @staticmethod so both forward_native and the matcher can share the same pure-PyTorch implementation. forward_native delegates to it
• test_fusion.py: Add unit tests (TestGatedModel) for the fusion pattern covering positive match cases (aiter quant, non-aiter quant, per-token dynamic) and negative cases (wrong group shape, per-tensor quant)
AITER Dependency
The fused Triton kernel (fused_rms_gated_fp8_group_quant) is provided by ROCm/aiter#2423 (https://github.com/ROCm/aiter/pull/2423) ("[Triton] optimized decode kernels for Qwen3-Next model"). The fusion pass is gated behind rocm_aiter_ops.are_gdn_triton_kernels_available(), so it is a no-op on AITER versions that do not include this PR.
Benchmark Results
Setup:
• Model: Qwen/Qwen3-Next-80B-A3B-Instruct-FP8, TP=1
• GPU: AMD MI355x (gfx950), single GPU
• Base image: vllm/vllm-openai-rocm:nightly (vLLM v0.19.2rc1) with AITER rebuilt from aiter:main + PR #2423
• Attention backend: ROCM_AITER_FA
• Compilation: cudagraph_mode=FULL_AND_PIECEWISE, custom_ops=["-rms_norm", "-silu_and_mul", "+quant_fp8"], pass_config={"fuse_norm_quant": true}
• Benchmark command: vllm bench serve --dataset_name random --random_input_len 1024 --random_output_len 1024 --max_concurrency 4 --num_prompts 32 --num_warmups 4 --seed 1 --temperature 0 --ignore_eos
Pattern matching verification:
• With fusion: RocmAiterRMSNormQuantFusionPass replaced 5 patterns (1+2+2 across repeated-layer subgraphs — the 4 additional matches are from AiterRMSNormGatedFp8GroupQuantPattern)
• Without fusion (pattern commented out): replaced 1 pattern (only the existing non-gated AiterRMSNormDynamicQuantPattern)
Throughput (ISL=1024, OSL=1024, concurrency=4):
┌─────────────────────────────────┬─────────────┬──────────┬───────┐
│ Metric │ With Fusion │ Baseline │ Delta │
├─────────────────────────────────┼─────────────┼──────────┼───────┤
│ Output token throughput (tok/s) │ 467.05 │ 456.52 │ +2.3% │
│ Total token throughput (tok/s) │ 934.11 │ 913.04 │ +2.3% │
│ Mean TPOT (ms) │ 8.44 │ 8.66 │ −2.5% │
│ P99 TPOT (ms) │ 8.67 │ 8.98 │ −3.5% │
│ Mean E2EL (ms) │ 8,769 │ 8,971 │ −2.3% │
└─────────────────────────────────┴─────────────┴──────────┴───────┘
Accuracy (lm_eval, gsm8k, 5-shot):
┌──────────────────┬────────────────┬────────────────┬─────────────────────────────┐
│ Filter │ With Fusion │ Baseline │ Delta │
├──────────────────┼────────────────┼────────────────┼─────────────────────────────┤
│ flexible-extract │ 0.8605 ±0.0095 │ 0.8506 ±0.0098 │ +0.0099 (within error bars) │
│ strict-match │ 0.8089 ±0.0108 │ 0.8097 ±0.0108 │ −0.0008 (within error bars) │
└──────────────────┴────────────────┴────────────────┴─────────────────────────────┘
Accuracy is statistically identical — the fusion is numerically safe.
Test plan
• [x] Unit tests: pytest tests/compile/passes/test_fusion.py -k "gated" — positive and negative pattern match cases
• [x] lm_eval --tasks gsm8k --num_fewshot 5 — accuracy unchanged vs. baseline
• [x] vllm bench serve — throughput improved ~2.3%, TPOT improved ~2.5%
• [x] Verified graceful no-op when AITER lacks GDN kernels (are_gdn_triton_kernels_available() == False)