@@ -229,28 +229,42 @@ void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
229229}
230230
231231void CodeGenCPU::AddMainFunction (const std::string& entry_func_name) {
232- llvm::Function* f = module_->getFunction (entry_func_name);
233- ICHECK (f) << " Function " << entry_func_name << " does not in module" ;
234- llvm::Type* type = llvm::ArrayType::get (t_char_, entry_func_name.length () + 1 );
235- llvm::GlobalVariable* global =
236- new llvm::GlobalVariable (*module_, type, true , llvm::GlobalValue::WeakAnyLinkage, nullptr ,
237- runtime::symbol::tvm_module_main);
238- #if TVM_LLVM_VERSION >= 100
239- global->setAlignment (llvm::Align (1 ));
240- #else
241- global->setAlignment (1 );
242- #endif
243- // comdat is needed for windows select any linking to work
244- // set comdat to Any(weak linking)
232+ // create a wrapper function with tvm_ffi_main name and redirects to the entry function
233+ llvm::Function* target_func = module_->getFunction (entry_func_name);
234+ ICHECK (target_func) << " Function " << entry_func_name << " does not exist in module" ;
235+
236+ // Create wrapper function
237+ llvm::Function* wrapper_func =
238+ llvm::Function::Create (target_func->getFunctionType (), llvm::Function::WeakAnyLinkage,
239+ runtime::symbol::tvm_ffi_main, module_.get ());
240+
241+ // Set attributes (Windows comdat, DLL export, etc.)
245242 if (llvm_target_->GetOrCreateTargetMachine ()->getTargetTriple ().isOSWindows ()) {
246- llvm::Comdat* comdat = module_->getOrInsertComdat (runtime::symbol::tvm_module_main );
243+ llvm::Comdat* comdat = module_->getOrInsertComdat (runtime::symbol::tvm_ffi_main );
247244 comdat->setSelectionKind (llvm::Comdat::Any);
248- global ->setComdat (comdat);
245+ wrapper_func ->setComdat (comdat);
249246 }
250247
251- global->setInitializer (
252- llvm::ConstantDataArray::getString (*llvm_target_->GetContext (), entry_func_name));
253- global->setDLLStorageClass (llvm::GlobalVariable::DLLExportStorageClass);
248+ wrapper_func->setCallingConv (llvm::CallingConv::C);
249+ wrapper_func->setDLLStorageClass (llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
250+
251+ // Create simple tail call
252+ llvm::BasicBlock* entry =
253+ llvm::BasicBlock::Create (*llvm_target_->GetContext (), " entry" , wrapper_func);
254+ builder_->SetInsertPoint (entry);
255+
256+ // Forward all arguments to target function
257+ std::vector<llvm::Value*> call_args;
258+ for (llvm::Value& arg : wrapper_func->args ()) {
259+ call_args.push_back (&arg);
260+ }
261+
262+ llvm::Value* result = builder_->CreateCall (target_func, call_args);
263+ if (target_func->getReturnType ()->isVoidTy ()) {
264+ builder_->CreateRetVoid ();
265+ } else {
266+ builder_->CreateRet (result);
267+ }
254268}
255269
256270std::unique_ptr<llvm::Module> CodeGenCPU::Finish () {
0 commit comments