|
33 | 33 | ReplaceFunctionallyEquivalentOpTargets, |
34 | 34 | ReplaceIm2RowWithViewPass, |
35 | 35 | ReplaceLinearWithFullyConnectedOpPass, |
| 36 | + ReplaceLogicalNotBooleanWhereWithWherePass, |
36 | 37 | ReplaceMatmulWithTransposedMatmulPass, |
37 | 38 | ReplaceMMWithAddMMPass, |
38 | 39 | ReplaceMulTensorWithMulAndFullOpsPass, |
@@ -2183,3 +2184,150 @@ def test_replace_quantized_embedding( |
2183 | 2184 | ), |
2184 | 2185 | 1, |
2185 | 2186 | ) |
| 2187 | + |
| 2188 | + |
| 2189 | +class TestReplaceLogicalNotBooleanWhereWithWherePass(unittest.TestCase): |
| 2190 | + """Tests for the ReplaceLogicalNotBooleanWhereWithWherePass.""" |
| 2191 | + |
| 2192 | + @torch.no_grad() |
| 2193 | + def test_replace_where_with_logical_not_boolean(self) -> None: |
| 2194 | + """Test that where(logical_not(bool_cond), x, y) is replaced with where(bool_cond, y, x).""" |
| 2195 | + # Setup: Create a graph with where(logical_not(bool_cond), x, y) |
| 2196 | + builder = GraphBuilder() |
| 2197 | + bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0) |
| 2198 | + x = builder.placeholder("x", torch.randn(4, 8)) |
| 2199 | + y = builder.placeholder("y", torch.randn(4, 8)) |
| 2200 | + |
| 2201 | + # Create logical_not node |
| 2202 | + logical_not = builder.call_operator( |
| 2203 | + op=exir_ops.edge.aten.logical_not.default, |
| 2204 | + args=(bool_cond,), |
| 2205 | + ) |
| 2206 | + |
| 2207 | + # Create where node using logical_not |
| 2208 | + where_node = builder.call_operator( |
| 2209 | + op=exir_ops.edge.aten.where.self, |
| 2210 | + args=(logical_not, x, y), |
| 2211 | + ) |
| 2212 | + builder.output([where_node]) |
| 2213 | + original_gm = builder.get_graph_module() |
| 2214 | + |
| 2215 | + # Execute: Apply the replacement pass |
| 2216 | + p = ReplaceLogicalNotBooleanWhereWithWherePass() |
| 2217 | + result = cast(PassResult, p(original_gm)) |
| 2218 | + |
| 2219 | + # Assert: Verify the pass modified the graph |
| 2220 | + self.assertTrue(result.modified) |
| 2221 | + graph_after_passes = result.graph_module |
| 2222 | + |
| 2223 | + # Assert: Verify logical_not is removed (dead code elimination) |
| 2224 | + self.assertEqual( |
| 2225 | + count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default), |
| 2226 | + 0, |
| 2227 | + ) |
| 2228 | + |
| 2229 | + # Assert: Verify where node still exists |
| 2230 | + self.assertEqual( |
| 2231 | + count_node(graph_after_passes, exir_ops.edge.aten.where.self), |
| 2232 | + 1, |
| 2233 | + ) |
| 2234 | + |
| 2235 | + # Assert: Verify the arguments are flipped (condition uses original bool_cond, x and y are swapped) |
| 2236 | + found_node = False |
| 2237 | + for node in graph_after_passes.graph.find_nodes( |
| 2238 | + op="call_function", target=exir_ops.edge.aten.where.self |
| 2239 | + ): |
| 2240 | + found_node = True |
| 2241 | + # First arg should be the original bool_cond (not the logical_not) |
| 2242 | + self.assertEqual(node.args[0].name, "bool_cond") |
| 2243 | + # Second and third args should be swapped (y, x instead of x, y) |
| 2244 | + self.assertEqual(node.args[1].name, "y") |
| 2245 | + self.assertEqual(node.args[2].name, "x") |
| 2246 | + self.assertTrue(found_node) |
| 2247 | + |
| 2248 | + @torch.no_grad() |
| 2249 | + def test_no_replacement_when_not_boolean_tensor(self) -> None: |
| 2250 | + """Test that the pass does NOT apply when logical_not input is not a boolean tensor.""" |
| 2251 | + # Setup: Create a graph with where(logical_not(float_tensor > 0), x, y) |
| 2252 | + # The logical_not input is not directly a boolean tensor |
| 2253 | + builder = GraphBuilder() |
| 2254 | + float_tensor = builder.placeholder("float_tensor", torch.randn(4, 8)) |
| 2255 | + x = builder.placeholder("x", torch.randn(4, 8)) |
| 2256 | + y = builder.placeholder("y", torch.randn(4, 8)) |
| 2257 | + |
| 2258 | + # Create a comparison that produces a boolean |
| 2259 | + gt_node = builder.call_operator( |
| 2260 | + op=exir_ops.edge.aten.gt.Scalar, |
| 2261 | + args=(float_tensor, 0.0), |
| 2262 | + ) |
| 2263 | + |
| 2264 | + # Create logical_not node using the comparison result |
| 2265 | + logical_not = builder.call_operator( |
| 2266 | + op=exir_ops.edge.aten.logical_not.default, |
| 2267 | + args=(gt_node,), |
| 2268 | + ) |
| 2269 | + |
| 2270 | + # Create where node |
| 2271 | + where_node = builder.call_operator( |
| 2272 | + op=exir_ops.edge.aten.where.self, |
| 2273 | + args=(logical_not, x, y), |
| 2274 | + ) |
| 2275 | + builder.output([where_node]) |
| 2276 | + original_gm = builder.get_graph_module() |
| 2277 | + |
| 2278 | + # Execute: Apply the replacement pass |
| 2279 | + p = ReplaceLogicalNotBooleanWhereWithWherePass() |
| 2280 | + result = cast(PassResult, p(original_gm)) |
| 2281 | + |
| 2282 | + # Assert: Verify the pass modified the graph (gt_node is a boolean tensor) |
| 2283 | + # The pass SHOULD apply because gt.Scalar returns a boolean tensor |
| 2284 | + self.assertTrue(result.modified) |
| 2285 | + graph_after_passes = result.graph_module |
| 2286 | + |
| 2287 | + # Assert: Verify logical_not is removed |
| 2288 | + self.assertEqual( |
| 2289 | + count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default), |
| 2290 | + 0, |
| 2291 | + ) |
| 2292 | + |
| 2293 | + @torch.no_grad() |
| 2294 | + def test_no_replacement_without_logical_not(self) -> None: |
| 2295 | + """Test that the pass does NOT apply when there's no logical_not.""" |
| 2296 | + # Setup: Create a graph with where(bool_cond, x, y) without logical_not |
| 2297 | + builder = GraphBuilder() |
| 2298 | + bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0) |
| 2299 | + x = builder.placeholder("x", torch.randn(4, 8)) |
| 2300 | + y = builder.placeholder("y", torch.randn(4, 8)) |
| 2301 | + |
| 2302 | + # Create where node directly without logical_not |
| 2303 | + where_node = builder.call_operator( |
| 2304 | + op=exir_ops.edge.aten.where.self, |
| 2305 | + args=(bool_cond, x, y), |
| 2306 | + ) |
| 2307 | + builder.output([where_node]) |
| 2308 | + original_gm = builder.get_graph_module() |
| 2309 | + |
| 2310 | + # Execute: Apply the replacement pass |
| 2311 | + p = ReplaceLogicalNotBooleanWhereWithWherePass() |
| 2312 | + result = cast(PassResult, p(original_gm)) |
| 2313 | + |
| 2314 | + # Assert: Verify the pass did NOT modify the graph |
| 2315 | + self.assertFalse(result.modified) |
| 2316 | + graph_after_passes = result.graph_module |
| 2317 | + |
| 2318 | + # Assert: Verify where node still exists unchanged |
| 2319 | + self.assertEqual( |
| 2320 | + count_node(graph_after_passes, exir_ops.edge.aten.where.self), |
| 2321 | + 1, |
| 2322 | + ) |
| 2323 | + |
| 2324 | + # Assert: Verify the arguments are unchanged |
| 2325 | + found_node = False |
| 2326 | + for node in graph_after_passes.graph.find_nodes( |
| 2327 | + op="call_function", target=exir_ops.edge.aten.where.self |
| 2328 | + ): |
| 2329 | + found_node = True |
| 2330 | + self.assertEqual(node.args[0].name, "bool_cond") |
| 2331 | + self.assertEqual(node.args[1].name, "x") |
| 2332 | + self.assertEqual(node.args[2].name, "y") |
| 2333 | + self.assertTrue(found_node) |
0 commit comments