Skip to content

Commit cd45513

Browse files
authored
[LLVM] Codegen subroutine call when CallNode::op is GlobalVar (#14901)
* [CodeGen][LLVM] Codegen to generate internal functions Previously, `CodeGenLLVM` required all TIR PrimFuncs to have the `kGlobalSymbol` attribute, using its value as the externally-visible symbol in the generated library. This commit relaxes that requirement, using the presence of `kGlobalSymbol` to indicate whether a function should be exposed externally. If `kGlobalSymbol` is not defined, then the symbol name is generated from the name of the `tvm::GlobalVar` with the prefix `"_internal_"`, and the symbol is not exposed externally. Since this does not change the codegen behavior for any function that was previously supported, this is not a breaking change. * [Codegen][LLVM] Handle callsite for internal functions * [UnitTest][LLVM] Added test for LLVM codegen for subroutine
1 parent 53cee4b commit cd45513

File tree

9 files changed

+170
-62
lines changed

9 files changed

+170
-62
lines changed

src/target/llvm/codegen_amdgpu.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ class CodeGenAMDGPU : public CodeGenLLVM {
8787
CodeGenAMDGPU() = default;
8888
virtual ~CodeGenAMDGPU() = default;
8989

90-
void AddFunction(const PrimFunc& f) final {
90+
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
9191
// add function as void return value
92-
CodeGenLLVM::AddFunctionInternal(f, true);
92+
CodeGenLLVM::AddFunctionInternal(gvar, f, true);
9393
function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
9494
std::ostringstream attr;
9595
attr << "1," << DetectROCMmaxThreadsPerBlock();
@@ -262,11 +262,7 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) {
262262

263263
cg->Init("TVMAMDGPUModule", llvm_target.get(), NullOpt, false, false);
264264

265-
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
266-
ICHECK(kv.second->template IsInstance<PrimFuncNode>())
267-
<< "Can only lower IR Module with PrimFuncs";
268-
return Downcast<PrimFunc>(kv.second);
269-
});
265+
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
270266

271267
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
272268
const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");

src/target/llvm/codegen_cpu.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,18 +220,17 @@ llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) {
220220
#endif
221221
}
222222

223-
void CodeGenCPU::AddFunction(const PrimFunc& f) {
223+
void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
224224
#if TVM_LLVM_VERSION >= 50
225225
di_subprogram_ = CreateDebugFunction(f);
226226
#endif
227227
EmitDebugLocation(f->span);
228-
CodeGenLLVM::AddFunction(f);
228+
CodeGenLLVM::AddFunction(gvar, f);
229229
if (f_tvm_register_system_symbol_ != nullptr) {
230-
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
231-
ICHECK(global_symbol.defined())
232-
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
233-
export_system_symbols_.emplace_back(
234-
std::make_pair(global_symbol.value().operator std::string(), function_));
230+
if (auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
231+
export_system_symbols_.emplace_back(
232+
std::make_pair(global_symbol.value().operator std::string(), function_));
233+
}
235234
}
236235
AddDebugInformation(f, function_);
237236
}

src/target/llvm/codegen_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class CodeGenCPU : public CodeGenLLVM {
6767
void Init(const std::string& module_name, LLVMTarget* llvm_target,
6868
Optional<String> system_lib_prefix, bool dynamic_lookup,
6969
bool target_c_runtime) override;
70-
void AddFunction(const PrimFunc& f) override;
70+
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) override;
7171
void AddMainFunction(const std::string& entry_func_name) override;
7272
std::unique_ptr<llvm::Module> Finish() override;
7373
void VisitStmt_(const AssertStmtNode* op) override;

src/target/llvm/codegen_hexagon.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
547547

548548
auto cg = std::make_unique<CodeGenHexagon>();
549549

550-
std::vector<PrimFunc> funcs;
551550
std::string entry_func;
552551

553552
for (auto kv : mod->functions) {
@@ -562,11 +561,10 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
562561
ICHECK(global_symbol.defined());
563562
entry_func = global_symbol.value();
564563
}
565-
funcs.emplace_back(f);
566564
}
567565

