diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index b467114c72f7d..363a7c1a1a557 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1935,11 +1935,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, unpackOp.getDestType().hasStaticShape() ? vectorSizes : shapeCastOp.getResultVectorType().getShape()); - Value dest = rewriter.create( - loc, reifiedRetShapes[0], - shapeCastOp.getResult().getType().getElementType()); Operation *write = createWriteOrMaskedWrite( - rewriter, loc, shapeCastOp.getResult(), dest, + rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(), /*writeIndices=*/{}, useInBoundsInsteadOfMasking); newResults.push_back(write->getResult(0)); return success(); diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 6722de817f6bf..11c86f1c31406 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -1158,6 +1158,7 @@ module attributes {transform.with_named_sequence} { // ----- // CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack +// CHECK-SAME: %[[ARG_0:.*]]: tensor, func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[C0:.*]] = arith.constant 0 // CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor @@ -1175,9 +1176,8 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor, %arg1: t // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32> // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32> // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32> -// CHECK: %[[empt0:.*]] = tensor.empty // CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1> -// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]] +// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[ARG_0]] // CHECK: return %[[write0]] %ret = linalg.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor -> tensor return %ret : tensor @@ -1193,6 +1193,8 @@ module attributes {transform.with_named_sequence} { // ----- // CHECK-LABEL: func @test_vectorize_unpack +// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32> +// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32> func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[C0:.*]]= arith.constant 0 : index @@ -1201,15 +1203,14 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2 // CHECK: %[[C32:.*]] = arith.constant 32 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<16x8x32x16xi1> - // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32> + // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] { vector.transfer_read %[[SRC]]{{.*}}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32> // CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32> // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32> - // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32> // CHECK: %[[C01:.*]] = arith.constant 0 : index // CHECK: %[[C256:.*]] = arith.constant 256 : index // CHECK: %[[C128:.*]] = arith.constant 128 : index // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C256]], %[[C128]] : vector<512x128xi1> - // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<512x128xi1> -> tensor<256x128xf32> + // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] { vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<512x128xi1> -> tensor<256x128xf32> // CHECK: return %[[WRIT]] : tensor<256x128xf32> %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32> @@ -1225,15 +1226,16 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2 // ----- // CHECK-LABEL: func @test_vectorize_unpack_no_masks +// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32> +// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32> func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> + // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32> // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32> - // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32> // CHECK: %[[C00:.*]] = arith.constant 0 : index - // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32> + // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32> // CHECK: return %[[WRIT]] : tensor<256x128xf32> %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32> @@ -1248,16 +1250,17 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: // ----- - // CHECK-LABEL: test_vectorize_unpack_with_outer_perm +// CHECK-LABEL: test_vectorize_unpack_with_outer_perm +// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32> +// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32> func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> + // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32> // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32> - // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32> // CHECK: %[[C00:.*]] = arith.constant 0 : index - // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32> + // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32> // CHECK: return %[[WRIT]] : tensor<256x128xf32> %0 = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32> @@ -1327,15 +1330,17 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes +// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32> +// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32> func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> + // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32> // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32> - // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32> // CHECK: %[[C00:.*]] = arith.constant 0 : index - // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32> + // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<256x128xf32>, tensor<256x128xf32> // CHECK: return %[[WRIT]] : tensor<256x128xf32> %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32> @@ -1350,15 +1355,17 @@ func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, // ----- +// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes_slice_output +// CHECK-SAME: %[[SRC:.*]]: tensor<8x4x16x16xf32> +// CHECK-SAME: %[[DEST:.*]]: tensor<64x127xf32> func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x4x16x16xf32>, %dest: tensor<64x127xf32>) -> tensor<64x127xf32> { // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32> + // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32> // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32> // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x16x8x16xf32> to vector<64x128xf32> - // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x127xf32> // CHECK: %[[C00:.*]] = arith.constant 0 : index - // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[EMPT]]{{\[}}%[[C00]], %[[C00]]] + // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]] // CHECK-SAME: {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32> // CHECK: return %[[WRIT]] : tensor<64x127xf32> %0 = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32> @@ -1374,18 +1381,20 @@ func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x // ----- +// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes_permute +// CHECK-SAME: %[[SRC:.*]]: tensor<4x7x4xf32> +// CHECK-SAME: %[[DEST:.*]]: tensor<7x16xf32> func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf32>, %dest: tensor<7x16xf32>) -> tensor<7x16xf32> { %0 = linalg.unpack %source outer_dims_perm=[1, 0] inner_dims_pos = [1] inner_tiles = [4] into %dest : tensor<4x7x4xf32> -> tensor<7x16xf32> return %0 : tensor<7x16xf32> } // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<4x7x4xf32>, vector<4x7x4xf32> + // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<4x7x4xf32>, vector<4x7x4xf32> // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 0, 2] : vector<4x7x4xf32> to vector<7x4x4xf32> // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<7x4x4xf32> to vector<7x16xf32> - // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<7x16xf32> // CHECK: %[[C00:.*]] = arith.constant 0 : index - // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<7x16xf32>, tensor<7x16xf32> + // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}}} : vector<7x16xf32>, tensor<7x16xf32> // CHECK: return %[[WRIT]] : tensor<7x16xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {