From e7c6e1eb41acda56ac085b509a174a7391eed14e Mon Sep 17 00:00:00 2001 From: cjian <2318164299@qq.com> Date: Thu, 8 Jan 2026 15:57:33 +0800 Subject: [PATCH 1/2] qkv norm Signed-off-by: cjian <2318164299@qq.com> --- .../ut/compilation/test_add_rms_norm_quant.py | 148 -------- .../test_npugraph_ex_utils_check.py | 54 +++ vllm_ascend/compilation/compiler_interface.py | 4 - .../compilation/npu_graph_ex_pass_manager.py | 51 +++ .../npugraph_ex_passes/add_rms_norm_quant.py | 320 ------------------ .../graphex_norm_quant_fusion_pass.py | 316 +++++++++++++++++ .../graphex_qknorm_rope_fusion_pass.py | 227 +++++++++++++ .../npugraph_ex_passes/utils/__init__.py | 0 .../utils/npugraph_ex_utils_check.py | 53 +++ vllm_ascend/platform.py | 8 +- 10 files changed, 707 insertions(+), 474 deletions(-) delete mode 100644 tests/ut/compilation/test_add_rms_norm_quant.py create mode 100644 tests/ut/compilation/test_npugraph_ex_utils_check.py create mode 100644 vllm_ascend/compilation/npu_graph_ex_pass_manager.py delete mode 100644 vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py create mode 100644 vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py create mode 100644 vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py create mode 100644 vllm_ascend/compilation/npugraph_ex_passes/utils/__init__.py create mode 100644 vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py diff --git a/tests/ut/compilation/test_add_rms_norm_quant.py b/tests/ut/compilation/test_add_rms_norm_quant.py deleted file mode 100644 index d056676ca46..00000000000 --- a/tests/ut/compilation/test_add_rms_norm_quant.py +++ /dev/null @@ -1,148 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -import sys -from unittest import mock - -import torch - - -def get_inputs(): - """ - Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern. - """ - rms_norm_input = torch.randn(2, 4) - residual = torch.randn(2, 4) - rms_norm_weight = torch.randn(4) - rmsnorm_bias = torch.randn(4) - scale = torch.ones(4) - offset = torch.zeros(4) - return [ - rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias - ] - - -def _extra_stream_scope_check_for_test(match) -> bool: - """ - Copied from the original implementation for testability. - Checks if all nodes in the same stream. - """ - non_default_streams = set() - has_default = False - - for node in match.nodes: - if node.op == "call_function": - current_stream = node.meta.get("stream_label") - if current_stream is None: - has_default = True - else: - non_default_streams.add(current_stream) - if len(non_default_streams) > 1: - return False - - if has_default and len(non_default_streams) > 0: - return False - - return True - - -def test_extra_stream_scope_check(): - """Test the stream scope check logic.""" - - class MockNode: - - def __init__(self, stream_label=None): - self.op = "call_function" - self.meta = {"stream_label": stream_label} - - class MockMatch: - - def __init__(self, nodes): - self.nodes = nodes - - # Test 1: all default stream (None) → OK - match1 = MockMatch([MockNode(None), MockNode(None)]) - assert _extra_stream_scope_check_for_test(match1) is True - - # Test 2: all same non-default stream → OK - match2 = MockMatch([MockNode("s1"), MockNode("s1")]) - assert _extra_stream_scope_check_for_test(match2) is True - - # Test 3: mixed streams → FAIL - match3 = MockMatch([MockNode("s1"), MockNode("s2")]) - assert _extra_stream_scope_check_for_test(match3) is False - - # Test 4: default + non-default → FAIL - match4 = MockMatch([MockNode(None), MockNode("s1")]) - assert _extra_stream_scope_check_for_test(match4) is False - - # Test 5: empty nodes → OK (edge case) - match5 = MockMatch([]) - assert _extra_stream_scope_check_for_test(match5) is True - - -def test_replacement_function_without_torch_npu(caplog): - with mock.patch.dict(sys.modules, { - 'torch_npu': None, - 'torchair': None, - 'torch_npu.dynamo': None - }): - if 'vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant' in sys.modules: - del sys.modules[ - 'vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant'] - - try: - from vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant import \ - replacement_add_rms_norm_quant_with_bias - result = replacement_add_rms_norm_quant_with_bias(epsilon=1e-5) - assert result is None - except (ImportError, AttributeError): - pass - - -def test_get_inputs_sp_pattern_with_bias(): - """ - Test that get_inputs generates tensors with correct shapes and device. - This test verifies the internal get_inputs function used in the pattern. - """ - try: - import torch - except ImportError: - return # Skip if torch is not available - - inputs = get_inputs() - ( - rms_norm_input, - residual, - rms_norm_weight, - scale, - offset, - rmsnorm_bias, - ) = inputs - - # Verify shapes - assert rms_norm_input.shape == (2, 4) - assert residual.shape == (2, 4) - assert rms_norm_weight.shape == (4, ) - assert rmsnorm_bias.shape == (4, ) - assert scale.shape == (4, ) - assert offset.shape == (4, ) - - # Verify number of inputs - assert len(inputs) == 6 - - # Verify specific values - assert torch.all(scale == 1.0) - assert torch.all(offset == 0.0) diff --git a/tests/ut/compilation/test_npugraph_ex_utils_check.py b/tests/ut/compilation/test_npugraph_ex_utils_check.py new file mode 100644 index 00000000000..2be1ce58fe8 --- /dev/null +++ b/tests/ut/compilation/test_npugraph_ex_utils_check.py @@ -0,0 +1,54 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import \ + extra_stream_scope_check + + +def test_extra_stream_scope_check_logic(): + """ + Test the extra_stream_scope_check logic used in both fusion patterns. + This is a pure function test (copied logic for testability). + """ + + class MockNode: + + def __init__(self, stream_label=None): + self.op = "call_function" + self.meta = {"stream_label": stream_label} + + class MockMatch: + + def __init__(self, nodes): + self.nodes = nodes + + # Test 1: all default → OK + assert extra_stream_scope_check( + MockMatch([MockNode(None), MockNode(None)])) is True + + # Test 2: same non-default → OK + assert extra_stream_scope_check( + MockMatch([MockNode("s1"), MockNode("s1")])) is True + + # Test 3: mixed non-default → FAIL + assert extra_stream_scope_check( + MockMatch([MockNode("s1"), MockNode("s2")])) is False + + # Test 4: default + non-default → FAIL + assert extra_stream_scope_check( + MockMatch([MockNode(None), MockNode("s1")])) is False + + # Test 5: empty → OK + assert extra_stream_scope_check(MockMatch([])) is True diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index b9917bb1734..02986cf1af9 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -90,10 +90,6 @@ def npugraph_ex_compile( graph.recompile() import torchair - # TODO: use a better way to lazy register replacement, instead of import one by one - # As an example, we directly import here to register replacement. - # import vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant # noqa - torch.npu.set_compile_mode(jit_compile=False) config = torchair.CompilerConfig() # use aclgraph mode, avoid the transformation from fx graph to Ascend IR. diff --git a/vllm_ascend/compilation/npu_graph_ex_pass_manager.py b/vllm_ascend/compilation/npu_graph_ex_pass_manager.py new file mode 100644 index 00000000000..be810de0842 --- /dev/null +++ b/vllm_ascend/compilation/npu_graph_ex_pass_manager.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from torch import fx as fx +from vllm.compilation.inductor_pass import get_pass_context +from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.config import VllmConfig + + +class NpuGraphEXPassManager: + """ + A pass manager for npu_graph ex fusion passes. + It handles the configuration and execution of passes. + The counterpart in vllm is PostGradPassManager. Since torch_npu + does not support triton for now, we define our own pass manager. + """ + + def __init__(self): + self.passes: list[VllmInductorPass] = [] + + def __call__(self, graph: fx.Graph) -> fx.Graph: + compile_range = get_pass_context().compile_range + + for pass_ in self.passes: + if pass_.is_applicable_for_range(compile_range): + pass_(graph) + graph.recompiler() + return graph + + def add(self, pass_: VllmInductorPass): + assert isinstance(pass_, VllmInductorPass) + self.passes.append(pass_) + + def configure(self, config: VllmConfig): + # By default, we enable the graph fusion and quantization fusion pass. + self.ascend_compilation_config: dict = config.additional_config.get("ascend_compilation_config", {}) diff --git a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py deleted file mode 100644 index 9da4548c5c5..00000000000 --- a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py +++ /dev/null @@ -1,320 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import functools - -import torch -from torch._inductor.pattern_matcher import Match -from vllm.logger import logger - - -def _extra_stream_scope_check(match: Match) -> bool: - """ - Checks if all nodes in the same stream. - """ - non_default_streams = set() - has_default = False - - for node in match.nodes: - if node.op == "call_function": - current_stream = node.meta.get("stream_label") - if current_stream is None: - has_default = True - else: - non_default_streams.add(current_stream) - if len(non_default_streams) > 1: - logger.debug( - f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " - f"Multiple streams found: {non_default_streams}. " - f"Fusion is not supported for cross-stream operations." - ) - return False - - if has_default and len(non_default_streams) > 0: - logger.debug( - f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " - f"Multiple streams found: {non_default_streams}. " - f"Fusion is not supported for cross-stream operations." - ) - return False - - return True - - -@functools.lru_cache(None) -# The replacement registered here will be actually executed after AOT. -def replacement_add_rms_norm_quant(epsilon): - def pattern( - rms_norm_input: torch.Tensor, - residual: torch.Tensor, - rms_norm_weight: torch.Tensor, - scale: torch.Tensor, - offset: torch.Tensor, - ): - """ - Pattern for AddRMSNormQuant fusion. - """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon) - out0 = output[0] - out1 = output[2] - quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False) - return quantized_output, out1 - - def replacement( - rms_norm_input: torch.Tensor, - residual: torch.Tensor, - rms_norm_weight: torch.Tensor, - scale: torch.Tensor, - offset: torch.Tensor, - ): - """ - Replacement for the AddRMSNormQuant fusion. - """ - output = torch.ops.npu.npu_add_rms_norm_quant( - rms_norm_input, - residual, - rms_norm_weight, - # The inverse of scale is required by npu_add_rms_norm_quant kernel - # which is opposite to the npu_quantize kernel. - 1.0 / scale, - offset, - epsilon=epsilon, - ) - quantized_output = output[0] - out1 = output[2] - return quantized_output, out1 - - def get_inputs(): - """ - Generate example inputs for the AddRMSNormQuant fusion pattern. - """ - rms_norm_input = torch.randn(2, 4, device="npu") - residual = torch.randn(2, 4, device="npu") - rms_norm_weight = torch.randn(4, device="npu") - scale = torch.tensor([1.0], device="npu") - offset = torch.tensor([0.0], device="npu") - return [rms_norm_input, residual, rms_norm_weight, scale, offset] - - import torchair - - torchair.register_replacement( - search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check - ) - - -# The replacement registered here will be actually executed after AOT. -def replacement_add_rms_norm_quant_with_bias(epsilon): - def pattern( - rms_norm_input: torch.Tensor, - residual: torch.Tensor, - rms_norm_weight: torch.Tensor, - scale: torch.Tensor, - offset: torch.Tensor, - bias: torch.Tensor, - ): - """ - Pattern for AddRMSNormQuantWithBias fusion. - """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon) - out0 = output[0] - out1 = output[2] - out0 = out0 + bias - quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False) - return quantized_output, out1 - - def replacement( - rms_norm_input: torch.Tensor, - residual: torch.Tensor, - rms_norm_weight: torch.Tensor, - scale: torch.Tensor, - offset: torch.Tensor, - bias: torch.Tensor, - ): - """ - Replacement for AddRMSNormQuantWithBias fusion. - """ - output = torch.ops.npu.npu_add_rms_norm_quant( - rms_norm_input, - residual, - rms_norm_weight, - # The inverse of scale is required by npu_add_rms_norm_quant kernel - # which is opposite to the npu_quantize kernel. - 1.0 / scale, - offset, - epsilon=epsilon, - beta=bias, - ) - quantized_output = output[0] - out1 = output[2] - return quantized_output, out1 - - def get_inputs(): - """ - Generate example inputs for the AddRMSNormQuantWithBias fusion pattern. - """ - rms_norm_input = torch.randn(2, 4, device="npu") - residual = torch.randn(2, 4, device="npu") - rms_norm_weight = torch.randn(4, device="npu") - rmsnorm_bias = torch.randn(4, device="npu") - scale = torch.ones(4, device="npu") - offset = torch.zeros(4, device="npu") - return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias] - - import torchair - - torchair.register_replacement( - search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check - ) - - -# The replacement registered here will be actually executed after AOT. -def replacement_add_rms_norm_quant_sp_pattern(epsilon): - def pattern( - rms_norm_input: torch.Tensor, - residual: torch.Tensor, - rms_norm_weight: torch.Tensor, - scale: torch.Tensor, - offset: torch.Tensor, - ): - """ - Pattern for AddRMSNormQuantSPPattern fusion. - """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon) - out0 = output[0] - out1 = output[2] - out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) - quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False) - return quantized_output, out1 - - def replacement( - rms_norm_input: torch.Tensor, - residual: torch.Tensor, - rms_norm_weight: torch.Tensor, - scale: torch.Tensor, - offset: torch.Tensor, - ): - """ - Replacement for the AddRMSNormQuantSPPattern fusion. - """ - output = torch.ops.npu.npu_add_rms_norm_quant( - rms_norm_input, - residual, - rms_norm_weight, - # The inverse of scale is required by npu_add_rms_norm_quant kernel - # which is opposite to the npu_quantize kernel. - 1.0 / scale, - offset, - epsilon=epsilon, - ) - quantized_output = output[0] - out1 = output[2] - quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True) - return quantized_output, out1 - - def get_inputs(): - """ - Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern. - """ - rms_norm_input = torch.randn(2, 4, device="npu") - residual = torch.randn(2, 4, device="npu") - rms_norm_weight = torch.randn(4, device="npu") - scale = torch.ones(4, device="npu") - offset = torch.zeros(4, device="npu") - return [rms_norm_input, residual, rms_norm_weight, scale, offset] - - import torchair - - torchair.register_replacement( - search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check - ) - - -# The replacement registered here will be actually executed after AOT. -def replacement_add_rms_norm_quant_sp_pattern_with_bias(epsilon): - def pattern( - rms_norm_input: torch.Tensor, - residual: torch.Tensor, - rms_norm_weight: torch.Tensor, - scale: torch.Tensor, - offset: torch.Tensor, - bias: torch.Tensor, - ): - """ - Pattern for AddRMSNormQuantSPPatternWithBias fusion. - """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, epsilon) - out0 = output[0] - out1 = output[2] - out0 = out0 + bias - out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) - quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, torch.qint8, -1, False) - return quantized_output, out1 - - def replacement( - rms_norm_input: torch.Tensor, - residual: torch.Tensor, - rms_norm_weight: torch.Tensor, - scale: torch.Tensor, - offset: torch.Tensor, - bias: torch.Tensor, - ): - """ - Replacement for the AddRMSNormQuantSPPatternWithBias fusion. - """ - output = torch.ops.npu.npu_add_rms_norm_quant( - rms_norm_input, - residual, - rms_norm_weight, - # The inverse of scale is required by npu_add_rms_norm_quant kernel - # which is opposite to the npu_quantize kernel. - 1.0 / scale, - offset, - epsilon=epsilon, - beta=bias, - ) - quantized_output = output[0] - out1 = output[2] - quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True) - return quantized_output, out1 - - def get_inputs(): - """ - Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern. - """ - rms_norm_input = torch.randn(2, 4, device="npu") - residual = torch.randn(2, 4, device="npu") - rms_norm_weight = torch.randn(4, device="npu") - rmsnorm_bias = torch.randn(4, device="npu") - scale = torch.ones(4, device="npu") - offset = torch.zeros(4, device="npu") - return [rms_norm_input, residual, rms_norm_weight, scale, offset, rmsnorm_bias] - - import torchair - - torchair.register_replacement( - search_fn=pattern, replace_fn=replacement, example_inputs=get_inputs(), extra_check=_extra_stream_scope_check - ) - - -# register converter for pass -common_epsilons = [1e-5, 1e-6] -for eps in common_epsilons: - logger.info(f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}") - replacement_add_rms_norm_quant(eps) - replacement_add_rms_norm_quant_with_bias(eps) - replacement_add_rms_norm_quant_sp_pattern(eps) - replacement_add_rms_norm_quant_sp_pattern_with_bias(eps) diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py new file mode 100644 index 00000000000..5c41100a1cf --- /dev/null +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py @@ -0,0 +1,316 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import torchair +from vllm.config import VllmConfig +from vllm.config.compilation import Range +from vllm.logger import logger + +from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check + + +class GraphEXAddRMSNormQuantPattern: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + scale = torch.ones(4, device="npu") + scale_reciprocal = torch.ones(4, device="npu") + offset = torch.zeros(4, device="npu") + return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset] + + def register(self): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + ): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) + return quantized_output, out1 + + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + ): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps + ) + quantized_output = output[0] + out1 = output[2] + return quantized_output, out1 + + torchair.register_replacement( + search_fn=pattern, + replace_fn=replacement, + example_inputs=self.get_inputs(), + extra_check=extra_stream_scope_check, + ) + + +class GraphEXAddRMSNormQuantPatternWithBias: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuantWithBias fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + rmsnorm_bias = torch.randn(4, device="npu") + scale = torch.ones(4, device="npu") + scale_reciprocal = torch.ones(4, device="npu") + offset = torch.zeros(4, device="npu") + return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias] + + # The replacement registered here will be actually executed after AOT. + def register(self): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): + """ + Pattern for AddRMSNormQuantWithBias fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + out0 = out0 + bias + quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) + return quantized_output, out1 + + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): + """ + Replacement for AddRMSNormQuantWithBias fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias + ) + quantized_output = output[0] + out1 = output[2] + return quantized_output, out1 + + torchair.register_replacement( + search_fn=pattern, + replace_fn=replacement, + example_inputs=self.get_inputs(), + extra_check=extra_stream_scope_check, + ) + + +class GraphEXAddRMSNormQuantSPPattern: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuantSPPattern fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + scale = torch.ones(4, device="npu") + scale_reciprocal = torch.ones(4, device="npu") + offset = torch.zeros(4, device="npu") + return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset] + + # The replacement registered here will be actually executed after AOT. + def register(self): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + ): + """ + Pattern for AddRMSNormQuantSPPattern fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) + quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) + return quantized_output, out1 + + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + ): + """ + Replacement for the AddRMSNormQuantSPPattern fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps + ) + quantized_output = output[0] + out1 = output[2] + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True) + return quantized_output, out1 + + torchair.register_replacement( + search_fn=pattern, + replace_fn=replacement, + example_inputs=self.get_inputs(), + extra_check=extra_stream_scope_check, + ) + + +class GraphEXAddRMSNormQuantSPPatternWithBias: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuantSPPatternWithBias fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + rmsnorm_bias = torch.randn(4, device="npu") + scale = torch.ones(4, device="npu") + scale_reciprocal = torch.ones(4, device="npu") + offset = torch.zeros(4, device="npu") + return [rms_norm_input, residual, rms_norm_weight, scale, scale_reciprocal, offset, rmsnorm_bias] + + # The replacement registered here will be actually executed after AOT. + def register(self): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): + """ + Pattern for AddRMSNormQuantSPPatternWithBias fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + out0 = out0 + bias + out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) + quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) + return quantized_output, out1 + + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + scale: torch.Tensor, + scale_reciprocal: torch.Tensor, + offset: torch.Tensor, + bias: torch.Tensor, + ): + """ + Replacement for the AddRMSNormQuantSPPatternWithBias fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, scale, offset, epsilon=self.eps, beta=bias + ) + quantized_output = output[0] + out1 = output[2] + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(quantized_output, True) + return quantized_output, out1 + + torchair.register_replacement( + search_fn=pattern, + replace_fn=replacement, + example_inputs=self.get_inputs(), + extra_check=extra_stream_scope_check, + ) + + +class GraphEXAddRMSNormFusionPass: + """ + A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend. + """ + + def __init__(self, vllm_config: VllmConfig): + dtype = vllm_config.model_config.dtype + if dtype not in (torch.bfloat16, torch.float16): + logger.debug("Quant fusion not enabled: unsupported dtype %s", dtype) + return + + common_epsilons = [1e-5, 1e-6] + for eps in common_epsilons: + GraphEXAddRMSNormQuantPattern(vllm_config, eps=eps).register() + GraphEXAddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register() + GraphEXAddRMSNormQuantSPPattern(vllm_config, eps=eps).register() + GraphEXAddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register() + + def __call__(self, graph: torch.fx.Graph): + pass + + def is_applicable_for_range(self, compile_range: Range) -> bool: + """ + Check if the pass is applicable for the current configuration. + """ + return True diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py new file mode 100644 index 00000000000..3317d132c58 --- /dev/null +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_qknorm_rope_fusion_pass.py @@ -0,0 +1,227 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import torchair +from vllm.attention.layer import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.compilation import Range +from vllm.logger import logger + +from vllm_ascend.compilation.npugraph_ex_passes.utils.npugraph_ex_utils_check import extra_stream_scope_check + + +class GraphEXQKNormRopeFusionPattern: + def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6): + self.vllm_config = vllm_config + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.eps = eps + self.device = vllm_config.device_config.device if vllm_config.device_config else None + + def get_inputs(self): + T = 5 + qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") + q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") + sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") + return [qkv, q_weight, k_weight, cos, sin] + + # The replacement registered here will be actually executed after AOT. + def register(self): + def pattern( + qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) + + q_flat = q_norm_out.view(q.shape) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) + + k_flat = k_norm_out.view(k.shape) + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) + + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + + return q_rope, k_rope, v + + def replacement( + qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ): + results = torch.ops.vllm.qkv_rmsnorm_rope( + input=qkv, + q_weight=q_weight, + k_weight=k_weight, + q_hidden_size=self.q_size, + kv_hidden_size=self.kv_size, + head_dim=self.head_dim, + eps=self.eps, + q_bias=None, + k_bias=None, + sin=sin, + cos=cos, + ) + + return results + + torchair.register_replacement( + search_fn=pattern, + replace_fn=replacement, + example_inputs=self.get_inputs(), + extra_check=extra_stream_scope_check, + ) + + +class GraphEXQKNormRopeFusionPatternWithBias: + def __init__(self, vllm_config, head_dim, num_heads, num_kv_heads, eps=1e-6): + self.vllm_config = vllm_config + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.eps = eps + self.device = vllm_config.device_config.device if vllm_config.device_config else None + + def get_inputs(self): + T = 5 + qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") + q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + cos = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") + sin = torch.empty(1, T, 1, self.head_dim, dtype=torch.bfloat16, device="npu") + return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin] + + # The replacement registered here will be actually executed after AOT. + def register(self): + def pattern( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_bias: torch.Tensor, + k_bias: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps) + q_normed = q_norm_out + q_bias + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) + k_normed = k_norm_out + k_bias + + q_flat = q_normed.view(q.shape) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) + + k_flat = k_normed.view(k.shape) + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) + + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + + return q_rope, k_rope, v + + def replacement( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_bias: torch.Tensor, + k_bias: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + results = torch.ops.vllm.qkv_rmsnorm_rope( + input=qkv, + q_weight=q_weight, + k_weight=k_weight, + q_hidden_size=self.q_size, + kv_hidden_size=self.kv_size, + head_dim=self.head_dim, + eps=self.eps, + q_bias=q_bias, + k_bias=k_bias, + sin=sin, + cos=cos, + ) + + return results + + torchair.register_replacement( + search_fn=pattern, + replace_fn=replacement, + example_inputs=self.get_inputs(), + extra_check=extra_stream_scope_check, + ) + + +class GraphEXQKNormRopeFusionPass: + """ + A pass for fusing QKV split and RMSNorm operations into a single qk_rmsnorm operator. + """ + + def __init__(self, vllm_config: VllmConfig): + dtype = vllm_config.model_config.dtype + if dtype not in (torch.bfloat16, torch.float16): + logger.debug("QKNorm and Rope fusion not enabled: unsupported dtype %s", dtype) + return + # use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern + attn_layers: dict[str, Attention] = get_layers_from_vllm_config(vllm_config, Attention) + if len(attn_layers) == 0: + logger.debug("QKNorm and Rope fusion enabled, but no Attention layers were discovered.") + return + layer = next(iter(attn_layers.values())) + for epsilon in [1e-6, 1e-5]: + if layer.head_size != 128: + logger.debug("QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size) + continue + GraphEXQKNormRopeFusionPattern( + vllm_config=vllm_config, + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + ).register() + GraphEXQKNormRopeFusionPatternWithBias( + vllm_config=vllm_config, + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + ).register() + + def __call__(self, graph: torch.fx.Graph): + pass + + def is_applicable_for_range(self, compile_range: Range) -> bool: + """ + Check if the pass is applicable for the current configuration. + """ + return True diff --git a/vllm_ascend/compilation/npugraph_ex_passes/utils/__init__.py b/vllm_ascend/compilation/npugraph_ex_passes/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py b/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py new file mode 100644 index 00000000000..481a16ed8c9 --- /dev/null +++ b/vllm_ascend/compilation/npugraph_ex_passes/utils/npugraph_ex_utils_check.py @@ -0,0 +1,53 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from torch._inductor.pattern_matcher import Match +from vllm.logger import logger + + +def extra_stream_scope_check(match: Match) -> bool: + """ + Checks if all nodes in the same stream. + """ + non_default_streams = set() + has_default = False + + for node in match.nodes: + if node.op == "call_function": + current_stream = node.meta.get("stream_label") + if current_stream is None: + has_default = True + else: + non_default_streams.add(current_stream) + if len(non_default_streams) > 1: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations." + ) + return False + + if has_default and len(non_default_streams) > 0: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations." + ) + return False + + return True diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 65e36223a3c..f32b51a8195 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -30,7 +30,7 @@ # todo: please remove it when solve cuda hard code in vllm os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1" -from vllm_ascend.ascend_config import init_ascend_config +from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config # isort: off from vllm_ascend.utils import ( @@ -121,7 +121,11 @@ def get_pass_manager_cls(cls) -> str: Get the pass manager class for this platform. It will be registered as a custom pass under the current_platform.pass_key. """ - return "vllm_ascend.compilation.graph_fusion_pass_manager.GraphFusionPassManager" + ascend_config = get_ascend_config() + if ascend_config.enable_npugraph_ex: + return "vllm_ascend.compilation.npu_graph_ex_pass_manager.NpuGraphEXPassManager" + else: + return "vllm_ascend.compilation.graph_fusion_pass_manager.GraphFusionPassManager" @classmethod def get_compile_backend(self) -> str: From 50c72f9a15af87275a6b10ea2ffb818a56be0c6d Mon Sep 17 00:00:00 2001 From: cjian <2318164299@qq.com> Date: Wed, 21 Jan 2026 09:25:29 +0800 Subject: [PATCH 2/2] 1 Signed-off-by: cjian <2318164299@qq.com> --- vllm_ascend/platform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index f32b51a8195..10da06a542a 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -121,8 +121,8 @@ def get_pass_manager_cls(cls) -> str: Get the pass manager class for this platform. It will be registered as a custom pass under the current_platform.pass_key. """ - ascend_config = get_ascend_config() - if ascend_config.enable_npugraph_ex: + npugraph_ex_config = get_ascend_config().npugraph_ex_config + if npugraph_ex_config.enable: return "vllm_ascend.compilation.npu_graph_ex_pass_manager.NpuGraphEXPassManager" else: return "vllm_ascend.compilation.graph_fusion_pass_manager.GraphFusionPassManager"