Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there any plan to upgrade to LLVM 14? #871

Closed
wanjunling168 opened this issue Nov 10, 2022 · 3 comments
Closed

Is there any plan to upgrade to LLVM 14? #871

wanjunling168 opened this issue Nov 10, 2022 · 3 comments

Comments

@wanjunling168
Copy link

wanjunling168 commented Nov 10, 2022

BTW, I managed to compile triton on Windows, with LLVM==14.0.6, MSVC 14.29.30133, torch==1.12.1, CUDA==11.6, CUDNN==8.4.1.50 (these libraries are all installed by conda)

with following patches:

diff --git a/CMakeLists.txt b/CMakeLists.txt
index deebd160c..5e72b3b5b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -43,7 +43,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS  -std=gnu++17")
 ##########
 if("${LLVM_LIBRARY_DIR}" STREQUAL "")
     if(WIN32)
-      find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
+      find_package(LLVM 14 REQUIRED COMPONENTS nvptx amdgpu)

       include_directories(${LLVM_INCLUDE_DIRS})
       separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
diff --git a/include/triton/driver/error.h b/include/triton/driver/error.h
index 6502b7493..65aee8a3c 100755
--- a/include/triton/driver/error.h
+++ b/include/triton/driver/error.h
@@ -4,6 +4,7 @@
 #define _TRITON_DRIVER_ERROR_H_

 #include <exception>
+#include <string>
 #include "triton/driver/dispatch.h"


diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc
index 1057cfef6..a5d485e1a 100644
--- a/lib/codegen/pass.cc
+++ b/lib/codegen/pass.cc
@@ -1,6 +1,7 @@
 #include "triton/codegen/pass.h"

 #include "llvm/IR/Constants.h"
+#include "llvm/Pass.h"  // Add Definition of ModulePass and FunctionPass
 #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc
index 4bd0baf34..e7b045923 100644
--- a/lib/codegen/selection/generator.cc
+++ b/lib/codegen/selection/generator.cc
@@ -332,16 +332,24 @@ void generator::visit_launch_inst(ir::launch_inst *launch) {
   builder_->SetInsertPoint(launch_bb);

   //
-  builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0}));
-  builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1}));
-  builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2}));
+  builder_->CreateStore(vals_[launch->get_grid()[0]][{}],
+                        builder_->CreateGEP(get_param_arg_tys[1], grid, {_0, _0}));
+  builder_->CreateStore(vals_[launch->get_grid()[1]][{}],
+                        builder_->CreateGEP(get_param_arg_tys[1], grid, {_0, _1}));
+  builder_->CreateStore(vals_[launch->get_grid()[2]][{}],
+                        builder_->CreateGEP(get_param_arg_tys[1], grid, {_0, _2}));
   Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]);
-  builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0}));
-  builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1}));
-  builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2}));
+  builder_->CreateStore(num_warps, builder_->CreateGEP(get_param_arg_tys[2], block, {_0, _0}));
+  builder_->CreateStore(builder_->getInt32(1),
+                        builder_->CreateGEP(get_param_arg_tys[2], block, {_0, _1}));
+  builder_->CreateStore(builder_->getInt32(1),
+                        builder_->CreateGEP(get_param_arg_tys[2], block, {_0, _2}));
   Function* called_fn = fns_[fn];
   Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]);
-  Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)});
+  Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee,
+                       builder_->CreateLoad(get_param_arg_tys[1], grid),
+                       builder_->CreateLoad(get_param_arg_tys[2], block),
+                       builder_->getInt32(0)});
   // forwrd-declare cudaLaunchDeviceV2
   std::vector<Type*> launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()};
   FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false);
@@ -364,7 +372,7 @@ void generator::visit_launch_inst(ir::launch_inst *launch) {
     unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8;
     off = (off + size - 1) / size * size;
     // get pointer to current arg
-    Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off));
+    Value* curr_arg_ptr = builder_->CreateGEP(curr_arg_ty, arg_ptr, builder_->getInt32(off));
     curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space));
     // store arg
     builder_->CreateStore(curr_arg, curr_arg_ptr);
@@ -622,17 +630,17 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp16x4_to_fp8x4(Value *in0
   "}", "=r,r,r", false);
   Value *packed_in0 = UndefValue::get(vec_ty(f16_ty, 2));
   Value *packed_in1 = UndefValue::get(vec_ty(f16_ty, 2));
-  packed_in0 = insert_elt(packed_in0, in0, (int)0);
-  packed_in0 = insert_elt(packed_in0, in1, (int)1);
-  packed_in1 = insert_elt(packed_in1, in2, (int)0);
-  packed_in1 = insert_elt(packed_in1, in3, (int)1);
+  packed_in0 = insert_elt(packed_in0, in0, (uint64_t)0);
+  packed_in0 = insert_elt(packed_in0, in1, (uint64_t)1);
+  packed_in1 = insert_elt(packed_in1, in2, (uint64_t)0);
+  packed_in1 = insert_elt(packed_in1, in3, (uint64_t)1);
   Value *in_arg0 = bit_cast(packed_in0, i32_ty);
   Value *in_arg1 = bit_cast(packed_in1, i32_ty);
   Value *ret = call(ptx, {in_arg0, in_arg1});
-  Value *ret0 = extract_elt(ret, (int)0);
-  Value *ret1 = extract_elt(ret, (int)1);
-  Value *ret2 = extract_elt(ret, (int)2);
-  Value *ret3 = extract_elt(ret, (int)3);
+  Value *ret0 = extract_elt(ret, (uint64_t)0);
+  Value *ret1 = extract_elt(ret, (uint64_t)1);
+  Value *ret2 = extract_elt(ret, (uint64_t)2);
+  Value *ret3 = extract_elt(ret, (uint64_t)3);
   return std::make_tuple(ret0, ret1, ret2, ret3);
 }