568566
cg->Init("TVMHexagonModule", llvm_target.get(), NullOpt, false, false);
569-
cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
567+
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
570568
if (entry_func.length() != 0) {
571569
cg->AddMainFunction(entry_func);
572570
}

src/target/llvm/codegen_llvm.cc

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,13 @@ void CodeGenLLVM::InitTarget() {
224224
#endif // TVM_LLVM_VERSION >= 60
225225
}
226226

227-
void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }
227+
llvm::Function* CodeGenLLVM::DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) {
228+
return this->DeclareFunctionInternal(gvar, f, false);
229+
}
230+
231+
void CodeGenLLVM::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
232+
this->AddFunctionInternal(gvar, f, false);
233+
}
228234

229235
void CodeGenLLVM::InitFuncState() {
230236
var_map_.clear();
@@ -234,15 +240,34 @@ void CodeGenLLVM::InitFuncState() {
234240
analyzer_.reset(new arith::Analyzer());
235241
}
236242

237-
void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
238-
this->InitFuncState();
243+
std::tuple<std::string, llvm::Function::LinkageTypes> CodeGenLLVM::GetLinkage(
244+
const GlobalVar& gvar, const PrimFunc& func) {
245+
if (auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
246+
return {global_symbol.value(), llvm::Function::ExternalLinkage};
247+
}
248+
249+
std::string symbol_name = [&]() {
250+
std::stringstream ss;
251+
ss << "_internal_";
252+
ss << gvar->name_hint;
253+
return ss.str();
254+
}();
255+
256+
return {symbol_name, llvm::Function::PrivateLinkage};
257+
}
258+
259+
llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func,
260+
bool ret_void) {
261+
if (auto it = functions_.find(gvar.get()); it != functions_.end()) {
262+
return it->second;
263+
}
239264

240-
ICHECK_EQ(f->buffer_map.size(), 0U)
265+
ICHECK_EQ(func->buffer_map.size(), 0U)
241266
<< "Cannot codegen function with buffer_map, please lower them first";
242267

243268
std::vector<llvm::Type*> param_types;
244-
is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias);
245-
for (Var param : f->params) {
269+
is_restricted_ = func->HasNonzeroAttr(tir::attr::kNoAlias);
270+
for (Var param : func->params) {
246271
param_types.push_back(GetLLVMType(param));
247272
if (!is_restricted_ && param.dtype().is_handle()) {
248273
alias_var_set_.insert(param.get());
@@ -254,17 +279,26 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
254279
llvm::FunctionType* ftype =
255280
llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);
256281

257-
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
258-
ICHECK(global_symbol.defined())
259-
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
260-
function_ = module_->getFunction(MakeStringRef(global_symbol.value()));
261-
if (function_ == nullptr) {
262-
function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
263-
MakeStringRef(global_symbol.value()), module_.get());
282+
auto [symbol_name, linkage_type] = GetLinkage(gvar, func);
283+
284+
auto function = module_->getFunction(MakeStringRef(symbol_name));
285+
if (function == nullptr) {
286+
function =
287+
llvm::Function::Create(ftype, linkage_type, MakeStringRef(symbol_name), module_.get());
264288
}
265-
function_->setCallingConv(llvm::CallingConv::C);
266-
function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
267-
SetTargetAttributes(function_);
289+
function->setCallingConv(llvm::CallingConv::C);
290+
function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
291+
SetTargetAttributes(function);
292+
293+
functions_[gvar.get()] = function;
294+
295+
return function;
296+
}
297+
298+
void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void) {
299+
this->InitFuncState();
300+
301+
function_ = DeclareFunctionInternal(gvar, f, ret_void);
268302

269303
// set var map and align information
270304
auto arg_it = function_->arg_begin();
@@ -1747,9 +1781,19 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
17471781
VLOG(2) << "CreateIntrinsic done";
17481782
return x;
17491783
}
1784+
} else if (auto* ptr_gvar = op->op.as<GlobalVarNode>()) {
1785+
auto gvar = GetRef<GlobalVar>(ptr_gvar);
1786+
auto it = functions_.find(ptr_gvar);
1787+
ICHECK(it != functions_.end()) << "Call to undefined GlobalVar \"" << gvar << "\"";
1788+
llvm::Function* callee = it->second;
1789+
std::vector<llvm::Value*> arg_value;
1790+
for (const auto& arg : op->args) {
1791+
arg_value.push_back(MakeValue(arg));
1792+
}
1793+
return builder_->CreateCall(callee, arg_value);
1794+
17501795
} else {
1751-
ICHECK(op->op.as<GlobalVarNode>());
1752-
LOG(FATAL) << "Do not yet support cross function call";
1796+
LOG(FATAL) << "Unsupported operation in CallNode: " << op->op;
17531797
}
17541798
}
17551799

