Skip to content

Conversation

@newling
Copy link
Contributor

@newling newling commented Jan 5, 2026

Based on the original PR #140583, but without vector.transpose -> vector.shape_cast.

This PR canonicalizes

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

to shape_cast. It was decided (see #140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added.

@llvmbot
Copy link
Member

llvmbot commented Jan 5, 2026

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

Based on the original PR #140583, but without vector.transpose -> vector.shape_cast.

This PR canonicalizes

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

to shape_cast. It was decided (see #140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added.


Patch is 22.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174452.diff

6 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+72-24)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+11-11)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir (+1-1)
  • (added) mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir (+130)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+3-5)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 12bdc9646ee84..27dd4e2e034fc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2359,11 +2359,48 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
+/// Replace `vector.extract` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return rewriter.notifyMatchFailure(
+          extractOp, "extract to vector with fewer elements");
+
+    // Negative values in `position` means that the extacted value is poison.
+    // There is a vector.extract folder for this.
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "leaving for extract poison folder");
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getSource());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results
+      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+          context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -3076,13 +3113,43 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// Replace `vector.broadcast` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements()) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast to a greater number of elements");
+    }
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, BroadcastToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -6479,10 +6546,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-///   i) Y = ShapeCast(X), or
-///  ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using Base::Base;
@@ -6497,22 +6561,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
     bool srcIsScalar = !srcVectorType;
 
-    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
-    // Example:
-    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
-    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
-    // to
-    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
-    if (srcVectorType) {
-      if (srcVectorType.getNumElements() ==
-          shapeCastOp.getResultVectorType().getNumElements()) {
-        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-            shapeCastOp, shapeCastOp.getResultVectorType(),
-            broadcastOp.getSource());
-        return success();
-      }
-    }
-
     // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
     // Example
     // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e17b1cfbe5e0d..fdc8e487ff6bb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -814,7 +814,7 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // CHECK-LABEL: negative_fold_extract_broadcast
 //       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   vector.shape_cast %{{.*}} : vector<1x1x4xf32> to vector<4xf32>
 func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
   %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -931,7 +931,7 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
 
 // CHECK-LABEL: fold_extract_broadcastlike_shape_cast
 //  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
