From a5d790449d6a2ea27686de8ae8bc1a90d9cf21a1 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 6 Sep 2019 21:29:31 +0800 Subject: [PATCH] [schedule] Improve ceil_divide in tile/split (#3842) --- src/schedule/message_passing.cc | 3 ++ .../unittest/test_schedule_bound_inference.py | 29 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index b39a6c6ed7070..c5c79ea4229d6 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -56,6 +56,9 @@ void PassDownDomain(const Stage& stage, arith::Analyzer* actx, bool allow_missing) { auto ceil_div = [actx](Expr a, Expr b) { + if (actx->CanProve(a % b == 0)) { + return actx->Simplify(a / b); + } return actx->Simplify((a + (b - 1)) / b); }; diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 21be6b7ec8bd6..1ff985356ee87 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -69,6 +69,33 @@ def test_bound3(): assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[1]].extent.value==16) +def test_bound_split_divisible(): + m = tvm.var('m') + l = tvm.var('l') + A = tvm.placeholder((8 * m, l), name='A') + B = tvm.compute((8 * m, l), lambda i, j: A[i, j], name='B') + s = tvm.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], 8) + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert bounds[xo].extent == m + assert bounds[xi].extent.value == 8 + +def test_bound_tile_divisible(): + m = tvm.var('m') + l = tvm.var('l') + shape = (8 * m, 32 * l) + A = tvm.placeholder(shape, name='A') + B = tvm.compute(shape, lambda i, j: A[i, j], name='B') + s = tvm.create_schedule(B.op) + xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], 8, 32) + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert bounds[xo].extent == m + assert bounds[xi].extent.value == 8 + assert bounds[yo].extent == l + assert bounds[yi].extent.value == 32 + def test_bound_fusesplit1(): m = tvm.var('m') l = tvm.var('l') @@ -393,3 +420,5 @@ def _check(B, A=A): test_bound_simplification_failure() test_bound_fusesplit1() test_bound_fusesplit2() + test_bound_split_divisible() + test_bound_tile_divisible()