src/target/llvm/codegen_llvm.h

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
#include <algorithm>
6464
#include <memory>
6565
#include <string>
66+
#include <tuple>
6667
#include <unordered_map>
6768
#include <unordered_set>
6869
#include <utility>
@@ -132,11 +133,17 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
132133
*/
133134
void SetFastMathFlags(llvm::FastMathFlags fmf);
134135

136+
virtual llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f);
137+
135138
/*!
136139
* \brief Compile and add function f to the current module.
140+
*
141+
* \param gvar The GlobalVar which may be used to may internal calls
142+
* to this function from elsewhere in the module.
143+
*
137144
* \param f The function to be added.
138145
*/
139-
virtual void AddFunction(const PrimFunc& f);
146+
virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& f);
140147
/*!
141148
* \brief Add main function as the entry name
142149
* \param entry_func_name The name of entry function to be added.
@@ -356,7 +363,28 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
356363
virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
357364
// Get correct address space depending on the backend
358365
virtual unsigned GetGlobalAddressSpace() const;
359-
void AddFunctionInternal(const PrimFunc& f, bool ret_void);
366+
367+
/*! \brief Get the linkage parameters for the function
368+
*
369+
* Returns a tuple whose first element is the name of the function
370+
* and whose second element is the linkage type to be used
371+
* (e.g. llvm::Function::ExternalLinkage or
372+
* llvm::Function::PrivateLinkage)
373+
*
374+
* \param func The PrimFunc whose symbol name and linkage type
375+
* should be returned
376+
*
377+
* \param gvar The GlobalVar to be used when generating the symbol
378+
* name. Only used for internal functions, for which the
379+
* kGlobalSymbol attribute is not defined.
380+
*/
381+
std::tuple<std::string, llvm::Function::LinkageTypes> GetLinkage(const GlobalVar& gvar,
382+
const PrimFunc& func);
383+
384+
llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void);
385+
386+
void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void);
387+
360388
// Create extern call
361389
llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name,
362390
const std::vector<llvm::Value*>& value);
@@ -517,6 +545,11 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
517545
std::unordered_map<const VarNode*, llvm::Value*> var_map_;
518546
// global strings
519547
std::unordered_map<std::string, llvm::Constant*> str_map_;
548+
549+
// Map from TVM's GlobalVar to the llvm::Function that represents
550+
// that function.
551+
std::unordered_map<const GlobalVarNode*, llvm::Function*> functions_;
552+
520553
// Whether current function is restricted
521554
bool is_restricted_{true};
522555
// The analyzer information
@@ -569,18 +602,26 @@ inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) {
569602

570603
template <typename IterType, typename ConvType>
571604
void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc) {
572-
std::vector<PrimFunc> funcs;
605+
std::vector<std::tuple<GlobalVar, PrimFunc>> funcs;
573606
for (auto it = begin; it != end; ++it) {
574-
funcs.push_back(pfunc(*it));
607+
auto [gvar, func] = *it;
608+
auto converted = pfunc(func);
609+
funcs.push_back({gvar, Downcast<PrimFunc>(converted)});
575610
}
576-
std::sort(funcs.begin(), funcs.end(), [](PrimFunc func_a, PrimFunc func_b) {
577-
std::string name_a = func_a->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
578-
std::string name_b = func_b->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
611+
std::sort(funcs.begin(), funcs.end(), [this](const auto& pair_a, const auto& pair_b) {
612+
const auto& [gvar_a, func_a] = pair_a;
613+
std::string name_a = std::get<std::string>(GetLinkage(gvar_a, func_a));
614+
615+
const auto& [gvar_b, func_b] = pair_b;
616+
std::string name_b = std::get<std::string>(GetLinkage(gvar_b, func_b));
579617
return name_a < name_b;
580618
});
581-
for (auto& f : funcs) {
582-
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
583-
AddFunction(f);
619+
620+
for (const auto& [gvar, func] : funcs) {
621+
DeclareFunction(gvar, func);
622+
}
623+
for (const auto& [gvar, func] : funcs) {
624+
AddFunction(gvar, func);
584625
}
585626
}
586627

src/target/llvm/codegen_nvptx.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,13 @@ namespace codegen {
6666
// NVPTX code generator.
6767
class CodeGenNVPTX : public CodeGenLLVM {
6868
public:
69-
void AddFunction(const PrimFunc& f) final {
69+
llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) final {
7070
// add function as void return value
71-
CodeGenLLVM::AddFunctionInternal(f, true);
71+
return CodeGenLLVM::DeclareFunctionInternal(gvar, f, true);
72+
}
73+
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
74+
// add function as void return value
75+
CodeGenLLVM::AddFunctionInternal(gvar, f, true);
7276
// annotate as kernel function
7377
llvm::LLVMContext* ctx = llvm_target_->GetContext();
7478
module_->getOrInsertNamedMetadata("nvvm.annotations")
@@ -311,11 +315,7 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) {
311315

312316
cg->Init("TVMPTXModule", llvm_target.get(), NullOpt, false, false);
313317

314-
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
315-
ICHECK(kv.second->template IsInstance<PrimFuncNode>())
316-
<< "Can only lower IR Module with PrimFuncs";
317-
return Downcast<PrimFunc>(kv.second);
318-
});
318+
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
319319

320320
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
321321
const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path");

src/target/llvm/llvm_module.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
125125
// The unique_ptr owning the module. This becomes empty once JIT has been initialized
126126
// (EngineBuilder takes ownership of the module).
127127
std::unique_ptr<llvm::Module> module_owning_ptr_;
128-
/* \brief names of the functions declared in this module */
128+
/* \brief names of the external functions declared in this module */
129129
Array<String> function_names_;
130130
};
131131

@@ -295,7 +295,6 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
295295
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
296296
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(llvm_target.get());
297297

298-
std::vector<PrimFunc> funcs;
299298
std::string entry_func;
300299
relay::Runtime runtime =
301300
mod->GetAttr<relay::Runtime>(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp"));
@@ -315,12 +314,16 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
315314
}
316315
auto f = Downcast<PrimFunc>(kv.second);
317316
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
318-
ICHECK(global_symbol.defined());
319-
function_names_.push_back(global_symbol.value());
320-
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
321-
entry_func = global_symbol.value();
317+
bool is_entry_func = f->HasNonzeroAttr(tir::attr::kIsEntryFunc);
318+
319+
ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed externally.";
320+
321+
if (global_symbol) {
322+
function_names_.push_back(global_symbol.value());
323+
if (is_entry_func) {
324+
entry_func = global_symbol.value();
325+
}
322326
}
323-
funcs.push_back(f);
324327
}
325328
// TODO(@jroesch): follow up on this condition.
326329
// ICHECK(funcs.size() > 0);
@@ -330,7 +333,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
330333
target_c_runtime);
331334
cg->SetFastMathFlags(llvm_target->GetFastMathFlags());
332335

333-
cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
336+
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
334337
if (entry_func.length() != 0) {
335338
cg->AddMainFunction(entry_func);
336339
}

0 commit comments

Comments
 (0)