Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

this->PrintFuncPrefix(stream);
PrintType(f->ret_type, stream);
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

Expand Down Expand Up @@ -128,7 +129,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
this->stream << "}\n\n";
}

void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; }
void CodeGenC::PrintFuncPrefix(std::ostream& os) {}

void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}

Expand Down Expand Up @@ -541,7 +542,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
this->GenerateForwardFunctionDeclarations(func->value, op->args);
Array<Type> arg_types;
for (size_t i = 1; i < op->args.size(); i++) {
arg_types.push_back(GetType(op->args[i]));
}
Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
Expand Down
7 changes: 5 additions & 2 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,14 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
/*!
* \brief Generate forward function declarations.
* \param global_symbol The symbolc of the target function.
* \param args The arguments to the function.
* \param arg_types The argument types to the function.
* \param ret_type The return type of the function
* \param os The output stream.
*/
virtual void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args) {}
const Array<Type>& arg_types,
const Type& ret_type) {}

/*!
* \brief Print external function call.
* \param ret_type The return type.
Expand Down
15 changes: 9 additions & 6 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
function_names_.push_back(runtime::symbol::tvm_module_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
PrintType(f->ret_type, stream);
stream << " " << tvm::runtime::symbol::tvm_module_main
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
Expand All @@ -97,7 +98,9 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
}

void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args) {

const Array<Type>& arg_types,
const Type& ret_type) {
if (!emit_fwd_func_decl_) {
return;
}
Expand All @@ -107,13 +110,13 @@ void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
}
}
this->PrintFuncPrefix(fwd_decl_stream);
this->PrintType(ret_type, fwd_decl_stream);
fwd_decl_stream << " " << global_symbol << "(";
for (size_t i = 1; i < args.size(); ++i) {
CodeGenSourceBase::PrintType(GetType(args[i]), fwd_decl_stream);
fwd_decl_stream << " ", this->PrintExpr(args[i], fwd_decl_stream);
if (i < args.size() - 1) {
for (size_t i = 0; i < arg_types.size(); ++i) {
if (i > 0) {
fwd_decl_stream << ", ";
}
CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream);
}
fwd_decl_stream << ");\n";
}
Expand All @@ -122,7 +125,7 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*)
os << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n"
<< "TVM_DLL int32_t";
<< "TVM_DLL ";
}

void CodeGenCHost::PrintFinalReturn() { // NOLINT(*)
Expand Down
5 changes: 3 additions & 2 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class CodeGenCHost : public CodeGenC {
void AddFunctionsOrdered(std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
void DefineModuleName();

using CodeGenC::PrintType;
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*)
void PrintFinalReturn() final; // NOLINT(*)
Expand All @@ -69,8 +70,8 @@ class CodeGenCHost : public CodeGenC {

void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*)

virtual void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args); // NOLINT(*)
void GenerateForwardFunctionDeclarations(String global_symbol, const Array<Type>& arg_types,
const Type& ret_type) override;
Array<String> GetFunctionNames() { return function_names_; }

private:
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void CodeGenCUDA::Init(bool output_ssa) {
ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
}

void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ void"; }
void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }

class ThreadIdxExtractor : public tir::StmtVisitor {
private:
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) {
}
}

void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel void"; }
void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel "; }

void CodeGenOpenCL::PreFunctionBody(const PrimFunc& f) {
for (Var arg : f->params) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_vhls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) {
}
}

void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" void"; }
void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" "; }

void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) {
Expand Down