Skip to content

Commit 3c6f9c9

Browse files
authored
[Arith] Added simplification rule for multiple equality compares (#15628)
The expression `(x==y) && (x==z)` requires that `y==z`. When `y` and `z` are constants, this can allow better constant folding by rewriting `(x==c1) && (x==c2)` into `(x==c1) && (c1==c2)`. This commit adds the above rewrite, and the corresponding rewrite of the negative expression.
1 parent c921781 commit 3c6f9c9

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

src/arith/rewrite_simplify.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,6 +1856,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
18561856
}),
18571857
cfalse, c2.Eval()->value > c1.Eval()->value);
18581858

1859+
TVM_TRY_REWRITE((x == c1) && (x == c2), (x == c1) && (c1 == c2));
18591860
TVM_TRY_REWRITE(matches_one_of(x == c1 && x != c2, x != c2 && x == c1), x == c1 && c1 != c2);
18601861

18611862
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && floormod(x, c2) == c3,
@@ -2000,6 +2001,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
20002001
TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
20012002
TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
20022003

2004+
TVM_TRY_REWRITE(x != c1 || x != c2, x != c1 || c1 != c2);
20032005
TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
20042006
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);
20052007

tests/python/unittest/test_arith_rewrite_simplify.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,7 @@ class TestLogical(BaseCompare):
951951
TestCase(tvm.tir.And(x <= 1, 2 <= x), tvm.tir.const(False, "bool")),
952952
TestCase(tvm.tir.And(2 <= x, x <= 1), tvm.tir.const(False, "bool")),
953953
TestCase(tvm.tir.And(x == 1, x != 2), x == 1),
954+
TestCase(tvm.tir.And(x == 1, x == 2), tvm.tir.const(False, "bool")),
954955
TestCase(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), tvm.tir.const(True, "bool")),
955956
TestCase(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), tvm.tir.const(True, "bool")),
956957
TestCase(tvm.tir.Or(x > y, tvm.tir.Not(x > y)), tvm.tir.const(True, "bool")),
@@ -965,6 +966,7 @@ class TestLogical(BaseCompare):
965966
TestCase(tvm.tir.Or(x <= 1, 2 <= x), tvm.tir.const(True, "bool")),
966967
TestCase(tvm.tir.Or(2 <= x, x <= 1), tvm.tir.const(True, "bool")),
967968
TestCase(tvm.tir.Or(x != 1, x == 2), x != 1),
969+
TestCase(tvm.tir.Or(x != 1, x != 2), tvm.tir.const(True, "bool")),
968970
TestCase(
969971
tvm.tir.Or(x == 1, tvm.tir.Or(y == 1, z == 1)),
970972
tvm.tir.Or(tvm.tir.Or(x == 1, y == 1), z == 1),

0 commit comments

Comments
 (0)