Skip to content

Commit b8a783b

Browse files
committed
[mlir] Return vectorized values instead of replacing
Signed-off-by: Max Dawkins <[email protected]>
1 parent 7aecd7e commit b8a783b

File tree

3 files changed

+64
-61
lines changed

3 files changed

+64
-61
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -771,17 +771,24 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
771771
/// to work (these are checked by the vectorizer itself).
772772
bool hasVectorizationImpl(Operation *);
773773

774+
/// Transformation information returned after vectorizing.
775+
struct VectorizationResult {
776+
/// Results of the vectorization transform to replace the original operation.
777+
SmallVector<Value> replacements;
778+
};
774779
/// Emit a suitable vector form for an operation. If provided,
775-
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
776-
/// must match the rank of the iteration space of the operation and the sizes
777-
/// must be smaller or equal than their counterpart interation space sizes, if
778-
/// static. `inputVectorShapes` also allows the vectorization of operations with
779-
/// dynamic shapes.
780-
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
781-
ArrayRef<int64_t> inputVectorSizes = {},
782-
ArrayRef<bool> inputScalableVecDims = {},
783-
bool vectorizeNDExtract = false,
784-
bool flatten1DDepthwiseConv = false);
780+
/// `inputVectorSizes` are used to vectorize this operation.
781+
/// `inputVectorSizes` must match the rank of the iteration space of the
782+
/// operation and the input vector sizes must be greater than or equal to
783+
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
784+
/// also allows the vectorization of operations with dynamic shapes. Returns
785+
/// a VectorizationResult containing the results of the vectorized op, or
786+
/// failure if the transformation fails.
787+
FailureOr<VectorizationResult>
788+
vectorize(RewriterBase &rewriter, Operation *op,
789+
ArrayRef<int64_t> inputVectorSizes = {},
790+
ArrayRef<bool> inputScalableVecDims = {},
791+
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
785792

786793
/// Emit a suitable vector form for a Copy op with fully static shape.
787794
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3609,9 +3609,14 @@ struct VectorizationPattern : public RewritePattern {
36093609
if (!linalg::hasVectorizationImpl(op))
36103610
return rewriter.notifyMatchFailure(op,
36113611
"Unsupported Op, cannot vectorize");
3612-
return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3613-
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
3614-
flatten1DDepthwiseConv);
3612+
FailureOr<VectorizationResult> vectorResults =
3613+
vectorize(rewriter, op, /*inputVectorSizes=*/{},
3614+
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
3615+
flatten1DDepthwiseConv);
3616+
if (failed(vectorResults))
3617+
return failure();
3618+
rewriter.replaceOp(op, vectorResults->replacements);
3619+
return success();
36153620
}
36163621

36173622
private:
@@ -3700,13 +3705,14 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
37003705
return mlir::emitSilenceableFailure(target->getLoc())
37013706
<< "Unsupported Op, cannot vectorize";
37023707
}
3703-
3704-
if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3705-
getScalableSizes(),
3706-
getVectorizeNdExtract().value_or(false)))) {
3708+
FailureOr<VectorizationResult> vectorResults =
3709+
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
3710+
getVectorizeNdExtract().value_or(false));
3711+
if (failed(vectorResults)) {
37073712
return mlir::emitSilenceableFailure(target->getLoc())
37083713
<< "Attempted to vectorize, but failed";
37093714
}
3715+
rewriter.replaceOp(target, vectorResults->replacements);
37103716
}
37113717

37123718
return DiagnosedSilenceableFailure::success();

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 34 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ enum VectorizationStatus {
563563
// TODO: support values if Op vectorized to Many-Ops whose results we need to
564564
// aggregate for replacement.
565565
};
566-
struct VectorizationResult {
566+
struct VectorizationHookResult {
567567
/// Return status from vectorizing the current op.
568568
enum VectorizationStatus status = VectorizationStatus::Failure;
569569
/// New vectorized operation to replace the current op.
@@ -727,7 +727,7 @@ using CustomVectorizationPrecondition =
727727
// assuming all its vectorized operands are already in the IRMapping.
728728
// Return nullptr if the Operation cannot be vectorized.
729729
using CustomVectorizationHook =
730-
std::function<VectorizationResult(Operation *, const IRMapping &)>;
730+
std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
731731

732732
/// Helper function to vectorize the terminator of a `linalgOp`. New result
733733
/// vector values are appended to `newResults`. Return
@@ -736,13 +736,13 @@ using CustomVectorizationHook =
736736
/// using the `newResults` vector making them available to the vectorization
737737
/// algorithm for RAUW. This function is meant to be used as a
738738
/// CustomVectorizationHook.
739-
static VectorizationResult
739+
static VectorizationHookResult
740740
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
741741
const IRMapping &bvm, VectorizationState &state,
742742
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
743743
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
744744
if (!yieldOp)
745-
return VectorizationResult{VectorizationStatus::Failure, nullptr};
745+
return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
746746
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
747747
// TODO: Scan for an opportunity for reuse.
748748
// TODO: use a map.
@@ -754,20 +754,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
754754
newResults.push_back(newResult);
755755
}
756756

757-
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
757+
return VectorizationHookResult{VectorizationStatus::NoReplace, nullptr};
758758
}
759759

760760
/// Helper function to vectorize the index operations of a `linalgOp`. Return
761761
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
762762
/// should map the produced operations. This function is meant to be used as a
763763
/// CustomVectorizationHook.
764-
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
765-
VectorizationState &state,
766-
Operation *op,
767-
LinalgOp linalgOp) {
764+
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
765+
VectorizationState &state,
766+
Operation *op,
767+
LinalgOp linalgOp) {
768768
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
769769
if (!indexOp)
770-
return VectorizationResult{VectorizationStatus::Failure, nullptr};
770+
return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
771771
auto loc = indexOp.getLoc();
772772
// Compute the static loop sizes of the index op.
773773
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
@@ -781,7 +781,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
781781
// dimension of the iteration space since the vectorization algorithm in this
782782
// case can handle the broadcast.
783783
if (dim == targetShape.size() - 1)
784-
return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
784+
return VectorizationHookResult{VectorizationStatus::NewOp, indexSteps};
785785
// Otherwise permute the targetShape to move the index dimension last,
786786
// broadcast the one-dimensional index vector to the permuted shape, and
787787
// finally transpose the broadcasted index vector to undo the permutation.
@@ -799,7 +799,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
799799
std::swap(transposition.back(), transposition[dim]);
800800
auto transposeOp =
801801
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
802-
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
802+
return VectorizationHookResult{VectorizationStatus::NewOp, transposeOp};
803803
}
804804

805805
/// Helper function to check if the tensor.extract can be vectorized by the
@@ -1100,12 +1100,12 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
11001100
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
11011101
/// should map the produced operations. This function is meant to be used as a
11021102
/// CustomVectorizationHook.
1103-
static VectorizationResult
1103+
static VectorizationHookResult
11041104
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11051105
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
11061106
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
11071107
if (!extractOp)
1108-
return VectorizationResult{VectorizationStatus::Failure, nullptr};
1108+
return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
11091109
auto loc = extractOp.getLoc();
11101110

11111111
// Compute the static loop sizes of the extract op.
@@ -1137,7 +1137,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11371137
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
11381138

11391139
LDBG("Vectorised as gather load: " << extractOp << "\n");
1140-
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
1140+
return VectorizationHookResult{VectorizationStatus::NewOp, gatherOp};
11411141
}
11421142

11431143
// 2. Handle:
@@ -1201,7 +1201,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12011201
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
12021202

12031203
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1204-
return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1204+
return VectorizationHookResult{VectorizationStatus::NewOp, maskedReadOp};
12051205
}
12061206

12071207
// 2b. Handle contiguous access.
@@ -1227,7 +1227,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12271227
inBounds);
12281228

12291229
LDBG("Vectorised as contiguous load: " << extractOp);
1230-
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1230+
return VectorizationHookResult{VectorizationStatus::NewOp, transferReadOp};
12311231
}
12321232

