From d3a57c6fa21217cfab4fd151b7e0fb066ec44c8b Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Tue, 14 Oct 2025 03:29:19 +0000 Subject: [PATCH 01/25] Define quant fusion pass Signed-off-by: Icey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 38 +++- vllm_ascend/compilation/compiler_interface.py | 71 ++++++++ vllm_ascend/compilation/pass_manager.py | 55 ++++++ vllm_ascend/compilation/quant_fusion_pass.py | 93 ++++++++++ vllm_ascend/ops/layernorm.py | 169 ++++++------------ vllm_ascend/platform.py | 1 + 6 files changed, 300 insertions(+), 127 deletions(-) create mode 100644 vllm_ascend/compilation/compiler_interface.py create mode 100644 vllm_ascend/compilation/pass_manager.py create mode 100644 vllm_ascend/compilation/quant_fusion_pass.py diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 115dbef1209..d418aa39dec 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -36,8 +36,10 @@ def __init__(self, vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} torchair_graph_config = additional_config.get("torchair_graph_config", {}) - self.torchair_graph_config = TorchairGraphConfig( - torchair_graph_config, vllm_config, additional_config) + self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config) + + ascend_compilation_config = additional_config.get("ascend_compilation_config", {}) + self.ascend_compilation_config = AscendCompilationConfig(**ascend_compilation_config) ascend_scheduler_config = additional_config.get( "ascend_scheduler_config", {}) @@ -136,12 +138,21 @@ def __init__(self, vllm_config): if self.pd_tp_ratio == 0: raise AssertionError( "Only support P node tp size lagger then D node tp size") - self.SLO_limits_for_dynamic_batch = additional_config.get( - "SLO_limits_for_dynamic_batch", -1) - from vllm_ascend.utils import \ - get_flashcomm2_oproj_tp_size_and_validate_config - self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config( - self, vllm_config) + +class AscendCompilationConfig: + """ + Configuration Object for ascend_compilation_config from additional_config + """ + + def __init__(self, + enable_graph_rewriter: bool = True, + fx_graph_eager: bool = False, + enable_quantization_fusion: bool = True, + **kwargs): + self.enable_graph_rewriter = enable_graph_rewriter + self.fx_graph_eager = fx_graph_eager + self.enable_quantization_fusion = enable_quantization_fusion + # Add more compilation related configs here as needed class TorchairGraphConfig: @@ -326,6 +337,17 @@ def check_ascend_config(vllm_config, enforce_eager): "it has been disabled automatically.") # aclgraph case else: + # This graph fusion can actually works on eager mode. + if ascend_config.ascend_compilation_config.enable_graph_rewriter: + logger.info( + "Graph rewriter enabled! Automatic kernel fusion is expected." + ) + + if ascend_config.ascend_compilation_config.enable_quantization_fusion: + logger.info( + "Quantization fusion enabled! op fusion on quantization are expected. " + ) + if vllm_config.model_config: model_type = vllm_config.model_config.hf_config.model_type if "qwen" not in model_type: diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py new file mode 100644 index 00000000000..bcfa300a8ef --- /dev/null +++ b/vllm_ascend/compilation/compiler_interface.py @@ -0,0 +1,71 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Optional + +import torch +import torch.fx as fx +from vllm.compilation.compiler_interface import CompilerInterface +from vllm.compilation.counter import compilation_counter + + +def get_dtype_from_args(args: list[Any]) -> list[torch.dtype]: + """ + Extract the dtype from the kwargs dictionary. + """ + dtype_list = [] + for value in args: + if isinstance(value, torch.Tensor): + dtype_list.append(value.dtype) + return dtype_list + + +def get_shapes_from_args(args: list[Any]) -> list[torch.Size]: + """ + Extract the shapes from the kwargs dictionary. + """ + shape_list = [] + for value in args: + if isinstance(value, torch.Tensor): + shape_list.append(value.shape) + return shape_list + +class AscendAdaptor(CompilerInterface): + name = "AscendAdaptor" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + + current_pass_manager = compiler_config["graph_rewriter_manager"] + arg_dtypes = get_dtype_from_args(example_inputs) + arg_shapes = get_shapes_from_args(example_inputs) + kwargs = { + "runtime_shape": runtime_shape, + "arg_shapes": arg_shapes, + "arg_dtypes": arg_dtypes + } + graph = current_pass_manager(graph, **kwargs) + compilation_counter.num_eager_compiles += 1 + + \ No newline at end of file diff --git a/vllm_ascend/compilation/pass_manager.py b/vllm_ascend/compilation/pass_manager.py new file mode 100644 index 00000000000..a21667c9af0 --- /dev/null +++ b/vllm_ascend/compilation/pass_manager.py @@ -0,0 +1,55 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from torch import fx as fx +from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.config import VllmConfig + + +class GraphRewritePassManager: + """ + A pass manager for graph rewriting passes. + It handles the configuration and execution of passes. + The counterpart in vllm is PostGradPassManager. Since torch_npu does not + support inductor and triton for now, we choose to adopt the graph rewriter on + fx graph rather than the inductor pass manager. + """ + + def __init__(self): + self.passes: list[VllmInductorPass] = [] + + def __call__(self, graph: fx.Graph, **kwargs) -> fx.Graph: + for pass_ in self.passes: + if pass_.is_applicable(**kwargs): + pass_(graph) + graph.recompile() + return graph + + def add(self, pass_: VllmInductorPass): + assert isinstance(pass_, VllmInductorPass) + self.passes.append(pass_) + + def configure(self, config: VllmConfig): + # By default, we enable the graph rewriter and quantization fusion pass. + self.ascend_compilation_config: dict = config.additional_config.get( + "ascend_compilation_config", {}) + if self.ascend_compilation_config.get("enable_quantization_fusion", + True): + from .quant_fusion_pass import AscendQuantFusionPass + self.passes.append(AscendQuantFusionPass(config)) + # Add more passes here as needed \ No newline at end of file diff --git a/vllm_ascend/compilation/quant_fusion_pass.py b/vllm_ascend/compilation/quant_fusion_pass.py new file mode 100644 index 00000000000..e76970060da --- /dev/null +++ b/vllm_ascend/compilation/quant_fusion_pass.py @@ -0,0 +1,93 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Callable, List, Tuple + +import torch +from torch.fx.subgraph_rewriter import replace_pattern +from vllm.compilation.vllm_inductor_pass import VllmInductorPass + + +class AddRMSNormQuantPattern: + + def __init__(self, vllm_config): + self.vllm_config = vllm_config + + def register(self, patterns: List[Tuple[Callable, Callable]]): + + def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, 1e-6) + out0 = output[0] + out1 = output[2] + quantized_output = torch.ops.npu.npu_quantize( + out0, scale, offset, torch.qint8, -1, False) + return quantized_output, out1 + + def replace(rms_norm_input, residual, rms_norm_weight, scale, offset): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, + residual, + rms_norm_weight, + 1. / + scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. + offset, + epsilon=1e-6) + quantized_output = output[0] + out1 = output[2] + return quantized_output, out1 + + patterns.append((pattern, replace)) + + +class AscendQuantFusionPass(VllmInductorPass): + """ + A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend. + """ + + def __init__(self, vllm_config): + super().__init__(vllm_config) + self.patterns: List[Tuple[Callable, Callable]] = [] + # Register the AddRMSNormQuant fusion pattern into the graph rewriter pattern list + AddRMSNormQuantPattern(vllm_config).register(self.patterns) + + def __call__(self, graph: torch.fx.Graph): + self.begin() + for pattern, replace in self.patterns: + replace_pattern(graph, pattern, replace) + self.end_and_log() + + def is_applicable(self, **kwargs): + """ + Check if the pass is applicable for the current configuration. + """ + arg_dtypes = kwargs.get("arg_dtypes", None) + if arg_dtypes is None: + return False + # We assume the first tensor's dtype is the data type of this model, update this solution when there is + # better solution. + dtype = arg_dtypes[0] if isinstance( + arg_dtypes, list) and len(arg_dtypes) > 0 else arg_dtypes + # We found that the kernel npu_add_rms_norm_quant accept varying data format for different dtypes, therefore, we only + # provide the solution on bfloat16 here. + return dtype in (torch.bfloat16, ) \ No newline at end of file diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index da5051c0aad..606d90777c5 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -23,85 +23,47 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm -def _addrmsnorm_forward_oot( - self, - x: torch.Tensor, - residual: torch.Tensor, - layer: Optional[torch.nn.Module] = None, - bias: Optional[torch.nn.Parameter] = None, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - import torch_npu - - from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type - - if layer is not None and get_ascend_device_type( - ) != AscendDeviceType._310P: - layer_cls_name = layer.__class__.__name__ - try: - weight_prefetch_method = get_forward_context( - ).weight_prefetch_method - except AssertionError: - weight_prefetch_method = None - - # prefetch qkvo_proj.weight preprocess - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( - layer_cls_name=layer_cls_name, - weight=layer.weight, - start_flag=x, - ) - # add_rms_norm_quant - x, _, residual = torch_npu.npu_add_rms_norm_quant( - x, - residual, - self.weight, - layer.aclnn_input_scale, - layer.aclnn_input_offset, - beta=bias, - epsilon=self.variance_epsilon) - - # prefetch qkvo_proj.weight postprocess - if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( - layer_cls_name=layer_cls_name, - stop_flag=x, - ) - - else: - if get_ascend_device_type() == AscendDeviceType._310P: - orig_dtype = residual.dtype - x = x + residual.to(x.dtype) - residual = x.to(orig_dtype) - x, _ = torch_npu.npu_rms_norm(x, self.weight, - self.variance_epsilon) - else: - x, _, residual = torch_npu.npu_add_rms_norm( - x, residual, self.weight, self.variance_epsilon) - if bias is not None: - x.add_(bias) - torch.ops.vllm.maybe_wait_prefetch_done(x) - return x, residual +# class AddRMSNormW8A8Quant(RMSNorm): +# # Fuse AddRmsNorm and W8A8 quantization ops together + +# def __init__( +# self, +# hidden_size: int, +# layer: torch.nn.Module, +# eps: float = 1e-6, +# var_hidden_size: Optional[int] = None, +# has_weight: bool = True, +# dtype: Optional[torch.dtype] = None, +# ) -> None: +# super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) +# self.layer = layer + +# def forward( +# self, +# x: torch.Tensor, +# residual: Optional[torch.Tensor] = None, +# ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: +# import torch_npu + +# if residual is not None: +# residual = torch.ops.vllm.maybe_chunk_residual(x, residual) +# assert x.size(0) == residual.size(0) +# x, _, residual = torch_npu.npu_add_rms_norm_quant( +# x, +# residual, +# self.weight, +# self.layer.aclnn_input_scale, +# self.layer.aclnn_input_offset, +# epsilon=self.variance_epsilon) +# torch.ops.vllm.maybe_wait_prefetch_done(x) +# return x, residual + +# x, residual = torch_npu.npu_rms_norm(x, self.weight, +# self.variance_epsilon) +# return x class AscendRMSNorm(RMSNorm): - - def __init__( - self, - hidden_size: int, - eps: float = 1e-6, - var_hidden_size: Optional[int] = None, - has_weight: bool = True, - dtype: Optional[torch.dtype] = None, - ) -> None: - super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) - vllm_config = get_current_vllm_config() - self.bias = None - # quantization with anti_method m4 will generate none-zero norm bias - if vllm_config.quant_config is not None and \ - any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), - requires_grad=False) - def forward_oot( self, x: torch.Tensor, @@ -109,59 +71,28 @@ def forward_oot( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu + from vllm_ascend.utils import is_310p if residual is not None: residual = torch.ops.vllm.maybe_chunk_residual(x, residual) assert x.size(0) == residual.size(0) - x, residual = _addrmsnorm_forward_oot( - self, x, residual, self.next_need_quant_fusion_linear, - self.bias) + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) + torch.ops.vllm.maybe_wait_prefetch_done(x) return x, residual + x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) if self.bias is not None: x.add_(self.bias) return x - @property - def next_need_quant_fusion_linear(self): - try: - forward_context = get_forward_context() - if not forward_context.addrmsnorm_quant_fusion_enabled or \ - forward_context.layer_idx == forward_context.num_hidden_layers: - return None - except AssertionError: - return None - - next_linear = None - model_instance = forward_context.model_instance - layer_idx = forward_context.layer_idx - fusion_linear = forward_context.fusion_linear - next_linear = None - if fusion_linear == "qkv_dense": - next_linear = model_instance.model.layers[ - layer_idx].self_attn.qkv_proj - forward_context.fusion_linear = "gate_up_dense" - elif fusion_linear == "gate_up_dense": - next_linear = model_instance.model.layers[ - layer_idx].mlp.gate_up_proj - forward_context.fusion_linear = "qkv_dense" - # if prefetch_mlp_weight enabled, following accumulation operation - # does not need to be repeated - if not forward_context.prefetch_mlp_enabled: - forward_context.layer_idx += 1 - elif fusion_linear == "qkv_moe": - next_linear = model_instance.model.layers[ - layer_idx].self_attn.qkv_proj - forward_context.fusion_linear = "gate_moe" - elif fusion_linear == "gate_moe": - forward_context.fusion_linear = "qkv_moe" - forward_context.layer_idx += 1 - from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod - if next_linear is not None and \ - not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod): - next_linear = None - return next_linear - class AscendQuantRMSNorm(AscendRMSNorm): diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index f59d1ed1e50..ac8d9681a8d 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -30,6 +30,7 @@ init_ascend_config) from vllm_ascend.torchair.utils import (check_torchair_cache_exist, delete_torchair_cache_file) +from vllm_ascend.compilation.compiler_interface import AscendAdaptor # isort: off from vllm_ascend.utils import ( From 2695e37df6d11beca8a4514c935ede32d32ec37b Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Tue, 14 Oct 2025 08:39:39 +0000 Subject: [PATCH 02/25] fix Signed-off-by: Icey <1790571317@qq.com> --- vllm_ascend/compilation/acl_graph.py | 8 +++- vllm_ascend/compilation/compiler_interface.py | 1 + ...nager.py => graph_rewrite_pass_manager.py} | 0 vllm_ascend/ops/layernorm.py | 40 ------------------- 4 files changed, 8 insertions(+), 41 deletions(-) rename vllm_ascend/compilation/{pass_manager.py => graph_rewrite_pass_manager.py} (100%) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 025ff3c12ca..9c4e33b0e66 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -65,6 +65,11 @@ def __init__(self, cudagraph_options: Optional[CUDAGraphOptions] = None): self.runnable = runnable self.vllm_config = vllm_config + self.ascend_compilation_config: dict = vllm_config.additional_config.get( + "ascend_compilation_config", {}) + self.fx_graph_eager = self.ascend_compilation_config.get( + "fx_graph_eager", False) + self.graph_pool = graph_pool self.runtime_mode = runtime_mode self.compilation_config = vllm_config.compilation_config @@ -101,7 +106,8 @@ def __call__(self, *args, **kwargs): aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode if aclgraph_runtime_mode == CUDAGraphMode.NONE or \ - aclgraph_runtime_mode != self.runtime_mode: + aclgraph_runtime_mode != self.runtime_mode or \ + self.fx_graph_eager: # CUDAGraphMode.NONE could mean the profile run, a warmup run, or # running without aclgraphs. # We do not trigger capture/replay if the runtime mode is not diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index bcfa300a8ef..3b25e86f8ee 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -67,5 +67,6 @@ def compile( } graph = current_pass_manager(graph, **kwargs) compilation_counter.num_eager_compiles += 1 + return graph, None \ No newline at end of file diff --git a/vllm_ascend/compilation/pass_manager.py b/vllm_ascend/compilation/graph_rewrite_pass_manager.py similarity index 100% rename from vllm_ascend/compilation/pass_manager.py rename to vllm_ascend/compilation/graph_rewrite_pass_manager.py diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 606d90777c5..3bec6aecf39 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -23,46 +23,6 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm -# class AddRMSNormW8A8Quant(RMSNorm): -# # Fuse AddRmsNorm and W8A8 quantization ops together - -# def __init__( -# self, -# hidden_size: int, -# layer: torch.nn.Module, -# eps: float = 1e-6, -# var_hidden_size: Optional[int] = None, -# has_weight: bool = True, -# dtype: Optional[torch.dtype] = None, -# ) -> None: -# super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) -# self.layer = layer - -# def forward( -# self, -# x: torch.Tensor, -# residual: Optional[torch.Tensor] = None, -# ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: -# import torch_npu - -# if residual is not None: -# residual = torch.ops.vllm.maybe_chunk_residual(x, residual) -# assert x.size(0) == residual.size(0) -# x, _, residual = torch_npu.npu_add_rms_norm_quant( -# x, -# residual, -# self.weight, -# self.layer.aclnn_input_scale, -# self.layer.aclnn_input_offset, -# epsilon=self.variance_epsilon) -# torch.ops.vllm.maybe_wait_prefetch_done(x) -# return x, residual - -# x, residual = torch_npu.npu_rms_norm(x, self.weight, -# self.variance_epsilon) -# return x - - class AscendRMSNorm(RMSNorm): def forward_oot( self, From a9f1d3c68b43c6e7a8926123c0458f5e6f499ab8 Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Tue, 14 Oct 2025 10:52:45 +0000 Subject: [PATCH 03/25] format file Signed-off-by: Icey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 11 +++++++---- vllm_ascend/compilation/compiler_interface.py | 5 ++--- vllm_ascend/compilation/graph_rewrite_pass_manager.py | 4 ++-- vllm_ascend/compilation/quant_fusion_pass.py | 11 ++++++----- vllm_ascend/ops/layernorm.py | 1 + vllm_ascend/platform.py | 1 + 6 files changed, 19 insertions(+), 14 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index d418aa39dec..4f288acced2 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -37,9 +37,11 @@ def __init__(self, vllm_config): torchair_graph_config = additional_config.get("torchair_graph_config", {}) self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config) - - ascend_compilation_config = additional_config.get("ascend_compilation_config", {}) - self.ascend_compilation_config = AscendCompilationConfig(**ascend_compilation_config) + + ascend_compilation_config = additional_config.get( + "ascend_compilation_config", {}) + self.ascend_compilation_config = AscendCompilationConfig( + **ascend_compilation_config) ascend_scheduler_config = additional_config.get( "ascend_scheduler_config", {}) @@ -139,6 +141,7 @@ def __init__(self, vllm_config): raise AssertionError( "Only support P node tp size lagger then D node tp size") + class AscendCompilationConfig: """ Configuration Object for ascend_compilation_config from additional_config @@ -347,7 +350,7 @@ def check_ascend_config(vllm_config, enforce_eager): logger.info( "Quantization fusion enabled! op fusion on quantization are expected. " ) - + if vllm_config.model_config: model_type = vllm_config.model_config.hf_config.model_type if "qwen" not in model_type: diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 3b25e86f8ee..5745b8e18c8 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -45,9 +45,10 @@ def get_shapes_from_args(args: list[Any]) -> list[torch.Size]: shape_list.append(value.shape) return shape_list + class AscendAdaptor(CompilerInterface): name = "AscendAdaptor" - + def compile( self, graph: fx.GraphModule, @@ -68,5 +69,3 @@ def compile( graph = current_pass_manager(graph, **kwargs) compilation_counter.num_eager_compiles += 1 return graph, None - - \ No newline at end of file diff --git a/vllm_ascend/compilation/graph_rewrite_pass_manager.py b/vllm_ascend/compilation/graph_rewrite_pass_manager.py index a21667c9af0..f1ed7d357e2 100644 --- a/vllm_ascend/compilation/graph_rewrite_pass_manager.py +++ b/vllm_ascend/compilation/graph_rewrite_pass_manager.py @@ -43,7 +43,7 @@ def __call__(self, graph: fx.Graph, **kwargs) -> fx.Graph: def add(self, pass_: VllmInductorPass): assert isinstance(pass_, VllmInductorPass) self.passes.append(pass_) - + def configure(self, config: VllmConfig): # By default, we enable the graph rewriter and quantization fusion pass. self.ascend_compilation_config: dict = config.additional_config.get( @@ -52,4 +52,4 @@ def configure(self, config: VllmConfig): True): from .quant_fusion_pass import AscendQuantFusionPass self.passes.append(AscendQuantFusionPass(config)) - # Add more passes here as needed \ No newline at end of file + # Add more passes here as needed diff --git a/vllm_ascend/compilation/quant_fusion_pass.py b/vllm_ascend/compilation/quant_fusion_pass.py index e76970060da..ad72ff8a70b 100644 --- a/vllm_ascend/compilation/quant_fusion_pass.py +++ b/vllm_ascend/compilation/quant_fusion_pass.py @@ -21,7 +21,7 @@ import torch from torch.fx.subgraph_rewriter import replace_pattern from vllm.compilation.vllm_inductor_pass import VllmInductorPass - + class AddRMSNormQuantPattern: @@ -34,7 +34,8 @@ def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, 1e-6) + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, + rms_norm_weight, 1e-6) out0 = output[0] out1 = output[2] quantized_output = torch.ops.npu.npu_quantize( @@ -70,11 +71,11 @@ def __init__(self, vllm_config): self.patterns: List[Tuple[Callable, Callable]] = [] # Register the AddRMSNormQuant fusion pattern into the graph rewriter pattern list AddRMSNormQuantPattern(vllm_config).register(self.patterns) - + def __call__(self, graph: torch.fx.Graph): self.begin() for pattern, replace in self.patterns: - replace_pattern(graph, pattern, replace) + replace_pattern(graph, pattern, replace) self.end_and_log() def is_applicable(self, **kwargs): @@ -90,4 +91,4 @@ def is_applicable(self, **kwargs): arg_dtypes, list) and len(arg_dtypes) > 0 else arg_dtypes # We found that the kernel npu_add_rms_norm_quant accept varying data format for different dtypes, therefore, we only # provide the solution on bfloat16 here. - return dtype in (torch.bfloat16, ) \ No newline at end of file + return dtype in (torch.bfloat16, ) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 3bec6aecf39..614850c983f 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -24,6 +24,7 @@ class AscendRMSNorm(RMSNorm): + def forward_oot( self, x: torch.Tensor, diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index ac8d9681a8d..78a571156f2 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -226,6 +226,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: compilation_config.use_inductor = False compilation_config.splitting_ops.extend(["vllm::mla_forward"]) update_aclgraph_sizes(vllm_config) + compilation_config.oot_compiler = AscendAdaptor.__module__ + "." + AscendAdaptor.__name__ elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ compilation_config.cudagraph_mode == CUDAGraphMode.FULL: logger.info( From f111d704d3d9c57082bbe0bc11b9bd80e5afd762 Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Fri, 31 Oct 2025 02:56:33 +0000 Subject: [PATCH 04/25] Change to graph fusion Signed-off-by: Icey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 4 +-- vllm_ascend/compilation/compiler_interface.py | 33 ++++++++++------- ...anager.py => graph_fusion_pass_manager.py} | 2 +- vllm_ascend/compilation/quant_fusion_pass.py | 36 ++++++++++++------- vllm_ascend/platform.py | 4 +++ 5 files changed, 51 insertions(+), 28 deletions(-) rename vllm_ascend/compilation/{graph_rewrite_pass_manager.py => graph_fusion_pass_manager.py} (98%) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 4f288acced2..907cf2a988e 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -148,11 +148,11 @@ class AscendCompilationConfig: """ def __init__(self, - enable_graph_rewriter: bool = True, + enable_graph_fusion: bool = True, fx_graph_eager: bool = False, enable_quantization_fusion: bool = True, **kwargs): - self.enable_graph_rewriter = enable_graph_rewriter + self.enable_graph_fusion = enable_graph_fusion self.fx_graph_eager = fx_graph_eager self.enable_quantization_fusion = enable_quantization_fusion # Add more compilation related configs here as needed diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 5745b8e18c8..bc5ada2aacf 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -22,7 +22,9 @@ import torch.fx as fx from vllm.compilation.compiler_interface import CompilerInterface from vllm.compilation.counter import compilation_counter - +from torch._dynamo.backends.common import aot_autograd +from torch._inductor.decomposition import select_decomp_table +from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized def get_dtype_from_args(args: list[Any]) -> list[torch.dtype]: """ @@ -57,15 +59,22 @@ def compile( runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: + def compile_inner(graph, example_inputs): + current_pass_manager = compiler_config["graph_fusion_manager"] + arg_dtypes = get_dtype_from_args(example_inputs) + arg_shapes = get_shapes_from_args(example_inputs) + kwargs = { + "runtime_shape": runtime_shape, + "arg_shapes": arg_shapes, + "arg_dtypes": arg_dtypes + } + graph = current_pass_manager(graph, **kwargs) + graph = decompose_auto_functionalized(graph) + compilation_counter.num_eager_compiles += 1 + return graph, None - current_pass_manager = compiler_config["graph_rewriter_manager"] - arg_dtypes = get_dtype_from_args(example_inputs) - arg_shapes = get_shapes_from_args(example_inputs) - kwargs = { - "runtime_shape": runtime_shape, - "arg_shapes": arg_shapes, - "arg_dtypes": arg_dtypes - } - graph = current_pass_manager(graph, **kwargs) - compilation_counter.num_eager_compiles += 1 - return graph, None + # Use the default decomposition table to decompose operators. + decompositions = select_decomp_table() + + # Use AOT Autograd to handle the forward compilation. + return aot_autograd(fw_compiler=compile_inner, decompositions=decompositions)(graph, example_inputs), None diff --git a/vllm_ascend/compilation/graph_rewrite_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py similarity index 98% rename from vllm_ascend/compilation/graph_rewrite_pass_manager.py rename to vllm_ascend/compilation/graph_fusion_pass_manager.py index f1ed7d357e2..86129aef62f 100644 --- a/vllm_ascend/compilation/graph_rewrite_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -21,7 +21,7 @@ from vllm.config import VllmConfig -class GraphRewritePassManager: +class AscendFusionPassManager: """ A pass manager for graph rewriting passes. It handles the configuration and execution of passes. diff --git a/vllm_ascend/compilation/quant_fusion_pass.py b/vllm_ascend/compilation/quant_fusion_pass.py index ad72ff8a70b..a22d63458da 100644 --- a/vllm_ascend/compilation/quant_fusion_pass.py +++ b/vllm_ascend/compilation/quant_fusion_pass.py @@ -15,21 +15,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from typing import Callable, List, Tuple - import torch -from torch.fx.subgraph_rewriter import replace_pattern from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from torch._inductor.pattern_matcher import PatternMatcherPass +import torch._inductor.pattern_matcher as pm class AddRMSNormQuantPattern: def __init__(self, vllm_config): self.vllm_config = vllm_config + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + scale = torch.tensor([1.0], device="npu") + offset = torch.tensor([0.0], device="npu") + return [rms_norm_input, residual, rms_norm_weight, scale, offset] - def register(self, patterns: List[Tuple[Callable, Callable]]): - + def register(self, pm_pass: PatternMatcherPass): + def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): """ Pattern for AddRMSNormQuant fusion. @@ -42,7 +51,7 @@ def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): out0, scale, offset, torch.qint8, -1, False) return quantized_output, out1 - def replace(rms_norm_input, residual, rms_norm_weight, scale, offset): + def replacement(rms_norm_input, residual, rms_norm_weight, scale, offset): """ Replacement for the AddRMSNormQuant fusion. """ @@ -57,8 +66,8 @@ def replace(rms_norm_input, residual, rms_norm_weight, scale, offset): quantized_output = output[0] out1 = output[2] return quantized_output, out1 - - patterns.append((pattern, replace)) + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) class AscendQuantFusionPass(VllmInductorPass): @@ -68,15 +77,16 @@ class AscendQuantFusionPass(VllmInductorPass): def __init__(self, vllm_config): super().__init__(vllm_config) - self.patterns: List[Tuple[Callable, Callable]] = [] + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rmsnorm_quant_fusion_pass" + ) # Register the AddRMSNormQuant fusion pattern into the graph rewriter pattern list AddRMSNormQuantPattern(vllm_config).register(self.patterns) def __call__(self, graph: torch.fx.Graph): self.begin() - for pattern, replace in self.patterns: - replace_pattern(graph, pattern, replace) - self.end_and_log() + matched_count = self.patterns.apply(graph) + self.end_and_log() def is_applicable(self, **kwargs): """ diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 78a571156f2..950ee48f685 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -66,6 +66,10 @@ class NPUPlatform(Platform): def is_sleep_mode_available(self) -> bool: return True + + @property + def pass_key(self) -> str: + return "graph_fusion_manager" @classmethod def pre_register_and_update(cls, From 1b43fc2a88b377fb0205d38f5200e4726e4fbe8a Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Fri, 31 Oct 2025 06:29:10 +0000 Subject: [PATCH 05/25] tiny fix Signed-off-by: Icey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 4 ++-- vllm_ascend/compilation/compiler_interface.py | 9 ++------- vllm_ascend/compilation/graph_fusion_pass_manager.py | 4 ++-- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 907cf2a988e..2902500f828 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -341,9 +341,9 @@ def check_ascend_config(vllm_config, enforce_eager): # aclgraph case else: # This graph fusion can actually works on eager mode. - if ascend_config.ascend_compilation_config.enable_graph_rewriter: + if ascend_config.ascend_compilation_config.enable_graph_fusion: logger.info( - "Graph rewriter enabled! Automatic kernel fusion is expected." + "graph fusion enabled! Automatic kernel fusion is expected." ) if ascend_config.ascend_compilation_config.enable_quantization_fusion: diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index bc5ada2aacf..41039669310 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -69,12 +69,7 @@ def compile_inner(graph, example_inputs): "arg_dtypes": arg_dtypes } graph = current_pass_manager(graph, **kwargs) - graph = decompose_auto_functionalized(graph) - compilation_counter.num_eager_compiles += 1 - return graph, None + return graph - # Use the default decomposition table to decompose operators. - decompositions = select_decomp_table() - # Use AOT Autograd to handle the forward compilation. - return aot_autograd(fw_compiler=compile_inner, decompositions=decompositions)(graph, example_inputs), None + return aot_autograd(fw_compiler=compile_inner)(graph, example_inputs), None diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 86129aef62f..1254a0de2cb 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -21,7 +21,7 @@ from vllm.config import VllmConfig -class AscendFusionPassManager: +class GraphFusionPassManager: """ A pass manager for graph rewriting passes. It handles the configuration and execution of passes. @@ -37,7 +37,7 @@ def __call__(self, graph: fx.Graph, **kwargs) -> fx.Graph: for pass_ in self.passes: if pass_.is_applicable(**kwargs): pass_(graph) - graph.recompile() + # graph.recompile() # 这句话是必写吗 return graph def add(self, pass_: VllmInductorPass): From 4d78271b03569acea9727c24f69f24200dbc0f13 Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Tue, 4 Nov 2025 11:52:45 +0000 Subject: [PATCH 06/25] tiny fix Signed-off-by: Icey <1790571317@qq.com> --- vllm_ascend/ops/layernorm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 614850c983f..025e4fae86d 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -34,8 +34,6 @@ def forward_oot( from vllm_ascend.utils import is_310p if residual is not None: - residual = torch.ops.vllm.maybe_chunk_residual(x, residual) - assert x.size(0) == residual.size(0) if is_310p(): orig_dtype = residual.dtype x = x + residual.to(x.dtype) @@ -45,7 +43,6 @@ def forward_oot( else: x, _, residual = torch_npu.npu_add_rms_norm( x, residual, self.weight, self.variance_epsilon) - torch.ops.vllm.maybe_wait_prefetch_done(x) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, From 8699caebbd119d6a913b5d26c763943f1318ad0b Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Wed, 12 Nov 2025 09:33:16 +0000 Subject: [PATCH 07/25] fix graph output Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/compiler_interface.py | 63 ++++++++++++++++--- 1 file changed, 56 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 41039669310..a094723c7af 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -17,14 +17,13 @@ # from typing import Any, Callable, Optional - +import functools import torch import torch.fx as fx from vllm.compilation.compiler_interface import CompilerInterface -from vllm.compilation.counter import compilation_counter from torch._dynamo.backends.common import aot_autograd -from torch._inductor.decomposition import select_decomp_table -from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized +from torch._inductor.utils import output_node +import torch.utils._pytree as pytree def get_dtype_from_args(args: list[Any]) -> list[torch.dtype]: """ @@ -48,9 +47,47 @@ def get_shapes_from_args(args: list[Any]) -> list[torch.Size]: return shape_list +def graph_returns_tuple(gm: fx.GraphModule) -> bool: + """True if a FX graph returns a tuple""" + if not isinstance(gm, fx.GraphModule): + return True # can't check this, assume true + (rv,) = output_node(gm).args + if isinstance(rv, (list, tuple)): + return True + if ( + isinstance(rv, torch.fx.node.Node) + and hasattr(rv.target, "_schema") + and len(rv.target._schema.returns) > 1 + and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns) + ): + # for graphs whose result is one node with multiple outputs + return True + return False + + +def make_graph_return_tuple( + gm: fx.GraphModule, +) -> tuple[Any, fx.GraphModule]: + """ + Mutate gm so it returns a tuple. This is only needed for graphs + not created by torchdynamo that return non-tuples. + Returns: + spec: The original output structure specification + gm: The modified GraphModule that returns a tuple + """ + node = output_node(gm) + (rv,) = node.args + rv, spec = pytree.tree_flatten(rv) + with gm.graph.inserting_before(node): + gm.graph.output(rv) + gm.graph.erase_node(node) + assert graph_returns_tuple(gm) + + return spec, gm + + class AscendAdaptor(CompilerInterface): name = "AscendAdaptor" - def compile( self, graph: fx.GraphModule, @@ -71,5 +108,17 @@ def compile_inner(graph, example_inputs): graph = current_pass_manager(graph, **kwargs) return graph - # Use AOT Autograd to handle the forward compilation. - return aot_autograd(fw_compiler=compile_inner)(graph, example_inputs), None + if not graph_returns_tuple(graph): + spec, graph = make_graph_return_tuple(graph) + else: + spec = None + + compiled_fn = aot_autograd(fw_compiler=compile_inner)(graph, example_inputs) + + if spec is not None: + @functools.wraps(compiled_fn) + def wrapper(*args, **kwargs): + return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) + return wrapper, None + else: + return compiled_fn, None \ No newline at end of file From 46f6ba153b25b5e0f084d7a41d13f7050a1d9176 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 13 Nov 2025 07:05:28 +0000 Subject: [PATCH 08/25] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/compiler_interface.py | 43 +++++++++++-------- .../compilation/graph_fusion_pass_manager.py | 1 - vllm_ascend/compilation/quant_fusion_pass.py | 20 ++++----- 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index a094723c7af..31825ae68e9 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -16,14 +16,16 @@ # limitations under the License. # -from typing import Any, Callable, Optional import functools +from typing import Any, Callable, Optional + import torch import torch.fx as fx -from vllm.compilation.compiler_interface import CompilerInterface +import torch.utils._pytree as pytree from torch._dynamo.backends.common import aot_autograd from torch._inductor.utils import output_node -import torch.utils._pytree as pytree +from vllm.compilation.compiler_interface import CompilerInterface + def get_dtype_from_args(args: list[Any]) -> list[torch.dtype]: """ @@ -51,23 +53,20 @@ def graph_returns_tuple(gm: fx.GraphModule) -> bool: """True if a FX graph returns a tuple""" if not isinstance(gm, fx.GraphModule): return True # can't check this, assume true - (rv,) = output_node(gm).args + (rv, ) = output_node(gm).args if isinstance(rv, (list, tuple)): return True - if ( - isinstance(rv, torch.fx.node.Node) - and hasattr(rv.target, "_schema") - and len(rv.target._schema.returns) > 1 - and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns) - ): + if (isinstance(rv, torch.fx.node.Node) and hasattr(rv.target, "_schema") + and len(rv.target._schema.returns) > 1 and all( + str(ret.type) == "Tensor" + for ret in rv.target._schema.returns)): # for graphs whose result is one node with multiple outputs return True return False - - + + def make_graph_return_tuple( - gm: fx.GraphModule, -) -> tuple[Any, fx.GraphModule]: + gm: fx.GraphModule, ) -> tuple[Any, fx.GraphModule]: """ Mutate gm so it returns a tuple. This is only needed for graphs not created by torchdynamo that return non-tuples. @@ -76,18 +75,19 @@ def make_graph_return_tuple( gm: The modified GraphModule that returns a tuple """ node = output_node(gm) - (rv,) = node.args + (rv, ) = node.args rv, spec = pytree.tree_flatten(rv) with gm.graph.inserting_before(node): gm.graph.output(rv) gm.graph.erase_node(node) assert graph_returns_tuple(gm) - + return spec, gm class AscendAdaptor(CompilerInterface): name = "AscendAdaptor" + def compile( self, graph: fx.GraphModule, @@ -96,6 +96,7 @@ def compile( runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: + def compile_inner(graph, example_inputs): current_pass_manager = compiler_config["graph_fusion_manager"] arg_dtypes = get_dtype_from_args(example_inputs) @@ -113,12 +114,16 @@ def compile_inner(graph, example_inputs): else: spec = None - compiled_fn = aot_autograd(fw_compiler=compile_inner)(graph, example_inputs) + compiled_fn = aot_autograd(fw_compiler=compile_inner)(graph, + example_inputs) if spec is not None: + @functools.wraps(compiled_fn) def wrapper(*args, **kwargs): - return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) + return pytree.tree_unflatten(compiled_fn(*args, **kwargs), + spec) + return wrapper, None else: - return compiled_fn, None \ No newline at end of file + return compiled_fn, None diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 1254a0de2cb..bff150c9cb4 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -37,7 +37,6 @@ def __call__(self, graph: fx.Graph, **kwargs) -> fx.Graph: for pass_ in self.passes: if pass_.is_applicable(**kwargs): pass_(graph) - # graph.recompile() # 这句话是必写吗 return graph def add(self, pass_: VllmInductorPass): diff --git a/vllm_ascend/compilation/quant_fusion_pass.py b/vllm_ascend/compilation/quant_fusion_pass.py index a22d63458da..dfd35fb5088 100644 --- a/vllm_ascend/compilation/quant_fusion_pass.py +++ b/vllm_ascend/compilation/quant_fusion_pass.py @@ -16,16 +16,16 @@ # limitations under the License. # import torch -from vllm.compilation.vllm_inductor_pass import VllmInductorPass -from torch._inductor.pattern_matcher import PatternMatcherPass import torch._inductor.pattern_matcher as pm +from torch._inductor.pattern_matcher import PatternMatcherPass +from vllm.compilation.vllm_inductor_pass import VllmInductorPass class AddRMSNormQuantPattern: def __init__(self, vllm_config): self.vllm_config = vllm_config - + def get_inputs(self): """ Generate example inputs for the AddRMSNormQuant fusion pattern. @@ -38,7 +38,7 @@ def get_inputs(self): return [rms_norm_input, residual, rms_norm_weight, scale, offset] def register(self, pm_pass: PatternMatcherPass): - + def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): """ Pattern for AddRMSNormQuant fusion. @@ -51,7 +51,8 @@ def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): out0, scale, offset, torch.qint8, -1, False) return quantized_output, out1 - def replacement(rms_norm_input, residual, rms_norm_weight, scale, offset): + def replacement(rms_norm_input, residual, rms_norm_weight, scale, + offset): """ Replacement for the AddRMSNormQuant fusion. """ @@ -66,8 +67,9 @@ def replacement(rms_norm_input, residual, rms_norm_weight, scale, offset): quantized_output = output[0] out1 = output[2] return quantized_output, out1 - - pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) class AscendQuantFusionPass(VllmInductorPass): @@ -78,9 +80,7 @@ class AscendQuantFusionPass(VllmInductorPass): def __init__(self, vllm_config): super().__init__(vllm_config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="rmsnorm_quant_fusion_pass" - ) - # Register the AddRMSNormQuant fusion pattern into the graph rewriter pattern list + pass_name="rmsnorm_quant_fusion_pass") AddRMSNormQuantPattern(vllm_config).register(self.patterns) def __call__(self, graph: torch.fx.Graph): From 4c4e8483766ca611beef420a1f99df477470f98d Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 13 Nov 2025 12:58:11 +0000 Subject: [PATCH 09/25] remove auto-grad Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/compiler_interface.py | 77 +++---------------- 1 file changed, 10 insertions(+), 67 deletions(-) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 31825ae68e9..ee57afec3e3 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -16,7 +16,6 @@ # limitations under the License. # -import functools from typing import Any, Callable, Optional import torch @@ -49,42 +48,6 @@ def get_shapes_from_args(args: list[Any]) -> list[torch.Size]: return shape_list -def graph_returns_tuple(gm: fx.GraphModule) -> bool: - """True if a FX graph returns a tuple""" - if not isinstance(gm, fx.GraphModule): - return True # can't check this, assume true - (rv, ) = output_node(gm).args - if isinstance(rv, (list, tuple)): - return True - if (isinstance(rv, torch.fx.node.Node) and hasattr(rv.target, "_schema") - and len(rv.target._schema.returns) > 1 and all( - str(ret.type) == "Tensor" - for ret in rv.target._schema.returns)): - # for graphs whose result is one node with multiple outputs - return True - return False - - -def make_graph_return_tuple( - gm: fx.GraphModule, ) -> tuple[Any, fx.GraphModule]: - """ - Mutate gm so it returns a tuple. This is only needed for graphs - not created by torchdynamo that return non-tuples. - Returns: - spec: The original output structure specification - gm: The modified GraphModule that returns a tuple - """ - node = output_node(gm) - (rv, ) = node.args - rv, spec = pytree.tree_flatten(rv) - with gm.graph.inserting_before(node): - gm.graph.output(rv) - gm.graph.erase_node(node) - assert graph_returns_tuple(gm) - - return spec, gm - - class AscendAdaptor(CompilerInterface): name = "AscendAdaptor" @@ -97,33 +60,13 @@ def compile( key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: - def compile_inner(graph, example_inputs): - current_pass_manager = compiler_config["graph_fusion_manager"] - arg_dtypes = get_dtype_from_args(example_inputs) - arg_shapes = get_shapes_from_args(example_inputs) - kwargs = { - "runtime_shape": runtime_shape, - "arg_shapes": arg_shapes, - "arg_dtypes": arg_dtypes - } - graph = current_pass_manager(graph, **kwargs) - return graph - - if not graph_returns_tuple(graph): - spec, graph = make_graph_return_tuple(graph) - else: - spec = None - - compiled_fn = aot_autograd(fw_compiler=compile_inner)(graph, - example_inputs) - - if spec is not None: - - @functools.wraps(compiled_fn) - def wrapper(*args, **kwargs): - return pytree.tree_unflatten(compiled_fn(*args, **kwargs), - spec) - - return wrapper, None - else: - return compiled_fn, None + current_pass_manager = compiler_config["graph_fusion_manager"] + arg_dtypes = get_dtype_from_args(example_inputs) + arg_shapes = get_shapes_from_args(example_inputs) + kwargs = { + "runtime_shape": runtime_shape, + "arg_shapes": arg_shapes, + "arg_dtypes": arg_dtypes + } + graph = current_pass_manager(graph, **kwargs) + return graph, None \ No newline at end of file From 444d8b177bd1b016111a887639d2a8373853f9df Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 18 Nov 2025 01:29:15 +0000 Subject: [PATCH 10/25] recover autograd Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/compiler_interface.py | 77 ++++++++++++++++--- 1 file changed, 66 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index ee57afec3e3..e60eeb31e98 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import functools from typing import Any, Callable, Optional import torch @@ -48,6 +48,41 @@ def get_shapes_from_args(args: list[Any]) -> list[torch.Size]: return shape_list +def graph_returns_tuple(gm: fx.GraphModule) -> bool: + """True if a FX graph returns a tuple""" + if not isinstance(gm, fx.GraphModule): + return True # can't check this, assume true + (rv, ) = output_node(gm).args + if isinstance(rv, (list, tuple)): + return True + if (isinstance(rv, torch.fx.node.Node) and hasattr(rv.target, "_schema") + and len(rv.target._schema.returns) > 1 and all( + str(ret.type) == "Tensor" + for ret in rv.target._schema.returns)): + # for graphs whose result is one node with multiple outputs + return True + return False + + +def make_graph_return_tuple( + gm: fx.GraphModule, ) -> tuple[Any, fx.GraphModule]: + """ + Mutate gm so it returns a tuple. This is only needed for graphs + not created by torchdynamo that return non-tuples. + Returns: + spec: The original output structure specification + gm: The modified GraphModule that returns a tuple + """ + node = output_node(gm) + (rv, ) = node.args + rv, spec = pytree.tree_flatten(rv) + with gm.graph.inserting_before(node): + gm.graph.output(rv) + gm.graph.erase_node(node) + assert graph_returns_tuple(gm) + + return spec, gm + class AscendAdaptor(CompilerInterface): name = "AscendAdaptor" @@ -60,13 +95,33 @@ def compile( key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: - current_pass_manager = compiler_config["graph_fusion_manager"] - arg_dtypes = get_dtype_from_args(example_inputs) - arg_shapes = get_shapes_from_args(example_inputs) - kwargs = { - "runtime_shape": runtime_shape, - "arg_shapes": arg_shapes, - "arg_dtypes": arg_dtypes - } - graph = current_pass_manager(graph, **kwargs) - return graph, None \ No newline at end of file + def compile_inner(graph, example_inputs): + current_pass_manager = compiler_config["graph_fusion_manager"] + arg_dtypes = get_dtype_from_args(example_inputs) + arg_shapes = get_shapes_from_args(example_inputs) + kwargs = { + "runtime_shape": runtime_shape, + "arg_shapes": arg_shapes, + "arg_dtypes": arg_dtypes + } + graph = current_pass_manager(graph, **kwargs) + return graph + + if not graph_returns_tuple(graph): + spec, graph = make_graph_return_tuple(graph) + else: + spec = None + + compiled_fn = aot_autograd(fw_compiler=compile_inner)(graph, + example_inputs) + + if spec is not None: + + @functools.wraps(compiled_fn) + def wrapper(*args, **kwargs): + return pytree.tree_unflatten(compiled_fn(*args, **kwargs), + spec) + + return wrapper, None + else: + return compiled_fn, None \ No newline at end of file From 2b4922c6833c7c1e8eb108097ef34ebc435a4831 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 20 Nov 2025 04:56:37 +0000 Subject: [PATCH 11/25] resolve conflict Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 14 ++++++++------ vllm_ascend/ops/layernorm.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 2902500f828..d2b420ee7f8 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -36,12 +36,9 @@ def __init__(self, vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} torchair_graph_config = additional_config.get("torchair_graph_config", {}) - self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config) - - ascend_compilation_config = additional_config.get( - "ascend_compilation_config", {}) - self.ascend_compilation_config = AscendCompilationConfig( - **ascend_compilation_config) + + self.torchair_graph_config = TorchairGraphConfig( + torchair_graph_config, vllm_config, additional_config) ascend_scheduler_config = additional_config.get( "ascend_scheduler_config", {}) @@ -140,6 +137,11 @@ def __init__(self, vllm_config): if self.pd_tp_ratio == 0: raise AssertionError( "Only support P node tp size lagger then D node tp size") + self.SLO_limits_for_dynamic_batch = additional_config.get( + "SLO_limits_for_dynamic_batch", -1) + from vllm_ascend.utils import \ + get_flashcomm2_oproj_tp_size_and_validate_config + self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(self, vllm_config) class AscendCompilationConfig: diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 025e4fae86d..b65647f5622 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -24,6 +24,23 @@ class AscendRMSNorm(RMSNorm): + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) + vllm_config = get_current_vllm_config() + self.bias = None + # quantization with anti_method m4 will generate none-zero norm bias + if vllm_config.quant_config is not None and \ + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) def forward_oot( self, From 1445020e52828eced1157d7dcc87db65dfc9df80 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 20 Nov 2025 06:59:46 +0000 Subject: [PATCH 12/25] adapot vllm change Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 5 +++++ vllm_ascend/platform.py | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index d2b420ee7f8..1fed2bf4984 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -39,6 +39,11 @@ def __init__(self, vllm_config): self.torchair_graph_config = TorchairGraphConfig( torchair_graph_config, vllm_config, additional_config) + + ascend_compilation_config = additional_config.get( + "ascend_compilation_config", {}) + self.ascend_compilation_config = AscendCompilationConfig( + **ascend_compilation_config) ascend_scheduler_config = additional_config.get( "ascend_scheduler_config", {}) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 950ee48f685..4e4515677d5 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -70,6 +70,15 @@ def is_sleep_mode_available(self) -> bool: @property def pass_key(self) -> str: return "graph_fusion_manager" + + @classmethod + def get_pass_manager_cls(cls) -> str: + return "vllm_ascend.compilation.graph_fusion_pass_manager.GraphFusionPassManager" + + @classmethod + def get_compile_backend(self) -> str: + from vllm_ascend.compilation.compiler_interface import AscendAdaptor + return AscendAdaptor.__module__ + "." + AscendAdaptor.__name__ @classmethod def pre_register_and_update(cls, From e4b290454bac4ff2a0bf320ed6c2f59503172bcb Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Wed, 26 Nov 2025 08:31:28 +0000 Subject: [PATCH 13/25] solve accuracy problem Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 42 +++++++-- vllm_ascend/compilation/acl_graph.py | 1 - vllm_ascend/compilation/compiler_interface.py | 89 ++++++++----------- .../compilation/graph_fusion_pass_manager.py | 13 ++- .../{ => passes}/quant_fusion_pass.py | 38 ++++---- vllm_ascend/ops/layernorm.py | 7 +- vllm_ascend/ops/rotary_embedding.py | 2 +- vllm_ascend/platform.py | 6 +- 8 files changed, 101 insertions(+), 97 deletions(-) rename vllm_ascend/compilation/{ => passes}/quant_fusion_pass.py (75%) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 1fed2bf4984..9ecb71d663d 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -36,10 +36,10 @@ def __init__(self, vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} torchair_graph_config = additional_config.get("torchair_graph_config", {}) - + self.torchair_graph_config = TorchairGraphConfig( torchair_graph_config, vllm_config, additional_config) - + ascend_compilation_config = additional_config.get( "ascend_compilation_config", {}) self.ascend_compilation_config = AscendCompilationConfig( @@ -146,12 +146,17 @@ def __init__(self, vllm_config): "SLO_limits_for_dynamic_batch", -1) from vllm_ascend.utils import \ get_flashcomm2_oproj_tp_size_and_validate_config - self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(self, vllm_config) + self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config( + self, vllm_config) class AscendCompilationConfig: """ - Configuration Object for ascend_compilation_config from additional_config + Configuration for controlling the behavior of Ascend graph optimization. + + This class provides a way to configure graph fusion optimizations. + These configurations directly impact the performance and behavior of models + deployed on Ascend platforms. """ def __init__(self, @@ -159,6 +164,27 @@ def __init__(self, fx_graph_eager: bool = False, enable_quantization_fusion: bool = True, **kwargs): + """ + Initialize the configuration. + + Args: + enable_graph_fusion (bool): Whether to enable graph fusion optimization. + When set to True, the system will attempt to fuse multiple operations + into more efficient computational graph structures to improve performance. + Default: True + + fx_graph_eager (bool): Whether to use eager mode for graph transformation. + When set to False, uses symbolic execution for graph transformation; + when set to True, uses eager execution mode. + Default: False + + enable_quantization_fusion (bool): Whether to enable quantization fusion optimization. + When set to True, the system will optimize quantization-related operations, + reducing the number of quantization/dequantization nodes. + Default: True + + **kwargs: Additional optional parameters for forward compatibility and configuration extension. + """ self.enable_graph_fusion = enable_graph_fusion self.fx_graph_eager = fx_graph_eager self.enable_quantization_fusion = enable_quantization_fusion @@ -353,10 +379,10 @@ def check_ascend_config(vllm_config, enforce_eager): "graph fusion enabled! Automatic kernel fusion is expected." ) - if ascend_config.ascend_compilation_config.enable_quantization_fusion: - logger.info( - "Quantization fusion enabled! op fusion on quantization are expected. " - ) + if ascend_config.ascend_compilation_config.enable_quantization_fusion: + logger.info( + "Quantization fusion enabled! op fusion on quantization are expected. " + ) if vllm_config.model_config: model_type = vllm_config.model_config.hf_config.model_type diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 9c4e33b0e66..ea42dff5608 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -69,7 +69,6 @@ def __init__(self, "ascend_compilation_config", {}) self.fx_graph_eager = self.ascend_compilation_config.get( "fx_graph_eager", False) - self.graph_pool = graph_pool self.runtime_mode = runtime_mode self.compilation_config = vllm_config.compilation_config diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index e60eeb31e98..262e5e8c399 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -16,38 +16,19 @@ # limitations under the License. # import functools +from collections.abc import Sequence from typing import Any, Callable, Optional import torch import torch.fx as fx import torch.utils._pytree as pytree from torch._dynamo.backends.common import aot_autograd -from torch._inductor.utils import output_node +from torch._inductor.decomposition import select_decomp_table +from torch._inductor.utils import InputType, output_node +from torch.fx import GraphModule from vllm.compilation.compiler_interface import CompilerInterface -def get_dtype_from_args(args: list[Any]) -> list[torch.dtype]: - """ - Extract the dtype from the kwargs dictionary. - """ - dtype_list = [] - for value in args: - if isinstance(value, torch.Tensor): - dtype_list.append(value.dtype) - return dtype_list - - -def get_shapes_from_args(args: list[Any]) -> list[torch.Size]: - """ - Extract the shapes from the kwargs dictionary. - """ - shape_list = [] - for value in args: - if isinstance(value, torch.Tensor): - shape_list.append(value.shape) - return shape_list - - def graph_returns_tuple(gm: fx.GraphModule) -> bool: """True if a FX graph returns a tuple""" if not isinstance(gm, fx.GraphModule): @@ -65,13 +46,13 @@ def graph_returns_tuple(gm: fx.GraphModule) -> bool: def make_graph_return_tuple( - gm: fx.GraphModule, ) -> tuple[Any, fx.GraphModule]: + gm: GraphModule, + inputs: Sequence[InputType], + compile_gm: Callable[..., Any], +) -> Callable[..., Any]: """ Mutate gm so it returns a tuple. This is only needed for graphs not created by torchdynamo that return non-tuples. - Returns: - spec: The original output structure specification - gm: The modified GraphModule that returns a tuple """ node = output_node(gm) (rv, ) = node.args @@ -81,7 +62,26 @@ def make_graph_return_tuple( gm.graph.erase_node(node) assert graph_returns_tuple(gm) - return spec, gm + compiled_fn = compile_gm(gm, inputs) + + @functools.wraps(compiled_fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) + + return wrapper + + +def compile_fx(model_: GraphModule, example_inputs_: list, + inner_compile: Callable, decompositions: dict) -> Callable: + recursive_compile_fx = functools.partial(compile_fx, + inner_compile=inner_compile, + decompositions=decompositions) + + if not graph_returns_tuple(model_): + return make_graph_return_tuple(model_, example_inputs_, + recursive_compile_fx) + return aot_autograd(fw_compiler=inner_compile)(model_, example_inputs_) + class AscendAdaptor(CompilerInterface): name = "AscendAdaptor" @@ -97,31 +97,16 @@ def compile( def compile_inner(graph, example_inputs): current_pass_manager = compiler_config["graph_fusion_manager"] - arg_dtypes = get_dtype_from_args(example_inputs) - arg_shapes = get_shapes_from_args(example_inputs) - kwargs = { - "runtime_shape": runtime_shape, - "arg_shapes": arg_shapes, - "arg_dtypes": arg_dtypes - } - graph = current_pass_manager(graph, **kwargs) + graph = current_pass_manager(graph, runtime_shape) return graph - if not graph_returns_tuple(graph): - spec, graph = make_graph_return_tuple(graph) - else: - spec = None - - compiled_fn = aot_autograd(fw_compiler=compile_inner)(graph, - example_inputs) - - if spec is not None: + decompositions = select_decomp_table() - @functools.wraps(compiled_fn) - def wrapper(*args, **kwargs): - return pytree.tree_unflatten(compiled_fn(*args, **kwargs), - spec) + compiled_fn = compile_fx( + model_=graph, + example_inputs_=example_inputs, + inner_compile=compile_inner, + decompositions=decompositions, + ) - return wrapper, None - else: - return compiled_fn, None \ No newline at end of file + return compiled_fn, None diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index bff150c9cb4..33d11cc1c4a 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -23,19 +23,18 @@ class GraphFusionPassManager: """ - A pass manager for graph rewriting passes. + A pass manager for graph fusion passes. It handles the configuration and execution of passes. - The counterpart in vllm is PostGradPassManager. Since torch_npu does not - support inductor and triton for now, we choose to adopt the graph rewriter on - fx graph rather than the inductor pass manager. + The counterpart in vllm is PostGradPassManager. Since torch_npu + does not support triton for now, we define our own pass manager. """ def __init__(self): self.passes: list[VllmInductorPass] = [] - def __call__(self, graph: fx.Graph, **kwargs) -> fx.Graph: + def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph: for pass_ in self.passes: - if pass_.is_applicable(**kwargs): + if pass_.is_applicable(runtime_shape): pass_(graph) return graph @@ -49,6 +48,6 @@ def configure(self, config: VllmConfig): "ascend_compilation_config", {}) if self.ascend_compilation_config.get("enable_quantization_fusion", True): - from .quant_fusion_pass import AscendQuantFusionPass + from .passes.quant_fusion_pass import AscendQuantFusionPass self.passes.append(AscendQuantFusionPass(config)) # Add more passes here as needed diff --git a/vllm_ascend/compilation/quant_fusion_pass.py b/vllm_ascend/compilation/passes/quant_fusion_pass.py similarity index 75% rename from vllm_ascend/compilation/quant_fusion_pass.py rename to vllm_ascend/compilation/passes/quant_fusion_pass.py index dfd35fb5088..e2c42ca228a 100644 --- a/vllm_ascend/compilation/quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/quant_fusion_pass.py @@ -23,8 +23,9 @@ class AddRMSNormQuantPattern: - def __init__(self, vllm_config): + def __init__(self, vllm_config, eps=1e-6): self.vllm_config = vllm_config + self.eps = eps def get_inputs(self): """ @@ -41,10 +42,10 @@ def register(self, pm_pass: PatternMatcherPass): def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): """ - Pattern for AddRMSNormQuant fusion. - """ + Pattern for AddRMSNormQuant fusion. + """ output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, - rms_norm_weight, 1e-6) + rms_norm_weight, self.eps) out0 = output[0] out1 = output[2] quantized_output = torch.ops.npu.npu_quantize( @@ -54,8 +55,8 @@ def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): def replacement(rms_norm_input, residual, rms_norm_weight, scale, offset): """ - Replacement for the AddRMSNormQuant fusion. - """ + Replacement for the AddRMSNormQuant fusion. + """ output = torch.ops.npu.npu_add_rms_norm_quant( rms_norm_input, residual, @@ -63,7 +64,7 @@ def replacement(rms_norm_input, residual, rms_norm_weight, scale, 1. / scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. offset, - epsilon=1e-6) + epsilon=self.eps) quantized_output = output[0] out1 = output[2] return quantized_output, out1 @@ -79,26 +80,21 @@ class AscendQuantFusionPass(VllmInductorPass): def __init__(self, vllm_config): super().__init__(vllm_config) - self.patterns: PatternMatcherPass = PatternMatcherPass( + self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass( pass_name="rmsnorm_quant_fusion_pass") - AddRMSNormQuantPattern(vllm_config).register(self.patterns) + + common_epsilons = [1e-5, 1e-6] + for eps in common_epsilons: + AddRMSNormQuantPattern(vllm_config, + eps=eps).register(self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin() - matched_count = self.patterns.apply(graph) + matched_count = self.pattern_match_passes.apply(graph) self.end_and_log() - def is_applicable(self, **kwargs): + def is_applicable(self, runtime_shape): """ Check if the pass is applicable for the current configuration. """ - arg_dtypes = kwargs.get("arg_dtypes", None) - if arg_dtypes is None: - return False - # We assume the first tensor's dtype is the data type of this model, update this solution when there is - # better solution. - dtype = arg_dtypes[0] if isinstance( - arg_dtypes, list) and len(arg_dtypes) > 0 else arg_dtypes - # We found that the kernel npu_add_rms_norm_quant accept varying data format for different dtypes, therefore, we only - # provide the solution on bfloat16 here. - return dtype in (torch.bfloat16, ) + return True diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index b65647f5622..6a95d5bf88d 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -19,12 +19,11 @@ import torch from vllm.config import get_current_vllm_config -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm class AscendRMSNorm(RMSNorm): - + def __init__( self, hidden_size: int, @@ -49,9 +48,9 @@ def forward_oot( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu - from vllm_ascend.utils import is_310p + from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type if residual is not None: - if is_310p(): + if get_ascend_device_type() == AscendDeviceType._310P: orig_dtype = residual.dtype x = x + residual.to(x.dtype) residual = x.to(orig_dtype) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 91a6f09fa1a..c10900a10c8 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -73,7 +73,7 @@ def _rope_forward_oot( query = query.contiguous().view(1, query.shape[0], -1, self.head_size) key = key.contiguous().view(1, key.shape[0], -1, self.head_size) - torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin) + query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin) elif self.rotary_dim < self.head_size: num_tokens = query.shape[0] query = query.view(num_tokens, -1, self.head_size) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 4e4515677d5..f1945ece4c0 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -66,15 +66,15 @@ class NPUPlatform(Platform): def is_sleep_mode_available(self) -> bool: return True - + @property def pass_key(self) -> str: return "graph_fusion_manager" - + @classmethod def get_pass_manager_cls(cls) -> str: return "vllm_ascend.compilation.graph_fusion_pass_manager.GraphFusionPassManager" - + @classmethod def get_compile_backend(self) -> str: from vllm_ascend.compilation.compiler_interface import AscendAdaptor From 501aa12952357fbcdc7f162c0ac94aed3bd7b805 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Wed, 26 Nov 2025 09:29:42 +0000 Subject: [PATCH 14/25] tiny fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ops/rotary_embedding.py | 3 ++- vllm_ascend/platform.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index c10900a10c8..045c9c65343 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -73,7 +73,8 @@ def _rope_forward_oot( query = query.contiguous().view(1, query.shape[0], -1, self.head_size) key = key.contiguous().view(1, key.shape[0], -1, self.head_size) - query, key = torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin) + query, key = torch_npu.npu_apply_rotary_pos_emb( + query, key, self.cos, self.sin) elif self.rotary_dim < self.head_size: num_tokens = query.shape[0] query = query.view(num_tokens, -1, self.head_size) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index f1945ece4c0..bdfb254b1c1 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -227,6 +227,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + from vllm_ascend.compilation.compiler_interface import AscendAdaptor + compilation_config.oot_compiler = AscendAdaptor.__module__ + "." + AscendAdaptor.__name__ + if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: compilation_config.mode = CompilationMode.NONE elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: @@ -239,7 +242,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: compilation_config.use_inductor = False compilation_config.splitting_ops.extend(["vllm::mla_forward"]) update_aclgraph_sizes(vllm_config) - compilation_config.oot_compiler = AscendAdaptor.__module__ + "." + AscendAdaptor.__name__ elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ compilation_config.cudagraph_mode == CUDAGraphMode.FULL: logger.info( From 22dade5e6dbd0933a3da9117370dd3e98be67caa Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Wed, 26 Nov 2025 09:52:05 +0000 Subject: [PATCH 15/25] fix __init__ Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/passes/__init__.py | 0 vllm_ascend/compilation/passes/quant_fusion_pass.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 vllm_ascend/compilation/passes/__init__.py diff --git a/vllm_ascend/compilation/passes/__init__.py b/vllm_ascend/compilation/passes/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/compilation/passes/quant_fusion_pass.py b/vllm_ascend/compilation/passes/quant_fusion_pass.py index e2c42ca228a..5968792e38e 100644 --- a/vllm_ascend/compilation/passes/quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/quant_fusion_pass.py @@ -90,7 +90,7 @@ def __init__(self, vllm_config): def __call__(self, graph: torch.fx.Graph): self.begin() - matched_count = self.pattern_match_passes.apply(graph) + matched_count = self.pattern_match_passes.apply(graph) # noqa: F841 self.end_and_log() def is_applicable(self, runtime_shape): From dba9a3205919e0d47321ee2252b385f8c8bc78d7 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 27 Nov 2025 11:35:10 +0000 Subject: [PATCH 16/25] add doc string and remove unuse code Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/acl_graph.py | 2 + vllm_ascend/compilation/compiler_interface.py | 69 ++++--------------- .../compilation/graph_fusion_pass_manager.py | 6 +- .../compilation/passes/quant_fusion_pass.py | 8 ++- vllm_ascend/ops/rotary_embedding.py | 2 + vllm_ascend/platform.py | 21 ++++-- 6 files changed, 44 insertions(+), 64 deletions(-) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index ea42dff5608..6ea0f41a955 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -113,6 +113,8 @@ def __call__(self, *args, **kwargs): # matches. This enables properly dispatching to the correct # CUDAGraphWrapper when nesting multiple instances with different # runtime modes. + # When fx_graph_eager is specified, we only trigger graph fusion to + # fuse the kernels without further capture them into the aclgraph. return self.runnable(*args, **kwargs) if batch_descriptor not in self.concrete_aclgraph_entries: diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 262e5e8c399..23e07bb441a 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -16,75 +16,36 @@ # limitations under the License. # import functools -from collections.abc import Sequence from typing import Any, Callable, Optional -import torch import torch.fx as fx -import torch.utils._pytree as pytree from torch._dynamo.backends.common import aot_autograd +from torch._inductor.compile_fx import (graph_returns_tuple, + make_graph_return_tuple) from torch._inductor.decomposition import select_decomp_table -from torch._inductor.utils import InputType, output_node from torch.fx import GraphModule from vllm.compilation.compiler_interface import CompilerInterface -def graph_returns_tuple(gm: fx.GraphModule) -> bool: - """True if a FX graph returns a tuple""" - if not isinstance(gm, fx.GraphModule): - return True # can't check this, assume true - (rv, ) = output_node(gm).args - if isinstance(rv, (list, tuple)): - return True - if (isinstance(rv, torch.fx.node.Node) and hasattr(rv.target, "_schema") - and len(rv.target._schema.returns) > 1 and all( - str(ret.type) == "Tensor" - for ret in rv.target._schema.returns)): - # for graphs whose result is one node with multiple outputs - return True - return False - - -def make_graph_return_tuple( - gm: GraphModule, - inputs: Sequence[InputType], - compile_gm: Callable[..., Any], -) -> Callable[..., Any]: - """ - Mutate gm so it returns a tuple. This is only needed for graphs - not created by torchdynamo that return non-tuples. - """ - node = output_node(gm) - (rv, ) = node.args - rv, spec = pytree.tree_flatten(rv) - with gm.graph.inserting_before(node): - gm.graph.output(rv) - gm.graph.erase_node(node) - assert graph_returns_tuple(gm) - - compiled_fn = compile_gm(gm, inputs) - - @functools.wraps(compiled_fn) - def wrapper(*args: Any, **kwargs: Any) -> Any: - return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) - - return wrapper - - -def compile_fx(model_: GraphModule, example_inputs_: list, +def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable: recursive_compile_fx = functools.partial(compile_fx, inner_compile=inner_compile, decompositions=decompositions) - if not graph_returns_tuple(model_): - return make_graph_return_tuple(model_, example_inputs_, + if not graph_returns_tuple(graph): + return make_graph_return_tuple(graph, example_inputs, recursive_compile_fx) - return aot_autograd(fw_compiler=inner_compile)(model_, example_inputs_) + return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs) -class AscendAdaptor(CompilerInterface): - name = "AscendAdaptor" +class AscendCompiler(CompilerInterface): + """ + AscendCompiler is a custom compiler interface for the Ascend platform. + This class provides a method to compile a PyTorch FX graph module with + specific configurations for graph fusion and decomposition. + """ + name = "AscendCompiler" def compile( self, @@ -103,8 +64,8 @@ def compile_inner(graph, example_inputs): decompositions = select_decomp_table() compiled_fn = compile_fx( - model_=graph, - example_inputs_=example_inputs, + graph=graph, + example_inputs=example_inputs, inner_compile=compile_inner, decompositions=decompositions, ) diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 33d11cc1c4a..27e7e3c14bc 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -43,11 +43,11 @@ def add(self, pass_: VllmInductorPass): self.passes.append(pass_) def configure(self, config: VllmConfig): - # By default, we enable the graph rewriter and quantization fusion pass. + # By default, we enable the graph fusion and quantization fusion pass. self.ascend_compilation_config: dict = config.additional_config.get( "ascend_compilation_config", {}) - if self.ascend_compilation_config.get("enable_quantization_fusion", - True): + self.pass_config = config.compilation_config.pass_config + if self.pass_config.get("enable_ascend_quant_fusion_pass", True): from .passes.quant_fusion_pass import AscendQuantFusionPass self.passes.append(AscendQuantFusionPass(config)) # Add more passes here as needed diff --git a/vllm_ascend/compilation/passes/quant_fusion_pass.py b/vllm_ascend/compilation/passes/quant_fusion_pass.py index 5968792e38e..aecbb5718db 100644 --- a/vllm_ascend/compilation/passes/quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/quant_fusion_pass.py @@ -15,6 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging + import torch import torch._inductor.pattern_matcher as pm from torch._inductor.pattern_matcher import PatternMatcherPass @@ -88,10 +90,10 @@ def __init__(self, vllm_config): AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) + @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.Graph): - self.begin() - matched_count = self.pattern_match_passes.apply(graph) # noqa: F841 - self.end_and_log() + self.matched_count = self.pattern_match_passes.apply(graph) + logging.info("Replaced %s patterns", self.matched_count) def is_applicable(self, runtime_shape): """ diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 045c9c65343..ee9dd9f9302 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -73,6 +73,8 @@ def _rope_forward_oot( query = query.contiguous().view(1, query.shape[0], -1, self.head_size) key = key.contiguous().view(1, key.shape[0], -1, self.head_size) + # Although this function modifies in-place, please retain the function's return value. + # Otherwise, the graph fusion operation may fail. query, key = torch_npu.npu_apply_rotary_pos_emb( query, key, self.cos, self.sin) elif self.rotary_dim < self.head_size: diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index bdfb254b1c1..7ffd6fa253e 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -69,16 +69,29 @@ def is_sleep_mode_available(self) -> bool: @property def pass_key(self) -> str: + """ + Inductor config key for the PassManager custom pass, for example 'post_grad_custom_post_pass'. + It is a parameter of inductor_config used to register custom passes. + Currently, we only use Inductor's 'pattern matcher' functionality, so we define our own pass_key. + """ return "graph_fusion_manager" @classmethod def get_pass_manager_cls(cls) -> str: + """ + Get the pass manager class for this platform. + It will be registered as a custom pass under the current_platform.pass_key. + """ return "vllm_ascend.compilation.graph_fusion_pass_manager.GraphFusionPassManager" @classmethod def get_compile_backend(self) -> str: - from vllm_ascend.compilation.compiler_interface import AscendAdaptor - return AscendAdaptor.__module__ + "." + AscendAdaptor.__name__ + """ + Get the custom compile backend. Previously, we used EagerAdaptor by default. + To use graph fusion operations, we defined our own backend compiler. + """ + from vllm_ascend.compilation.compiler_interface import AscendCompiler + return AscendCompiler.__module__ + "." + AscendCompiler.__name__ @classmethod def pre_register_and_update(cls, @@ -227,8 +240,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - from vllm_ascend.compilation.compiler_interface import AscendAdaptor - compilation_config.oot_compiler = AscendAdaptor.__module__ + "." + AscendAdaptor.__name__ + from vllm_ascend.compilation.compiler_interface import AscendCompiler + compilation_config.oot_compiler = AscendCompiler.__module__ + "." + AscendCompiler.__name__ if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: compilation_config.mode = CompilationMode.NONE From 49daa473e2dfebdb14646b48c4410a20b16f38a5 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 27 Nov 2025 12:11:52 +0000 Subject: [PATCH 17/25] remove unuse fx_graph_eager define Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 8 -------- vllm_ascend/compilation/acl_graph.py | 7 +------ vllm_ascend/compilation/graph_fusion_pass_manager.py | 4 ++-- vllm_ascend/compilation/passes/quant_fusion_pass.py | 3 ++- 4 files changed, 5 insertions(+), 17 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 9ecb71d663d..c2b333ecab9 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -161,7 +161,6 @@ class AscendCompilationConfig: def __init__(self, enable_graph_fusion: bool = True, - fx_graph_eager: bool = False, enable_quantization_fusion: bool = True, **kwargs): """ @@ -173,11 +172,6 @@ def __init__(self, into more efficient computational graph structures to improve performance. Default: True - fx_graph_eager (bool): Whether to use eager mode for graph transformation. - When set to False, uses symbolic execution for graph transformation; - when set to True, uses eager execution mode. - Default: False - enable_quantization_fusion (bool): Whether to enable quantization fusion optimization. When set to True, the system will optimize quantization-related operations, reducing the number of quantization/dequantization nodes. @@ -186,7 +180,6 @@ def __init__(self, **kwargs: Additional optional parameters for forward compatibility and configuration extension. """ self.enable_graph_fusion = enable_graph_fusion - self.fx_graph_eager = fx_graph_eager self.enable_quantization_fusion = enable_quantization_fusion # Add more compilation related configs here as needed @@ -373,7 +366,6 @@ def check_ascend_config(vllm_config, enforce_eager): "it has been disabled automatically.") # aclgraph case else: - # This graph fusion can actually works on eager mode. if ascend_config.ascend_compilation_config.enable_graph_fusion: logger.info( "graph fusion enabled! Automatic kernel fusion is expected." diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 6ea0f41a955..2e347691d19 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -67,8 +67,6 @@ def __init__(self, self.vllm_config = vllm_config self.ascend_compilation_config: dict = vllm_config.additional_config.get( "ascend_compilation_config", {}) - self.fx_graph_eager = self.ascend_compilation_config.get( - "fx_graph_eager", False) self.runtime_mode = runtime_mode self.compilation_config = vllm_config.compilation_config @@ -105,16 +103,13 @@ def __call__(self, *args, **kwargs): aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode if aclgraph_runtime_mode == CUDAGraphMode.NONE or \ - aclgraph_runtime_mode != self.runtime_mode or \ - self.fx_graph_eager: + aclgraph_runtime_mode != self.runtime_mode: # CUDAGraphMode.NONE could mean the profile run, a warmup run, or # running without aclgraphs. # We do not trigger capture/replay if the runtime mode is not # matches. This enables properly dispatching to the correct # CUDAGraphWrapper when nesting multiple instances with different # runtime modes. - # When fx_graph_eager is specified, we only trigger graph fusion to - # fuse the kernels without further capture them into the aclgraph. return self.runnable(*args, **kwargs) if batch_descriptor not in self.concrete_aclgraph_entries: diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 27e7e3c14bc..dc6e4ba6dcb 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -46,8 +46,8 @@ def configure(self, config: VllmConfig): # By default, we enable the graph fusion and quantization fusion pass. self.ascend_compilation_config: dict = config.additional_config.get( "ascend_compilation_config", {}) - self.pass_config = config.compilation_config.pass_config - if self.pass_config.get("enable_ascend_quant_fusion_pass", True): + if self.ascend_compilation_config.get( + "enable_ascend_quant_fusion_pass", True): from .passes.quant_fusion_pass import AscendQuantFusionPass self.passes.append(AscendQuantFusionPass(config)) # Add more passes here as needed diff --git a/vllm_ascend/compilation/passes/quant_fusion_pass.py b/vllm_ascend/compilation/passes/quant_fusion_pass.py index aecbb5718db..174393d7e1b 100644 --- a/vllm_ascend/compilation/passes/quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/quant_fusion_pass.py @@ -90,10 +90,11 @@ def __init__(self, vllm_config): AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) - @VllmInductorPass.time_and_log def __call__(self, graph: torch.fx.Graph): + self.begin() self.matched_count = self.pattern_match_passes.apply(graph) logging.info("Replaced %s patterns", self.matched_count) + self.end_and_log() def is_applicable(self, runtime_shape): """ From 7d0a33656140604ac340637dd186a286b7e80ed0 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 27 Nov 2025 12:34:24 +0000 Subject: [PATCH 18/25] remove unuse ascend forward_context Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_forward_context.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 11c1d3a0373..bd4a3509e26 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -159,25 +159,6 @@ def set_ascend_forward_context( forward_context.weight_prefetch_method = weight_prefetch_method forward_context.is_mtp_model = is_mtp_model - # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. - # It will be improved later by implementing operator fusion through the FX graph. - # - # set for addrmsnorm+quant fusion. - # this optim now just support dense models due to the specific operators used. - # Once the necessary conditions are met, support for MOE models will also be added. - from vllm_ascend.quantization.quant_config import AscendQuantConfig - model_type_scope = ["llama", "qwen2", "qwen3", "qwen3_moe"] - addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \ - vllm_config.model_config.hf_config.model_type in model_type_scope and \ - forward_context.layer_idx is not None - if addrmsnorm_quant_fusion_enabled: - forward_context.model_instance = model_instance - forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers - forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense" - if vllm_config.model_config.hf_config.model_type == "qwen3_moe": - forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe" - forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled - if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens From b3c36e3ddc2b4fa4c10b9dcd2fb5f79933071e76 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Fri, 28 Nov 2025 07:16:18 +0000 Subject: [PATCH 19/25] add unit tests to the pr Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/ut/ops/test_layernorm.py | 176 +++++++-------------------------- tests/ut/test_ascend_config.py | 17 ++++ tests/ut/test_platform.py | 7 +- 3 files changed, 57 insertions(+), 143 deletions(-) diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index 77af2649aae..e50656e85d4 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -1,16 +1,17 @@ -import unittest from unittest.mock import patch import pytest import torch -from pytest_mock import MockerFixture from vllm.model_executor.layers.layernorm import RMSNorm -from tests.ut.base import PytestBase -from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import AscendDeviceType +@pytest.fixture +def dummy_tensor(): + return torch.randn(4, 8, dtype=torch.float16) + + def mock_rms_norm(x, weight, eps): return x + 1, None @@ -19,145 +20,38 @@ def mock_add_rms_norm(x, residual, weight, eps): return 2 * x, None, 2 * residual -def mock_add_rms_norm_quant_with_bias(x, residual, weight, quant_scale, - quant_offset, beta, epsilon): - x_out = 2 * x - residual_out = 2 * residual - x_out_quant = x_out.to(torch.int8) - residual_out_quant = residual_out.to(torch.int8) - return x_out_quant, None, residual_out_quant - - -class TestAscendRMSNorm(PytestBase): +@pytest.mark.parametrize("is_310p", [True, False]) +@pytest.mark.parametrize("residual", + [None, torch.randn(4, 8, dtype=torch.float32)]) +@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) +@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) +def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p, residual, + dummy_tensor): - @pytest.fixture(autouse=True) - def context(self, mocker: MockerFixture): - mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) - mocker.patch("torch_npu.npu_add_rms_norm", - side_effect=mock_add_rms_norm) - mocker.patch("torch_npu.npu_add_rms_norm_quant", - side_effect=mock_add_rms_norm_quant_with_bias) - mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done", - side_effect=lambda x: None) - - # Test case for the most common and basic scenario - @pytest.mark.parametrize( - "residual", [None, torch.randn(4, 8, dtype=torch.float16)]) - @patch("torch.ops.vllm.maybe_chunk_residual") - def test_forward_oot_basic(self, mock_maybe_chunk_residual, residual): - mock_maybe_chunk_residual.side_effect = lambda x, residual: residual + with patch("vllm_ascend.utils.get_ascend_device_type", + return_value=AscendDeviceType._310P + if is_310p else AscendDeviceType._910_93): layer = RMSNorm(hidden_size=8, eps=1e-05) - x = torch.randn(4, 8, dtype=torch.float16) if residual is not None: - x_out, residual_out = layer.forward_oot(x, residual) - - x_out_expected = 2 * x - residual_out_expected = 2 * residual - - assert torch.allclose(x_out, x_out_expected) - assert torch.allclose(residual_out, residual_out_expected) + out_x, out_residual = layer.forward_oot(dummy_tensor, residual) + + if is_310p: + expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype) + expected_out_x = expected_arg_x + 1 + expected_out_residual = expected_arg_x.to(residual.dtype) + + mock_rmsnorm.assert_called_once() + assert torch.allclose(out_x, expected_out_x) + assert torch.allclose(out_residual, expected_out_residual) + else: + expected_out_x = 2 * dummy_tensor + expected_out_residual = 2 * residual + mock_add_rmsnorm.assert_called_once() + assert torch.allclose(out_x, expected_out_x) + assert torch.allclose(out_residual, expected_out_residual) else: - x_out = layer.forward(x, residual) - x_out_expected = x + 1 - - assert torch.allclose(x_out, x_out_expected) - - # Test case for addrmsnorm + w8a8 quant fusion - def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture): - mock_soc_version = mocker.patch( - "vllm_ascend.utils.get_ascend_device_type") - mock_soc_version.return_value = AscendDeviceType._910_93 - mock_get_forward_context = mocker.patch( - "vllm_ascend.ops.layernorm.get_forward_context") - - # Simulating a scenario with quant_fusion enabled - mock_forward_context = mocker.MagicMock() - - mock_model_instance = mocker.MagicMock() - mock_forward_context.model_instance = mock_model_instance - num_hidden_layers = 3 - mock_model_instance.model.layers = [ - mocker.MagicMock() for _ in range(num_hidden_layers) - ] - - mock_layer_0 = mock_model_instance.model.layers[0] - mock_layer_0.self_attn.qkv_proj = mocker.MagicMock() - mock_layer_0.mlp.gate_up_proj = mocker.MagicMock() - - mock_layer_1 = mock_model_instance.model.layers[1] - mock_layer_1.self_attn.qkv_proj = mocker.MagicMock() - mock_layer_1.mlp.gate_up_proj = mocker.MagicMock() - - mock_quant_method_0_qkv = mocker.MagicMock() - mock_quant_method_0_qkv.quant_method = AscendW8A8LinearMethod() - mock_quant_method_0_gate_up = mocker.MagicMock() - mock_quant_method_0_gate_up.quant_method = AscendW8A8LinearMethod() - mock_layer_0.self_attn.qkv_proj.quant_method = mock_quant_method_0_qkv - mock_layer_0.mlp.gate_up_proj.quant_method = mock_quant_method_0_gate_up - - mock_quant_method_1_qkv = mocker.MagicMock() - mock_quant_method_1_qkv.quant_method = AscendW8A8LinearMethod() - mock_quant_method_1_gate_up = mocker.MagicMock() - mock_quant_method_1_gate_up.quant_method = AscendW8A8LinearMethod() - mock_layer_1.self_attn.qkv_proj.quant_method = mock_quant_method_1_qkv - mock_layer_1.mlp.gate_up_proj.quant_method = mock_quant_method_1_gate_up - - mock_get_forward_context.return_value = mock_forward_context - - mock_forward_context.addrmsnorm_quant_fusion_enabled = True - mock_forward_context.prefetch_mlp_enabled = False - mock_forward_context.layer_idx = 0 - mock_forward_context.num_hidden_layers = num_hidden_layers - mock_forward_context.fusion_linear = "gate_up_dense" - mock_forward_context.weight_prefetch_method = None - mocker.patch("torch.ops.vllm.maybe_chunk_residual", - lambda x, residual: residual) - - # Ensure fusion and layer_idx increment are handled correctly - x = torch.randn(4, 8, dtype=torch.float16) - residual = torch.randn(4, 8, dtype=torch.float16) - layer = RMSNorm(hidden_size=8, eps=1e-05) - - x_out, residual_out = layer.forward_oot(x, residual) - - assert mock_get_forward_context.call_count == 2 - assert mock_forward_context.fusion_linear == "qkv_dense" - assert mock_forward_context.layer_idx == 1 - - x_out, residual_out = layer.forward_oot(x, residual) - - assert mock_get_forward_context.call_count == 4 - assert mock_forward_context.fusion_linear == "gate_up_dense" - assert mock_forward_context.layer_idx == 1 - - mock_forward_context.fusion_linear = "gate_moe" - x_out, residual_out = layer.forward_oot(x, residual) - - assert mock_get_forward_context.call_count == 5 - fusion_linear_expected = "qkv_moe" - assert mock_forward_context.fusion_linear == fusion_linear_expected - assert mock_forward_context.layer_idx == 2 - - x_out, residual_out = layer.forward_oot(x, residual) - - assert mock_get_forward_context.call_count == 6 - fusion_linear_expected = "gate_moe" - assert mock_forward_context.fusion_linear == fusion_linear_expected - assert mock_forward_context.layer_idx == 2 - - # last layer returned directly - x_out, residual_out = layer.forward_oot(x, residual) - - assert mock_get_forward_context.call_count == 7 - assert mock_forward_context.fusion_linear == "qkv_moe" - assert mock_forward_context.layer_idx == 3 - - x_out, residual_out = layer.forward_oot(x, residual) - - assert mock_get_forward_context.call_count == 8 - assert mock_forward_context.fusion_linear == "qkv_moe" - assert mock_forward_context.layer_idx == 3 - + out_x = layer.forward_oot(dummy_tensor, residual) + expected_out_x = dummy_tensor + 1 -if __name__ == '__main__': - unittest.main() + mock_rmsnorm.assert_called_once() + assert torch.allclose(out_x, expected_out_x) diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index be066179f1d..0f9db6764f9 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -56,6 +56,13 @@ def test_init_ascend_config_without_additional_config(self): self.assertTrue(torchair_graph_config.enable_frozen_parameter) self.assertFalse(torchair_graph_config.enable_kv_nz) + ascend_compilation_config = ascend_config.ascend_compilation_config + self.assertTrue(ascend_compilation_config.enable_graph_fusion) + self.assertTrue(ascend_compilation_config.enable_quantization_fusion) + + ascend_scheduler_config = ascend_config.ascend_scheduler_config + self.assertFalse(ascend_scheduler_config.enabled) + @_clean_up_ascend_config def test_init_ascend_config_with_additional_config(self): test_vllm_config = VllmConfig() @@ -70,6 +77,10 @@ def test_init_ascend_config_with_additional_config(self): "enable_frozen_parameter": True, "enable_kv_nz": True }, + "ascend_compilation_config": { + "enable_graph_fusion": False, + "enable_quantization_fusion": False, + }, "multistream_overlap_shared_expert": True, "expert_map_path": "test_expert_map_path", "refresh": True, @@ -87,6 +98,12 @@ def test_init_ascend_config_with_additional_config(self): self.assertTrue(torchair_graph_config.enable_view_optimize) self.assertTrue(torchair_graph_config.enable_frozen_parameter) self.assertTrue(torchair_graph_config.enable_kv_nz) + ascend_compilation_config = ascend_config.ascend_compilation_config + self.assertFalse(ascend_compilation_config.enable_graph_fusion) + self.assertFalse(ascend_compilation_config.enable_quantization_fusion) + + ascend_scheduler_config = ascend_config.ascend_scheduler_config + self.assertTrue(ascend_scheduler_config.enabled) @_clean_up_ascend_config def test_init_ascend_config_with_refresh(self): diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 5dedff7faa7..ed54256329b 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -685,9 +685,12 @@ def test_aclgraph_enable(self): importlib.reload(platform) self.platform.check_and_update_config(VllmConfig) + target_msg = "PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode" + found = any(target_msg in log for log in cm.output) + self.assertTrue( - "PIECEWISE compilation enabled on NPU. use_inductor not supported - " - "using only ACL Graph mode" in cm.output[0]) + found, + f"Expected log message not found. Captured logs: {cm.output}") self.assertEqual( VllmConfig.compilation_config.mode, From 4e4b07526d40cb6a98503a811f8e47eb66cd1810 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Fri, 28 Nov 2025 09:49:29 +0000 Subject: [PATCH 20/25] modify ascend compilation config and update platform config Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/ut/test_ascend_config.py | 3 --- vllm_ascend/ascend_config.py | 16 +--------------- vllm_ascend/compilation/acl_graph.py | 4 +--- .../compilation/graph_fusion_pass_manager.py | 4 ++-- vllm_ascend/platform.py | 7 +++++++ 5 files changed, 11 insertions(+), 23 deletions(-) diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 0f9db6764f9..71a12814c50 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -57,7 +57,6 @@ def test_init_ascend_config_without_additional_config(self): self.assertFalse(torchair_graph_config.enable_kv_nz) ascend_compilation_config = ascend_config.ascend_compilation_config - self.assertTrue(ascend_compilation_config.enable_graph_fusion) self.assertTrue(ascend_compilation_config.enable_quantization_fusion) ascend_scheduler_config = ascend_config.ascend_scheduler_config @@ -78,7 +77,6 @@ def test_init_ascend_config_with_additional_config(self): "enable_kv_nz": True }, "ascend_compilation_config": { - "enable_graph_fusion": False, "enable_quantization_fusion": False, }, "multistream_overlap_shared_expert": True, @@ -99,7 +97,6 @@ def test_init_ascend_config_with_additional_config(self): self.assertTrue(torchair_graph_config.enable_frozen_parameter) self.assertTrue(torchair_graph_config.enable_kv_nz) ascend_compilation_config = ascend_config.ascend_compilation_config - self.assertFalse(ascend_compilation_config.enable_graph_fusion) self.assertFalse(ascend_compilation_config.enable_quantization_fusion) ascend_scheduler_config = ascend_config.ascend_scheduler_config diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index c2b333ecab9..e5bb7e00b17 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -159,19 +159,11 @@ class AscendCompilationConfig: deployed on Ascend platforms. """ - def __init__(self, - enable_graph_fusion: bool = True, - enable_quantization_fusion: bool = True, - **kwargs): + def __init__(self, enable_quantization_fusion: bool = True, **kwargs): """ Initialize the configuration. Args: - enable_graph_fusion (bool): Whether to enable graph fusion optimization. - When set to True, the system will attempt to fuse multiple operations - into more efficient computational graph structures to improve performance. - Default: True - enable_quantization_fusion (bool): Whether to enable quantization fusion optimization. When set to True, the system will optimize quantization-related operations, reducing the number of quantization/dequantization nodes. @@ -179,7 +171,6 @@ def __init__(self, **kwargs: Additional optional parameters for forward compatibility and configuration extension. """ - self.enable_graph_fusion = enable_graph_fusion self.enable_quantization_fusion = enable_quantization_fusion # Add more compilation related configs here as needed @@ -366,11 +357,6 @@ def check_ascend_config(vllm_config, enforce_eager): "it has been disabled automatically.") # aclgraph case else: - if ascend_config.ascend_compilation_config.enable_graph_fusion: - logger.info( - "graph fusion enabled! Automatic kernel fusion is expected." - ) - if ascend_config.ascend_compilation_config.enable_quantization_fusion: logger.info( "Quantization fusion enabled! op fusion on quantization are expected. " diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 2e347691d19..025ff3c12ca 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -65,8 +65,6 @@ def __init__(self, cudagraph_options: Optional[CUDAGraphOptions] = None): self.runnable = runnable self.vllm_config = vllm_config - self.ascend_compilation_config: dict = vllm_config.additional_config.get( - "ascend_compilation_config", {}) self.runtime_mode = runtime_mode self.compilation_config = vllm_config.compilation_config @@ -103,7 +101,7 @@ def __call__(self, *args, **kwargs): aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode if aclgraph_runtime_mode == CUDAGraphMode.NONE or \ - aclgraph_runtime_mode != self.runtime_mode: + aclgraph_runtime_mode != self.runtime_mode: # CUDAGraphMode.NONE could mean the profile run, a warmup run, or # running without aclgraphs. # We do not trigger capture/replay if the runtime mode is not diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index dc6e4ba6dcb..79f24fcd9fc 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -46,8 +46,8 @@ def configure(self, config: VllmConfig): # By default, we enable the graph fusion and quantization fusion pass. self.ascend_compilation_config: dict = config.additional_config.get( "ascend_compilation_config", {}) - if self.ascend_compilation_config.get( - "enable_ascend_quant_fusion_pass", True): + if self.ascend_compilation_config.get("enable_quantization_fusion", + True): from .passes.quant_fusion_pass import AscendQuantFusionPass self.passes.append(AscendQuantFusionPass(config)) # Add more passes here as needed diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 7ffd6fa253e..f8bee8d51de 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -162,6 +162,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config ascend_scheduler_config = ascend_config.ascend_scheduler_config + ascend_compilation_config = ascend_config.ascend_compilation_config + if ascend_compilation_config: + vllm_config.additional_config.setdefault( + "ascend_compilation_config", {}).update( + vars(ascend_compilation_config + ) if not isinstance(ascend_compilation_config, dict) + else ascend_compilation_config) kv_cache_dtype = vllm_config.additional_config.get( "kv_cache_dtype", None) From 247877ed43bba4c16e1bc263619bc0e7fec1853c Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 1 Dec 2025 06:26:53 +0000 Subject: [PATCH 21/25] add dtype check and e2e test Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/e2e/singlecard/test_quantize_fusion.py | 205 ++++++++++++++++++ .../compilation/passes/quant_fusion_pass.py | 6 + 2 files changed, 211 insertions(+) create mode 100644 tests/e2e/singlecard/test_quantize_fusion.py diff --git a/tests/e2e/singlecard/test_quantize_fusion.py b/tests/e2e/singlecard/test_quantize_fusion.py new file mode 100644 index 00000000000..e564654b378 --- /dev/null +++ b/tests/e2e/singlecard/test_quantize_fusion.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable, Sequence +from copy import deepcopy +from typing import Any, Callable, List, Optional, Sequence + +import pytest +import torch +import torch.fx as fx +import torch.nn as nn +import torch_npu +import vllm.config +from torch._inductor.decomposition import select_decomp_table +from vllm.compilation.fx_utils import OpOverload +from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config + +from vllm_ascend.compilation.compiler_interface import compile_fx +from vllm_ascend.compilation.passes.quant_fusion_pass import \ + AscendQuantFusionPass + + +class TestModel(nn.Module): + """ + A minimal test model that simulates the pattern: + AddRMSNorm → Quantization + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.rms_norm_weight = nn.Parameter( + torch.randn(hidden_size, device=device)) + self.quant_scale = torch.tensor([1.0], device=device) + self.quant_offset = torch.tensor([0.0], device=device) + + def forward(self, x): + """ + Forward pass: + 1. Perform npu_add_rms_norm + 2. Quantize the normalized output to int8 + Returns both quantized output and updated residual. + """ + residual = torch.zeros_like(x) + + norm_output, _, new_residual = torch_npu.npu_add_rms_norm( + x, residual, self.rms_norm_weight, self.eps) + + quantized_output = torch_npu.npu_quantize(norm_output, + self.quant_scale, + self.quant_offset, + torch.qint8, -1, False) + + return quantized_output, new_residual + + def ops_in_model_before(self) -> List[OpOverload]: + """Return the list of expected operators BEFORE fusion.""" + return [ + torch.ops.npu.npu_add_rms_norm.default, + torch.ops.npu.npu_quantize.default + ] + + def ops_in_model_after(self) -> List[OpOverload]: + """Return the list of expected operators AFTER successful fusion.""" + return [torch.ops.npu.npu_add_rms_norm_quant.default] + + +class TestBackend: + """ + A custom compilation backend for testing operator fusion passes. + It applies the AscendQuantFusionPass during graph compilation and + records the FX graph before and after the transformation. + """ + + def __init__(self): + vllm_config = get_current_vllm_config() + compile_config = vllm_config.compilation_config + self.custom_passes = [AscendQuantFusionPass(vllm_config=vllm_config)] + self.inductor_config = compile_config.inductor_compile_config + self.inductor_config["graph_fusion_manager"] = self.post_pass + + # Placeholders to store FX graphs for verification + self.graph_pre_pass = None + self.graph_post_pass = None + + def post_pass(self, + graph: fx.Graph, + runtime_shape: int | None = None) -> fx.Graph: + """ + Apply custom graph transformation passes. + """ + self.graph_pre_pass = deepcopy(graph) + for pass_ in self.custom_passes: + pass_(graph) + self.graph_post_pass = deepcopy(graph) + return graph + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None + ) -> tuple[Optional[Callable], Optional[Any]]: + """ + Compile the FX graph using vLLM's Ascend compiler interface. + Wraps the post-pass logic into the inner_compile callback. + """ + + def compile_inner(graph, example_inputs): + current_pass_manager = compiler_config["graph_fusion_manager"] + return current_pass_manager(graph, runtime_shape) + + decompositions = select_decomp_table() + compiled_fn = compile_fx( + graph=graph, + example_inputs=example_inputs, + inner_compile=compile_inner, + decompositions=decompositions, + ) + return compiled_fn, None + + def __call__(self, gm: fx.GraphModule, + example_inputs: List[Any]) -> callable: + """ + Make the backend callable by torch.compile(). + Returns a compiled executable function. + """ + compiled_fn, _ = self.compile( + gm, + example_inputs, + compiler_config={"graph_fusion_manager": self.post_pass}, + runtime_shape=None, + key=None, + ) + return compiled_fn + + def find_nodes_by_target(self, graph: fx.GraphModule, + target: OpOverload) -> List[fx.Node]: + """Helper to find all FX nodes that call a specific operator.""" + return [ + node for node in graph.graph.nodes + if hasattr(node, 'target') and node.target == target + ] + + def check_before_ops(self, + ops: Sequence[OpOverload], + fully_replaced: bool = True): + """ + Verify that the original (unfused) operators exist before the pass + and are fully removed afterward (if fully_replaced=True). + """ + for op in ops: + num_pre = len(self.find_nodes_by_target(self.graph_pre_pass, op)) + num_post = len(self.find_nodes_by_target(self.graph_post_pass, op)) + print(f"Op {op}: pre={num_pre}, post={num_post}") + + assert num_pre > 0, f"Op {op} not found in pre-pass graph" + if fully_replaced: + assert num_post == 0, f"Unexpected op {op} in post-pass graph: {num_post} nodes remain" + + def check_after_ops(self, ops: Sequence[OpOverload]): + """Verify that the fused operator appears in the transformed graph.""" + for op in ops: + num_post = len(self.find_nodes_by_target(self.graph_post_pass, op)) + print(f"Op {op}: post={num_post}") + assert num_post > 0, f"Op {op} not found in post-pass graph" + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("num_tokens", [257]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): + """ + End-to-end test for AddRMSNorm+Quantize fusion. + Compares: Operator presence/absence before and after graph transformation + """ + torch.set_default_dtype(dtype) + torch.manual_seed(1) + + vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype)) + + with vllm.config.set_current_vllm_config(vllm_config): + backend = TestBackend() + model = TestModel(hidden_size, eps, device="npu") + model = model.to("npu") + + x = torch.rand(num_tokens, + hidden_size, + device="npu", + dtype=dtype, + requires_grad=False) + + result_unfused = model(x) + print("Unfused result:", [t.shape for t in result_unfused]) + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x) + print("Fused result:", [t.shape for t in result_fused]) + + print("=== Checking operator fusion ===") + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) diff --git a/vllm_ascend/compilation/passes/quant_fusion_pass.py b/vllm_ascend/compilation/passes/quant_fusion_pass.py index 174393d7e1b..b4fac64802e 100644 --- a/vllm_ascend/compilation/passes/quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/quant_fusion_pass.py @@ -85,6 +85,12 @@ def __init__(self, vllm_config): self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass( pass_name="rmsnorm_quant_fusion_pass") + dtype = vllm_config.model_config.dtype + if dtype not in (torch.bfloat16, torch.float16): + logging.info("Quant fusion not enabled: unsupported dtype %s", + dtype) + return + common_epsilons = [1e-5, 1e-6] for eps in common_epsilons: AddRMSNormQuantPattern(vllm_config, From 9652682f0f086758f24b96e1d2786d434362c055 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 1 Dec 2025 06:30:20 +0000 Subject: [PATCH 22/25] add license to e2e test Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/e2e/singlecard/test_quantize_fusion.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/e2e/singlecard/test_quantize_fusion.py b/tests/e2e/singlecard/test_quantize_fusion.py index e564654b378..b5a4e2d32e0 100644 --- a/tests/e2e/singlecard/test_quantize_fusion.py +++ b/tests/e2e/singlecard/test_quantize_fusion.py @@ -1,6 +1,19 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from collections.abc import Callable, Sequence from copy import deepcopy from typing import Any, Callable, List, Optional, Sequence From 3aaaa447c94243e0e21eac285a5a6fcd4ce66080 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 2 Dec 2025 11:07:12 +0000 Subject: [PATCH 23/25] resolve conflict Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/e2e/singlecard/test_quantize_fusion.py | 4 +--- vllm_ascend/platform.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/e2e/singlecard/test_quantize_fusion.py b/tests/e2e/singlecard/test_quantize_fusion.py index b5a4e2d32e0..17c089984f5 100644 --- a/tests/e2e/singlecard/test_quantize_fusion.py +++ b/tests/e2e/singlecard/test_quantize_fusion.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from collections.abc import Callable, Sequence from copy import deepcopy from typing import Any, Callable, List, Optional, Sequence @@ -135,8 +134,7 @@ def compile_inner(graph, example_inputs): ) return compiled_fn, None - def __call__(self, gm: fx.GraphModule, - example_inputs: List[Any]) -> callable: + def __call__(self, gm: fx.GraphModule, example_inputs: List[Any]): """ Make the backend callable by torch.compile(). Returns a compiled executable function. diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index f8bee8d51de..ad66674a614 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -30,7 +30,6 @@ init_ascend_config) from vllm_ascend.torchair.utils import (check_torchair_cache_exist, delete_torchair_cache_file) -from vllm_ascend.compilation.compiler_interface import AscendAdaptor # isort: off from vllm_ascend.utils import ( From 65270f2e7aa65b50aec934ac8e465e3318fc5809 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Wed, 3 Dec 2025 03:32:21 +0000 Subject: [PATCH 24/25] fix moe w8a8 accuracy and fix ut Signed-off-by: wxsIcey <1790571317@qq.com> --- ...test_quantize_fusion.py => test_quant_fusion.py} | 10 ++++++---- tests/ut/test_ascend_config.py | 3 --- .../compilation/graph_fusion_pass_manager.py | 4 ++-- vllm_ascend/compilation/passes/quant_fusion_pass.py | 13 ++++++++----- vllm_ascend/ops/layernorm.py | 2 ++ 5 files changed, 18 insertions(+), 14 deletions(-) rename tests/e2e/singlecard/{test_quantize_fusion.py => test_quant_fusion.py} (96%) diff --git a/tests/e2e/singlecard/test_quantize_fusion.py b/tests/e2e/singlecard/test_quant_fusion.py similarity index 96% rename from tests/e2e/singlecard/test_quantize_fusion.py rename to tests/e2e/singlecard/test_quant_fusion.py index 17c089984f5..340fa8dee9b 100644 --- a/tests/e2e/singlecard/test_quantize_fusion.py +++ b/tests/e2e/singlecard/test_quant_fusion.py @@ -29,7 +29,7 @@ from vllm_ascend.compilation.compiler_interface import compile_fx from vllm_ascend.compilation.passes.quant_fusion_pass import \ - AscendQuantFusionPass + AddRMSNormQuantFusionPass class TestModel(nn.Module): @@ -81,14 +81,16 @@ def ops_in_model_after(self) -> List[OpOverload]: class TestBackend: """ A custom compilation backend for testing operator fusion passes. - It applies the AscendQuantFusionPass during graph compilation and + It applies the AddRMSNormQuantFusionPass during graph compilation and records the FX graph before and after the transformation. """ def __init__(self): vllm_config = get_current_vllm_config() compile_config = vllm_config.compilation_config - self.custom_passes = [AscendQuantFusionPass(vllm_config=vllm_config)] + self.custom_passes = [ + AddRMSNormQuantFusionPass(vllm_config=vllm_config) + ] self.inductor_config = compile_config.inductor_compile_config self.inductor_config["graph_fusion_manager"] = self.post_pass @@ -184,7 +186,7 @@ def check_after_ops(self, ops: Sequence[OpOverload]): @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): +def test_rmsnorm_quant_fusion(dtype, hidden_size, num_tokens, eps): """ End-to-end test for AddRMSNorm+Quantize fusion. Compares: Operator presence/absence before and after graph transformation diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 71a12814c50..19203a4e874 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -99,9 +99,6 @@ def test_init_ascend_config_with_additional_config(self): ascend_compilation_config = ascend_config.ascend_compilation_config self.assertFalse(ascend_compilation_config.enable_quantization_fusion) - ascend_scheduler_config = ascend_config.ascend_scheduler_config - self.assertTrue(ascend_scheduler_config.enabled) - @_clean_up_ascend_config def test_init_ascend_config_with_refresh(self): test_vllm_config = VllmConfig() diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 79f24fcd9fc..b46bc135321 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -48,6 +48,6 @@ def configure(self, config: VllmConfig): "ascend_compilation_config", {}) if self.ascend_compilation_config.get("enable_quantization_fusion", True): - from .passes.quant_fusion_pass import AscendQuantFusionPass - self.passes.append(AscendQuantFusionPass(config)) + from .passes.quant_fusion_pass import AddRMSNormQuantFusionPass + self.passes.append(AddRMSNormQuantFusionPass(config)) # Add more passes here as needed diff --git a/vllm_ascend/compilation/passes/quant_fusion_pass.py b/vllm_ascend/compilation/passes/quant_fusion_pass.py index b4fac64802e..6dc8fe70b3c 100644 --- a/vllm_ascend/compilation/passes/quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/quant_fusion_pass.py @@ -42,7 +42,9 @@ def get_inputs(self): def register(self, pm_pass: PatternMatcherPass): - def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor): """ Pattern for AddRMSNormQuant fusion. """ @@ -54,8 +56,9 @@ def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): out0, scale, offset, torch.qint8, -1, False) return quantized_output, out1 - def replacement(rms_norm_input, residual, rms_norm_weight, scale, - offset): + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor): """ Replacement for the AddRMSNormQuant fusion. """ @@ -75,7 +78,7 @@ def replacement(rms_norm_input, residual, rms_norm_weight, scale, pm.fwd_only, pm_pass) -class AscendQuantFusionPass(VllmInductorPass): +class AddRMSNormQuantFusionPass(VllmInductorPass): """ A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend. """ @@ -99,7 +102,7 @@ def __init__(self, vllm_config): def __call__(self, graph: torch.fx.Graph): self.begin() self.matched_count = self.pattern_match_passes.apply(graph) - logging.info("Replaced %s patterns", self.matched_count) + logging.debug("Replaced %s patterns", self.matched_count) self.end_and_log() def is_applicable(self, runtime_shape): diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 6a95d5bf88d..cdbba32f7df 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -59,6 +59,8 @@ def forward_oot( else: x, _, residual = torch_npu.npu_add_rms_norm( x, residual, self.weight, self.variance_epsilon) + if self.bias is not None: + x.add_(self.bias) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, From 2bcbeb4ed20813fe909eccac88953a51d91b2a39 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Wed, 3 Dec 2025 03:58:22 +0000 Subject: [PATCH 25/25] remove unuse code and reformat code Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/e2e/singlecard/test_quant_fusion.py | 3 ++- tests/ut/test_ascend_config.py | 3 --- vllm_ascend/compilation/passes/quant_fusion_pass.py | 7 ++++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/e2e/singlecard/test_quant_fusion.py b/tests/e2e/singlecard/test_quant_fusion.py index 340fa8dee9b..c6f7aecfa39 100644 --- a/tests/e2e/singlecard/test_quant_fusion.py +++ b/tests/e2e/singlecard/test_quant_fusion.py @@ -186,7 +186,8 @@ def check_after_ops(self, ops: Sequence[OpOverload]): @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) -def test_rmsnorm_quant_fusion(dtype, hidden_size, num_tokens, eps): +def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int, + num_tokens: int, eps: float): """ End-to-end test for AddRMSNorm+Quantize fusion. Compares: Operator presence/absence before and after graph transformation diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 19203a4e874..ac33ae1536d 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -59,9 +59,6 @@ def test_init_ascend_config_without_additional_config(self): ascend_compilation_config = ascend_config.ascend_compilation_config self.assertTrue(ascend_compilation_config.enable_quantization_fusion) - ascend_scheduler_config = ascend_config.ascend_scheduler_config - self.assertFalse(ascend_scheduler_config.enabled) - @_clean_up_ascend_config def test_init_ascend_config_with_additional_config(self): test_vllm_config = VllmConfig() diff --git a/vllm_ascend/compilation/passes/quant_fusion_pass.py b/vllm_ascend/compilation/passes/quant_fusion_pass.py index 6dc8fe70b3c..87f0e1a429b 100644 --- a/vllm_ascend/compilation/passes/quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/quant_fusion_pass.py @@ -21,11 +21,12 @@ import torch._inductor.pattern_matcher as pm from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.config import VllmConfig class AddRMSNormQuantPattern: - def __init__(self, vllm_config, eps=1e-6): + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): self.vllm_config = vllm_config self.eps = eps @@ -83,7 +84,7 @@ class AddRMSNormQuantFusionPass(VllmInductorPass): A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend. """ - def __init__(self, vllm_config): + def __init__(self, vllm_config: VllmConfig): super().__init__(vllm_config) self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass( pass_name="rmsnorm_quant_fusion_pass") @@ -105,7 +106,7 @@ def __call__(self, graph: torch.fx.Graph): logging.debug("Replaced %s patterns", self.matched_count) self.end_and_log() - def is_applicable(self, runtime_shape): + def is_applicable(self, runtime_shape: int | None = None) -> bool: """ Check if the pass is applicable for the current configuration. """