Skip to content

Conversation

@newling
Copy link
Contributor

@newling newling commented May 19, 2025

Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)

For example these can all be expressed as shape casts:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

This PR adds canonicalizes to convert the above 3 examples to shape_casts.

I've added some more comments as review comments.

I'm happy to split this PR up and add the new patterns separately.

@newling newling changed the title [vector][mlir] Canonicalize to shape_cast where possible [wip][vector][mlir] Canonicalize to shape_cast where possible May 19, 2025
@newling newling changed the title [wip][vector][mlir] Canonicalize to shape_cast where possible [vector][mlir] Canonicalize to shape_cast where possible May 19, 2025
@newling newling force-pushed the canonicalize_to_shape_cast branch from d546ab3 to 29d41d8 Compare June 5, 2025 18:07
@newling newling force-pushed the canonicalize_to_shape_cast branch from 29d41d8 to f2e5417 Compare June 25, 2025 23:03
@github-actions
Copy link

github-actions bot commented Jun 25, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space I'm getting back to this PR. Peephole question: is this operation ok? i.e. is

vector.shape_cast %a vector<[4]x1xf32> to vector<1x[4]xf32>

an acceptable operation to have after running mlir-opt -arm-sme-vector-legalization -cse -canonicalize ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, yes. But I can't guarantee there's no logic that expects vector<[4]x1xf32> instead of vector<1x[4]xf32> ;-) If that's the case, we will fix it and I will be grateful for uncovering this :)

@newling newling force-pushed the canonicalize_to_shape_cast branch from 7bc5da0 to e673522 Compare June 26, 2025 15:31
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this, as now it happens in 2 steps during canonicalization. The first converts the Broadcast to a ShapeCast. The second combines the 2 ShapeCasts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this, as it now happens in 2 steps during canonicalization. The first (new) step is to rewrite the transpose as a shape_cast. The second step is to fold shape_cast(shape_cast) to shape_cast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this pattern, as it is a special case of TransposeToShapeCast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: removed these tests, as the pattern they are testing is removed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep them? shouldn't they still be canonicalized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add them back, yes they're still canonicalized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: as the vector.transpose is canonicalized to a vector.shape_cast, the lowering test is now moved to shape_cast lowering

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-sme

Author: James Newling (newling)

Changes

Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)

For example these can all be expressed as shape casts:

%0 = vector.broadcast %arg0 : vector&lt;4xi8&gt; to vector&lt;1x1x4xi8&gt;
%1 = vector.transpose %arg1, [1, 0] : vector&lt;2x1xi8&gt; to vector&lt;1x2xi8&gt;
%2 = vector.extract %arg2[0] : vector&lt;4xi8&gt; from vector&lt;1x4xi8&gt;

This PR adds canonicalizes to convert the above 3 examples to shape_casts.

I've added some more comments as review comments.

I'm happy to split this PR up and add the new patterns separately.


Patch is 41.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140583.diff

10 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+84-53)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (-61)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+4-4)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+26-41)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir (+2-2)
  • (added) mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir (+162)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+60)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+5-7)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (-85)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+4-4)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..08cc4af158e10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2351,11 +2351,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    // Negative values in `position` indicates poison, which cannot be
+    // represented with a shape_cast
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results
+      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+          context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -2867,13 +2897,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, BroadcastToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -5991,10 +6044,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-///   i) Y = ShapeCast(X), or
-///  ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
     bool srcIsScalar = !srcVectorType;
 
-    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
-    // Example:
-    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
-    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
-    // to
-    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
-    if (srcVectorType) {
-      if (srcVectorType.getNumElements() ==
-          shapeCastOp.getResultVectorType().getNumElements()) {
-        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-            shapeCastOp, shapeCastOp.getResultVectorType(),
-            broadcastOp.getSource());
-        return success();
-      }
-    }
-
     // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
     // Example
     // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6233,7 +6267,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
   // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
   //
-  // Example of what NOT to fold:
+  // Example of what not to fold:
   // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
   //
   if (getSourceVectorType() == getResultVectorType() &&
@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose(shape_cast) into a new shape_cast.
-class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TransposeOp transposeOp,
-                                PatternRewriter &rewriter) const override {
-    auto shapeCastOp =
-        transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
-    if (!shapeCastOp)
-      return failure();
-    if (!isOrderPreserving(transposeOp))
-      return failure();
-
-    VectorType resultType = transposeOp.getType();
-
-    // We don't need to check isValidShapeCast at this point, because it is
-    // guaranteed that merging the transpose into the the shape_cast is a valid
-    // shape_cast, because the transpose just inserts/removes ones.
-
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
-                                                     shapeCastOp.getSource());
-    return success();
-  }
-};
-
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6488,35 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
+/// BEFORE:
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+struct TransposeToShapeCast final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+
+    if (!isOrderPreserving(transpose)) {
+      return rewriter.notifyMatchFailure(
+          transpose, "not order preserving, so not semantically a 'copy'");
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        transpose, transpose.getType(), transpose.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+  results.add<FoldTransposeBroadcast, FoldTransposeCreateMask,
+              FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..71410eda28297 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,7 +11,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransposeLowering vectorTransposeLowering;
 };
 
