diff --git a/backends/arm/test/passes/test_fuse_view_copy.py b/backends/arm/test/passes/test_fuse_view_copy.py new file mode 100644 index 00000000000..7bf931349b6 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_view_copy.py @@ -0,0 +1,82 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform + + +class FuseSequentialViews(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x.view((1, 2, 3, 4)).view((2, 3, 4, 1)).view((2, 3, 4)) + + data = (torch.randn(2, 3, 1, 4),) + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 3, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 1, + } + + +class FuseSequentialWithNoopsViews(torch.nn.Module): + def forward(self, x: torch.Tensor): + return ( + x.view((1, 2, 3, 4)) + .clone() + .view((2, 3, 4, 1)) + .to(dtype=torch.int32) + .view((2, 3, 4)) + .abs() + .reciprocal() + .sqrt() + .view((12, 2)) + ) + + data = (torch.randn(2, 3, 1, 4),) + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 4, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 1, + } + + +class DontFuseBranchingViews(torch.nn.Module): + def forward(self, x: torch.Tensor): + x = x.view((1, 2, 3, 4)) + x1 = x.abs().view((2, 3, 4, 1)) + x2 = x.ceil().view((2, 3, 4, 1)) + return x1 + x2 + + data = (torch.randn(2, 3, 1, 4),) + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 3, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_view_copy": 3, + } + + +tests = { + "fuse_sequential_views": FuseSequentialViews(), + "fuse_sequential_with_noops_views": FuseSequentialWithNoopsViews(), + "dont_fuse_branching_views": DontFuseBranchingViews(), +} + + +@common.parametrize("model", tests) +def test_fuse_view_copy(model): + pipeline = PassPipeline( + model, + model.data, + quantize=False, + ops_before_pass=model.ops_before_pass, + ops_after_pass=model.ops_after_pass, + pass_list=[FuseViewCopyTransform], + ) + pipeline.run() diff --git a/backends/transforms/fuse_view_copy.py b/backends/transforms/fuse_view_copy.py index 1972513d2ef..b7c52f95fa3 100644 --- a/backends/transforms/fuse_view_copy.py +++ b/backends/transforms/fuse_view_copy.py @@ -14,9 +14,41 @@ from executorch.exir.pass_base import ExportPass, PassResult +UNARY_ELEMENTWISE_OPS = [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten._to_copy.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.round.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.silu.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.sign.default, + exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.log.default, +] + + def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: """ - Find chains of view_copy nodes and merge them into one view_copy node. + Find chains of view_copy nodes and unary elementwise ops and set all + view_copy nodes to have the final shape. The views will then be removed + by the remove_noop_view_copy call. + Only merges view_copy nodes that are not used by any other nodes. """ ops = exir_ops.edge @@ -24,21 +56,25 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool] modified = False for node in graph.nodes: if node.op == "call_function" and node.target == view_op: - # find ending view_copy node in chain + # Find a chain of unary elementwise ops and save all view_copy nodes end_node = node + view_ops = [node] while ( end_node.op == "call_function" - and end_node.target == view_op + and end_node.target in UNARY_ELEMENTWISE_OPS and len(end_node.users) == 1 - and list(end_node.users)[0].target == view_op + and list(end_node.users)[0].target in UNARY_ELEMENTWISE_OPS ): end_node = list(end_node.users)[0] - # we can swap the first node's shape arg with the last node's shape arg - if node != end_node: - with graph.inserting_after(node): - new_args = (node.args[0], end_node.args[1]) + if end_node.target == view_op: + view_ops.append(end_node) + + # Set all view_copy nodes to have the final shape + if len(view_ops) > 1: + final_shape = view_ops[-1].args[1] + for node in view_ops: + new_args = (node.args[0], final_shape) node.args = new_args - end_node.replace_all_uses_with(node) modified = True graph.eliminate_dead_code() @@ -67,10 +103,14 @@ class FuseViewCopyTransform(ExportPass): _passes_required_after: Set[Type[ExportPass]] = set() def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph) - graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph) - modified = merge_modified or noop_modified + graph_module.graph, modified = merge_view_copy_chains(graph_module.graph) if modified: graph_module.recompile() graph_module = super().call(graph_module).graph_module + + graph_module.graph, modified = remove_noop_view_copy(graph_module.graph) + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified)