diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 9db2a742a7d55..189438e9ad528 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -854,17 +854,23 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/); /// to work (these are checked by the vectorizer itself). bool hasVectorizationImpl(Operation *); -/// Emit a suitable vector form for an operation. If provided, -/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes` -/// must match the rank of the iteration space of the operation and the sizes -/// must be smaller or equal than their counterpart interation space sizes, if -/// static. `inputVectorShapes` also allows the vectorization of operations with -/// dynamic shapes. -LogicalResult vectorize(RewriterBase &rewriter, Operation *op, - ArrayRef inputVectorSizes = {}, - ArrayRef inputScalableVecDims = {}, - bool vectorizeNDExtract = false, - bool flatten1DDepthwiseConv = false); +/// Transformation information returned after vectorizing. +struct VectorizationResult { + /// Results of the vectorization transform to replace the original operation. + SmallVector replacements; +}; +/// Returns a `VectorizationResult` containing the results of the vectorized op, +/// or failure if the transformation fails. If provided, `inputVectorSizes` are +/// used to vectorize this operation. `inputVectorSizes` must match the rank of +/// the iteration space of the operation and the input vector sizes must be +/// greater than or equal to their counterpart iteration space sizes, if static. +/// `inputVectorShapes` also allows the vectorization of operations with dynamic +/// shapes. +FailureOr +vectorize(RewriterBase &rewriter, Operation *op, + ArrayRef inputVectorSizes = {}, + ArrayRef inputScalableVecDims = {}, + bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 2355edea2df6c..2b78e31558ea2 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3823,9 +3823,14 @@ struct VectorizationPattern : public RewritePattern { if (!linalg::hasVectorizationImpl(op)) return rewriter.notifyMatchFailure(op, "Unsupported Op, cannot vectorize"); - return vectorize(rewriter, op, /*inputVectorSizes=*/{}, - /*inputScalableVecDims=*/{}, vectorizeNDExtract, - flatten1DDepthwiseConv); + FailureOr vectorResults = + vectorize(rewriter, op, /*inputVectorSizes=*/{}, + /*inputScalableVecDims=*/{}, vectorizeNDExtract, + flatten1DDepthwiseConv); + if (failed(vectorResults)) + return failure(); + rewriter.replaceOp(op, vectorResults->replacements); + return success(); } private: @@ -3914,13 +3919,14 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply( return mlir::emitSilenceableFailure(target->getLoc()) << "Unsupported Op, cannot vectorize"; } - - if (failed(linalg::vectorize(rewriter, target, vectorSizes, - getScalableSizes(), - getVectorizeNdExtract().value_or(false)))) { + FailureOr vectorResults = + linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(), + getVectorizeNdExtract().value_or(false)); + if (failed(vectorResults)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; } + rewriter.replaceOp(target, vectorResults->replacements); } return DiagnosedSilenceableFailure::success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index ff8e0b8977ae8..e6a19fb5f57be 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -551,9 +551,10 @@ enum class Conv1DOpOrder { Nwc // Corresponds to operation that traverses the input in (n, w, c) order. }; -/// Helper data structure to represent the result of vectorization. -/// In certain specific cases, like terminators, we do not want to propagate/ -enum VectorizationStatus { +/// Helper data structure to represent the result of vectorization for a single +/// operation. In certain specific cases, like terminators, we do not want to +/// propagate. +enum VectorizationHookStatus { /// Op failed to vectorize. Failure = 0, /// Op vectorized and custom function took care of replacement logic @@ -564,9 +565,12 @@ enum VectorizationStatus { // TODO: support values if Op vectorized to Many-Ops whose results we need to // aggregate for replacement. }; -struct VectorizationResult { +/// VectorizationHookResult contains the vectorized op returned from a +/// CustomVectorizationHook. This is an internal implementation detail of +/// linalg vectorization, not to be confused with VectorizationResult. +struct VectorizationHookResult { /// Return status from vectorizing the current op. - enum VectorizationStatus status = VectorizationStatus::Failure; + enum VectorizationHookStatus status = VectorizationHookStatus::Failure; /// New vectorized operation to replace the current op. /// Replacement behavior is specified by `status`. Operation *newOp; @@ -728,22 +732,22 @@ using CustomVectorizationPrecondition = // assuming all its vectorized operands are already in the IRMapping. // Return nullptr if the Operation cannot be vectorized. using CustomVectorizationHook = - std::function; + std::function; /// Helper function to vectorize the terminator of a `linalgOp`. New result /// vector values are appended to `newResults`. Return -/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it -/// should not try to map produced operations and instead return the results -/// using the `newResults` vector making them available to the vectorization -/// algorithm for RAUW. This function is meant to be used as a +/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm +/// that it should not try to map produced operations and instead return the +/// results using the `newResults` vector making them available to the +/// vectorization algorithm for RAUW. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult +static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl &newResults) { auto yieldOp = dyn_cast(op); if (!yieldOp) - return VectorizationResult{VectorizationStatus::Failure, nullptr}; + return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr}; for (const auto &output : llvm::enumerate(yieldOp.getValues())) { // TODO: Scan for an opportunity for reuse. // TODO: use a map. @@ -755,20 +759,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, newResults.push_back(newResult); } - return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; + return VectorizationHookResult{VectorizationHookStatus::NoReplace, nullptr}; } /// Helper function to vectorize the index operations of a `linalgOp`. Return -/// VectorizationStatus::NewOp to signal the vectorization algorithm that it +/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, - VectorizationState &state, - Operation *op, - LinalgOp linalgOp) { +static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, + VectorizationState &state, + Operation *op, + LinalgOp linalgOp) { IndexOp indexOp = dyn_cast(op); if (!indexOp) - return VectorizationResult{VectorizationStatus::Failure, nullptr}; + return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr}; auto loc = indexOp.getLoc(); // Compute the static loop sizes of the index op. ArrayRef targetShape = state.getCanonicalVecShape(); @@ -782,7 +786,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, // dimension of the iteration space since the vectorization algorithm in this // case can handle the broadcast. if (dim == targetShape.size() - 1) - return VectorizationResult{VectorizationStatus::NewOp, indexSteps}; + return VectorizationHookResult{VectorizationHookStatus::NewOp, indexSteps}; // Otherwise permute the targetShape to move the index dimension last, // broadcast the one-dimensional index vector to the permuted shape, and // finally transpose the broadcasted index vector to undo the permutation. @@ -800,7 +804,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, std::swap(transposition.back(), transposition[dim]); auto transposeOp = rewriter.create(loc, broadCastOp, transposition); - return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; + return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp}; } /// Helper function to check if the tensor.extract can be vectorized by the @@ -1098,15 +1102,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, } /// Helper function to vectorize the tensor.extract operations. Returns -/// VectorizationStatus::NewOp to signal the vectorization algorithm that it +/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult +static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm) { tensor::ExtractOp extractOp = dyn_cast(op); if (!extractOp) - return VectorizationResult{VectorizationStatus::Failure, nullptr}; + return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr}; auto loc = extractOp.getLoc(); // Compute the static loop sizes of the extract op. @@ -1138,7 +1142,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); LDBG("Vectorised as gather load: " << extractOp << "\n"); - return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; + return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp}; } // 2. Handle: @@ -1202,7 +1206,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, mlir::vector::maskOperation(rewriter, transferReadOp, allTrue); LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n"); - return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp}; + return VectorizationHookResult{VectorizationHookStatus::NewOp, + maskedReadOp}; } // 2b. Handle contiguous access. @@ -1228,7 +1233,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, inBounds); LDBG("Vectorised as contiguous load: " << extractOp); - return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; + return VectorizationHookResult{VectorizationHookStatus::NewOp, + transferReadOp}; } /// Emit reduction operations if the shapes of the value to reduce is different @@ -1268,9 +1274,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, /// This function assumes all operands of `op` have been vectorized and are in /// the `bvm` mapping. As a consequence, this function is meant to be called on /// a topologically-sorted list of ops. -/// This function does not update `bvm` but returns a VectorizationStatus that -/// instructs the caller what `bvm` update needs to occur. -static VectorizationResult +/// This function does not update `bvm` but returns a VectorizationHookStatus +/// that instructs the caller what `bvm` update needs to occur. +static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef customVectorizationHooks) { @@ -1279,8 +1285,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, // 1. Try to apply any CustomVectorizationHook. if (!customVectorizationHooks.empty()) { for (auto &customFunc : customVectorizationHooks) { - VectorizationResult result = customFunc(op, bvm); - if (result.status == VectorizationStatus::Failure) + VectorizationHookResult result = customFunc(op, bvm); + if (result.status == VectorizationHookStatus::Failure) continue; return result; } @@ -1289,11 +1295,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, // 2. Constant ops don't get vectorized but rather broadcasted at their users. // Clone so that the constant is not confined to the linalgOp block . if (isa(op)) - return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)}; + return VectorizationHookResult{VectorizationHookStatus::NewOp, + rewriter.clone(*op)}; // 3. Only ElementwiseMappable are allowed in the generic vectorization. if (!OpTrait::hasElementwiseMappableTraits(op)) - return VectorizationResult{VectorizationStatus::Failure, nullptr}; + return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr}; // 4 . Check if the operation is a reduction. SmallVector> reductionOperands; @@ -1316,7 +1323,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first, reductionOperands[0].second, bvm); if (reduceOp) - return VectorizationResult{VectorizationStatus::NewOp, reduceOp}; + return VectorizationHookResult{VectorizationHookStatus::NewOp, reduceOp}; } // 5. Generic vectorization path for ElementwiseMappable ops. @@ -1356,8 +1363,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, : resultType); } // d. Build and return the new op. - return VectorizationResult{ - VectorizationStatus::NewOp, + return VectorizationHookResult{ + VectorizationHookStatus::NewOp, rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands, resultTypes, op->getAttrs())}; } @@ -1461,34 +1468,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, SmallVector hooks; // 4a. Register CustomVectorizationHook for yieldOp. CustomVectorizationHook vectorizeYield = - [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { + [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult { return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults); }; hooks.push_back(vectorizeYield); // 4b. Register CustomVectorizationHook for indexOp. CustomVectorizationHook vectorizeIndex = - [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { + [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult { return vectorizeLinalgIndex(rewriter, state, op, linalgOp); }; hooks.push_back(vectorizeIndex); // 4c. Register CustomVectorizationHook for extractOp. CustomVectorizationHook vectorizeExtract = - [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { + [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult { return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm); }; hooks.push_back(vectorizeExtract); // 5. Iteratively call `vectorizeOneOp` to each op in the slice. for (Operation &op : block->getOperations()) { - VectorizationResult result = + VectorizationHookResult result = vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks); - if (result.status == VectorizationStatus::Failure) { + if (result.status == VectorizationHookStatus::Failure) { LDBG("failed to vectorize: " << op << "\n"); return failure(); } - if (result.status == VectorizationStatus::NewOp) { + if (result.status == VectorizationHookStatus::NewOp) { Operation *maybeMaskedOp = state.maskOperation(rewriter, result.newOp, linalgOp); LDBG("New vector op: " << *maybeMaskedOp << "\n"); @@ -2525,17 +2532,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) { tensor::InsertSliceOp>(op); } -/// Emit a suitable vector form for an operation. If provided, -/// `inputVectorSizes` are used to vectorize this operation. -/// `inputVectorSizes` must match the rank of the iteration space of the -/// operation and the input vector sizes must be greater than or equal to -/// their counterpart iteration space sizes, if static. `inputVectorShapes` -/// also allows the vectorization of operations with dynamic shapes. -LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, - ArrayRef inputVectorSizes, - ArrayRef inputScalableVecDims, - bool vectorizeNDExtract, - bool flatten1DDepthwiseConv) { +FailureOr +mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, + ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, + bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); @@ -2617,12 +2618,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, return failure(); } - if (!results.empty()) - rewriter.replaceOp(op, results); - else - rewriter.eraseOp(op); - - return success(); + return VectorizationResult{results}; } LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,