Skip to content

Commit a0f6542

Browse files
Groverksslialan
authored andcommitted
Revert "[mlir] Return vectorized values instead of replacing (llvm#144158)"
This reverts commit 4d21da0.
1 parent d4eef14 commit a0f6542

File tree

3 files changed

+78
-86
lines changed

3 files changed

+78
-86
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -854,23 +854,17 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
854854
/// to work (these are checked by the vectorizer itself).
855855
bool hasVectorizationImpl(Operation *);
856856

857-
/// Transformation information returned after vectorizing.
858-
struct VectorizationResult {
859-
/// Results of the vectorization transform to replace the original operation.
860-
SmallVector<Value> replacements;
861-
};
862-
/// Returns a `VectorizationResult` containing the results of the vectorized op,
863-
/// or failure if the transformation fails. If provided, `inputVectorSizes` are
864-
/// used to vectorize this operation. `inputVectorSizes` must match the rank of
865-
/// the iteration space of the operation and the input vector sizes must be
866-
/// greater than or equal to their counterpart iteration space sizes, if static.
867-
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
868-
/// shapes.
869-
FailureOr<VectorizationResult>
870-
vectorize(RewriterBase &rewriter, Operation *op,
871-
ArrayRef<int64_t> inputVectorSizes = {},
872-
ArrayRef<bool> inputScalableVecDims = {},
873-
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
857+
/// Emit a suitable vector form for an operation. If provided,
858+
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
859+
/// must match the rank of the iteration space of the operation and the sizes
860+
/// must be smaller or equal than their counterpart interation space sizes, if
861+
/// static. `inputVectorShapes` also allows the vectorization of operations with
862+
/// dynamic shapes.
863+
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
864+
ArrayRef<int64_t> inputVectorSizes = {},
865+
ArrayRef<bool> inputScalableVecDims = {},
866+
bool vectorizeNDExtract = false,
867+
bool flatten1DDepthwiseConv = false);
874868

875869
/// Emit a suitable vector form for a Copy op with fully static shape.
876870
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3823,14 +3823,9 @@ struct VectorizationPattern : public RewritePattern {
38233823
if (!linalg::hasVectorizationImpl(op))
38243824
return rewriter.notifyMatchFailure(op,
38253825
"Unsupported Op, cannot vectorize");
3826-
FailureOr<VectorizationResult> vectorResults =
3827-
vectorize(rewriter, op, /*inputVectorSizes=*/{},
3828-
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
3829-
flatten1DDepthwiseConv);
3830-
if (failed(vectorResults))
3831-
return failure();
3832-
rewriter.replaceOp(op, vectorResults->replacements);
3833-
return success();
3826+
return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3827+
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
3828+
flatten1DDepthwiseConv);
38343829
}
38353830

38363831
private:
@@ -3919,14 +3914,13 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
39193914
return mlir::emitSilenceableFailure(target->getLoc())
39203915
<< "Unsupported Op, cannot vectorize";
39213916
}
3922-
FailureOr<VectorizationResult> vectorResults =
3923-
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
3924-
getVectorizeNdExtract().value_or(false));
3925-
if (failed(vectorResults)) {
3917+
3918+
if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3919+
getScalableSizes(),
3920+
getVectorizeNdExtract().value_or(false)))) {
39263921
return mlir::emitSilenceableFailure(target->getLoc())
39273922
<< "Attempted to vectorize, but failed";
39283923
}
3929-
rewriter.replaceOp(target, vectorResults->replacements);
39303924
}
39313925

39323926
return DiagnosedSilenceableFailure::success();

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,9 @@ enum class Conv1DOpOrder {
551551
Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
552552
};
553553