-/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
-/// to 2D vectors with at least one unit dim. For example:
-///
-/// Replace:
-///   vector.transpose %0, [1, 0] : vector<4x1xi32>> to
-///                                 vector<1x4xi32>
-/// with:
-///   vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
-///
-/// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit dim can be scalable.
-///
-/// TODO: This pattern was introduced specifically to help lower scalable
-/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
-/// to cancel out) would be preferable:
-///
-///  BEFORE:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
-///     %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  AFTER:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
-///
-/// Given the context above, we may want to consider (re-)moving this pattern
-/// at some later time. I am leaving it for now in case there are other users
-/// that I am not aware of.
-class Transpose2DWithUnitDimToShapeCast
-    : public OpRewritePattern<vector::TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
-                                    PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
-
-  LogicalResult matchAndRewrite(vector::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.getVector();
-    VectorType resType = op.getResultVectorType();
-
-    // Set up convenience transposition table.
-    ArrayRef<int64_t> transp = op.getPermutation();
-
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
-    }
-
-    return failure();
-  }
-};
-
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -511,8 +452,6 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
     VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
-  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
-                                                  benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       vectorTransposeLowering, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 6cdf576272ebc..a9a2fdccdd82f 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i
 
 // -----
 
-// The pass should do nothing (and not crash).
-// CHECK-LABEL: @illegal_transpose_no_defining_source_op
-func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
+// CHECK-LABEL: @transpose_no_defining_source_op
+func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
 {
-  // CHECK: vector.transpose
+  // CHECK:      vector.shape_cast
+  // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
   %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..374c71c814e89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
 // -----
 
 // CHECK-LABEL: transpose_3D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-NEXT: return [[ARG]]
 func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
-  // CHECK-NOT: transpose
   %0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
-  // CHECK-NEXT: return [[ARG]]
   return %0 : vector<4x3x2xf32>
 }
 
 // -----
 
+// CHECK-LABEL: transpose_0D_identity
+//  CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
+//  CHECK-NEXT: return [[ARG]]
+func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
+  %0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
+  return %0 : vector<i8>
+}
+
+// -----
+
 // CHECK-LABEL: transpose_2D_sequence
 // CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
 func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
@@ -753,12 +762,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // -----
 
+
 // CHECK-LABEL: negative_fold_extract_broadcast
-//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
 func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
-  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -797,8 +807,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 // rank(extract_output) < rank(broadcast_input)
 func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
 
 // -----
 
-// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
-  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
-  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
-  return %1 : vector<1x2x1xf32>
-}
-
-// -----
-
-// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
-func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
-    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
-    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
-    return %1 : vector<1x1xf32>
-}
-
-// -----
-
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -1920,12 +1906,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
 
 // -----
 
-// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-LABEL: func @insert_extract_to_shape_cast
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-//       CHECK:   %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
 //       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
   %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
   %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
   %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2277,7 +2263,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
 
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
-  // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+  // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
   %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
   return %shuffle : vector<1xi32>
 }
@@ -2764,9 +2750,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
 // CHECK-LABEL: func.func @extract_from_broadcast
 func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
-
-  //  CHECK-NEXT:   %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
-  //  CHECK-NEXT:   return %0 : vector<1xf32>
+  //  CHECK-NEXT:   %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+  //  CHECK-NEXT:   return %[[RES]] : vector<1xf32>
   %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
   return %1: vector<1xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..d5f96a8928770 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
 //  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
-//       CHECK:       return %[[EXTRACT]] : vector<2xi8>
+//       CHECK:       %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK:       return %[[SC]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
diff --git...
[truncated]

@newling
Copy link
Contributor Author

newling commented Jun 26, 2025

Hi @banach-space and @dcaballe, I've pulled this PR out of draft mode, so please feel free to comment on it whenever!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! LGTM in general. The only general comment is to make sure we don't reduce testing coverage. I think we should keep/update the tests even for those cases where the pattern is removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep both tests, one with the original shape and one with the new ones?

Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?

Copy link
Contributor Author

@newling newling Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep both tests, one with the original shape and one with the new ones?

Makes sense, will do.

Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?

