From 31e4bfc7a8031760a9c3642d0034db16c01f882f Mon Sep 17 00:00:00 2001 From: Zingo Andersen Date: Sat, 8 Mar 2025 16:00:48 +0100 Subject: [PATCH] =?UTF-8?q?Revert=20"Arm=20backend:=20Add=20FuseViewCopyTr?= =?UTF-8?q?ansform=20and=20FuseConstantsPass=20in=20arm=5Fp=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 3a7c23172c2f2895df5a30ca2db290cd65efd56e. --- backends/arm/_passes/arm_pass_manager.py | 11 +- backends/arm/_passes/arm_pass_utils.py | 25 --- .../arm/_passes/fuse_constant_ops_pass.py | 170 ------------------ .../passes/test_fuse_constant_ops_pass.py | 115 ------------ 4 files changed, 2 insertions(+), 319 deletions(-) delete mode 100644 backends/arm/_passes/fuse_constant_ops_pass.py delete mode 100644 backends/arm/test/passes/test_fuse_constant_ops_pass.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 26ff15db396..f8a4a40648f 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -51,7 +51,6 @@ RetraceFoldedDtypesPass, ) from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found] FuseQuantizedActivationPass, ) @@ -79,7 +78,6 @@ UnsqueezeScalarPlaceholdersPass, ) from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, @@ -116,6 +114,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) + self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(RemoveClonePass()) self.add_pass(SizeAdjustConv2DPass()) @@ -129,12 +128,8 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) - self.add_pass(FuseViewCopyTransform()) - self.add_pass(FuseConstantOpsPass(exported_program)) - self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(AnnotateChannelsLastDimOrder()) self.add_pass(InsertRescalePass()) - return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: @@ -160,6 +155,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) + self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(RemoveClonePass()) self.add_pass(SizeAdjustConv2DPass()) @@ -173,9 +169,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) - self.add_pass(FuseViewCopyTransform()) - self.add_pass(FuseConstantOpsPass(exported_program)) - self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(AnnotateChannelsLastDimOrder()) self.add_pass(InsertRescalePass()) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index a8d06713678..3445886ffa7 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -26,7 +26,6 @@ ) from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor -from torch.export.graph_signature import InputKind def is_get_attr_node(node: torch.fx.Node) -> bool: @@ -45,30 +44,6 @@ def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: ) -def get_constant_placeholder_kind( - exp_prog: ExportedProgram, node: torch.fx.Node -) -> InputKind: - if is_param(exp_prog, node): - return InputKind.PARAMETER - if is_buffer(exp_prog, node): - return InputKind.BUFFER - if is_lifted_tensor_constant(exp_prog, node): - return InputKind.CONSTANT_TENSOR - - raise RuntimeError("Node is neither PARAMETER, BUFFER nor CONSTANT_TENSOR") - - -def is_persistent_buffer(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool | None: - if is_buffer(exp_prog, node): - buffer_name = exp_prog.graph_signature.inputs_to_buffers[node.name] - if buffer_name in exp_prog.graph_signature.non_persistent_buffers: - return False - else: - return True - - return None - - def get_param_tensor( exp_prog: ExportedProgram, node: torch.fx.Node ) -> Optional[torch.Tensor]: diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py deleted file mode 100644 index 1fff7d76dfc..00000000000 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging - -import torch._export.utils -from executorch.backends.arm._passes.arm_pass_utils import ( - get_constant_placeholder_kind, - get_param_tensor, - is_persistent_buffer, -) -from executorch.backends.transforms.utils import ( - create_constant_placeholder, - delete_constant_placeholder, -) -from executorch.exir import ExportedProgram -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - -logger = logging.getLogger(__name__) - - -class FuseConstantOpsPass(ExportPass): - """ - Fuses ops with only placeholder parameters into one placeholder parameter node with the op - pre-calulcated on its data. - - Original: - state_dict = {x_tensor_name : data} - def f(): - return x.view(...) - - After pass: - state_dict = {x_tensor_name_fused_const : data.view(...)} - def f(): - return x - """ - - def __init__(self, exported_program: ExportedProgram) -> None: - super().__init__() - self.exported_program = exported_program - - def fuse_nodes(self, node) -> bool: - """ - Takes a node with only parameter inputs and replaces it with one constant tensor node with - the operations already carried out on the data. - """ - - if node.target == exir_ops.edge.aten.full.default: - # Create data from args - size, fill_value = node.args - dtype = node.kwargs["dtype"] - data = torch.full(size, float(fill_value), dtype=dtype) - - insert_pos = list(node.graph.nodes)[0] - else: - # Extract tensors and args from the node - - if len(node.all_input_nodes) == 0: - raise RuntimeError("No inputs found") - - data_list = [ - get_param_tensor(self.exported_program, input_node) - for input_node in node.all_input_nodes - ] - - args = node.args[len(node.all_input_nodes) :] - kwargs = node.kwargs - - if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0: - dequantize_op = ( - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default - ) - - for i in range(len(node.all_input_nodes)): - q_params = node.meta["input_qparams"][i] - data_list[i] = dequantize_op( - data_list[i], - q_params.scale, - q_params.zp, - q_params.qmin, - q_params.qmax, - q_params.dtype, - ) - - # Run the op on the extracted tensor - data = node.target(*data_list, *args, **kwargs) - - if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0: - quantize_op = ( - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default - ) - q_params = node.meta["output_qparams"][0] - data = quantize_op( - data, - q_params.scale, - q_params.zp, - q_params.qmin, - q_params.qmax, - q_params.dtype, - ) - - insert_pos = list(node.all_input_nodes)[0] - - # Make new node the same kind as the first constant input - input_kind = get_constant_placeholder_kind(self.exported_program, insert_pos) - persistent_buffer = is_persistent_buffer(self.exported_program, insert_pos) - - # Create new node - with node.graph.inserting_before(insert_pos): - const_node = create_constant_placeholder( - exp_program=self.exported_program, - graph=node.graph, - kind=input_kind, - name=node.name + "_fused_const", - data=data, - persistent_buffer=persistent_buffer, - ) - - node.replace_all_uses_with(const_node) - - return True - - def call(self, graph_module): - modified = True - input_nodes_to_delete = [] - for node in graph_module.graph.nodes: - if node.op != "call_function": - continue - if node.target == torch.ops.tosa._table.default: - continue - if node.target == exir_ops.edge.aten.repeat.default: - _, multiples = node.args - # Do not fuse if the repeat creates a larger output, i.e. any multiple > 1 - if any((multiple > 1 for multiple in multiples)): - continue - - input_nodes = node.all_input_nodes - input_nodes_constant = ( - torch._export.utils.is_param(self.exported_program, input_node) - or torch._export.utils.is_lifted_tensor_constant( - self.exported_program, input_node - ) - or torch._export.utils.is_buffer(self.exported_program, input_node) - for input_node in input_nodes - ) - input_nodes_single_users = ( - len(input_node.users) == 1 for input_node in input_nodes - ) - - if all(input_nodes_constant) and all(input_nodes_single_users): - try: - self.fuse_nodes(node) - graph_module.recompile() # Recompile needed to catch chains of constant ops - input_nodes_to_delete.extend(input_nodes) - except Exception as e: - logger.warning( - f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}" - ) - - if modified: - graph_module.graph.eliminate_dead_code() - for input_node in input_nodes_to_delete: - delete_constant_placeholder(self.exported_program, input_node) - - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py deleted file mode 100644 index 80d7293607f..00000000000 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import operator -from typing import Tuple - -import torch -from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.test_pipeline import ( - PassPipeline, - TosaPipelineBI, -) - -input_t = Tuple[torch.Tensor] # Input x - - -class FuseParameter(torch.nn.Module): - ops_before_pass = { - "executorch_exir_dialects_edge__ops_aten_full_default": 1, - "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, - "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, - "executorch_exir_dialects_edge__ops_aten_addmm_default": 1, - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, - } - ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1} - ops_not_after_pass = [ - "executorch_exir_dialects_edge__ops_aten_full_default", - "executorch_exir_dialects_edge__ops_aten_view_copy_default", - "executorch_exir_dialects_edge__ops_aten_permute_copy_default", - "executorch_exir_dialects_edge__ops_aten_addmm_default", - ] - - def __init__( - self, - in_features: int = 1, - out_features: int = 1, - bias: bool = True, - ): - super().__init__() - self.fc = torch.nn.Linear( - in_features=in_features, - out_features=out_features, - bias=bias, - ) - - def forward(self, x): - return self.fc(torch.ones(1)) + x - - -class FuseBuffer(torch.nn.Module): - ops_before_pass = { - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, - "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, - } - ops_after_pass = { - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, - "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, - } - ops_not_after_pass = [ - "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" - ] - - def forward(self, x: torch.Tensor): - return (x + 1) * 2 - - -class FuseLiftedTensor(torch.nn.Module): - ops_before_pass = { - "executorch_exir_dialects_edge__ops_aten_select_copy_int": 1, - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, - } - ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1} - ops_not_after_pass = ["executorch_exir_dialects_edge__ops_aten_select_copy_int"] - - def __init__( - self, - ): - super().__init__() - self.lifted_tensor = torch.rand(2) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - sliced = self.lifted_tensor[0] - return operator.add(sliced, x) - - -modules = { - "fuse_parameter": FuseParameter(), - "fuse_buffer": FuseBuffer(), - "fuse_const_tensor": FuseLiftedTensor(), -} - - -@common.parametrize("module", modules) -def test_fuse_batchnorm_tosa_MI(module): - pipeline = PassPipeline[input_t]( - module=module, - test_data=(torch.rand(1),), - tosa_version="TOSA-0.80+MI", - ops_before_pass=module.ops_before_pass, - ops_after_pass=module.ops_after_pass, - ops_not_after_pass=module.ops_not_after_pass, - passes_with_exported_program=[FuseConstantOpsPass], - ) - pipeline.run() - - -@common.parametrize("module", modules) -def test_fuse_batchnorm_tosa_BI(module): - pipeline = TosaPipelineBI[input_t]( - module, (torch.rand(10, 10),), [], [], use_to_edge_transform_and_lower=True - ) - pipeline.run()