-
Notifications
You must be signed in to change notification settings - Fork 16k
[vector][mlir] Canonicalize to shape_cast where possible #140583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
d546ab3 to
29d41d8
Compare
29d41d8 to
f2e5417
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@banach-space I'm getting back to this PR. Peephole question: is this operation ok? i.e. is
vector.shape_cast %a vector<[4]x1xf32> to vector<1x[4]xf32>
an acceptable operation to have after running mlir-opt -arm-sme-vector-legalization -cse -canonicalize ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, yes. But I can't guarantee there's no logic that expects vector<[4]x1xf32> instead of vector<1x[4]xf32> ;-) If that's the case, we will fix it and I will be grateful for uncovering this :)
7bc5da0 to
e673522
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: I've removed this, as now it happens in 2 steps during canonicalization. The first converts the Broadcast to a ShapeCast. The second combines the 2 ShapeCasts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: I've removed this, as it now happens in 2 steps during canonicalization. The first (new) step is to rewrite the transpose as a shape_cast. The second step is to fold shape_cast(shape_cast) to shape_cast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: I've removed this pattern, as it is a special case of TransposeToShapeCast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: removed these tests, as the pattern they are testing is removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we keep them? shouldn't they still be canonicalized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add them back, yes they're still canonicalized
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: as the vector.transpose is canonicalized to a vector.shape_cast, the lowering test is now moved to shape_cast lowering
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: James Newling (newling) ChangesDiscussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777) For example these can all be expressed as shape casts: %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>This PR adds canonicalizes to convert the above 3 examples to shape_casts. I've added some more comments as review comments. I'm happy to split this PR up and add the new patterns separately. Patch is 41.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140583.diff 10 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..08cc4af158e10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2351,11 +2351,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
return success();
}
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!outType)
+ return failure();
+
+ // Negative values in `position` indicates poison, which cannot be
+ // represented with a shape_cast
+ if (llvm::any_of(extractOp.getMixedPosition(),
+ [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+ return failure();
+
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+ extractOp.getVector());
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results
+ .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+ context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
@@ -2867,13 +2897,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+struct BroadcastToShapeCast final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceType) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "source is a scalar, shape_cast doesn't support scalar");
+ }
+
+ VectorType outType = broadcast.getType();
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+ broadcast.getSource());
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, BroadcastToShapeCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -5991,10 +6044,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-/// i) Y = ShapeCast(X), or
-/// ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;
- // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
- // Example:
- // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
- // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
- // to
- // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
- if (srcVectorType) {
- if (srcVectorType.getNumElements() ==
- shapeCastOp.getResultVectorType().getNumElements()) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- shapeCastOp, shapeCastOp.getResultVectorType(),
- broadcastOp.getSource());
- return success();
- }
- }
-
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
// Example
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6233,7 +6267,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
//
- // Example of what NOT to fold:
+ // Example of what not to fold:
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
//
if (getSourceVectorType() == getResultVectorType() &&
@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
-/// Folds transpose(shape_cast) into a new shape_cast.
-class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TransposeOp transposeOp,
- PatternRewriter &rewriter) const override {
- auto shapeCastOp =
- transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
- if (!shapeCastOp)
- return failure();
- if (!isOrderPreserving(transposeOp))
- return failure();
-
- VectorType resultType = transposeOp.getType();
-
- // We don't need to check isValidShapeCast at this point, because it is
- // guaranteed that merging the transpose into the the shape_cast is a valid
- // shape_cast, because the transpose just inserts/removes ones.
-
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
- shapeCastOp.getSource());
- return success();
- }
-};
-
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6488,35 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
}
};
+/// BEFORE:
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+struct TransposeToShapeCast final
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+
+ if (!isOrderPreserving(transpose)) {
+ return rewriter.notifyMatchFailure(
+ transpose, "not order preserving, so not semantically a 'copy'");
+ }
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ transpose, transpose.getType(), transpose.getVector());
+ return success();
+ }
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
- FoldTransposeSplat, FoldTransposeBroadcast>(context);
+ results.add<FoldTransposeBroadcast, FoldTransposeCreateMask,
+ FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..71410eda28297 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
vector::VectorTransposeLowering vectorTransposeLowering;
};
-/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
-/// to 2D vectors with at least one unit dim. For example:
-///
-/// Replace:
-/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
-/// vector<1x4xi32>
-/// with:
-/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
-///
-/// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit dim can be scalable.
-///
-/// TODO: This pattern was introduced specifically to help lower scalable
-/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
-/// to cancel out) would be preferable:
-///
-/// BEFORE:
-/// %0 = some_op
-/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
-/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// AFTER:
-/// %0 = some_op
-/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
-///
-/// Given the context above, we may want to consider (re-)moving this pattern
-/// at some later time. I am leaving it for now in case there are other users
-/// that I am not aware of.
-class Transpose2DWithUnitDimToShapeCast
- : public OpRewritePattern<vector::TransposeOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
- PatternBenefit benefit = 1)
- : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
-
- LogicalResult matchAndRewrite(vector::TransposeOp op,
- PatternRewriter &rewriter) const override {
- Value input = op.getVector();
- VectorType resType = op.getResultVectorType();
-
- // Set up convenience transposition table.
- ArrayRef<int64_t> transp = op.getPermutation();
-
- if (resType.getRank() == 2 &&
- ((resType.getShape().front() == 1 &&
- !resType.getScalableDims().front()) ||
- (resType.getShape().back() == 1 &&
- !resType.getScalableDims().back())) &&
- transp == ArrayRef<int64_t>({1, 0})) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
- return success();
- }
-
- return failure();
- }
-};
-
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
@@ -511,8 +452,6 @@ class TransposeOp2DToShuffleLowering
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
- patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
- benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
vectorTransposeLowering, patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 6cdf576272ebc..a9a2fdccdd82f 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i
// -----
-// The pass should do nothing (and not crash).
-// CHECK-LABEL: @illegal_transpose_no_defining_source_op
-func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
+// CHECK-LABEL: @transpose_no_defining_source_op
+func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
{
- // CHECK: vector.transpose
+ // CHECK: vector.shape_cast
+ // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
%0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..374c71c814e89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
// -----
// CHECK-LABEL: transpose_3D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+// CHECK-NEXT: return [[ARG]]
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
- // CHECK-NOT: transpose
%0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
- // CHECK-NEXT: return [[ARG]]
return %0 : vector<4x3x2xf32>
}
// -----
+// CHECK-LABEL: transpose_0D_identity
+// CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
+// CHECK-NEXT: return [[ARG]]
+func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
+ %0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
+ return %0 : vector<i8>
+}
+
+// -----
+
// CHECK-LABEL: transpose_2D_sequence
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
@@ -753,12 +762,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
// -----
+
// CHECK-LABEL: negative_fold_extract_broadcast
-// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
- %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
- %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+ %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
@@ -797,8 +807,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
// rank(extract_output) < rank(broadcast_input)
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
- %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
return %r : vector<4xf32>
}
@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
// -----
-// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
-// CHECK-NOT: vector.broadcast
-// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
- %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
- %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
- return %1 : vector<1x2x1xf32>
-}
-
-// -----
-
-// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
-// CHECK-NOT: vector.broadcast
-// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
-func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
- %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
- %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
- return %1 : vector<1x1xf32>
-}
-
-// -----
-
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -1920,12 +1906,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
// -----
-// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-LABEL: func @insert_extract_to_shape_cast
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
%arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
%0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
%1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2277,7 +2263,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
- // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
%shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
return %shuffle : vector<1xi32>
}
@@ -2764,9 +2750,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
// CHECK-LABEL: func.func @extract_from_broadcast
func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
%0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
-
- // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
- // CHECK-NEXT: return %0 : vector<1xf32>
+ // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+ // CHECK-NEXT: return %[[RES]] : vector<1xf32>
%1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
return %1: vector<1xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..d5f96a8928770 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>)
-// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
-// CHECK: return %[[EXTRACT]] : vector<2xi8>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
+// CHECK: return %[[SC]] : vector<2xi8>
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
diff --git...
[truncated]
|
|
Hi @banach-space and @dcaballe, I've pulled this PR out of draft mode, so please feel free to comment on it whenever! |
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LGTM in general. The only general comment is to make sure we don't reduce testing coverage. I think we should keep/update the tests even for those cases where the pattern is removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep both tests, one with the original shape and one with the new ones?
Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep both tests, one with the original shape and one with the new ones?
Makes sense, will do.
Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?
No because you can't broadcast <1x1xf32> to <4xf32> -- broadcasts can never reduce rank in Vector. FWIW slightly related to my comment here where this would be simpler if ops didn't do implicit shape casting. In this case if it was something like
%s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
%b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x2x4xf32>
%r = vector.extract %b[0, 0] : vector<1x1x4xf32> from vector<1x2x4xf32>
%s = vector.shape_cast %r : vector<1x1x4> to vector<4>
ie if we constrained broadcasts and extracts to be rank retaining, then this would be canonicalized to
%s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
%b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x1x4xf32>
%s = vector.shape_cast %b : vector<1x1x4> to vector<4>
which, if you have faith that the shape_casts will vanish at a later point, is simpler!
p.s. I plan to reply in #145740 later today
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we keep them? shouldn't they still be canonicalized?
|
Thanks! I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues. @newling , why not name all "folding" patterns as |
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change shapes?
I'll give this a spin with IREE
Yes, I think so. Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds). @banach-space and @dcaballe thanks for your feedback! Unfortunately I'm going to put this on hold again temporarily, as I've uncovered some other things which should be done before this. Moving back into draft mode, will ping when I think it's ready again. |
+1 |
1ff3399 to
92e809e
Compare
|
This PR is back, and ready for review! Let me summarize the previous concerns as this is quite old now: @dcaballe raised concerns about removing tests. I have reinstated all canonicalization tests. |
|
@dcaballe and @banach-space see this post here https://discourse.llvm.org/t/rfc-update-to-general-design-section-of-operation-canonicalizations-in-mlir/79355?u=maheshravishankar . This talks about the how vector.transpose captures more information than a vector.shape_cast and how you cannot always go from shape_cast to transpose. This is exactly the issue with treating vector.shape_cast as "canonical" representation for transposes and hoping that we can lift back to the original representation always. |
|
Thanks @MaheshRavishankar , as promised I am returning to this after you've shared your example.
I've extracted this repro as something representative (*): func.func @transpose_to_shape_cast_1(%0 : vector<4x1x1xf32>) -> vector<1x4x1xf32> {
%res = vector.transpose %0, [2, 0, 1] : vector<4x1x1xf32> to vector<1x4x1xf32>
return %res : vector<1x4x1xf32>
}
// -----
func.func @transpose_to_shape_cast_2(%0 : vector<4x1x1xf32>) -> vector<1x4x1xf32> {
%res = vector.transpose %0, [1, 0, 2] : vector<4x1x1xf32> to vector<1x4x1xf32>
return %res : vector<1x4x1xf32>
}QUESTION/COMMENT: Aren't the examples above identical operations? YES - LLVM example! # Canonicalize to vector.shape_cast, then lower.
$ mlir-opt repro.mlir -canonicalize -test-lower-to-llvm --split-input-file
# Lower as vector.transpose.
$ mlir-opt repro.mlir -test-lower-to-llvm --split-input-fileIn both cases I get the following (testing using this PR): module {
llvm.func @transpose_to_shape_cast_1(%arg0: !llvm.array<4 x array<1 x vector<1xf32>>>) -> !llvm.array<1 x array<4 x vector<1xf32>>> {
%0 = llvm.mlir.poison : !llvm.array<1 x array<4 x vector<1xf32>>>
%1 = llvm.extractvalue %arg0[0, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%2 = llvm.insertvalue %1, %0[0, 0] : !llvm.array<1 x array<4 x vector<1xf32>>>
%3 = llvm.extractvalue %arg0[1, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%4 = llvm.insertvalue %3, %2[0, 1] : !llvm.array<1 x array<4 x vector<1xf32>>>
%5 = llvm.extractvalue %arg0[2, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%6 = llvm.insertvalue %5, %4[0, 2] : !llvm.array<1 x array<4 x vector<1xf32>>>
%7 = llvm.extractvalue %arg0[3, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%8 = llvm.insertvalue %7, %6[0, 3] : !llvm.array<1 x array<4 x vector<1xf32>>>
llvm.return %res : !llvm.array<1 x array<4 x vector<1xf32>>>
}
}
// -----
module {
llvm.func @transpose_to_shape_cast_2(%arg0: !llvm.array<4 x array<1 x vector<1xf32>>>) -> !llvm.array<1 x array<4 x vector<1xf32>>> {
%0 = llvm.mlir.poison : !llvm.array<1 x array<4 x vector<1xf32>>>
%1 = llvm.extractvalue %arg0[0, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%2 = llvm.insertvalue %1, %0[0, 0] : !llvm.array<1 x array<4 x vector<1xf32>>>
%3 = llvm.extractvalue %arg0[1, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%4 = llvm.insertvalue %3, %2[0, 1] : !llvm.array<1 x array<4 x vector<1xf32>>>
%5 = llvm.extractvalue %arg0[2, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%6 = llvm.insertvalue %5, %4[0, 2] : !llvm.array<1 x array<4 x vector<1xf32>>>
%7 = llvm.extractvalue %arg0[3, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
%res = llvm.insertvalue %7, %6[0, 3] : !llvm.array<1 x array<4 x vector<1xf32>>>
llvm.return %res : !llvm.array<1 x array<4 x vector<1xf32>>>
}
}Note, YES - SPIR-V example! # Canonicalize to vector.shape_cast, then lower.
$ mlir-opt repro.mlir -canonicalize -test-convert-to-spirv --split-input-file
# Lower as vector.transpose.
$ mlir-opt repro.mlir -test-convert-to-spirv --split-input-fileIn both cases I get the following (testing using this PR): module {
func.func @transpose_to_shape_cast_1(%arg0: vector<1xf32>, %arg1: vector<1xf32>, %arg2: vector<1xf32>, %arg3: vector<1xf32>) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
return %arg0, %arg1, %arg2, %arg3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
}
}
// -----
module {
func.func @transpose_to_shape_cast_2(%arg0: vector<1xf32>, %arg1: vector<1xf32>, %arg2: vector<1xf32>, %arg3: vector<1xf32>) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
return %arg0, %arg1, %arg2, %arg3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
}
}SPIR-V makes it even clearer that we are dealing with a NO-OP 😅 FINAL THOUGHTS I argue that in all cases we are dealing with one operation for which we have multiple names ( I obviously might be missing something - please correct me know if that's the case. I am sharing this to make my mental model clear and to avoid confusion. -Andrzej (*) Please provide other examples if this does not capture what you had in mind. |
|
@banach-space I think I might not have communicated the intent of my example properly. This was more to show why It of course lowers to the same thing if you are just lowering to LLVM or SPIR-V. I have already agreed that while lowering to LLVM you should lower both these transposes to |
|
I think there is a misunderstanding about what we can expect from a canonical form. A canonical form should indeed allow us to convert to any equivalent representation. However, while these representations are semantically equivalent, the canonical form doesn't (and shouldn't need to) preserve information about which specific representation was part of the input IR. To illustrate this with a simple but realistic example, consider different ways to represent "multiplication by 2" (I haven't checked but this is probably something that LLVM canonicalizes today to a single form):
If we choose option 2 as our canonical form, we can certainly convert to both options 1 and 3 from it when needed. What we can't do (and what isn't a requirement for canonical forms) is to automatically know which of these three forms the input IR had without any additional context. Bringing this back to the The key point here is: if preserving exactly the original input representation is important for your use case, then canonicalization is not the right transformation to apply at that stage of your pipeline. That is not the right expectation to have for a canonical form. |
|
I found this "visualization" from Cursor quite illustrative: Data Layout DiagramOriginal Vector:
|
Answering my own question, I can think of one use case: any kind of traversal that needs to track or propagate a property across one of the unit dimensions in the example wouldn't be able to do so with the Conclusion 1 : We can’t canonicalize a transpose operation to a shape cast when multiple unit dimensions are transposed. The data layout or dimension mapping across the operation becomes ambiguous with the Great, that’s progress! We have identified something technical specific. I suggest that we continue focusing on the technical aspects of the different IR forms. Could we continue this exercise? Could we come up with similar examples for:
|
I am happy cursor was able to give you a better explanation of what I was trying to say all this while. Good to have reached this common state. I think we did discuss previously, then stating that "certain" transpose/broadcasts are canonically shape_casts, and forcing them to then become shape_casts without control is now creating unnecessary complication in the definition of canonicalization. If some transformation is relying on following dimensions through broadcasts/transposes, now it has to look at a shape_cast, decide if this is "convertible to a transpose/broadcast" and then handle that appropriately. This does not seem like a great setup. |
|
Thank you for the detailed discussion, @dcaballe - that was very helpful in clarifying the underlying issues.
Agreed. That’s one concrete example (*). @MaheshRavishankar, could you help us identify other specific cases so that we can better scope or constrain this change? All in all, given the nuances discussed, I don’t see a specific blocker preventing this from being merged - or is there?
I think we may have to agree to disagree here. That said, as -Andrzej (*) A transpose operation to a shape cast when multiple unit dimensions are transposed. |
We might be having different reads of the blocker. To me this discussion is uncovering more reasons why this change shouldnt be merged (this kind of thing is what I was saying would be an issue from the get go).
I want to re-iterate : this is not about my use case. We can find ways to work around things either way. So I am disagree-ing more with the approach here, rather than "this doesnt fit my use case". |
|
Hi folks, this PR was brought to the attention of the @llvm/mlir-area-team. We think it can benefit from a higher-bandwidth discussion in a call. Could you all please fill in https://www.when2meet.com/?33926948-FNg3Z so we can facilitate that? |
|
I entered my availability in the tool. However, I think we are missing a step by skipping an RFC which was requested upthread. This seems to be the first step of the standard llvm decision making process: https://github.com/llvm/llvm-www/blob/main/proposals/LP0001-LLVMDecisionMaking.md#proposed-solution. |
|
An RFC would certainly be welcome, though it is but a tool to attempt establishing consensus. There being or, conversely, not being an RFC, shouldn't preclude us from using other means available to find consensus. Unless the lack of an RFC is the principal contentious point to start with. LP0001 was amended by LP0004, which requires area teams to facilitate consensus seeking: https://github.com/llvm/llvm-www/blob/HEAD/proposals/LP0004-project-governance.md#consensus-seeking-decision-making. The area team executes on that. |
|
My understanding is that LP0004 only reinforces the role of RFCs in decision making; some quotes:
The last quote does mention PRs specifically, but the general framework seems to be built around RFCs and discussion on discourse. In this specific instance, I think the discussion would strongly benefit from a more comprehensive proposal instead of continued back-and-forth that hasn't resulted in any consensus. |
(emphasis mine) I'd rather us focus on working towards consensus in accordance with the spirit of these documents, not the procedural minutae. But we can also discuss that during the call. Which is now scheduled for Tuesday, Dec 16, 15:30 CET / 2:30pm GMT / 9:30am ET / 6:30am PST, meeting link meet.google.com/xyp-khwh-scr. Let's strive for efficiency and refresh our understanding of the thread before the call so we don't waste time during it. |
|
To clarify, after holding another discussion with the area team: we are explicitly refraining from making any decision based on the discussion within the area team only, without talking to all parties, this includes mandating an RFC. The outcome of the call may well be that an RFC is required, but it may become clearer by whom (there is disagreement between vector maintainers) and with what scope. |
|
A brief summary of the call:
|
|
Thank you for organising and facilitating this, @ftynse! And thanks to everyone who joined - I found the discussion very constructive. It’s great to see that we’ve identified clear, actionable next steps 🙏 |
|
Thanks @ftynse, and everyone who attended the call, I'm sorry I couldn't make it. I'm happy to create a new PR with 2 of the 3 patterns in this PR, will do so in the coming weeks.
One small advantage is that we get the following simplification for free: %0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
%1 = vector.transpose %0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
%2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>====> %0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
%1 = vector.shape_cast %0 : vector<1x4xf32> to vector<4x1xf32>
%2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>====> %2 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<4xf32> |
|
OK, so to summarize, there are cases where transpose => shape_cast is beneficial and cases where shape_cast => transpose is beneficial.
@newling Are there other cases where transpose => shape_cast canonicalization is beneficial? You can write down the same example with transpose / shape_cast inverted and come to the opposite conclusion. |
Are you saying that there are sequences of transposes that can't be all shape_casts that would fold themselves? |
|
No. What I meant, you can make this argument, which is very similar to @newling's but argues for the opposite canonicalization: ==> ==> |
This is not a benefit of having shape_cast here. This is a special case of a more general canonicalization pattern that can be written on shape_cast(transpose). You can canonicalize shape_cast(transpose) -> transpose(shape_cast) if it works on a subset which is doesn't get affected by the shape_cast: shape_cast : 1x2x2 -> 1x4
transpose: 1x4 -> 4x1
shape_cast: 4x1 -> 4apply shape_cast(transpose) -> transpose(shape_cast) shape_cast: 1x2x2 -> 1x4
shape_cast: 1x4 -> 4
transpose: 4 -> 4 // no-opand you get the same result later: shape_cast: 1x2x2 -> 4It also works for more interesting cases: shape_cast: 1x2x2x2x3 -> 1x4x2x3
transpose: 1x4x2x3 -> 4x1x3x2
shape_cast: 4x1x3x2 -> 4x3x2canonicalizes to: shape_cast: 1x2x2x4 -> 4x2x3
transpose: 4x2x3 -> 3x2This is actually shows a good reason not to focus on these special cased patterns, as they just hide more general patterns. For shape_cast, you just have to choose a direction to go, either always up, always down, or always expand, always collapse and you can write patterns like this. |
|
Not only this, this pattern in that form has a bigger problem of it doesn't canonicalize to a single form depending on what order patterns are run in. Let's say we implement the transpose -> shape_cast pattern for unit dimensions: shape_cast: 1x2x2x2x3 -> 1x4x2x3
transpose: 1x4x2x3 -> 1x4x3x2
transpose: 1x4x3x2 -> 4x1x3x2
shape_cast: 4x1x3x2 -> 4x3x2if the transpose -> shape_cast unit dim canonicalization runs first: shape_cast: 1x2x2x2x3 -> 1x4x2x3
transpose: 1x4x2x3 -> 1x4x3x2
shape_cast: 1x4x3x2 -> 4x3x2You are stuck here now, unless we have the more general pattern i talked about before. If you canonicalize transpose(transpose) first: shape_cast: 1x2x2x2x3 -> 1x4x2x3
transpose: 1x4x2x3 -> 4x1x3x2
shape_cast: 4x1x3x2 -> 4x3x2and you are stuck here. (note this is a different form from the above ordering). This is also why i was asking for concrete examples of why this is a good form, because without examples, it's hard to understand things better. |
|
That's a good example of 'getting stuck' @Groverkss. I agree that there are arguments in all directions (no canonicalize vs canonicalize to transpose vs canonicalize to shape_cast). A more advanced canonicalization with multiple 'bubbling' stages (try to get order |
…ctor.shape_cast (#174452) Based on the original PR #140583, but without vector.transpose -> vector.shape_cast. This PR canonicalizes %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8> to shape_cast. It was decided (see #140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added. --------- Signed-off-by: James Newling <[email protected]>
…dcast to vector.shape_cast (#174452) Based on the original PR llvm/llvm-project#140583, but without vector.transpose -> vector.shape_cast. This PR canonicalizes %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8> to shape_cast. It was decided (see llvm/llvm-project#140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added. --------- Signed-off-by: James Newling <[email protected]>
…ctor.shape_cast (llvm#174452) Based on the original PR llvm#140583, but without vector.transpose -> vector.shape_cast. This PR canonicalizes %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8> to shape_cast. It was decided (see llvm#140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added. --------- Signed-off-by: James Newling <[email protected]>
…ctor.shape_cast (llvm#174452) Based on the original PR llvm#140583, but without vector.transpose -> vector.shape_cast. This PR canonicalizes %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8> to shape_cast. It was decided (see llvm#140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added. --------- Signed-off-by: James Newling <[email protected]>
…ctor.shape_cast (#174452) Based on the original PR llvm/llvm-project#140583, but without vector.transpose -> vector.shape_cast. This PR canonicalizes %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8> to shape_cast. It was decided (see llvm/llvm-project#140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added. --------- Signed-off-by: James Newling <[email protected]>
Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)
For example these can all be expressed as shape casts:
This PR adds canonicalizes to convert the above 3 examples to shape_casts.
I've added some more comments as review comments.
I'm happy to split this PR up and add the new patterns separately.