Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ target_link_libraries(triton-opt PRIVATE
MLIROptLib
MLIRPass
MLIRTransforms
MLIRSPIRVDialect
)

mlir_check_all_link_libraries(triton-opt)
Expand Down
2 changes: 2 additions & 0 deletions bin/triton-opt.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include "./RegisterTritonDialects.h"

#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"

int main(int argc, char **argv) {
mlir::DialectRegistry registry;
registerTritonDialects(registry);
registry.insert<mlir::spirv::SPIRVDialect>();

return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "Triton (GPU) optimizer driver\n", registry));
Expand Down
145 changes: 145 additions & 0 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,151 @@ llvm.func @triton_gen.named_barrier(%barrier_id : i32, %thread_group_count : i32

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 32>>
Comment thread
victor-eds marked this conversation as resolved.
} {
llvm.func @triton_gen.sub_group_reduce() {
%0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.i32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (i32, i8, i32, i32) -> i32
%9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32
%10 = llvm.mlir.constant(0.0 : f32) : f32
// CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32
%11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32
%12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32
%13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8
// CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveClustered.f32([[VAL]], [[KIND]], [[SIZE]], [[ZERO]]) : (f32, i8, i32, i32) -> f32
%14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32
llvm.return
}
}

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 16>>
} {
llvm.func @triton_gen.sub_group_reduce() {
%0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(1 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(2 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(3 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(4 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(5 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(6 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(7 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(8 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.i32([[VAL]], [[KIND]], [[ZERO]]) : (i32, i8, i32) -> i32
%9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32
%10 = llvm.mlir.constant(0.0 : f32) : f32
// CHECK: [[VAL:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(9 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32
%11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(10 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32
%12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(11 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32
%13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32
// CHECK: [[KIND:%.*]] = llvm.mlir.constant(12 : i8) : i8
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call @llvm.genx.GenISA.WaveAll.f32([[VAL]], [[KIND]], [[ZERO]]) : (f32, i8, i32) -> f32
%14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32
llvm.return
}
}

// -----

// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xordj(f64, i32) -> f64 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorfj(f32, i32) -> f32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorDhj(f16, i32) -> f16 attributes {passthrough = ["convergent"]}
Expand Down
33 changes: 33 additions & 0 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,39 @@ llvm.func @triton_gen.named_barrier_wait(%barrier_id : i32) {
llvm.return
}

llvm.func @triton_gen.sub_group_reduce() {
// CHECK-LABEL: triton_gen.sub_group_reduce
%0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: triton_gen.sub_group_reduce sum %0 {size = 16} : i32
%1 = triton_gen.sub_group_reduce sum %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce prod %0 {size = 16} : i32
%2 = triton_gen.sub_group_reduce prod %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce umin %0 {size = 16} : i32
%3 = triton_gen.sub_group_reduce umin %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce umax %0 {size = 16} : i32
%4 = triton_gen.sub_group_reduce umax %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce imin %0 {size = 16} : i32
%5 = triton_gen.sub_group_reduce imin %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce imax %0 {size = 16} : i32
%6 = triton_gen.sub_group_reduce imax %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce or %0 {size = 16} : i32
%7 = triton_gen.sub_group_reduce or %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce xor %0 {size = 16} : i32
%8 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32
// CHECK: triton_gen.sub_group_reduce and %0 {size = 16} : i32
%9 = triton_gen.sub_group_reduce and %0 {size = 16} : i32
%10 = llvm.mlir.constant(0.0 : f32) : f32
// CHECK: triton_gen.sub_group_reduce fsum %10 {size = 16} : f32
%11 = triton_gen.sub_group_reduce fsum %10 {size = 16} : f32
// CHECK: triton_gen.sub_group_reduce fprod %10 {size = 16} : f32
%12 = triton_gen.sub_group_reduce fprod %10 {size = 16} : f32
// CHECK: triton_gen.sub_group_reduce fmin %10 {size = 16} : f32
%13 = triton_gen.sub_group_reduce fmin %10 {size = 16} : f32
// CHECK: triton_gen.sub_group_reduce fmax %10 {size = 16} : f32
%14 = triton_gen.sub_group_reduce fmax %10 {size = 16} : f32
llvm.return
}

llvm.func @triton_gen.sub_group_shuffle() {
// CHECK-LABEL: triton_gen.sub_group_shuffle
%0 = llvm.mlir.constant(0 : i32) : i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@

include "mlir/IR/EnumAttr.td"

/// Enum attribute of the different reduce kinds.
def TritonGEN_ReduceKindAttr : I32EnumAttr<"ReduceKind", "TritonGEN reduce kind",
[
I32EnumAttrCase<"SUM", 0, "sum">,
I32EnumAttrCase<"PROD", 1, "prod">,
I32EnumAttrCase<"UMIN", 2, "umin">,
I32EnumAttrCase<"UMAX", 3, "umax">,
I32EnumAttrCase<"IMIN", 4, "imin">,
I32EnumAttrCase<"IMAX", 5, "imax">,
I32EnumAttrCase<"OR", 6, "or">,
I32EnumAttrCase<"XOR", 7, "xor">,
I32EnumAttrCase<"AND", 8, "and">,
I32EnumAttrCase<"FSUM", 9, "fsum">,
I32EnumAttrCase<"FPROD", 10, "fprod">,
I32EnumAttrCase<"FMIN", 11, "fmin">,
I32EnumAttrCase<"FMAX", 12, "fmax">
]> {
let cppNamespace = "::mlir::triton::TritonGEN";
}

/// Enum attribute of the different shuffle kinds.
def TritonGEN_ShflKindAttr : I32EnumAttr<"ShflKind", "TritonGEN shuffle kind",
[
Expand Down
23 changes: 23 additions & 0 deletions third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,29 @@ def TritonGEN_NamedBarrierWaitOp : TritonGEN_Op<"named_barrier_wait">,

def IntegerOrFloatType : AnyTypeOf<[AnyInteger, AnyFloat]>;

def TritonGEN_SubGroupReduceOp : TritonGEN_Op<"sub_group_reduce", [
AllTypesMatch<["res", "value"]>]>,
Results<(outs IntegerOrFloatType:$res)>,
Arguments<(ins IntegerOrFloatType:$value,
TritonGEN_ReduceKindAttr:$kind,
I32Attr:$size)> {
let summary = "Subgroup reduce";

let description = [{
The `triton_gen.sub_group_reduce` operation is invoked by all work items in
a subgroup, each of them providing a $value. The $size argument is used to
form groups of $size consecutive work items called clusters. Each cluster
performs the reduction operation identified by $kind. The result of the
cluster reduction is propagated to the work items belonging to that cluster.
}];

let assemblyFormat = [{
$kind $value ` ` `{` `size` `=` $size `}` attr-dict `:` type($value)
}];

let hasVerifier = 1;
}

def TritonGEN_SubGroupShuffleOp : TritonGEN_Op<"sub_group_shuffle", [
TypesMatchWith<"result and value have the same type",
"res", "value", "$_self">]>,
Expand Down
9 changes: 9 additions & 0 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ template <typename Op> static LogicalResult verifyInput(Op op) {
return success();
}

//===----------------------------------------------------------------------===//
// gen.sub_group_reduce
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::SubGroupReduceOp::verify() {
// TODO: Add verification for SubGroupReduceOp.
return success();
}
Comment thread
victor-eds marked this conversation as resolved.

//===----------------------------------------------------------------------===//
// gen.matrix.dpas
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ add_triton_library(TritonGENToLLVM

LINK_LIBS PUBLIC
GenISAIntrinsics
MLIRLLVMDialect
MLIRSPIRVDialect
)
61 changes: 58 additions & 3 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down Expand Up @@ -95,6 +97,13 @@ static std::string getTypeMangling(Type ty) {
});
}

/// Get the subgroup size from the target.
static int getSubgroupSize(Operation *op) {
spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op);
assert(attr && "Expecting valid target env attribute");
return attr.getResourceLimits().getSubgroupSize();
}

static LLVM::CallOp createSubGroupShuffle(ConversionPatternRewriter &rewriter,
Value value, Value mask,
TritonGEN::ShflKind kind) {
Expand Down Expand Up @@ -903,6 +912,51 @@ struct TritonGENNamedBarrierWaitLowering
}
};

struct TritonSubGroupReduceLowering
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupReduceOp> {
using ConvertOpToLLVMPattern<
TritonGEN::SubGroupReduceOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(TritonGEN::SubGroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value val = op.getValue();
Type val_ty = val.getType();
llvm::LLVMContext llvmContext;
LLVM::TypeToLLVMIRTranslator typeTranslator(llvmContext);
auto moduleOp = op->getParentOfType<ModuleOp>();
auto kind = rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, static_cast<int>(op.getKind()));

std::string funcName;
SmallVector<Type> argTypes;
SmallVector<Value> args;
if (getSubgroupSize(op) == op.getSize()) {
funcName = llvm::GenISAIntrinsic::getName(
llvm::GenISAIntrinsic::GenISA_WaveAll,
{typeTranslator.translateType(val_ty)});
argTypes = {val_ty, i8_ty, i32_ty};
args = {val, kind, i32_val(0)};
} else {
funcName = llvm::GenISAIntrinsic::getName(
llvm::GenISAIntrinsic::GenISA_WaveClustered,
{typeTranslator.translateType(val_ty)});
argTypes = {val_ty, i8_ty, i32_ty, i32_ty};
auto size = rewriter.create<LLVM::ConstantOp>(
loc, i32_ty, static_cast<int>(op.getSize()));
args = {val, kind, size, i32_val(0)};
Comment thread
victor-eds marked this conversation as resolved.
}

LLVM::LLVMFuncOp funcOp =
LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, val_ty);
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);

rewriter.replaceOp(op, rewriter.create<LLVM::CallOp>(loc, funcOp, args));
Comment thread
whitneywhtsang marked this conversation as resolved.
return success();
}
};

struct TritonSubGroupShuffleLowering
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupShuffleOp> {
using ConvertOpToLLVMPattern<
Expand Down Expand Up @@ -1066,9 +1120,10 @@ void mlir::triton::populateTritonGENToLLVMConversionPatterns(
TritonGENSubgroupIdLowering, TritonGENBarrierLowering,
TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering,
TritonGENNamedBarrierSignalLowering, TritonGENNamedBarrierWaitLowering,
TritonSubGroupShuffleLowering, TritonMatrixDPASLowering,
TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
TritonMatrix2DBlockPrefetchLowering>(converter);
TritonSubGroupReduceLowering, TritonSubGroupShuffleLowering,
TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering>(
converter);
}

void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {
Expand Down