From 0226e9b72cdfac21bc0758144ff1ecd14e8bc63e Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 7 Jun 2024 16:05:14 +0000 Subject: [PATCH 1/7] [GEN] Enhance matrix operators verification Signed-off-by: Whitney Tsang --- .../lib/Dialect/TritonGEN/IR/TritonGENOps.cpp | 93 ++++++++++++++----- 1 file changed, 69 insertions(+), 24 deletions(-) diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index d9d2009a1b..45f29ecc4d 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -22,12 +22,18 @@ using namespace mlir::triton; // Utility functions //===----------------------------------------------------------------------===// -template static LogicalResult verifyInput(Op op) { +template static LogicalResult verifyMatrixInput(Op op) { static_assert(llvm::is_one_of::value, "Unexpected template parameter"); + std::optional width = getConstantIntValue(op.getBaseWidth()); + std::optional pitch = getConstantIntValue(op.getBasePitch()); + if (pitch && width && *pitch < *width) + return op->emitOpError( + "4th operand (base pitch) should be >= 2nd operand (base width)"); + if (op.getElemSizeInBits() != 8 && op.getElemSizeInBits() != 16 && op.getElemSizeInBits() != 32) return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32"); @@ -36,43 +42,61 @@ template static LogicalResult verifyInput(Op op) { return op->emitOpError( "transpose and vnni transform are mutually exclusive"); - std::optional width = getConstantIntValue(op.getBaseWidth()); - std::optional pitch = getConstantIntValue(op.getBasePitch()); - if (pitch && width && *pitch < *width) - return op->emitOpError( - "4th operand (base pitch) should be >= 2nd operand (base width)"); + if (op.getTranspose() && op.getElemSizeInBits() != 32) + return op->emitOpError("transpose is only supported for 32 bit elements"); + + if (op.getVnniTransform() && op.getElemSizeInBits() == 32) + return op->emitOpError("vnni transform is only supported for 8 and 16 bit " + "elements"); - uint32_t TileHeight = op.getTileHeight(); - if (TileHeight != 1 && TileHeight != 2 && TileHeight != 4 && - TileHeight != 8 && TileHeight != 16 && TileHeight != 32) + uint32_t tileHeight = op.getTileHeight(); + if (tileHeight != 1 && tileHeight != 2 && tileHeight != 4 && + tileHeight != 8 && tileHeight != 16 && tileHeight != 32) return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32"); - uint32_t TileWidth = op.getTileWidth(); + uint32_t vBlocks = op.getVBlocks(); + if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 && vBlocks != 8) + return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8"); + + uint32_t tileWidth = op.getTileWidth(); switch (op.getElemSizeInBits()) { - case 32: - if (TileWidth != 8 && TileWidth != 16) - return op->emitOpError( - "tile_width for 32 bit elements should be equal " - "to systolic depth, i.e., 8 elements, for matrix A or subgroup size, " - "i.e., 16 elements, for matrix B"); - break; case 16: - if (TileWidth != 16) + if (tileWidth != 16) return op->emitOpError("tile_width for 16 bit elements should be equal " "to systolic depth times 2, i.e., 16 elements"); break; case 8: - if (TileWidth != 32) + if (tileWidth != 32) return op->emitOpError("tile_width for 8 bit elements should be equal " "to systolic depth times 4, i.e., 32 elements"); break; - default: - return op->emitOpError("element size should be 8, 16 or 32 bits"); } return success(); } +template static LogicalResult verifyMatrixReadInput(Op op) { + static_assert(llvm::is_one_of::value, + "Unexpected template parameter"); + + uint32_t tileWidth = op.getTileWidth(); + if (op.getVnniTransform()) { + if (tileWidth != 16) + return op->emitOpError("tile_width for vnni transform should be equal " + "to subgroup size, i.e., 16 elements"); + return success(); + } + + // When reading matrix B of 32 bit elements, it does not need to be vnni transformed. + if (op.getElemSizeInBits() == 32 && tileWidth != 8 && tileWidth != 16) + return op->emitOpError("tile_width for 32 bit elements should be equal " + "to systolic depth, i.e., 8 elements, for matrix A or " + "subgroup size, i.e., 16 elements, for matrix B"); + + return success(); +} + //===----------------------------------------------------------------------===// // gen.sub_group_reduce //===----------------------------------------------------------------------===// @@ -173,7 +197,18 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() { //===----------------------------------------------------------------------===// LogicalResult TritonGEN::Matrix2DBlockLoadOp::verify() { - return verifyInput(*this); + if (verifyMatrixInput(*this).failed()) + return failure(); + + VectorType resTy = getRes().getType(); + unsigned resSize = + resTy.getNumElements() * resTy.getElementType().getIntOrFloatBitWidth(); + unsigned subgroupSize = TritonGEN::getSubgroupSize(*this); + if (resSize * subgroupSize != + getElemSizeInBits() * getTileHeight() * getTileWidth() * getVBlocks()) + return emitOpError("result size does not match the expected size"); + + return verifyMatrixReadInput(*this); } //===----------------------------------------------------------------------===// @@ -181,7 +216,14 @@ LogicalResult TritonGEN::Matrix2DBlockLoadOp::verify() { //===----------------------------------------------------------------------===// LogicalResult TritonGEN::Matrix2DBlockStoreOp::verify() { - return verifyInput(*this); + if (verifyMatrixInput(*this).failed()) + return failure(); + + if (getElemSizeInBits() == 32 && getTileWidth() != 8) + return emitOpError("tile_width for 32 bit elements should be equal " + "to systolic depth, i.e., 8 elements"); + + return success(); } //===----------------------------------------------------------------------===// @@ -189,5 +231,8 @@ LogicalResult TritonGEN::Matrix2DBlockStoreOp::verify() { //===----------------------------------------------------------------------===// LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() { - return verifyInput(*this); + if (verifyMatrixInput(*this).failed()) + return failure(); + + return verifyMatrixReadInput(*this); } From 7f34699cba4fa492a94770f83c8416bbe9d02cf2 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 7 Jun 2024 16:57:31 +0000 Subject: [PATCH 2/7] [GEN] Enhance matrix operators verification Signed-off-by: Whitney Tsang --- .../lib/Dialect/TritonGEN/IR/TritonGENOps.cpp | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index 45f29ecc4d..410ce81c39 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -40,13 +40,13 @@ template static LogicalResult verifyMatrixInput(Op op) { if (op.getTranspose() && op.getVnniTransform()) return op->emitOpError( - "transpose and vnni transform are mutually exclusive"); + "transpose and vnni_transform are mutually exclusive"); if (op.getTranspose() && op.getElemSizeInBits() != 32) return op->emitOpError("transpose is only supported for 32 bit elements"); if (op.getVnniTransform() && op.getElemSizeInBits() == 32) - return op->emitOpError("vnni transform is only supported for 8 and 16 bit " + return op->emitOpError("vnni_transform is only supported for 8 and 16 bit " "elements"); uint32_t tileHeight = op.getTileHeight(); @@ -59,6 +59,14 @@ template static LogicalResult verifyMatrixInput(Op op) { return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8"); uint32_t tileWidth = op.getTileWidth(); + if (op.getVnniTransform()) { + if (tileWidth != 16) + return op->emitOpError( + "tile_width when vnni_transform is true should be equal " + "to subgroup size (16 elements)"); + return success(); + } + switch (op.getElemSizeInBits()) { case 16: if (tileWidth != 16) @@ -81,18 +89,10 @@ template static LogicalResult verifyMatrixReadInput(Op op) { "Unexpected template parameter"); uint32_t tileWidth = op.getTileWidth(); - if (op.getVnniTransform()) { - if (tileWidth != 16) - return op->emitOpError("tile_width for vnni transform should be equal " - "to subgroup size, i.e., 16 elements"); - return success(); - } - - // When reading matrix B of 32 bit elements, it does not need to be vnni transformed. if (op.getElemSizeInBits() == 32 && tileWidth != 8 && tileWidth != 16) - return op->emitOpError("tile_width for 32 bit elements should be equal " - "to systolic depth, i.e., 8 elements, for matrix A or " - "subgroup size, i.e., 16 elements, for matrix B"); + return op->emitOpError("tile_width for 32 bit elements should be equal to " + "systolic depth (8 elements) for matrix A and the " + "subgroup size (16 elements) for matrix B"); return success(); } @@ -203,10 +203,13 @@ LogicalResult TritonGEN::Matrix2DBlockLoadOp::verify() { VectorType resTy = getRes().getType(); unsigned resSize = resTy.getNumElements() * resTy.getElementType().getIntOrFloatBitWidth(); - unsigned subgroupSize = TritonGEN::getSubgroupSize(*this); - if (resSize * subgroupSize != - getElemSizeInBits() * getTileHeight() * getTileWidth() * getVBlocks()) - return emitOpError("result size does not match the expected size"); + constexpr unsigned subgroupSize = 16; + unsigned expectedSize = getElemSizeInBits() * getTileHeight() * + getTileWidth() * getVBlocks() / subgroupSize; + if (resSize != expectedSize) + return emitOpError() << "result size of " << resSize + << " bits does not match the expected size of " + << expectedSize << " bits"; return verifyMatrixReadInput(*this); } From 9e29c48b65d3f72d8058f20f6e3783f7c6a2949f Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 7 Jun 2024 19:09:36 +0000 Subject: [PATCH 3/7] fix existing tests Signed-off-by: Whitney Tsang --- test/TritonGEN/tritongen-invalid.mlir | 26 +++++++++---------- test/TritonGEN/tritongen-to-llvm.mlir | 6 ++--- .../lib/Dialect/TritonGEN/IR/TritonGENOps.cpp | 11 ++++++-- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index 2e682fb7c9..a0962362b8 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -179,15 +179,15 @@ llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8 llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockload' op expecting 'elem_size_in_bits' to be 8, 16, or 32}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<1xi16> llvm.return } // ----- llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // expected-error @+1 {{'triton_gen.2Dblockload' op transpose and vnni transform are mutually exclusive}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=4, tile_height=1, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi32> + // expected-error @+1 {{'triton_gen.2Dblockload' op transpose and vnni_transform are mutually exclusive}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=2, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<1xi32> llvm.return } @@ -197,15 +197,15 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y %base_width = llvm.mlir.constant(4 : i32) : i32 %base_pitch = llvm.mlir.constant(2 : i32) : i32 // expected-error @+1 {{'triton_gen.2Dblockload' op 4th operand (base pitch) should be >= 2nd operand (base width)}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=2, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<1xi32> llvm.return } // ----- llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width for 32 bit elements should be equal to systolic depth, i.e., 8 elements, for matrix A or subgroup size, i.e., 16 elements, for matrix B}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=5, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<5xf32> + // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width for 32 bit elements should be equal to systolic depth (8 elements) for matrix A and the subgroup size (16 elements) for matrix B}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<2xi32> llvm.return } @@ -213,7 +213,7 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width for 16 bit elements should be equal to systolic depth times 2, i.e., 16 elements}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xf16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=32, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<2xi16> llvm.return } @@ -221,7 +221,7 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width for 8 bit elements should be equal to systolic depth times 4, i.e., 32 elements}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi8> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=2, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<1xi16> llvm.return } @@ -229,7 +229,7 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockload' op expecting tile_height to be 1, 2, 4, 8, 16, or 32}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=64, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xf32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=64, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<32xi32> llvm.return } @@ -244,7 +244,7 @@ llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height // ----- llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<4xi32>) { - // expected-error @+1 {{'triton_gen.2Dblockstore' op transpose and vnni transform are mutually exclusive}} + // expected-error @+1 {{'triton_gen.2Dblockstore' op transpose and vnni_transform are mutually exclusive}} triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=4, tile_height=1, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi32>) llvm.return } @@ -262,7 +262,7 @@ llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %x : i32, % // ----- llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<4xf32>) { - // expected-error @+1 {{'triton_gen.2Dblockstore' op tile_width for 32 bit elements should be equal to systolic depth, i.e., 8 elements, for matrix A or subgroup size, i.e., 16 elements, for matrix B}} + // expected-error @+1 {{'triton_gen.2Dblockstore' op tile_width for 32 bit elements should be equal to systolic depth, i.e., 8 elements}} triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xf32>) llvm.return } @@ -302,7 +302,7 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei // ----- llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // expected-error @+1 {{'triton_gen.2Dblockprefetch' op transpose and vnni transform are mutually exclusive}} + // expected-error @+1 {{'triton_gen.2Dblockprefetch' op transpose and vnni_transform are mutually exclusive}} triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=4, tile_height=1, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } @@ -320,7 +320,7 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_height : i32, %x : i32 // ----- llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 32 bit elements should be equal to systolic depth, i.e., 8 elements, for matrix A or subgroup size, i.e., 16 elements, for matrix B}} + // expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 32 bit elements should be equal to systolic depth (8 elements) for matrix A and the subgroup size (16 elements) for matrix B}} triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=5, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index 51d1f6278b..151aaada2d 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -313,7 +313,7 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_ // ----- -// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8f32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xf32> +// CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v4i32(i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<4xi32> llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) { @@ -329,8 +329,8 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_hei // CHECK-DAG: [[WIDTH:%.*]] = llvm.sub %arg1, [[ONE]] : i32 // CHECK-DAG: [[HEIGHT:%.*]] = llvm.sub %arg2, [[ONE]] : i32 // CHECK-DAG: [[PITCH:%.*]] = llvm.sub %arg3, [[ONE]] : i32 - // CHECK-NEXT: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8f32([[PTR]], [[WIDTH]], [[HEIGHT]], [[PITCH]], %arg4, %arg5, [[CST_32]], [[CST_8a]], [[CST_8b]], [[CST_1]], [[CST_FALSE_1]], [[CST_FALSE_2]], [[ZERO]]) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<8xf32> - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xf32> + // CHECK-NEXT: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v4i32([[PTR]], [[WIDTH]], [[HEIGHT]], [[PITCH]], %arg4, %arg5, [[CST_32]], [[CST_8a]], [[CST_8b]], [[CST_1]], [[CST_FALSE_1]], [[CST_FALSE_2]], [[ZERO]]) : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> vector<4xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi32> llvm.return } diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index 410ce81c39..f7eac7d3f1 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -201,8 +201,15 @@ LogicalResult TritonGEN::Matrix2DBlockLoadOp::verify() { return failure(); VectorType resTy = getRes().getType(); - unsigned resSize = - resTy.getNumElements() * resTy.getElementType().getIntOrFloatBitWidth(); + unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth(); + if (getElemSizeInBits() == 32 || getVnniTransform()) { + if (resElemTySize != 32) + return emitOpError() << "expecting result element type to be 32 bits"; + } else if (resElemTySize != 16) { + return emitOpError() << "expecting result element type to be 16 bits"; + } + + unsigned resSize = resTy.getNumElements() * resElemTySize; constexpr unsigned subgroupSize = 16; unsigned expectedSize = getElemSizeInBits() * getTileHeight() * getTileWidth() * getVBlocks() / subgroupSize; From f272d0479a87e45e9428fb847a6b0b4c27d9653a Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 7 Jun 2024 19:39:11 +0000 Subject: [PATCH 4/7] Add new 2dblockload invalid tests Signed-off-by: Whitney Tsang --- test/TritonGEN/tritongen-invalid.mlir | 82 ++++++++++++++++++++++----- 1 file changed, 69 insertions(+), 13 deletions(-) diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index a0962362b8..7f884c6e1b 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -177,9 +177,19 @@ llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8 // ----- +llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32) { + %base_width = llvm.mlir.constant(4 : i32) : i32 + %base_pitch = llvm.mlir.constant(2 : i32) : i32 + // expected-error @+1 {{'triton_gen.2Dblockload' op 4th operand (base pitch) should be >= 2nd operand (base width)}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return +} + +// ----- + llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockload' op expecting 'elem_size_in_bits' to be 8, 16, or 32}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<1xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=4, tile_height=8, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16> llvm.return } @@ -187,25 +197,47 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockload' op transpose and vnni_transform are mutually exclusive}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=2, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<1xi32> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=16, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16> llvm.return } // ----- -llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32) { - %base_width = llvm.mlir.constant(4 : i32) : i32 - %base_pitch = llvm.mlir.constant(2 : i32) : i32 - // expected-error @+1 {{'triton_gen.2Dblockload' op 4th operand (base pitch) should be >= 2nd operand (base width)}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=2, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<1xi32> +llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockload' op transpose is only supported for 32 bit elements}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16> llvm.return } // ----- llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width for 32 bit elements should be equal to systolic depth (8 elements) for matrix A and the subgroup size (16 elements) for matrix B}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<2xi32> + // expected-error @+1 {{'triton_gen.2Dblockload' op vnni_transform is only supported for 8 and 16 bit elements}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi32> + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockload' op expecting tile_height to be 1, 2, 4, 8, 16, or 32}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=64, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<64xi16> + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockload' op expecting v_blocks to be 1, 2, 4, or 8}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=6, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<48xi16> + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width when vnni_transform is true should be equal to subgroup size (16 elements)}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16> llvm.return } @@ -213,7 +245,7 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width for 16 bit elements should be equal to systolic depth times 2, i.e., 16 elements}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=32, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<2xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xi16> llvm.return } @@ -221,15 +253,39 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width for 8 bit elements should be equal to systolic depth times 4, i.e., 32 elements}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=2, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<1xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi16> llvm.return } // ----- llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // expected-error @+1 {{'triton_gen.2Dblockload' op expecting tile_height to be 1, 2, 4, 8, 16, or 32}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=64, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<32xi32> + // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width for 32 bit elements should be equal to systolic depth (8 elements) for matrix A and the subgroup size (16 elements) for matrix B}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockload' op expecting result element type to be 32 bits}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockload' op expecting result element type to be 16 bits}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xi8> + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockload' op result size of 256 bits does not match the expected size of 128 bits}} + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xi16> llvm.return } From 64173aeb081f5ad07713b35ada0da524eade0d3f Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 7 Jun 2024 20:02:47 +0000 Subject: [PATCH 5/7] Add new 2dblockprefetch invalid tests Signed-off-by: Whitney Tsang --- test/TritonGEN/tritongen-invalid.mlir | 58 +++++++++++++++++++++------ 1 file changed, 45 insertions(+), 13 deletions(-) diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index 7f884c6e1b..29c8755de9 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -349,9 +349,19 @@ llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height // ----- +llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32) { + %base_width = llvm.mlir.constant(4 : i32) : i32 + %base_pitch = llvm.mlir.constant(2 : i32) : i32 + // expected-error @+1 {{'triton_gen.2Dblockprefetch' op 4th operand (base pitch) should be >= 2nd operand (base width)}} + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting 'elem_size_in_bits' to be 8, 16, or 32}} - triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=4, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } @@ -359,25 +369,47 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockprefetch' op transpose and vnni_transform are mutually exclusive}} - triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=4, tile_height=1, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=16, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } // ----- -llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32) { - %base_width = llvm.mlir.constant(4 : i32) : i32 - %base_pitch = llvm.mlir.constant(2 : i32) : i32 - // expected-error @+1 {{'triton_gen.2Dblockprefetch' op 4th operand (base pitch) should be >= 2nd operand (base width)}} - triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) +llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockprefetch' op transpose is only supported for 32 bit elements}} + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } // ----- llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 32 bit elements should be equal to systolic depth (8 elements) for matrix A and the subgroup size (16 elements) for matrix B}} - triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=5, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) + // expected-error @+1 {{'triton_gen.2Dblockprefetch' op vnni_transform is only supported for 8 and 16 bit elements}} + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting tile_height to be 1, 2, 4, 8, 16, or 32}} + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=64, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting v_blocks to be 1, 2, 4, or 8}} + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=6, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width when vnni_transform is true should be equal to subgroup size (16 elements)}} + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } @@ -385,7 +417,7 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 16 bit elements should be equal to systolic depth times 2, i.e., 16 elements}} - triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } @@ -393,14 +425,14 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 8 bit elements should be equal to systolic depth times 4, i.e., 32 elements}} - triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } // ----- llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting tile_height to be 1, 2, 4, 8, 16, or 32}} - triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=64, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) + // expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 32 bit elements should be equal to systolic depth (8 elements) for matrix A and the subgroup size (16 elements) for matrix B}} + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } From 4c14bc0783d660ca38bb910327bdf09177454add Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 7 Jun 2024 20:14:29 +0000 Subject: [PATCH 6/7] Add new 2dblockstore invalid tests Signed-off-by: Whitney Tsang --- test/TritonGEN/tritongen-invalid.mlir | 62 +++++++++++++++++++-------- 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index 29c8755de9..91b307ed99 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -291,59 +291,83 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height // ----- -llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<4xi32>) { +llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) { + %base_width = llvm.mlir.constant(4 : i32) : i32 + %base_pitch = llvm.mlir.constant(2 : i32) : i32 + // expected-error @+1 {{'triton_gen.2Dblockstore' op 4th operand (base pitch) should be >= 2nd operand (base width)}} + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>) + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<32xi16>) { // expected-error @+1 {{'triton_gen.2Dblockstore' op expecting 'elem_size_in_bits' to be 8, 16, or 32}} - triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=64, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi32>) + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=64, tile_width=4, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<32xi16>) llvm.return } // ----- -llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<4xi32>) { +llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<16xi8>) { // expected-error @+1 {{'triton_gen.2Dblockstore' op transpose and vnni_transform are mutually exclusive}} - triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=4, tile_height=1, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi32>) + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=16, tile_height=16, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<16xi8>) llvm.return } // ----- -llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<4xi32>) { - %base_width = llvm.mlir.constant(4 : i32) : i32 - %base_pitch = llvm.mlir.constant(2 : i32) : i32 - // expected-error @+1 {{'triton_gen.2Dblockstore' op 4th operand (base pitch) should be >= 2nd operand (base width)}} - triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi32>) +llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi32>) { + // expected-error @+1 {{'triton_gen.2Dblockstore' op vnni_transform is only supported for 8 and 16 bit elements}} + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi32>) llvm.return } // ----- -llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<4xf32>) { - // expected-error @+1 {{'triton_gen.2Dblockstore' op tile_width for 32 bit elements should be equal to systolic depth, i.e., 8 elements}} - triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xf32>) +llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<64xi8>) { + // expected-error @+1 {{'triton_gen.2Dblockstore' op expecting tile_height to be 1, 2, 4, 8, 16, or 32}} + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=64, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<64xi8>) llvm.return } // ----- -llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<4xf16>) { +llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) { + // expected-error @+1 {{'triton_gen.2Dblockstore' op expecting v_blocks to be 1, 2, 4, or 8}} + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=6, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>) + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) { + // expected-error @+1 {{'triton_gen.2Dblockstore' op tile_width when vnni_transform is true should be equal to subgroup size (16 elements)}} + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>) + llvm.return +} + +// ----- + +llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) { // expected-error @+1 {{'triton_gen.2Dblockstore' op tile_width for 16 bit elements should be equal to systolic depth times 2, i.e., 16 elements}} - triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xf16>) + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi16>) llvm.return } // ----- -llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<4xi8>) { +llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) { // expected-error @+1 {{'triton_gen.2Dblockstore' op tile_width for 8 bit elements should be equal to systolic depth times 4, i.e., 32 elements}} - triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=4, tile_height=1, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi8>) + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>) llvm.return } // ----- -llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xf32>) { - // expected-error @+1 {{'triton_gen.2Dblockstore' op expecting tile_height to be 1, 2, 4, 8, 16, or 32}} - triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=8, tile_height=64, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xf32>) +llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi32>) { + // expected-error @+1 {{'triton_gen.2Dblockstore' op tile_width for 32 bit elements should be equal to systolic depth, i.e., 8 elements}} + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi32>) llvm.return } From 23638f444f8c07637ddedb3862e72e4794ff9c6c Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Sat, 8 Jun 2024 01:13:00 +0000 Subject: [PATCH 7/7] address review comments Signed-off-by: Whitney Tsang --- test/TritonGEN/tritongen-invalid.mlir | 6 +++--- third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index 91b307ed99..4c271f3593 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -189,7 +189,7 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { // expected-error @+1 {{'triton_gen.2Dblockload' op expecting 'elem_size_in_bits' to be 8, 16, or 32}} - %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=4, tile_height=8, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16> + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=4, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16> llvm.return } @@ -260,7 +260,7 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height // ----- llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width for 32 bit elements should be equal to systolic depth (8 elements) for matrix A and the subgroup size (16 elements) for matrix B}} + // expected-error @+1 {{'triton_gen.2Dblockload' op tile_width for 32 bit elements should be equal to either be 8 or 16}} %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xi32> llvm.return } @@ -456,7 +456,7 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei // ----- llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { - // expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 32 bit elements should be equal to systolic depth (8 elements) for matrix A and the subgroup size (16 elements) for matrix B}} + // expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 32 bit elements should be equal to either be 8 or 16}} triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) llvm.return } diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index f7eac7d3f1..31d89bb263 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -90,9 +90,8 @@ template static LogicalResult verifyMatrixReadInput(Op op) { uint32_t tileWidth = op.getTileWidth(); if (op.getElemSizeInBits() == 32 && tileWidth != 8 && tileWidth != 16) - return op->emitOpError("tile_width for 32 bit elements should be equal to " - "systolic depth (8 elements) for matrix A and the " - "subgroup size (16 elements) for matrix B"); + return op->emitOpError( + "tile_width for 32 bit elements should be equal to either be 8 or 16"); return success(); }