@@ -563,7 +563,7 @@ enum VectorizationStatus {
563563 // TODO: support values if Op vectorized to Many-Ops whose results we need to
564564 // aggregate for replacement.
565565};
566- struct VectorizationResult {
566+ struct VectorizationHookResult {
567567 // / Return status from vectorizing the current op.
568568 enum VectorizationStatus status = VectorizationStatus::Failure;
569569 // / New vectorized operation to replace the current op.
@@ -727,7 +727,7 @@ using CustomVectorizationPrecondition =
727727// assuming all its vectorized operands are already in the IRMapping.
728728// Return nullptr if the Operation cannot be vectorized.
729729using CustomVectorizationHook =
730- std::function<VectorizationResult (Operation *, const IRMapping &)>;
730+ std::function<VectorizationHookResult (Operation *, const IRMapping &)>;
731731
732732// / Helper function to vectorize the terminator of a `linalgOp`. New result
733733// / vector values are appended to `newResults`. Return
@@ -736,13 +736,13 @@ using CustomVectorizationHook =
736736// / using the `newResults` vector making them available to the vectorization
737737// / algorithm for RAUW. This function is meant to be used as a
738738// / CustomVectorizationHook.
739- static VectorizationResult
739+ static VectorizationHookResult
740740vectorizeLinalgYield (RewriterBase &rewriter, Operation *op,
741741 const IRMapping &bvm, VectorizationState &state,
742742 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
743743 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
744744 if (!yieldOp)
745- return VectorizationResult {VectorizationStatus::Failure, nullptr };
745+ return VectorizationHookResult {VectorizationStatus::Failure, nullptr };
746746 for (const auto &output : llvm::enumerate (yieldOp.getValues ())) {
747747 // TODO: Scan for an opportunity for reuse.
748748 // TODO: use a map.
@@ -754,20 +754,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
754754 newResults.push_back (newResult);
755755 }
756756
757- return VectorizationResult {VectorizationStatus::NoReplace, nullptr };
757+ return VectorizationHookResult {VectorizationStatus::NoReplace, nullptr };
758758}
759759
760760// / Helper function to vectorize the index operations of a `linalgOp`. Return
761761// / VectorizationStatus::NewOp to signal the vectorization algorithm that it
762762// / should map the produced operations. This function is meant to be used as a
763763// / CustomVectorizationHook.
764- static VectorizationResult vectorizeLinalgIndex (RewriterBase &rewriter,
765- VectorizationState &state,
766- Operation *op,
767- LinalgOp linalgOp) {
764+ static VectorizationHookResult vectorizeLinalgIndex (RewriterBase &rewriter,
765+ VectorizationState &state,
766+ Operation *op,
767+ LinalgOp linalgOp) {
768768 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
769769 if (!indexOp)
770- return VectorizationResult {VectorizationStatus::Failure, nullptr };
770+ return VectorizationHookResult {VectorizationStatus::Failure, nullptr };
771771 auto loc = indexOp.getLoc ();
772772 // Compute the static loop sizes of the index op.
773773 ArrayRef<int64_t > targetShape = state.getCanonicalVecShape ();
@@ -781,7 +781,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
781781 // dimension of the iteration space since the vectorization algorithm in this
782782 // case can handle the broadcast.
783783 if (dim == targetShape.size () - 1 )
784- return VectorizationResult {VectorizationStatus::NewOp, indexSteps};
784+ return VectorizationHookResult {VectorizationStatus::NewOp, indexSteps};
785785 // Otherwise permute the targetShape to move the index dimension last,
786786 // broadcast the one-dimensional index vector to the permuted shape, and
787787 // finally transpose the broadcasted index vector to undo the permutation.
@@ -799,7 +799,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
799799 std::swap (transposition.back (), transposition[dim]);
800800 auto transposeOp =
801801 rewriter.create <vector::TransposeOp>(loc, broadCastOp, transposition);
802- return VectorizationResult {VectorizationStatus::NewOp, transposeOp};
802+ return VectorizationHookResult {VectorizationStatus::NewOp, transposeOp};
803803}
804804
805805// / Helper function to check if the tensor.extract can be vectorized by the
@@ -1100,12 +1100,12 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
11001100// / VectorizationStatus::NewOp to signal the vectorization algorithm that it
11011101// / should map the produced operations. This function is meant to be used as a
11021102// / CustomVectorizationHook.
1103- static VectorizationResult
1103+ static VectorizationHookResult
11041104vectorizeTensorExtract (RewriterBase &rewriter, VectorizationState &state,
11051105 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
11061106 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
11071107 if (!extractOp)
1108- return VectorizationResult {VectorizationStatus::Failure, nullptr };
1108+ return VectorizationHookResult {VectorizationStatus::Failure, nullptr };
11091109 auto loc = extractOp.getLoc ();
11101110
11111111 // Compute the static loop sizes of the extract op.
@@ -1137,7 +1137,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11371137 gatherOp = state.maskOperation (rewriter, gatherOp, linalgOp);
11381138
11391139 LDBG (" Vectorised as gather load: " << extractOp << " \n " );
1140- return VectorizationResult {VectorizationStatus::NewOp, gatherOp};
1140+ return VectorizationHookResult {VectorizationStatus::NewOp, gatherOp};
11411141 }
11421142
11431143 // 2. Handle:
@@ -1201,7 +1201,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12011201 mlir::vector::maskOperation (rewriter, transferReadOp, allTrue);
12021202
12031203 LDBG (" Vectorised as scalar broadcast load: " << extractOp << " \n " );
1204- return VectorizationResult {VectorizationStatus::NewOp, maskedReadOp};
1204+ return VectorizationHookResult {VectorizationStatus::NewOp, maskedReadOp};
12051205 }
12061206
12071207 // 2b. Handle contiguous access.
@@ -1227,7 +1227,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12271227 inBounds);
12281228
12291229 LDBG (" Vectorised as contiguous load: " << extractOp);
1230- return VectorizationResult {VectorizationStatus::NewOp, transferReadOp};
1230+ return VectorizationHookResult {VectorizationStatus::NewOp, transferReadOp};
12311231}
12321232
12331233// / Emit reduction operations if the shapes of the value to reduce is different
@@ -1269,7 +1269,7 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
12691269// / a topologically-sorted list of ops.
12701270// / This function does not update `bvm` but returns a VectorizationStatus that
12711271// / instructs the caller what `bvm` update needs to occur.
1272- static VectorizationResult
1272+ static VectorizationHookResult
12731273vectorizeOneOp (RewriterBase &rewriter, VectorizationState &state,
12741274 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
12751275 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1278,7 +1278,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12781278 // 1. Try to apply any CustomVectorizationHook.
12791279 if (!customVectorizationHooks.empty ()) {
12801280 for (auto &customFunc : customVectorizationHooks) {
1281- VectorizationResult result = customFunc (op, bvm);
1281+ VectorizationHookResult result = customFunc (op, bvm);
12821282 if (result.status == VectorizationStatus::Failure)
12831283 continue ;
12841284 return result;
@@ -1288,11 +1288,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12881288 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
12891289 // Clone so that the constant is not confined to the linalgOp block .
12901290 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1291- return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone (*op)};
1291+ return VectorizationHookResult{VectorizationStatus::NewOp,
1292+ rewriter.clone (*op)};
12921293
12931294 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
12941295 if (!OpTrait::hasElementwiseMappableTraits (op))
1295- return VectorizationResult {VectorizationStatus::Failure, nullptr };
1296+ return VectorizationHookResult {VectorizationStatus::Failure, nullptr };
12961297
12971298 // 4 . Check if the operation is a reduction.
12981299 SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1315,7 +1316,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13151316 reduceIfNeeded (rewriter, linalgOp, op, reductionOperands[0 ].first ,
13161317 reductionOperands[0 ].second , bvm);
13171318 if (reduceOp)
1318- return VectorizationResult {VectorizationStatus::NewOp, reduceOp};
1319+ return VectorizationHookResult {VectorizationStatus::NewOp, reduceOp};
13191320 }
13201321
13211322 // 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1355,7 +1356,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13551356 : resultType);
13561357 }
13571358 // d. Build and return the new op.
1358- return VectorizationResult {
1359+ return VectorizationHookResult {
13591360 VectorizationStatus::NewOp,
13601361 rewriter.create (op->getLoc (), op->getName ().getIdentifier (), vecOperands,
13611362 resultTypes, op->getAttrs ())};
@@ -1460,28 +1461,28 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14601461 SmallVector<CustomVectorizationHook> hooks;
14611462 // 4a. Register CustomVectorizationHook for yieldOp.
14621463 CustomVectorizationHook vectorizeYield =
1463- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1464+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14641465 return vectorizeLinalgYield (rewriter, op, bvm, state, linalgOp, newResults);
14651466 };
14661467 hooks.push_back (vectorizeYield);
14671468
14681469 // 4b. Register CustomVectorizationHook for indexOp.
14691470 CustomVectorizationHook vectorizeIndex =
1470- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1471+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14711472 return vectorizeLinalgIndex (rewriter, state, op, linalgOp);
14721473 };
14731474 hooks.push_back (vectorizeIndex);
14741475
14751476 // 4c. Register CustomVectorizationHook for extractOp.
14761477 CustomVectorizationHook vectorizeExtract =
1477- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1478+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14781479 return vectorizeTensorExtract (rewriter, state, op, linalgOp, bvm);
14791480 };
14801481 hooks.push_back (vectorizeExtract);
14811482
14821483 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
14831484 for (Operation &op : block->getOperations ()) {
1484- VectorizationResult result =
1485+ VectorizationHookResult result =
14851486 vectorizeOneOp (rewriter, state, linalgOp, &op, bvm, hooks);
14861487 if (result.status == VectorizationStatus::Failure) {
14871488 LDBG (" failed to vectorize: " << op << " \n " );
@@ -2522,17 +2523,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25222523 tensor::InsertSliceOp>(op);
25232524}
25242525
2525- // / Emit a suitable vector form for an operation. If provided,
2526- // / `inputVectorSizes` are used to vectorize this operation.
2527- // / `inputVectorSizes` must match the rank of the iteration space of the
2528- // / operation and the input vector sizes must be greater than or equal to
2529- // / their counterpart iteration space sizes, if static. `inputVectorShapes`
2530- // / also allows the vectorization of operations with dynamic shapes.
2531- LogicalResult mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2532- ArrayRef<int64_t > inputVectorSizes,
2533- ArrayRef<bool > inputScalableVecDims,
2534- bool vectorizeNDExtract,
2535- bool flatten1DDepthwiseConv) {
2526+ FailureOr<VectorizationResult>
2527+ mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2528+ ArrayRef<int64_t > inputVectorSizes,
2529+ ArrayRef<bool > inputScalableVecDims,
2530+ bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
25362531 LDBG (" Attempting to vectorize:\n " << *op << " \n " );
25372532 LDBG (" Input vector sizes: " );
25382533 LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -2614,12 +2609,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
26142609 return failure ();
26152610 }
26162611
2617- if (!results.empty ())
2618- rewriter.replaceOp (op, results);
2619- else
2620- rewriter.eraseOp (op);
2621-
2622- return success ();
2612+ return VectorizationResult ({results});
26232613}
26242614
26252615LogicalResult mlir::linalg::vectorizeCopy (RewriterBase &rewriter,
0 commit comments