Skip to content

Commit 47c08d9

Browse files
authored
Update ReplaceLogicalNotBooleanWhereWithWherePass to use new pass interface
Differential Revision: D86782910 Pull Request resolved: #15755
1 parent bee30ac commit 47c08d9

File tree

2 files changed

+141
-33
lines changed

2 files changed

+141
-33
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -69,50 +69,46 @@ def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool:
6969

7070

7171
@register_cadence_pass(CadencePassAttribute(opt_level=0))
72-
class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass):
72+
class ReplaceLogicalNotBooleanWhereWithWherePass(RemoveOrReplacePassInterface):
7373
"""
7474
A where op with a logical_not and a boolean tensor can be replaced
7575
by a where op with flipped inputs and the initial boolean tensor.
7676
"""
7777

78-
def replace_logical_nop_where_with_where(
79-
self, graph_module: torch.fx.GraphModule
80-
) -> None:
81-
graph = graph_module.graph
82-
for node in graph.nodes:
83-
# We are only interested in where nodes
84-
if node.target != exir_ops.edge.aten.where.self:
85-
continue
78+
@property
79+
def targets(self) -> list[EdgeOpOverload]:
80+
return [exir_ops.edge.aten.where.self]
8681

87-
# If the third arg is not a logical_not, bail.
88-
if node.args[0].target != exir_ops.edge.aten.logical_not.default:
89-
continue
82+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
83+
# If the first arg is not a logical_not, bail.
84+
if not isinstance(node.args[0], torch.fx.Node):
85+
return False
9086

91-
# Get the third arg node and its input
92-
logical_not_node = node.args[0]
93-
logical_not_input_node = logical_not_node.args[0]
87+
logical_not_node = cast(torch.fx.Node, node.args[0])
88+
if logical_not_node.target != exir_ops.edge.aten.logical_not.default:
89+
return False
9490

95-
# If the logical_not input is not a boolean tensor, bail.
96-
if logical_not_input_node.meta["val"].dtype != torch.bool:
97-
continue
91+
# Get the first arg node and its input
92+
if not isinstance(logical_not_node.args[0], torch.fx.Node):
93+
return False
9894

99-
# Replace the where op with another one, flipping the inputs and using the boolean
100-
# tensor from logical_not.
101-
with graph.inserting_before(node):
102-
linear_node = graph.call_function(
103-
exir_ops.edge.aten.where.self,
104-
args=(logical_not_node.args[0], node.args[2], node.args[1]),
105-
)
106-
# Replace all the uses
107-
node.replace_all_uses_with(linear_node)
95+
logical_not_input_node = cast(torch.fx.Node, logical_not_node.args[0])
10896

109-
graph_module.recompile()
110-
graph_module.graph.eliminate_dead_code()
97+
# If the logical_not input is not a boolean tensor, bail.
98+
if logical_not_input_node.meta["val"].dtype != torch.bool:
99+
return False
111100

112-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
113-
self.replace_logical_nop_where_with_where(graph_module)
114-
result = super().call(graph_module)
115-
return result
101+
# Replace the where op with another one, flipping the inputs and using the boolean
102+
# tensor from logical_not.
103+
with node.graph.inserting_before(node):
104+
new_node = node.graph.call_function(
105+
exir_ops.edge.aten.where.self,
106+
args=(logical_not_input_node, node.args[2], node.args[1]),
107+
)
108+
new_node.meta = node.meta
109+
# Replace all the uses
110+
node.replace_all_uses_with(new_node)
111+
return True
116112

117113