554-
/// Helper data structure to represent the result of vectorization for a single
555-
/// operation. In certain specific cases, like terminators, we do not want to
556-
/// propagate.
557-
enum VectorizationHookStatus {
554+
/// Helper data structure to represent the result of vectorization.
555+
/// In certain specific cases, like terminators, we do not want to propagate/
556+
enum VectorizationStatus {
558557
/// Op failed to vectorize.
559558
Failure = 0,
560559
/// Op vectorized and custom function took care of replacement logic
@@ -565,12 +564,9 @@ enum VectorizationHookStatus {
565564
// TODO: support values if Op vectorized to Many-Ops whose results we need to
566565
// aggregate for replacement.
567566
};
568-
/// VectorizationHookResult contains the vectorized op returned from a
569-
/// CustomVectorizationHook. This is an internal implementation detail of
570-
/// linalg vectorization, not to be confused with VectorizationResult.
571-
struct VectorizationHookResult {
567+
struct VectorizationResult {
572568
/// Return status from vectorizing the current op.
573-
enum VectorizationHookStatus status = VectorizationHookStatus::Failure;
569+
enum VectorizationStatus status = VectorizationStatus::Failure;
574570
/// New vectorized operation to replace the current op.
575571
/// Replacement behavior is specified by `status`.
576572
Operation *newOp;
@@ -732,22 +728,22 @@ using CustomVectorizationPrecondition =
732728
// assuming all its vectorized operands are already in the IRMapping.
733729
// Return nullptr if the Operation cannot be vectorized.
734730
using CustomVectorizationHook =
735-
std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
731+
std::function<VectorizationResult(Operation *, const IRMapping &)>;
736732

737733
/// Helper function to vectorize the terminator of a `linalgOp`. New result
738734
/// vector values are appended to `newResults`. Return
739-
/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
740-
/// that it should not try to map produced operations and instead return the
741-
/// results using the `newResults` vector making them available to the
742-
/// vectorization algorithm for RAUW. This function is meant to be used as a
735+
/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
736+
/// should not try to map produced operations and instead return the results
737+
/// using the `newResults` vector making them available to the vectorization
738+
/// algorithm for RAUW. This function is meant to be used as a
743739
/// CustomVectorizationHook.
744-
static VectorizationHookResult
740+
static VectorizationResult
745741
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
746742
const IRMapping &bvm, VectorizationState &state,
747743
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
748744
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
749745
if (!yieldOp)
750-
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
746+
return VectorizationResult{VectorizationStatus::Failure, nullptr};
751747
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
752748
// TODO: Scan for an opportunity for reuse.
753749
// TODO: use a map.
@@ -759,20 +755,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
759755
newResults.push_back(newResult);
760756
}
761757

762-
return VectorizationHookResult{VectorizationHookStatus::NoReplace, nullptr};
758+
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
763759
}
764760

765761
/// Helper function to vectorize the index operations of a `linalgOp`. Return
766-
/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
762+
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
767763
/// should map the produced operations. This function is meant to be used as a
768764
/// CustomVectorizationHook.
769-
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
770-
VectorizationState &state,
771-
Operation *op,
772-
LinalgOp linalgOp) {
765+
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
766+
VectorizationState &state,
767+
Operation *op,
768+
LinalgOp linalgOp) {
773769
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
774770
if (!indexOp)
775-
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
771+
return VectorizationResult{VectorizationStatus::Failure, nullptr};
776772
auto loc = indexOp.getLoc();
777773
// Compute the static loop sizes of the index op.
778774
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
@@ -786,7 +782,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
786782
// dimension of the iteration space since the vectorization algorithm in this
787783
// case can handle the broadcast.
788784
if (dim == targetShape.size() - 1)
789-
return VectorizationHookResult{VectorizationHookStatus::NewOp, indexSteps};
785+
return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
790786
// Otherwise permute the targetShape to move the index dimension last,
791787
// broadcast the one-dimensional index vector to the permuted shape, and
792788
// finally transpose the broadcasted index vector to undo the permutation.
@@ -804,7 +800,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
804800
std::swap(transposition.back(), transposition[dim]);
805801
auto transposeOp =
806802
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
807-
return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp};
803+
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
808804
}
809805

810806
/// Helper function to check if the tensor.extract can be vectorized by the
@@ -1102,15 +1098,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
11021098
}
11031099

11041100
/// Helper function to vectorize the tensor.extract operations. Returns
1105-
/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1101+
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
11061102
/// should map the produced operations. This function is meant to be used as a
11071103
/// CustomVectorizationHook.
1108-
static VectorizationHookResult
1104+
static VectorizationResult
11091105
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11101106
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
11111107
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
11121108
if (!extractOp)
1113-
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
1109+
return VectorizationResult{VectorizationStatus::Failure, nullptr};
11141110
auto loc = extractOp.getLoc();
11151111

11161112
// Compute the static loop sizes of the extract op.
@@ -1142,7 +1138,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11421138
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
11431139

11441140
LDBG("Vectorised as gather load: " << extractOp << "\n");
1145-
return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp};
1141+
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
11461142
}
11471143

11481144
// 2. Handle:
@@ -1206,8 +1202,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12061202
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
12071203

12081204
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1209-
return VectorizationHookResult{VectorizationHookStatus::NewOp,
1210-
maskedReadOp};
1205+
return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
12111206
}
12121207

12131208
// 2b. Handle contiguous access.
@@ -1233,8 +1228,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12331228
inBounds);
12341229

12351230
LDBG("Vectorised as contiguous load: " << extractOp);
1236-
return VectorizationHookResult{VectorizationHookStatus::NewOp,
1237-
transferReadOp};
1231+
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
12381232
}
12391233

