Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->PrintFinalReturn();
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
Expand All @@ -132,8 +131,6 @@ void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; }

void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}

void CodeGenC::PrintFinalReturn() {}

std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }

void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*)
Expand Down Expand Up @@ -537,6 +534,9 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
PrintExpr(op->args[0], os);
os << " ) return ";
PrintExpr(op->args[1], os);
} else if (op->op.same_as(builtin::ret())) {
os << "return ";
PrintExpr(op->args[0], os);
} else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
Expand Down
4 changes: 0 additions & 4 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,6 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
* Example: __launch_bounds__(256) for CUDA functions
*/
virtual void PrintExtraAttrs(const PrimFunc& f);
/*!
* \brief Print the final return at the end the function.
*/
virtual void PrintFinalReturn(); // NOLINT(*)
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
Expand Down
5 changes: 0 additions & 5 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,6 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*)
<< "TVM_DLL int32_t";
}

void CodeGenCHost::PrintFinalReturn() { // NOLINT(*)
this->PrintIndent();
stream << "return 0;\n";
}

std::string CodeGenCHost::Finish() { // NOLINT(*)
std::string ret = decl_stream.str();
if (emit_fwd_func_decl_) {
Expand Down
1 change: 0 additions & 1 deletion src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class CodeGenCHost : public CodeGenC {

void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*)
void PrintFinalReturn() final; // NOLINT(*)

// overload visitor functions
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
Expand Down
6 changes: 6 additions & 0 deletions src/target/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
this->Push(op->args[0]);
this->PushOp(StackVM::PUSH_I64, 0);
this->PushOp(StackVM::EQ_HANDLE);
} else if (op->op.same_as(builtin::ret())) {
CHECK(op->args.size() == 1 && op->args[0]->IsInstance<IntImmNode>() &&
op->args[0].as<IntImmNode>()->value == 0)
<< "StackVM does not support return values, "
<< "and the return value " << op->args
<< " is not special case of returning an error code of zero.";
} else {
LOG(FATAL) << "unknown function call " << op->op;
}
Expand Down
9 changes: 7 additions & 2 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,16 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}
}

// Return error code of zero on success
body = SeqStmt({body, Evaluate(ret(Integer(0)))});

// Apply all argument assertions
std::ostringstream num_args_error;
num_args_error << name_hint << ": num_args should be " << num_args;
std::vector<Stmt> arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())};
func_ptr->body =
MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
body = MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);

func_ptr->body = body;
func_ptr->params = args;

Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
Expand Down
4 changes: 3 additions & 1 deletion src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop));
}

func_ptr->body = MergeNest(device_init, func_ptr->body);
Stmt body = MergeNest(device_init, SeqStmt({func_ptr->body, Evaluate(ret(Integer(0)))}));

func_ptr->body = body;
func_ptr->params = args;
func_ptr->ret_type = PrimType(DataType::Int(32));
func_ptr->buffer_map = Map<Var, Buffer>();
Expand Down
9 changes: 7 additions & 2 deletions tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ def check_packed_func(target="llvm"):
node = prim_func.body

# Recursively visit PrimFunc until we meet the for-loop:
while isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)):
node = node.body
while True:
if isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)):
node = node.body
elif isinstance(node, tvm.tir.SeqStmt):
node = node[0]
else:
break

# For-loop:
assert isinstance(node, tvm.tir.stmt.For)
Expand Down
15 changes: 12 additions & 3 deletions tests/python/unittest/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,18 @@ def _find_assignment(stmt, var_name):


def _find_next(stmt, type):
while not isinstance(stmt, type):
stmt = stmt.body
return stmt
search_stack = [stmt]

while search_stack:
stmt = search_stack.pop()
if isinstance(stmt, type):
return stmt
elif isinstance(stmt, tvm.tir.SeqStmt):
search_stack.extend(reversed(stmt))
else:
search_stack.append(stmt.body)

return None


def _find_compute_scope(func):
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_tir_transform_make_unpacked_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def main(A_data: T.handle("float32")) -> T.int32:
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 2)
mod.subroutine(A_data)
T.ret(T.int32(0))

@T.prim_func
def subroutine(A_data: T.handle("float32")):
Expand Down Expand Up @@ -215,6 +216,7 @@ def main(A_data: T.handle("float32")) -> T.int32:
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)
T.ret(T.int32(0))

@T.prim_func
def subroutine(A_data: T.handle("float32")):
Expand Down Expand Up @@ -259,11 +261,13 @@ def main(A_data: T.handle("float32")) -> T.int32:
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)
T.ret(T.int32(0))

@T.prim_func
def subroutine(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.evaluate(A_data)
T.ret(T.int32(0))

return mod

Expand Down Expand Up @@ -316,13 +320,15 @@ def main(A_data: T.handle("float32")) -> T.int32:
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)
T.ret(T.int32(0))

@T.prim_func
def subroutine(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
T.evaluate(A_data)
T.ret(T.int32(0))

return mod

Expand Down