@@ -551,10 +551,9 @@ enum class Conv1DOpOrder {
551551 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
552552};
553553
554- // / Helper data structure to represent the result of vectorization for a single
555- // / operation. In certain specific cases, like terminators, we do not want to
556- // / propagate.
557- enum VectorizationHookStatus {
554+ // / Helper data structure to represent the result of vectorization.
555+ // / In certain specific cases, like terminators, we do not want to propagate/
556+ enum VectorizationStatus {
558557 // / Op failed to vectorize.
559558 Failure = 0 ,
560559 // / Op vectorized and custom function took care of replacement logic
@@ -565,12 +564,9 @@ enum VectorizationHookStatus {
565564 // TODO: support values if Op vectorized to Many-Ops whose results we need to
566565 // aggregate for replacement.
567566};
568- // / VectorizationHookResult contains the vectorized op returned from a
569- // / CustomVectorizationHook. This is an internal implementation detail of
570- // / linalg vectorization, not to be confused with VectorizationResult.
571- struct VectorizationHookResult {
567+ struct VectorizationResult {
572568 // / Return status from vectorizing the current op.
573- enum VectorizationHookStatus status = VectorizationHookStatus ::Failure;
569+ enum VectorizationStatus status = VectorizationStatus ::Failure;
574570 // / New vectorized operation to replace the current op.
575571 // / Replacement behavior is specified by `status`.
576572 Operation *newOp;
@@ -732,22 +728,22 @@ using CustomVectorizationPrecondition =
732728// assuming all its vectorized operands are already in the IRMapping.
733729// Return nullptr if the Operation cannot be vectorized.
734730using CustomVectorizationHook =
735- std::function<VectorizationHookResult (Operation *, const IRMapping &)>;
731+ std::function<VectorizationResult (Operation *, const IRMapping &)>;
736732
737733// / Helper function to vectorize the terminator of a `linalgOp`. New result
738734// / vector values are appended to `newResults`. Return
739- // / VectorizationHookStatus ::NoReplace to signal the vectorization algorithm
740- // / that it should not try to map produced operations and instead return the
741- // / results using the `newResults` vector making them available to the
742- // / vectorization algorithm for RAUW. This function is meant to be used as a
735+ // / VectorizationStatus ::NoReplace to signal the vectorization algorithm that it
736+ // / should not try to map produced operations and instead return the results
737+ // / using the `newResults` vector making them available to the vectorization
738+ // / algorithm for RAUW. This function is meant to be used as a
743739// / CustomVectorizationHook.
744- static VectorizationHookResult
740+ static VectorizationResult
745741vectorizeLinalgYield (RewriterBase &rewriter, Operation *op,
746742 const IRMapping &bvm, VectorizationState &state,
747743 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
748744 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
749745 if (!yieldOp)
750- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
746+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
751747 for (const auto &output : llvm::enumerate (yieldOp.getValues ())) {
752748 // TODO: Scan for an opportunity for reuse.
753749 // TODO: use a map.
@@ -759,20 +755,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
759755 newResults.push_back (newResult);
760756 }
761757
762- return VectorizationHookResult{VectorizationHookStatus ::NoReplace, nullptr };
758+ return VectorizationResult{VectorizationStatus ::NoReplace, nullptr };
763759}
764760
765761// / Helper function to vectorize the index operations of a `linalgOp`. Return
766- // / VectorizationHookStatus ::NewOp to signal the vectorization algorithm that it
762+ // / VectorizationStatus ::NewOp to signal the vectorization algorithm that it
767763// / should map the produced operations. This function is meant to be used as a
768764// / CustomVectorizationHook.
769- static VectorizationHookResult vectorizeLinalgIndex (RewriterBase &rewriter,
770- VectorizationState &state,
771- Operation *op,
772- LinalgOp linalgOp) {
765+ static VectorizationResult vectorizeLinalgIndex (RewriterBase &rewriter,
766+ VectorizationState &state,
767+ Operation *op,
768+ LinalgOp linalgOp) {
773769 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
774770 if (!indexOp)
775- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
771+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
776772 auto loc = indexOp.getLoc ();
777773 // Compute the static loop sizes of the index op.
778774 ArrayRef<int64_t > targetShape = state.getCanonicalVecShape ();
@@ -786,7 +782,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
786782 // dimension of the iteration space since the vectorization algorithm in this
787783 // case can handle the broadcast.
788784 if (dim == targetShape.size () - 1 )
789- return VectorizationHookResult{VectorizationHookStatus ::NewOp, indexSteps};
785+ return VectorizationResult{VectorizationStatus ::NewOp, indexSteps};
790786 // Otherwise permute the targetShape to move the index dimension last,
791787 // broadcast the one-dimensional index vector to the permuted shape, and
792788 // finally transpose the broadcasted index vector to undo the permutation.
@@ -804,7 +800,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
804800 std::swap (transposition.back (), transposition[dim]);
805801 auto transposeOp =
806802 rewriter.create <vector::TransposeOp>(loc, broadCastOp, transposition);
807- return VectorizationHookResult{VectorizationHookStatus ::NewOp, transposeOp};
803+ return VectorizationResult{VectorizationStatus ::NewOp, transposeOp};
808804}
809805
810806// / Helper function to check if the tensor.extract can be vectorized by the
@@ -1102,15 +1098,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
11021098}
11031099
11041100// / Helper function to vectorize the tensor.extract operations. Returns
1105- // / VectorizationHookStatus ::NewOp to signal the vectorization algorithm that it
1101+ // / VectorizationStatus ::NewOp to signal the vectorization algorithm that it
11061102// / should map the produced operations. This function is meant to be used as a
11071103// / CustomVectorizationHook.
1108- static VectorizationHookResult
1104+ static VectorizationResult
11091105vectorizeTensorExtract (RewriterBase &rewriter, VectorizationState &state,
11101106 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
11111107 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
11121108 if (!extractOp)
1113- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
1109+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
11141110 auto loc = extractOp.getLoc ();
11151111
11161112 // Compute the static loop sizes of the extract op.
@@ -1142,7 +1138,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11421138 gatherOp = state.maskOperation (rewriter, gatherOp, linalgOp);
11431139
11441140 LDBG (" Vectorised as gather load: " << extractOp << " \n " );
1145- return VectorizationHookResult{VectorizationHookStatus ::NewOp, gatherOp};
1141+ return VectorizationResult{VectorizationStatus ::NewOp, gatherOp};
11461142 }
11471143
11481144 // 2. Handle:
@@ -1206,8 +1202,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12061202 mlir::vector::maskOperation (rewriter, transferReadOp, allTrue);
12071203
12081204 LDBG (" Vectorised as scalar broadcast load: " << extractOp << " \n " );
1209- return VectorizationHookResult{VectorizationHookStatus::NewOp,
1210- maskedReadOp};
1205+ return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
12111206 }
12121207
12131208 // 2b. Handle contiguous access.
@@ -1233,8 +1228,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12331228 inBounds);
12341229
12351230 LDBG (" Vectorised as contiguous load: " << extractOp);
1236- return VectorizationHookResult{VectorizationHookStatus::NewOp,
1237- transferReadOp};
1231+ return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
12381232}
12391233
12401234// / Emit reduction operations if the shapes of the value to reduce is different
@@ -1274,9 +1268,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
12741268// / This function assumes all operands of `op` have been vectorized and are in
12751269// / the `bvm` mapping. As a consequence, this function is meant to be called on
12761270// / a topologically-sorted list of ops.
1277- // / This function does not update `bvm` but returns a VectorizationHookStatus
1278- // / that instructs the caller what `bvm` update needs to occur.
1279- static VectorizationHookResult
1271+ // / This function does not update `bvm` but returns a VectorizationStatus that
1272+ // / instructs the caller what `bvm` update needs to occur.
1273+ static VectorizationResult
12801274vectorizeOneOp (RewriterBase &rewriter, VectorizationState &state,
12811275 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
12821276 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1285,8 +1279,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12851279 // 1. Try to apply any CustomVectorizationHook.
12861280 if (!customVectorizationHooks.empty ()) {
12871281 for (auto &customFunc : customVectorizationHooks) {
1288- VectorizationHookResult result = customFunc (op, bvm);
1289- if (result.status == VectorizationHookStatus ::Failure)
1282+ VectorizationResult result = customFunc (op, bvm);
1283+ if (result.status == VectorizationStatus ::Failure)
12901284 continue ;
12911285 return result;
12921286 }
@@ -1295,12 +1289,11 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12951289 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
12961290 // Clone so that the constant is not confined to the linalgOp block .
12971291 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1298- return VectorizationHookResult{VectorizationHookStatus::NewOp,
1299- rewriter.clone (*op)};
1292+ return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone (*op)};
13001293
13011294 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
13021295 if (!OpTrait::hasElementwiseMappableTraits (op))
1303- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
1296+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
13041297
13051298 // 4 . Check if the operation is a reduction.
13061299 SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1323,7 +1316,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13231316 reduceIfNeeded (rewriter, linalgOp, op, reductionOperands[0 ].first ,
13241317 reductionOperands[0 ].second , bvm);
13251318 if (reduceOp)
1326- return VectorizationHookResult{VectorizationHookStatus ::NewOp, reduceOp};
1319+ return VectorizationResult{VectorizationStatus ::NewOp, reduceOp};
13271320 }
13281321
13291322 // 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1363,8 +1356,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13631356 : resultType);
13641357 }
13651358 // d. Build and return the new op.
1366- return VectorizationHookResult {
1367- VectorizationHookStatus ::NewOp,
1359+ return VectorizationResult {
1360+ VectorizationStatus ::NewOp,
13681361 rewriter.create (op->getLoc (), op->getName ().getIdentifier (), vecOperands,
13691362 resultTypes, op->getAttrs ())};
13701363}
@@ -1468,34 +1461,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14681461 SmallVector<CustomVectorizationHook> hooks;
14691462 // 4a. Register CustomVectorizationHook for yieldOp.
14701463 CustomVectorizationHook vectorizeYield =
1471- [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1464+ [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14721465 return vectorizeLinalgYield (rewriter, op, bvm, state, linalgOp, newResults);
14731466 };
14741467 hooks.push_back (vectorizeYield);
14751468
14761469 // 4b. Register CustomVectorizationHook for indexOp.
14771470 CustomVectorizationHook vectorizeIndex =
1478- [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1471+ [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14791472 return vectorizeLinalgIndex (rewriter, state, op, linalgOp);
14801473 };
14811474 hooks.push_back (vectorizeIndex);
14821475
14831476 // 4c. Register CustomVectorizationHook for extractOp.
14841477 CustomVectorizationHook vectorizeExtract =
1485- [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1478+ [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14861479 return vectorizeTensorExtract (rewriter, state, op, linalgOp, bvm);
14871480 };
14881481 hooks.push_back (vectorizeExtract);
14891482
14901483 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
14911484 for (Operation &op : block->getOperations ()) {
1492- VectorizationHookResult result =
1485+ VectorizationResult result =
14931486 vectorizeOneOp (rewriter, state, linalgOp, &op, bvm, hooks);
1494- if (result.status == VectorizationHookStatus ::Failure) {
1487+ if (result.status == VectorizationStatus ::Failure) {
14951488 LDBG (" failed to vectorize: " << op << " \n " );
14961489 return failure ();
14971490 }
1498- if (result.status == VectorizationHookStatus ::NewOp) {
1491+ if (result.status == VectorizationStatus ::NewOp) {
14991492 Operation *maybeMaskedOp =
15001493 state.maskOperation (rewriter, result.newOp , linalgOp);
15011494 LDBG (" New vector op: " << *maybeMaskedOp << " \n " );
@@ -2532,11 +2525,17 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25322525 tensor::InsertSliceOp>(op);
25332526}
25342527
2535- FailureOr<VectorizationResult>
2536- mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2537- ArrayRef<int64_t > inputVectorSizes,
2538- ArrayRef<bool > inputScalableVecDims,
2539- bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2528+ // / Emit a suitable vector form for an operation. If provided,
2529+ // / `inputVectorSizes` are used to vectorize this operation.
2530+ // / `inputVectorSizes` must match the rank of the iteration space of the
2531+ // / operation and the input vector sizes must be greater than or equal to
2532+ // / their counterpart iteration space sizes, if static. `inputVectorShapes`
2533+ // / also allows the vectorization of operations with dynamic shapes.
2534+ LogicalResult mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2535+ ArrayRef<int64_t > inputVectorSizes,
2536+ ArrayRef<bool > inputScalableVecDims,
2537+ bool vectorizeNDExtract,
2538+ bool flatten1DDepthwiseConv) {
25402539 LDBG (" Attempting to vectorize:\n " << *op << " \n " );
25412540 LDBG (" Input vector sizes: " );
25422541 LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -2618,7 +2617,12 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
26182617 return failure ();
26192618 }
26202619
2621- return VectorizationResult{results};
2620+ if (!results.empty ())
2621+ rewriter.replaceOp (op, results);
2622+ else
2623+ rewriter.eraseOp (op);
2624+
2625+ return success ();
26222626}
26232627
26242628LogicalResult mlir::linalg::vectorizeCopy (RewriterBase &rewriter,
0 commit comments