Skip to content

Commit 53136fa

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Remove no-op clones (#15838)
Summary: Pull Request resolved: #15838 Differential Revision: D86588171 Pulled By: GregoryComer
1 parent 101e915 commit 53136fa

File tree

4 files changed

+71
-54
lines changed

4 files changed

+71
-54
lines changed

backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def forward(self, x):
7676
return self.block(x)
7777

7878

79+
@unittest.skip("Clones are optimized out of the graph.")
7980
class TestCloneConverter(unittest.TestCase):
8081
__test__ = False # Prevent interfering with PyTest tests
8182

backends/transforms/test/test_remove_clone_ops.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import unittest
810

911
import torch
@@ -164,34 +166,6 @@ def test_clone_non_identity_survives(self):
164166
assert torch.allclose(actual, expected)
165167
assert is_channel_last_dim_order(actual)
166168

167-
def test_clone_identity_removed(self):
168-
"""Verify identity clone ops are removed by RemoveCloneOpsTransform."""
169-
170-
for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
171-
model = SimpleCloneChannelsLastModule()
172-
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)
173-
174-
exported = export(model.eval(), (x,), strict=True)
175-
before_epm = to_edge(
176-
exported,
177-
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
178-
)
179-
180-
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
181-
before_epm.exported_program().graph_module.code
182-
)
183-
184-
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
185-
186-
FileCheck().check_not(clone_op_str).run(
187-
updated_epm.exported_program().graph_module.code
188-
)
189-
190-
expected = before_epm.exported_program().module()(x)
191-
actual = updated_epm.exported_program().module()(x)
192-
assert torch.allclose(actual, expected)
193-
assert is_channel_last_dim_order(actual)
194-
195169

196170
if __name__ == "__main__":
197171
unittest.main()

exir/passes/remove_noop_pass.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,35 +56,10 @@ def call(self, graph_module: GraphModule) -> PassResult:
5656
dequant_nodes = []
5757

5858
for node in graph_module.graph.nodes:
59-
if node.op != "call_function":
60-
continue
61-
62-
if node.target not in (
63-
torch.ops.aten.to.dtype,
64-
torch.ops.aten.dropout.default,
65-
torch.ops.aten.slice_copy.Tensor,
66-
):
67-
continue
68-
69-
orig_tensor = node.args[0].meta["val"]
70-
71-
if orig_tensor is node.meta["val"]:
72-
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
73-
# Otherwise, removing only the op will suffice.
59+
if RemoveNoopPass._should_remove_node(node):
7460
if node.args[0].target in _DEQUANT_OPS:
7561
dequant_nodes += [node.args[0]]
7662
node.replace_all_uses_with(node.args[0])
77-
continue
78-
79-
if node.target == torch.ops.aten.slice_copy.Tensor:
80-
# Only do this check if all the dims are static.
81-
if all(isinstance(dim, int) for dim in orig_tensor.size()):
82-
if orig_tensor.shape == node.meta["val"].shape:
83-
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
84-
# Otherwise, removing only the op will suffice.
85-
if node.args[0].target in _DEQUANT_OPS:
86-
dequant_nodes += [node.args[0]]
87-
node.replace_all_uses_with(node.args[0])
8863

8964
graph_module.graph.eliminate_dead_code()
9065
eliminate_dq_q(graph_module, dequant_nodes)
@@ -93,6 +68,41 @@ def call(self, graph_module: GraphModule) -> PassResult:
9368

9469
return PassResult(graph_module, True)
9570

71+
@staticmethod
72+
def _should_remove_node(node: torch.fx.Node) -> bool:
73+
if node.op != "call_function":
74+
return False
75+
76+
input_meta_val = (
77+
node.args[0].meta.get("val", None)
78+
if len(node.args) > 0 and hasattr(node.args[0], "meta")
79+
else None
80+
)
81+
82+
if input_meta_val is not None:
83+
if node.target in (
84+
torch.ops.aten.to.dtype,
85+
torch.ops.aten.dropout.default,
86+
):
87+
return input_meta_val is node.meta["val"]
88+
elif node.target == torch.ops.aten.slice_copy.Tensor:
89+
# Only do this check if all the dims are static.
90+
return (
91+
all(isinstance(dim, int) for dim in input_meta_val.size())
92+
and input_meta_val.shape == node.meta["val"].shape
93+
)
94+
elif node.target == torch.ops.aten.clone.default:
95+
# Remove if memory_format=None, preserve_format, or input already has the target memory format.
96+
dest_memory_format = (
97+
node.kwargs.get("memory_format", None) or torch.preserve_format
98+
)
99+
return (
100+
dest_memory_format == torch.preserve_format
101+
or input_meta_val.is_contiguous(memory_format=dest_memory_format)
102+
)
103+
104+
return False
105+
96106

97107
class RemoveToCopyPass(ExportPass):
98108
"""

exir/tests/test_passes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,3 +2093,35 @@ def forward(self, x):
20932093
prop_tensor.is_contiguous(),
20942094
f"Propagated tensor is not contiguous: {prop_tensor.stride()}",
20952095
)
2096+
2097+
def test_remove_noop_pass_clone(self) -> None:
2098+
"""
2099+
Verify the no-op clones are removed from the graph.
2100+
"""
2101+
2102+
class CloneModel(torch.nn.Module):
2103+
def forward(self, x):
2104+
return x.clone() + x.clone()
2105+
2106+
model = CloneModel()
2107+
inputs = (torch.randn(1, 16),)
2108+
2109+
ep = torch.export.export(model, inputs)
2110+
lowered = to_edge_transform_and_lower(ep)
2111+
2112+
# Sanity check the test - we should see clones in the exported program
2113+
self.assertTrue(
2114+
any(
2115+
n.op == "call_function" and n.target == torch.ops.aten.clone.default
2116+
for n in ep.graph.nodes
2117+
)
2118+
)
2119+
2120+
# Since the clone ops are no-ops, they should be gone.
2121+
self.assertFalse(
2122+
any(
2123+
n.op == "call_function"
2124+
and n.target == exir_ops.edge.dim_order_ops._clone_dim_order.default
2125+
for n in lowered.exported_program().graph.nodes
2126+
)
2127+
)

0 commit comments

Comments
 (0)