@@ -551,10 +551,9 @@ enum class Conv1DOpOrder {
551551 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
552552};
553553
554- // / Helper data structure to represent the result of vectorization for a single
555- // / operation. In certain specific cases, like terminators, we do not want to
556- // / propagate.
557- enum VectorizationHookStatus {
554+ // / Helper data structure to represent the result of vectorization.
555+ // / In certain specific cases, like terminators, we do not want to propagate/
556+ enum VectorizationStatus {
558557 // / Op failed to vectorize.
559558 Failure = 0 ,
560559 // / Op vectorized and custom function took care of replacement logic
@@ -565,12 +564,9 @@ enum VectorizationHookStatus {
565564 // TODO: support values if Op vectorized to Many-Ops whose results we need to
566565 // aggregate for replacement.
567566};
568- // / VectorizationHookResult contains the vectorized op returned from a
569- // / CustomVectorizationHook. This is an internal implementation detail of
570- // / linalg vectorization, not to be confused with VectorizationResult.
571- struct VectorizationHookResult {
567+ struct VectorizationResult {
572568 // / Return status from vectorizing the current op.
573- enum VectorizationHookStatus status = VectorizationHookStatus ::Failure;
569+ enum VectorizationStatus status = VectorizationStatus ::Failure;
574570 // / New vectorized operation to replace the current op.
575571 // / Replacement behavior is specified by `status`.
576572 Operation *newOp;
@@ -732,22 +728,22 @@ using CustomVectorizationPrecondition =
732728// assuming all its vectorized operands are already in the IRMapping.
733729// Return nullptr if the Operation cannot be vectorized.
734730using CustomVectorizationHook =
735- std::function<VectorizationHookResult (Operation *, const IRMapping &)>;
731+ std::function<VectorizationResult (Operation *, const IRMapping &)>;
736732
737733// / Helper function to vectorize the terminator of a `linalgOp`. New result
738734// / vector values are appended to `newResults`. Return
739- // / VectorizationHookStatus ::NoReplace to signal the vectorization algorithm
740- // / that it should not try to map produced operations and instead return the
741- // / results using the `newResults` vector making them available to the
742- // / vectorization algorithm for RAUW. This function is meant to be used as a
735+ // / VectorizationStatus ::NoReplace to signal the vectorization algorithm that it
736+ // / should not try to map produced operations and instead return the results
737+ // / using the `newResults` vector making them available to the vectorization
738+ // / algorithm for RAUW. This function is meant to be used as a
743739// / CustomVectorizationHook.
744- static VectorizationHookResult
740+ static VectorizationResult
745741vectorizeLinalgYield (RewriterBase &rewriter, Operation *op,
746742 const IRMapping &bvm, VectorizationState &state,
747743 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
748744 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
749745 if (!yieldOp)
750- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
746+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
751747 for (const auto &output : llvm::enumerate (yieldOp.getValues ())) {
752748 // TODO: Scan for an opportunity for reuse.
753749 // TODO: use a map.
@@ -759,20 +755,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
759755 newResults.push_back (newResult);
760756 }
761757
762- return VectorizationHookResult{VectorizationHookStatus ::NoReplace, nullptr };
758+ return VectorizationResult{VectorizationStatus ::NoReplace, nullptr };
763759}
764760
765761// / Helper function to vectorize the index operations of a `linalgOp`. Return
766- // / VectorizationHookStatus ::NewOp to signal the vectorization algorithm that it
762+ // / VectorizationStatus ::NewOp to signal the vectorization algorithm that it
767763// / should map the produced operations. This function is meant to be used as a
768764// / CustomVectorizationHook.
769- static VectorizationHookResult vectorizeLinalgIndex (RewriterBase &rewriter,
770- VectorizationState &state,
771- Operation *op,
772- LinalgOp linalgOp) {
765+ static VectorizationResult vectorizeLinalgIndex (RewriterBase &rewriter,
766+ VectorizationState &state,
767+ Operation *op,
768+ LinalgOp linalgOp) {
773769 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
774770 if (!indexOp)
775- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
771+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
776772 auto loc = indexOp.getLoc ();
777773 // Compute the static loop sizes of the index op.
778774 ArrayRef<int64_t > targetShape = state.getCanonicalVecShape ();
@@ -786,7 +782,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
786782 // dimension of the iteration space since the vectorization algorithm in this
787783 // case can handle the broadcast.
788784 if (dim == targetShape.size () - 1 )
789- return VectorizationHookResult{VectorizationHookStatus ::NewOp, indexSteps};
785+ return VectorizationResult{VectorizationStatus ::NewOp, indexSteps};
790786 // Otherwise permute the targetShape to move the index dimension last,
791787 // broadcast the one-dimensional index vector to the permuted shape, and
792788 // finally transpose the broadcasted index vector to undo the permutation.
@@ -804,7 +800,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
804800 std::swap (transposition.back (), transposition[dim]);
805801 auto transposeOp =
806802 rewriter.create <vector::TransposeOp>(loc, broadCastOp, transposition);
807- return VectorizationHookResult{VectorizationHookStatus ::NewOp, transposeOp};
803+ return VectorizationResult{VectorizationStatus ::NewOp, transposeOp};
808804}
809805
810806// / Helper function to check if the tensor.extract can be vectorized by the
@@ -1100,15 +1096,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
11001096}
11011097
11021098// / Helper function to vectorize the tensor.extract operations. Returns
1103- // / VectorizationHookStatus ::NewOp to signal the vectorization algorithm that it
1099+ // / VectorizationStatus ::NewOp to signal the vectorization algorithm that it
11041100// / should map the produced operations. This function is meant to be used as a
11051101// / CustomVectorizationHook.
1106- static VectorizationHookResult
1102+ static VectorizationResult
11071103vectorizeTensorExtract (RewriterBase &rewriter, VectorizationState &state,
11081104 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
11091105 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
11101106 if (!extractOp)
1111- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
1107+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
11121108 auto loc = extractOp.getLoc ();
11131109
11141110 // Compute the static loop sizes of the extract op.
@@ -1140,7 +1136,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11401136 gatherOp = state.maskOperation (rewriter, gatherOp, linalgOp);
11411137
11421138 LDBG (" Vectorised as gather load: " << extractOp << " \n " );
1143- return VectorizationHookResult{VectorizationHookStatus ::NewOp, gatherOp};
1139+ return VectorizationResult{VectorizationStatus ::NewOp, gatherOp};
11441140 }
11451141
11461142 // 2. Handle:
@@ -1204,8 +1200,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12041200 mlir::vector::maskOperation (rewriter, transferReadOp, allTrue);
12051201
12061202 LDBG (" Vectorised as scalar broadcast load: " << extractOp << " \n " );
1207- return VectorizationHookResult{VectorizationHookStatus::NewOp,
1208- maskedReadOp};
1203+ return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
12091204 }
12101205
12111206 // 2b. Handle contiguous access.
@@ -1231,8 +1226,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12311226 inBounds);
12321227
12331228 LDBG (" Vectorised as contiguous load: " << extractOp);
1234- return VectorizationHookResult{VectorizationHookStatus::NewOp,
1235- transferReadOp};
1229+ return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
12361230}
12371231
12381232// / Emit reduction operations if the shapes of the value to reduce is different
@@ -1272,9 +1266,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
12721266// / This function assumes all operands of `op` have been vectorized and are in
12731267// / the `bvm` mapping. As a consequence, this function is meant to be called on
12741268// / a topologically-sorted list of ops.
1275- // / This function does not update `bvm` but returns a VectorizationHookStatus
1276- // / that instructs the caller what `bvm` update needs to occur.
1277- static VectorizationHookResult
1269+ // / This function does not update `bvm` but returns a VectorizationStatus that
1270+ // / instructs the caller what `bvm` update needs to occur.
1271+ static VectorizationResult
12781272vectorizeOneOp (RewriterBase &rewriter, VectorizationState &state,
12791273 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
12801274 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1283,8 +1277,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12831277 // 1. Try to apply any CustomVectorizationHook.
12841278 if (!customVectorizationHooks.empty ()) {
12851279 for (auto &customFunc : customVectorizationHooks) {
1286- VectorizationHookResult result = customFunc (op, bvm);
1287- if (result.status == VectorizationHookStatus ::Failure)
1280+ VectorizationResult result = customFunc (op, bvm);
1281+ if (result.status == VectorizationStatus ::Failure)
12881282 continue ;
12891283 return result;
12901284 }
@@ -1293,12 +1287,11 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12931287 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
12941288 // Clone so that the constant is not confined to the linalgOp block .
12951289 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1296- return VectorizationHookResult{VectorizationHookStatus::NewOp,
1297- rewriter.clone (*op)};
1290+ return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone (*op)};
12981291
12991292 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
13001293 if (!OpTrait::hasElementwiseMappableTraits (op))
1301- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
1294+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
13021295
13031296 // 4 . Check if the operation is a reduction.
13041297 SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1321,7 +1314,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13211314 reduceIfNeeded (rewriter, linalgOp, op, reductionOperands[0 ].first ,
13221315 reductionOperands[0 ].second , bvm);
13231316 if (reduceOp)
1324- return VectorizationHookResult{VectorizationHookStatus ::NewOp, reduceOp};
1317+ return VectorizationResult{VectorizationStatus ::NewOp, reduceOp};
13251318 }
13261319
13271320 // 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1361,8 +1354,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13611354 : resultType);
13621355 }
13631356 // d. Build and return the new op.
1364- return VectorizationHookResult {
1365- VectorizationHookStatus ::NewOp,
1357+ return VectorizationResult {
1358+ VectorizationStatus ::NewOp,
13661359 rewriter.create (op->getLoc (), op->getName ().getIdentifier (), vecOperands,
13671360 resultTypes, op->getAttrs ())};
13681361}
@@ -1466,34 +1459,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14661459 SmallVector<CustomVectorizationHook> hooks;
14671460 // 4a. Register CustomVectorizationHook for yieldOp.
14681461 CustomVectorizationHook vectorizeYield =
1469- [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1462+ [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14701463 return vectorizeLinalgYield (rewriter, op, bvm, state, linalgOp, newResults);
14711464 };
14721465 hooks.push_back (vectorizeYield);
14731466
14741467 // 4b. Register CustomVectorizationHook for indexOp.
14751468 CustomVectorizationHook vectorizeIndex =
1476- [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1469+ [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14771470 return vectorizeLinalgIndex (rewriter, state, op, linalgOp);
14781471 };
14791472 hooks.push_back (vectorizeIndex);
14801473
14811474 // 4c. Register CustomVectorizationHook for extractOp.
14821475 CustomVectorizationHook vectorizeExtract =
1483- [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1476+ [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14841477 return vectorizeTensorExtract (rewriter, state, op, linalgOp, bvm);
14851478 };
14861479 hooks.push_back (vectorizeExtract);
14871480
14881481 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
14891482 for (Operation &op : block->getOperations ()) {
1490- VectorizationHookResult result =
1483+ VectorizationResult result =
14911484 vectorizeOneOp (rewriter, state, linalgOp, &op, bvm, hooks);
1492- if (result.status == VectorizationHookStatus ::Failure) {
1485+ if (result.status == VectorizationStatus ::Failure) {
14931486 LDBG (" failed to vectorize: " << op << " \n " );
14941487 return failure ();
14951488 }
1496- if (result.status == VectorizationHookStatus ::NewOp) {
1489+ if (result.status == VectorizationStatus ::NewOp) {
14971490 Operation *maybeMaskedOp =
14981491 state.maskOperation (rewriter, result.newOp , linalgOp);
14991492 LDBG (" New vector op: " << *maybeMaskedOp << " \n " );
@@ -2530,11 +2523,17 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25302523 tensor::InsertSliceOp>(op);
25312524}
25322525
2533- FailureOr<VectorizationResult>
2534- mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2535- ArrayRef<int64_t > inputVectorSizes,
2536- ArrayRef<bool > inputScalableVecDims,
2537- bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2526+ // / Emit a suitable vector form for an operation. If provided,
2527+ // / `inputVectorSizes` are used to vectorize this operation.
2528+ // / `inputVectorSizes` must match the rank of the iteration space of the
2529+ // / operation and the input vector sizes must be greater than or equal to
2530+ // / their counterpart iteration space sizes, if static. `inputVectorShapes`
2531+ // / also allows the vectorization of operations with dynamic shapes.
2532+ LogicalResult mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2533+ ArrayRef<int64_t > inputVectorSizes,
2534+ ArrayRef<bool > inputScalableVecDims,
2535+ bool vectorizeNDExtract,
2536+ bool flatten1DDepthwiseConv) {
25382537 LDBG (" Attempting to vectorize:\n " << *op << " \n " );
25392538 LDBG (" Input vector sizes: " );
25402539 LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -2616,7 +2615,12 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
26162615 return failure ();
26172616 }
26182617
2619- return VectorizationResult{results};
2618+ if (!results.empty ())
2619+ rewriter.replaceOp (op, results);
2620+ else
2621+ rewriter.eraseOp (op);
2622+
2623+ return success ();
26202624}
26212625
26222626LogicalResult mlir::linalg::vectorizeCopy (RewriterBase &rewriter,
0 commit comments