@@ -131,8 +131,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
131131};
132132
133133runtime::Module BuildNVPTX (Array<LoweredFunc> funcs, std::string target) {
134- CHECK (target.length (
135- ) >= 5 &&
134+ CHECK (target.length () >= 5 &&
136135 target.substr (0 , 5 ) == " nvptx" );
137136 llvm::TargetMachine* tm = GetLLVMTargetMachine (
138137 " -mtriple=nvptx64-nvidia-cuda -mcpu=sm_20" +
@@ -144,16 +143,19 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
144143 cg->AddFunction (f);
145144 }
146145 std::unique_ptr<llvm::Module> module = cg->Finish ();
147- llvm::SmallString<8 > data;
148- llvm::raw_svector_ostream dest (data);
149- dest.SetUnbuffered ();
146+ llvm::SmallString<8 > data_ptx, data_ll;
147+ llvm::raw_svector_ostream dest_ptx (data_ptx), dest_ll (data_ll);
148+ dest_ptx.SetUnbuffered ();
149+ dest_ll.SetUnbuffered ();
150150 llvm::legacy::PassManager pass;
151151 CHECK (tm->addPassesToEmitFile (
152- pass, dest , llvm::TargetMachine::CGFT_AssemblyFile) == 0 )
152+ pass, dest_ptx , llvm::TargetMachine::CGFT_AssemblyFile) == 0 )
153153 << " Cannot emit target CGFT_ObjectFile" ;
154154 pass.run (*module );
155- std::string ptx (data.begin (), data.end ());
156- return CUDAModuleCreate (ptx, " ptx" , ExtractFuncInfo (funcs), " " );
155+ module ->print (dest_ll, nullptr );
156+ std::string ptx (data_ptx.begin (), data_ptx.end ());
157+ std::string ll (data_ll.begin (), data_ll.end ());
158+ return CUDAModuleCreate (ptx, " ptx" , ExtractFuncInfo (funcs), ll);
157159}
158160
159161TVM_REGISTER_API (" codegen.build_nvptx" )
0 commit comments