From 6385efb2f85a9b4c7690089825ce895d159ac9c1 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 30 May 2024 20:15:54 +0000 Subject: [PATCH 1/3] [GEN] Add sub_group_reduce operator Signed-off-by: Whitney Tsang --- bin/CMakeLists.txt | 1 + bin/triton-opt.cpp | 2 + test/TritonGEN/tritongen-to-llvm.mlir | 141 ++++++++++++++++++ test/TritonGEN/tritongen.mlir | 33 ++++ .../Dialect/TritonGEN/IR/TritonGENAttrDefs.td | 20 +++ .../Dialect/TritonGEN/IR/TritonGENOps.td | 24 +++ .../intel/include/TritonGENToLLVM/Passes.td | 2 +- .../lib/Dialect/TritonGEN/IR/TritonGENOps.cpp | 9 ++ .../intel/lib/TritonGENToLLVM/CMakeLists.txt | 2 + .../TritonGENToLLVM/TritonGENToLLVMPass.cpp | 62 +++++++- 10 files changed, 292 insertions(+), 4 deletions(-) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 99a791fdad..91c5421a12 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -25,6 +25,7 @@ target_link_libraries(triton-opt PRIVATE MLIROptLib MLIRPass MLIRTransforms + MLIRSPIRVDialect ) mlir_check_all_link_libraries(triton-opt) diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 2d2570771a..541bf1a6f5 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -1,10 +1,12 @@ #include "./RegisterTritonDialects.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" int main(int argc, char **argv) { mlir::DialectRegistry registry; registerTritonDialects(registry); + registry.insert(); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Triton (GPU) optimizer driver\n", registry)); diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index 366219f314..7a3125866e 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -101,6 +101,147 @@ llvm.func @triton_gen.named_barrier(%barrier_id : i32, %thread_group_count : i32 // ----- +llvm.func @triton_gen.sub_group_reduce() { + %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 -> i32 + %10 = llvm.mlir.constant(0.0 : f32) : f32 + // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 -> f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 -> f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 -> f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 -> f32 + llvm.return +} + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> +} { + llvm.func @triton_gen.sub_group_reduce() { + %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 -> i32 + %10 = llvm.mlir.constant(0.0 : f32) : f32 + // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 + %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 -> f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 + %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 -> f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 + %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 -> f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 + %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 -> f32 + llvm.return + } +} + +// ----- + // CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xordj(f64, i32) -> f64 attributes {passthrough = ["convergent"]} // CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorfj(f32, i32) -> f32 attributes {passthrough = ["convergent"]} // CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorDhj(f16, i32) -> f16 attributes {passthrough = ["convergent"]} diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index 4bfbc1c549..cf388df2c3 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -66,6 +66,39 @@ llvm.func @triton_gen.named_barrier_wait(%barrier_id : i32) { llvm.return } +llvm.func @triton_gen.sub_group_reduce() { + // CHECK-LABEL: triton_gen.sub_group_reduce + %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: triton_gen.sub_group_reduce sum %0 {size = 16} : i32 -> i32 + %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 -> i32 + // CHECK: triton_gen.sub_group_reduce prod %0 {size = 16} : i32 -> i32 + %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 -> i32 + // CHECK: triton_gen.sub_group_reduce umin %0 {size = 16} : i32 -> i32 + %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 -> i32 + // CHECK: triton_gen.sub_group_reduce umax %0 {size = 16} : i32 -> i32 + %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 -> i32 + // CHECK: triton_gen.sub_group_reduce imin %0 {size = 16} : i32 -> i32 + %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 -> i32 + // CHECK: triton_gen.sub_group_reduce imax %0 {size = 16} : i32 -> i32 + %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 -> i32 + // CHECK: triton_gen.sub_group_reduce or %0 {size = 16} : i32 -> i32 + %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 -> i32 + // CHECK: triton_gen.sub_group_reduce xor %0 {size = 16} : i32 -> i32 + %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 -> i32 + // CHECK: triton_gen.sub_group_reduce and %0 {size = 16} : i32 -> i32 + %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 -> i32 + %10 = llvm.mlir.constant(0.0 : f32) : f32 + // CHECK: triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 -> f32 + %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 -> f32 + // CHECK: triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 -> f32 + %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 -> f32 + // CHECK: triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 -> f32 + %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 -> f32 + // CHECK: triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 -> f32 + %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 -> f32 + llvm.return +} + llvm.func @triton_gen.sub_group_shuffle() { // CHECK-LABEL: triton_gen.sub_group_shuffle %0 = llvm.mlir.constant(0 : i32) : i32 diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td index b609719818..46c2098653 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td @@ -11,6 +11,26 @@ include "mlir/IR/EnumAttr.td" +/// Enum attribute of the different reduce kinds. +def TritonGEN_ReduceKindAttr : I32EnumAttr<"ReduceKind", "TritonGEN reduce kind", + [ + I32EnumAttrCase<"SUM", 0, "sum">, + I32EnumAttrCase<"PROD", 1, "prod">, + I32EnumAttrCase<"UMIN", 2, "umin">, + I32EnumAttrCase<"UMAX", 3, "umax">, + I32EnumAttrCase<"IMIN", 4, "imin">, + I32EnumAttrCase<"IMAX", 5, "imax">, + I32EnumAttrCase<"OR", 6, "or">, + I32EnumAttrCase<"XOR", 7, "xor">, + I32EnumAttrCase<"AND", 8, "and">, + I32EnumAttrCase<"FSUM", 9, "fsum">, + I32EnumAttrCase<"FPROD", 10, "fprod">, + I32EnumAttrCase<"FMIN", 11, "fmin">, + I32EnumAttrCase<"FMAX", 12, "fmax"> + ]> { + let cppNamespace = "::mlir::triton::TritonGEN"; +} + /// Enum attribute of the different shuffle kinds. def TritonGEN_ShflKindAttr : I32EnumAttr<"ShflKind", "TritonGEN shuffle kind", [ diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index 98f35a8c84..1cd01a0314 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -237,6 +237,30 @@ def TritonGEN_NamedBarrierWaitOp : TritonGEN_Op<"named_barrier_wait">, def IntegerOrFloatType : AnyTypeOf<[AnyInteger, AnyFloat]>; +def TritonGEN_SubGroupReduceOp : TritonGEN_Op<"sub_group_reduce", [ + TypesMatchWith<"result and value have the same type", + "res", "value", "$_self">]>, + Results<(outs IntegerOrFloatType:$res)>, + Arguments<(ins IntegerOrFloatType:$value, + TritonGEN_ReduceKindAttr:$kind, + I32Attr:$size)> { + let summary = "Subgroup reduce"; + + string baseDescription = [{ + The `gen.sub_group_reduce` operation is invoked by all work items in a + subgroup, each of them providing a $value. The $size argument is used to + form groups of $size consecutive work items called clusters. Each cluster + performs the reduction operation identified by $kind. The result of the + cluster reduction is propagated to the work items belonging to that cluster. + }]; + + let assemblyFormat = [{ + $kind $value ` ` `{` `size` `=` $size `}` attr-dict `:` type($value) `->` type($res) + }]; + + let hasVerifier = 1; +} + def TritonGEN_SubGroupShuffleOp : TritonGEN_Op<"sub_group_shuffle", [ TypesMatchWith<"result and value have the same type", "res", "value", "$_self">]>, diff --git a/third_party/intel/include/TritonGENToLLVM/Passes.td b/third_party/intel/include/TritonGENToLLVM/Passes.td index 84a0a14186..10e76247fc 100644 --- a/third_party/intel/include/TritonGENToLLVM/Passes.td +++ b/third_party/intel/include/TritonGENToLLVM/Passes.td @@ -16,7 +16,7 @@ def ConvertTritonGENToLLVM : Pass<"convert-tritongen-to-llvm"> { let description = [{ This pass converts the TritonGEN dialect operations to LLVM dialect operations. }]; - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::spirv::SPIRVDialect"]; } #endif // TRITONGEN_TO_LLVM_CONVERSION_PASSES diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index cae45ade80..96e8358713 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -70,6 +70,15 @@ template static LogicalResult verifyInput(Op op) { return success(); } +//===----------------------------------------------------------------------===// +// gen.sub_group_reduce +//===----------------------------------------------------------------------===// + +LogicalResult TritonGEN::SubGroupReduceOp::verify() { + // TODO: Add verification for SubGroupReduceOp. + return success(); +} + //===----------------------------------------------------------------------===// // gen.matrix.dpas //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt index 7d82d3f36e..f24ba8efcf 100644 --- a/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt @@ -12,4 +12,6 @@ add_triton_library(TritonGENToLLVM LINK_LIBS PUBLIC GenISAIntrinsics + MLIRLLVMDialect + MLIRSPIRVDialect ) diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 8213300fa1..91fea65b54 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -19,6 +19,8 @@ #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -95,6 +97,13 @@ static std::string getTypeMangling(Type ty) { }); } +/// Get the subgroup size from the target or return a default. +static int getSubgroupSize(Operation *op) { + return spirv::lookupTargetEnvOrDefault(op) + .getResourceLimits() + .getSubgroupSize(); +} + static LLVM::CallOp createSubGroupShuffle(ConversionPatternRewriter &rewriter, Value value, Value mask, TritonGEN::ShflKind kind) { @@ -903,6 +912,52 @@ struct TritonGENNamedBarrierWaitLowering } }; +struct TritonSubGroupReduceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + TritonGEN::SubGroupReduceOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(TritonGEN::SubGroupReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value val = op.getValue(); + Type val_ty = val.getType(); + llvm::LLVMContext llvmContext; + LLVM::TypeToLLVMIRTranslator typeTranslator(llvmContext); + auto moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + auto kind = rewriter.create( + loc, i8_ty, static_cast(op.getKind())); + + std::string funcName; + SmallVector argTypes; + SmallVector args; + if (getSubgroupSize(op) == op.getSize()) { + funcName = llvm::GenISAIntrinsic::getName( + llvm::GenISAIntrinsic::GenISA_WaveAll, + {typeTranslator.translateType(val_ty)}); + argTypes = {val_ty, i8_ty, i32_ty}; + args = {val, kind, i32_val(0)}; + } else { + funcName = llvm::GenISAIntrinsic::getName( + llvm::GenISAIntrinsic::GenISA_WaveClustered, + {typeTranslator.translateType(val_ty)}); + argTypes = {val_ty, i8_ty, i32_ty, i32_ty}; + auto size = rewriter.create( + loc, i32_ty, static_cast(op.getSize())); + args = {val, kind, size, i32_val(0)}; + } + + LLVM::LLVMFuncOp funcOp = + LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, val_ty); + funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + + rewriter.replaceOp(op, rewriter.create(loc, funcOp, args)); + return success(); + } +}; + struct TritonSubGroupShuffleLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -1066,9 +1121,10 @@ void mlir::triton::populateTritonGENToLLVMConversionPatterns( TritonGENSubgroupIdLowering, TritonGENBarrierLowering, TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering, TritonGENNamedBarrierSignalLowering, TritonGENNamedBarrierWaitLowering, - TritonSubGroupShuffleLowering, TritonMatrixDPASLowering, - TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering, - TritonMatrix2DBlockPrefetchLowering>(converter); + TritonSubGroupReduceLowering, TritonSubGroupShuffleLowering, + TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering, + TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering>( + converter); } void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry ®istry) { From 343a89bd89c6ae44baa557faca5706391f649f04 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 31 May 2024 04:19:55 +0000 Subject: [PATCH 2/3] Require spirv.target_env Signed-off-by: Whitney Tsang --- test/TritonGEN/tritongen-to-llvm.mlir | 146 +++++++++--------- .../TritonGENToLLVM/TritonGENToLLVMPass.cpp | 8 +- 2 files changed, 79 insertions(+), 75 deletions(-) diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index 7a3125866e..d5ea5d2064 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -101,77 +101,81 @@ llvm.func @triton_gen.named_barrier(%barrier_id : i32, %thread_group_count : i32 // ----- -llvm.func @triton_gen.sub_group_reduce() { - %0 = llvm.mlir.constant(0 : i32) : i32 - // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 -> i32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 -> i32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 -> i32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 -> i32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 -> i32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 -> i32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 -> i32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 -> i32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 -> i32 - %10 = llvm.mlir.constant(0.0 : f32) : f32 - // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 - %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 -> f32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 - %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 -> f32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 - %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 -> f32 - // CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 - %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 -> f32 - llvm.return +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> +} { + llvm.func @triton_gen.sub_group_reduce() { + %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 -> i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 -> i32 + %10 = llvm.mlir.constant(0.0 : f32) : f32 + // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 -> f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 -> f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 -> f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 -> f32 + llvm.return + } } // ----- diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 91fea65b54..80c072f4e6 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -97,11 +97,11 @@ static std::string getTypeMangling(Type ty) { }); } -/// Get the subgroup size from the target or return a default. +/// Get the subgroup size from the target. static int getSubgroupSize(Operation *op) { - return spirv::lookupTargetEnvOrDefault(op) - .getResourceLimits() - .getSubgroupSize(); + spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op); + assert(attr && "Expecting valid target env attribute"); + return attr.getResourceLimits().getSubgroupSize(); } static LLVM::CallOp createSubGroupShuffle(ConversionPatternRewriter &rewriter, From fb5633084ff37db4384e8e439fb326be2b652ee4 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 31 May 2024 14:01:53 +0000 Subject: [PATCH 3/3] address review comments (part 1) Signed-off-by: Whitney Tsang --- test/TritonGEN/tritongen-to-llvm.mlir | 52 +++++++++---------- test/TritonGEN/tritongen.mlir | 52 +++++++++---------- .../Dialect/TritonGEN/IR/TritonGENOps.td | 11 ++-- .../intel/include/TritonGENToLLVM/Passes.td | 2 +- .../TritonGENToLLVM/TritonGENToLLVMPass.cpp | 3 +- 5 files changed, 59 insertions(+), 61 deletions(-) diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index d5ea5d2064..bd3763b98d 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -111,69 +111,69 @@ module attributes { // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 -> i32 + %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 -> i32 + %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 -> i32 + %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 -> i32 + %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 -> i32 + %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 -> i32 + %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 -> i32 + %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 -> i32 + %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 - %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 -> i32 + %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 %10 = llvm.mlir.constant(0.0 : f32) : f32 // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 - %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 -> f32 + %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 - %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 -> f32 + %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 - %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 -> f32 + %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8 // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 - %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 -> f32 + %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 llvm.return } } @@ -189,57 +189,57 @@ module attributes { // CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 - %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 -> i32 + %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 - %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 -> i32 + %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 - %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 -> i32 + %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 - %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 -> i32 + %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 - %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 -> i32 + %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 - %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 -> i32 + %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 - %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 -> i32 + %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 - %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 -> i32 + %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 - %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 -> i32 + %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 %10 = llvm.mlir.constant(0.0 : f32) : f32 // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 - %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 -> f32 + %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 - %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 -> f32 + %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 - %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 -> f32 + %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 // CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 - %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 -> f32 + %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 llvm.return } } diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index cf388df2c3..b1baa74d25 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -69,33 +69,33 @@ llvm.func @triton_gen.named_barrier_wait(%barrier_id : i32) { llvm.func @triton_gen.sub_group_reduce() { // CHECK-LABEL: triton_gen.sub_group_reduce %0 = llvm.mlir.constant(0 : i32) : i32 - // CHECK: triton_gen.sub_group_reduce sum %0 {size = 16} : i32 -> i32 - %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 -> i32 - // CHECK: triton_gen.sub_group_reduce prod %0 {size = 16} : i32 -> i32 - %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 -> i32 - // CHECK: triton_gen.sub_group_reduce umin %0 {size = 16} : i32 -> i32 - %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 -> i32 - // CHECK: triton_gen.sub_group_reduce umax %0 {size = 16} : i32 -> i32 - %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 -> i32 - // CHECK: triton_gen.sub_group_reduce imin %0 {size = 16} : i32 -> i32 - %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 -> i32 - // CHECK: triton_gen.sub_group_reduce imax %0 {size = 16} : i32 -> i32 - %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 -> i32 - // CHECK: triton_gen.sub_group_reduce or %0 {size = 16} : i32 -> i32 - %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 -> i32 - // CHECK: triton_gen.sub_group_reduce xor %0 {size = 16} : i32 -> i32 - %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 -> i32 - // CHECK: triton_gen.sub_group_reduce and %0 {size = 16} : i32 -> i32 - %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 -> i32 + // CHECK: triton_gen.sub_group_reduce sum %0 {size = 16} : i32 + %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce prod %0 {size = 16} : i32 + %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce umin %0 {size = 16} : i32 + %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce umax %0 {size = 16} : i32 + %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce imin %0 {size = 16} : i32 + %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce imax %0 {size = 16} : i32 + %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce or %0 {size = 16} : i32 + %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce xor %0 {size = 16} : i32 + %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce and %0 {size = 16} : i32 + %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 %10 = llvm.mlir.constant(0.0 : f32) : f32 - // CHECK: triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 -> f32 - %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 -> f32 - // CHECK: triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 -> f32 - %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 -> f32 - // CHECK: triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 -> f32 - %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 -> f32 - // CHECK: triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 -> f32 - %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 -> f32 + // CHECK: triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 + %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 + // CHECK: triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 + %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 + // CHECK: triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 + %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 + // CHECK: triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 + %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 llvm.return } diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index 1cd01a0314..b9f913b0fe 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -238,24 +238,23 @@ def TritonGEN_NamedBarrierWaitOp : TritonGEN_Op<"named_barrier_wait">, def IntegerOrFloatType : AnyTypeOf<[AnyInteger, AnyFloat]>; def TritonGEN_SubGroupReduceOp : TritonGEN_Op<"sub_group_reduce", [ - TypesMatchWith<"result and value have the same type", - "res", "value", "$_self">]>, + AllTypesMatch<["res", "value"]>]>, Results<(outs IntegerOrFloatType:$res)>, Arguments<(ins IntegerOrFloatType:$value, TritonGEN_ReduceKindAttr:$kind, I32Attr:$size)> { let summary = "Subgroup reduce"; - string baseDescription = [{ - The `gen.sub_group_reduce` operation is invoked by all work items in a - subgroup, each of them providing a $value. The $size argument is used to + let description = [{ + The `triton_gen.sub_group_reduce` operation is invoked by all work items in + a subgroup, each of them providing a $value. The $size argument is used to form groups of $size consecutive work items called clusters. Each cluster performs the reduction operation identified by $kind. The result of the cluster reduction is propagated to the work items belonging to that cluster. }]; let assemblyFormat = [{ - $kind $value ` ` `{` `size` `=` $size `}` attr-dict `:` type($value) `->` type($res) + $kind $value ` ` `{` `size` `=` $size `}` attr-dict `:` type($value) }]; let hasVerifier = 1; diff --git a/third_party/intel/include/TritonGENToLLVM/Passes.td b/third_party/intel/include/TritonGENToLLVM/Passes.td index 10e76247fc..84a0a14186 100644 --- a/third_party/intel/include/TritonGENToLLVM/Passes.td +++ b/third_party/intel/include/TritonGENToLLVM/Passes.td @@ -16,7 +16,7 @@ def ConvertTritonGENToLLVM : Pass<"convert-tritongen-to-llvm"> { let description = [{ This pass converts the TritonGEN dialect operations to LLVM dialect operations. }]; - let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::spirv::SPIRVDialect"]; + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; } #endif // TRITONGEN_TO_LLVM_CONVERSION_PASSES diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 80c072f4e6..0c46488a61 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -925,8 +925,7 @@ struct TritonSubGroupReduceLowering Type val_ty = val.getType(); llvm::LLVMContext llvmContext; LLVM::TypeToLLVMIRTranslator typeTranslator(llvmContext); - auto moduleOp = - rewriter.getBlock()->getParent()->getParentOfType(); + auto moduleOp = op->getParentOfType(); auto kind = rewriter.create( loc, i8_ty, static_cast(op.getKind()));