From bf5f79efb4933582a09212b8a36467e461f2dd53 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 27 Aug 2019 15:32:03 -0700 Subject: [PATCH 1/4] [schedule] Improve ceil_divide in tile/split --- src/schedule/message_passing.cc | 3 ++ .../unittest/test_schedule_bound_inference.py | 29 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index b39a6c6ed707..c5c79ea4229d 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -56,6 +56,9 @@ void PassDownDomain(const Stage& stage, arith::Analyzer* actx, bool allow_missing) { auto ceil_div = [actx](Expr a, Expr b) { + if (actx->CanProve(a % b == 0)) { + return actx->Simplify(a / b); + } return actx->Simplify((a + (b - 1)) / b); }; diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 21be6b7ec8bd..1ff985356ee8 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -69,6 +69,33 @@ def test_bound3(): assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[1]].extent.value==16) +def test_bound_split_divisible(): + m = tvm.var('m') + l = tvm.var('l') + A = tvm.placeholder((8 * m, l), name='A') + B = tvm.compute((8 * m, l), lambda i, j: A[i, j], name='B') + s = tvm.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], 8) + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert bounds[xo].extent == m + assert bounds[xi].extent.value == 8 + +def test_bound_tile_divisible(): + m = tvm.var('m') + l = tvm.var('l') + shape = (8 * m, 32 * l) + A = tvm.placeholder(shape, name='A') + B = tvm.compute(shape, lambda i, j: A[i, j], name='B') + s = tvm.create_schedule(B.op) + xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], 8, 32) + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert bounds[xo].extent == m + assert bounds[xi].extent.value == 8 + assert bounds[yo].extent == l + assert bounds[yi].extent.value == 32 + def test_bound_fusesplit1(): m = tvm.var('m') l = tvm.var('l') @@ -393,3 +420,5 @@ def _check(B, A=A): test_bound_simplification_failure() test_bound_fusesplit1() test_bound_fusesplit2() + test_bound_split_divisible() + test_bound_tile_divisible() From 0e9685ef858d87de2dfb57c2bd6ea95f9bd59ca3 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 29 Aug 2019 09:12:55 -0700 Subject: [PATCH 2/4] support arithmetic size (only multiplication for now) in codegen --- src/pass/arg_binder.cc | 28 +++++++++++++++++----- tests/python/unittest/test_codegen_llvm.py | 22 +++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 8268fc4e1aed..01a2f96efc3f 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -46,25 +46,41 @@ void BinderAddAssert(Expr cond, } } +// Deal with things like arg = 8 * M +// TODO(Yizhi Liu): make it more general +std::pair TryExtractVariable(const Expr& arg, const Expr& value) { + Expr var, factor; + if (const Mul* op = arg.as()) { + var = op->a.as() ? op->a : op->b; + factor = op->a.as() ? op->b : op->a; + if (var.as() && factor.as()) { + return std::make_pair(var, value / factor); + } + } + return std::make_pair(arg, value); +} + bool ArgBinder::Bind_(const Expr& arg, const Expr& value, const std::string& arg_name, bool with_lets) { CHECK_EQ(arg.type(), value.type()); - if (const Variable* v = arg.as()) { + Expr new_arg, new_value; + std::tie(new_arg, new_value) = TryExtractVariable(Simplify(arg), value); + if (const Variable* v = new_arg.as()) { auto it = def_map_->find(v); if (it == def_map_->end()) { - Var v_arg(arg.node_); + Var v_arg(new_arg.node_); defs_.emplace_back(v_arg); if (with_lets) { - (*def_map_)[v] = arg; - init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0))); + (*def_map_)[v] = new_arg; + init_nest_.emplace_back(LetStmt::make(v_arg, new_value, Evaluate::make(0))); } else { - (*def_map_)[v] = value; + (*def_map_)[v] = new_value; } return true; } else { - BinderAddAssert(it->second == value, arg_name, &asserts_); + BinderAddAssert(it->second == new_value, arg_name, &asserts_); } } else { BinderAddAssert(arg == value, arg_name, &asserts_); diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 34dad36a9076..b868681336bd 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -343,6 +343,27 @@ def check_llvm(n): check_llvm(64) +def test_llvm_arith_size(): + print("test_llvm_arith_size") + def check_llvm(N, n): + if not tvm.module.enabled("llvm"): + return + A = tvm.placeholder((N, ), name='A') + C = tvm.compute((N,), lambda i: A[i], name='C') + s = tvm.create_schedule(C.op) + # build and invoke the kernel. + f = tvm.build(s, [A, C], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) + c = tvm.nd.empty((n,), A.dtype, ctx) + f(a, c) + c_np = a.asnumpy() + tvm.testing.assert_allclose(c.asnumpy(), c_np) + check_llvm(8 * tvm.var('N'), 8) + check_llvm(32 * tvm.var('N'), 64) + + def test_rank_zero(): def check_llvm(n): if not tvm.module.enabled("llvm"): @@ -587,6 +608,7 @@ def vectorizer(op): test_llvm_bool() test_llvm_persist_parallel() test_llvm_condition() + test_llvm_arith_size() test_llvm_vadd_pipeline() test_llvm_add_pipeline() test_llvm_intrin() From 67f9385d134f3c312165796f2b6b72f7b05f3ec9 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 29 Aug 2019 09:19:12 -0700 Subject: [PATCH 3/4] fix lint --- src/pass/arg_binder.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 01a2f96efc3f..a6b2d9968569 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "ir_util.h" #include "arg_binder.h" #include "../arithmetic/compute_expr.h" From 425d5c768b6b16d3be249a4ce1894e66fa365e18 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 2 Sep 2019 23:09:30 -0700 Subject: [PATCH 4/4] revert codegen change --- src/pass/arg_binder.cc | 29 +++++----------------- tests/python/unittest/test_codegen_llvm.py | 22 ---------------- 2 files changed, 6 insertions(+), 45 deletions(-) diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index a6b2d9968569..8268fc4e1aed 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include "ir_util.h" #include "arg_binder.h" #include "../arithmetic/compute_expr.h" @@ -47,41 +46,25 @@ void BinderAddAssert(Expr cond, } } -// Deal with things like arg = 8 * M -// TODO(Yizhi Liu): make it more general -std::pair TryExtractVariable(const Expr& arg, const Expr& value) { - Expr var, factor; - if (const Mul* op = arg.as()) { - var = op->a.as() ? op->a : op->b; - factor = op->a.as() ? op->b : op->a; - if (var.as() && factor.as()) { - return std::make_pair(var, value / factor); - } - } - return std::make_pair(arg, value); -} - bool ArgBinder::Bind_(const Expr& arg, const Expr& value, const std::string& arg_name, bool with_lets) { CHECK_EQ(arg.type(), value.type()); - Expr new_arg, new_value; - std::tie(new_arg, new_value) = TryExtractVariable(Simplify(arg), value); - if (const Variable* v = new_arg.as()) { + if (const Variable* v = arg.as()) { auto it = def_map_->find(v); if (it == def_map_->end()) { - Var v_arg(new_arg.node_); + Var v_arg(arg.node_); defs_.emplace_back(v_arg); if (with_lets) { - (*def_map_)[v] = new_arg; - init_nest_.emplace_back(LetStmt::make(v_arg, new_value, Evaluate::make(0))); + (*def_map_)[v] = arg; + init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0))); } else { - (*def_map_)[v] = new_value; + (*def_map_)[v] = value; } return true; } else { - BinderAddAssert(it->second == new_value, arg_name, &asserts_); + BinderAddAssert(it->second == value, arg_name, &asserts_); } } else { BinderAddAssert(arg == value, arg_name, &asserts_); diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index b868681336bd..34dad36a9076 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -343,27 +343,6 @@ def check_llvm(n): check_llvm(64) -def test_llvm_arith_size(): - print("test_llvm_arith_size") - def check_llvm(N, n): - if not tvm.module.enabled("llvm"): - return - A = tvm.placeholder((N, ), name='A') - C = tvm.compute((N,), lambda i: A[i], name='C') - s = tvm.create_schedule(C.op) - # build and invoke the kernel. - f = tvm.build(s, [A, C], "llvm") - ctx = tvm.cpu(0) - # launch the kernel. - a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) - c = tvm.nd.empty((n,), A.dtype, ctx) - f(a, c) - c_np = a.asnumpy() - tvm.testing.assert_allclose(c.asnumpy(), c_np) - check_llvm(8 * tvm.var('N'), 8) - check_llvm(32 * tvm.var('N'), 64) - - def test_rank_zero(): def check_llvm(n): if not tvm.module.enabled("llvm"): @@ -608,7 +587,6 @@ def vectorizer(op): test_llvm_bool() test_llvm_persist_parallel() test_llvm_condition() - test_llvm_arith_size() test_llvm_vadd_pipeline() test_llvm_add_pipeline() test_llvm_intrin()