diff --git a/test/TritonIntelGPU/prefetch-to-llvm.mlir b/test/TritonIntelGPU/prefetch-to-llvm.mlir index 082f75d5fd..6649d29004 100644 --- a/test/TritonIntelGPU/prefetch-to-llvm.mlir +++ b/test/TritonIntelGPU/prefetch-to-llvm.mlir @@ -1,75 +1,172 @@ -// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm +// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s // CHECK-DAG: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x1cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects, no_unwind} -// CHECK-DAG: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x2cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects, no_unwind} +// CHECK-DAG: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_2r16x2cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects, no_unwind} module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} { - tt.func public @matmul_with_prefetch(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) { - // CHECK-LABEL: @matmul_with_prefetch +// CHECK-LABEL: llvm.func spir_kernelcc @prefetch_block_ptr( +// CHECK-SAME: %[[BASE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[BASE_HEIGHT:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i64, +// CHECK-SAME: %[[BASE_WIDTH:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i64, +// CHECK-SAME: %[[ROW_STRIDE:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i64) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array} { + tt.func public @prefetch_block_ptr(%arg0: !tt.ptr, %arg2: i64, %arg4: i64, %arg5: i64) { %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 - // CHECK: %[[ROW_MAJOR_BLOCK_PTR:.*]] = llvm.insertvalue %arg0, {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[VAL_17:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_18:.*]] = llvm.zext %[[VAL_17]] : i32 to i64 - // CHECK: %[[VAL_19:.*]] = llvm.trunc %[[VAL_18]] : i64 to i32 - // CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_21:.*]] = llvm.urem %[[VAL_19]], %[[VAL_20]] : i32 - // CHECK: %[[VAL_22:.*]] = llvm.udiv %[[VAL_19]], %[[VAL_20]] : i32 - // CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_24:.*]] = llvm.urem %[[VAL_22]], %[[VAL_23]] : i32 - // CHECK: %[[VAL_25:.*]] = llvm.udiv %[[VAL_22]], %[[VAL_23]] : i32 - // CHECK: %[[ROW_MAJOR_OFFSET_Y:.*]] = llvm.extractvalue %[[ROW_MAJOR_BLOCK_PTR]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[ROW_MAJOR_OFFSET_X:.*]] = llvm.extractvalue %[[ROW_MAJOR_BLOCK_PTR]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[ROW_MAJOR_HEIGHT_:.*]] = llvm.extractvalue %[[ROW_MAJOR_BLOCK_PTR]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[ROW_MAJOR_WIDTH_:.*]] = llvm.extractvalue %[[ROW_MAJOR_BLOCK_PTR]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[ROW_MAJOR_ROW_STRIDE_:.*]] = llvm.extractvalue %[[ROW_MAJOR_BLOCK_PTR]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[ROW_MAJOR_BASE:.*]] = llvm.extractvalue %[[ROW_MAJOR_BLOCK_PTR]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[VAL_34:.*]] = llvm.mul %[[ROW_MAJOR_WIDTH_]], {{.*}} : i64 - // CHECK: %[[ROW_MAJOR_WIDTH:.*]] = llvm.trunc %[[VAL_34]] : i64 to i32 - // CHECK: %[[ROW_MAJOR_HEIGHT:.*]] = llvm.trunc %[[ROW_MAJOR_HEIGHT_]] : i64 to i32 - // CHECK: %[[ROW_MAJOR_ROW_STRIDE:.*]] = llvm.mul %[[ROW_MAJOR_ROW_STRIDE_]], {{.*}} : i64 - // CHECK: %[[ROW_MAJOR_STRIDE:.*]] = llvm.trunc %[[ROW_MAJOR_ROW_STRIDE]] : i64 to i32 - // CHECK: %[[COLUMN_MAJOR_WARP_OFF_X_:.*]] = llvm.add {{.*}}, %[[ROW_MAJOR_OFFSET_X]] : i32 - // CHECK: %[[COLUMN_MAJOR_WARP_OFF_Y_:.*]] = llvm.add {{.*}}, %[[ROW_MAJOR_OFFSET_Y]] : i32 - // CHECK: %[[COLUMN_MAJOR_WARP_OFF_Y:.*]] = llvm.trunc %[[COLUMN_MAJOR_WARP_OFF_Y_]] : i32 to i32 - // CHECK: %[[COLUMN_MAJOR_WARP_OFF_X:.*]] = llvm.trunc %[[COLUMN_MAJOR_WARP_OFF_X_]] : i32 to i32 - // CHECK: %[[VAL_56:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_57:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_59:.*]] = llvm.insertelement %[[COLUMN_MAJOR_WARP_OFF_X]], {{.*}}{{\[}}%[[VAL_57]] : i32] : vector<2xi32> - // CHECK: %[[ROW_MAJOR_COORD:.*]] = llvm.insertelement %[[COLUMN_MAJOR_WARP_OFF_Y]], {{.*}}{{\[}}%[[VAL_56]] : i32] : vector<2xi32> - // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x1cPU3AS1viiiDv2_i(%[[ROW_MAJOR_BASE]], %[[ROW_MAJOR_WIDTH]], %[[ROW_MAJOR_HEIGHT]], %[[ROW_MAJOR_STRIDE]], %[[ROW_MAJOR_COORD]]) {{.*}} : (!llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () - %rowMajorPtr = tt.make_tensor_ptr %arg0, [%arg2, %arg4], [%arg5, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > - triton_intel_gpu.prefetch %rowMajorPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> - - // COM: The memory layout is same for the column major memory and row major memory. The prefetch function should be the same. - - // CHECK: %[[COLUMN_MAJOR_BLOCK_PTR:.*]] = llvm.insertvalue %arg1, {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[COLUMN_MAJOR_OFFSET_Y:.*]] = llvm.extractvalue %[[COLUMN_MAJOR_BLOCK_PTR]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[COLUMN_MAJOR_OFFSET_X:.*]] = llvm.extractvalue %[[COLUMN_MAJOR_BLOCK_PTR]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[COLUMN_MAJOR_HEIGHT_:.*]] = llvm.extractvalue %[[COLUMN_MAJOR_BLOCK_PTR]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[COLUMN_MAJOR_WIDTH:.*]] = llvm.extractvalue %[[COLUMN_MAJOR_BLOCK_PTR]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[COLUMN_MAJOR_COL_STRIDE:.*]] = llvm.extractvalue %[[COLUMN_MAJOR_BLOCK_PTR]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[COLUMN_MAJOR_BASE:.*]] = llvm.extractvalue %[[COLUMN_MAJOR_BLOCK_PTR]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> - // CHECK: %[[VAL_86:.*]] = llvm.mul %[[COLUMN_MAJOR_HEIGHT_]], {{.*}} : i64 - // CHECK: %[[COLUMN_MAJOR_HEIGHT:.*]] = llvm.trunc %[[VAL_86]] : i64 to i32 - // CHECK: %[[COLUMN_MAJOR_WIDTH_:.*]] = llvm.trunc %[[COLUMN_MAJOR_WIDTH]] : i64 to i32 - // CHECK: %[[VAL_90:.*]] = llvm.mul %[[COLUMN_MAJOR_COL_STRIDE]], {{.*}} : i64 - // CHECK: %[[COLUMN_MAJOR_STRIDE:.*]] = llvm.trunc %[[VAL_90]] : i64 to i32 - // CHECK: %[[COLUMN_MAJOR_WARP_OFF_X_:.*]] = llvm.add {{.*}}, %[[COLUMN_MAJOR_OFFSET_X]] : i32 - // CHECK: %[[COLUMN_MAJOR_WARP_OFF_Y_:.*]] = llvm.add {{.*}}, %[[COLUMN_MAJOR_OFFSET_Y]] : i32 - // CHECK: %[[COLUMN_MAJOR_WARP_OFF_Y:.*]] = llvm.trunc %[[COLUMN_MAJOR_WARP_OFF_Y_]] : i32 to i32 - // CHECK: %[[COLUMN_MAJOR_WARP_OFF_X:.*]] = llvm.trunc %[[COLUMN_MAJOR_WARP_OFF_X_]] : i32 to i32 - // CHECK: %[[VAL_108:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_109:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.insertelement %[[COLUMN_MAJOR_WARP_OFF_X]], {{.*}}{{\[}}%[[VAL_109]] : i32] : vector<2xi32> - // CHECK: %[[COLUMN_MAJOR_COORD:.*]] = llvm.insertelement %[[COLUMN_MAJOR_WARP_OFF_Y]], {{.*}}{{\[}}%[[VAL_108]] : i32] : vector<2xi32> - // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x2cPU3AS1viiiDv2_i(%[[COLUMN_MAJOR_BASE]], %[[COLUMN_MAJOR_HEIGHT]], %[[COLUMN_MAJOR_WIDTH_]], %[[COLUMN_MAJOR_STRIDE]], %[[COLUMN_MAJOR_COORD]]) {{.*}} : (!llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> () - %columnMajorPtr = tt.make_tensor_ptr %arg1, [%arg4, %arg3], [%c1_i64, %arg6], [%c0_i32, %c0_i32] {order = array} : > - triton_intel_gpu.prefetch %columnMajorPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false, triton_intel_gpu.block_io = "column_major"} : !tt.ptr> + // CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-DAG: %[[CST_2_I32:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i64) : i64 + // CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_15:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_16:.*]] = llvm.zext %[[VAL_15]] : i32 to i64 + // CHECK: %[[VAL_17:.*]] = llvm.trunc %[[VAL_16]] : i64 to i32 + // CHECK: %[[VAL_18:.*]] = llvm.urem %[[VAL_17]], %[[CST_1]] : i32 + // CHECK: %[[VAL_19:.*]] = llvm.udiv %[[VAL_17]], %[[CST_1]] : i32 + // CHECK: %[[VAL_20:.*]] = llvm.urem %[[VAL_19]], %[[CST_8]] : i32 + // CHECK: %[[VAL_21:.*]] = llvm.mul %[[BASE_WIDTH]], %[[CST_2]] : i64 + // CHECK: %[[ROW_MAJOR_BASE_WIDTH:.*]] = llvm.trunc %[[VAL_21]] : i64 to i32 + // CHECK: %[[ROW_MAJOR_BASE_HEIGHT:.*]] = llvm.trunc %[[BASE_HEIGHT]] : i64 to i32 + // CHECK: %[[VAL_24:.*]] = llvm.mul %[[ROW_STRIDE]], %[[CST_2]] : i64 + // CHECK: %[[ROW_MAJOR_PITCH:.*]] = llvm.trunc %[[VAL_24]] : i64 to i32 + // CHECK: %[[VAL_26:.*]] = llvm.mul %[[VAL_18]], %[[CST_32]] : i32 + // CHECK: %[[VAL_27:.*]] = llvm.add %[[VAL_26]], %[[CST_0]] : i32 + // CHECK: %[[VAL_28:.*]] = llvm.urem %[[VAL_27]], %[[CST_32]] : i32 + // CHECK: %[[VAL_29:.*]] = llvm.add %[[VAL_28]], %[[CST_0]] : i32 + // CHECK: %[[VAL_30:.*]] = llvm.mul %[[VAL_20]], %[[CST_2_I32]] : i32 + // CHECK: %[[VAL_31:.*]] = llvm.add %[[VAL_30]], %[[CST_0]] : i32 + // CHECK: %[[VAL_32:.*]] = llvm.urem %[[VAL_31]], %[[CST_16]] : i32 + // CHECK: %[[VAL_33:.*]] = llvm.add %[[VAL_32]], %[[CST_0]] : i32 + // CHECK: %[[ROW_MAJOR_OFFSET_Y:.*]] = llvm.trunc %[[VAL_33]] : i32 to i32 + // CHECK: %[[ROW_MAJOR_OFFSET_X:.*]] = llvm.trunc %[[VAL_29]] : i32 to i32 + // CHECK: %[[VAL_36:.*]] = llvm.insertelement %[[ROW_MAJOR_OFFSET_X]], {{.*}} : i32] : vector<2xi32> + // CHECK: %[[ROW_MAJOR_OFFSETS:.*]] = llvm.insertelement %[[ROW_MAJOR_OFFSET_Y]], {{.*}} : i32] : vector<2xi32> + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_2r16x2cPU3AS1viiiDv2_i(%[[BASE]], %[[ROW_MAJOR_BASE_WIDTH]], %[[ROW_MAJOR_BASE_HEIGHT]], %[[ROW_MAJOR_PITCH]], %[[ROW_MAJOR_OFFSETS]]) + %rowMajorPtr = tt.make_tensor_ptr %arg0, [%arg2, %arg4], [%arg5, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + triton_intel_gpu.prefetch %rowMajorPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + + // CHECK: %[[VAL_32:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_33:.*]] = llvm.zext %[[VAL_32]] : i32 to i64 + // CHECK: %[[VAL_34:.*]] = llvm.trunc %[[VAL_33]] : i64 to i32 + // CHECK: %[[VAL_35:.*]] = llvm.urem %[[VAL_34]], %[[CST_2_I32]] : i32 + // CHECK: %[[VAL_36:.*]] = llvm.udiv %[[VAL_34]], %[[CST_2_I32]] : i32 + // CHECK: %[[VAL_37:.*]] = llvm.urem %[[VAL_36]], %[[CST_4]] : i32 + // CHECK: %[[VAL_38:.*]] = llvm.mul %[[BASE_WIDTH]], %[[CST_2]] : i64 + // CHECK: %[[COL_MAJOR_BASE_WIDTH:.*]] = llvm.trunc %[[VAL_38]] : i64 to i32 + // CHECK: %[[COL_MAJOR_BASE_HEIGHT:.*]] = llvm.trunc %[[BASE_HEIGHT]] : i64 to i32 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[ROW_STRIDE]], %[[CST_2]] : i64 + // CHECK: %[[COL_MAJOR_PITCH:.*]] = llvm.trunc %[[VAL_41]] : i64 to i32 + // CHECK: %[[VAL_43:.*]] = llvm.mul %[[VAL_35]], %[[CST_16]] : i32 + // CHECK: %[[VAL_44:.*]] = llvm.add %[[VAL_43]], %[[CST_0]] : i32 + // CHECK: %[[VAL_45:.*]] = llvm.urem %[[VAL_44]], %[[CST_32]] : i32 + // CHECK: %[[VAL_46:.*]] = llvm.add %[[VAL_45]], %[[CST_0]] : i32 + // CHECK: %[[VAL_47:.*]] = llvm.mul %[[VAL_37]], %[[CST_4]] : i32 + // CHECK: %[[VAL_48:.*]] = llvm.add %[[VAL_47]], %[[CST_0]] : i32 + // CHECK: %[[VAL_49:.*]] = llvm.urem %[[VAL_48]], %[[CST_16]] : i32 + // CHECK: %[[VAL_50:.*]] = llvm.add %[[VAL_49]], %[[CST_0]] : i32 + // CHECK: %[[COL_MAJOR_OFFSET_Y:.*]] = llvm.trunc %[[VAL_50]] : i32 to i32 + // CHECK: %[[COL_MAJOR_OFFSET_X:.*]] = llvm.trunc %[[VAL_46]] : i32 to i32 + // CHECK: %[[VAL_54:.*]] = llvm.insertelement %[[COL_MAJOR_OFFSET_X]], {{.*}} : i32] : vector<2xi32> + // CHECK: %[[COL_MAJOR_OFFSETS:.*]] = llvm.insertelement %[[COL_MAJOR_OFFSET_Y]], {{.*}} : i32] : vector<2xi32> + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x1cPU3AS1viiiDv2_i(%[[BASE]], %[[COL_MAJOR_BASE_WIDTH]], %[[COL_MAJOR_BASE_HEIGHT]], %[[COL_MAJOR_PITCH]], %[[COL_MAJOR_OFFSETS]]) {{.*}} + %columnMajorPtr = tt.make_tensor_ptr %arg0, [%arg4, %arg2], [%c1_i64, %arg5], [%c0_i32, %c0_i32] {order = array} : > + triton_intel_gpu.prefetch %columnMajorPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false, triton_intel_gpu.block_io = "column_major"} : !tt.ptr> + + // COM: The memory is not structured densely. Not to prefetch it to the cache. // CHECK-NOT: block_prefetch - %nonContiguousPtr = tt.make_tensor_ptr %arg1, [%arg4, %arg3], [%arg6, %arg6], [%c0_i32, %c0_i32] {order = array} : > - triton_intel_gpu.prefetch %nonContiguousPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> + %nonContiguousPtr = tt.make_tensor_ptr %arg0, [%arg4, %arg2], [%arg5, %arg5], [%c0_i32, %c0_i32] {order = array} : > + triton_intel_gpu.prefetch %nonContiguousPtr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @prefetch_tensor_of_pointers + tt.func public @prefetch_tensor_of_pointers(%tensor_of_ptr: tensor<64x32x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>>) { + // CHECK: %[[MASK:.*]] = llvm.mlir.constant(1 : i8) : i8 + // CHECK: %[[VAL_2:.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE_HEIGHT:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[BASE_WIDTH:.*]] = llvm.mlir.constant(64 : i32) : i32 + // CHECK: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1 + + // CHECK: %[[ADDR_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>)> + // CHECK: %[[ADDR_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>)> + // CHECK: %[[ADDR_16:.*]] = llvm.extractvalue {{.*}}[16] : !llvm.struct<(ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>)> + // CHECK: %[[ADDR_32:.*]] = llvm.extractvalue {{.*}}[32] : !llvm.struct<(ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>)> + // CHECK: %[[ADDR_48:.*]] = llvm.extractvalue {{.*}}[48] : !llvm.struct<(ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>, ptr<1>)> + // CHECK: %[[VAL_13:.*]] = llvm.ptrtoint %[[ADDR_0]] : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_14:.*]] = llvm.ptrtoint %[[ADDR_1]] : !llvm.ptr<1> to i64 + // CHECK: %[[PITCH:.*]] = llvm.sub %[[VAL_14]], %[[VAL_13]] : i64 + // CHECK: %[[UNIFIED_PITCH:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[PITCH]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[UNIFIED_PITCH_I32:.*]] = llvm.trunc %[[UNIFIED_PITCH]] : i64 to i32 + // CHECK: %[[VAL_18:.*]] = llvm.intr.umax(%[[UNIFIED_PITCH_I32]], %[[BASE_WIDTH]]) : (i32, i32) -> i32 + // CHECK: %[[PITCH_IN_BYTES_I32:.*]] = llvm.trunc %[[VAL_18]] : i32 to i32 + + // CHECK: %[[UNIFIED_MASK:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[MASK]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i8, i32) -> i8 + // CHECK: %[[UNIFIED_MASK_I1:.*]] = llvm.trunc %[[UNIFIED_MASK]] : i8 to i1 + // CHECK: %[[OFFSET_Y:.*]] = llvm.select %[[UNIFIED_MASK_I1]], %[[CST_0]], %[[BASE_HEIGHT]] : i1, i32 + // CHECK: %[[UNIFIED_BASE:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_13]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[VAL_26:.*]] = llvm.inttoptr %[[UNIFIED_BASE]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAL_27:.*]] = llvm.insertelement %[[CST_0]], {{.*}} : vector<2xi32> + // CHECK: %[[OFFSETS:.*]] = llvm.insertelement %[[OFFSET_Y]], {{.*}} : vector<2xi32> + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_26]], %[[BASE_WIDTH]], %[[BASE_HEIGHT]], %[[PITCH_IN_BYTES_I32]], %[[OFFSETS]]) + + // CHECK: %[[VAL_29:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[MASK]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i8, i32) -> i8 + // CHECK: %[[VAL_30:.*]] = llvm.trunc %[[VAL_29]] : i8 to i1 + // CHECK: %[[VAL_31:.*]] = llvm.select %[[VAL_30]], %[[CST_0]], %[[BASE_HEIGHT]] : i1, i32 + // CHECK: %[[VAL_32:.*]] = llvm.ptrtoint %[[ADDR_16]] : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_33:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_32]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[VAL_34:.*]] = llvm.inttoptr %[[VAL_33]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAL_35:.*]] = llvm.insertelement %[[VAL_31]], {{.*}} : vector<2xi32> + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_34]], %[[BASE_WIDTH]], %[[BASE_HEIGHT]], %[[PITCH_IN_BYTES_I32]], %[[VAL_35]]) + + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[MASK]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i8, i32) -> i8 + // CHECK: %[[VAL_37:.*]] = llvm.trunc %[[VAL_36]] : i8 to i1 + // CHECK: %[[VAL_38:.*]] = llvm.select %[[VAL_37]], %[[CST_0]], %[[BASE_HEIGHT]] : i1, i32 + // CHECK: %[[VAL_39:.*]] = llvm.ptrtoint %[[ADDR_32]] : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_40:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_39]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[VAL_41:.*]] = llvm.inttoptr %[[VAL_40]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAL_42:.*]] = llvm.insertelement %[[VAL_38]], {{.*}} : i32] : vector<2xi32> + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_41]], %[[BASE_WIDTH]], %[[BASE_HEIGHT]], %[[PITCH_IN_BYTES_I32]], %[[VAL_42]]) + + // CHECK: %[[VAL_43:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[MASK]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i8, i32) -> i8 + // CHECK: %[[VAL_44:.*]] = llvm.trunc %[[VAL_43]] : i8 to i1 + // CHECK: %[[VAL_45:.*]] = llvm.select %[[VAL_44]], %[[CST_0]], %[[BASE_HEIGHT]] : i1, i32 + // CHECK: %[[VAL_46:.*]] = llvm.ptrtoint %[[ADDR_48]] : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_47:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_46]], %[[CST_0]]) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[VAL_48:.*]] = llvm.inttoptr %[[VAL_47]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAL_49:.*]] = llvm.insertelement %[[VAL_45]], {{.*}} : i32] : vector<2xi32> + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%[[VAL_48]], %[[BASE_WIDTH]], %[[BASE_HEIGHT]], %[[PITCH_IN_BYTES_I32]], %[[VAL_49]]) + + %mask_tensor = arith.constant dense<1> : tensor<64x32xi1, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>> + triton_intel_gpu.prefetch %tensor_of_ptr, %mask_tensor {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>> + + // CHECK-COUNT-4: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i + + triton_intel_gpu.prefetch %tensor_of_ptr {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : tensor<64x32x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>> + + tt.return + } +} + +// ----- + +// COM: Currently the prefetch operation in this test cannot be lowered correctly, so we check that the test compiles cleanly and not 2D block prefetch operation gets generated. +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [4, 1], A = [32, 8], B = [8, 16], C = [32, 16]}> +module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @kernel + tt.func public @kernel(%arg0 : tensor<128x32x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>) { + // CHECK-NOT: intel_sub_group_2d_block_prefetch + triton_intel_gpu.prefetch %arg0 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array, triton_intel_gpu.block_io = "row_major"} : tensor<128x32x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> tt.return } } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 0f2130e640..d29841b5f2 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -316,7 +316,7 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { // Determine whether the given LoadOp can be lowered to using block IO // instructions. - bool isLoadCandidate(triton::LoadOp op) const { + static bool isLoadCandidate(triton::LoadOp op) { Attribute blockIOAttr = op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); if (!blockIOAttr) @@ -332,7 +332,7 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { std::enable_if_t::value, bool> = true> - bool isMemoryRowMajor(OpTy op) const { + static bool isMemoryRowMajor(OpTy op) { Attribute blockIOAttr = op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); assert(blockIOAttr && "Expecting block IO attribute"); @@ -347,7 +347,7 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { return memoryLayoutInfo == "row_major"; } - DpasEncodingAttr::OpIdx getOpIdx(RankedTensorType tensorTy) const { + static DpasEncodingAttr::OpIdx getOpIdx(RankedTensorType tensorTy) { if (hasDpasEncoding(tensorTy)) return DpasEncodingAttr::OpIdx::OperandC; @@ -356,7 +356,7 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { return static_cast(dotLayout.getOpIdx()); } - DpasEncodingAttr getDpasLayout(RankedTensorType tensorTy) const { + static DpasEncodingAttr getDpasLayout(RankedTensorType tensorTy) { Attribute encoding = tensorTy.getEncoding(); return cast( hasDpasEncoding(tensorTy) @@ -382,12 +382,21 @@ struct PrefetchOpConversion LogicalResult matchAndRewrite(triton::gpu::intel::PrefetchOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - Value ptr = op.getPtr(); - if (isTensorPointerType(ptr.getType())) - return rewriteTensorPointerPrefetch(op, adaptor, rewriter); + LogicalResult res = + isTensorPointerType(op.getPtr().getType()) + ? rewriteTensorPointerPrefetch(op, adaptor, rewriter) + : rewriteRegularPointerPrefetch(op, adaptor, rewriter); + + // FIXME: the prefetch lowering code should never fail. Currently it does in + // some cases. We should address those cases instead of removing the + // prefetch operation. + if (failed(res)) { + op.emitWarning("Prefetch operation could not be converted to LLVM. " + "The operation was erased."); + rewriter.eraseOp(op); + } - llvm_unreachable("Unexpected prefetch operation on 'regular' ptr"); - return failure(); + return success(); } LogicalResult @@ -530,6 +539,220 @@ struct PrefetchOpConversion rewriter.eraseOp(op); return success(); } + + LogicalResult + rewriteRegularPointerPrefetch(triton::gpu::intel::PrefetchOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Attribute blockIOAttr = + op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); + if (!blockIOAttr) + return failure(); + + // Only support rank 2 block pointer, either row major or column major. + StringRef memoryLayoutInfo = cast(blockIOAttr).getValue(); + assert((memoryLayoutInfo == "row_major" || + memoryLayoutInfo == "column_major") && + "Only row_major or column_major is supported"); + + const bool memoryRowMajor = (memoryLayoutInfo == "row_major"); + + // TODO: To support more layouts on memory. + if (!memoryRowMajor) + return failure(); + + auto tensorOfPointers = cast(op.getPtr().getType()); + std::optional encoding = + getDotEncoding(tensorOfPointers); + if (!encoding) + return failure(); + + auto dpasLayout = cast(encoding->getParent()); + SmallVector warpsPerCTA(dpasLayout.getWarpsPerCTA()); + ArrayRef cluster = dpasLayout.getRepCluster(); + SmallVector repCluster{cluster.begin(), cluster.end()}; + ArrayRef tensorShape = tensorOfPointers.getShape(); + DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorOfPointers); + SmallVector repetitions = + dpasLayout.getDPASRepetitions(tensorShape, opIdx); + assert(repetitions.size() == 3 && + "getDPASRepetitions always return rank 3 size"); + SmallVector numReps{repetitions.begin() + 1, repetitions.end()}; + + SmallVector shardTensorShape; + switch (opIdx) { + case DpasEncodingAttr::OpIdx::OperandA: { + shardTensorShape = { + std::min(tensorShape[0], dpasLayout.getShapeA()[0]), + tensorShape[1]}; + warpsPerCTA[1] = 1; + repCluster[1] = 1; + numReps[1] = 1; + } break; + case DpasEncodingAttr::OpIdx::OperandB: { + shardTensorShape = { + tensorShape[0], + std::min(tensorShape[1], dpasLayout.getShapeB()[1])}; + warpsPerCTA[0] = 1; + repCluster[0] = 1; + numReps[0] = 1; + } break; + case DpasEncodingAttr::OpIdx::OperandC: { + llvm_unreachable("unexpected OpIdx::OperandC"); + } break; + } + + auto ptrType = cast(tensorOfPointers.getElementType()); + Type elementType = ptrType.getPointeeType(); + auto tensorType = RankedTensorType::get(shardTensorShape, elementType, + tensorOfPointers.getEncoding()); + + Value mask = op.getMask(); + unsigned maskConstancyHor = std::numeric_limits::max(), + maskConstancyVer = std::numeric_limits::max(); + if (mask) { + // No need to check the constancy of scalar mask. + if (auto maskTy = dyn_cast_or_null(mask.getType())) { + maskConstancyHor = maskConstancyVer = 1; + AxisInfo *axisInfo = + const_cast( + axisAnalysisPass) + .getAxisInfo(mask); + if (axisInfo) { + maskConstancyHor = axisInfo->getConstancy(1); + maskConstancyVer = axisInfo->getConstancy(0); + } + } + } + + SmallVector prefetchShape = + get2DPrefetchShapePerWarp(tensorType); + prefetchShape = {std::min(prefetchShape[0], maskConstancyVer), + std::min(prefetchShape[1], maskConstancyHor)}; + + SmallVector numPrefetchsPerRep = { + mlir::ceil(shardTensorShape[0], prefetchShape[0]), + mlir::ceil(shardTensorShape[1], prefetchShape[1])}; + + Type eltTy = tensorType.getElementType(); + unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); + unsigned tileWidthInElem = prefetchShape[1]; + unsigned tileHeightInElem = prefetchShape[0]; + unsigned vBlocks = 1; + switch (elemSizeInBits) { + case 8: + if (tileWidthInElem == 64) { + // OCL interface supports 8b_?r32x2c for 64 bytes per row of 8 bits + // element. + vBlocks = 2; + tileWidthInElem = 32; + } + break; + case 16: + if (tileWidthInElem == 32) { + // OCL interface supports 16b_?r16x2c for 64 bytes per row of 16 bits + // element. + vBlocks = 2; + tileWidthInElem = 16; + } + break; + } + + auto mod = rewriter.getBlock()->getParent()->getParentOfType(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + std::map, Value> baseAddrs, masks; + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + + // Get the LLVM values for pointers + SmallVector ptrElems = unpackLLElements(loc, llPtr, rewriter); + SmallVector maskElems; + if (llMask) + maskElems = unpackLLElements(loc, llMask, rewriter); + + // re-arrange the baseAddrs and masks to for large 2D block IO. + // Layout is unrelated to the scalar type. + SmallVector> offsets = + emitOffsetForLayout(*encoding, tensorOfPointers); + for (size_t i = 0; i < ptrElems.size(); ++i) { + SmallVector offset = offsets[i]; + baseAddrs[offset] = ptrElems[i]; + if (llMask && maskElems.size() > 1) + masks[offset] = maskElems[i]; + } + + // baseAddrs[{0, 0}] and baseAddrs[{1, 0}] are currently used to calculate + // the pitch. + if (baseAddrs.count({0, 0}) == 0 || baseAddrs.count({1, 0}) == 0) + return failure(); + + Value baseWidth = + b.i32_val(vBlocks * tileWidthInElem * (elemSizeInBits / 8)); + Value baseHeight = b.i32_val(tileHeightInElem); + Value offsetBaseX = b.i32_val(0); + Value offsetBaseY = b.i32_val(0); + Value rowStrideInBytes = b.sub(b.ptrtoint(i64_ty, baseAddrs[{1, 0}]), + b.ptrtoint(i64_ty, baseAddrs[{0, 0}])); + rowStrideInBytes = + targetInfo.shuffleIdx(rewriter, loc, rowStrideInBytes, 0); + rowStrideInBytes = b.umax(b.trunc(i32_ty, rowStrideInBytes), baseWidth); + rowStrideInBytes = b.trunc(i32_ty, rowStrideInBytes); + + for (int row = 0; row < numReps[0]; ++row) { + for (int col = 0; col < numReps[1]; ++col) { + // Prefetch the data for each repetitions. + for (int i = 0; i < numPrefetchsPerRep[0]; ++i) + for (int j = 0; j < numPrefetchsPerRep[1]; ++j) { + unsigned offsetN = col * warpsPerCTA[1] * shardTensorShape[1] + + j * prefetchShape[1]; + unsigned offsetM = row * warpsPerCTA[0] * shardTensorShape[0] + + i * prefetchShape[0]; + + Value pred; + if (llMask) + pred = (maskElems.size() > 1) + ? targetInfo.shuffleIdx(rewriter, loc, + masks[{offsetM, offsetN}], 0) + : maskElems[0]; + + else + pred = b.int_val(1, 1); + + // If the mask exists and evaluates to false, we set offsetY to be + // equal to baseHeight, which causes the HW to ignore the generated + // prefetch operation (given that the block to be prefetched would + // be outside the baseWidth X baseHeight shape). + Value offsetY = b.select(pred, b.i32_val(0), baseHeight); + Value addr = targetInfo.shuffleIdx( + rewriter, loc, baseAddrs[{offsetM, offsetN}], 0); + + auto newOp = rewriter.create( + loc, + /*ptr*/ addr, + /*base_width*/ baseWidth, + /*base_height*/ baseHeight, + /*base_pitch*/ rowStrideInBytes, + /*x*/ offsetBaseX, + /*y*/ offsetY, + /*elem_size_in_bits*/ elemSizeInBits, + /*tile_width*/ tileWidthInElem, + /*tile_height*/ tileHeightInElem, + /*v_blocks*/ vBlocks, + /*cache_opt*/ TritonGEN::LoadCacheControl::L1C_L3C); + if (failed(newOp.verify())) { + // Explicitly invoke verifier because `triton_gen` ops are + // immediately lowered further to a builtin call. + return failure(); + } + } + } + } + + rewriter.eraseOp(op); + return success(); + } }; struct LoadOpToBlockIOConversion @@ -693,73 +916,73 @@ struct LoadOpToBlockIOConversion if (isTensorPointerType(ptr.getType())) { // TODO: move the tensor pointer rewrite code here. return failure(); - } else { - Value llPtr = adaptor.getPtr(); - Value llMask = adaptor.getMask(); - Value llOther = adaptor.getOther(); + } - SmallVector ptrElems, maskElems, otherElems; - // Get the LLVM values for pointers - ptrElems = unpackLLElements(loc, llPtr, rewriter); - assert(ptrElems.size() == numElems && - "the number of pointer values is not matched with the number of " - "elements"); + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); - // Get the LLVM values for mask - if (llMask) { - maskElems = unpackLLElements(loc, llMask, rewriter); - assert(maskElems.size() == numElems && - "the number of mask values is not matched with the number of " - "elements"); - auto axisInfo = const_cast( - axisAnalysisPass) - .getAxisInfo(mask); - if (axisInfo) { - maskConstancyHor = axisInfo->getConstancy(rank - 1); - maskConstancyVer = axisInfo->getConstancy(rank - 2); - } else { - maskConstancyHor = 1; - maskConstancyVer = 1; - } + SmallVector ptrElems, maskElems, otherElems; + // Get the LLVM values for pointers + ptrElems = unpackLLElements(loc, llPtr, rewriter); + assert(ptrElems.size() == numElems && + "the number of pointer values is not matched with the number of " + "elements"); + + // Get the LLVM values for mask + if (llMask) { + maskElems = unpackLLElements(loc, llMask, rewriter); + assert(maskElems.size() == numElems && + "the number of mask values is not matched with the number of " + "elements"); + auto axisInfo = + const_cast(axisAnalysisPass) + .getAxisInfo(mask); + if (axisInfo) { + maskConstancyHor = axisInfo->getConstancy(rank - 1); + maskConstancyVer = axisInfo->getConstancy(rank - 2); } else { - // no mask - maskConstancyHor = std::numeric_limits::max(); - maskConstancyVer = std::numeric_limits::max(); + maskConstancyHor = 1; + maskConstancyVer = 1; } + } else { + // no mask + maskConstancyHor = std::numeric_limits::max(); + maskConstancyVer = std::numeric_limits::max(); + } - // Check the constancy of the mask support to load the memory in 2D block. - if (!(maskConstancyHor >= instWidth && maskConstancyVer >= instHeight)) - return failure(); + // Check the constancy of the mask support to load the memory in 2D block. + if (!(maskConstancyHor >= instWidth && maskConstancyVer >= instHeight)) + return failure(); - // Get the LLVM values for `other` - DenseElementsAttr constAttr; - if (other && isa(eltTy) && - matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && - isa(constAttr.getElementType())) { - otherIsSplatConstInt = true; - splatVal = constAttr.getSplatValue().getSExtValue(); - } - if (other) { - otherElems = unpackLLElements(loc, llOther, rewriter); - } + // Get the LLVM values for `other` + DenseElementsAttr constAttr; + if (other && isa(eltTy) && + matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && + isa(constAttr.getElementType())) { + otherIsSplatConstInt = true; + splatVal = constAttr.getSplatValue().getSExtValue(); + } + if (other) { + otherElems = unpackLLElements(loc, llOther, rewriter); + } - // re-arrange the ptrs and masks to for large 2D block IO. - // Layout is unrelated to the scalar type. - SmallVector> offsets = - mlir::emitOffsetForLayout(encoding, tensorType); - for (size_t i = 0; i < ptrElems.size(); ++i) { - SmallVector offset = offsets[i]; - ptrs[offset] = ptrElems[i]; - if (llMask) - masks[offset] = maskElems[i]; - if (otherElems.size()) - others[offset] = otherElems[i]; - } - // ptrs[{0, 0}] and ptrs[{1, 0}] are currently used to calculate the - // pitch. - if (ptrs.count({0, 0}) == 0 || ptrs.count({1, 0}) == 0) - return failure(); + // re-arrange the ptrs and masks to for large 2D block IO. + // Layout is unrelated to the scalar type. + SmallVector> offsets = + mlir::emitOffsetForLayout(encoding, tensorType); + for (size_t i = 0; i < ptrElems.size(); ++i) { + SmallVector offset = offsets[i]; + ptrs[offset] = ptrElems[i]; + if (llMask) + masks[offset] = maskElems[i]; + if (otherElems.size()) + others[offset] = otherElems[i]; } + // ptrs[{0, 0}] and ptrs[{1, 0}] are currently used to calculate the + // pitch. + if (ptrs.count({0, 0}) == 0 || ptrs.count({1, 0}) == 0) + return failure(); unsigned numOperandsPer2DLoadM, numOperandsPer2DloadN; if (!isTransposeRequired) {