Skip to content

Commit 2eeb37e

Browse files
authored
[Arith][Bugfix] Simplify "x - 1 < y" into "x <= y" (#14528)
This simplification was introduced in #13217, and was erroneously removed in #13933. This commit re-enables this simplification, and adds unit tests to prevent any future regression.
1 parent e1b49c8 commit 2eeb37e

File tree

2 files changed

+59
-16
lines changed

2 files changed

+59
-16
lines changed

src/arith/rewrite_simplify.cc

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/tir/op.h>
3030

3131
#include <algorithm>
32+
#include <tuple>
3233
#include <utility>
3334

3435
#include "../target/datatype/registry.h"
@@ -120,6 +121,23 @@ PrimExpr NormalizeBooleanOperators(PrimExpr expr) {
120121
}
121122
}
122123

124+
std::tuple<PrimExpr, int64_t> ExtractConstantOffset(const PrimExpr& expr) {
125+
PVar<PrimExpr> x;
126+
PVar<IntImm> c1;
127+
128+
// Any (c1+x) terms are normalized into (x+c1), so we don't need to
129+
// check for it.
130+
if ((x + c1).Match(expr)) {
131+
return {x.Eval(), c1.Eval()->value};
132+
} else if ((x - c1).Match(expr)) {
133+
return {x.Eval(), -c1.Eval()->value};
134+
} else if ((c1 - x).Match(expr)) {
135+
return {x.Eval(), c1.Eval()->value};
136+
} else {
137+
return {expr, 0};
138+
}
139+
}
140+
123141
CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimExpr& y) {
124142
CompareResult output = CompareResult::kUnknown;
125143

@@ -1664,20 +1682,28 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) {
16641682
TVM_TRY_RECURSIVE_REWRITE(x < c1 + y, x - y < c1);
16651683
TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y);
16661684

1667-
if ((x + c1 < y + c2).Match(ret)) {
1668-
int64_t diff = c2.Eval()->value - c1.Eval()->value;
1669-
PrimExpr out = [&]() {
1670-
if (diff == 0) {
1671-
return (x < y).Eval();
1672-
} else if (diff == 1) {
1673-
return (x <= y).Eval();
1674-
} else if (diff < 0) {
1675-
return (x + (-diff) < y).Eval();
1676-
} else {
1677-
return (x < y + diff).Eval();
1678-
}
1679-
}();
1680-
return RecursiveRewrite(out);
1685+
auto merge_constants = [&]() -> Optional<PrimExpr> {
1686+
auto [lhs, lhs_offset] = ExtractConstantOffset(ret->a);
1687+
auto [rhs, rhs_offset] = ExtractConstantOffset(ret->b);
1688+
if (lhs_offset == 0 && rhs_offset == 0) {
1689+
return NullOpt;
1690+
}
1691+
1692+
int64_t diff = rhs_offset - lhs_offset;
1693+
if (diff == 0) {
1694+
return lhs < rhs;
1695+
} else if (diff == 1) {
1696+
return lhs <= rhs;
1697+
} else if (diff < 0 && rhs_offset != 0) {
1698+
return lhs + make_const(lhs.dtype(), -diff) < rhs;
1699+
} else if (diff > 0 && lhs_offset != 0) {
1700+
return lhs < rhs + make_const(rhs.dtype(), diff);
1701+
}
1702+
1703+
return NullOpt;
1704+
}();
1705+
if (merge_constants) {
1706+
return RecursiveRewrite(merge_constants.value());
16811707
}
16821708
}
16831709
return std::move(ret);

tests/python/unittest/test_arith_rewrite_simplify.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,23 @@ class TestComparisons(BaseCompare):
750750
TestCase((x - 10).equal(0), x.equal(10)),
751751
TestCase((10 - x).equal(0), x.equal(10)),
752752
TestCase((x * y).equal(0), tvm.tir.Or(x.equal(0), y.equal(0))),
753+
# Write LT as LE for integer arguments, if possible
754+
TestCase(x - 1 < y, x <= y),
755+
TestCase(x + (-1) < y, x <= y),
756+
TestCase(x < y - (-1), x <= y),
757+
TestCase(x < y + 1, x <= y),
758+
TestCase(x + 2 < y + 3, x <= y),
759+
TestCase(x - 3 < y - 2, x <= y),
760+
TestCase(x - 3 < y + (-2), x <= y),
761+
TestCase(x + (-3) < y - 2, x <= y),
762+
# Merge constants on the LHS/RHS of a LT expression.
763+
TestCase(x + 10 < y + 10, x < y),
764+
TestCase(x + 5 < y + 10, x < y + 5),
765+
TestCase(x + 10 < y + 5, x + 5 < y),
766+
TestCase(x - 5 < y - 10, x + 5 < y),
767+
TestCase(x - 10 < y - 5, x < y + 5),
768+
TestCase(x < y - 10, x + 10 < y),
769+
TestCase(x - 10 < y, x < y + 10),
753770
# cmp bound
754771
TestCase(x + y < x + z, y < z),
755772
TestCase(x + y < z + x, y < z),
@@ -815,7 +832,7 @@ class TestComparisons(BaseCompare):
815832
TestCase(tdiv(x, 4) * 4 < x - y, tvm.tir.LT(y, tmod(x, 4))),
816833
TestCase(tdiv(x + 2, 4) * 4 >= x, tvm.tir.LE(tmod(x + 2, 4), 2)),
817834
TestCase(tdiv(x + 2, 4) * 4 >= x + y, tvm.tir.LE(tmod(x + 2, 4) + y, 2)),
818-
TestCase(tdiv(x + 2, 4) * 4 >= x - y, tvm.tir.LE(tmod(x + 2, 4) + (-2), y)),
835+
TestCase(tdiv(x + 2, 4) * 4 >= x - y, tvm.tir.LE(tmod(x + 2, 4), y + 2)),
819836
# floor div
820837
TestCase(fld(x, 2) < 3, x < 6),
821838
TestCase(3 < fld(x, 2), tvm.tir.LT(7, x)),
@@ -833,7 +850,7 @@ class TestComparisons(BaseCompare):
833850
TestCase(fld(x, 4) * 4 < x - y, tvm.tir.LT(y, flm(x, 4))),
834851
TestCase(fld(x + 2, 4) * 4 >= x, tvm.tir.LE(flm(x + 2, 4), 2)),
835852
TestCase(fld(x + 2, 4) * 4 >= x + y, tvm.tir.LE(flm(x + 2, 4) + y, 2)),
836-
TestCase(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4) + (-2), y)),
853+
TestCase(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4), y + 2)),
837854
# End DivMod Rules
838855
# merging flm/fld into known value
839856
TestCase(tir.all(fld(x, 8) == 3, flm(x, 8) == 4), x == 28),

0 commit comments

Comments
 (0)