Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class CodeGenAMDGPU : public CodeGenLLVM {
CodeGenAMDGPU() = default;
virtual ~CodeGenAMDGPU() = default;

void AddFunction(const PrimFunc& f) final {
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true);
CodeGenLLVM::AddFunctionInternal(gvar, f, true);
function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
std::ostringstream attr;
attr << "1," << DetectROCMmaxThreadsPerBlock();
Expand Down Expand Up @@ -262,11 +262,7 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) {

cg->Init("TVMAMDGPUModule", llvm_target.get(), NullOpt, false, false);

cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
ICHECK(kv.second->template IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
return Downcast<PrimFunc>(kv.second);
});
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());

llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
Expand Down
13 changes: 6 additions & 7 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,17 @@ llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) {
#endif
}

void CodeGenCPU::AddFunction(const PrimFunc& f) {
void CodeGenCPU::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
#if TVM_LLVM_VERSION >= 50
di_subprogram_ = CreateDebugFunction(f);
#endif
EmitDebugLocation(f->span);
CodeGenLLVM::AddFunction(f);
CodeGenLLVM::AddFunction(gvar, f);
if (f_tvm_register_system_symbol_ != nullptr) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back(
std::make_pair(global_symbol.value().operator std::string(), function_));
if (auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
export_system_symbols_.emplace_back(
std::make_pair(global_symbol.value().operator std::string(), function_));
}
}
AddDebugInformation(f, function_);
}
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class CodeGenCPU : public CodeGenLLVM {
void Init(const std::string& module_name, LLVMTarget* llvm_target,
Optional<String> system_lib_prefix, bool dynamic_lookup,
bool target_c_runtime) override;
void AddFunction(const PrimFunc& f) override;
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) override;
void AddMainFunction(const std::string& entry_func_name) override;
std::unique_ptr<llvm::Module> Finish() override;
void VisitStmt_(const AssertStmtNode* op) override;
Expand Down
4 changes: 1 addition & 3 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {

auto cg = std::make_unique<CodeGenHexagon>();

std::vector<PrimFunc> funcs;
std::string entry_func;

for (auto kv : mod->functions) {
Expand All @@ -562,11 +561,10 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
ICHECK(global_symbol.defined());
entry_func = global_symbol.value();
}
funcs.emplace_back(f);
}

cg->Init("TVMHexagonModule", llvm_target.get(), NullOpt, false, false);
cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
if (entry_func.length() != 0) {
cg->AddMainFunction(entry_func);
}
Expand Down
80 changes: 62 additions & 18 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,13 @@ void CodeGenLLVM::InitTarget() {
#endif // TVM_LLVM_VERSION >= 60
}

void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }
llvm::Function* CodeGenLLVM::DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) {
return this->DeclareFunctionInternal(gvar, f, false);
}

void CodeGenLLVM::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
this->AddFunctionInternal(gvar, f, false);
}