@@ -726,17 +734,17 @@ std::tuple<Value*, Value*, Value*, Value*> generator::bf16x4_to_fp8x4(Value *in0
   "}", "=r,r,r", false);
   Value *packed_in0 = UndefValue::get(vec_ty(bf16_ty, 2));
   Value *packed_in1 = UndefValue::get(vec_ty(bf16_ty, 2));
-  packed_in0 = insert_elt(packed_in0, in0, (int)0);
-  packed_in0 = insert_elt(packed_in0, in1, (int)1);
-  packed_in1 = insert_elt(packed_in1, in2, (int)0);
-  packed_in1 = insert_elt(packed_in1, in3, (int)1);
+  packed_in0 = insert_elt(packed_in0, in0, (uint64_t)0);
+  packed_in0 = insert_elt(packed_in0, in1, (uint64_t)1);
+  packed_in1 = insert_elt(packed_in1, in2, (uint64_t)0);
+  packed_in1 = insert_elt(packed_in1, in3, (uint64_t)1);
   Value *in_arg0 = bit_cast(packed_in0, i32_ty);
   Value *in_arg1 = bit_cast(packed_in1, i32_ty);
   Value *ret = call(ptx, {in_arg0, in_arg1});
-  Value *ret0 = extract_elt(ret, (int)0);
-  Value *ret1 = extract_elt(ret, (int)1);
-  Value *ret2 = extract_elt(ret, (int)2);
-  Value *ret3 = extract_elt(ret, (int)3);
+  Value *ret0 = extract_elt(ret, (uint64_t)0);
+  Value *ret1 = extract_elt(ret, (uint64_t)1);
+  Value *ret2 = extract_elt(ret, (uint64_t)2);
+  Value *ret3 = extract_elt(ret, (uint64_t)3);
   return std::make_tuple(ret0, ret1, ret2, ret3);
 }

@@ -888,8 +896,8 @@ std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> gener
   "}", "=r,=r,=r,=r,r,r,r", false);

   Value *packed_in = UndefValue::get(vec_ty(i16_ty, 2));
-  packed_in = insert_elt(packed_in, in0, (int)0);
-  packed_in = insert_elt(packed_in, in0, (int)1);
+  packed_in = insert_elt(packed_in, in0, (uint64_t)0);
+  packed_in = insert_elt(packed_in, in0, (uint64_t)1);
   Value *in = bit_cast(packed_in, i32_ty);

   Value *ret = call(ptx, {in, scale_x512, shift});
@@ -986,13 +994,13 @@ std::tuple<Value*, Value*, Value*, Value*> generator::int32_to_float16x4(Value *
 std::tuple<Value*, Value*> generator::prepare_scale_shift(Value *scale, Value *shift){
   Value *scale_x512 = fmul(scale, bit_cast(i16(0x6000), f16_ty));
   Value *p_scale_x512 = UndefValue::get(vec_ty(f16_ty, 2));
-  p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)0);
-  p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)1);
+  p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (uint64_t)0);
+  p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (uint64_t)1);
   p_scale_x512 = bit_cast(p_scale_x512, i32_ty);

   Value *p_shift = UndefValue::get(vec_ty(f16_ty, 2));
-  p_shift = insert_elt(p_shift, shift, (int)0);
-  p_shift = insert_elt(p_shift, shift, (int)1);
+  p_shift = insert_elt(p_shift, shift, (uint64_t)0);
+  p_shift = insert_elt(p_shift, shift, (uint64_t)1);
   p_shift = bit_cast(p_shift, i32_ty);

   return std::make_tuple(p_scale_x512, p_shift);
@@ -1413,7 +1421,7 @@ void generator::visit_store_inst(ir::store_inst * x){
     // ---
     // finally call inline ASM
     // ---
-    InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
+    InlineAsm *i_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
     std::vector<Value*> args = {pred, ptr};
     for(unsigned int ii = 0; ii < n_words; ii++){
       size_t n_subw = width / nbits;
@@ -1429,7 +1437,7 @@ void generator::visit_store_inst(ir::store_inst * x){
     }
     if (has_l2_evict_policy)
       args.push_back(policies_.at(x->get_eviction_policy()));
-    call(_asm, args);
+    call(i_asm, args);
   }
 }
 void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
@@ -3660,7 +3668,7 @@ void generator::visit_function(ir::function* fn) {
     if(attr.is_llvm_attr()){
       llvm::Attribute llattr = cvt(attr);
       if(llattr.getKindAsEnum() != llvm::Attribute::None)
-        ret->addAttribute(id, cvt(attr));
+        ret->addAttributeAtIndex(id, cvt(attr));
     }
   }
   // set metadata
diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc
index a73e6541d..c58cd58c5 100644
--- a/lib/driver/llvm.cc
+++ b/lib/driver/llvm.cc
@@ -40,7 +40,8 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
-#include "llvm/Support/TargetRegistry.h"
+// #include "llvm/Support/TargetRegistry.h"  // that's a LLVM-11 PATH, following is LLVM-14's
+#include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/TargetSelect.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
@Jokeren
Copy link
Contributor

Jokeren commented Nov 11, 2022

Triton-MLIR uses LLVM 14. The master branch is now barely maintained.

@wanjunling168
Copy link
Author

OK, thanks for your reply @Jokeren

@jyizheng
Copy link

Which commit is this patch for?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants