diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index f4d4a9048ced..3682054e8e4b 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -543,6 +543,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; + // Vector rules if (op->dtype.is_scalable_or_fixed_length_vector()) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); @@ -697,9 +698,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1)); TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y); TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1)); - } else if (op->dtype.is_float()) { + } else { // Cancellation rules. Deliberately off of the integer path, to // avoid introducing checks on the side effects for the fast path. + // + // These simplifications do not preserve NaN/Inf that may occur in + // the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does + // not cancel out. However, since models should not encounter NaN + // in the first place, this allows better simplification for the + // supported path. TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x), SideEffect(x.Eval()) <= CallEffectKind::kReadState); TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState); @@ -1678,6 +1685,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // Pattern var match IntImm PVar c1, c2; PVar lanes; + PConst ctrue(make_const(ret->dtype, true)); // vector rule if (ret->dtype.is_scalable_or_fixed_length_vector()) { @@ -1698,6 +1706,17 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { TVM_TRY_REWRITE(c1 - x == c2, x == c1 - c2); TVM_TRY_REWRITE(x + c1 == c2, x == c2 - c1); TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0); + TVM_TRY_REWRITE(x == x, ctrue); + } else { + // Mimic the cancellation rules for SubNode. For Index datatypes, + // we skip the check for side effects. + // + // These simplifications do not preserve NaN/Inf that may occur in + // the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does + // not cancel out. However, since models should not encounter NaN + // in the first place, this allows better simplification for the + // supported path. + TVM_TRY_REWRITE_IF(x == x, ctrue, SideEffect(x.Eval()) <= CallEffectKind::kReadState); } return std::move(ret); } diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 1ebaab53af2d..90f0aeef47d7 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -321,6 +321,42 @@ class TestSelect(BaseCompare): ) +class TestCancellation(BaseCompare): + var_int8 = tir.Var("var_int8", "int8") + var_int32 = tir.Var("var_int32", "int32") + var_int64 = tir.Var("var_int64", "int64") + var_uint8 = tir.Var("var_uint8", "uint8") + var_uint32 = tir.Var("var_uint32", "uint32") + var_uint64 = tir.Var("var_uint64", "uint64") + + test_case = tvm.testing.parameter( + TestCase(tir.const(5, "int64") - tir.const(5, "int64"), tir.const(0, "int64")), + TestCase(tir.const(5, "uint8") - tir.const(5, "uint8"), tir.const(0, "uint8")), + TestCase(var_int8 - var_int8, tir.const(0, "int8")), + TestCase(var_int32 - var_int32, tir.const(0, "int32")), + TestCase(var_int64 - var_int64, tir.const(0, "int64")), + TestCase(var_uint8 - var_uint8, tir.const(0, "uint8")), + TestCase(var_uint32 - var_uint32, tir.const(0, "uint32")), + TestCase(var_uint64 - var_uint64, tir.const(0, "uint64")), + TestCase(tir.EQ(tir.const(5, "int64"), tir.const(5, "int64")), tir.const(True, "bool")), + TestCase(tir.EQ(tir.const(5, "uint8"), tir.const(5, "uint8")), tir.const(True, "bool")), + TestCase(tir.EQ(var_int8, var_int8), tir.const(True, "bool")), + TestCase(tir.EQ(var_int32, var_int32), tir.const(True, "bool")), + TestCase(tir.EQ(var_int64, var_int64), tir.const(True, "bool")), + TestCase(tir.EQ(var_uint8, var_uint8), tir.const(True, "bool")), + TestCase(tir.EQ(var_uint32, var_uint32), tir.const(True, "bool")), + TestCase(tir.EQ(var_uint64, var_uint64), tir.const(True, "bool")), + TestCase(tir.NE(tir.const(5, "int64"), tir.const(5, "int64")), tir.const(False, "bool")), + TestCase(tir.NE(tir.const(5, "uint8"), tir.const(5, "uint8")), tir.const(False, "bool")), + TestCase(tir.NE(var_int8, var_int8), tir.const(False, "bool")), + TestCase(tir.NE(var_int32, var_int32), tir.const(False, "bool")), + TestCase(tir.NE(var_int64, var_int64), tir.const(False, "bool")), + TestCase(tir.NE(var_uint8, var_uint8), tir.const(False, "bool")), + TestCase(tir.NE(var_uint32, var_uint32), tir.const(False, "bool")), + TestCase(tir.NE(var_uint64, var_uint64), tir.const(False, "bool")), + ) + + class TestAddIndex(BaseCompare): x, y, z = te.var("x"), te.var("y"), te.var("z") diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 9a0245d27487..3b0237740045 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -38,6 +38,35 @@ def test_simplify_reshape_flattened_index(): ) +dtype = tvm.testing.parameter( + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", +) + + +def test_can_prove_self_identity(dtype): + ana = tvm.arith.Analyzer() + + n = tir.Var("n", dtype) + assert ana.can_prove(n == n) + + +def test_can_prove_self_equal_to_self(dtype): + ana = tvm.arith.Analyzer() + + n = tir.Var("n", dtype) + assert ana.can_prove_equal(n, n) + + def test_simplify_symbolic_comparison(): ana = tvm.arith.Analyzer()