void CodeGenLLVM::InitFuncState() {
var_map_.clear();
Expand All @@ -234,15 +240,34 @@ void CodeGenLLVM::InitFuncState() {
analyzer_.reset(new arith::Analyzer());
}

void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
this->InitFuncState();
std::tuple<std::string, llvm::Function::LinkageTypes> CodeGenLLVM::GetLinkage(
const GlobalVar& gvar, const PrimFunc& func) {
if (auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
return {global_symbol.value(), llvm::Function::ExternalLinkage};
}

std::string symbol_name = [&]() {
std::stringstream ss;
ss << "_internal_";
ss << gvar->name_hint;
return ss.str();
}();

return {symbol_name, llvm::Function::PrivateLinkage};
}

llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func,
bool ret_void) {
if (auto it = functions_.find(gvar.get()); it != functions_.end()) {
return it->second;
}

ICHECK_EQ(f->buffer_map.size(), 0U)
ICHECK_EQ(func->buffer_map.size(), 0U)
<< "Cannot codegen function with buffer_map, please lower them first";

std::vector<llvm::Type*> param_types;
is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias);
for (Var param : f->params) {
is_restricted_ = func->HasNonzeroAttr(tir::attr::kNoAlias);
for (Var param : func->params) {
param_types.push_back(GetLLVMType(param));
if (!is_restricted_ && param.dtype().is_handle()) {
alias_var_set_.insert(param.get());
Expand All @@ -254,17 +279,26 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
llvm::FunctionType* ftype =
llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);

auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
function_ = module_->getFunction(MakeStringRef(global_symbol.value()));
if (function_ == nullptr) {
function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
MakeStringRef(global_symbol.value()), module_.get());
auto [symbol_name, linkage_type] = GetLinkage(gvar, func);

auto function = module_->getFunction(MakeStringRef(symbol_name));
if (function == nullptr) {
function =
llvm::Function::Create(ftype, linkage_type, MakeStringRef(symbol_name), module_.get());
}
function_->setCallingConv(llvm::CallingConv::C);
function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
SetTargetAttributes(function_);
function->setCallingConv(llvm::CallingConv::C);
function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
SetTargetAttributes(function);

functions_[gvar.get()] = function;

return function;
}

void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void) {
this->InitFuncState();

function_ = DeclareFunctionInternal(gvar, f, ret_void);

// set var map and align information
auto arg_it = function_->arg_begin();
Expand Down Expand Up @@ -1747,9 +1781,19 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
VLOG(2) << "CreateIntrinsic done";
return x;
}
} else if (auto* ptr_gvar = op->op.as<GlobalVarNode>()) {
auto gvar = GetRef<GlobalVar>(ptr_gvar);
auto it = functions_.find(ptr_gvar);
ICHECK(it != functions_.end()) << "Call to undefined GlobalVar \"" << gvar << "\"";
llvm::Function* callee = it->second;
std::vector<llvm::Value*> arg_value;
for (const auto& arg : op->args) {
arg_value.push_back(MakeValue(arg));
}
return builder_->CreateCall(callee, arg_value);

} else {
ICHECK(op->op.as<GlobalVarNode>());
LOG(FATAL) << "Do not yet support cross function call";
LOG(FATAL) << "Unsupported operation in CallNode: " << op->op;
}
}

Expand Down
61 changes: 51 additions & 10 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#include <algorithm>
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
Expand Down Expand Up @@ -132,11 +133,17 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*/
void SetFastMathFlags(llvm::FastMathFlags fmf);

virtual llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f);

/*!
* \brief Compile and add function f to the current module.
*
* \param gvar The GlobalVar which may be used to may internal calls
* to this function from elsewhere in the module.
*
* \param f The function to be added.
*/
virtual void AddFunction(const PrimFunc& f);
virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& f);
/*!
* \brief Add main function as the entry name
* \param entry_func_name The name of entry function to be added.
Expand Down Expand Up @@ -356,7 +363,28 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
// Get correct address space depending on the backend
virtual unsigned GetGlobalAddressSpace() const;
void AddFunctionInternal(const PrimFunc& f, bool ret_void);

/*! \brief Get the linkage parameters for the function
*
* Returns a tuple whose first element is the name of the function
* and whose second element is the linkage type to be used
* (e.g. llvm::Function::ExternalLinkage or
* llvm::Function::PrivateLinkage)
*
* \param func The PrimFunc whose symbol name and linkage type
* should be returned
*
* \param gvar The GlobalVar to be used when generating the symbol
* name. Only used for internal functions, for which the
* kGlobalSymbol attribute is not defined.
*/
std::tuple<std::string, llvm::Function::LinkageTypes> GetLinkage(const GlobalVar& gvar,
const PrimFunc& func);

llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void);

void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void);

// Create extern call
llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name,
const std::vector<llvm::Value*>& value);
Expand Down Expand Up @@ -517,6 +545,11 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
std::unordered_map<const VarNode*, llvm::Value*> var_map_;
// global strings
std::unordered_map<std::string, llvm::Constant*> str_map_;

