-
Notifications
You must be signed in to change notification settings - Fork 15.9k
[mlir][vector] Canonicalize vector.extract and vector.broadcast to vector.shape_cast #174452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Canonicalize vector.extract and vector.broadcast to vector.shape_cast #174452
Conversation
Signed-off-by: James Newling <[email protected]>
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesBased on the original PR #140583, but without vector.transpose -> vector.shape_cast. This PR canonicalizes %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> to shape_cast. It was decided (see #140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added. Patch is 22.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174452.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 12bdc9646ee84..27dd4e2e034fc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2359,11 +2359,48 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
return success();
}
+/// Replace `vector.extract` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!outType)
+ return failure();
+
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return rewriter.notifyMatchFailure(
+ extractOp, "extract to vector with fewer elements");
+
+ // Negative values in `position` means that the extacted value is poison.
+ // There is a vector.extract folder for this.
+ if (llvm::any_of(extractOp.getMixedPosition(),
+ [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+ return rewriter.notifyMatchFailure(extractOp,
+ "leaving for extract poison folder");
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+ extractOp.getSource());
+
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results
+ .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+ context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
@@ -3076,13 +3113,43 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+/// Replace `vector.broadcast` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct BroadcastToShapeCast final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+ PatternRewriter &rewriter) const override {
+
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceType) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "source is a scalar, shape_cast doesn't support scalar");
+ }
+
+ VectorType outType = broadcast.getType();
+ if (sourceType.getNumElements() != outType.getNumElements()) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast to a greater number of elements");
+ }
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+ broadcast.getSource());
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, BroadcastToShapeCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -6479,10 +6546,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-/// i) Y = ShapeCast(X), or
-/// ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using Base::Base;
@@ -6497,22 +6561,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;
- // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
- // Example:
- // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
- // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
- // to
- // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
- if (srcVectorType) {
- if (srcVectorType.getNumElements() ==
- shapeCastOp.getResultVectorType().getNumElements()) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- shapeCastOp, shapeCastOp.getResultVectorType(),
- broadcastOp.getSource());
- return success();
- }
- }
-
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
// Example
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e17b1cfbe5e0d..fdc8e487ff6bb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -814,7 +814,7 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
// CHECK-LABEL: negative_fold_extract_broadcast
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK: vector.shape_cast %{{.*}} : vector<1x1x4xf32> to vector<4xf32>
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -931,7 +931,7 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
-// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32>
// CHECK: return %[[R]] : vector<1x1xf32>
func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
-> vector<1x1xf32> {
@@ -1561,7 +1561,7 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
// -----
-// Check the case where the same dimension is both broadcasted and sliced
+// Check the case where the same dimension is both broadcasted and sliced
// CHECK-LABEL: func @extract_strided_broadcast5
// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
@@ -1918,7 +1918,7 @@ func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>,
// CHECK-SAME: (%[[V0:.*]]: vector<4x8xf32>, %[[MEM:.*]]: tensor<1x1x4x8xf32>)
// CHECK-NOT: vector.transfer_write
// CHECK-NOT: vector.transfer_read
-// CHECK: %[[RET:.+]] = vector.broadcast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
+// CHECK: %[[RET:.+]] = vector.shape_cast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
// CHECK: return %[[RET]]
func.func @store_to_load_tensor_forwarding_unit_dim_broadcast(
%vec: vector<4x8xf32>,
@@ -2197,8 +2197,8 @@ func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
// CHECK-LABEL: func @insert_extract_to_broadcast
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
%arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
@@ -2565,7 +2565,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
- // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
%shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
return %shuffle : vector<1xi32>
}
@@ -3047,7 +3047,7 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
%0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
- // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
+ // CHECK-NEXT: %0 = vector.shape_cast {{.*}} : vector<1x1x1xf32> to vector<1xf32>
// CHECK-NEXT: return %0 : vector<1xf32>
%1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
return %1: vector<1xf32>
@@ -3554,7 +3554,7 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> {
// +---------------------------------------------------------------------------
// CHECK-LABEL: transpose_from_elements_1d
-// CHECK-SAME: %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32
+// CHECK-SAME: %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32
func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
%v = vector.from_elements %el_0, %el_1 : vector<2xi32>
%t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
@@ -3565,7 +3565,7 @@ func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
}
// CHECK-LABEL: transpose_from_elements_2d
-// CHECK-SAME: %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32
+// CHECK-SAME: %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32
func.func @transpose_from_elements_2d(
%el_0_0: i32, %el_0_1: i32, %el_0_2: i32,
%el_1_0: i32, %el_1_1: i32, %el_1_2: i32
@@ -3579,7 +3579,7 @@ func.func @transpose_from_elements_2d(
}
// CHECK-LABEL: transpose_from_elements_3d
-// CHECK-SAME: %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32
+// CHECK-SAME: %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32
func.func @transpose_from_elements_3d(
%el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32,
%el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index f43328f621787..aa6539e466c95 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,7 +81,7 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>)
-// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
+// CHECK: %[[EXTRACT:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
// CHECK: return %[[EXTRACT]] : vector<2xi8>
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
new file mode 100644
index 0000000000000..86fc6a5c0d6fc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+// This file contains tests where a vector.shape_cast is the result
+// of canonicalization.
+
+// **--------------------------------------------------------** //
+// Tests of BroadcastToShapeCast
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
+// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+ return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// broadcast can only be transformed to a shape_cast if the number of elements is
+// unchanged by the broadcast.
+// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast
+// CHECK-NOT: shape_cast
+// CHECK: return
+func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+ return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar
+// cannot be transformed to a shape_cast.
+// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast
+// CHECK-NOT: shape_cast
+// CHECK: return
+func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
+ %0 = vector.broadcast %arg0 : i8 to vector<1xi8>
+ return %0 : vector<1xi8>
+}
+
+// -----
+
+// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
+// CHECK-NOT: vector.broadcast
+// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
+ %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
+ %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
+ return %1 : vector<1x2x1xf32>
+}
+
+// -----
+
+// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
+// CHECK-NOT: vector.broadcast
+// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
+func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
+ %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
+ %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
+ return %1 : vector<1x1xf32>
+}
+
+// -----
+
+// **--------------------------------------------------------** //
+// Tests of ExtractToShapeCast
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: @extract_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
+// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+ %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison. We could
+// convert this to shape_cast (would be a legal transform with poison)
+// but we conservatively choose not to.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+// CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+ %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_shapecast_to_shapecast
+// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
+// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
+// CHECK: return %[[R]]
+func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
+ %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
+ %r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32>
+ return %r : vector<12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_shape_cast
+// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
+ %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+ %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+ %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+ return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_broadcast
+func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
+ %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
+ // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+ // CHECK-NEXT: return %[[RES]] : vector<1xf32>
+ %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
+ return %1: vector<1xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 45afbffc1be48..8206c1a3ec865 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -24,7 +24,7 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
%f0 = arith.constant 0.0 : f32
// CHECK: %[[S:.*]] = vector.transfer_read %[[SRC]][]
-// CHECK: %[[V:.*]] = vector.broadcast %[[S]] : vector<f32> to vector<1xf32>
+// CHECK: %[[V:.*]] = vector.shape_cast %[[S]] : vector<f32> to vector<1xf32>
%res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
tensor<f32>, vector<1xf32>
@@ -369,8 +369,7 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
%c0 = arith.constant 0 : index
%res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
- // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
- // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
+ // CHECK: %[[NEW_VEC1:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
// CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
return %res : tensor<?x?x?x?xf32>
@@ -385,8 +384,7 ...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesBased on the original PR #140583, but without vector.transpose -> vector.shape_cast. This PR canonicalizes %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> to shape_cast. It was decided (see #140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added. Patch is 22.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174452.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 12bdc9646ee84..27dd4e2e034fc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2359,11 +2359,48 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
return success();
}
+/// Replace `vector.extract` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!outType)
+ return failure();
+
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return rewriter.notifyMatchFailure(
+ extractOp, "extract to vector with fewer elements");
+
+ // Negative values in `position` means that the extacted value is poison.
+ // There is a vector.extract folder for this.
+ if (llvm::any_of(extractOp.getMixedPosition(),
+ [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+ return rewriter.notifyMatchFailure(extractOp,
+ "leaving for extract poison folder");
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+ extractOp.getSource());
+
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results
+ .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+ context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
@@ -3076,13 +3113,43 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+/// Replace `vector.broadcast` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct BroadcastToShapeCast final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+ PatternRewriter &rewriter) const override {
+
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceType) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "source is a scalar, shape_cast doesn't support scalar");
+ }
+
+ VectorType outType = broadcast.getType();
+ if (sourceType.getNumElements() != outType.getNumElements()) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "broadcast to a greater number of elements");
+ }
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+ broadcast.getSource());
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, BroadcastToShapeCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -6479,10 +6546,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-/// i) Y = ShapeCast(X), or
-/// ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using Base::Base;
@@ -6497,22 +6561,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;
- // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
- // Example:
- // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
- // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
- // to
- // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
- if (srcVectorType) {
- if (srcVectorType.getNumElements() ==
- shapeCastOp.getResultVectorType().getNumElements()) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- shapeCastOp, shapeCastOp.getResultVectorType(),
- broadcastOp.getSource());
- return success();
- }
- }
-
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
// Example
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e17b1cfbe5e0d..fdc8e487ff6bb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -814,7 +814,7 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
// CHECK-LABEL: negative_fold_extract_broadcast
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK: vector.shape_cast %{{.*}} : vector<1x1x4xf32> to vector<4xf32>
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -931,7 +931,7 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
-// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32>
// CHECK: return %[[R]] : vector<1x1xf32>
func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
-> vector<1x1xf32> {
@@ -1561,7 +1561,7 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
// -----
-// Check the case where the same dimension is both broadcasted and sliced
+// Check the case where the same dimension is both broadcasted and sliced
// CHECK-LABEL: func @extract_strided_broadcast5
// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
@@ -1918,7 +1918,7 @@ func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>,
// CHECK-SAME: (%[[V0:.*]]: vector<4x8xf32>, %[[MEM:.*]]: tensor<1x1x4x8xf32>)
// CHECK-NOT: vector.transfer_write
// CHECK-NOT: vector.transfer_read
-// CHECK: %[[RET:.+]] = vector.broadcast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
+// CHECK: %[[RET:.+]] = vector.shape_cast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
// CHECK: return %[[RET]]
func.func @store_to_load_tensor_forwarding_unit_dim_broadcast(
%vec: vector<4x8xf32>,
@@ -2197,8 +2197,8 @@ func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
// CHECK-LABEL: func @insert_extract_to_broadcast
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
%arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
@@ -2565,7 +2565,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
- // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
%shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
return %shuffle : vector<1xi32>
}
@@ -3047,7 +3047,7 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
%0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
- // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
+ // CHECK-NEXT: %0 = vector.shape_cast {{.*}} : vector<1x1x1xf32> to vector<1xf32>
// CHECK-NEXT: return %0 : vector<1xf32>
%1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
return %1: vector<1xf32>
@@ -3554,7 +3554,7 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> {
// +---------------------------------------------------------------------------
// CHECK-LABEL: transpose_from_elements_1d
-// CHECK-SAME: %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32
+// CHECK-SAME: %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32
func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
%v = vector.from_elements %el_0, %el_1 : vector<2xi32>
%t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
@@ -3565,7 +3565,7 @@ func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
}
// CHECK-LABEL: transpose_from_elements_2d
-// CHECK-SAME: %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32
+// CHECK-SAME: %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32
func.func @transpose_from_elements_2d(
%el_0_0: i32, %el_0_1: i32, %el_0_2: i32,
%el_1_0: i32, %el_1_1: i32, %el_1_2: i32
@@ -3579,7 +3579,7 @@ func.func @transpose_from_elements_2d(
}
// CHECK-LABEL: transpose_from_elements_3d
-// CHECK-SAME: %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32
+// CHECK-SAME: %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32
func.func @transpose_from_elements_3d(
%el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32,
%el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index f43328f621787..aa6539e466c95 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,7 +81,7 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>)
-// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
+// CHECK: %[[EXTRACT:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
// CHECK: return %[[EXTRACT]] : vector<2xi8>
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
new file mode 100644
index 0000000000000..86fc6a5c0d6fc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+// This file contains tests where a vector.shape_cast is the result
+// of canonicalization.
+
+// **--------------------------------------------------------** //
+// Tests of BroadcastToShapeCast
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
+// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+ return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// broadcast can only be transformed to a shape_cast if the number of elements is
+// unchanged by the broadcast.
+// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast
+// CHECK-NOT: shape_cast
+// CHECK: return
+func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+ return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar
+// cannot be transformed to a shape_cast.
+// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast
+// CHECK-NOT: shape_cast
+// CHECK: return
+func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
+ %0 = vector.broadcast %arg0 : i8 to vector<1xi8>
+ return %0 : vector<1xi8>
+}
+
+// -----
+
+// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
+// CHECK-NOT: vector.broadcast
+// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
+ %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
+ %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
+ return %1 : vector<1x2x1xf32>
+}
+
+// -----
+
+// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
+// CHECK-NOT: vector.broadcast
+// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
+func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
+ %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
+ %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
+ return %1 : vector<1x1xf32>
+}
+
+// -----
+
+// **--------------------------------------------------------** //
+// Tests of ExtractToShapeCast
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: @extract_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
+// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+ %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison. We could
+// convert this to shape_cast (would be a legal transform with poison)
+// but we conservatively choose not to.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+// CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+ %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_shapecast_to_shapecast
+// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
+// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
+// CHECK: return %[[R]]
+func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
+ %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
+ %r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32>
+ return %r : vector<12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_shape_cast
+// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
+ %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+ %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+ %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+ return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_broadcast
+func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
+ %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
+ // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+ // CHECK-NEXT: return %[[RES]] : vector<1xf32>
+ %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
+ return %1: vector<1xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 45afbffc1be48..8206c1a3ec865 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -24,7 +24,7 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
%f0 = arith.constant 0.0 : f32
// CHECK: %[[S:.*]] = vector.transfer_read %[[SRC]][]
-// CHECK: %[[V:.*]] = vector.broadcast %[[S]] : vector<f32> to vector<1xf32>
+// CHECK: %[[V:.*]] = vector.shape_cast %[[S]] : vector<f32> to vector<1xf32>
%res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
tensor<f32>, vector<1xf32>
@@ -369,8 +369,7 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
%c0 = arith.constant 0 : index
%res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
- // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
- // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
+ // CHECK: %[[NEW_VEC1:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
// CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
return %res : tensor<?x?x?x?xf32>
@@ -385,8 +384,7 ...
[truncated]
|
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving based on my review of the original PR (#140583).
Given #140583 (comment), I don’t expect further objections. That said, please wait for at least one more +1, ideally from someone who raised objections in the previous discussion (CC @Groverkss, @kuhar, @MaheshRavishankar).
Many thanks for working on this!
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just some nits
mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
Outdated
Show resolved
Hide resolved
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but I'd also like @Groverkss to take a look
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks!
It's been over a week - I suggest that we land this, Kunwar can leave a post-commit review. |
|
Just a heads up that we'll land this in 2 days if there are no objections (@Groverkss) |
Reverts carried forward: * Local revert of llvm/llvm-project#169614 due to #22649 Other changes: * Fixes lit tests to account for llvm/llvm-project#174452
…ctor.shape_cast (llvm#174452) Based on the original PR llvm#140583, but without vector.transpose -> vector.shape_cast. This PR canonicalizes %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8> to shape_cast. It was decided (see llvm#140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added. --------- Signed-off-by: James Newling <[email protected]>
…ctor.shape_cast (llvm#174452) Based on the original PR llvm#140583, but without vector.transpose -> vector.shape_cast. This PR canonicalizes %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8> to shape_cast. It was decided (see llvm#140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added. --------- Signed-off-by: James Newling <[email protected]>
Reverts carried forward: * Local revert of llvm/llvm-project#169614 due to #22649 Other changes: * Fixes lit tests to account for llvm/llvm-project#174452 Signed-off-by: Keshav Vinayak Jha <[email protected]>
Based on the original PR #140583, but without vector.transpose -> vector.shape_cast.
This PR canonicalizes
%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
to shape_cast. It was decided (see #140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added.