Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[schedule] Improve ceil_divide in tile/split #3842

Merged
merged 5 commits into from
Sep 6, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions src/pass/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/runtime/device_api.h>
#include <utility>
#include "ir_util.h"
#include "arg_binder.h"
#include "../arithmetic/compute_expr.h"
Expand All @@ -46,25 +47,41 @@ void BinderAddAssert(Expr cond,
}
}

// Deal with things like arg = 8 * M
// TODO(Yizhi Liu): make it more general
std::pair<Expr, Expr> TryExtractVariable(const Expr& arg, const Expr& value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please comment about an example use-case.
This logic deals with an additional constraint like 8 * arg = value, and we would like to know what is the example application. Because normally, this will generate a constraint instead and the simplifier should be able to simplify the constraint. So we won't need this logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without the fix, tvm will complain things like "TVMError: Not all Vars are passed in api_args: 'M' 'K' 'N' does not appear in api_args"

The example use-case is op computation for mxnet. To get decent performance for unknown shape (during compile time), we generate multiple kernels for one operator, and dispatch according to the shape during runtime. However, if the shape is pure symbolic, tiling/splitting has to generate if-else, which hurts the performance. One idea is to "hint" tvm, say, this shape can be divided by 8, pls don't generate if-else for tile size 2/4/8. This is where 8*M comes from.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, i think we could use better expressions for such kind of hint. Your proposal is interesting.

We could also use alternatives to resolve this problem, in particular, we can do something like
AssertExpr(x, x % 8 == 0). Given this change affects more of the perf, can be separate it out from this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean remove the second commit (support arithmetic size in codegen) from this PR and merge the first one (Improve ceil_divide in tile/split) for now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use a second PR for that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted

Expr var, factor;
if (const Mul* op = arg.as<Mul>()) {
var = op->a.as<Variable>() ? op->a : op->b;
factor = op->a.as<Variable>() ? op->b : op->a;
if (var.as<Variable>() && factor.as<IntImm>()) {
return std::make_pair(var, value / factor);
}
}
return std::make_pair(arg, value);
}

bool ArgBinder::Bind_(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_lets) {
CHECK_EQ(arg.type(), value.type());
if (const Variable* v = arg.as<Variable>()) {
Expr new_arg, new_value;
std::tie(new_arg, new_value) = TryExtractVariable(Simplify(arg), value);
if (const Variable* v = new_arg.as<Variable>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
Var v_arg(arg.node_);
Var v_arg(new_arg.node_);
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = arg;
init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0)));
(*def_map_)[v] = new_arg;
init_nest_.emplace_back(LetStmt::make(v_arg, new_value, Evaluate::make(0)));
} else {
(*def_map_)[v] = value;
(*def_map_)[v] = new_value;
}
return true;
} else {
BinderAddAssert(it->second == value, arg_name, &asserts_);
BinderAddAssert(it->second == new_value, arg_name, &asserts_);
}
} else {
BinderAddAssert(arg == value, arg_name, &asserts_);
Expand Down
3 changes: 3 additions & 0 deletions src/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ void PassDownDomain(const Stage& stage,
arith::Analyzer* actx,
bool allow_missing) {
auto ceil_div = [actx](Expr a, Expr b) {
if (actx->CanProve(a % b == 0)) {
return actx->Simplify(a / b);
}
return actx->Simplify((a + (b - 1)) / b);
};

Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,27 @@ def check_llvm(n):
check_llvm(64)


def test_llvm_arith_size():
print("test_llvm_arith_size")
def check_llvm(N, n):
if not tvm.module.enabled("llvm"):
return
A = tvm.placeholder((N, ), name='A')
C = tvm.compute((N,), lambda i: A[i], name='C')
s = tvm.create_schedule(C.op)
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
f(a, c)
c_np = a.asnumpy()
tvm.testing.assert_allclose(c.asnumpy(), c_np)
check_llvm(8 * tvm.var('N'), 8)
check_llvm(32 * tvm.var('N'), 64)


def test_rank_zero():
def check_llvm(n):
if not tvm.module.enabled("llvm"):
Expand Down Expand Up @@ -587,6 +608,7 @@ def vectorizer(op):
test_llvm_bool()
test_llvm_persist_parallel()
test_llvm_condition()
test_llvm_arith_size()
test_llvm_vadd_pipeline()
test_llvm_add_pipeline()
test_llvm_intrin()
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_schedule_bound_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,33 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16)

def test_bound_split_divisible():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((8 * m, l), name='A')
B = tvm.compute((8 * m, l), lambda i, j: A[i, j], name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], 8)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xo].extent == m
assert bounds[xi].extent.value == 8

def test_bound_tile_divisible():
m = tvm.var('m')
l = tvm.var('l')
shape = (8 * m, 32 * l)
A = tvm.placeholder(shape, name='A')
B = tvm.compute(shape, lambda i, j: A[i, j], name='B')
s = tvm.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], 8, 32)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert bounds[xo].extent == m
assert bounds[xi].extent.value == 8
assert bounds[yo].extent == l
assert bounds[yi].extent.value == 32

def test_bound_fusesplit1():
m = tvm.var('m')
l = tvm.var('l')
Expand Down Expand Up @@ -393,3 +420,5 @@ def _check(B, A=A):
test_bound_simplification_failure()
test_bound_fusesplit1()
test_bound_fusesplit2()
test_bound_split_divisible()
test_bound_tile_divisible()