From 607d0f49e4b94418062a3aa9b1368e5fe7f0048d Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 5 Jun 2024 00:19:25 +0000 Subject: [PATCH] [GEN] Extend Matrix2DBlockLoadOp OCL usages Signed-off-by: Whitney Tsang --- .../tritongpu_to_llvm_intel_block_ptr.mlir | 15 +- .../tritongen-2Dblockload-to-llvm.mlir | 271 ++++++++++++++++++ test/TritonGEN/tritongen-to-llvm.mlir | 42 --- test/TritonIntelGPU/load-to-llvm-2dload.mlir | 8 +- test/TritonIntelGPU/prefetch-to-llvm.mlir | 10 +- .../TritonGENToLLVM/TritonGENToLLVMPass.cpp | 38 ++- 6 files changed, 313 insertions(+), 71 deletions(-) create mode 100644 test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir diff --git a/test/Conversion/intel/tritongpu_to_llvm_intel_block_ptr.mlir b/test/Conversion/intel/tritongpu_to_llvm_intel_block_ptr.mlir index 9cb6ab39d7..0cf60e4597 100644 --- a/test/Conversion/intel/tritongpu_to_llvm_intel_block_ptr.mlir +++ b/test/Conversion/intel/tritongpu_to_llvm_intel_block_ptr.mlir @@ -3,8 +3,8 @@ module attributes {"triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 1 : i32} { // CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32, vector<8xi32>) // CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]} - // CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<32xi32> - // CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v64i16(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<64xi16> + // CHECK-DAG: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_read_16b_32r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {passthrough = ["nounwind"]} + // CHECK-DAG: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x2cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {passthrough = ["nounwind"]} // CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { @@ -66,13 +66,14 @@ module attributes {"triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-wa %62 = arith.cmpi slt, %40, %c4096_i32 : i32 cf.cond_br %62, ^bb2, ^bb3 ^bb2: - // CHECK: [[A_PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64 - // CHECK: [[A:%.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v64i16([[A_PTR]], {{.*}} -> vector<64xi16> + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_32r16x2cPU3AS1viiiDv2_iPt(%arg0, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[A_PTR:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK: [[A:%.*]] = llvm.load [[A_PTR]] : !llvm.ptr -> vector<64xi16> // CHECK-NEXT: [[castA:%.*]] = llvm.bitcast [[A]] : vector<64xi16> to vector<64xf16> - // CHECK: [[B_PTR:%.*]] = llvm.ptrtoint %arg1 : !llvm.ptr<1> to i64 - // CHECK: [[B0:%.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i32([[B_PTR]], {{.*}} -> vector<32xi32> + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x2cPU3AS1viiiDv2_iPj(%arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[B_PTR:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK: [[B0:%.*]] = llvm.load [[B_PTR]] : !llvm.ptr -> vector<32xi32> // CHECK-NEXT: [[castB:%.*]] = llvm.bitcast [[B0]] : vector<32xi32> to vector<64xf16> - // CHECK: [[B1:%.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i32({{.*}} -> vector<32xi32> + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x2cPU3AS1viiiDv2_iPj(%arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[B_PTR:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK: [[B1:%.*]] = llvm.load [[B_PTR]] : !llvm.ptr -> vector<32xi32> // CHECK: [[subA1:%.*]] = llvm.shufflevector [[castA]], [[castA]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<64xf16> // CHECK: [[subB1:%.*]] = llvm.shufflevector [[castB]], [[castB]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<64xf16> // CHECK-NEXT: [[castDotA1:%.*]] = llvm.bitcast [[subA1]] : vector<8xf16> to vector<8xi16> diff --git a/test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir b/test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir new file mode 100644 index 0000000000..5b855218b4 --- /dev/null +++ b/test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir @@ -0,0 +1,271 @@ +// RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s + +// CHECK: llvm.func spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x1cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {passthrough = ["nounwind"]} + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) { + // CHECK: [[EIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-NEXT: [[DEST:%.*]] = llvm.alloca [[EIGHT]] x i16 : (i32) -> !llvm.ptr + // CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: [[UNDEF:%.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK-NEXT: [[COORD0:%.*]] = llvm.insertelement %arg4, [[UNDEF]][[[ZERO]] : i32] : vector<2xi32> + // CHECK-NEXT: [[COORD1:%.*]] = llvm.insertelement %arg5, [[COORD0]][[[ONE]] : i32] : vector<2xi32> + // CHECK-NEXT: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, [[COORD1]], [[DEST]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return +} + +// ----- + +// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i16(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<16xi16> + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) { + // CHECK-DAG: [[PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64 + // CHECK-DAG: [[CST_1:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: [[CST_8:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-DAG: [[CST_16:%.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-DAG: [[CST_32:%.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK-DAG: [[CST_FALSE_1:%.*]] = llvm.mlir.constant(false) : i1 + // CHECK-DAG: [[CST_FALSE_2:%.*]] = llvm.mlir.constant(false) : i1 + // CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: [[WIDTH:%.*]] = llvm.sub %arg1, [[ONE]] : i32 + // CHECK-DAG: [[HEIGHT:%.*]] = llvm.sub %arg2, [[ONE]] : i32 + // CHECK-DAG: [[PITCH:%.*]] = llvm.sub %arg3, [[ONE]] : i32 + // CHECK-NEXT: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i16([[PTR]], [[WIDTH]], [[HEIGHT]], [[PITCH]], %arg4, %arg5, [[CST_8]], [[CST_32]], [[CST_16]], [[CST_1]], [[CST_FALSE_1]], [[CST_FALSE_2]], [[ZERO]]) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<16xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i16 + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=32, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_16r16x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_32r16x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=32, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v4i32 + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<4xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i32 + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_16r8x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32 + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_32r8x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=32, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_8b_16r32x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_8b_32r32x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<64xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=32, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<64xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_16r16x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_32r16x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<64xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=32, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<64xi16> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_32b_8r8x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_16r8x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=16, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_32r8x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=32, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transform_8b_32r16x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=32, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transform_8b_32r16x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=32, v_blocks=2, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transform_8b_32r16x4cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=32, v_blocks=4, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=32, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=2, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=32, v_blocks=2, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi32> + llvm.return +} + +// ----- + +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=16, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index 9e927b515d..5787545ed8 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -297,48 +297,6 @@ llvm.func @triton_gen.dpas.f32(%c : vector<8xf32>, %a : vector<4xf32>, %b : vect // ----- -// CHECK: llvm.func spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {passthrough = ["nounwind"]} - -llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) { - // CHECK: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK-NEXT: [[DEST:%.*]] = llvm.alloca [[C16]] x i16 : (i32) -> !llvm.ptr - // CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK-DAG: [[UNDEF:%.*]] = llvm.mlir.undef : vector<2xi32> - // CHECK-NEXT: [[COORD0:%.*]] = llvm.insertelement %arg4, [[UNDEF]][[[ZERO]] : i32] : vector<2xi32> - // CHECK-NEXT: [[COORD1:%.*]] = llvm.insertelement %arg5, [[COORD0]][[[ONE]] : i32] : vector<2xi32> - // CHECK-NEXT: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, [[COORD1]], [[DEST]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () - // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16> - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> - llvm.return -} - -// ----- - -// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v4i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<4xi32> - -llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) { - // CHECK-DAG: [[PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr to i64 - // CHECK-DAG: [[CST_32:%.*]] = llvm.mlir.constant(32 : i32) : i32 - // CHECK-DAG: [[CST_8a:%.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-DAG: [[CST_8b:%.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-DAG: [[CST_1:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK-DAG: [[CST_FALSE_1:%.*]] = llvm.mlir.constant(false) : i1 - // CHECK-DAG: [[CST_FALSE_2:%.*]] = llvm.mlir.constant(false) : i1 - // CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK-DAG: [[WIDTH:%.*]] = llvm.sub %arg1, [[ONE]] : i32 - // CHECK-DAG: [[HEIGHT:%.*]] = llvm.sub %arg2, [[ONE]] : i32 - // CHECK-DAG: [[PITCH:%.*]] = llvm.sub %arg3, [[ONE]] : i32 - // CHECK-NEXT: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v4i32([[PTR]], [[WIDTH]], [[HEIGHT]], [[PITCH]], %arg4, %arg5, [[CST_32]], [[CST_8a]], [[CST_8b]], [[CST_1]], [[CST_FALSE_1]], [[CST_FALSE_2]], [[ZERO]]) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<4xi32> - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi32> - llvm.return -} - -// ----- - // CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8f32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32, vector<8xf32>) llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xf32>) { diff --git a/test/TritonIntelGPU/load-to-llvm-2dload.mlir b/test/TritonIntelGPU/load-to-llvm-2dload.mlir index d4c2f0f848..4aaa2ca823 100644 --- a/test/TritonIntelGPU/load-to-llvm-2dload.mlir +++ b/test/TritonIntelGPU/load-to-llvm-2dload.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm // CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]} -// CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi32> -// CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi16> +// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {passthrough = ["nounwind"]} +// CHECK-DAG: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {passthrough = ["nounwind"]} #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}> #dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], A = [8, 16], B = [16, 16], C = [8, 16]}> #dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}> @@ -14,8 +14,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war %c1_i64 = arith.constant 1 : i64 %ptrA = tt.make_tensor_ptr %arg0, [%arg2, %arg4], [%arg5, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > %ptrB = tt.make_tensor_ptr %arg1, [%arg4, %arg3], [%arg7, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > - // CHECK-COUNT-4: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i16{{.*}} -> vector<8xi16> - // CHECK-COUNT-4: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i32{{.*}} -> vector<8xi32> + // CHECK-COUNT-4: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-COUNT-4: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () // CHECK-COUNT-8: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> %A = tt.load %ptrA {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> %B = tt.load %ptrB {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> diff --git a/test/TritonIntelGPU/prefetch-to-llvm.mlir b/test/TritonIntelGPU/prefetch-to-llvm.mlir index 445296a452..da21907381 100644 --- a/test/TritonIntelGPU/prefetch-to-llvm.mlir +++ b/test/TritonIntelGPU/prefetch-to-llvm.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm // CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]} -// CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi32> -// CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i16(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xi16> +// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {passthrough = ["nounwind"]} +// CHECK-DAG: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {passthrough = ["nounwind"]} // CHECK-DAG: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}> #dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], A = [8, 16], B = [16, 16], C = [8, 16]}> @@ -12,9 +12,9 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war tt.func public @matmul_with_prefetch(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) { // CHECK-LABEL: @matmul_with_prefetch // CHECK-COUNT-2: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid{{.*}} -> () - // CHECK-COUNT-1: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i16{{.*}} -> vector<8xi16> - // CHECK-COUNT-1: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i32{{.*}} -> vector<8xi32> - // CHECK-COUNT-1: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> %C = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas> %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 1c0e43ad93..0d59557651 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -251,16 +251,23 @@ static LLVM::CallOp createGenISADPAS(TritonGEN::MatrixDPASOp op, } static bool isOCLBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) { - if (op.getVnniTransform() || op.getTranspose()) + // intel_sub_group_2d_block_read_32b_8r8x1c is expected to be lowered to + // llvm.genx.GenISA.LSC2DBlockRead.v4i32, but it is incorrectly lowered to + // llvm.genx.GenISA.LSC2DBlockRead.v8i32. + if (op.getElemSizeInBits() == 32 && op.getTileHeight() == 8 && + op.getTileWidth() == 8 && op.getVBlocks() == 1) return false; - if (op.getElemSizeInBits() == 32) + // Missing intel_sub_group_2d_block_read_32b_8r16x1c and + // intel_sub_group_2d_block_read_32b_16r16x1c. + if (op.getElemSizeInBits() == 32 && op.getTileWidth() == 16 && + op.getVBlocks() == 1) return false; - if (op.getTileHeight() > 8) - return false; - - if (op.getVBlocks() != 2) + // Missing intel_sub_group_2d_block_read_8b_16r32x1c and + // intel_sub_group_2d_block_read_8b_32r32x1c. + if (op.getElemSizeInBits() == 8 && op.getTileHeight() > 8 && + op.getTileWidth() == 32 && op.getVBlocks() == 1) return false; if (op.getCacheControl() != TritonGEN::LoadCacheControl::DEFAULT) @@ -280,13 +287,18 @@ static Value createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op, auto dest = rewriter.create( loc, ptr_ty(context), resType.getElementType(), i32_val(resType.getNumElements())); - std::string fnName = "intel_sub_group_2d_block_read_" + - std::to_string(op.getElemSizeInBits()) + "b_" + - std::to_string(op.getTileHeight()) + "r" + - std::to_string(op.getTileWidth()) + "x" + - std::to_string(op.getVBlocks()) + "c"; - fnName = - "_Z" + std::to_string(fnName.size()) + fnName + "PU3AS1viiiDv2_iPt"; + std::string fnName = "intel_sub_group_2d_block_read_"; + if (op.getVnniTransform()) + fnName += "transform_"; + else if (op.getTranspose()) + fnName += "transpose_"; + fnName += std::to_string(op.getElemSizeInBits()) + "b_" + + std::to_string(op.getTileHeight()) + "r" + + std::to_string(op.getTileWidth()) + "x" + + std::to_string(op.getVBlocks()) + "c"; + fnName = "_Z" + std::to_string(fnName.size()) + fnName + "PU3AS1viiiDv2_iP"; + fnName += + (resType.getElementType().getIntOrFloatBitWidth() == 32) ? "j" : "t"; VectorType vecType = vec_ty(i32_ty, 2); Value byteCoord = insert_element( vecType, insert_element(vecType, undef(vecType), op.getX(), i32_val(0)),