12401234
/// Emit reduction operations if the shapes of the value to reduce is different
@@ -1274,9 +1268,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
12741268
/// This function assumes all operands of `op` have been vectorized and are in
12751269
/// the `bvm` mapping. As a consequence, this function is meant to be called on
12761270
/// a topologically-sorted list of ops.
1277-
/// This function does not update `bvm` but returns a VectorizationHookStatus
1278-
/// that instructs the caller what `bvm` update needs to occur.
1279-
static VectorizationHookResult
1271+
/// This function does not update `bvm` but returns a VectorizationStatus that
1272+
/// instructs the caller what `bvm` update needs to occur.
1273+
static VectorizationResult
12801274
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12811275
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
12821276
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1285,8 +1279,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12851279
// 1. Try to apply any CustomVectorizationHook.
12861280
if (!customVectorizationHooks.empty()) {
12871281
for (auto &customFunc : customVectorizationHooks) {
1288-
VectorizationHookResult result = customFunc(op, bvm);
1289-
if (result.status == VectorizationHookStatus::Failure)
1282+
VectorizationResult result = customFunc(op, bvm);
1283+
if (result.status == VectorizationStatus::Failure)
12901284
continue;
12911285
return result;
12921286
}
@@ -1295,12 +1289,11 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12951289
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
12961290
// Clone so that the constant is not confined to the linalgOp block .
12971291
if (isa<arith::ConstantOp, func::ConstantOp>(op))
1298-
return VectorizationHookResult{VectorizationHookStatus::NewOp,
1299-
rewriter.clone(*op)};
1292+
return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
13001293

13011294
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
13021295
if (!OpTrait::hasElementwiseMappableTraits(op))
1303-
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
1296+
return VectorizationResult{VectorizationStatus::Failure, nullptr};
13041297

13051298
// 4 . Check if the operation is a reduction.
13061299
SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1323,7 +1316,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13231316
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
13241317
reductionOperands[0].second, bvm);
13251318
if (reduceOp)
1326-
return VectorizationHookResult{VectorizationHookStatus::NewOp, reduceOp};
1319+
return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
13271320
}
13281321

13291322
// 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1363,8 +1356,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13631356
: resultType);
13641357
}
13651358
// d. Build and return the new op.
1366-
return VectorizationHookResult{
1367-
VectorizationHookStatus::NewOp,
1359+
return VectorizationResult{
1360+
VectorizationStatus::NewOp,
13681361
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
13691362
resultTypes, op->getAttrs())};
13701363
}
@@ -1468,34 +1461,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14681461
SmallVector<CustomVectorizationHook> hooks;
14691462
// 4a. Register CustomVectorizationHook for yieldOp.
14701463
CustomVectorizationHook vectorizeYield =
1471-
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1464+
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14721465
return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
14731466
};
14741467
hooks.push_back(vectorizeYield);
14751468

14761469
// 4b. Register CustomVectorizationHook for indexOp.
14771470
CustomVectorizationHook vectorizeIndex =
1478-
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1471+
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14791472
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
14801473
};
14811474
hooks.push_back(vectorizeIndex);
14821475

14831476
// 4c. Register CustomVectorizationHook for extractOp.
14841477
CustomVectorizationHook vectorizeExtract =
1485-
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1478+
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14861479
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
14871480
};
14881481
hooks.push_back(vectorizeExtract);
14891482

14901483
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
14911484
for (Operation &op : block->getOperations()) {
1492-
VectorizationHookResult result =
1485+
VectorizationResult result =
14931486
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1494-
if (result.status == VectorizationHookStatus::Failure) {
1487+
if (result.status == VectorizationStatus::Failure) {
14951488
LDBG("failed to vectorize: " << op << "\n");
14961489
return failure();
14971490
}
1498-
if (result.status == VectorizationHookStatus::NewOp) {
1491+
if (result.status == VectorizationStatus::NewOp) {
14991492
Operation *maybeMaskedOp =
15001493
state.maskOperation(rewriter, result.newOp, linalgOp);
15011494
LDBG("New vector op: " << *maybeMaskedOp << "\n");
@@ -2532,11 +2525,17 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25322525
tensor::InsertSliceOp>(op);
25332526
}
25342527

2535-
FailureOr<VectorizationResult>
2536-
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2537-
ArrayRef<int64_t> inputVectorSizes,
2538-
ArrayRef<bool> inputScalableVecDims,
2539-
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2528+
/// Emit a suitable vector form for an operation. If provided,
2529+
/// `inputVectorSizes` are used to vectorize this operation.
2530+
/// `inputVectorSizes` must match the rank of the iteration space of the
2531+
/// operation and the input vector sizes must be greater than or equal to
2532+
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
2533+
/// also allows the vectorization of operations with dynamic shapes.
2534+
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2535+
ArrayRef<int64_t> inputVectorSizes,
2536+
ArrayRef<bool> inputScalableVecDims,
2537+
bool vectorizeNDExtract,
2538+
bool flatten1DDepthwiseConv) {
25402539
LDBG("Attempting to vectorize:\n" << *op << "\n");
25412540
LDBG("Input vector sizes: ");
25422541
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2618,7 +2617,12 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
26182617
return failure();
26192618
}
26202619

2621-
return VectorizationResult{results};
2620+
if (!results.empty())
2621+
rewriter.replaceOp(op, results);
2622+
else
2623+
rewriter.eraseOp(op);
2624+
2625+
return success();
26222626
}
26232627

26242628
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,

0 commit comments

Comments
 (0)