Skip to content

Commit 25ab8e5

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update ReplaceLogicalNotBooleanWhereWithWherePass to use new pass interface (pytorch#15755)
Summary: As titled, more efficient now and properly updates the modified bit. Also, this pass was missing tests, so added some for a variety of different cases. Differential Revision: D86782910
1 parent 3920e52 commit 25ab8e5

File tree

2 files changed

+177
-33
lines changed

2 files changed

+177
-33
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -69,50 +69,46 @@ def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool:
6969

7070

7171
@register_cadence_pass(CadencePassAttribute(opt_level=0))
72-
class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass):
72+
class ReplaceLogicalNotBooleanWhereWithWherePass(RemoveOrReplacePassInterface):
7373
"""
7474
A where op with a logical_not and a boolean tensor can be replaced
7575
by a where op with flipped inputs and the initial boolean tensor.
7676
"""
7777

78-
def replace_logical_nop_where_with_where(
79-
self, graph_module: torch.fx.GraphModule
80-
) -> None:
81-
graph = graph_module.graph
82-
for node in graph.nodes:
83-
# We are only interested in where nodes
84-
if node.target != exir_ops.edge.aten.where.self:
85-
continue
78+
@property
79+
def targets(self) -> list[EdgeOpOverload]:
80+
return [exir_ops.edge.aten.where.self]
8681

87-
# If the third arg is not a logical_not, bail.
88-
if node.args[0].target != exir_ops.edge.aten.logical_not.default:
89-
continue
82+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
83+
# If the first arg is not a logical_not, bail.
84+
if not isinstance(node.args[0], torch.fx.Node):
85+
return False
9086

91-
# Get the third arg node and its input
92-
logical_not_node = node.args[0]
93-
logical_not_input_node = logical_not_node.args[0]
87+
logical_not_node = cast(torch.fx.Node, node.args[0])
88+
if logical_not_node.target != exir_ops.edge.aten.logical_not.default:
89+
return False
9490

95-
# If the logical_not input is not a boolean tensor, bail.
96-
if logical_not_input_node.meta["val"].dtype != torch.bool:
97-
continue
91+
# Get the first arg node and its input
92+
if not isinstance(logical_not_node.args[0], torch.fx.Node):
93+
return False
9894

99-
# Replace the where op with another one, flipping the inputs and using the boolean
100-
# tensor from logical_not.
101-
with graph.inserting_before(node):
102-
linear_node = graph.call_function(
103-
exir_ops.edge.aten.where.self,
104-
args=(logical_not_node.args[0], node.args[2], node.args[1]),
105-
)
106-
# Replace all the uses
107-
node.replace_all_uses_with(linear_node)
95+
logical_not_input_node = cast(torch.fx.Node, logical_not_node.args[0])
10896

109-
graph_module.recompile()
110-
graph_module.graph.eliminate_dead_code()
97+
# If the logical_not input is not a boolean tensor, bail.
98+
if logical_not_input_node.meta["val"].dtype != torch.bool:
99+
return False
111100

112-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
113-
self.replace_logical_nop_where_with_where(graph_module)
114-
result = super().call(graph_module)
115-
return result
101+
# Replace the where op with another one, flipping the inputs and using the boolean
102+
# tensor from logical_not.
103+
with node.graph.inserting_before(node):
104+
new_node = node.graph.call_function(
105+
exir_ops.edge.aten.where.self,
106+
args=(logical_not_input_node, node.args[2], node.args[1]),
107+
)
108+
new_node.meta = node.meta
109+
# Replace all the uses
110+
node.replace_all_uses_with(new_node)
111+
return True
116112

117113

