Skip to content

[Graph][Fusion] Add MatmulAllReduceAddRMSNorm graph fusion for npugraph_ex.#6006

Merged
wangxiyuan merged 2 commits intovllm-project:mainfrom
ForBetterCodeNine:graph_ex_0119
Jan 27, 2026
Merged

[Graph][Fusion] Add MatmulAllReduceAddRMSNorm graph fusion for npugraph_ex.#6006
wangxiyuan merged 2 commits intovllm-project:mainfrom
ForBetterCodeNine:graph_ex_0119

Conversation

@ForBetterCodeNine
Copy link
Copy Markdown
Contributor

@ForBetterCodeNine ForBetterCodeNine commented Jan 19, 2026

What this PR does / why we need it?

This PR builds upon PR #5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph.

This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm and added corresponding ST test cases for regression monitoring.

Does this PR introduce any user-facing change?

No

How was this patch tested?

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new GraphEX fusion pass for Matmul + AllReduce + AddRMSNorm and an end-to-end test for the fused operator. While the initiative is good, the implementation contains several critical issues that must be addressed. The test file has a bug in tensor shape initialization and an unused parameter. The fusion pass file has multiple critical bugs, including a typo, incorrect argument passing, a missing base class inheritance, and a potential correctness issue in the replacement logic. Please review the detailed comments for fixes.

x2 = torch.ones([n, k], dtype=DTYPE).npu(rank)
else:
x1 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank)
x2 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The shape of tensor x2 is incorrect. For torch.nn.functional.linear(a, b) where a is x1 with shape [m, k], the weight b (x2 here) should have shape [n, k]. The current code initializes x2 with shape [m, k]. This will cause a shape mismatch error during the npu_add_rms_norm operation, as the residual tensor has shape [m, n], while the output of linear will have an incorrect shape.

Suggested change
x2 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank)
x2 = torch.normal(0, 0.1, [n, k], dtype=DTYPE).npu(rank)

self.local_rank,
self.eps,
True,
False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The is_allgather_add_out parameter is hardcoded to False. The pattern being replaced produces a full tensor for the second output (add_out), as it's derived from an all-reduced tensor. If is_allgather_add_out=False causes the fused operator to return a sharded add_out, this would be a correctness issue, as subsequent layers expect a full tensor. The corresponding test for this operator uses is_allgather_add_out=True. It should likely be True here as well to ensure the replacement is correct for middle layers.

Suggested change
False,
True,

hidden_size = 4096
x = torch.randn(batch_size, seq_len, hidden_size, device="npu")
weight = torch.randn(hidden_size, hidden_size, device="npu")
residual = torch.rann(batch_size, seq_len, hidden_size, device="npu")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a typo torch.rann. It should be torch.randn. This will cause an AttributeError when get_inputs is called.

Suggested change
residual = torch.rann(batch_size, seq_len, hidden_size, device="npu")
residual = torch.randn(batch_size, seq_len, hidden_size, device="npu")