No because you can't broadcast <1x1xf32> to <4xf32> -- broadcasts can never reduce rank in Vector. FWIW slightly related to my comment here where this would be simpler if ops didn't do implicit shape casting. In this case if it was something like

 %s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
 %b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x2x4xf32>
 %r = vector.extract %b[0, 0] : vector<1x1x4xf32> from vector<1x2x4xf32>
 %s = vector.shape_cast %r : vector<1x1x4> to vector<4>

ie if we constrained broadcasts and extracts to be rank retaining, then this would be canonicalized to

 %s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
 %b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x1x4xf32>
 %s = vector.shape_cast %b : vector<1x1x4> to vector<4>

which, if you have faith that the shape_casts will vanish at a later point, is simpler!

p.s. I plan to reply in #145740 later today

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep them? shouldn't they still be canonicalized?

@banach-space
Copy link
Contributor

Thanks!

I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues.

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment on lines 810 to 811
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change shapes?

@newling
Copy link
Contributor Author

newling commented Jun 27, 2025

Thanks!

I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues.

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

I'll give this a spin with IREE

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

Yes, I think so. Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds).

@banach-space and @dcaballe thanks for your feedback! Unfortunately I'm going to put this on hold again temporarily, as I've uncovered some other things which should be done before this. Moving back into draft mode, will ping when I think it's ready again.

@newling newling marked this pull request as draft June 27, 2025 00:51
@banach-space
Copy link
Contributor

Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds).

+1

@newling
Copy link
Contributor Author

newling commented Aug 6, 2025

This PR is back, and ready for review!

Let me summarize the previous concerns as this is quite old now:

@dcaballe raised concerns about removing tests. I have reinstated all canonicalization tests.
@banach-space raised concerns about test naming and dir structure. I would prefer to address these in a later PR as part of a wider canonicalization/folding test refactor.
@banach-space noted this would likely cause ripples downstream, and suggested running IREE tests. I have done this, and indeed some lit tests fail. I will take responsibility for fixing these (FYI @Groverkss).

@MaheshRavishankar
Copy link
Contributor

@dcaballe and @banach-space see this post here https://discourse.llvm.org/t/rfc-update-to-general-design-section-of-operation-canonicalizations-in-mlir/79355?u=maheshravishankar . This talks about the how vector.transpose captures more information than a vector.shape_cast and how you cannot always go from shape_cast to transpose.

This is exactly the issue with treating vector.shape_cast as "canonical" representation for transposes and hoping that we can lift back to the original representation always.

@banach-space
Copy link
Contributor

Thanks @MaheshRavishankar , as promised I am returning to this after you've shared your example.

see this post here https://discourse.llvm.org/t/rfc-update-to-general-design-section-of-operation-canonicalizations-in-mlir/79355?u=maheshravishankar . This talks about the how vector.transpose captures more information than a vector.shape_cast and how you cannot always go from shape_cast to transpose.

I've extracted this repro as something representative (*):

func.func @transpose_to_shape_cast_1(%0 : vector<4x1x1xf32>) -> vector<1x4x1xf32> {
  %res = vector.transpose %0, [2, 0, 1] : vector<4x1x1xf32> to vector<1x4x1xf32>
  return %res : vector<1x4x1xf32>
}

// -----

func.func @transpose_to_shape_cast_2(%0 : vector<4x1x1xf32>) -> vector<1x4x1xf32> {
  %res = vector.transpose %0, [1, 0, 2] : vector<4x1x1xf32> to vector<1x4x1xf32>
  return %res : vector<1x4x1xf32>
}

QUESTION/COMMENT:

Aren't the examples above identical operations?

YES - LLVM example!
Lets try these:

# Canonicalize to vector.shape_cast, then lower.
$ mlir-opt  repro.mlir -canonicalize -test-lower-to-llvm --split-input-file
# Lower as vector.transpose.
$ mlir-opt  repro.mlir -test-lower-to-llvm --split-input-file

In both cases I get the following (testing using this PR):