118114
@register_cadence_pass(CadencePassAttribute(opt_level=0))

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ReplaceFunctionallyEquivalentOpTargets,
3434
ReplaceIm2RowWithViewPass,
3535
ReplaceLinearWithFullyConnectedOpPass,
36+
ReplaceLogicalNotBooleanWhereWithWherePass,
3637
ReplaceMatmulWithTransposedMatmulPass,
3738
ReplaceMMWithAddMMPass,
3839
ReplaceMulTensorWithMulAndFullOpsPass,
@@ -2053,3 +2054,114 @@ def test_replace_quantized_embedding(
20532054
),
20542055
1,
20552056
)
2057+
2058+
2059+
class TestReplaceLogicalNotBooleanWhereWithWherePass(unittest.TestCase):
2060+
"""Tests for the ReplaceLogicalNotBooleanWhereWithWherePass."""
2061+
2062+
def test_replace_where_with_logical_not_boolean(self) -> None:
2063+
"""Test that where(logical_not(bool_cond), x, y) is replaced with where(bool_cond, y, x)."""
2064+
# Setup: Create a graph with where(logical_not(bool_cond), x, y)
2065+
builder = GraphBuilder()
2066+
bool_cond_ = torch.randn(4, 8) > 0
2067+
x_ = torch.randn(4, 8)
2068+
y_ = torch.randn(4, 8)
2069+
2070+
bool_cond = builder.placeholder("bool_cond", bool_cond_)
2071+
x = builder.placeholder("x", x_)
2072+
y = builder.placeholder("y", y_)
2073+
2074+
# Create logical_not node
2075+
logical_not = builder.call_operator(
2076+
op=exir_ops.edge.aten.logical_not.default,
2077+
args=(bool_cond,),
2078+
)
2079+
2080+
# Create where node using logical_not
2081+
where_node = builder.call_operator(
2082+
op=exir_ops.edge.aten.where.self,
2083+
args=(logical_not, x, y),
2084+
)
2085+
builder.output([where_node])
2086+
original_gm = builder.get_graph_module()
2087+
2088+
# Make a copy of the original graph before applying the pass
2089+
original_gm_copy = copy.deepcopy(original_gm)
2090+
2091+
# Execute: Apply the replacement pass
2092+
p = ReplaceLogicalNotBooleanWhereWithWherePass()
2093+
result = cast(PassResult, p(original_gm))
2094+
2095+
# Assert: Verify the pass modified the graph
2096+
self.assertTrue(result.modified)
2097+
graph_after_passes = result.graph_module
2098+
2099+
# Assert: Verify logical_not is removed (dead code elimination)
2100+
self.assertEqual(
2101+
count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default),
2102+
0,
2103+
)
2104+
2105+
# Assert: Verify where node still exists
2106+
self.assertEqual(
2107+
count_node(graph_after_passes, exir_ops.edge.aten.where.self),
2108+
1,
2109+
)
2110+
2111+
# Assert: Verify the arguments are flipped (condition uses original bool_cond, x and y are swapped)
2112+
where_nodes = list(
2113+
graph_after_passes.graph.find_nodes(
2114+
op="call_function", target=exir_ops.edge.aten.where.self
2115+
)
2116+
)
2117+
for node in where_nodes:
2118+
# First arg should be the original bool_cond (not the logical_not)
2119+
self.assertEqual(node.args[0].name, "bool_cond")
2120+
# Second and third args should be swapped (y, x instead of x, y)
2121+
self.assertEqual(node.args[1].name, "y")
2122+
self.assertEqual(node.args[2].name, "x")
2123+
2124+
# Assert: Verify outputs match exactly by running both graphs
2125+
validate(
2126+
original_gm_copy,
2127+
graph_after_passes,
2128+
(bool_cond_, x_, y_),
2129+
"ReplaceLogicalNotBooleanWhereWithWherePass",
2130+
)
2131+
2132+
def test_no_replacement_without_logical_not(self) -> None:
2133+
"""Test that the pass does NOT apply when there's no logical_not."""
2134+
# Setup: Create a graph with where(bool_cond, x, y) without logical_not
2135+
builder = GraphBuilder()
2136+
bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0)
2137+
x = builder.placeholder("x", torch.randn(4, 8))
2138+
y = builder.placeholder("y", torch.randn(4, 8))
2139+
2140+
# Create where node directly without logical_not
2141+
where_node = builder.call_operator(
2142+
op=exir_ops.edge.aten.where.self,
2143+
args=(bool_cond, x, y),
2144+
)
2145+
builder.output([where_node])
2146+
original_gm = builder.get_graph_module()
2147+
2148+
# Execute: Apply the replacement pass
2149+
p = ReplaceLogicalNotBooleanWhereWithWherePass()
2150+
result = cast(PassResult, p(original_gm))
2151+
2152+
# Assert: Verify the pass did NOT modify the graph
2153+
self.assertFalse(result.modified)
2154+
graph_after_passes = result.graph_module
2155+
2156+
# Assert: Verify where node still exists unchanged
2157+
self.assertEqual(
2158+
count_node(graph_after_passes, exir_ops.edge.aten.where.self),
2159+
1,
2160+
)
2161+
2162+
for node in graph_after_passes.graph.find_nodes(
2163+
op="call_function", target=exir_ops.edge.aten.where.self
2164+
):
2165+
self.assertEqual(node.args[0].name, "bool_cond")
2166+
self.assertEqual(node.args[1].name, "x")
2167+
self.assertEqual(node.args[2].name, "y")

0 commit comments

Comments
 (0)