diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index 218667c331a5..690dc31075c3 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -25,6 +25,7 @@ #include "common_subexpr_elim_tools.h" +#include #include // For the class Pass and the class PassContext #include #include // For the ExprDeepEqual analysis @@ -727,7 +728,10 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) { // For now, we just check the syntactic equality, but that could later become a semantic test, // for instance identifying computations modulo commutativity (like x+y and y+x), or modulo // associativity (like (x+y)+z and x+(y+z)), etc. - return EqualTerms(a, b); + arith::Analyzer analyser; + PrimExpr a_simplified = analyser.Simplify(a); + PrimExpr b_simplified = analyser.Simplify(b); + return EqualTerms(a_simplified, b_simplified); } /*! diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 17c0cbdd99c6..01c231d9629c 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -16,6 +16,51 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T + + +@T.prim_func +def func_distributivity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: + B = T.buffer_decl((50,), "int32") + B[i1] = x * (y + z) + B[i2] = x * y + x * z + + +@T.prim_func +def func_distributivity_expected( + i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: + B = T.buffer_decl((50,), "int32") + cse_var_1 = T.var("int32") + with T.let(cse_var_1, x * (y + z)): + B[i1] = cse_var_1 + B[i2] = cse_var_1 + + +@T.prim_func +def func_associativity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None: + B = T.buffer_decl((50,), "int32") + B[i1] = (x + y) + z + B[i2] = x + (y + z) + + +@T.prim_func +def func_associativity_expected( + i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 +) -> None: + B = T.buffer_decl((50,), "int32") + cse_var_1 = T.var("int32") + with T.let(cse_var_1, (x + y) + z): + B[i1] = cse_var_1 + B[i2] = cse_var_1 + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.CommonSubexprElimTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed) + # A test program which gives the opportunity for the CSE pass to introduce two new variables, at two different levels def test_cse(): @@ -305,8 +350,18 @@ def test_cse_cascade(): assert tvm.ir.structural_equal(store3.value, cse_var_2) +def test_semantic_equiv_distributivity(): + _check(func_distributivity, func_distributivity_expected) + + +def test_semantic_equiv_associativity(): + _check(func_associativity, func_associativity_expected) + + if __name__ == "__main__": test_cse() test_cse_ifNode_1() test_cse_ifNode_2() test_cse_cascade() + test_semantic_equiv_distributivity() + test_semantic_equiv_associativity()