module {
  llvm.func @transpose_to_shape_cast_1(%arg0: !llvm.array<4 x array<1 x vector<1xf32>>>) -> !llvm.array<1 x array<4 x vector<1xf32>>> {
    %0 = llvm.mlir.poison : !llvm.array<1 x array<4 x vector<1xf32>>>
    %1 = llvm.extractvalue %arg0[0, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %2 = llvm.insertvalue %1, %0[0, 0] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %3 = llvm.extractvalue %arg0[1, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %4 = llvm.insertvalue %3, %2[0, 1] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %5 = llvm.extractvalue %arg0[2, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %6 = llvm.insertvalue %5, %4[0, 2] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %7 = llvm.extractvalue %arg0[3, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %8 = llvm.insertvalue %7, %6[0, 3] : !llvm.array<1 x array<4 x vector<1xf32>>>
    llvm.return %res : !llvm.array<1 x array<4 x vector<1xf32>>>
  }
}

// -----
module {
  llvm.func @transpose_to_shape_cast_2(%arg0: !llvm.array<4 x array<1 x vector<1xf32>>>) -> !llvm.array<1 x array<4 x vector<1xf32>>> {
    %0 = llvm.mlir.poison : !llvm.array<1 x array<4 x vector<1xf32>>>
    %1 = llvm.extractvalue %arg0[0, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %2 = llvm.insertvalue %1, %0[0, 0] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %3 = llvm.extractvalue %arg0[1, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %4 = llvm.insertvalue %3, %2[0, 1] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %5 = llvm.extractvalue %arg0[2, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %6 = llvm.insertvalue %5, %4[0, 2] : !llvm.array<1 x array<4 x vector<1xf32>>>
    %7 = llvm.extractvalue %arg0[3, 0] : !llvm.array<4 x array<1 x vector<1xf32>>>
    %res = llvm.insertvalue %7, %6[0, 3] : !llvm.array<1 x array<4 x vector<1xf32>>>
    llvm.return %res : !llvm.array<1 x array<4 x vector<1xf32>>>
  }
}

Note, %res == %arg0, which confirms that we are dealing with a NO-OP.

YES - SPIR-V example!
Lets try these:

# Canonicalize to vector.shape_cast, then lower.
$ mlir-opt  repro.mlir -canonicalize -test-convert-to-spirv --split-input-file
# Lower as vector.transpose.
$ mlir-opt  repro.mlir -test-convert-to-spirv --split-input-file

In both cases I get the following (testing using this PR):

module {
  func.func @transpose_to_shape_cast_1(%arg0: vector<1xf32>, %arg1: vector<1xf32>, %arg2: vector<1xf32>, %arg3: vector<1xf32>) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
    return %arg0, %arg1, %arg2, %arg3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
  }
}

// -----
module {
  func.func @transpose_to_shape_cast_2(%arg0: vector<1xf32>, %arg1: vector<1xf32>, %arg2: vector<1xf32>, %arg3: vector<1xf32>) -> (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) {
    return %arg0, %arg1, %arg2, %arg3 : vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>
  }
}

SPIR-V makes it even clearer that we are dealing with a NO-OP 😅

FINAL THOUGHTS

I argue that in all cases we are dealing with one operation for which we have multiple names (vector.transpose [2, 0, 1] vs vector.transpose [1, 0, 2] vs vector.shape_cast). This discussion is merely trying to establish a single name for all of this.

I obviously might be missing something - please correct me know if that's the case. I am sharing this to make my mental model clear and to avoid confusion.

-Andrzej

(*) Please provide other examples if this does not capture what you had in mind.

@MaheshRavishankar
Copy link
Contributor

@banach-space I think I might not have communicated the intent of my example properly. This was more to show why vector.shape_cast is not the canonical representation of transposes with unit-dims, and that you cannot always recover the transpose from a shape_cast cause it looses information. So a vector.transpose carries more information than a vector.shape_cast.

It of course lowers to the same thing if you are just lowering to LLVM or SPIR-V. I have already agreed that while lowering to LLVM you should lower both these transposes to vector.shape_cast and then just cancel these out, cause at LLVM level this makes no difference. So it is perfectly valid to say that for the sequence of transformations lowering from vector dialect to LLVM without any further context specific analysis, all transpose with unit-dims should be lowered to shape_cast. In other words, the "normal form" for the set of passes that lower to vector dialect ops to LLVM prefers using shape_cast. I am on board with that. Canonical form is different. As shown above, the vector.transpose to vector.shape_cast and back is not always possible to do. You "lose" information when you lower to vector.shape_cast. That makes it non-canonical since a canonicalizer would not help you reach a "better state" of the program.

@dcaballe
Copy link
Contributor

dcaballe commented Nov 14, 2025

I think there is a misunderstanding about what we can expect from a canonical form. A canonical form should indeed allow us to convert to any equivalent representation. However, while these representations are semantically equivalent, the canonical form doesn't (and shouldn't need to) preserve information about which specific representation was part of the input IR.

To illustrate this with a simple but realistic example, consider different ways to represent "multiplication by 2" (I haven't checked but this is probably something that LLVM canonicalizes today to a single form):

  1. mul %a, 2
  2. shl %a, 1
  3. add %a, %a

If we choose option 2 as our canonical form, we can certainly convert to both options 1 and 3 from it when needed. What we can't do (and what isn't a requirement for canonical forms) is to automatically know which of these three forms the input IR had without any additional context.

Bringing this back to the vector.transpose example, the same principle applies. We can convert from the vector.shape_cast to any the two flavors of the vector.transpose operations but we have to decide which one.

The key point here is: if preserving exactly the original input representation is important for your use case, then canonicalization is not the right transformation to apply at that stage of your pipeline. That is not the right expectation to have for a canonical form.

@dcaballe
Copy link
Contributor

I found this "visualization" from Cursor quite illustrative:

Data Layout Diagram

Original Vector: vector<4x1x1xf32>

Let's assume we have 4 elements: [a, b, c, d]

Original layout (vector<4x1x1xf32>):
Dimension 0 (size=4): [a, b, c, d]
Dimension 1 (size=1): [ ]
Dimension 2 (size=1): [ ]

Conceptual 3D representation:
[[[a]], [[b]], [[c]], [[d]]]

Linear memory layout: [a, b, c, d]

Transpose with permutation [2, 0, 1]: vector<4x1x1xf32> → vector<1x4x1xf32>

After transpose [2, 0, 1]:
- Old dim 2 (size=1) → New dim 0 (size=1)
- Old dim 0 (size=4) → New dim 1 (size=4)  
- Old dim 1 (size=1) → New dim 2 (size=1)

Result layout (vector<1x4x1xf32>):
Dimension 0 (size=1): [ ]
Dimension 1 (size=4): [a, b, c, d]
Dimension 2 (size=1): [ ]

Conceptual 3D representation:
[[[a], [b], [c], [d]]]

Linear memory layout: [a, b, c, d]  (NO CHANGE!)

Transpose with permutation [1, 0, 2]: vector<4x1x1xf32> → vector<1x4x1xf32>

After transpose [1, 0, 2]:
- Old dim 1 (size=1) → New dim 0 (size=1)
- Old dim 0 (size=4) → New dim 1 (size=4)
- Old dim 2 (size=1) → New dim 2 (size=1)

Result layout (vector<1x4x1xf32>):
Dimension 0 (size=1): [ ]
Dimension 1 (size=4): [a, b, c, d]
Dimension 2 (size=1): [ ]

Conceptual 3D representation:
[[[a], [b], [c], [d]]]

Linear memory layout: [a, b, c, d]  (NO CHANGE!)

Visual Diagram

     Original: vector<4x1x1xf32>
     ┌─────────────────────────┐
     │ [[[a]], [[b]], [[c]], [[d]]] │
     └─────────────────────────┘
              │
              ├── transpose [2,0,1] ──┐
              │                       │
              └── transpose [1,0,2] ──┼─┐
                                      │ │
                                      ▼ ▼
                        Both produce: vector<1x4x1xf32>
                        ┌─────────────────────────┐
                        │ [[[a], [b], [c], [d]]]  │
                        └─────────────────────────┘

     Memory Layout (all cases): [a, b, c, d]

Could you think of an example where the actual permutation patterns ([2, 0, 1] vs [1, 0, 2]) are not redundant and lead to semantic differences in the IR?

@dcaballe
Copy link
Contributor

Could you think of an example where the actual permutation patterns ([2, 0, 1] vs [1, 0, 2]) are not redundant and lead to semantic differences in the IR?

Answering my own question, I can think of one use case: any kind of traversal that needs to track or propagate a property across one of the unit dimensions in the example wouldn't be able to do so with the vector.shape_cast form. The "data layout" information becomes ambiguous.

Conclusion 1 : We can’t canonicalize a transpose operation to a shape cast when multiple unit dimensions are transposed. The data layout or dimension mapping across the operation becomes ambiguous with the vector.shape_cast op.

Great, that’s progress! We have identified something technical specific. I suggest that we continue focusing on the technical aspects of the different IR forms. Could we continue this exercise? Could we come up with similar examples for:

  1. Non-ambiguous cases of transposes (i.e., only one unit dim is transposed)
  2. Broadcast (i.e., reshape-like broadcast adding unit dims)

@MaheshRavishankar
Copy link
Contributor

Great, that’s progress! We have identified something technical specific. I suggest that we continue focusing on the technical aspects of the different IR forms. Could we continue this exercise? Could we come up with similar examples for:

Non-ambiguous cases of transposes (i.e., only one unit dim is transposed)
Broadcast (i.e., reshape-like broadcast adding unit dims)

I am happy cursor was able to give you a better explanation of what I was trying to say all this while. Good to have reached this common state.

I think we did discuss previously, then stating that "certain" transpose/broadcasts are canonically shape_casts, and forcing them to then become shape_casts without control is now creating unnecessary complication in the definition of canonicalization. If some transformation is relying on following dimensions through broadcasts/transposes, now it has to look at a shape_cast, decide if this is "convertible to a transpose/broadcast" and then handle that appropriately. This does not seem like a great setup.

@banach-space
Copy link
Contributor

Thank you for the detailed discussion, @dcaballe - that was very helpful in clarifying the underlying issues.

I suggest that we continue focusing on the technical aspects of the different IR forms. Could we continue this exercise?

Agreed.

That’s one concrete example (*). @MaheshRavishankar, could you help us identify other specific cases so that we can better scope or constrain this change?

All in all, given the nuances discussed, I don’t see a specific blocker preventing this from being merged - or is there?

This does not seem like a great setup.

I think we may have to agree to disagree here. That said, as Vector maintainers, we’re committed to supporting all Vector users. If this change ends up causing issues for you, we’ll work with you to address them.

-Andrzej

(*) A transpose operation to a shape cast when multiple unit dimensions are transposed.

@MaheshRavishankar
Copy link
Contributor

All in all, given the nuances discussed, I don’t see a specific blocker preventing this from being merged - or is there?

We might be having different reads of the blocker. To me this discussion is uncovering more reasons why this change shouldnt be merged (this kind of thing is what I was saying would be an issue from the get go).

This does not seem like a great setup.

I think we may have to agree to disagree here. That said, as Vector maintainers, we’re committed to supporting all Vector users. If this change ends up causing issues for you, we’ll work with you to address them.

I want to re-iterate : this is not about my use case. We can find ways to work around things either way. So I am disagree-ing more with the approach here, rather than "this doesnt fit my use case".

@ftynse
Copy link
Member

ftynse commented Dec 10, 2025

Hi folks, this PR was brought to the attention of the @llvm/mlir-area-team. We think it can benefit from a higher-bandwidth discussion in a call. Could you all please fill in https://www.when2meet.com/?33926948-FNg3Z so we can facilitate that?

@kuhar
Copy link
Member

kuhar commented Dec 11, 2025

I entered my availability in the tool. However, I think we are missing a step by skipping an RFC which was requested upthread. This seems to be the first step of the standard llvm decision making process: https://github.com/llvm/llvm-www/blob/main/proposals/LP0001-LLVMDecisionMaking.md#proposed-solution.

@ftynse
Copy link
Member

ftynse commented Dec 11, 2025

An RFC would certainly be welcome, though it is but a tool to attempt establishing consensus. There being or, conversely, not being an RFC, shouldn't preclude us from using other means available to find consensus. Unless the lack of an RFC is the principal contentious point to start with.

LP0001 was amended by LP0004, which requires area teams to facilitate consensus seeking: https://github.com/llvm/llvm-www/blob/HEAD/proposals/LP0004-project-governance.md#consensus-seeking-decision-making. The area team executes on that.

@kuhar
Copy link
Member

kuhar commented Dec 11, 2025

My understanding is that LP0004 only reinforces the role of RFCs in decision making; some quotes:

The goal of this proposal is to codify a structure for how decisions are made and who makes the final decision. This proposal builds on LP0001 LLVM Decision Making, and assumes that contentious decisions go through that process. For the context of this proposal and [LP0001], a contentious decision is one where general agreement is not reached through discussion on Discourse. At the core of this proposal is the adoption of consensus-seeking decision making rather than formal consensus methods, and a recognition that decision making isn’t always binary.

Historically LLVM has relied on an informal RFC process. Our project documentation mentions the use of RFCs, but there is no documentation on the RFC process. This can lead to specific ambiguity about community acceptance of RFCs.

A core component of this proposal is a shift to encourage more proposals to use the process defined in LP0001 LLVM Decision Making. This proposal suggests the following guidance for when to use the Proposal process:

Finally, area teams are responsible for facilitating decision making for their area of the project. Facilitating decision making can take any number of forms ranging from contributing to RFC discussions, helping mediate disagreements, or fulfilling roles originally delegated to Chris Lattner in the LLVM Decision Making process.

Area teams should prepare a meeting agenda by collecting all the active RFCs in the community or significant disagreements in pull requests. During the team meeting, the area team should try to identify actionable next steps or information to gather so the RFC or pull request can proceed. An area team may escalate to the project council as needed.

The last quote does mention PRs specifically, but the general framework seems to be built around RFCs and discussion on discourse. In this specific instance, I think the discussion would strongly benefit from a more comprehensive proposal instead of continued back-and-forth that hasn't resulted in any consensus.

@ftynse
Copy link
Member

ftynse commented Dec 12, 2025

[...] At the core of this proposal is the adoption of consensus-seeking decision making rather than formal consensus methods, and a recognition that decision making isn’t always binary.

(emphasis mine) I'd rather us focus on working towards consensus in accordance with the spirit of these documents, not the procedural minutae. But we can also discuss that during the call.

Which is now scheduled for Tuesday, Dec 16, 15:30 CET / 2:30pm GMT / 9:30am ET / 6:30am PST, meeting link meet.google.com/xyp-khwh-scr. Let's strive for efficiency and refresh our understanding of the thread before the call so we don't waste time during it.

@ftynse
Copy link
Member

ftynse commented Dec 12, 2025

To clarify, after holding another discussion with the area team: we are explicitly refraining from making any decision based on the discussion within the area team only, without talking to all parties, this includes mandating an RFC. The outcome of the call may well be that an RFC is required, but it may become clearer by whom (there is disagreement between vector maintainers) and with what scope.

@ftynse
Copy link
Member

ftynse commented Dec 16, 2025

A brief summary of the call:

  • There is consensus that patterns rewriting extract and broadcast to shape_cast belong to canonicalization. These can be factored out into a new PR and merged following the usual procedure.
  • There is no consensus on transpose to shape_cast rewriting.
    • A specific request is to show why it is undesirable to canonicalize in the other direction, that is, rewrite shape_cast into transpose. Alternatively, show where having shape_cast instead of transpose is beneficial.
    • The pattern itself may make sense as part of converting vector to LLVM dialect so it doesn't just sit there.
  • We will follow up separately on a collaborative definition of what makes something a canonicalization. Specifically, it was raised that, from systems perspective, it may not be desirable to have rewrites between two equivalent forms without deeming one of those directions canonical.

@banach-space
Copy link
Contributor

Thank you for organising and facilitating this, @ftynse! And thanks to everyone who joined - I found the discussion very constructive. It’s great to see that we’ve identified clear, actionable next steps 🙏

@newling
Copy link
Contributor Author

newling commented Dec 17, 2025

Thanks @ftynse, and everyone who attended the call, I'm sorry I couldn't make it. I'm happy to create a new PR with 2 of the 3 patterns in this PR, will do so in the coming weeks.

Alternatively, show where having shape_cast instead of transpose is beneficial.

One small advantage is that we get the following simplification for free:

%0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
%1 = vector.transpose %0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
%2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>

====>

%0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
%1 = vector.shape_cast %0 : vector<1x4xf32> to vector<4x1xf32>
%2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>

====>

%2 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<4xf32>

@matthias-springer
Copy link
Member

OK, so to summarize, there are cases where transpose => shape_cast is beneficial and cases where shape_cast => transpose is beneficial.

@newling Are there other cases where transpose => shape_cast canonicalization is beneficial? You can write down the same example with transpose / shape_cast inverted and come to the opposite conclusion.

@joker-eph
Copy link
Collaborator

You can write down the same example with transpose / shape_cast inverted and come to the opposite conclusion.

Are you saying that there are sequences of transposes that can't be all shape_casts that would fold themselves?

@matthias-springer
Copy link
Member

No. What I meant, you can make this argument, which is very similar to @newling's but argues for the opposite canonicalization:

%0 = vector.transpose %arg0 [1, 0] : vector<1x4xf32> to vector<4x1xf32>
%1 = vector.shape_cast %0 : vector<4x1xf32> to vector<1x4xf32>
%2 = vector.transpose %1 [1, 0] : vector<1x4xf32> to vector<4x1xf32>

==>

%0 = vector.transpose %arg0 [1, 0] : vector<1x4xf32> to vector<4x1xf32>
%1 = vector.transpose %0 [1, 0] : vector<4x1xf32> to vector<1x4xf32>
%2 = vector.transpose %1 [1, 0] : vector<1x4xf32> to vector<4x1xf32>

==>

%0 = vector.transpose %arg0 [1, 0] : vector<1x4xf32> to vector<4x1xf32>

@Groverkss
Copy link
Member

Thanks @ftynse, and everyone who attended the call, I'm sorry I couldn't make it. I'm happy to create a new PR with 2 of the 3 patterns in this PR, will do so in the coming weeks.

Alternatively, show where having shape_cast instead of transpose is beneficial.

One small advantage is that we get the following simplification for free:

%0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
%1 = vector.transpose %0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
%2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>

====>

%0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
%1 = vector.shape_cast %0 : vector<1x4xf32> to vector<4x1xf32>
%2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>

====>

%2 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<4xf32>

This is not a benefit of having shape_cast here. This is a special case of a more general canonicalization pattern that can be written on shape_cast(transpose).

%0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
%1 = vector.transpose %0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
%2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>

You can canonicalize shape_cast(transpose) -> transpose(shape_cast) if it works on a subset which is doesn't get affected by the shape_cast:

shape_cast : 1x2x2 -> 1x4
transpose: 1x4 -> 4x1
shape_cast: 4x1 -> 4

apply shape_cast(transpose) -> transpose(shape_cast)

shape_cast: 1x2x2 -> 1x4
shape_cast: 1x4 -> 4
transpose: 4 -> 4 // no-op

and you get the same result later:

shape_cast: 1x2x2 -> 4

It also works for more interesting cases:

shape_cast: 1x2x2x2x3 -> 1x4x2x3
transpose: 1x4x2x3 -> 4x1x3x2
shape_cast: 4x1x3x2 -> 4x3x2

canonicalizes to:

shape_cast: 1x2x2x4 -> 4x2x3
transpose: 4x2x3 -> 3x2

This is actually shows a good reason not to focus on these special cased patterns, as they just hide more general patterns.

For shape_cast, you just have to choose a direction to go, either always up, always down, or always expand, always collapse and you can write patterns like this.

@Groverkss
Copy link
Member

Groverkss commented Dec 17, 2025

Not only this, this pattern in that form has a bigger problem of it doesn't canonicalize to a single form depending on what order patterns are run in. Let's say we implement the transpose -> shape_cast pattern for unit dimensions:

shape_cast: 1x2x2x2x3 -> 1x4x2x3
transpose: 1x4x2x3 -> 1x4x3x2
transpose: 1x4x3x2 -> 4x1x3x2
shape_cast: 4x1x3x2 -> 4x3x2

if the transpose -> shape_cast unit dim canonicalization runs first:

shape_cast: 1x2x2x2x3 -> 1x4x2x3
transpose: 1x4x2x3 -> 1x4x3x2
shape_cast: 1x4x3x2 -> 4x3x2

You are stuck here now, unless we have the more general pattern i talked about before.

If you canonicalize transpose(transpose) first:

shape_cast: 1x2x2x2x3 -> 1x4x2x3
transpose: 1x4x2x3 -> 4x1x3x2
shape_cast: 4x1x3x2 -> 4x3x2

and you are stuck here. (note this is a different form from the above ordering).

This is also why i was asking for concrete examples of why this is a good form, because without examples, it's hard to understand things better.

@newling
Copy link
Contributor Author

newling commented Dec 17, 2025

That's a good example of 'getting stuck' @Groverkss. I agree that there are arguments in all directions (no canonicalize vs canonicalize to transpose vs canonicalize to shape_cast). A more advanced canonicalization with multiple 'bubbling' stages (try to get order shape_cast < broadcast < transpose then try to get transpose < shape_cast < broadcast etc) would result in some nice folding opportunities.

newling added a commit that referenced this pull request Jan 22, 2026
…ctor.shape_cast (#174452)

Based on the original PR
#140583, but without
vector.transpose -> vector.shape_cast.

This PR canonicalizes 

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

to shape_cast. It was decided (see
#140583) that the
vector.transpose -> vector.shape_cast needs further consideration before
being added.

---------

Signed-off-by: James Newling <[email protected]>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 22, 2026
…dcast to vector.shape_cast (#174452)

Based on the original PR
llvm/llvm-project#140583, but without
vector.transpose -> vector.shape_cast.

This PR canonicalizes

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

to shape_cast. It was decided (see
llvm/llvm-project#140583) that the
vector.transpose -> vector.shape_cast needs further consideration before
being added.

---------

Signed-off-by: James Newling <[email protected]>
Harrish92 pushed a commit to Harrish92/llvm-project that referenced this pull request Jan 23, 2026
…ctor.shape_cast (llvm#174452)

Based on the original PR
llvm#140583, but without
vector.transpose -> vector.shape_cast.

This PR canonicalizes 

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

to shape_cast. It was decided (see
llvm#140583) that the
vector.transpose -> vector.shape_cast needs further consideration before
being added.

---------

Signed-off-by: James Newling <[email protected]>
Harrish92 pushed a commit to Harrish92/llvm-project that referenced this pull request Jan 24, 2026
…ctor.shape_cast (llvm#174452)

Based on the original PR
llvm#140583, but without
vector.transpose -> vector.shape_cast.

This PR canonicalizes 

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

to shape_cast. It was decided (see
llvm#140583) that the
vector.transpose -> vector.shape_cast needs further consideration before
being added.

---------

Signed-off-by: James Newling <[email protected]>
HugoSilvaSantos pushed a commit to HugoSilvaSantos/arm-toolchain that referenced this pull request Jan 27, 2026
…ctor.shape_cast (#174452)

Based on the original PR
llvm/llvm-project#140583, but without
vector.transpose -> vector.shape_cast.

This PR canonicalizes 

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

to shape_cast. It was decided (see
llvm/llvm-project#140583) that the
vector.transpose -> vector.shape_cast needs further consideration before
being added.

---------

Signed-off-by: James Newling <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants