Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<PrimExpr> lanes;

// Vector rules
if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes));
Expand Down Expand Up @@ -697,9 +698,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1));
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
} else if (op->dtype.is_float()) {
} else {
// Cancellation rules. Deliberately off of the integer path, to
// avoid introducing checks on the side effects for the fast path.
//
// These simplifications do not preserve NaN/Inf that may occur in
// the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does
// not cancel out. However, since models should not encounter NaN
// in the first place, this allows better simplification for the
// supported path.
TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState);
Expand Down Expand Up @@ -1678,6 +1685,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) {
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<PrimExpr> lanes;
PConst<PrimExpr> ctrue(make_const(ret->dtype, true));

// vector rule
if (ret->dtype.is_scalable_or_fixed_length_vector()) {
Expand All @@ -1698,6 +1706,17 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) {
TVM_TRY_REWRITE(c1 - x == c2, x == c1 - c2);
TVM_TRY_REWRITE(x + c1 == c2, x == c2 - c1);
TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0);
TVM_TRY_REWRITE(x == x, ctrue);
} else {
// Mimic the cancellation rules for SubNode. For Index datatypes,
// we skip the check for side effects.
//
// These simplifications do not preserve NaN/Inf that may occur in
// the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does
// not cancel out. However, since models should not encounter NaN
// in the first place, this allows better simplification for the
// supported path.
TVM_TRY_REWRITE_IF(x == x, ctrue, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
}
return std::move(ret);
}
Expand Down
36 changes: 36 additions & 0 deletions tests/python/arith/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,42 @@ class TestSelect(BaseCompare):
)


class TestCancellation(BaseCompare):
var_int8 = tir.Var("var_int8", "int8")
var_int32 = tir.Var("var_int32", "int32")
var_int64 = tir.Var("var_int64", "int64")
var_uint8 = tir.Var("var_uint8", "uint8")
var_uint32 = tir.Var("var_uint32", "uint32")
var_uint64 = tir.Var("var_uint64", "uint64")

test_case = tvm.testing.parameter(
TestCase(tir.const(5, "int64") - tir.const(5, "int64"), tir.const(0, "int64")),
TestCase(tir.const(5, "uint8") - tir.const(5, "uint8"), tir.const(0, "uint8")),
TestCase(var_int8 - var_int8, tir.const(0, "int8")),
TestCase(var_int32 - var_int32, tir.const(0, "int32")),
TestCase(var_int64 - var_int64, tir.const(0, "int64")),
TestCase(var_uint8 - var_uint8, tir.const(0, "uint8")),
TestCase(var_uint32 - var_uint32, tir.const(0, "uint32")),
TestCase(var_uint64 - var_uint64, tir.const(0, "uint64")),
TestCase(tir.EQ(tir.const(5, "int64"), tir.const(5, "int64")), tir.const(True, "bool")),
TestCase(tir.EQ(tir.const(5, "uint8"), tir.const(5, "uint8")), tir.const(True, "bool")),
TestCase(tir.EQ(var_int8, var_int8), tir.const(True, "bool")),
TestCase(tir.EQ(var_int32, var_int32), tir.const(True, "bool")),
TestCase(tir.EQ(var_int64, var_int64), tir.const(True, "bool")),
TestCase(tir.EQ(var_uint8, var_uint8), tir.const(True, "bool")),
TestCase(tir.EQ(var_uint32, var_uint32), tir.const(True, "bool")),
TestCase(tir.EQ(var_uint64, var_uint64), tir.const(True, "bool")),
TestCase(tir.NE(tir.const(5, "int64"), tir.const(5, "int64")), tir.const(False, "bool")),
TestCase(tir.NE(tir.const(5, "uint8"), tir.const(5, "uint8")), tir.const(False, "bool")),
TestCase(tir.NE(var_int8, var_int8), tir.const(False, "bool")),
TestCase(tir.NE(var_int32, var_int32), tir.const(False, "bool")),
TestCase(tir.NE(var_int64, var_int64), tir.const(False, "bool")),
TestCase(tir.NE(var_uint8, var_uint8), tir.const(False, "bool")),
TestCase(tir.NE(var_uint32, var_uint32), tir.const(False, "bool")),
TestCase(tir.NE(var_uint64, var_uint64), tir.const(False, "bool")),
)


class TestAddIndex(BaseCompare):
x, y, z = te.var("x"), te.var("y"), te.var("z")

Expand Down
29 changes: 29 additions & 0 deletions tests/python/arith/test_arith_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,35 @@ def test_simplify_reshape_flattened_index():
)


dtype = tvm.testing.parameter(
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"float64",
)


def test_can_prove_self_identity(dtype):
ana = tvm.arith.Analyzer()

n = tir.Var("n", dtype)
assert ana.can_prove(n == n)


def test_can_prove_self_equal_to_self(dtype):
ana = tvm.arith.Analyzer()

n = tir.Var("n", dtype)
assert ana.can_prove_equal(n, n)


def test_simplify_symbolic_comparison():
ana = tvm.arith.Analyzer()

Expand Down