Skip to content

Commit

Permalink
support arithmetic size (only multiplication for now) in codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Aug 29, 2019
1 parent bf5f79e commit 0e9685e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/pass/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,41 @@ void BinderAddAssert(Expr cond,
}
}

// Deal with things like arg = 8 * M
// TODO(Yizhi Liu): make it more general
std::pair<Expr, Expr> TryExtractVariable(const Expr& arg, const Expr& value) {
Expr var, factor;
if (const Mul* op = arg.as<Mul>()) {
var = op->a.as<Variable>() ? op->a : op->b;
factor = op->a.as<Variable>() ? op->b : op->a;
if (var.as<Variable>() && factor.as<IntImm>()) {
return std::make_pair(var, value / factor);
}
}
return std::make_pair(arg, value);
}

bool ArgBinder::Bind_(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_lets) {
CHECK_EQ(arg.type(), value.type());
if (const Variable* v = arg.as<Variable>()) {
Expr new_arg, new_value;
std::tie(new_arg, new_value) = TryExtractVariable(Simplify(arg), value);
if (const Variable* v = new_arg.as<Variable>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
Var v_arg(arg.node_);
Var v_arg(new_arg.node_);
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = arg;
init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0)));
(*def_map_)[v] = new_arg;
init_nest_.emplace_back(LetStmt::make(v_arg, new_value, Evaluate::make(0)));
} else {
(*def_map_)[v] = value;
(*def_map_)[v] = new_value;
}
return true;
} else {
BinderAddAssert(it->second == value, arg_name, &asserts_);
BinderAddAssert(it->second == new_value, arg_name, &asserts_);
}
} else {
BinderAddAssert(arg == value, arg_name, &asserts_);
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,27 @@ def check_llvm(n):
check_llvm(64)


def test_llvm_arith_size():
print("test_llvm_arith_size")
def check_llvm(N, n):
if not tvm.module.enabled("llvm"):
return
A = tvm.placeholder((N, ), name='A')
C = tvm.compute((N,), lambda i: A[i], name='C')
s = tvm.create_schedule(C.op)
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
c = tvm.nd.empty((n,), A.dtype, ctx)
f(a, c)
c_np = a.asnumpy()
tvm.testing.assert_allclose(c.asnumpy(), c_np)
check_llvm(8 * tvm.var('N'), 8)
check_llvm(32 * tvm.var('N'), 64)


def test_rank_zero():
def check_llvm(n):
if not tvm.module.enabled("llvm"):
Expand Down Expand Up @@ -587,6 +608,7 @@ def vectorizer(op):
test_llvm_bool()
test_llvm_persist_parallel()
test_llvm_condition()
test_llvm_arith_size()
test_llvm_vadd_pipeline()
test_llvm_add_pipeline()
test_llvm_intrin()
Expand Down

0 comments on commit 0e9685e

Please sign in to comment.