Skip to content

Commit

Permalink
[ARITH] Tight bound for floormod (#6771)
Browse files Browse the repository at this point in the history
  • Loading branch information
hzfan authored Oct 28, 2020
1 parent 3d624ec commit b4858d4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
25 changes: 22 additions & 3 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,23 @@ class ConstIntBoundAnalyzer::Impl
}

Entry VisitExpr_(const FloorModNode* op) final {
/* let a / b = x + y, where x is integer, y \in [0, 1)
* floormod(a, b) = a - floordiv(a, b) * b
* floordiv(a, b) = x
* floormod(a, b) = a - floordiv(a, b) * b
* = a - x * b
* = a - (a / b - y) * b
* = a - a + y * b
* = y * b
* note that 0 <= y < 1
* when b > 0, 0 <= b * y < b
* 0 <= b * y <= b - 1
* when b < 0, b < b * y <= 0
* b + 1 <= b * y <= 0
* In all cases, min(0, b + 1) <= b * y <= max(0, b - 1)
* min(0, b_min + 1) <= b * y <= max(0, b_max - 1)
* That is, min(0, b_min + 1) <= floormod(a, b) <= max(0, b_max - 1)
*/
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
if (b.min_value > 0) {
Expand All @@ -259,9 +276,11 @@ class ConstIntBoundAnalyzer::Impl
}
} else {
ICHECK(!b.is_const(0)) << "floormod by zero";
// mod by negative value is rare,
// and we just use the simpliest rule.
return Everything(op->dtype);
int64_t b_min_cap = InfAwareAdd(b.min_value, 1);
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
return Intersect(MakeBound(std::min(static_cast<int64_t>(0), b_min_cap),
std::max(static_cast<int64_t>(0), b_max_cap)),
Everything(op->dtype));
}
}

Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_arith_const_int_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,17 @@ def test_let_bound():
assert bd.max_value == 2


def test_floormod_negative_divisor():
analyzer = tvm.arith.Analyzer()
flm, fld = tvm.te.floormod, tvm.te.floordiv
a, b = te.var("a"), te.var("b")
analyzer.update(a, tvm.arith.ConstIntBound(0, 6))
analyzer.update(b, tvm.arith.ConstIntBound(-5, 7))
bd = analyzer.const_int_bound(flm(a, b))
assert bd.min_value == -4
assert bd.max_value == 6


if __name__ == "__main__":
test_let_bound()
test_dtype_bound()
Expand All @@ -318,3 +329,4 @@ def test_let_bound():
test_shift_and_bound()
test_mix_index_bound()
test_size_var_bound()
test_floormod_negative_divisor()

0 comments on commit b4858d4

Please sign in to comment.