Skip to content

Commit dbe6ce6

Browse files
[ARITH] Fix canonical simplify for LE with incorrect range assumptions
Fix a bug in canonical simplification of less-than expressions where the algorithm incorrectly assumed variables could have negative values when simplifying expressions of the form `ax + b < c`. The previous implementation checked if `-d < xn < d` before simplifying, but this was incorrect when variables are constrained to non-negative ranges. For example, with constraints `0 < x, y < 2` and expression `2x + y < 8`, the algorithm would incorrectly check if `-2 < y < 2` and then simplify to `x < 4`. However, when x=4 and y=-1, we get 2*4 + (-1) = 7 < 8, which satisfies the original constraint but violates the intended variable bounds. The fix changes the range check to `0 <= xn < d`, ensuring that simplification only occurs when variables are properly bounded from below at zero. Co-authored-by: FeiyangChen <[email protected]>
1 parent bcae402 commit dbe6ce6

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

src/arith/canonical_simplify.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,7 +1391,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) {
13911391
// First convert a < b into a - b < 0
13921392
PrimExpr expr = this->CanonicalMutate(op->a - op->b);
13931393
// Case: x0 * s0 + x1 * s1 + ... + xn + c < 0, let d = gcd(s0, s1, ..., s{n-1}, c)
1394-
// 1. if can prove -d < xn < d, then we can simplify
1394+
// 1. if can prove 0 <= xn < d, then we can simplify
13951395
// the expression to x0 * (s0/d) + x1 * (s1/d) + ... + x{n-1} * (s{n-1}/d) < c/d,
13961396
// e.g. `x * 8 + y < 16` where `y` \in [0, 8), we can simplify it to `x < 2`
13971397
// 2. if xn is in pattern of yn % m, where m % d == 0, convert it to yn // d % (m/d)
@@ -1417,8 +1417,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) {
14171417
ICHECK(extra->dtype == dtype);
14181418
PrimExpr normal_extra = extra->Normalize();
14191419
if (this->analyzer_->CanProve(normal_extra < make_const(dtype, gcd)) &&
1420-
this->analyzer_->CanProve(normal_extra > make_const(dtype, -gcd))) {
1421-
// Case 1. -d < xn < d
1420+
this->analyzer_->CanProve(normal_extra >= make_const(dtype, 0))) {
1421+
// Case 1. 0 <= xn < d
14221422
divisible.CopyOnWrite()->DivideBy(gcd);
14231423
return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype));
14241424
} else if (extra->args.size() == 1 &&

src/runtime/pack_args.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ enum ArgConvertCode {
134134
};
135135

136136
inline ArgConvertCode GetArgConvertCode(DLDataType t) {
137-
ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now";
137+
ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to device function for now";
138138
if (t.code == kDLInt) {
139139
if (t.bits == 64U) return INT64_TO_INT64;
140140
if (t.bits == 32U) return INT64_TO_INT32;

tests/python/arith/test_arith_canonical_simplify.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,6 @@ def test_simplify_le():
448448
ck.verify(x * -8 + z * 4 < 16, ck.analyzer.rewrite_simplify(-2 < x))
449449

450450
ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16)
451-
ck.verify(x * 8 + y - z < 16, x < 2)
452451

453452
n = te.size_var("n")
454453
ck.verify(x * 8 + y < n, x * 8 + y < n)

0 commit comments

Comments
 (0)