Skip to content

Commit 74456cf

Browse files
committed
Omit declarations that are present in TVM include files
1 parent deb6693 commit 74456cf

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

src/target/source/codegen_c_host.cc

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d
5252
declared_globals_.clear();
5353
decl_stream << "// tvm target: " << target_str << "\n";
5454
decl_stream << "#define TVM_EXPORTS\n";
55-
decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n";
56-
decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n";
55+
DeclareIncludeTVMRuntimeAPI();
56+
DeclareIncludeTVMBackendAPI();
5757
decl_stream << "#include <math.h>\n";
5858
decl_stream << "#include <stdbool.h>\n";
5959
if (devices.find("ethos-u") != devices.end()) {
@@ -70,6 +70,66 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d
7070
CodeGenC::Init(output_ssa);
7171
}
7272

73+
void CodeGenCHost::DeclareIncludeTVMRuntimeAPI() {
74+
decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n";
75+
included_function_names_.insert("TVMAPISetLastError");
76+
included_function_names_.insert("TVMAPISetLastPythonError");
77+
included_function_names_.insert("TVMGetLastPythonError");
78+
included_function_names_.insert("TVMGetLastError");
79+
included_function_names_.insert("TVMGetLastBacktrace");
80+
included_function_names_.insert("TVMDropLastPythonError");
81+
included_function_names_.insert("TVMThrowLastError");
82+
included_function_names_.insert("TVMModLoadFromFile");
83+
included_function_names_.insert("TVMModImport");
84+
included_function_names_.insert("TVMModGetFunction");
85+
included_function_names_.insert("TVMModFree");
86+
included_function_names_.insert("TVMFuncFree");
87+
included_function_names_.insert("TVMFuncCall");
88+
included_function_names_.insert("TVMCFuncSetReturn");
89+
included_function_names_.insert("TVMCbArgToReturn");
90+
included_function_names_.insert("TVMFuncCreateFromCFunc");
91+
included_function_names_.insert("TVMFuncRegisterGlobal");
92+
included_function_names_.insert("TVMFuncGetGlobal");
93+
included_function_names_.insert("TVMFuncListGlobalNames");
94+
included_function_names_.insert("TVMFuncRemoveGlobal");
95+
included_function_names_.insert("TVMArrayAlloc");
96+
included_function_names_.insert("TVMArrayFree");
97+
included_function_names_.insert("TVMArrayCopyFromBytes");
98+
included_function_names_.insert("TVMArrayCopyToBytes");
99+
included_function_names_.insert("TVMArrayCopyFromTo");
100+
included_function_names_.insert("TVMArrayFromDLPack");
101+
included_function_names_.insert("TVMArrayToDLPack");
102+
included_function_names_.insert("TVMDLManagedTensorCallDeleter");
103+
included_function_names_.insert("TVMStreamCreate");
104+
included_function_names_.insert("TVMStreamFree");
105+
included_function_names_.insert("TVMSetStream");
106+
included_function_names_.insert("TVMSynchronize");
107+
included_function_names_.insert("TVMStreamStreamSynchronize");
108+
included_function_names_.insert("TVMObjectGetTypeIndex");
109+
included_function_names_.insert("TVMObjectTypeKey2Index");
110+
included_function_names_.insert("TVMObjectTypeIndex2Key");
111+
included_function_names_.insert("TVMObjectRetain");
112+
included_function_names_.insert("TVMObjectFree");
113+
included_function_names_.insert("TVMByteArrayFree");
114+
included_function_names_.insert("TVMDeviceAllocDataSpace");
115+
included_function_names_.insert("TVMDeviceAllocDataSpaceWithScope");
116+
included_function_names_.insert("TVMDeviceFreeDataSpace");
117+
included_function_names_.insert("TVMDeviceCopyDataFromTo");
118+
included_function_names_.insert("TVMObjectDerivedFrom");
119+
}
120+
121+
void CodeGenCHost::DeclareIncludeTVMBackendAPI() {
122+
decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n";
123+
included_function_names_.insert("TVMBackendGetFuncFromEnv");
124+
included_function_names_.insert("TVMBackendRegisterSystemLibSymbol");
125+
included_function_names_.insert("TVMBackendAllocWorkspace");
126+
included_function_names_.insert("TVMBackendFreeWorkspace");
127+
included_function_names_.insert("TVMBackendRegisterEnvCAPI");
128+
included_function_names_.insert("TVMBackendParallelLaunch");
129+
included_function_names_.insert("TVMBackendParallelBarrier");
130+
included_function_names_.insert("TVMBackendRunOnce");
131+
}
132+
73133
void CodeGenCHost::InitGlobalContext() {
74134
decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx << " = NULL;\n";
75135
}

src/target/source/codegen_c_host.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,21 @@ class CodeGenCHost : public CodeGenC {
7373
const Type& ret_type) override;
7474
Array<String> GetFunctionNames() { return function_names_; }
7575

76+
protected:
77+
/* \brief Names declared in external headers
78+
*
79+
* When encountering a `builtin::call_extern`, a forward declaration
80+
* will usually be generated based on the arguments used in TIR. In
81+
* some cases, this can conflict with the declaration used in the
82+
* header file. For example, the `c_backend_api.h` header declares
83+
* `void* TVMBackendStringRetValue(const char*)`, but the
84+
* auto-generated declaration would have `uint8*` argument.
85+
*
86+
* Names in this set will be excluded from the automatic forward
87+
* declaration, to avoid conflicting declarations.
88+
*/
89+
std::unordered_set<std::string> included_function_names_;
90+
7691
private:
7792
/* \brief Internal structure to store information about function calls */
7893
struct FunctionInfo {
@@ -110,6 +125,9 @@ class CodeGenCHost : public CodeGenC {
110125
template <typename T>
111126
inline void PrintTernaryCondExpr(const T* op, const char* compare,
112127
std::ostream& os); // NOLINT(*)
128+
129+
void DeclareIncludeTVMRuntimeAPI();
130+
void DeclareIncludeTVMBackendAPI();
113131
};
114132

115133
} // namespace codegen

0 commit comments

Comments
 (0)