-//       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+//       CHECK:   %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32>
 //       CHECK:   return %[[R]] : vector<1x1xf32>
 func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
   -> vector<1x1xf32> {
@@ -1561,7 +1561,7 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
 
 // -----
 
-// Check the case where the same dimension is both broadcasted and sliced 
+// Check the case where the same dimension is both broadcasted and sliced
 // CHECK-LABEL: func @extract_strided_broadcast5
 //  CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
 //       CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
@@ -1918,7 +1918,7 @@ func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>,
 //  CHECK-SAME: (%[[V0:.*]]: vector<4x8xf32>, %[[MEM:.*]]: tensor<1x1x4x8xf32>)
 // CHECK-NOT: vector.transfer_write
 // CHECK-NOT: vector.transfer_read
-// CHECK: %[[RET:.+]] = vector.broadcast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
+// CHECK: %[[RET:.+]] = vector.shape_cast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
 // CHECK: return %[[RET]]
 func.func @store_to_load_tensor_forwarding_unit_dim_broadcast(
     %vec: vector<4x8xf32>,
@@ -2197,8 +2197,8 @@ func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
 
 // CHECK-LABEL: func @insert_extract_to_broadcast
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-//       CHECK:   %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
 //       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
 func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
   %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
@@ -2565,7 +2565,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
 
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
-  // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+  // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
   %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
   return %shuffle : vector<1xi32>
 }
@@ -3047,7 +3047,7 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
 func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
 
-  //  CHECK-NEXT:   %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
+  //  CHECK-NEXT:   %0 = vector.shape_cast {{.*}} : vector<1x1x1xf32> to vector<1xf32>
   //  CHECK-NEXT:   return %0 : vector<1xf32>
   %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
   return %1: vector<1xf32>
@@ -3554,7 +3554,7 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> {
 // +---------------------------------------------------------------------------
 
 // CHECK-LABEL: transpose_from_elements_1d
-// CHECK-SAME:  %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32 
+// CHECK-SAME:  %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32
 func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
   %v = vector.from_elements %el_0, %el_1 : vector<2xi32>
   %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
@@ -3565,7 +3565,7 @@ func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
 }
 
 // CHECK-LABEL: transpose_from_elements_2d
-// CHECK-SAME:  %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32 
+// CHECK-SAME:  %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32
 func.func @transpose_from_elements_2d(
   %el_0_0: i32, %el_0_1: i32, %el_0_2: i32,
   %el_1_0: i32, %el_1_1: i32, %el_1_2: i32
@@ -3579,7 +3579,7 @@ func.func @transpose_from_elements_2d(
 }
 
 // CHECK-LABEL: transpose_from_elements_3d
-// CHECK-SAME:  %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32 
+// CHECK-SAME:  %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32
 func.func @transpose_from_elements_3d(
   %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32,
   %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index f43328f621787..aa6539e466c95 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,7 +81,7 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
 //  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
+//       CHECK:       %[[EXTRACT:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
 //       CHECK:       return %[[EXTRACT]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
new file mode 100644
index 0000000000000..86fc6a5c0d6fc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize |  FileCheck %s
+
+// This file contains tests where a vector.shape_cast is the result
+// of canonicalization.
+
+// **--------------------------------------------------------** //
+//   Tests of BroadcastToShapeCast
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+//  CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
+//  CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT: return %[[SHAPE_CAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+  return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// broadcast can only be transformed to a shape_cast if the number of elements is
+// unchanged by the broadcast.
+// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return
+func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+  return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar
+// cannot be transformed to a shape_cast.
+// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return
+func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
+  %0 = vector.broadcast %arg0 : i8 to vector<1xi8>
+  return %0 : vector<1xi8>
+}
+
+// -----
+
+// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
+  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
+  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
+  return %1 : vector<1x2x1xf32>
+}
+
+// -----
+
+// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
+func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
+    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
+    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
+    return %1 : vector<1x1xf32>
+}
+
+// -----
+
+// **--------------------------------------------------------** //
+//   Tests of ExtractToShapeCast
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: @extract_to_shape_cast
+//  CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
+//  CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT: return %[[SHAPE_CAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+  %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison. We could
+// convert this to shape_cast (would be a legal transform with poison)
+// but we conservatively choose not to.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+//   CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+  %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_shapecast_to_shapecast
+//  CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
+//       CHECK:   %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
+//       CHECK:   return %[[R]]
+func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
+  %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
+  %r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32>
+  return %r : vector<12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_shape_cast
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
+  %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+  %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+  %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+  return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_broadcast
+func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
+  %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
+  //  CHECK-NEXT:   %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+  //  CHECK-NEXT:   return %[[RES]] : vector<1xf32>
+  %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
+  return %1: vector<1xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 45afbffc1be48..8206c1a3ec865 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -24,7 +24,7 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK:   %[[S:.*]] = vector.transfer_read %[[SRC]][]
-//  CHECK:   %[[V:.*]] = vector.broadcast %[[S]] : vector<f32> to vector<1xf32>
+//  CHECK:   %[[V:.*]] = vector.shape_cast %[[S]] : vector<f32> to vector<1xf32>
     %res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
       tensor<f32>, vector<1xf32>
 
@@ -369,8 +369,7 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
   %c0 = arith.constant 0 : index
 
   %res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
-  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
+  // CHECK: %[[NEW_VEC1:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
   // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
 
   return %res : tensor<?x?x?x?xf32>
@@ -385,8 +384,7 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 5, 2026

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

Based on the original PR #140583, but without vector.transpose -> vector.shape_cast.

This PR canonicalizes

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

