Skip to content

Commit b8c8aad

Browse files
authored
[BACKEND] Allow nvptx to pass ll ir to CUDAModule (#404)
1 parent 50c7a01 commit b8c8aad

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/codegen/llvm/codegen_nvptx.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
131131
};
132132

133133
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
134-
CHECK(target.length(
135-
) >= 5 &&
134+
CHECK(target.length() >= 5 &&
136135
target.substr(0, 5) == "nvptx");
137136
llvm::TargetMachine* tm = GetLLVMTargetMachine(
138137
"-mtriple=nvptx64-nvidia-cuda -mcpu=sm_20" +
@@ -144,16 +143,19 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
144143
cg->AddFunction(f);
145144
}
146145
std::unique_ptr<llvm::Module> module = cg->Finish();
147-
llvm::SmallString<8> data;
148-
llvm::raw_svector_ostream dest(data);
149-
dest.SetUnbuffered();
146+
llvm::SmallString<8> data_ptx, data_ll;
147+
llvm::raw_svector_ostream dest_ptx(data_ptx), dest_ll(data_ll);
148+
dest_ptx.SetUnbuffered();
149+
dest_ll.SetUnbuffered();
150150
llvm::legacy::PassManager pass;
151151
CHECK(tm->addPassesToEmitFile(
152-
pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
152+
pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0)
153153
<< "Cannot emit target CGFT_ObjectFile";
154154
pass.run(*module);
155-
std::string ptx(data.begin(), data.end());
156-
return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), "");
155+
module->print(dest_ll, nullptr);
156+
std::string ptx(data_ptx.begin(), data_ptx.end());
157+
std::string ll(data_ll.begin(), data_ll.end());
158+
return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), ll);
157159
}
158160

159161
TVM_REGISTER_API("codegen.build_nvptx")

0 commit comments

Comments
 (0)