From 33558fafb438d6cda9b77a11ca3122b980a7704d Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 24 Apr 2025 20:09:21 +0000 Subject: [PATCH 1/5] [TTIG_PrefetchOp] Add `mask` argument Signed-off-by: Whitney Tsang --- test/TritonIntelGPU/tritonintelgpu.mlir | 11 ++++++++++ .../TritonIntelGPU/IR/TritonIntelGPUOps.td | 20 +++++++++++++++---- .../lib/Dialect/TritonIntelGPU/IR/Ops.cpp | 6 ++++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/test/TritonIntelGPU/tritonintelgpu.mlir b/test/TritonIntelGPU/tritonintelgpu.mlir index 3d486780b2..6311789e49 100644 --- a/test/TritonIntelGPU/tritonintelgpu.mlir +++ b/test/TritonIntelGPU/tritonintelgpu.mlir @@ -50,6 +50,17 @@ tt.func @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg // ----- +tt.func @triton_intel_gpu.prefetch(%arg0: !tt.ptr>, %arg1: tensor<2x32xi1>) { + // CHECK-LABEL: @triton_intel_gpu.prefetch + // CHECK: triton_intel_gpu.prefetch %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> + triton_intel_gpu.prefetch %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch %arg0, %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> + triton_intel_gpu.prefetch %arg0, %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> + tt.return +} + +// ----- + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { tt.func @triton_intel_gpu.sub_group_transpose(%local_buffer : !tt.ptr, %src : tensor<16x16xf16>) -> tensor<16x16xf16> { // CHECK-LABEL: @triton_intel_gpu.sub_group_transpose diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td index 85d33ed5c5..cbc3e63d13 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td @@ -107,7 +107,10 @@ def TTIG_ExtractOp : TTIG_Op<"extract", [Pure]> { let hasFolder = 1; } -def TTIG_PrefetchOp : TTIG_Op<"prefetch"> { +def TTIG_PrefetchOp : TTIG_Op<"prefetch", [ + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">, +]> { let summary = "Tensor prefetch operation"; let description = [{ The `prefetch` operation prefetches an input tensor. @@ -117,11 +120,20 @@ def TTIG_PrefetchOp : TTIG_Op<"prefetch"> { : !tt.ptr ``` }]; - let arguments = (ins AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, TT_CacheModifierAttr:$cache, - TT_EvictionPolicyAttr:$evict, BoolAttr:$isVolatile); + let arguments = ( + ins AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + Optional:$mask, + TT_CacheModifierAttr:$cache, + TT_EvictionPolicyAttr:$evict, + BoolAttr:$isVolatile + ); let results = (outs); + let builders = [ + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + ]; let assemblyFormat = [{ - operands attr-dict `:` type($ptr) + $ptr (`,` $mask^)? attr-dict `:` type($ptr) }]; } diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp index c89b76a491..b025fdc253 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp @@ -197,6 +197,12 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { return {}; } +void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, + bool isVolatile) { + PrefetchOp::build(builder, state, ptr, /*mask=*/{}, cache, evict, isVolatile); +} + LogicalResult SubGroupTransposeOp::verify() { RankedTensorType srcType = getSrc().getType(); auto mod = getOperation()->getParentOfType(); From 1a385442e3d8385fc4f5c0889109263c6aa1bd81 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 25 Apr 2025 14:41:17 +0000 Subject: [PATCH 2/5] add invalid test case Signed-off-by: Whitney Tsang --- test/TritonIntelGPU/tritonintelgpu-invalid.mlir | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/TritonIntelGPU/tritonintelgpu-invalid.mlir b/test/TritonIntelGPU/tritonintelgpu-invalid.mlir index 8f209712ff..477d403c9d 100644 --- a/test/TritonIntelGPU/tritonintelgpu-invalid.mlir +++ b/test/TritonIntelGPU/tritonintelgpu-invalid.mlir @@ -130,6 +130,15 @@ tt.func @triton_intel_gpu.extract(%ptr : !tt.ptr>) { // ----- +tt.func @triton_intel_gpu.prefetch(%arg0: !tt.ptr>, %arg1: tensor<4x32xi1>) { + // expected-note@-1 {{prior use here}} + // expected-error@+1 {{use of value '%arg1' expects different type than prior uses: 'tensor<2x32xi1>' vs 'tensor<4x32xi1>'}} + triton_intel_gpu.prefetch %arg0, %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> + tt.return +} + +// ----- + #warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { From ed31253bfc11eeaa8dcd35ff3109168a23ae79f0 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 25 Apr 2025 04:50:19 +0000 Subject: [PATCH 3/5] [MatmulLoopPipeline] Predicate `PrefetchOp` Signed-off-by: Whitney Tsang --- test/TritonIntelGPU/loop-pipeline.mlir | 34 ++++++++++--------- test/TritonIntelGPU/split-barrier.mlir | 4 +-- .../Pipeliner/MatmulLoopPipeline.cpp | 22 ++++++------ 3 files changed, 31 insertions(+), 29 deletions(-) diff --git a/test/TritonIntelGPU/loop-pipeline.mlir b/test/TritonIntelGPU/loop-pipeline.mlir index 4635fedea1..1b43379c1f 100644 --- a/test/TritonIntelGPU/loop-pipeline.mlir +++ b/test/TritonIntelGPU/loop-pipeline.mlir @@ -79,15 +79,17 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, %51 = arith.muli %arg7, %c32_i32 : i32 %52 = tt.splat %51 : i32 -> tensor<32x256xi32, #blocked1> // COM: There are 3 stages in loop pipelining, the first 2 prefetching stages are before the loop and the last one is inside the loop. + // CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}} + // CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]> // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> + // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> + // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: scf.for %[[VAL_92:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args(%[[VAL_93:.*]] = {{.*}}, %[[VAL_94:.*]] = {{.*}}, %[[VAL_95:.*]] = {{.*}}, %[[VAL_96:.*]] = {{.*}}, %[[VAL_97:.*]] = {{.*}}) -> (tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>) : i32 { // CHECK: %[[VAL_106:.*]] = tt.addptr %[[VAL_94]], {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<64x32xi32, #[[$BLOCK_0]]> // CHECK: %[[VAL_107:.*]] = tt.addptr %[[VAL_95]], {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>, tensor<32x256xi32, #[[$BLOCK_1]]> - // CHECK: triton_intel_gpu.prefetch %[[VAL_106]] {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> - // CHECK: triton_intel_gpu.prefetch %[[VAL_107]] {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> + // CHECK: triton_intel_gpu.prefetch %[[VAL_106]], {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> + // CHECK: triton_intel_gpu.prefetch %[[VAL_107]], {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: %[[VAL_116:.*]] = tt.load %[[VAL_96]], {{.*}}, {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> // CHECK: %[[VAL_120:.*]] = tt.load %[[VAL_97]], {{.*}}, {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: %[[VAL_121:.*]] = ttg.convert_layout %[[VAL_116]] : tensor<64x32xf16, #[[$BLOCK_0]]> -> tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> @@ -166,12 +168,12 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32 %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> // CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr>>, !tt.ptr>>, !tt.ptr>>, !tt.ptr>>) // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> // CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]> // CHECK-NEXT: scf.yield %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { @@ -239,12 +241,12 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32 %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> // CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr>>, !tt.ptr>>, !tt.ptr>>, !tt.ptr>>) // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> // CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]> // CHECK-NEXT: scf.yield %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { @@ -302,18 +304,18 @@ module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.sup %12 = arith.extsi %arg3 : i32 to i64 // CHECK: scf.for %[[OUTER_IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (i32) // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR1]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR1]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR2]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR2]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK: [[PTR3:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR3]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR3]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK: [[PTR4:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR4]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR4]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK-NEXT: scf.for %[[INNER_IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x128xf32, #blocked>, !tt.ptr>, !tt.ptr>, !tt.ptr>, !tt.ptr>) // CHECK: [[PTR5:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR5]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR5]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK: [[PTR6:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : > - // CHECK-NEXT: triton_intel_gpu.prefetch [[PTR6]] {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch [[PTR6]], {{.*}} {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %13 = scf.for %arg6 = %0 to %7 step %c448_i32 iter_args(%arg7 = %10) -> (i32) : i32 { %14 = arith.divsi %arg6, %11 : i32 %15 = arith.muli %14, %c8_i32 : i32 diff --git a/test/TritonIntelGPU/split-barrier.mlir b/test/TritonIntelGPU/split-barrier.mlir index db559c3e8b..a2db6e5c93 100644 --- a/test/TritonIntelGPU/split-barrier.mlir +++ b/test/TritonIntelGPU/split-barrier.mlir @@ -26,7 +26,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32 // WORKGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive // SUBGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> // CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]> // WORKGROUP_SCOPE: spirv.INTEL.ControlBarrierWait // SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait @@ -73,7 +73,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32 // WORKGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive // SUBGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr>> - // CHECK-NEXT: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> + // CHECK: triton_intel_gpu.prefetch {{.*}} : !tt.ptr> // CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]> // WORKGROUP_SCOPE: spirv.INTEL.ControlBarrierWait // SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp index 1788ba753b..a7e93f35e4 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp @@ -6,6 +6,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" @@ -154,7 +155,7 @@ static Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask, Value pred) { Location loc = pred.getLoc(); Value mask = pred; - Type maskType = tt::getI1SameShape(typeLike); + Type maskType = tt::getI1SameShape(tt::getPointeeType(typeLike)); if (isa(maskType)) mask = rewriter.create(loc, maskType, pred); @@ -167,18 +168,17 @@ static Value getPredMask(RewriterBase &rewriter, Type typeLike, static Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred) { OpBuilder::InsertionGuard guard(rewriter); - if (mlir::isMemoryEffectFree(op) || isa(op)) + if (mlir::isMemoryEffectFree(op)) return op; - if (auto loadOp = dyn_cast(op)) { - rewriter.setInsertionPoint(loadOp); - Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), - loadOp.getMask(), pred); - loadOp.getMaskMutable().assign(mask); - return loadOp; - } - - llvm_unreachable("don't know how to predicate this operation"); + return TypeSwitch(op) + .Case([&](auto op) { + rewriter.setInsertionPoint(op); + Value mask = + getPredMask(rewriter, op.getPtr().getType(), op.getMask(), pred); + op.getMaskMutable().assign(mask); + return op; + }); } /// Helper to get the defining operation of a value. From cc4150d4f734f238289d34e7e662a3272c4b6ee5 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 25 Apr 2025 15:11:17 +0000 Subject: [PATCH 4/5] address review comments Signed-off-by: Whitney Tsang --- test/TritonIntelGPU/loop-pipeline.mlir | 2 +- .../Pipeliner/MatmulLoopPipeline.cpp | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/TritonIntelGPU/loop-pipeline.mlir b/test/TritonIntelGPU/loop-pipeline.mlir index 1b43379c1f..f4f6f394e2 100644 --- a/test/TritonIntelGPU/loop-pipeline.mlir +++ b/test/TritonIntelGPU/loop-pipeline.mlir @@ -81,7 +81,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, // COM: There are 3 stages in loop pipelining, the first 2 prefetching stages are before the loop and the last one is inside the loop. // CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}} // CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]> - // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> + // CHECK: triton_intel_gpu.prefetch {{.*}}, %[[LOOP_MASK]] {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp index a7e93f35e4..d65be0221c 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp @@ -150,9 +150,10 @@ static void collectOpsToPipeline(scf::ForOp forOp, } } -/// Combine the current mask with the given predicate. -static Value getPredMask(RewriterBase &rewriter, Type typeLike, - Value currentMask, Value pred) { +/// Return a new mask of type of shape \p typeLike, and value combining the +/// current mask \p currentMask with the given predicate \p pred. +static Value computeNewMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { Location loc = pred.getLoc(); Value mask = pred; Type maskType = tt::getI1SameShape(tt::getPointeeType(typeLike)); @@ -175,7 +176,7 @@ static Operation *predicateOp(RewriterBase &rewriter, Operation *op, .Case([&](auto op) { rewriter.setInsertionPoint(op); Value mask = - getPredMask(rewriter, op.getPtr().getType(), op.getMask(), pred); + computeNewMask(rewriter, op.getPtr().getType(), op.getMask(), pred); op.getMaskMutable().assign(mask); return op; }); From 5b538dde4cfd993d9adb51cc7f633979c6f559e9 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Sat, 26 Apr 2025 19:12:53 +0000 Subject: [PATCH 5/5] Add CHECKs for loop body prefetch Signed-off-by: Whitney Tsang --- test/TritonIntelGPU/loop-pipeline.mlir | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/TritonIntelGPU/loop-pipeline.mlir b/test/TritonIntelGPU/loop-pipeline.mlir index f4f6f394e2..e16e1405a1 100644 --- a/test/TritonIntelGPU/loop-pipeline.mlir +++ b/test/TritonIntelGPU/loop-pipeline.mlir @@ -86,9 +86,11 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> // CHECK: triton_intel_gpu.prefetch {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: scf.for %[[VAL_92:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args(%[[VAL_93:.*]] = {{.*}}, %[[VAL_94:.*]] = {{.*}}, %[[VAL_95:.*]] = {{.*}}, %[[VAL_96:.*]] = {{.*}}, %[[VAL_97:.*]] = {{.*}}) -> (tensor<64x256xf32, #[[$DPAS]]>, tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>, tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>) : i32 { + // CHECK: %[[LOAD_MASK:.*]] = arith.cmpi slt, {{.*}} // CHECK: %[[VAL_106:.*]] = tt.addptr %[[VAL_94]], {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]>, tensor<64x32xi32, #[[$BLOCK_0]]> // CHECK: %[[VAL_107:.*]] = tt.addptr %[[VAL_95]], {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>, tensor<32x256xi32, #[[$BLOCK_1]]> - // CHECK: triton_intel_gpu.prefetch %[[VAL_106]], {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> + // CHECK: %[[LOOP_MASK:.*]] = tt.splat %[[LOAD_MASK]] : i1 -> tensor<64x32xi1, #[[$BLOCK_0]]> + // CHECK: triton_intel_gpu.prefetch %[[VAL_106]], %[[LOOP_MASK]] {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> // CHECK: triton_intel_gpu.prefetch %[[VAL_107]], {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]> // CHECK: %[[VAL_116:.*]] = tt.load %[[VAL_96]], {{.*}}, {{.*}} : tensor<64x32x!tt.ptr, #[[$BLOCK_0]]> // CHECK: %[[VAL_120:.*]] = tt.load %[[VAL_97]], {{.*}}, {{.*}} : tensor<32x256x!tt.ptr, #[[$BLOCK_1]]>