From abde0d44d0692d84a8e5c014630e61c46bc9c542 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Mar 2023 08:44:34 -0500 Subject: [PATCH 1/3] [CodeGen][LLVM] Codegen to generate internal functions Previously, `CodeGenLLVM` required all TIR PrimFuncs to have the `kGlobalSymbol` attribute, using its value as the externally-visible symbol in the generated library. This commit relaxes that requirement, using the presence of `kGlobalSymbol` to indicate whether a function should be exposed externally. If `kGlobalSymbol` is not defined, then the symbol name is generated from the name of the `tvm::GlobalVar` with the prefix `"_internal_"`, and the symbol is not exposed externally. Since this does not change the codegen behavior for any function that was previously supported, this is not a breaking change. --- src/target/llvm/codegen_amdgpu.cc | 10 ++--- src/target/llvm/codegen_cpu.cc | 13 +++--- src/target/llvm/codegen_cpu.h | 2 +- src/target/llvm/codegen_hexagon.cc | 4 +- src/target/llvm/codegen_llvm.cc | 68 ++++++++++++++++++++++-------- src/target/llvm/codegen_llvm.h | 61 ++++++++++++++++++++++----- src/target/llvm/codegen_nvptx.cc | 14 +++--- src/target/llvm/llvm_module.cc | 19 +++++---- 8 files changed, 131 insertions(+), 60 deletions(-) diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 3efe548e1c2e..6bf1ca6eabd5 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -87,9 +87,9 @@ class CodeGenAMDGPU : public CodeGenLLVM { CodeGenAMDGPU() = default; virtual ~CodeGenAMDGPU() = default; - void AddFunction(const PrimFunc& f) final { + void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final { // add function as void return value - CodeGenLLVM::AddFunctionInternal(f, true); + CodeGenLLVM::AddFunctionInternal(gvar, f, true); function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); std::ostringstream attr; attr << "1," << DetectROCMmaxThreadsPerBlock(); @@ -262,11 +262,7 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { cg->Init("TVMAMDGPUModule", llvm_target.get(), NullOpt, false, false); - cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { - ICHECK(kv.second->template IsInstance()) - << "Can only lower IR Module with PrimFuncs"; - return Downcast(kv.second); - }); + cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end()); llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index ca7fb4e41476..f129511e5a17 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -220,18 +220,17 @@ llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) { #endif } -void CodeGenCPU::AddFunction(const PrimFunc& f) { +void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& f) { #if TVM_LLVM_VERSION >= 50 di_subprogram_ = CreateDebugFunction(f); #endif EmitDebugLocation(f->span); - CodeGenLLVM::AddFunction(f); + CodeGenLLVM::AddFunction(gvar, f); if (f_tvm_register_system_symbol_ != nullptr) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - export_system_symbols_.emplace_back( - std::make_pair(global_symbol.value().operator std::string(), function_)); + if (auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol)) { + export_system_symbols_.emplace_back( + std::make_pair(global_symbol.value().operator std::string(), function_)); + } } AddDebugInformation(f, function_); } diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 3cc1bbeb419e..2924aee46e6b 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -67,7 +67,7 @@ class CodeGenCPU : public CodeGenLLVM { void Init(const std::string& module_name, LLVMTarget* llvm_target, Optional system_lib_prefix, bool dynamic_lookup, bool target_c_runtime) override; - void AddFunction(const PrimFunc& f) override; + void AddFunction(const GlobalVar& gvar, const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; std::unique_ptr Finish() override; void VisitStmt_(const AssertStmtNode* op) override; diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index a2f13e98b1e5..2d0945c70498 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -547,7 +547,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { auto cg = std::make_unique(); - std::vector funcs; std::string entry_func; for (auto kv : mod->functions) { @@ -562,11 +561,10 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { ICHECK(global_symbol.defined()); entry_func = global_symbol.value(); } - funcs.emplace_back(f); } cg->Init("TVMHexagonModule", llvm_target.get(), NullOpt, false, false); - cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); + cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index eb53e9b6dc87..f342f559c614 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -224,7 +224,13 @@ void CodeGenLLVM::InitTarget() { #endif // TVM_LLVM_VERSION >= 60 } -void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); } +llvm::Function* CodeGenLLVM::DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) { + return this->DeclareFunctionInternal(gvar, f, false); +} + +void CodeGenLLVM::AddFunction(const GlobalVar& gvar, const PrimFunc& f) { + this->AddFunctionInternal(gvar, f, false); +} void CodeGenLLVM::InitFuncState() { var_map_.clear(); @@ -234,15 +240,34 @@ void CodeGenLLVM::InitFuncState() { analyzer_.reset(new arith::Analyzer()); } -void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { - this->InitFuncState(); +std::tuple CodeGenLLVM::GetLinkage( + const GlobalVar& gvar, const PrimFunc& func) { + if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { + return {global_symbol.value(), llvm::Function::ExternalLinkage}; + } + + std::string symbol_name = [&]() { + std::stringstream ss; + ss << "_internal_"; + ss << gvar->name_hint; + return ss.str(); + }(); + + return {symbol_name, llvm::Function::PrivateLinkage}; +} + +llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func, + bool ret_void) { + if (auto it = functions_.find(gvar.get()); it != functions_.end()) { + return it->second; + } - ICHECK_EQ(f->buffer_map.size(), 0U) + ICHECK_EQ(func->buffer_map.size(), 0U) << "Cannot codegen function with buffer_map, please lower them first"; std::vector param_types; - is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias); - for (Var param : f->params) { + is_restricted_ = func->HasNonzeroAttr(tir::attr::kNoAlias); + for (Var param : func->params) { param_types.push_back(GetLLVMType(param)); if (!is_restricted_ && param.dtype().is_handle()) { alias_var_set_.insert(param.get()); @@ -254,17 +279,26 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { llvm::FunctionType* ftype = llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - function_ = module_->getFunction(MakeStringRef(global_symbol.value())); - if (function_ == nullptr) { - function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - MakeStringRef(global_symbol.value()), module_.get()); - } - function_->setCallingConv(llvm::CallingConv::C); - function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); - SetTargetAttributes(function_); + auto [symbol_name, linkage_type] = GetLinkage(gvar, func); + + auto function = module_->getFunction(MakeStringRef(symbol_name)); + if (function == nullptr) { + function = + llvm::Function::Create(ftype, linkage_type, MakeStringRef(symbol_name), module_.get()); + } + function->setCallingConv(llvm::CallingConv::C); + function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); + SetTargetAttributes(function); + + functions_[gvar.get()] = function; + + return function; +} + +void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void) { + this->InitFuncState(); + + function_ = DeclareFunctionInternal(gvar, f, ret_void); // set var map and align information auto arg_it = function_->arg_begin(); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index ca4c916f84d8..07d96e40e07c 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -63,6 +63,7 @@ #include #include #include +#include #include #include #include @@ -132,11 +133,17 @@ class CodeGenLLVM : public ExprFunctor, */ void SetFastMathFlags(llvm::FastMathFlags fmf); + virtual llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f); + /*! * \brief Compile and add function f to the current module. + * + * \param gvar The GlobalVar which may be used to may internal calls + * to this function from elsewhere in the module. + * * \param f The function to be added. */ - virtual void AddFunction(const PrimFunc& f); + virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& f); /*! * \brief Add main function as the entry name * \param entry_func_name The name of entry function to be added. @@ -356,7 +363,28 @@ class CodeGenLLVM : public ExprFunctor, virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const; // Get correct address space depending on the backend virtual unsigned GetGlobalAddressSpace() const; - void AddFunctionInternal(const PrimFunc& f, bool ret_void); + + /*! \brief Get the linkage parameters for the function + * + * Returns a tuple whose first element is the name of the function + * and whose second element is the linkage type to be used + * (e.g. llvm::Function::ExternalLinkage or + * llvm::Function::PrivateLinkage) + * + * \param func The PrimFunc whose symbol name and linkage type + * should be returned + * + * \param gvar The GlobalVar to be used when generating the symbol + * name. Only used for internal functions, for which the + * kGlobalSymbol attribute is not defined. + */ + std::tuple GetLinkage(const GlobalVar& gvar, + const PrimFunc& func); + + llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void); + + void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void); + // Create extern call llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name, const std::vector& value); @@ -517,6 +545,11 @@ class CodeGenLLVM : public ExprFunctor, std::unordered_map var_map_; // global strings std::unordered_map str_map_; + + // Map from TVM's GlobalVar to the llvm::Function that represents + // that function. + std::unordered_map functions_; + // Whether current function is restricted bool is_restricted_{true}; // The analyzer information @@ -569,18 +602,26 @@ inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) { template void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc) { - std::vector funcs; + std::vector> funcs; for (auto it = begin; it != end; ++it) { - funcs.push_back(pfunc(*it)); + auto [gvar, func] = *it; + auto converted = pfunc(func); + funcs.push_back({gvar, Downcast(converted)}); } - std::sort(funcs.begin(), funcs.end(), [](PrimFunc func_a, PrimFunc func_b) { - std::string name_a = func_a->GetAttr(tvm::attr::kGlobalSymbol).value(); - std::string name_b = func_b->GetAttr(tvm::attr::kGlobalSymbol).value(); + std::sort(funcs.begin(), funcs.end(), [this](const auto& pair_a, const auto& pair_b) { + const auto& [gvar_a, func_a] = pair_a; + std::string name_a = std::get(GetLinkage(gvar_a, func_a)); + + const auto& [gvar_b, func_b] = pair_b; + std::string name_b = std::get(GetLinkage(gvar_b, func_b)); return name_a < name_b; }); - for (auto& f : funcs) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - AddFunction(f); + + for (const auto& [gvar, func] : funcs) { + DeclareFunction(gvar, func); + } + for (const auto& [gvar, func] : funcs) { + AddFunction(gvar, func); } } diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 18f60922910b..a40208513079 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -66,9 +66,13 @@ namespace codegen { // NVPTX code generator. class CodeGenNVPTX : public CodeGenLLVM { public: - void AddFunction(const PrimFunc& f) final { + llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) final { // add function as void return value - CodeGenLLVM::AddFunctionInternal(f, true); + return CodeGenLLVM::DeclareFunctionInternal(gvar, f, true); + } + void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final { + // add function as void return value + CodeGenLLVM::AddFunctionInternal(gvar, f, true); // annotate as kernel function llvm::LLVMContext* ctx = llvm_target_->GetContext(); module_->getOrInsertNamedMetadata("nvvm.annotations") @@ -311,11 +315,7 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { cg->Init("TVMPTXModule", llvm_target.get(), NullOpt, false, false); - cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { - ICHECK(kv.second->template IsInstance()) - << "Can only lower IR Module with PrimFuncs"; - return Downcast(kv.second); - }); + cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end()); llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index b6a0da84752a..a2b080371422 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -125,7 +125,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { // The unique_ptr owning the module. This becomes empty once JIT has been initialized // (EngineBuilder takes ownership of the module). std::unique_ptr module_owning_ptr_; - /* \brief names of the functions declared in this module */ + /* \brief names of the external functions declared in this module */ Array function_names_; }; @@ -295,7 +295,6 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); std::unique_ptr cg = CodeGenLLVM::Create(llvm_target.get()); - std::vector funcs; std::string entry_func; relay::Runtime runtime = mod->GetAttr(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp")); @@ -315,12 +314,16 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { } auto f = Downcast(kv.second); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()); - function_names_.push_back(global_symbol.value()); - if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - entry_func = global_symbol.value(); + bool is_entry_func = f->HasNonzeroAttr(tir::attr::kIsEntryFunc); + + ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed externally."; + + if (global_symbol) { + function_names_.push_back(global_symbol.value()); + if (is_entry_func) { + entry_func = global_symbol.value(); + } } - funcs.push_back(f); } // TODO(@jroesch): follow up on this condition. // ICHECK(funcs.size() > 0); @@ -330,7 +333,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { target_c_runtime); cg->SetFastMathFlags(llvm_target->GetFastMathFlags()); - cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); + cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); } From 920ad5850e867e891379ee9ee4f3132ca80dec9a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Mar 2023 09:00:59 -0500 Subject: [PATCH 2/3] [Codegen][LLVM] Handle callsite for internal functions --- src/target/llvm/codegen_llvm.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f342f559c614..e321dd7c5f96 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1781,9 +1781,19 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { VLOG(2) << "CreateIntrinsic done"; return x; } + } else if (auto* ptr_gvar = op->op.as()) { + auto gvar = GetRef(ptr_gvar); + auto it = functions_.find(ptr_gvar); + ICHECK(it != functions_.end()) << "Call to undefined GlobalVar \"" << gvar << "\""; + llvm::Function* callee = it->second; + std::vector arg_value; + for (const auto& arg : op->args) { + arg_value.push_back(MakeValue(arg)); + } + return builder_->CreateCall(callee, arg_value); + } else { - ICHECK(op->op.as()); - LOG(FATAL) << "Do not yet support cross function call"; + LOG(FATAL) << "Unsupported operation in CallNode: " << op->op; } } From ed149dea6099f9f6c8ab1429d2491fef46a908af Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sun, 21 May 2023 09:49:27 -0500 Subject: [PATCH 3/3] [UnitTest][LLVM] Added test for LLVM codegen for subroutine --- .../unittest/test_target_codegen_llvm.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 856de716fcc7..f4326e6fc53d 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -28,7 +28,7 @@ from tvm import te from tvm.contrib import clang, utils from tvm.relay.backend import Runtime -from tvm.script import tir as T +from tvm.script import tir as T, ir as I from tvm.target.codegen import llvm_get_intrinsic_name, llvm_lookup_intrinsic_id @@ -1022,5 +1022,32 @@ def func(a: T.handle("float64"), b: T.handle("float64"), n: T.int64): tvm.build(func, target="llvm") +@tvm.testing.requires_llvm +def test_subroutine_call(): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, dtype="float32")): + T.func_attr({"global_symbol": "main"}) + mod.subroutine(A.data) + + @T.prim_func + def subroutine(A_data: T.handle("float32")): + # The calling_conv parameter is to prevent MakePackedAPI + # from changing the call signature of the subroutine. + T.func_attr({"global_symbol": "subroutine", "calling_conv": -1}) + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 42.0 + + target = "llvm" + dev = tvm.cpu() + + built = tvm.build(mod) + + arr = tvm.nd.array(np.zeros([1], "float32"), device=dev) + built["main"](arr) + assert arr.numpy()[0] == 42.0 + + if __name__ == "__main__": tvm.testing.main()