[Graph][Fusion] Add MatmulAllReduceAddRMSNorm graph fusion for npugraph_ex.#6006
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new GraphEX fusion pass for Matmul + AllReduce + AddRMSNorm and an end-to-end test for the fused operator. While the initiative is good, the implementation contains several critical issues that must be addressed. The test file has a bug in tensor shape initialization and an unused parameter. The fusion pass file has multiple critical bugs, including a typo, incorrect argument passing, a missing base class inheritance, and a potential correctness issue in the replacement logic. Please review the detailed comments for fixes.
| x2 = torch.ones([n, k], dtype=DTYPE).npu(rank) | ||
| else: | ||
| x1 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank) | ||
| x2 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank) |
There was a problem hiding this comment.
The shape of tensor x2 is incorrect. For torch.nn.functional.linear(a, b) where a is x1 with shape [m, k], the weight b (x2 here) should have shape [n, k]. The current code initializes x2 with shape [m, k]. This will cause a shape mismatch error during the npu_add_rms_norm operation, as the residual tensor has shape [m, n], while the output of linear will have an incorrect shape.
| x2 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank) | |
| x2 = torch.normal(0, 0.1, [n, k], dtype=DTYPE).npu(rank) |
| self.local_rank, | ||
| self.eps, | ||
| True, | ||
| False, |
There was a problem hiding this comment.
The is_allgather_add_out parameter is hardcoded to False. The pattern being replaced produces a full tensor for the second output (add_out), as it's derived from an all-reduced tensor. If is_allgather_add_out=False causes the fused operator to return a sharded add_out, this would be a correctness issue, as subsequent layers expect a full tensor. The corresponding test for this operator uses is_allgather_add_out=True. It should likely be True here as well to ensure the replacement is correct for middle layers.
| False, | |
| True, |
| hidden_size = 4096 | ||
| x = torch.randn(batch_size, seq_len, hidden_size, device="npu") | ||
| weight = torch.randn(hidden_size, hidden_size, device="npu") | ||
| residual = torch.rann(batch_size, seq_len, hidden_size, device="npu") |
There was a problem hiding this comment.
| torchair.register_replacement( | ||
| search_fn=pattern, | ||
| replace_fn=replacement, | ||
| example_inputs=self.get_inputs(), |
There was a problem hiding this comment.
The example_inputs argument to torchair.register_replacement expects a callable, but it's being passed the result of self.get_inputs(). This will cause an error during pattern registration. It should be self.get_inputs to pass the method itself.
| example_inputs=self.get_inputs(), | |
| example_inputs=self.get_inputs, |
| class GraphEXMatmulAllReduceAddRMSNormPass: | ||
| def __init__(self, vllm_config: VllmConfig): | ||
| super().__init__(vllm_config) |
There was a problem hiding this comment.
The class GraphEXMatmulAllReduceAddRMSNormPass calls super().__init__ but does not inherit from any base class, which will cause a TypeError. Based on its usage in GraphFusionPassManager, it should inherit from VllmInductorPass. You'll also need to add from vllm.compilation.vllm_inductor_pass import VllmInductorPass at the top of the file.
| class GraphEXMatmulAllReduceAddRMSNormPass: | |
| def __init__(self, vllm_config: VllmConfig): | |
| super().__init__(vllm_config) | |
| from vllm.compilation.vllm_inductor_pass import VllmInductorPass | |
| class GraphEXMatmulAllReduceAddRMSNormPass(VllmInductorPass): | |
| def __init__(self, vllm_config: VllmConfig): | |
| super().__init__(vllm_config) |
| return rmsnorm_ret, add_ret | ||
|
|
||
|
|
||
| def worker(rank, ep_world_size, batch_size, m, k, n): |
There was a problem hiding this comment.
The batch_size parameter is defined but never used within the worker function. It should be removed from the function signature to improve code clarity. The call to mp.spawn on line 134 and the args tuple on line 133 should also be updated accordingly.
| def worker(rank, ep_world_size, batch_size, m, k, n): | |
| def worker(rank, ep_world_size, m, k, n): |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
40b05fc to
61807f8
Compare
| global_rank_id = 0 | ||
|
|
||
|
|
||
| def golden_op_matmul_allreduce_add_rmsnorm(a, b, residual, gamma, epsilon): |
There was a problem hiding this comment.
Is this test related to graph_ex?
61807f8 to
0f5e65a
Compare
Signed-off-by: cjian <2318164299@qq.com>
20278fd to
021092b
Compare
| mm = torch.ops.vllm.unquantized_gemm(x, weight, None) | ||
| all_reduce_ = tensor_model_parallel_all_reduce(mm) | ||
| output = torch.ops.npu.npu_add_rms_norm(all_reduce_, residual, rms_norm_weight) | ||
| output = torch.ops._C_ascend.npu_add_rms_norm_bias(all_reduce_, residual, rms_norm_weight, None) |
There was a problem hiding this comment.
This change is caused by the introduction of custom operators. As a result, AddRmsNorm in the graph is replaced with AddRmsNormBias.
Signed-off-by: cjian <2318164299@qq.com>
021092b to
5772319
Compare
…to qwen3next_rebase * 'main' of https://github.com/vllm-project/vllm-ascend: (86 commits) [refactor] refactor excute_model and _dymmy_run method (vllm-project#6043) [Refactor] profiler config optimze (vllm-project#6141) [Graph][Fusion] Add MatmulAllReduceAddRMSNorm graph fusion for npugraph_ex. (vllm-project#6006) [UT]: refactoring 310p ops ut (vllm-project#6296) [Refact.]: refactoring 310p-kv cache allocator, align with main branch (vllm-project#6270) [Misc] Removes unnecessary graph size re-initialization (vllm-project#6280) [Main2Main] Upgrade vllm commit to 0123 (vllm-project#6169) [BugFix] Fix wheel package build workflow (vllm-project#6276) [CI][BugFix] Qwen3-Next nightly test fix. (vllm-project#6247) [Doc] quick fix for vllm-ascend version (vllm-project#6278) [Community] Nominate whx-sjtu as maintainer (vllm-project#6268) [Lint] Fix mypy issue to make CI happy (vllm-project#6272) BugFix: Fix moe_load accumulation error in ACL graph mode (vllm-project#6182) [Patch] Remove the patch of ECExampleConnector (vllm-project#5976) [Bugfix] Fix PP+PCP and PP+flashcomm1 bugs (vllm-project#5416) [Feat] proxy delay to remove instances (vllm-project#5934) [CI] Add workfolw_dispatch for nightly image build (vllm-project#6269) [bugfix][npugraph_ex]fix static kernel uninstall issue (vllm-project#6128) [Doc] 310P Documents update (vllm-project#6246) [Feature] Mooncake connector get remote ptp size (vllm-project#5822) ...
…ph_ex. (vllm-project#6006) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm and added corresponding ST test cases for regression monitoring. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2c24bc6 --------- Signed-off-by: cjian <2318164299@qq.com>
…ph_ex. (vllm-project#6006) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm and added corresponding ST test cases for regression monitoring. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2c24bc6 --------- Signed-off-by: cjian <2318164299@qq.com> Signed-off-by: momochenchuw <chenchuw@huawei.com>
…ph_ex. (vllm-project#6006) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm and added corresponding ST test cases for regression monitoring. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2c24bc6 --------- Signed-off-by: cjian <2318164299@qq.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
…ph_ex. (vllm-project#6006) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm and added corresponding ST test cases for regression monitoring. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2c24bc6 --------- Signed-off-by: cjian <2318164299@qq.com>
…ph_ex. (vllm-project#6006) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm and added corresponding ST test cases for regression monitoring. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2c24bc6 --------- Signed-off-by: cjian <2318164299@qq.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
…ph_ex. (vllm-project#6006) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm and added corresponding ST test cases for regression monitoring. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2c24bc6 --------- Signed-off-by: cjian <2318164299@qq.com>
…ph_ex. (vllm-project#6006) ### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm and added corresponding ST test cases for regression monitoring. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2c24bc6 --------- Signed-off-by: cjian <2318164299@qq.com>
What this PR does / why we need it?
This PR builds upon PR #5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph.
This time, we performed the operator fusion of MatmulAllReduceAddRMSNorm and added corresponding ST test cases for regression monitoring.
Does this PR introduce any user-facing change?
No
How was this patch tested?