|
33 | 33 | ReplaceFunctionallyEquivalentOpTargets, |
34 | 34 | ReplaceIm2RowWithViewPass, |
35 | 35 | ReplaceLinearWithFullyConnectedOpPass, |
| 36 | + ReplaceLogicalNotBooleanWhereWithWherePass, |
36 | 37 | ReplaceMatmulWithTransposedMatmulPass, |
37 | 38 | ReplaceMMWithAddMMPass, |
38 | 39 | ReplaceMulTensorWithMulAndFullOpsPass, |
@@ -2053,3 +2054,112 @@ def test_replace_quantized_embedding( |
2053 | 2054 | ), |
2054 | 2055 | 1, |
2055 | 2056 | ) |
| 2057 | + |
| 2058 | + |
| 2059 | +class TestReplaceLogicalNotBooleanWhereWithWherePass(unittest.TestCase): |
| 2060 | + """Tests for the ReplaceLogicalNotBooleanWhereWithWherePass.""" |
| 2061 | + |
| 2062 | + def test_replace_where_with_logical_not_boolean(self) -> None: |
| 2063 | + """Test that where(logical_not(bool_cond), x, y) is replaced with where(bool_cond, y, x).""" |
| 2064 | + # Setup: Create a graph with where(logical_not(bool_cond), x, y) |
| 2065 | + builder = GraphBuilder() |
| 2066 | + bool_cond_ = torch.randn(4, 8) > 0 |
| 2067 | + x_ = torch.randn(4, 8) |
| 2068 | + y_ = torch.randn(4, 8) |
| 2069 | + |
| 2070 | + bool_cond = builder.placeholder("bool_cond", bool_cond_) |
| 2071 | + x = builder.placeholder("x", x_) |
| 2072 | + y = builder.placeholder("y", y_) |
| 2073 | + |
| 2074 | + # Create logical_not node |
| 2075 | + logical_not = builder.call_operator( |
| 2076 | + op=exir_ops.edge.aten.logical_not.default, |
| 2077 | + args=(bool_cond,), |
| 2078 | + ) |
| 2079 | + |
| 2080 | + # Create where node using logical_not |
| 2081 | + where_node = builder.call_operator( |
| 2082 | + op=exir_ops.edge.aten.where.self, |
| 2083 | + args=(logical_not, x, y), |
| 2084 | + ) |
| 2085 | + builder.output([where_node]) |
| 2086 | + original_gm = builder.get_graph_module() |
| 2087 | + |
| 2088 | + # Make a copy of the original graph before applying the pass |
| 2089 | + original_gm_copy = copy.deepcopy(original_gm) |
| 2090 | + |
| 2091 | + # Execute: Apply the replacement pass |
| 2092 | + p = ReplaceLogicalNotBooleanWhereWithWherePass() |
| 2093 | + result = cast(PassResult, p(original_gm)) |
| 2094 | + |
| 2095 | + # Assert: Verify the pass modified the graph |
| 2096 | + self.assertTrue(result.modified) |
| 2097 | + graph_after_passes = result.graph_module |
| 2098 | + |
| 2099 | + # Assert: Verify logical_not is removed (dead code elimination) |
| 2100 | + self.assertEqual( |
| 2101 | + count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default), |
| 2102 | + 0, |
| 2103 | + ) |
| 2104 | + |
| 2105 | + # Assert: Verify where node still exists |
| 2106 | + self.assertEqual( |
| 2107 | + count_node(graph_after_passes, exir_ops.edge.aten.where.self), |
| 2108 | + 1, |
| 2109 | + ) |
| 2110 | + |
| 2111 | + # Assert: Verify the arguments are flipped (condition uses original bool_cond, x and y are swapped) |
| 2112 | + where_nodes = list(graph_after_passes.graph.find_nodes( |
| 2113 | + op="call_function", target=exir_ops.edge.aten.where.self |
| 2114 | + )) |
| 2115 | + for node in where_nodes: |
| 2116 | + # First arg should be the original bool_cond (not the logical_not) |
| 2117 | + self.assertEqual(node.args[0].name, "bool_cond") |
| 2118 | + # Second and third args should be swapped (y, x instead of x, y) |
| 2119 | + self.assertEqual(node.args[1].name, "y") |
| 2120 | + self.assertEqual(node.args[2].name, "x") |
| 2121 | + |
| 2122 | + # Assert: Verify outputs match exactly by running both graphs |
| 2123 | + validate( |
| 2124 | + original_gm_copy, |
| 2125 | + graph_after_passes, |
| 2126 | + (bool_cond_, x_, y_), |
| 2127 | + "ReplaceLogicalNotBooleanWhereWithWherePass", |
| 2128 | + ) |
| 2129 | + |
| 2130 | + def test_no_replacement_without_logical_not(self) -> None: |
| 2131 | + """Test that the pass does NOT apply when there's no logical_not.""" |
| 2132 | + # Setup: Create a graph with where(bool_cond, x, y) without logical_not |
| 2133 | + builder = GraphBuilder() |
| 2134 | + bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0) |
| 2135 | + x = builder.placeholder("x", torch.randn(4, 8)) |
| 2136 | + y = builder.placeholder("y", torch.randn(4, 8)) |
| 2137 | + |
| 2138 | + # Create where node directly without logical_not |
| 2139 | + where_node = builder.call_operator( |
| 2140 | + op=exir_ops.edge.aten.where.self, |
| 2141 | + args=(bool_cond, x, y), |
| 2142 | + ) |
| 2143 | + builder.output([where_node]) |
| 2144 | + original_gm = builder.get_graph_module() |
| 2145 | + |
| 2146 | + # Execute: Apply the replacement pass |
| 2147 | + p = ReplaceLogicalNotBooleanWhereWithWherePass() |
| 2148 | + result = cast(PassResult, p(original_gm)) |
| 2149 | + |
| 2150 | + # Assert: Verify the pass did NOT modify the graph |
| 2151 | + self.assertFalse(result.modified) |
| 2152 | + graph_after_passes = result.graph_module |
| 2153 | + |
| 2154 | + # Assert: Verify where node still exists unchanged |
| 2155 | + self.assertEqual( |
| 2156 | + count_node(graph_after_passes, exir_ops.edge.aten.where.self), |
| 2157 | + 1, |
| 2158 | + ) |
| 2159 | + |
| 2160 | + for node in graph_after_passes.graph.find_nodes( |
| 2161 | + op="call_function", target=exir_ops.edge.aten.where.self |
| 2162 | + ): |
| 2163 | + self.assertEqual(node.args[0].name, "bool_cond") |
| 2164 | + self.assertEqual(node.args[1].name, "x") |
| 2165 | + self.assertEqual(node.args[2].name, "y") |
0 commit comments