diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 99a791fdad..91c5421a12 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -25,6 +25,7 @@ target_link_libraries(triton-opt PRIVATE MLIROptLib MLIRPass MLIRTransforms + MLIRSPIRVDialect ) mlir_check_all_link_libraries(triton-opt) diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 2d2570771a..541bf1a6f5 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -1,10 +1,12 @@ #include "./RegisterTritonDialects.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" int main(int argc, char **argv) { mlir::DialectRegistry registry; registerTritonDialects(registry); + registry.insert(); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Triton (GPU) optimizer driver\n", registry)); diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index 366219f314..bd3763b98d 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -101,6 +101,151 @@ llvm.func @triton_gen.named_barrier(%barrier_id : i32, %thread_group_count : i32 // ----- +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> +} { + llvm.func @triton_gen.sub_group_reduce() { + %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32 + %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 + %10 = llvm.mlir.constant(0.0 : f32) : f32 + // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8 + // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32 + %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 + llvm.return + } +} + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> +} { + llvm.func @triton_gen.sub_group_reduce() { + %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32 + %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 + %10 = llvm.mlir.constant(0.0 : f32) : f32 + // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 + %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 + %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 + %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 + // CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8 + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32 + %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 + llvm.return + } +} + +// ----- + // CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xordj(f64, i32) -> f64 attributes {passthrough = ["convergent"]} // CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorfj(f32, i32) -> f32 attributes {passthrough = ["convergent"]} // CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorDhj(f16, i32) -> f16 attributes {passthrough = ["convergent"]} diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index 4bfbc1c549..b1baa74d25 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -66,6 +66,39 @@ llvm.func @triton_gen.named_barrier_wait(%barrier_id : i32) { llvm.return } +llvm.func @triton_gen.sub_group_reduce() { + // CHECK-LABEL: triton_gen.sub_group_reduce + %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: triton_gen.sub_group_reduce sum %0 {size = 16} : i32 + %1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce prod %0 {size = 16} : i32 + %2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce umin %0 {size = 16} : i32 + %3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce umax %0 {size = 16} : i32 + %4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce imin %0 {size = 16} : i32 + %5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce imax %0 {size = 16} : i32 + %6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce or %0 {size = 16} : i32 + %7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce xor %0 {size = 16} : i32 + %8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 + // CHECK: triton_gen.sub_group_reduce and %0 {size = 16} : i32 + %9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 + %10 = llvm.mlir.constant(0.0 : f32) : f32 + // CHECK: triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 + %11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32 + // CHECK: triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 + %12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32 + // CHECK: triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 + %13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32 + // CHECK: triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 + %14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32 + llvm.return +} + llvm.func @triton_gen.sub_group_shuffle() { // CHECK-LABEL: triton_gen.sub_group_shuffle %0 = llvm.mlir.constant(0 : i32) : i32 diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td index b609719818..46c2098653 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td @@ -11,6 +11,26 @@ include "mlir/IR/EnumAttr.td" +/// Enum attribute of the different reduce kinds. +def TritonGEN_ReduceKindAttr : I32EnumAttr<"ReduceKind", "TritonGEN reduce kind", + [ + I32EnumAttrCase<"SUM", 0, "sum">, + I32EnumAttrCase<"PROD", 1, "prod">, + I32EnumAttrCase<"UMIN", 2, "umin">, + I32EnumAttrCase<"UMAX", 3, "umax">, + I32EnumAttrCase<"IMIN", 4, "imin">, + I32EnumAttrCase<"IMAX", 5, "imax">, + I32EnumAttrCase<"OR", 6, "or">, + I32EnumAttrCase<"XOR", 7, "xor">, + I32EnumAttrCase<"AND", 8, "and">, + I32EnumAttrCase<"FSUM", 9, "fsum">, + I32EnumAttrCase<"FPROD", 10, "fprod">, + I32EnumAttrCase<"FMIN", 11, "fmin">, + I32EnumAttrCase<"FMAX", 12, "fmax"> + ]> { + let cppNamespace = "::mlir::triton::TritonGEN"; +} + /// Enum attribute of the different shuffle kinds. def TritonGEN_ShflKindAttr : I32EnumAttr<"ShflKind", "TritonGEN shuffle kind", [ diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index 98f35a8c84..b9f913b0fe 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -237,6 +237,29 @@ def TritonGEN_NamedBarrierWaitOp : TritonGEN_Op<"named_barrier_wait">, def IntegerOrFloatType : AnyTypeOf<[AnyInteger, AnyFloat]>; +def TritonGEN_SubGroupReduceOp : TritonGEN_Op<"sub_group_reduce", [ + AllTypesMatch<["res", "value"]>]>, + Results<(outs IntegerOrFloatType:$res)>, + Arguments<(ins IntegerOrFloatType:$value, + TritonGEN_ReduceKindAttr:$kind, + I32Attr:$size)> { + let summary = "Subgroup reduce"; + + let description = [{ + The `triton_gen.sub_group_reduce` operation is invoked by all work items in + a subgroup, each of them providing a $value. The $size argument is used to + form groups of $size consecutive work items called clusters. Each cluster + performs the reduction operation identified by $kind. The result of the + cluster reduction is propagated to the work items belonging to that cluster. + }]; + + let assemblyFormat = [{ + $kind $value ` ` `{` `size` `=` $size `}` attr-dict `:` type($value) + }]; + + let hasVerifier = 1; +} + def TritonGEN_SubGroupShuffleOp : TritonGEN_Op<"sub_group_shuffle", [ TypesMatchWith<"result and value have the same type", "res", "value", "$_self">]>, diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index cae45ade80..96e8358713 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -70,6 +70,15 @@ template static LogicalResult verifyInput(Op op) { return success(); } +//===----------------------------------------------------------------------===// +// gen.sub_group_reduce +//===----------------------------------------------------------------------===// + +LogicalResult TritonGEN::SubGroupReduceOp::verify() { + // TODO: Add verification for SubGroupReduceOp. + return success(); +} + //===----------------------------------------------------------------------===// // gen.matrix.dpas //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt index 7d82d3f36e..f24ba8efcf 100644 --- a/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt @@ -12,4 +12,6 @@ add_triton_library(TritonGENToLLVM LINK_LIBS PUBLIC GenISAIntrinsics + MLIRLLVMDialect + MLIRSPIRVDialect ) diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 8213300fa1..0c46488a61 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -19,6 +19,8 @@ #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -95,6 +97,13 @@ static std::string getTypeMangling(Type ty) { }); } +/// Get the subgroup size from the target. +static int getSubgroupSize(Operation *op) { + spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op); + assert(attr && "Expecting valid target env attribute"); + return attr.getResourceLimits().getSubgroupSize(); +} + static LLVM::CallOp createSubGroupShuffle(ConversionPatternRewriter &rewriter, Value value, Value mask, TritonGEN::ShflKind kind) { @@ -903,6 +912,51 @@ struct TritonGENNamedBarrierWaitLowering } }; +struct TritonSubGroupReduceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + TritonGEN::SubGroupReduceOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(TritonGEN::SubGroupReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value val = op.getValue(); + Type val_ty = val.getType(); + llvm::LLVMContext llvmContext; + LLVM::TypeToLLVMIRTranslator typeTranslator(llvmContext); + auto moduleOp = op->getParentOfType(); + auto kind = rewriter.create( + loc, i8_ty, static_cast(op.getKind())); + + std::string funcName; + SmallVector argTypes; + SmallVector args; + if (getSubgroupSize(op) == op.getSize()) { + funcName = llvm::GenISAIntrinsic::getName( + llvm::GenISAIntrinsic::GenISA_WaveAll, + {typeTranslator.translateType(val_ty)}); + argTypes = {val_ty, i8_ty, i32_ty}; + args = {val, kind, i32_val(0)}; + } else { + funcName = llvm::GenISAIntrinsic::getName( + llvm::GenISAIntrinsic::GenISA_WaveClustered, + {typeTranslator.translateType(val_ty)}); + argTypes = {val_ty, i8_ty, i32_ty, i32_ty}; + auto size = rewriter.create( + loc, i32_ty, static_cast(op.getSize())); + args = {val, kind, size, i32_val(0)}; + } + + LLVM::LLVMFuncOp funcOp = + LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, val_ty); + funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + + rewriter.replaceOp(op, rewriter.create(loc, funcOp, args)); + return success(); + } +}; + struct TritonSubGroupShuffleLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -1066,9 +1120,10 @@ void mlir::triton::populateTritonGENToLLVMConversionPatterns( TritonGENSubgroupIdLowering, TritonGENBarrierLowering, TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering, TritonGENNamedBarrierSignalLowering, TritonGENNamedBarrierWaitLowering, - TritonSubGroupShuffleLowering, TritonMatrixDPASLowering, - TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering, - TritonMatrix2DBlockPrefetchLowering>(converter); + TritonSubGroupReduceLowering, TritonSubGroupShuffleLowering, + TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering, + TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering>( + converter); } void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry ®istry) {