Skip to content

Commit e9b7170

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update ReplaceSqueezeAndUnsqueezeWithViewPass to use new pass interface (#15757)
Summary: As titled, now it is more efficient and correctly updates the modified bit. Updated tests, too Differential Revision: D86785126
1 parent 64ee840 commit e9b7170

File tree

2 files changed

+34
-24
lines changed

2 files changed

+34
-24
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -193,39 +193,39 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
193193

194194

195195
@register_cadence_pass(CadencePassAttribute(opt_level=0))
196-
class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
196+
class ReplaceSqueezeAndUnsqueezeWithViewPass(RemoveOrReplacePassInterface):
197197
"""
198198
When the shape is static, replace squeeze_copy and unsqueeze_copy ops with
199199
view_copy op
200200
"""
201201

202-
def call_operator(
203-
self,
204-
op,
205-
args: Tuple[Argument, ...],
206-
kwargs: Dict[str, Argument],
207-
meta: NodeMetadata,
208-
) -> ProxyValue:
209-
# Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket,
210-
# which allows us to cover all overloads.
211-
if get_edge_overload_packet(op) not in {
212-
exir_ops.edge.aten.squeeze_copy,
213-
exir_ops.edge.aten.unsqueeze_copy,
214-
}:
215-
return super().call_operator(op, args, kwargs, meta)
202+
@property
203+
def targets(self) -> list[EdgeOpOverload]:
204+
return [
205+
exir_ops.edge.aten.squeeze_copy.default,
206+
exir_ops.edge.aten.squeeze_copy.dim,
207+
exir_ops.edge.aten.squeeze_copy.dims,
208+
exir_ops.edge.aten.unsqueeze_copy.default,
209+
]
210+
211+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
216212
# Get the output tensor shape
217-
out_shape = meta["val"].shape
213+
out_shape = node.meta["val"].shape
218214

219215
# Bail out if any dim is not an int (dynamic shape)
220216
for dim in list(out_shape):
221217
if not isinstance(dim, int):
222-
return super().call_operator(op, args, kwargs, meta)
218+
return False
223219

224-
# Return a view op with the new shape
225-
view_args = (args[0], list(out_shape))
226-
return super().call_operator(
227-
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
228-
)
220+
# Replace with view op with the new shape
221+
with node.graph.inserting_before(node):
222+
new_node = node.graph.call_function(
223+
exir_ops.edge.aten.view_copy.default,
224+
args=(node.args[0], list(out_shape)),
225+
)
226+
new_node.meta = node.meta
227+
node.replace_all_uses_with(new_node)
228+
return True
229229

230230

231231
@register_cadence_pass(CadencePassAttribute(opt_level=0))

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,12 @@ def test_replace_squeeze_with_view(
989989
args=(x,),
990990
)
991991
p = ReplaceSqueezeAndUnsqueezeWithViewPass()
992-
graph_after_passes = cast(PassResult, p(original_gm)).graph_module
992+
result = cast(PassResult, p(original_gm))
993+
994+
# Assert: Verify the pass modified the graph
995+
self.assertTrue(result.modified)
996+
graph_after_passes = result.graph_module
997+
993998
self.assertIsNotNone(graph_after_passes)
994999
self.assertEqual(
9951000
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default),
@@ -1024,7 +1029,12 @@ def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None:
10241029
args=(x, dim),
10251030
)
10261031
p = ReplaceSqueezeAndUnsqueezeWithViewPass()
1027-
graph_after_passes = cast(PassResult, p(original_gm)).graph_module
1032+
result = cast(PassResult, p(original_gm))
1033+
1034+
# Assert: Verify the pass modified the graph
1035+
self.assertTrue(result.modified)
1036+
graph_after_passes = result.graph_module
1037+
10281038
self.assertIsNotNone(graph_after_passes)
10291039
self.assertEqual(
10301040
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default),

0 commit comments

Comments
 (0)