// Map from TVM's GlobalVar to the llvm::Function that represents
// that function.
std::unordered_map<const GlobalVarNode*, llvm::Function*> functions_;

// Whether current function is restricted
bool is_restricted_{true};
// The analyzer information
Expand Down Expand Up @@ -569,18 +602,26 @@ inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) {

template <typename IterType, typename ConvType>
void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfunc) {
std::vector<PrimFunc> funcs;
std::vector<std::tuple<GlobalVar, PrimFunc>> funcs;
for (auto it = begin; it != end; ++it) {
funcs.push_back(pfunc(*it));
auto [gvar, func] = *it;
auto converted = pfunc(func);
funcs.push_back({gvar, Downcast<PrimFunc>(converted)});
}
std::sort(funcs.begin(), funcs.end(), [](PrimFunc func_a, PrimFunc func_b) {
std::string name_a = func_a->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
std::string name_b = func_b->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
std::sort(funcs.begin(), funcs.end(), [this](const auto& pair_a, const auto& pair_b) {
const auto& [gvar_a, func_a] = pair_a;
std::string name_a = std::get<std::string>(GetLinkage(gvar_a, func_a));

const auto& [gvar_b, func_b] = pair_b;
std::string name_b = std::get<std::string>(GetLinkage(gvar_b, func_b));
return name_a < name_b;
});
for (auto& f : funcs) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
AddFunction(f);

for (const auto& [gvar, func] : funcs) {
DeclareFunction(gvar, func);
}
for (const auto& [gvar, func] : funcs) {
AddFunction(gvar, func);
}
}

Expand Down
14 changes: 7 additions & 7 deletions src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,13 @@ namespace codegen {
// NVPTX code generator.
class CodeGenNVPTX : public CodeGenLLVM {
public:
void AddFunction(const PrimFunc& f) final {
llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true);
return CodeGenLLVM::DeclareFunctionInternal(gvar, f, true);
}
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(gvar, f, true);
// annotate as kernel function
llvm::LLVMContext* ctx = llvm_target_->GetContext();
module_->getOrInsertNamedMetadata("nvvm.annotations")
Expand Down Expand Up @@ -311,11 +315,7 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) {

cg->Init("TVMPTXModule", llvm_target.get(), NullOpt, false, false);

cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) {
ICHECK(kv.second->template IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
return Downcast<PrimFunc>(kv.second);
});
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());

llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path");
Expand Down
19 changes: 11 additions & 8 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
// The unique_ptr owning the module. This becomes empty once JIT has been initialized
// (EngineBuilder takes ownership of the module).
std::unique_ptr<llvm::Module> module_owning_ptr_;
/* \brief names of the functions declared in this module */
/* \brief names of the external functions declared in this module */
Array<String> function_names_;
};

Expand Down Expand Up @@ -295,7 +295,6 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine();
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(llvm_target.get());

std::vector<PrimFunc> funcs;
std::string entry_func;
relay::Runtime runtime =
mod->GetAttr<relay::Runtime>(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp"));
Expand All @@ -315,12 +314,16 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
}
auto f = Downcast<PrimFunc>(kv.second);
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined());
function_names_.push_back(global_symbol.value());
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
entry_func = global_symbol.value();
bool is_entry_func = f->HasNonzeroAttr(tir::attr::kIsEntryFunc);

ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed externally.";

if (global_symbol) {
function_names_.push_back(global_symbol.value());
if (is_entry_func) {
entry_func = global_symbol.value();
}
}
funcs.push_back(f);
}
// TODO(@jroesch): follow up on this condition.
// ICHECK(funcs.size() > 0);
Expand All @@ -330,7 +333,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
target_c_runtime);
cg->SetFastMathFlags(llvm_target->GetFastMathFlags());

cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
if (entry_func.length() != 0) {
cg->AddMainFunction(entry_func);
}
Expand Down
Loading