to shape_cast. It was decided (see #140583) that the vector.transpose -> vector.shape_cast needs further consideration before being added.


Patch is 22.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174452.diff

6 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+72-24)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+11-11)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir (+1-1)
  • (added) mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir (+130)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+3-5)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 12bdc9646ee84..27dd4e2e034fc 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2359,11 +2359,48 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
+/// Replace `vector.extract` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return rewriter.notifyMatchFailure(
+          extractOp, "extract to vector with fewer elements");
+
+    // Negative values in `position` means that the extacted value is poison.
+    // There is a vector.extract folder for this.
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "leaving for extract poison folder");
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getSource());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results
+      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+          context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -3076,13 +3113,43 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// Replace `vector.broadcast` with `vector.shape_cast`.
+///
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+///
+/// The canonical form of vector operations that reshape vectors is shape_cast.
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements()) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "broadcast to a greater number of elements");
+    }
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, BroadcastToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -6479,10 +6546,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-///   i) Y = ShapeCast(X), or
-///  ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using Base::Base;
@@ -6497,22 +6561,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
     bool srcIsScalar = !srcVectorType;
 
-    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
-    // Example:
-    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
-    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
-    // to
-    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
-    if (srcVectorType) {
-      if (srcVectorType.getNumElements() ==
-          shapeCastOp.getResultVectorType().getNumElements()) {
-        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-            shapeCastOp, shapeCastOp.getResultVectorType(),
-            broadcastOp.getSource());
-        return success();
-      }
-    }
-
     // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
     // Example
     // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e17b1cfbe5e0d..fdc8e487ff6bb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -814,7 +814,7 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // CHECK-LABEL: negative_fold_extract_broadcast
 //       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   vector.shape_cast %{{.*}} : vector<1x1x4xf32> to vector<4xf32>
 func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
   %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -931,7 +931,7 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
 
 // CHECK-LABEL: fold_extract_broadcastlike_shape_cast
 //  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
