Skip to content

Commit 21154c2

Browse files
[TE] Fix Const Int bound analysis to handle uints for division (#10102)
* case to handle uints * add unit test
1 parent 4d0dac3 commit 21154c2

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

src/arith/const_int_bound.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ class ConstIntBoundAnalyzer::Impl
430430
// the domain ranges.
431431

432432
// If the range of b contains 0, then some infinity will be involved
433-
if (b.min_value <= 0 && 0 <= b.max_value) {
433+
if (b.min_value <= 0 && 0 <= b.max_value && dt.is_int()) {
434434
Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt);
435435
Entry b_pos = b.max_value > 0 ? MakeBound(1, b.max_value) : Everything(dt);
436436

@@ -439,6 +439,10 @@ class ConstIntBoundAnalyzer::Impl
439439

440440
return MakeBound(std::min(e_neg.min_value, e_pos.min_value),
441441
std::max(e_neg.max_value, e_pos.max_value));
442+
} else if (b.min_value == 0 && dt.is_uint()) {
443+
// uints only have one sided bounds
444+
Entry assumed_b = MakeBound(1, b.max_value);
445+
return BinaryOpBoundary(a, assumed_b, op);
442446
}
443447
// If the range of b does not have 0, use BinaryOpBoundary.
444448
return BinaryOpBoundary(a, b, op);

tests/python/unittest/test_arith_const_int_bound.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,14 @@ def test_floordiv_bound():
195195
assert bd.min_value == -9
196196
assert bd.max_value == 9
197197

198+
# Test handling unsigned integers well
199+
x, y = te.var("x", dtype="uint32"), te.var("y", dtype="uint32")
200+
analyzer.update(x, tvm.arith.ConstIntBound(1, 4), override=True)
201+
analyzer.update(y, tvm.arith.ConstIntBound(0, 12), override=True)
202+
bd = analyzer.const_int_bound(fld(x, y))
203+
assert bd.min_value == 0
204+
assert bd.max_value == 4
205+
198206

199207
def test_floormod_bound():
200208
analyzer = tvm.arith.Analyzer()

0 commit comments

Comments
 (0)