Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/tir/transforms/common_subexpr_elim_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include "common_subexpr_elim_tools.h"

#include <tvm/arith/analyzer.h>
#include <tvm/ir/transform.h> // For the class Pass and the class PassContext
#include <tvm/runtime/container/string.h>
#include <tvm/tir/analysis.h> // For the ExprDeepEqual analysis
Expand Down Expand Up @@ -727,7 +728,10 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) {
// For now, we just check the syntactic equality, but that could later become a semantic test,
// for instance identifying computations modulo commutativity (like x+y and y+x), or modulo
// associativity (like (x+y)+z and x+(y+z)), etc.
return EqualTerms(a, b);
arith::Analyzer analyser;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yuanfz98 , One thing to note here is that running recursively simplification on every subexpression would result in a possibly quadratic complexity wrt to the expression size, so we would want to use it with care, perhaps only triger with limited expression length to avoid long compilation time.

In this particular case, directly running a simplification pass before the common subexpr elim would have a same effect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree, I will close this PR. I will follow the ongoing development of pass dependencies and come back to this feature after. Thanks for your reply !
cc @FranckQC @zhiics

PrimExpr a_simplified = analyser.Simplify(a);
PrimExpr b_simplified = analyser.Simplify(b);
return EqualTerms(a_simplified, b_simplified);
}

/*!
Expand Down
55 changes: 55 additions & 0 deletions tests/python/unittest/test_tir_transform_common_subexpr_elim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,51 @@
# under the License.
import tvm
from tvm import te
from tvm.script import tir as T


@T.prim_func
def func_distributivity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None:
B = T.buffer_decl((50,), "int32")
B[i1] = x * (y + z)
B[i2] = x * y + x * z


@T.prim_func
def func_distributivity_expected(
i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
) -> None:
B = T.buffer_decl((50,), "int32")
cse_var_1 = T.var("int32")
with T.let(cse_var_1, x * (y + z)):
B[i1] = cse_var_1
B[i2] = cse_var_1


@T.prim_func
def func_associativity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32) -> None:
B = T.buffer_decl((50,), "int32")
B[i1] = (x + y) + z
B[i2] = x + (y + z)


@T.prim_func
def func_associativity_expected(
i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
) -> None:
B = T.buffer_decl((50,), "int32")
cse_var_1 = T.var("int32")
with T.let(cse_var_1, (x + y) + z):
B[i1] = cse_var_1
B[i2] = cse_var_1


def _check(original, transformed):
func = original
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.CommonSubexprElimTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed)


# A test program which gives the opportunity for the CSE pass to introduce two new variables, at two different levels
def test_cse():
Expand Down Expand Up @@ -305,8 +350,18 @@ def test_cse_cascade():
assert tvm.ir.structural_equal(store3.value, cse_var_2)


def test_semantic_equiv_distributivity():
_check(func_distributivity, func_distributivity_expected)


def test_semantic_equiv_associativity():
_check(func_associativity, func_associativity_expected)


if __name__ == "__main__":
test_cse()
test_cse_ifNode_1()
test_cse_ifNode_2()
test_cse_cascade()
test_semantic_equiv_distributivity()
test_semantic_equiv_associativity()