[Graph][Fusion] Add AddRMSNorm(with bias)#5491
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new TorchAIR replacement pattern to fuse RMS normalization, bias addition, and quantization into a single npu_add_rms_norm_quant operation. While the goal is to improve performance, the current implementation contains several critical issues. The fusion pattern is incorrectly defined using the target fused operator instead of the sequence of operations to be fused. Additionally, there are errors in the return value indexing of the replacement function and in the signature of the example input generator function. These issues need to be addressed for the fusion to work as intended.
| output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, residual, | ||
| rms_norm_weight, epsilon) | ||
| out0 = output[0] | ||
| out1 = output[1] | ||
| out0 = out0 + bias | ||
| quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, | ||
| torch.qint8, -1, False) | ||
| return quantized_output, out1 |
There was a problem hiding this comment.
The pattern definition is incorrect. It should describe the sequence of operations to be fused, which are npu_add_rms_norm, bias addition, and npu_quantize. However, it's currently using npu_add_rms_norm_quant, which is the fused operator itself. This will not match the intended graph pattern. Furthermore, npu_add_rms_norm returns the residual at index 2, not 1.
| output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, residual, | |
| rms_norm_weight, epsilon) | |
| out0 = output[0] | |
| out1 = output[1] | |
| out0 = out0 + bias | |
| quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, | |
| torch.qint8, -1, False) | |
| return quantized_output, out1 | |
| output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, | |
| rms_norm_weight, epsilon) | |
| out0 = output[0] | |
| out1 = output[2] | |
| out0 = out0 + bias | |
| quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, | |
| torch.qint8, -1, False) | |
| return quantized_output, out1 |
| quantized_output = output[0] | ||
| out1 = output[2] | ||
| return quantized_output, out1 |
There was a problem hiding this comment.
The replacement function's return values do not match the pattern's return values. The pattern returns (quantized_output, residual). The npu_add_rms_norm_quant operator returns a tuple of (quantized_output, residual, rstd). The residual is at index 1, but the code returns output[2], which is rstd. This mismatch will lead to incorrect behavior downstream.
| quantized_output = output[0] | |
| out1 = output[2] | |
| return quantized_output, out1 | |
| quantized_output = output[0] | |
| out1 = output[1] | |
| return quantized_output, out1 |
| def get_inputs(self): | ||
| """ | ||
| Generate example inputs for the AddRMSNormQuant fusion pattern. | ||
| """ | ||
| rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) | ||
| residual = torch.randn(2, 4, device="npu", dtype=self.dtype) | ||
| rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) | ||
| rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype) | ||
| scale = torch.ones(4, device="npu", dtype=self.dtype) | ||
| scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) | ||
| offset = torch.zeros(4, device="npu", dtype=self.dtype) | ||
| return [ | ||
| rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, | ||
| offset, rmsnorm_bias | ||
| ] |
There was a problem hiding this comment.
The get_inputs function is defined with self as its first parameter and accesses self.dtype, but it is a nested function, not a class method. This will cause a TypeError when called and an AttributeError if it were to run. The function should be defined without self.
| def get_inputs(self): | |
| """ | |
| Generate example inputs for the AddRMSNormQuant fusion pattern. | |
| """ | |
| rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) | |
| residual = torch.randn(2, 4, device="npu", dtype=self.dtype) | |
| rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) | |
| rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype) | |
| scale = torch.ones(4, device="npu", dtype=self.dtype) | |
| scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype) | |
| offset = torch.zeros(4, device="npu", dtype=self.dtype) | |
| return [ | |
| rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, | |
| offset, rmsnorm_bias | |
| ] | |
| def get_inputs(): | |
| """ | |
| Generate example inputs for the AddRMSNormQuant fusion pattern. | |
| """ | |
| dtype = torch.float16 | |
| rms_norm_input = torch.randn(2, 4, device="npu", dtype=dtype) | |
| residual = torch.randn(2, 4, device="npu", dtype=dtype) | |
| rms_norm_weight = torch.randn(4, device="npu", dtype=dtype) | |
| rmsnorm_bias = torch.randn(4, device="npu", dtype=dtype) | |
| scale = torch.ones(4, device="npu", dtype=dtype) | |
| scale_reciprocal = torch.ones(4, device="npu", dtype=dtype) | |
| offset = torch.zeros(4, device="npu", dtype=dtype) | |
| return [ | |
| rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, | |
| offset, rmsnorm_bias | |
| ] |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
08930e7 to
7ee557a
Compare
d4448bd to
f7a08ca
Compare
Signed-off-by: cjian <2318164299@qq.com>
f7a08ca to
29668bc
Compare
### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@5326c89 Signed-off-by: cjian <2318164299@qq.com> Signed-off-by: wjunLu <wjunlu217@gmail.com>
### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@5326c89 Signed-off-by: cjian <2318164299@qq.com>
### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@5326c89 Signed-off-by: cjian <2318164299@qq.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@5326c89 Signed-off-by: cjian <2318164299@qq.com>
### What this PR does / why we need it? This PR builds upon PR vllm-project#5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph. For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models. For more details can refer to the RFC:vllm-project#4715 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.98, max_num_batched_tokens=512, # load_format="dummy", max_model_len=2048, max_num_seqs=16, quantization="ascend", additional_config={ "refresh": True, "enable_npugraph_ex": True }, compilation_config={ "cudagraph_capture_sizes": [8, 16], "cudagraph_mode": "FULL_DECODE_ONLY", }, ) if profile_dir: llm.start_profile() outputs = llm.generate(prompts, sampling_params) if profile_dir: llm.stop_profile() for i, output in enumerate(outputs): if i >= 5: break prompt = output.prompt generated_text = output.outputs[0].text print( f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}" ) ``` - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@5326c89 Signed-off-by: cjian <2318164299@qq.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
What this PR does / why we need it?
This PR builds upon PR #5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph.
For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models.
For more details can refer to the RFC:#4715
Does this PR introduce any user-facing change?
no
How was this patch tested?