@@ -28,14 +28,16 @@ std::unique_ptr<CodeGenLLVM> CodeGenLLVM::Create(llvm::TargetMachine *tm) {
2828
2929void CodeGenLLVM::Init (const std::string& module_name,
3030 llvm::TargetMachine* tm,
31- llvm::LLVMContext* ctx) {
31+ llvm::LLVMContext* ctx,
32+ bool system_lib) {
3233 InitializeLLVM ();
3334 static_assert (sizeof (TVMValue) == sizeof (double ), " invariant" );
3435 // static_assert(alignof(TVMValue) == alignof(double), "invariant");
3536 // clear maps
3637 var_map_.clear ();
3738 str_map_.clear ();
3839 func_handle_map_.clear ();
40+ export_system_symbols_.clear ();
3941 // initialize types.
4042 if (ctx_ != ctx) {
4143 t_void_ = llvm::Type::getVoidTy (*ctx);
@@ -96,6 +98,13 @@ void CodeGenLLVM::Init(const std::string& module_name,
9698 t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo (), t_void_p_}
9799 , false ),
98100 llvm::Function::ExternalLinkage, " TVMBackendParallelFor" , module_.get ());
101+ if (system_lib) {
102+ f_tvm_register_system_symbol_ = llvm::Function::Create (
103+ llvm::FunctionType::get (t_int_, {t_char_->getPointerTo (), t_void_p_}, false ),
104+ llvm::Function::ExternalLinkage, " TVMBackendRegisterSystemLibSymbol" , module_.get ());
105+ } else {
106+ f_tvm_register_system_symbol_ = nullptr ;
107+ }
99108 this ->InitTarget (tm);
100109 // initialize builder
101110 builder_.reset (new IRBuilder (*ctx));
@@ -125,9 +134,15 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
125134void CodeGenLLVM::InitGlobalContext () {
126135 gv_mod_ctx_ = new llvm::GlobalVariable (
127136 *module_, t_void_p_, false ,
128- llvm::GlobalValue::LinkOnceAnyLinkage, 0 , " __tvm_module_ctx" );
137+ llvm::GlobalValue::LinkOnceAnyLinkage, 0 ,
138+ tvm::runtime::symbol::tvm_module_ctx);
129139 gv_mod_ctx_->setAlignment (data_layout_->getTypeAllocSize (t_void_p_));
130140 gv_mod_ctx_->setInitializer (llvm::Constant::getNullValue (t_void_p_));
141+
142+ if (f_tvm_register_system_symbol_ != nullptr ) {
143+ export_system_symbols_.emplace_back (
144+ std::make_pair (tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_));
145+ }
131146}
132147
133148void CodeGenLLVM::InitFuncState () {
@@ -171,6 +186,11 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
171186 builder_->SetInsertPoint (block);
172187 this ->VisitStmt (f->body );
173188 builder_->CreateRet (ConstInt32 (0 ));
189+
190+ if (f_tvm_register_system_symbol_ != nullptr ) {
191+ export_system_symbols_.emplace_back (
192+ std::make_pair (f->name , builder_->CreatePointerCast (function_, t_void_p_)));
193+ }
174194}
175195
176196void CodeGenLLVM::AddMainFunction (const std::string& entry_func_name) {
@@ -225,13 +245,35 @@ void CodeGenLLVM::Optimize() {
225245}
226246
227247std::unique_ptr<llvm::Module> CodeGenLLVM::Finish () {
248+ this ->AddStartupFunction ();
228249 this ->Optimize ();
229250 var_map_.clear ();
230251 str_map_.clear ();
231252 func_handle_map_.clear ();
253+ export_system_symbols_.clear ();
232254 return std::move (module_);
233255}
234256
257+ void CodeGenLLVM::AddStartupFunction () {
258+ if (export_system_symbols_.size () != 0 ) {
259+ llvm::FunctionType* ftype = llvm::FunctionType::get (t_void_, {}, false );
260+ function_ = llvm::Function::Create (
261+ ftype,
262+ llvm::Function::InternalLinkage,
263+ " __tvm_module_startup" , module_.get ());
264+ llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create (*ctx_, " entry" , function_);
265+ builder_->SetInsertPoint (startup_entry);
266+ for (const auto & kv : export_system_symbols_) {
267+ llvm::Value* name = GetConstString (kv.first );
268+ builder_->CreateCall (
269+ f_tvm_register_system_symbol_, {
270+ name, builder_->CreateBitCast (kv.second , t_void_p_)});
271+ }
272+ llvm::appendToGlobalCtors (*module_, function_, 65535 );
273+ builder_->CreateRet (nullptr );
274+ }
275+ }
276+
235277llvm::Type* CodeGenLLVM::LLVMType (const Type& t) const {
236278 llvm::Type* ret = nullptr ;
237279 if (t.is_uint () || t.is_int ()) {
0 commit comments