torchair.register_replacement(
search_fn=pattern,
replace_fn=replacement,
example_inputs=self.get_inputs(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The example_inputs argument to torchair.register_replacement expects a callable, but it's being passed the result of self.get_inputs(). This will cause an error during pattern registration. It should be self.get_inputs to pass the method itself.

Suggested change
example_inputs=self.get_inputs(),
example_inputs=self.get_inputs,

Comment on lines +135 to +137
class GraphEXMatmulAllReduceAddRMSNormPass:
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The class GraphEXMatmulAllReduceAddRMSNormPass calls super().__init__ but does not inherit from any base class, which will cause a TypeError. Based on its usage in GraphFusionPassManager, it should inherit from VllmInductorPass. You'll also need to add from vllm.compilation.vllm_inductor_pass import VllmInductorPass at the top of the file.

Suggested change
class GraphEXMatmulAllReduceAddRMSNormPass:
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
class GraphEXMatmulAllReduceAddRMSNormPass(VllmInductorPass):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)

return rmsnorm_ret, add_ret


def worker(rank, ep_world_size, batch_size, m, k, n):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The batch_size parameter is defined but never used within the worker function. It should be removed from the function signature to improve code clarity. The call to mp.spawn on line 134 and the args tuple on line 133 should also be updated accordingly.

Suggested change
def worker(rank, ep_world_size, batch_size, m, k, n):
def worker(rank, ep_world_size, m, k, n):

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@ForBetterCodeNine ForBetterCodeNine force-pushed the graph_ex_0119 branch 5 times, most recently from 40b05fc to 61807f8 Compare January 20, 2026 02:25
@ForBetterCodeNine ForBetterCodeNine changed the title Graph ex 0119 [Graph][Fusion] Add MatmulAllReduceAddRMSNorm graph fusion for npugraph_ex. Jan 20, 2026
global_rank_id = 0


def golden_op_matmul_allreduce_add_rmsnorm(a, b, residual, gamma, epsilon):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test related to graph_ex?

Signed-off-by: cjian <2318164299@qq.com>
@ForBetterCodeNine ForBetterCodeNine force-pushed the graph_ex_0119 branch 3 times, most recently from 20278fd to 021092b Compare January 27, 2026 01:38
mm = torch.ops.vllm.unquantized_gemm(x, weight, None)
all_reduce_ = tensor_model_parallel_all_reduce(mm)
output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight)
output = torch.ops._C_ascend.npu_add_rms_norm_bias(all_reduce_, residual, rms_norm_weight, None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is caused by the introduction of custom operators. As a result, AddRmsNorm in the graph is replaced with AddRmsNormBias.

Signed-off-by: cjian <2318164299@qq.com>
@wangxiyuan wangxiyuan merged commit 54e8389 into vllm-project:main Jan 27, 2026
20 checks passed
@ForBetterCodeNine ForBetterCodeNine deleted the graph_ex_0119 branch January 27, 2026 12:22
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Jan 28, 2026
…to qwen3next_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend: (86 commits)
  [refactor] refactor excute_model and _dymmy_run method  (vllm-project#6043)
  [Refactor] profiler config optimze (vllm-project#6141)
  [Graph][Fusion] Add MatmulAllReduceAddRMSNorm graph fusion for npugraph_ex. (vllm-project#6006)
  [UT]: refactoring 310p ops ut (vllm-project#6296)
  [Refact.]: refactoring 310p-kv cache allocator, align with main branch (vllm-project#6270)
  [Misc] Removes unnecessary graph size re-initialization (vllm-project#6280)
  [Main2Main] Upgrade vllm commit to 0123 (vllm-project#6169)
  [BugFix] Fix wheel package build workflow (vllm-project#6276)
  [CI][BugFix] Qwen3-Next nightly test fix. (vllm-project#6247)
  [Doc] quick fix for vllm-ascend version (vllm-project#6278)
  [Community] Nominate whx-sjtu as maintainer (vllm-project#6268)
  [Lint] Fix mypy issue to make CI happy (vllm-project#6272)
  BugFix:  Fix moe_load accumulation error in ACL graph mode (vllm-project#6182)
  [Patch] Remove the patch of ECExampleConnector (vllm-project#5976)
  [Bugfix] Fix PP+PCP and PP+flashcomm1 bugs (vllm-project#5416)
  [Feat] proxy delay to remove instances (vllm-project#5934)
  [CI] Add workfolw_dispatch for nightly image build (vllm-project#6269)
  [bugfix][npugraph_ex]fix static kernel uninstall issue (vllm-project#6128)
  [Doc] 310P Documents update (vllm-project#6246)
  [Feature] Mooncake connector get remote ptp size (vllm-project#5822)
  ...
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…ph_ex. (vllm-project#6006)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm
and added corresponding ST test cases for regression monitoring.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2c24bc6

---------

Signed-off-by: cjian <2318164299@qq.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
…ph_ex. (vllm-project#6006)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm
and added corresponding ST test cases for regression monitoring.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2c24bc6

---------

Signed-off-by: cjian <2318164299@qq.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
@wangxiyuan wangxiyuan mentioned this pull request Feb 24, 2026
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…ph_ex. (vllm-project#6006)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm
and added corresponding ST test cases for regression monitoring.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2c24bc6

---------

Signed-off-by: cjian <2318164299@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…ph_ex. (vllm-project#6006)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm
and added corresponding ST test cases for regression monitoring.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2c24bc6

---------

Signed-off-by: cjian <2318164299@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…ph_ex. (vllm-project#6006)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm
and added corresponding ST test cases for regression monitoring.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2c24bc6

---------

Signed-off-by: cjian <2318164299@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…ph_ex. (vllm-project#6006)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm
and added corresponding ST test cases for regression monitoring.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2c24bc6

---------

Signed-off-by: cjian <2318164299@qq.com>
jiangyunfan1 pushed a commit to jiangyunfan1/vllm-ascend that referenced this pull request Apr 9, 2026
…ph_ex. (vllm-project#6006)

### What this PR does / why we need it?
This PR builds upon PR
vllm-project#5011 and aims to
further enhance the npu_graph_ex_passes module. Based on prior work, we
have added graph optimization support for the add_rms_quant fused
operator in scenarios where a bias term is present—ensuring the fusion
pattern is correctly registered and matched into the computation graph.

This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm
and added corresponding ST test cases for regression monitoring.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@2c24bc6

---------

Signed-off-by: cjian <2318164299@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants