Skip to content

Commit dbbcf90

Browse files
authored
[FFI][REFATOR] Cleanup entry function to redirect (#18205)
This PR updates the entry function mechanism to create a stub that redirects to the real function. This new behavior helps to simplify the runtime logic supporting entry function. Also updates the name to `__tvm_ffi_main__`
1 parent 6790af8 commit dbbcf90

File tree

17 files changed

+57
-56
lines changed

17 files changed

+57
-56
lines changed

include/tvm/runtime/module.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,14 +290,14 @@ namespace symbol {
290290
constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi_library_ctx";
291291
/*! \brief Global variable to store binary data alongside a library module. */
292292
constexpr const char* tvm_ffi_library_bin = "__tvm_ffi_library_bin";
293+
/*! \brief Placeholder for the module's entry function. */
294+
constexpr const char* tvm_ffi_main = "__tvm_ffi_main__";
293295
/*! \brief global function to set device */
294296
constexpr const char* tvm_set_device = "__tvm_set_device";
295297
/*! \brief Auxiliary counter to global barrier. */
296298
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
297299
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
298300
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
299-
/*! \brief Placeholder for the module's entry function. */
300-
constexpr const char* tvm_module_main = "__tvm_main__";
301301
} // namespace symbol
302302

303303
// implementations of inline functions.

jvm/core/src/main/java/org/apache/tvm/Module.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private static Function getApi(String name) {
4646
}
4747

4848
private Function entry = null;
49-
private final String entryName = "__tvm_main__";
49+
private final String entryName = "__tvm_ffi_main__";
5050

5151

5252
/**

python/tvm/runtime/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class Module(tvm.ffi.Object):
103103

104104
def __new__(cls):
105105
instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter
106-
instance.entry_name = "__tvm_main__"
106+
instance.entry_name = "__tvm_ffi_main__"
107107
instance._entry = None
108108
return instance
109109

@@ -118,7 +118,7 @@ def entry_func(self):
118118
"""
119119
if self._entry:
120120
return self._entry
121-
self._entry = self.get_function("__tvm_main__")
121+
self._entry = self.get_function("__tvm_ffi_main__")
122122
return self._entry
123123

124124
def implements_function(self, name, query_imports=False):

src/runtime/cuda/cuda_module.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ class CUDAPrepGlobalBarrier {
258258
ffi::Function CUDAModuleNode::GetFunction(const String& name,
259259
const ObjectPtr<Object>& sptr_to_self) {
260260
ICHECK_EQ(sptr_to_self.get(), this);
261-
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
262261
if (name == symbol::tvm_prepare_global_barrier) {
263262
return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self));
264263
}

src/runtime/library_module.cc

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,7 @@ class LibraryModuleNode final : public ModuleNode {
5050

5151
ffi::Function GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
5252
TVMFFISafeCallType faddr;
53-
if (name == runtime::symbol::tvm_module_main) {
54-
const char* entry_name =
55-
reinterpret_cast<const char*>(lib_->GetSymbol(runtime::symbol::tvm_module_main));
56-
ICHECK(entry_name != nullptr)
57-
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
58-
faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(entry_name));
59-
} else {
60-
faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
61-
}
53+
faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
6254
if (faddr == nullptr) return ffi::Function();
6355
return packed_func_wrapper_(faddr, sptr_to_self);
6456
}

src/runtime/metal/metal_module.mm

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args)
264264
ffi::Function ret;
265265
AUTORELEASEPOOL {
266266
ICHECK_EQ(sptr_to_self.get(), this);
267-
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
268267
auto it = fmap_.find(name);
269268
if (it == fmap_.end()) {
270269
ret = ffi::Function();

src/runtime/opencl/opencl_module.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() {
138138
ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name,
139139
const ObjectPtr<Object>& sptr_to_self) {
140140
ICHECK_EQ(sptr_to_self.get(), this);
141-
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
142141
auto it = fmap_.find(name);
143142
if (it == fmap_.end()) return ffi::Function();
144143
const FunctionInfo& info = it->second;

src/runtime/rocm/rocm_module.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ class ROCMWrappedFunc {
195195
ffi::Function ROCMModuleNode::GetFunction(const String& name,
196196
const ObjectPtr<Object>& sptr_to_self) {
197197
ICHECK_EQ(sptr_to_self.get(), this);
198-
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
199198
auto it = fmap_.find(name);
200199
if (it == fmap_.end()) return ffi::Function();
201200
const FunctionInfo& info = it->second;

src/runtime/vulkan/vulkan_wrapped_func.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ VulkanModuleNode::~VulkanModuleNode() {
208208
ffi::Function VulkanModuleNode::GetFunction(const String& name,
209209
const ObjectPtr<Object>& sptr_to_self) {
210210
ICHECK_EQ(sptr_to_self.get(), this);
211-
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
212211
auto it = fmap_.find(name);
213212
if (it == fmap_.end()) return ffi::Function();
214213
const FunctionInfo& info = it->second;

src/target/llvm/codegen_cpu.cc

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -229,28 +229,42 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
229229
}
230230

231231
void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) {
232-
llvm::Function* f = module_->getFunction(entry_func_name);
233-
ICHECK(f) << "Function " << entry_func_name << "does not in module";
234-
llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1);
235-
llvm::GlobalVariable* global =
236-
new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::WeakAnyLinkage, nullptr,
237-
runtime::symbol::tvm_module_main);
238-
#if TVM_LLVM_VERSION >= 100
239-
global->setAlignment(llvm::Align(1));
240-
#else
241-
global->setAlignment(1);
242-
#endif
243-
// comdat is needed for windows select any linking to work
244-
// set comdat to Any(weak linking)
232+
// create a wrapper function with tvm_ffi_main name and redirects to the entry function
233+
llvm::Function* target_func = module_->getFunction(entry_func_name);
234+
ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module";
235+
236+
// Create wrapper function
237+
llvm::Function* wrapper_func =
238+
llvm::Function::Create(target_func->getFunctionType(), llvm::Function::WeakAnyLinkage,
239+
runtime::symbol::tvm_ffi_main, module_.get());
240+
241+
// Set attributes (Windows comdat, DLL export, etc.)
245242
if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) {
246-
llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_module_main);
243+
llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_ffi_main);
247244
comdat->setSelectionKind(llvm::Comdat::Any);
248-
global->setComdat(comdat);
245+
wrapper_func->setComdat(comdat);
249246
}
250247

251-
global->setInitializer(
252-
llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), entry_func_name));
253-
global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
248+
wrapper_func->setCallingConv(llvm::CallingConv::C);
249+
wrapper_func->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
250+
251+
// Create simple tail call
252+
llvm::BasicBlock* entry =
253+
llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", wrapper_func);
254+
builder_->SetInsertPoint(entry);
255+
256+
// Forward all arguments to target function
257+
std::vector<llvm::Value*> call_args;
258+
for (llvm::Value& arg : wrapper_func->args()) {
259+
call_args.push_back(&arg);
260+
}
261+
262+
llvm::Value* result = builder_->CreateCall(target_func, call_args);
263+
if (target_func->getReturnType()->isVoidTy()) {
264+
builder_->CreateRetVoid();
265+
} else {
266+
builder_->CreateRet(result);
267+
}
254268
}
255269

256270
std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {

0 commit comments

Comments
 (0)