Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions backends/arm/test/passes/test_fuse_view_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform


class FuseSequentialViews(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x.view((1, 2, 3, 4)).view((2, 3, 4, 1)).view((2, 3, 4))

data = (torch.randn(2, 3, 1, 4),)
ops_before_pass = {
"executorch_exir_dialects_edge__ops_aten_view_copy": 3,
}
ops_after_pass = {
"executorch_exir_dialects_edge__ops_aten_view_copy": 1,
}


class FuseSequentialWithNoopsViews(torch.nn.Module):
def forward(self, x: torch.Tensor):
return (
x.view((1, 2, 3, 4))
.clone()
.view((2, 3, 4, 1))
.to(dtype=torch.int32)
.view((2, 3, 4))
.abs()
.reciprocal()
.sqrt()
.view((12, 2))
)

data = (torch.randn(2, 3, 1, 4),)
ops_before_pass = {
"executorch_exir_dialects_edge__ops_aten_view_copy": 4,
}
ops_after_pass = {
"executorch_exir_dialects_edge__ops_aten_view_copy": 1,
}


class DontFuseBranchingViews(torch.nn.Module):
def forward(self, x: torch.Tensor):
x = x.view((1, 2, 3, 4))
x1 = x.abs().view((2, 3, 4, 1))
x2 = x.ceil().view((2, 3, 4, 1))
return x1 + x2

data = (torch.randn(2, 3, 1, 4),)
ops_before_pass = {
"executorch_exir_dialects_edge__ops_aten_view_copy": 3,
}
ops_after_pass = {
"executorch_exir_dialects_edge__ops_aten_view_copy": 3,
}


tests = {
"fuse_sequential_views": FuseSequentialViews(),
"fuse_sequential_with_noops_views": FuseSequentialWithNoopsViews(),
"dont_fuse_branching_views": DontFuseBranchingViews(),
}


@common.parametrize("model", tests)
def test_fuse_view_copy(model):
pipeline = PassPipeline(
model,
model.data,
quantize=False,
ops_before_pass=model.ops_before_pass,
ops_after_pass=model.ops_after_pass,
pass_list=[FuseViewCopyTransform],
)
pipeline.run()
64 changes: 52 additions & 12 deletions backends/transforms/fuse_view_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,67 @@
from executorch.exir.pass_base import ExportPass, PassResult


UNARY_ELEMENTWISE_OPS = [
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.alias_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.round.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.silu.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.sign.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.gelu.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.log.default,
]


def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]:
"""
Find chains of view_copy nodes and merge them into one view_copy node.
Find chains of view_copy nodes and unary elementwise ops and set all
view_copy nodes to have the final shape. The views will then be removed
by the remove_noop_view_copy call.

Only merges view_copy nodes that are not used by any other nodes.
"""
ops = exir_ops.edge
view_op = ops.aten.view_copy.default
modified = False
for node in graph.nodes:
if node.op == "call_function" and node.target == view_op:
# find ending view_copy node in chain
# Find a chain of unary elementwise ops and save all view_copy nodes
end_node = node
view_ops = [node]
while (
end_node.op == "call_function"
and end_node.target == view_op
and end_node.target in UNARY_ELEMENTWISE_OPS
and len(end_node.users) == 1
and list(end_node.users)[0].target == view_op
and list(end_node.users)[0].target in UNARY_ELEMENTWISE_OPS
):
end_node = list(end_node.users)[0]
# we can swap the first node's shape arg with the last node's shape arg
if node != end_node:
with graph.inserting_after(node):
new_args = (node.args[0], end_node.args[1])
if end_node.target == view_op:
view_ops.append(end_node)

# Set all view_copy nodes to have the final shape
if len(view_ops) > 1:
final_shape = view_ops[-1].args[1]
for node in view_ops:
new_args = (node.args[0], final_shape)
node.args = new_args
end_node.replace_all_uses_with(node)
modified = True

graph.eliminate_dead_code()
Expand Down Expand Up @@ -67,10 +103,14 @@ class FuseViewCopyTransform(ExportPass):
_passes_required_after: Set[Type[ExportPass]] = set()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph)
graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph)
modified = merge_modified or noop_modified
graph_module.graph, modified = merge_view_copy_chains(graph_module.graph)
if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

graph_module.graph, modified = remove_noop_view_copy(graph_module.graph)
if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
Loading