Skip to content

Commit

Permalink
[ARITH] RewriteSimplifier: improved cmp simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrechanik-h committed Mar 20, 2019
1 parent f81e287 commit f0f3d31
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class RewriteSimplifier::Impl : public IRMutator {
kEQ,
kGT,
kLT,
kGE,
kLE,
kNE
};
// reference to the main analyzer
Expand Down Expand Up @@ -140,6 +142,12 @@ class RewriteSimplifier::Impl : public IRMutator {
if (dbound->max_value < val) {
return kLT;
}
if (dbound->min_value >= val) {
return kGE;
}
if (dbound->max_value <= val) {
return kLE;
}
return kUnknown;
}

Expand Down Expand Up @@ -994,12 +1002,10 @@ Mutate_(const EQ* op, const Expr& self) {

if (IsIndexType(op->a.type())) {
CompareResult result = TryCompare(op->a - op->b, 0);
if (result != kUnknown) {
if (result == kEQ) {
return make_const(op->type, true);
} else {
return make_const(op->type, false);
}
if (result == kEQ) {
return make_const(op->type, true);
} else if (result == kNE || result == kGT || result == kLT) {
return make_const(op->type, false);
}
TVM_TRY_REWRITE(x - c1 == 0, x == c1);
TVM_TRY_REWRITE(c1 - x == 0, x == c1);
Expand Down Expand Up @@ -1055,7 +1061,7 @@ Mutate_(const LT* op, const Expr& self) {
if (result == kLT) {
return make_const(op->type, true);
}
if (result == kEQ || result == kGT) {
if (result == kEQ || result == kGT || result == kGE) {
return make_const(op->type, false);
}

Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,21 @@ def test_cmp_simplify():
ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x))
ck.verify(x + 1 < tvm.max(8, x), x < 7)

ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-10, 0), override=True)
ck.analyzer.update(z, tvm.arith.ConstIntBound(-5, 5), override=True)

ck.verify(x < 11, tvm.const(1, "bool"))
ck.verify(x <= 10, tvm.const(1, "bool"))
ck.verify(z <= 5, tvm.const(1, "bool"))
ck.verify(x + y <= 10, tvm.const(1, "bool"))
ck.verify(x + y >= -10, tvm.const(1, "bool"))
ck.verify(z - 5 <= y + 10, tvm.const(1, "bool"))
ck.verify(tvm.all(x > -1, z <= x + 5), tvm.const(1, "bool"))
ck.verify(x*y <= 0, tvm.const(1, "bool"))
ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool"))
ck.verify(y*y >= 0, tvm.const(1, "bool"))


def test_logical_simplify():
ck = RewriteChecker()
Expand Down

0 comments on commit f0f3d31

Please sign in to comment.