118114
@register_cadence_pass(CadencePassAttribute(opt_level=0))

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ReplaceFunctionallyEquivalentOpTargets,
3434
ReplaceIm2RowWithViewPass,
3535
ReplaceLinearWithFullyConnectedOpPass,
36+
ReplaceLogicalNotBooleanWhereWithWherePass,
3637
ReplaceMatmulWithTransposedMatmulPass,
3738
ReplaceMMWithAddMMPass,
3839
ReplaceMulTensorWithMulAndFullOpsPass,
@@ -2183,3 +2184,150 @@ def test_replace_quantized_embedding(
21832184
),
21842185
1,
21852186
)
2187+
2188+
2189+
class TestReplaceLogicalNotBooleanWhereWithWherePass(unittest.TestCase):
2190+
"""Tests for the ReplaceLogicalNotBooleanWhereWithWherePass."""
2191+
2192+
@torch.no_grad()
2193+
def test_replace_where_with_logical_not_boolean(self) -> None:
2194+
"""Test that where(logical_not(bool_cond), x, y) is replaced with where(bool_cond, y, x)."""
2195+
# Setup: Create a graph with where(logical_not(bool_cond), x, y)
2196+
builder = GraphBuilder()
2197+
bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0)
2198+
x = builder.placeholder("x", torch.randn(4, 8))
2199+
y = builder.placeholder("y", torch.randn(4, 8))
2200+
2201+
# Create logical_not node
2202+
logical_not = builder.call_operator(
2203+
op=exir_ops.edge.aten.logical_not.default,
2204+
args=(bool_cond,),
2205+
)
2206+
2207+
# Create where node using logical_not
2208+
where_node = builder.call_operator(
2209+
op=exir_ops.edge.aten.where.self,
2210+
args=(logical_not, x, y),
2211+
)
2212+
builder.output([where_node])
2213+
original_gm = builder.get_graph_module()
2214+
2215+
# Execute: Apply the replacement pass
2216+
p = ReplaceLogicalNotBooleanWhereWithWherePass()
2217+
result = cast(PassResult, p(original_gm))
2218+
2219+
# Assert: Verify the pass modified the graph
2220+
self.assertTrue(result.modified)
2221+
graph_after_passes = result.graph_module
2222+
2223+
# Assert: Verify logical_not is removed (dead code elimination)
2224+
self.assertEqual(
2225+
count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default),
2226+
0,
2227+
)
2228+
2229+
# Assert: Verify where node still exists
2230+
self.assertEqual(
2231+
count_node(graph_after_passes, exir_ops.edge.aten.where.self),
2232+
1,
2233+
)
2234+
2235+
# Assert: Verify the arguments are flipped (condition uses original bool_cond, x and y are swapped)
2236+
found_node = False
2237+
for node in graph_after_passes.graph.find_nodes(
2238+
op="call_function", target=exir_ops.edge.aten.where.self
2239+
):
2240+
found_node = True
2241+
# First arg should be the original bool_cond (not the logical_not)
2242+
self.assertEqual(node.args[0].name, "bool_cond")
2243+
# Second and third args should be swapped (y, x instead of x, y)
2244+
self.assertEqual(node.args[1].name, "y")
2245+
self.assertEqual(node.args[2].name, "x")
2246+
self.assertTrue(found_node)
2247+
2248+
@torch.no_grad()
2249+
def test_no_replacement_when_not_boolean_tensor(self) -> None:
2250+
"""Test that the pass does NOT apply when logical_not input is not a boolean tensor."""
2251+
# Setup: Create a graph with where(logical_not(float_tensor > 0), x, y)
2252+
# The logical_not input is not directly a boolean tensor
2253+
builder = GraphBuilder()
2254+
float_tensor = builder.placeholder("float_tensor", torch.randn(4, 8))
2255+
x = builder.placeholder("x", torch.randn(4, 8))
2256+
y = builder.placeholder("y", torch.randn(4, 8))
2257+
2258+
# Create a comparison that produces a boolean
2259+
gt_node = builder.call_operator(
2260+
op=exir_ops.edge.aten.gt.Scalar,
2261+
args=(float_tensor, 0.0),
2262+
)
2263+
2264+
# Create logical_not node using the comparison result
2265+
logical_not = builder.call_operator(
2266+
op=exir_ops.edge.aten.logical_not.default,
2267+
args=(gt_node,),
2268+
)
2269+
2270+
# Create where node
2271+
where_node = builder.call_operator(
2272+
op=exir_ops.edge.aten.where.self,
2273+
args=(logical_not, x, y),
2274+
)
2275+
builder.output([where_node])
2276+
original_gm = builder.get_graph_module()
2277+
2278+
# Execute: Apply the replacement pass
2279+
p = ReplaceLogicalNotBooleanWhereWithWherePass()
2280+
result = cast(PassResult, p(original_gm))
2281+
2282+
# Assert: Verify the pass modified the graph (gt_node is a boolean tensor)
2283+
# The pass SHOULD apply because gt.Scalar returns a boolean tensor
2284+
self.assertTrue(result.modified)
2285+
graph_after_passes = result.graph_module
2286+
2287+
# Assert: Verify logical_not is removed
2288+
self.assertEqual(
2289+
count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default),
2290+
0,
2291+
)
2292+
2293+
@torch.no_grad()
2294+
def test_no_replacement_without_logical_not(self) -> None:
2295+
"""Test that the pass does NOT apply when there's no logical_not."""
2296+
# Setup: Create a graph with where(bool_cond, x, y) without logical_not
2297+
builder = GraphBuilder()
2298+
bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0)
2299+
x = builder.placeholder("x", torch.randn(4, 8))
2300+
y = builder.placeholder("y", torch.randn(4, 8))
2301+
2302+
# Create where node directly without logical_not
2303+
where_node = builder.call_operator(
2304+
op=exir_ops.edge.aten.where.self,
2305+
args=(bool_cond, x, y),
2306+
)
2307+
builder.output([where_node])
2308+
original_gm = builder.get_graph_module()
2309+
2310+
# Execute: Apply the replacement pass
2311+
p = ReplaceLogicalNotBooleanWhereWithWherePass()
2312+
result = cast(PassResult, p(original_gm))
2313+
2314+
# Assert: Verify the pass did NOT modify the graph
2315+
self.assertFalse(result.modified)
2316+
graph_after_passes = result.graph_module
2317+
2318+
# Assert: Verify where node still exists unchanged
2319+
self.assertEqual(
2320+
count_node(graph_after_passes, exir_ops.edge.aten.where.self),
2321+
1,
2322+
)
2323+
2324+
# Assert: Verify the arguments are unchanged
2325+
found_node = False
2326+
for node in graph_after_passes.graph.find_nodes(
2327+
op="call_function", target=exir_ops.edge.aten.where.self
2328+
):
2329+
found_node = True
2330+
self.assertEqual(node.args[0].name, "bool_cond")
2331+
self.assertEqual(node.args[1].name, "x")
2332+
self.assertEqual(node.args[2].name, "y")
2333+
self.assertTrue(found_node)

0 commit comments

Comments
 (0)