-//       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+//       CHECK:   %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32>
 //       CHECK:   return %[[R]] : vector<1x1xf32>
 func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
   -> vector<1x1xf32> {
@@ -1561,7 +1561,7 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
 
 // -----
 
-// Check the case where the same dimension is both broadcasted and sliced 
+// Check the case where the same dimension is both broadcasted and sliced
 // CHECK-LABEL: func @extract_strided_broadcast5
 //  CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
 //       CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
@@ -1918,7 +1918,7 @@ func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>,
 //  CHECK-SAME: (%[[V0:.*]]: vector<4x8xf32>, %[[MEM:.*]]: tensor<1x1x4x8xf32>)
 // CHECK-NOT: vector.transfer_write
 // CHECK-NOT: vector.transfer_read
-// CHECK: %[[RET:.+]] = vector.broadcast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
+// CHECK: %[[RET:.+]] = vector.shape_cast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
 // CHECK: return %[[RET]]
 func.func @store_to_load_tensor_forwarding_unit_dim_broadcast(
     %vec: vector<4x8xf32>,
@@ -2197,8 +2197,8 @@ func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
 
 // CHECK-LABEL: func @insert_extract_to_broadcast
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-//       CHECK:   %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
 //       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
 func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
   %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
@@ -2565,7 +2565,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
 
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
-  // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+  // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
   %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
   return %shuffle : vector<1xi32>
 }
@@ -3047,7 +3047,7 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
 func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
 
-  //  CHECK-NEXT:   %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
+  //  CHECK-NEXT:   %0 = vector.shape_cast {{.*}} : vector<1x1x1xf32> to vector<1xf32>
   //  CHECK-NEXT:   return %0 : vector<1xf32>
   %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
   return %1: vector<1xf32>
@@ -3554,7 +3554,7 @@ func.func @from_elements_index_to_i64_conversion() -> vector<3xi64> {
 // +---------------------------------------------------------------------------
 
 // CHECK-LABEL: transpose_from_elements_1d
-// CHECK-SAME:  %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32 
+// CHECK-SAME:  %[[EL_0:.*]]: i32, %[[EL_1:.*]]: i32
 func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
   %v = vector.from_elements %el_0, %el_1 : vector<2xi32>
   %t = vector.transpose %v, [0] : vector<2xi32> to vector<2xi32>
@@ -3565,7 +3565,7 @@ func.func @transpose_from_elements_1d(%el_0: i32, %el_1: i32) -> vector<2xi32> {
 }
 
 // CHECK-LABEL: transpose_from_elements_2d
-// CHECK-SAME:  %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32 
+// CHECK-SAME:  %[[EL_0_0:.*]]: i32, %[[EL_0_1:.*]]: i32, %[[EL_0_2:.*]]: i32, %[[EL_1_0:.*]]: i32, %[[EL_1_1:.*]]: i32, %[[EL_1_2:.*]]: i32
 func.func @transpose_from_elements_2d(
   %el_0_0: i32, %el_0_1: i32, %el_0_2: i32,
   %el_1_0: i32, %el_1_1: i32, %el_1_2: i32
@@ -3579,7 +3579,7 @@ func.func @transpose_from_elements_2d(
 }
 
 // CHECK-LABEL: transpose_from_elements_3d
-// CHECK-SAME:  %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32 
+// CHECK-SAME:  %[[EL_0_0_0:.*]]: i32, %[[EL_0_0_1:.*]]: i32, %[[EL_0_1_0:.*]]: i32, %[[EL_0_1_1:.*]]: i32, %[[EL_0_2_0:.*]]: i32, %[[EL_0_2_1:.*]]: i32, %[[EL_1_0_0:.*]]: i32, %[[EL_1_0_1:.*]]: i32, %[[EL_1_1_0:.*]]: i32, %[[EL_1_1_1:.*]]: i32, %[[EL_1_2_0:.*]]: i32, %[[EL_1_2_1:.*]]: i32
 func.func @transpose_from_elements_3d(
   %el_0_0_0: i32, %el_0_0_1: i32, %el_0_1_0: i32, %el_0_1_1: i32, %el_0_2_0: i32, %el_0_2_1: i32,
   %el_1_0_0: i32, %el_1_0_1: i32, %el_1_1_0: i32, %el_1_1_1: i32, %el_1_2_0: i32, %el_1_2_1: i32
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index f43328f621787..aa6539e466c95 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,7 +81,7 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
 //  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
+//       CHECK:       %[[EXTRACT:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
 //       CHECK:       return %[[EXTRACT]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
new file mode 100644
index 0000000000000..86fc6a5c0d6fc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize |  FileCheck %s
+
+// This file contains tests where a vector.shape_cast is the result
+// of canonicalization.
+
+// **--------------------------------------------------------** //
+//   Tests of BroadcastToShapeCast
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+//  CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
+//  CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT: return %[[SHAPE_CAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+  return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// broadcast can only be transformed to a shape_cast if the number of elements is
+// unchanged by the broadcast.
+// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return
+func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+  return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar
+// cannot be transformed to a shape_cast.
+// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return
+func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
+  %0 = vector.broadcast %arg0 : i8 to vector<1xi8>
+  return %0 : vector<1xi8>
+}
+
+// -----
+
+// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
+func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
+  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
+  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
+  return %1 : vector<1x2x1xf32>
+}
+
+// -----
+
+// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
+// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
+//   CHECK-NOT:   vector.broadcast
+//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
+func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
+    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
+    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
+    return %1 : vector<1x1xf32>
+}
+
+// -----
+
+// **--------------------------------------------------------** //
+//   Tests of ExtractToShapeCast
+// **--------------------------------------------------------** //
+
+// CHECK-LABEL: @extract_to_shape_cast
+//  CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
+//  CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT: return %[[SHAPE_CAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+  %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison. We could
+// convert this to shape_cast (would be a legal transform with poison)
+// but we conservatively choose not to.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+//   CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+  %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_shapecast_to_shapecast
+//  CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
+//       CHECK:   %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
+//       CHECK:   return %[[R]]
+func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
+  %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
+  %r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32>
+  return %r : vector<12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_shape_cast
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
+  %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+  %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+  %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+  return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_broadcast
+func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
+  %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
+  //  CHECK-NEXT:   %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+  //  CHECK-NEXT:   return %[[RES]] : vector<1xf32>
+  %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
+  return %1: vector<1xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 45afbffc1be48..8206c1a3ec865 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -24,7 +24,7 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK:   %[[S:.*]] = vector.transfer_read %[[SRC]][]
-//  CHECK:   %[[V:.*]] = vector.broadcast %[[S]] : vector<f32> to vector<1xf32>
+//  CHECK:   %[[V:.*]] = vector.shape_cast %[[S]] : vector<f32> to vector<1xf32>
     %res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
       tensor<f32>, vector<1xf32>
 
@@ -369,8 +369,7 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
   %c0 = arith.constant 0 : index
 
   %res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
-  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
+  // CHECK: %[[NEW_VEC1:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
   // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
 
   return %res : tensor<?x?x?x?xf32>
@@ -385,8 +384,7 ...
[truncated]

@github-actions
Copy link

github-actions bot commented Jan 5, 2026

🐧 Linux x64 Test Results

  • 7377 tests passed
  • 599 tests skipped

✅ The build succeeded and all tests passed.

@github-actions
Copy link

github-actions bot commented Jan 5, 2026

🪟 Windows x64 Test Results

  • 3425 tests passed
  • 412 tests skipped

✅ The build succeeded and all tests passed.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving based on my review of the original PR (#140583).

Given #140583 (comment), I don’t expect further objections. That said, please wait for at least one more +1, ideally from someone who raised objections in the previous discussion (CC @Groverkss, @kuhar, @MaheshRavishankar).

Many thanks for working on this!

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just some nits

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but I'd also like @Groverkss to take a look

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks!

@banach-space
Copy link
Contributor

LGTM but I'd also like @Groverkss to take a look

It's been over a week - I suggest that we land this, Kunwar can leave a post-commit review.

@newling
Copy link
Contributor Author

newling commented Jan 19, 2026

Just a heads up that we'll land this in 2 days if there are no objections (@Groverkss)

@newling newling merged commit 962a9a3 into llvm:main Jan 22, 2026
11 checks passed
amd-eochoalo added a commit to iree-org/iree that referenced this pull request Jan 22, 2026
Reverts carried forward:
* Local revert of llvm/llvm-project#169614 due
to #22649

Other changes:
* Fixes lit tests to account for
llvm/llvm-project#174452
Harrish92 pushed a commit to Harrish92/llvm-project that referenced this pull request Jan 23, 2026
…ctor.shape_cast (llvm#174452)

Based on the original PR
llvm#140583, but without
vector.transpose -> vector.shape_cast.

This PR canonicalizes 

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

to shape_cast. It was decided (see
llvm#140583) that the
vector.transpose -> vector.shape_cast needs further consideration before
being added.

---------

Signed-off-by: James Newling <[email protected]>
Harrish92 pushed a commit to Harrish92/llvm-project that referenced this pull request Jan 24, 2026
…ctor.shape_cast (llvm#174452)

Based on the original PR
llvm#140583, but without
vector.transpose -> vector.shape_cast.

This PR canonicalizes 

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

to shape_cast. It was decided (see
llvm#140583) that the
vector.transpose -> vector.shape_cast needs further consideration before
being added.

---------

Signed-off-by: James Newling <[email protected]>
keshavvinayak01 pushed a commit to iree-org/iree that referenced this pull request Jan 27, 2026
Reverts carried forward:
* Local revert of llvm/llvm-project#169614 due
to #22649

Other changes:
* Fixes lit tests to account for
llvm/llvm-project#174452

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants