Skip to content

Commit 7392432

Browse files
authored
[LLVM] Remove the "ret_void" argument of AddFunction (#15127)
Prior to this commit, the `"ret_void"` argument needed to be explicitly provided to `CodeGenLLVM::AddFunction` and `CodeGenLLVM::DeclareFunction`. If this was inconsistent with the `builtin::ret()` usage within the `PrimFunc`, this could cause the incorrect return type in the generated LLVM-IR, resulting in LLVM IR verification failures. This commit removes the `"ret_void"` argument, instead using the type annotation in `PrimFunc::ret_type`, removing this opportunity for inconsistency. This PR is intended to fix a ROCm regression reported in #14901 (comment).
1 parent bee073b commit 7392432

File tree

4 files changed

+16
-16
lines changed

4 files changed

+16
-16
lines changed

src/target/llvm/codegen_amdgpu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
8989

9090
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
9191
// add function as void return value
92-
CodeGenLLVM::AddFunctionInternal(gvar, f, true);
92+
CodeGenLLVM::AddFunctionInternal(gvar, f);
9393
function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
9494
std::ostringstream attr;
9595
attr << "1," << DetectROCMmaxThreadsPerBlock();

src/target/llvm/codegen_llvm.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,11 @@ void CodeGenLLVM::InitTarget() {
227227
}
228228

229229
llvm::Function* CodeGenLLVM::DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) {
230-
return this->DeclareFunctionInternal(gvar, f, false);
230+
return this->DeclareFunctionInternal(gvar, f);
231231
}
232232

233233
void CodeGenLLVM::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
234-
this->AddFunctionInternal(gvar, f, false);
234+
this->AddFunctionInternal(gvar, f);
235235
}
236236

237237
void CodeGenLLVM::InitFuncState() {
@@ -258,8 +258,7 @@ std::tuple<std::string, llvm::Function::LinkageTypes> CodeGenLLVM::GetLinkage(
258258
return {symbol_name, llvm::Function::PrivateLinkage};
259259
}
260260

261-
llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func,
262-
bool ret_void) {
261+
llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func) {
263262
if (auto it = functions_.find(gvar.get()); it != functions_.end()) {
264263
return it->second;
265264
}
@@ -275,11 +274,9 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons
275274
alias_var_set_.insert(param.get());
276275
}
277276
}
278-
// TODO(tvm-team):
279-
// Update the function type to respect the ret_type field of f.
280-
// Once we allow more flexibility in the PrimFunc.
277+
281278
llvm::FunctionType* ftype =
282-
llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);
279+
llvm::FunctionType::get(GetLLVMType(func->ret_type), param_types, false);
283280

284281
auto [symbol_name, linkage_type] = GetLinkage(gvar, func);
285282

@@ -297,10 +294,10 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons
297294
return function;
298295
}
299296

300-
void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void) {
297+
void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f) {
301298
this->InitFuncState();
302299

303-
function_ = DeclareFunctionInternal(gvar, f, ret_void);
300+
function_ = DeclareFunctionInternal(gvar, f);
304301

305302
// set var map and align information
306303
auto arg_it = function_->arg_begin();
@@ -341,7 +338,10 @@ void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f,
341338
#endif
342339

343340
EmitDebugLocation(f->span);
344-
if (ret_void) {
341+
342+
if (IsVoidType(f->ret_type)) {
343+
// All other return types are handled when encountering
344+
// builtin::ret().
345345
builder_->CreateRetVoid();
346346
} else {
347347
builder_->CreateRet(ConstInt32(0));

src/target/llvm/codegen_llvm.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,9 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
381381
std::tuple<std::string, llvm::Function::LinkageTypes> GetLinkage(const GlobalVar& gvar,
382382
const PrimFunc& func);
383383

384-
llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void);
384+
llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f);
385385

386-
void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void);
386+
void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f);
387387

388388
// Create extern call
389389
llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name,

src/target/llvm/codegen_nvptx.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ class CodeGenNVPTX : public CodeGenLLVM {
6868
public:
6969
llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) final {
7070
// add function as void return value
71-
return CodeGenLLVM::DeclareFunctionInternal(gvar, f, true);
71+
return CodeGenLLVM::DeclareFunctionInternal(gvar, f);
7272
}
7373
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
7474
// add function as void return value
75-
CodeGenLLVM::AddFunctionInternal(gvar, f, true);
75+
CodeGenLLVM::AddFunctionInternal(gvar, f);
7676
// annotate as kernel function
7777
llvm::LLVMContext* ctx = llvm_target_->GetContext();
7878
module_->getOrInsertNamedMetadata("nvvm.annotations")

0 commit comments

Comments
 (0)