diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 085f879c2d0e6..ae23a4158155f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2364,11 +2364,48 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, return success(); } +/// Replace `vector.extract` with `vector.shape_cast`. +/// +/// BEFORE: +/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> +/// AFTER: +/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> +/// +/// The canonical form of vector operations that reshape vectors is shape_cast. +struct ExtractToShapeCast final : OpRewritePattern { + using Base::Base; + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + VectorType sourceType = extractOp.getSourceVectorType(); + VectorType outType = dyn_cast(extractOp.getType()); + if (!outType) + return failure(); + + if (sourceType.getNumElements() != outType.getNumElements()) + return rewriter.notifyMatchFailure( + extractOp, "extract to vector with fewer elements"); + + // Negative values in `position` means that the extacted value is poison. + // There is a vector.extract folder for this. + if (llvm::any_of(extractOp.getMixedPosition(), + [](OpFoldResult v) { return !isConstantIntValue(v, 0); })) + return rewriter.notifyMatchFailure(extractOp, + "leaving for extract poison folder"); + + rewriter.replaceOpWithNewOp(extractOp, outType, + extractOp.getSource()); + + return success(); + } +}; + } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); results.add(foldExtractFromShapeCastToShapeCast); results.add(foldExtractFromFromElements); } @@ -3081,13 +3118,43 @@ struct BroadcastFolder : public OpRewritePattern { return success(); } }; + +/// Replace `vector.broadcast` with `vector.shape_cast`. +/// +/// BEFORE: +/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// AFTER: +/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// +/// The canonical form of vector operations that reshape vectors is shape_cast. +struct BroadcastToShapeCast final + : public OpRewritePattern { + using Base::Base; + LogicalResult matchAndRewrite(vector::BroadcastOp broadcast, + PatternRewriter &rewriter) const override { + + auto sourceType = dyn_cast(broadcast.getSourceType()); + if (!sourceType) { + return rewriter.notifyMatchFailure( + broadcast, "source is a scalar, shape_cast doesn't support scalar"); + } + + VectorType outType = broadcast.getType(); + if (sourceType.getNumElements() != outType.getNumElements()) { + return rewriter.notifyMatchFailure( + broadcast, "broadcast to a greater number of elements"); + } + + rewriter.replaceOpWithNewOp(broadcast, outType, + broadcast.getSource()); + return success(); + } +}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // BroadcastToShapeCast is not a default canonicalization, it is opt-in by - // calling `populateCastAwayVectorLeadingOneDimPatterns` - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -6552,10 +6619,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final } }; -/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either -/// i) Y = ShapeCast(X), or -/// ii) Y = Broadcast(X) -/// If both (i) and (ii) are possible, (i) is chosen. +/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X) class ShapeCastBroadcastFolder final : public OpRewritePattern { public: using Base::Base; @@ -6570,22 +6634,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern { auto srcVectorType = dyn_cast(broadcastOp.getSourceType()); bool srcIsScalar = !srcVectorType; - // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X). - // Example: - // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32> - // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32> - // to - // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32> - if (srcVectorType) { - if (srcVectorType.getNumElements() == - shapeCastOp.getResultVectorType().getNumElements()) { - rewriter.replaceOpWithNewOp( - shapeCastOp, shapeCastOp.getResultVectorType(), - broadcastOp.getSource()); - return success(); - } - } - // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X) // Example // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32> diff --git a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir index 4bb40bef9fba2..4660cc75a1940 100644 --- a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir @@ -24,7 +24,7 @@ func.func @conv(%arg0: tensor<1x1080x1962x48xi32>, %arg1: tensor<1x43x48xi32>) - // Loop over the Filter width dim // CHECK: scf.for %{{.*}} = %[[C0]] to %[[C_43]] step %[[C1]] {{.*}} -> (tensor<1x1x4x?xi32>) { // CHECK-NOT: vector.mask -// CHECK: vector.broadcast {{.*}} : vector<[4]xi32> to vector<1x4x[4]xi32> +// CHECK: vector.broadcast {{.*}} : vector<1x[4]xi32> to vector<1x4x[4]xi32> // CHECK-NEXT: arith.muli {{.*}} : vector<1x4x[4]xi32> // CHECK-NEXT: arith.addi {{.*}} : vector<1x4x[4]xi32> // CHECK-NOT: vector.mask diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index a30eda1e06cf8..8126389212ce6 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -814,7 +814,7 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector, // CHECK-LABEL: negative_fold_extract_broadcast // CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32> -// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32> +// CHECK: vector.shape_cast %{{.*}} : vector<1x1x4xf32> to vector<4xf32> func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> { %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32> %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32> @@ -931,7 +931,7 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde // CHECK-LABEL: fold_extract_broadcastlike_shape_cast // CHECK-SAME: %[[A:.*]]: vector<1xf32> -// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32> +// CHECK: %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32> // CHECK: return %[[R]] : vector<1x1xf32> func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index) -> vector<1x1xf32> { @@ -1561,7 +1561,7 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> { // ----- -// Check the case where the same dimension is both broadcasted and sliced +// Check the case where the same dimension is both broadcasted and sliced // CHECK-LABEL: func @extract_strided_broadcast5 // CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>) // CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32> @@ -1918,7 +1918,7 @@ func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>, // CHECK-SAME: (%[[V0:.*]]: vector<4x8xf32>, %[[MEM:.*]]: tensor<1x1x4x8xf32>) // CHECK-NOT: vector.transfer_write // CHECK-NOT: vector.transfer_read -// CHECK: %[[RET:.+]] = vector.broadcast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32> +// CHECK: %[[RET:.+]] = vector.shape_cast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32> // CHECK: return %[[RET]] func.func @store_to_load_tensor_forwarding_unit_dim_broadcast( %vec: vector<4x8xf32>, @@ -2197,8 +2197,8 @@ func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> { // CHECK-LABEL: func @insert_extract_to_broadcast // CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) -// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32> -// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> +// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> // CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>, %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { @@ -2565,7 +2565,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> { // CHECK-LABEL: func @shuffle_canonicalize_0d func.func @shuffle_canonicalize_0d(%v0 : vector, %v1 : vector) -> vector<1xi32> { - // CHECK: vector.broadcast %{{.*}} : vector to vector<1xi32> + // CHECK: vector.shape_cast %{{.*}} : vector to vector<1xi32> %shuffle = vector.shuffle %v0, %v1 [0] : vector, vector return %shuffle : vector<1xi32> } @@ -3047,7 +3047,7 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> { %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32> - // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32> + // CHECK-NEXT: %0 = vector.shape_cast {{.*}} : vector<1x1x1xf32> to vector<1xf32> // CHECK-NEXT: return %0 : vector<1xf32> %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32> return %1: vector<1xf32> @@ -3554,7 +3554,7 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> { // +--------------------------------------------------------------------------- // CHECK-LABEL: transpose_from_elements_1d -// CHECK-SAME: %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32 +// CHECK-SAME: %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32 func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> { %v = vector.from_elements %el_0, %el_1 : vector<2xi32> %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32> @@ -3565,7 +3565,7 @@ func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> { } // CHECK-LABEL: transpose_from_elements_2d -// CHECK-SAME: %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32 +// CHECK-SAME: %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32 func.func @transpose_from_elements_2d( %el_0_0: i32, %el_0_1: i32, %el_0_2: i32, %el_1_0: i32, %el_1_1: i32, %el_1_2: i32 @@ -3579,7 +3579,7 @@ func.func @transpose_from_elements_2d( } // CHECK-LABEL: transpose_from_elements_3d -// CHECK-SAME: %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32 +// CHECK-SAME: %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32 func.func @transpose_from_elements_3d( %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32, %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32 diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index f43328f621787..aa6539e466c95 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -81,7 +81,7 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector< // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( // CHECK-SAME: %[[A:.*]]: vector<1x2xi8>) -// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8> +// CHECK: %[[EXTRACT:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8> // CHECK: return %[[EXTRACT]] : vector<2xi8> func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir new file mode 100644 index 0000000000000..1d082179207dc --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir @@ -0,0 +1,130 @@ +// RUN: mlir-opt %s --split-input-file --canonicalize | FileCheck %s + +// This file contains tests where a vector.shape_cast is the result +// of canonicalization. + +// **--------------------------------------------------------** // +// Tests of BroadcastToShapeCast +// **--------------------------------------------------------** // + +// CHECK-LABEL: @broadcast_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> +// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<1x1x4xi8> +func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { + %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> + return %0 : vector<1x1x4xi8> +} + +// ----- + +// broadcast can only be transformed to a shape_cast if the number of elements is +// unchanged by the broadcast. +// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast +// CHECK-NOT: shape_cast +// CHECK: return +func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> { + %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8> + return %0 : vector<2x3x4xi8> +} + +// ----- + +// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar +// cannot be transformed to a shape_cast. +// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast +// CHECK-NOT: shape_cast +// CHECK: return +func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> { + %0 = vector.broadcast %arg0 : i8 to vector<1xi8> + return %0 : vector<1xi8> +} + +// ----- + +// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is. +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast +// CHECK-NOT: vector.broadcast +// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32> +func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32> + %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32> + return %1 : vector<1x2x1xf32> +} + +// ----- + +// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen. +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible +// CHECK-NOT: vector.broadcast +// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32> +func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> { + %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32> + %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32> + return %1 : vector<1x1xf32> +} + +// ----- + +// **--------------------------------------------------------** // +// Tests of ExtractToShapeCast +// **--------------------------------------------------------** // + +// CHECK-LABEL: @extract_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> +// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<4xf32> +func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { + %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// In this example, arg1 might be negative indicating poison. We could +// convert this to shape_cast (would be a legal transform with poison) +// but we conservatively choose not to. +// CHECK-LABEL: @negative_extract_to_shape_cast +// CHECK-NOT: shape_cast +func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> { + %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: fold_extract_shapecast_to_shapecast +// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>) +// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32> +// CHECK: return %[[R]] +func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> { + %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32> + %r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32> + return %r : vector<12xf32> +} + +// ----- + +// CHECK-LABEL: func @insert_extract_to_shape_cast +// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) +// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> +// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> +func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>, + %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { + %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32> + %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> + return %0, %1 : vector<4xf32>, vector<1x1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @extract_from_broadcast +func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> { + %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32> + // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32> + // CHECK-NEXT: return %[[RES]] : vector<1xf32> + %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32> + return %1: vector<1xf32> +} + diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 45afbffc1be48..8206c1a3ec865 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -24,7 +24,7 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor) -> vector<1xf32> { %f0 = arith.constant 0.0 : f32 // CHECK: %[[S:.*]] = vector.transfer_read %[[SRC]][] -// CHECK: %[[V:.*]] = vector.broadcast %[[S]] : vector to vector<1xf32> +// CHECK: %[[V:.*]] = vector.shape_cast %[[S]] : vector to vector<1xf32> %res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} : tensor, vector<1xf32> @@ -369,8 +369,7 @@ func.func @transfer_write_broadcast_unit_dim_tensor( %c0 = arith.constant 0 : index %res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor - // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32> - // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32> + // CHECK: %[[NEW_VEC1:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32> // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor return %res : tensor @@ -385,8 +384,7 @@ func.func @transfer_write_broadcast_unit_dim_memref( %c0 = arith.constant 0 : index vector.transfer_write %vec_0, %mem_0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref - // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32> - // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32> + // CHECK: %[[NEW_VEC1:.*]] = vector.shape_cast %{{.*}} : vector<8x16xf32> to vector<8x16x1xf32> // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref return diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 135db02d543ef..2d0330043db06 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1534,7 +1534,7 @@ func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1x // CHECK-PROP-DAG: %[[THREADID:.*]] = gpu.thread_id x // CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%[[THREADID]])[32] args(%[[IN2]] // CHECK-PROP: %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}] -// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32> +// CHECK-PROP: %[[EXTRACT:.*]] = vector.shape_cast %[[GATHER]] : vector<1x64xi32> to vector<64xi32> // CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex> // CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extract %[[CAST]][{{.*}}] : index from vector<64xindex> // CHECK-PROP: gpu.yield %[[EXTRACTELT]] : index @@ -1571,7 +1571,7 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1 // CHECK-PROP-LABEL: func @dont_fold_vector_broadcast( // CHECK-PROP: %[[r:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>) // CHECK-PROP: %[[some_def:.*]] = "some_def" -// CHECK-PROP: %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32> +// CHECK-PROP: %[[broadcast:.*]] = vector.shape_cast %[[some_def]] : vector<64xf32> to vector<1x64xf32> // CHECK-PROP: gpu.yield %[[broadcast]] : vector<1x64xf32> // CHECK-PROP: vector.print %[[r]] : vector<1x2xf32> func.func @dont_fold_vector_broadcast(%laneid: index) {