12331233
/// Emit reduction operations if the shapes of the value to reduce is different
@@ -1269,7 +1269,7 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
12691269
/// a topologically-sorted list of ops.
12701270
/// This function does not update `bvm` but returns a VectorizationStatus that
12711271
/// instructs the caller what `bvm` update needs to occur.
1272-
static VectorizationResult
1272+
static VectorizationHookResult
12731273
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12741274
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
12751275
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1278,7 +1278,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12781278
// 1. Try to apply any CustomVectorizationHook.
12791279
if (!customVectorizationHooks.empty()) {
12801280
for (auto &customFunc : customVectorizationHooks) {
1281-
VectorizationResult result = customFunc(op, bvm);
1281+
VectorizationHookResult result = customFunc(op, bvm);
12821282
if (result.status == VectorizationStatus::Failure)
12831283
continue;
12841284
return result;
@@ -1288,11 +1288,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12881288
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
12891289
// Clone so that the constant is not confined to the linalgOp block .
12901290
if (isa<arith::ConstantOp, func::ConstantOp>(op))
1291-
return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
1291+
return VectorizationHookResult{VectorizationStatus::NewOp,
1292+
rewriter.clone(*op)};
12921293

12931294
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
12941295
if (!OpTrait::hasElementwiseMappableTraits(op))
1295-
return VectorizationResult{VectorizationStatus::Failure, nullptr};
1296+
return VectorizationHookResult{VectorizationStatus::Failure, nullptr};
12961297

12971298
// 4 . Check if the operation is a reduction.
12981299
SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1315,7 +1316,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13151316
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
13161317
reductionOperands[0].second, bvm);
13171318
if (reduceOp)
1318-
return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
1319+
return VectorizationHookResult{VectorizationStatus::NewOp, reduceOp};
13191320
}
13201321

13211322
// 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1355,7 +1356,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13551356
: resultType);
13561357
}
13571358
// d. Build and return the new op.
1358-
return VectorizationResult{
1359+
return VectorizationHookResult{
13591360
VectorizationStatus::NewOp,
13601361
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
13611362
resultTypes, op->getAttrs())};
@@ -1460,28 +1461,28 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14601461
SmallVector<CustomVectorizationHook> hooks;
14611462
// 4a. Register CustomVectorizationHook for yieldOp.
14621463
CustomVectorizationHook vectorizeYield =
1463-
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1464+
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14641465
return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
14651466
};
14661467
hooks.push_back(vectorizeYield);
14671468

14681469
// 4b. Register CustomVectorizationHook for indexOp.
14691470
CustomVectorizationHook vectorizeIndex =
1470-
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1471+
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14711472
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
14721473
};
14731474
hooks.push_back(vectorizeIndex);
14741475

14751476
// 4c. Register CustomVectorizationHook for extractOp.
14761477
CustomVectorizationHook vectorizeExtract =
1477-
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1478+
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14781479
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
14791480
};
14801481
hooks.push_back(vectorizeExtract);
14811482

14821483
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
14831484
for (Operation &op : block->getOperations()) {
1484-
VectorizationResult result =
1485+
VectorizationHookResult result =
14851486
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
14861487
if (result.status == VectorizationStatus::Failure) {
14871488
LDBG("failed to vectorize: " << op << "\n");
@@ -2522,17 +2523,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25222523
tensor::InsertSliceOp>(op);
25232524
}
25242525

2525-
/// Emit a suitable vector form for an operation. If provided,
2526-
/// `inputVectorSizes` are used to vectorize this operation.
2527-
/// `inputVectorSizes` must match the rank of the iteration space of the
2528-
/// operation and the input vector sizes must be greater than or equal to
2529-
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
2530-
/// also allows the vectorization of operations with dynamic shapes.
2531-
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2532-
ArrayRef<int64_t> inputVectorSizes,
2533-
ArrayRef<bool> inputScalableVecDims,
2534-
bool vectorizeNDExtract,
2535-
bool flatten1DDepthwiseConv) {
2526+
FailureOr<VectorizationResult>
2527+
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2528+
ArrayRef<int64_t> inputVectorSizes,
2529+
ArrayRef<bool> inputScalableVecDims,
2530+
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
25362531
LDBG("Attempting to vectorize:\n" << *op << "\n");
25372532
LDBG("Input vector sizes: ");
25382533
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2614,12 +2609,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
26142609
return failure();
26152610
}
26162611

2617-
if (!results.empty())
2618-
rewriter.replaceOp(op, results);
2619-
else
2620-
rewriter.eraseOp(op);
2621-
2622-
return success();
2612+
return VectorizationResult({results});
26232613
}
26242614

26252615
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,

0 commit comments

Comments
 (0)