Skip to content

[Graph][Fusion] Add AddRMSNorm(with bias)#5491

Merged
wangxiyuan merged 1 commit intovllm-project:mainfrom
ForBetterCodeNine:add_rms
Dec 31, 2025
Merged

[Graph][Fusion] Add AddRMSNorm(with bias)#5491
wangxiyuan merged 1 commit intovllm-project:mainfrom
ForBetterCodeNine:add_rms

Conversation

@ForBetterCodeNine
Copy link
Copy Markdown
Contributor

@ForBetterCodeNine ForBetterCodeNine commented Dec 29, 2025

What this PR does / why we need it?

This PR builds upon PR #5011 and aims to further enhance the npu_graph_ex_passes module. Based on prior work, we have added graph optimization support for the add_rms_quant fused operator in scenarios where a bias term is present—ensuring the fusion pattern is correctly registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark results show that, compared to the unfused baseline, enabling this fusion pass significantly improves inference throughput for W8A8 quantized models.
For more details can refer to the RFC:#4715

Does this PR introduce any user-facing change?

no

How was this patch tested?

llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new TorchAIR replacement pattern to fuse RMS normalization, bias addition, and quantization into a single npu_add_rms_norm_quant operation. While the goal is to improve performance, the current implementation contains several critical issues. The fusion pattern is incorrectly defined using the target fused operator instead of the sequence of operations to be fused. Additionally, there are errors in the return value indexing of the replacement function and in the signature of the example input generator function. These issues need to be addressed for the fusion to work as intended.

Comment on lines +163 to +170
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, residual,
rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[1]
out0 = out0 + bias
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
return quantized_output, out1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The pattern definition is incorrect. It should describe the sequence of operations to be fused, which are npu_add_rms_norm, bias addition, and npu_quantize. However, it's currently using npu_add_rms_norm_quant, which is the fused operator itself. This will not match the intended graph pattern. Furthermore, npu_add_rms_norm returns the residual at index 2, not 1.

Suggested change
output = torch.ops.npu.npu_add_rms_norm_quant(rms_norm_input, residual,
rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[1]
out0 = out0 + bias
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
return quantized_output, out1
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
return quantized_output, out1

Comment on lines +187 to +189
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The replacement function's return values do not match the pattern's return values. The pattern returns (quantized_output, residual). The npu_add_rms_norm_quant operator returns a tuple of (quantized_output, residual, rstd). The residual is at index 1, but the code returns output[2], which is rstd. This mismatch will lead to incorrect behavior downstream.

Suggested change
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
quantized_output = output[0]
out1 = output[1]
return quantized_output, out1

Comment on lines +191 to +205
def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The get_inputs function is defined with self as its first parameter and accesses self.dtype, but it is a nested function, not a class method. This will cause a TypeError when called and an AttributeError if it were to run. The function should be defined without self.

Suggested change
def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype)
residual = torch.randn(2, 4, device="npu", dtype=self.dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype)
rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype)
scale = torch.ones(4, device="npu", dtype=self.dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=self.dtype)
offset = torch.zeros(4, device="npu", dtype=self.dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]
def get_inputs():
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
dtype = torch.float16
rms_norm_input = torch.randn(2, 4, device="npu", dtype=dtype)
residual = torch.randn(2, 4, device="npu", dtype=dtype)
rms_norm_weight = torch.randn(4, device="npu", dtype=dtype)
rmsnorm_bias = torch.randn(4, device="npu", dtype=dtype)
scale = torch.ones(4, device="npu", dtype=dtype)
scale_reciprocal = torch.ones(4, device="npu", dtype=dtype)
offset = torch.zeros(4, device="npu", dtype=dtype)
return [
rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal,
offset, rmsnorm_bias
]

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@ForBetterCodeNine ForBetterCodeNine force-pushed the add_rms branch 8 times, most recently from 08930e7 to 7ee557a Compare December 30, 2025 07:12
@ForBetterCodeNine ForBetterCodeNine changed the title add rms norm quant with bias [Graph][Fusion] Add AddRMSNorm(with bias) Dec 30, 2025
@ForBetterCodeNine ForBetterCodeNine force-pushed the add_rms branch 6 times, most recently from d4448bd to f7a08ca Compare December 31, 2025 06:08
Signed-off-by: cjian <2318164299@qq.com>
@wangxiyuan wangxiyuan merged commit 80fc0f5 into vllm-project:main Dec 31, 2025
19 checks passed
@ForBetterCodeNine ForBetterCodeNine deleted the add_rms branch January 4, 2026 02:16
wjunLu pushed a commit to wjunLu/vllm-ascend that referenced this pull request Jan 4, 2026
### What this PR does / why we need it?
This PR builds upon PR vllm-project#5011 and aims to further enhance the
npu_graph_ex_passes module. Based on prior work, we have added graph
optimization support for the add_rms_quant fused operator in scenarios
where a bias term is present—ensuring the fusion pattern is correctly
registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:vllm-project#4715

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: cjian <2318164299@qq.com>
Signed-off-by: wjunLu <wjunlu217@gmail.com>
Rozwel-dx pushed a commit to Rozwel-dx/vllm-ascend that referenced this pull request Jan 8, 2026
### What this PR does / why we need it?
This PR builds upon PR vllm-project#5011 and aims to further enhance the
npu_graph_ex_passes module. Based on prior work, we have added graph
optimization support for the add_rms_quant fused operator in scenarios
where a bias term is present—ensuring the fusion pattern is correctly
registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:vllm-project#4715

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: cjian <2318164299@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?
This PR builds upon PR vllm-project#5011 and aims to further enhance the
npu_graph_ex_passes module. Based on prior work, we have added graph
optimization support for the add_rms_quant fused operator in scenarios
where a bias term is present—ensuring the fusion pattern is correctly
registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:vllm-project#4715

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: cjian <2318164299@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?
This PR builds upon PR vllm-project#5011 and aims to further enhance the
npu_graph_ex_passes module. Based on prior work, we have added graph
optimization support for the add_rms_quant fused operator in scenarios
where a bias term is present—ensuring the fusion pattern is correctly
registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:vllm-project#4715

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: cjian <2318164299@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?
This PR builds upon PR vllm-project#5011 and aims to further enhance the
npu_graph_ex_passes module. Based on prior work, we have added graph
optimization support for the add_rms_quant fused operator in scenarios
where a bias term is present—ensuring the fusion pattern is correctly
registered and matched into the computation graph.

For validation, we switched to the Qwen3-235B-A22B-W8A8 model. Benchmark
results show that, compared to the unfused baseline, enabling this
fusion pass significantly improves inference throughput for W8A8
quantized models.
For more details can refer to the
RFC:vllm-project#4715

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
```
llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=False,
        enable_expert_parallel=enable_expert_parallel,
        trust_remote_code=trust_remote_code,
        gpu_memory_utilization=0.98,
        max_num_batched_tokens=512,
        # load_format="dummy",
        max_model_len=2048,
        max_num_seqs=16,
        quantization="ascend",
        additional_config={
            "refresh": True,
            "enable_npugraph_ex": True
        },
        compilation_config={
            "cudagraph_capture_sizes": [8, 16],
            "cudagraph_mode": "FULL_DECODE_ONLY",
        },
    )
    if profile_dir:
        llm.start_profile()
    outputs = llm.generate(prompts, sampling_params)
    if profile_dir:
        llm.stop_profile()
    for i, output in enumerate(outputs):
        if i >= 5:
            break
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
```
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: cjian <2318164299@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants