diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index f6792c1a4e8b..bcdd0bfea0dd 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -87,6 +87,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(stream); + PrintType(f->ret_type, stream); this->PrintExtraAttrs(f); this->stream << " " << static_cast(global_symbol.value()) << "("; @@ -128,7 +129,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { this->stream << "}\n\n"; } -void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; } +void CodeGenC::PrintFuncPrefix(std::ostream& os) {} void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {} @@ -541,7 +542,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) ICHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); this->PrintCallExtern(GetType(GetRef(op)), func->value, op->args, true, os); - this->GenerateForwardFunctionDeclarations(func->value, op->args); + Array arg_types; + for (size_t i = 1; i < op->args.size(); i++) { + arg_types.push_back(GetType(op->args[i])); + } + Type ret_type = GetTypeFromRuntimeDataType(op->dtype); + this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type); } else if (op_attr_global_symbol_.count(call_op)) { // call extern if the op itself have a global symbol. this->PrintCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 4f0da5a9dbad..de9c2f1745cb 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -232,11 +232,14 @@ class CodeGenC : public ExprFunctor, /*! * \brief Generate forward function declarations. * \param global_symbol The symbolc of the target function. - * \param args The arguments to the function. + * \param arg_types The argument types to the function. + * \param ret_type The return type of the function * \param os The output stream. */ virtual void GenerateForwardFunctionDeclarations(String global_symbol, - const Array& args) {} + const Array& arg_types, + const Type& ret_type) {} + /*! * \brief Print external function call. * \param ret_type The return type. diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 1d8071774e9e..e98852c270d5 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -87,6 +87,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) { function_names_.push_back(runtime::symbol::tvm_module_main); stream << "// CodegenC: NOTE: Auto-generated entry function\n"; PrintFuncPrefix(stream); + PrintType(f->ret_type, stream); stream << " " << tvm::runtime::symbol::tvm_module_main << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " << "int* out_ret_tcode, void* resource_handle) {\n"; @@ -97,7 +98,9 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) { } void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol, - const Array& args) { + + const Array& arg_types, + const Type& ret_type) { if (!emit_fwd_func_decl_) { return; } @@ -107,13 +110,13 @@ void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol, } } this->PrintFuncPrefix(fwd_decl_stream); + this->PrintType(ret_type, fwd_decl_stream); fwd_decl_stream << " " << global_symbol << "("; - for (size_t i = 1; i < args.size(); ++i) { - CodeGenSourceBase::PrintType(GetType(args[i]), fwd_decl_stream); - fwd_decl_stream << " ", this->PrintExpr(args[i], fwd_decl_stream); - if (i < args.size() - 1) { + for (size_t i = 0; i < arg_types.size(); ++i) { + if (i > 0) { fwd_decl_stream << ", "; } + CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream); } fwd_decl_stream << ");\n"; } @@ -122,7 +125,7 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*) os << "#ifdef __cplusplus\n" << "extern \"C\"\n" << "#endif\n" - << "TVM_DLL int32_t"; + << "TVM_DLL "; } void CodeGenCHost::PrintFinalReturn() { // NOLINT(*) diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 6bae574627d5..9c71f197f0e1 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -55,6 +55,7 @@ class CodeGenCHost : public CodeGenC { void AddFunctionsOrdered(std::vector> functions); void DefineModuleName(); + using CodeGenC::PrintType; void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*) void PrintFinalReturn() final; // NOLINT(*) @@ -69,8 +70,8 @@ class CodeGenCHost : public CodeGenC { void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*) - virtual void GenerateForwardFunctionDeclarations(String global_symbol, - const Array& args); // NOLINT(*) + void GenerateForwardFunctionDeclarations(String global_symbol, const Array& arg_types, + const Type& ret_type) override; Array GetFunctionNames() { return function_names_; } private: diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ec8695a2a038..cd0ec0e34f03 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -49,7 +49,7 @@ void CodeGenCUDA::Init(bool output_ssa) { ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } -void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ void"; } +void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; } class ThreadIdxExtractor : public tir::StmtVisitor { private: diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index fa4ca7d34ba8..c15d2253d716 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -88,7 +88,7 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { } } -void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel void"; } +void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel "; } void CodeGenOpenCL::PreFunctionBody(const PrimFunc& f) { for (Var arg : f->params) { diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 8463d6ac4147..83046de10701 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -80,7 +80,7 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) { } } -void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" void"; } +void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" "; } void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) { for (size_t i = 0; i < f->params.size(); ++i) {