From 8458bd78bf7ccd52845d2431e6a6bd74fe20d168 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 15 Jul 2024 15:15:00 -0500 Subject: [PATCH 1/2] [TIR][Analyzer] Simplify `x==x` expressions for all dtypes Prior to this commit, there was no rule to simplify `x == x` into `True`. In some cases, despite not having an explicit rewrite rule in `RewriteSimplifier`, the `RewriteSimplifier::CanProve` function would check if `x-x` simplifies to zero, relying on the rewrite rules used for `tir::Sub`. However, the rule to rewrite `x-x` into zero was only enabled for `int32`, `int64`, and floating-point types, so relying on this behavior was inconsistent. This commit updates the rewrite rules for both `tir::EQ` and `tir::Sub` to check for simplification of `x-x` or `x==x`, regardless of the datatype. This change preserves the fast-path for index data-types, in which `int32` and `int64` expressions may be simplified without checking for side effects. For all other dtypes, the cancellation only applies when evaluating `x` has no side effects. --- src/arith/rewrite_simplify.cc | 9 ++++- .../arith/test_arith_rewrite_simplify.py | 36 +++++++++++++++++++ tests/python/arith/test_arith_simplify.py | 29 +++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index f4d4a9048ced..59d15c464b47 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -543,6 +543,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; + // Vector rules if (op->dtype.is_scalable_or_fixed_length_vector()) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); @@ -697,7 +698,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1)); TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y); TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1)); - } else if (op->dtype.is_float()) { + } else { // Cancellation rules. Deliberately off of the integer path, to // avoid introducing checks on the side effects for the fast path. TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x), @@ -1678,6 +1679,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // Pattern var match IntImm PVar c1, c2; PVar lanes; + PConst ctrue(make_const(ret->dtype, true)); // vector rule if (ret->dtype.is_scalable_or_fixed_length_vector()) { @@ -1698,6 +1700,11 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { TVM_TRY_REWRITE(c1 - x == c2, x == c1 - c2); TVM_TRY_REWRITE(x + c1 == c2, x == c2 - c1); TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0); + TVM_TRY_REWRITE(x == x, ctrue); + } else { + // Mimic the cancellation rules for SubNode. For Index datatypes, + // we skip the check for side effects. + TVM_TRY_REWRITE_IF(x == x, ctrue, SideEffect(x.Eval()) <= CallEffectKind::kReadState); } return std::move(ret); } diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 1ebaab53af2d..90f0aeef47d7 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -321,6 +321,42 @@ class TestSelect(BaseCompare): ) +class TestCancellation(BaseCompare): + var_int8 = tir.Var("var_int8", "int8") + var_int32 = tir.Var("var_int32", "int32") + var_int64 = tir.Var("var_int64", "int64") + var_uint8 = tir.Var("var_uint8", "uint8") + var_uint32 = tir.Var("var_uint32", "uint32") + var_uint64 = tir.Var("var_uint64", "uint64") + + test_case = tvm.testing.parameter( + TestCase(tir.const(5, "int64") - tir.const(5, "int64"), tir.const(0, "int64")), + TestCase(tir.const(5, "uint8") - tir.const(5, "uint8"), tir.const(0, "uint8")), + TestCase(var_int8 - var_int8, tir.const(0, "int8")), + TestCase(var_int32 - var_int32, tir.const(0, "int32")), + TestCase(var_int64 - var_int64, tir.const(0, "int64")), + TestCase(var_uint8 - var_uint8, tir.const(0, "uint8")), + TestCase(var_uint32 - var_uint32, tir.const(0, "uint32")), + TestCase(var_uint64 - var_uint64, tir.const(0, "uint64")), + TestCase(tir.EQ(tir.const(5, "int64"), tir.const(5, "int64")), tir.const(True, "bool")), + TestCase(tir.EQ(tir.const(5, "uint8"), tir.const(5, "uint8")), tir.const(True, "bool")), + TestCase(tir.EQ(var_int8, var_int8), tir.const(True, "bool")), + TestCase(tir.EQ(var_int32, var_int32), tir.const(True, "bool")), + TestCase(tir.EQ(var_int64, var_int64), tir.const(True, "bool")), + TestCase(tir.EQ(var_uint8, var_uint8), tir.const(True, "bool")), + TestCase(tir.EQ(var_uint32, var_uint32), tir.const(True, "bool")), + TestCase(tir.EQ(var_uint64, var_uint64), tir.const(True, "bool")), + TestCase(tir.NE(tir.const(5, "int64"), tir.const(5, "int64")), tir.const(False, "bool")), + TestCase(tir.NE(tir.const(5, "uint8"), tir.const(5, "uint8")), tir.const(False, "bool")), + TestCase(tir.NE(var_int8, var_int8), tir.const(False, "bool")), + TestCase(tir.NE(var_int32, var_int32), tir.const(False, "bool")), + TestCase(tir.NE(var_int64, var_int64), tir.const(False, "bool")), + TestCase(tir.NE(var_uint8, var_uint8), tir.const(False, "bool")), + TestCase(tir.NE(var_uint32, var_uint32), tir.const(False, "bool")), + TestCase(tir.NE(var_uint64, var_uint64), tir.const(False, "bool")), + ) + + class TestAddIndex(BaseCompare): x, y, z = te.var("x"), te.var("y"), te.var("z") diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 9a0245d27487..3b0237740045 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -38,6 +38,35 @@ def test_simplify_reshape_flattened_index(): ) +dtype = tvm.testing.parameter( + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", +) + + +def test_can_prove_self_identity(dtype): + ana = tvm.arith.Analyzer() + + n = tir.Var("n", dtype) + assert ana.can_prove(n == n) + + +def test_can_prove_self_equal_to_self(dtype): + ana = tvm.arith.Analyzer() + + n = tir.Var("n", dtype) + assert ana.can_prove_equal(n, n) + + def test_simplify_symbolic_comparison(): ana = tvm.arith.Analyzer() From 474dfd1762f041dd5ef0ba6908c01047ec9be85a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 18 Jul 2024 16:06:49 -0500 Subject: [PATCH 2/2] Add comment about simplifications of NaN/Inf --- src/arith/rewrite_simplify.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 59d15c464b47..3682054e8e4b 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -701,6 +701,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { } else { // Cancellation rules. Deliberately off of the integer path, to // avoid introducing checks on the side effects for the fast path. + // + // These simplifications do not preserve NaN/Inf that may occur in + // the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does + // not cancel out. However, since models should not encounter NaN + // in the first place, this allows better simplification for the + // supported path. TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x), SideEffect(x.Eval()) <= CallEffectKind::kReadState); TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState); @@ -1704,6 +1710,12 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { } else { // Mimic the cancellation rules for SubNode. For Index datatypes, // we skip the check for side effects. + // + // These simplifications do not preserve NaN/Inf that may occur in + // the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does + // not cancel out. However, since models should not encounter NaN + // in the first place, this allows better simplification for the + // supported path. TVM_TRY_REWRITE_IF(x == x, ctrue, SideEffect(x.Eval()) <= CallEffectKind::kReadState); } return std::move(ret);