@@ -224,7 +224,13 @@ void CodeGenLLVM::InitTarget() {
224224#endif // TVM_LLVM_VERSION >= 60
225225}
226226
227- void CodeGenLLVM::AddFunction (const PrimFunc& f) { this ->AddFunctionInternal (f, false ); }
227+ llvm::Function* CodeGenLLVM::DeclareFunction (const GlobalVar& gvar, const PrimFunc& f) {
228+ return this ->DeclareFunctionInternal (gvar, f, false );
229+ }
230+
231+ void CodeGenLLVM::AddFunction (const GlobalVar& gvar, const PrimFunc& f) {
232+ this ->AddFunctionInternal (gvar, f, false );
233+ }
228234
229235void CodeGenLLVM::InitFuncState () {
230236 var_map_.clear ();
@@ -234,15 +240,34 @@ void CodeGenLLVM::InitFuncState() {
234240 analyzer_.reset (new arith::Analyzer ());
235241}
236242
237- void CodeGenLLVM::AddFunctionInternal (const PrimFunc& f, bool ret_void) {
238- this ->InitFuncState ();
243+ std::tuple<std::string, llvm::Function::LinkageTypes> CodeGenLLVM::GetLinkage (
244+ const GlobalVar& gvar, const PrimFunc& func) {
245+ if (auto global_symbol = func->GetAttr <String>(tvm::attr::kGlobalSymbol )) {
246+ return {global_symbol.value (), llvm::Function::ExternalLinkage};
247+ }
248+
249+ std::string symbol_name = [&]() {
250+ std::stringstream ss;
251+ ss << " _internal_" ;
252+ ss << gvar->name_hint ;
253+ return ss.str ();
254+ }();
255+
256+ return {symbol_name, llvm::Function::PrivateLinkage};
257+ }
258+
259+ llvm::Function* CodeGenLLVM::DeclareFunctionInternal (const GlobalVar& gvar, const PrimFunc& func,
260+ bool ret_void) {
261+ if (auto it = functions_.find (gvar.get ()); it != functions_.end ()) {
262+ return it->second ;
263+ }
239264
240- ICHECK_EQ (f ->buffer_map .size (), 0U )
265+ ICHECK_EQ (func ->buffer_map .size (), 0U )
241266 << " Cannot codegen function with buffer_map, please lower them first" ;
242267
243268 std::vector<llvm::Type*> param_types;
244- is_restricted_ = f ->HasNonzeroAttr (tir::attr::kNoAlias );
245- for (Var param : f ->params ) {
269+ is_restricted_ = func ->HasNonzeroAttr (tir::attr::kNoAlias );
270+ for (Var param : func ->params ) {
246271 param_types.push_back (GetLLVMType (param));
247272 if (!is_restricted_ && param.dtype ().is_handle ()) {
248273 alias_var_set_.insert (param.get ());
@@ -254,17 +279,26 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
254279 llvm::FunctionType* ftype =
255280 llvm::FunctionType::get (ret_void ? t_void_ : t_int_, param_types, false );
256281
257- auto global_symbol = f->GetAttr <String>(tvm::attr::kGlobalSymbol );
258- ICHECK (global_symbol.defined ())
259- << " CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute" ;
260- function_ = module_->getFunction (MakeStringRef (global_symbol.value ()));
261- if (function_ == nullptr ) {
262- function_ = llvm::Function::Create (ftype, llvm::Function::ExternalLinkage,
263- MakeStringRef (global_symbol.value ()), module_.get ());
282+ auto [symbol_name, linkage_type] = GetLinkage (gvar, func);
283+
284+ auto function = module_->getFunction (MakeStringRef (symbol_name));
285+ if (function == nullptr ) {
286+ function =
287+ llvm::Function::Create (ftype, linkage_type, MakeStringRef (symbol_name), module_.get ());
264288 }
265- function_->setCallingConv (llvm::CallingConv::C);
266- function_->setDLLStorageClass (llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
267- SetTargetAttributes (function_);
289+ function->setCallingConv (llvm::CallingConv::C);
290+ function->setDLLStorageClass (llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
291+ SetTargetAttributes (function);
292+
293+ functions_[gvar.get ()] = function;
294+
295+ return function;
296+ }
297+
298+ void CodeGenLLVM::AddFunctionInternal (const GlobalVar& gvar, const PrimFunc& f, bool ret_void) {
299+ this ->InitFuncState ();
300+
301+ function_ = DeclareFunctionInternal (gvar, f, ret_void);
268302
269303 // set var map and align information
270304 auto arg_it = function_->arg_begin ();
@@ -1747,9 +1781,19 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
17471781 VLOG (2 ) << " CreateIntrinsic done" ;
17481782 return x;
17491783 }
1784+ } else if (auto * ptr_gvar = op->op .as <GlobalVarNode>()) {
1785+ auto gvar = GetRef<GlobalVar>(ptr_gvar);
1786+ auto it = functions_.find (ptr_gvar);
1787+ ICHECK (it != functions_.end ()) << " Call to undefined GlobalVar \" " << gvar << " \" " ;
1788+ llvm::Function* callee = it->second ;
1789+ std::vector<llvm::Value*> arg_value;
1790+ for (const auto & arg : op->args ) {
1791+ arg_value.push_back (MakeValue (arg));
1792+ }
1793+ return builder_->CreateCall (callee, arg_value);
1794+
17501795 } else {
1751- ICHECK (op->op .as <GlobalVarNode>());
1752- LOG (FATAL) << " Do not yet support cross function call" ;
1796+ LOG (FATAL) << " Unsupported operation in CallNode: " << op->op ;
17531797 }
17541798}
17551799
0 commit comments