Skip to content

Commit

Permalink
[schedule] Improve ceil_divide in tile/split (apache#3842)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored and wweic committed Sep 16, 2019
1 parent 9cf7a73 commit 8dbc631
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ void PassDownDomain(const Stage& stage,
arith::Analyzer* actx,
bool allow_missing) {
auto ceil_div = [actx](Expr a, Expr b) {
if (actx->CanProve(a % b == 0)) {
return actx->Simplify(a / b);
}
return actx->Simplify((a + (b - 1)) / b);
};

Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_schedule_bound_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,33 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)

def test_bound_split_divisible():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((8 * m, l), name='A')
B = tvm.compute((8 * m, l), lambda i, j: A[i, j], name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xo].extent == m
assert bounds[xi].extent.value == 8

def test_bound_tile_divisible():
m = tvm.var('m')
l = tvm.var('l')
shape = (8 * m, 32 * l)
A = tvm.placeholder(shape, name='A')
B = tvm.compute(shape, lambda i, j: A[i, j], name='B')
s = tvm.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], 8, 32)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xo].extent == m
assert bounds[xi].extent.value == 8
assert bounds[yo].extent == l
assert bounds[yi].extent.value == 32

def test_bound_fusesplit1():
m = tvm.var('m')
l = tvm.var('l')
Expand Down Expand Up @@ -393,3 +420,5 @@ def _check(B, A=A):
test_bound_simplification_failure()
test_bound_fusesplit1()
test_bound_fusesplit2()
test_bound_split_divisible()
test_bound_tile_divisible()

0 comments on commit 8dbc631

Please sign in to comment.