Skip to content

Commit c1a4cd5

Browse files
author
Tobias Gysi
committed
[mlir][linalg] refactor the result handling during vectorization.
Return the vectorization results using a vector passed by reference instead of returning them embedded in a structure. Differential Revision: https://reviews.llvm.org/D98182
1 parent e31c77b commit c1a4cd5

File tree

3 files changed

+32
-35
lines changed

3 files changed

+32
-35
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,12 +263,8 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
263263
OperationFolder *folder = nullptr);
264264

265265
/// Emit a suitable vector form for a Linalg op with fully static shape.
266-
struct VectorizedLinalgOp {
267-
SmallVector<Value> tensorResults;
268-
VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default;
269-
};
270-
Optional<VectorizedLinalgOp> vectorizeLinalgOp(OpBuilder &builder,
271-
Operation *op);
266+
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
267+
SmallVectorImpl<Value> &newResults);
272268

273269
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
274270
template <typename LoopTy>

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,11 +468,11 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
468468
return failure();
469469
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
470470
return failure();
471-
Optional<VectorizedLinalgOp> res = vectorizeLinalgOp(rewriter, op);
472-
if (!res)
471+
SmallVector<Value> newResults;
472+
if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
473473
return failure();
474-
if (!res->tensorResults.empty())
475-
rewriter.replaceOp(op, res->tensorResults);
474+
if (!newResults.empty())
475+
rewriter.replaceOp(op, newResults);
476476
else
477477
rewriter.eraseOp(op);
478478
return success();

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -139,27 +139,27 @@ using CustomVectorizationHook = std::function<VectorizationResult(
139139
Operation *, const BlockAndValueMapping &)>;
140140

141141
/// Helper function to vectorize the terminator of a `linalgOp`. New result
142-
/// vector values are appended to `results`.
143-
/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm
144-
/// that it should not try to map produced operations: this is the purpose of
145-
/// the `results` argument to capture such values and make them available for
146-
/// RAUW to the vectorization algorithm.
147-
/// This function is meant to be used as a CustomVectorizationHook.
142+
/// vector values are appended to `newResults`. Return
143+
/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
144+
/// should not try to map produced operations and instead return the results
145+
/// using the `newResults` vector making them available to the
146+
/// vectorization algorithm for RAUW. This function is meant to be used as a
147+
/// CustomVectorizationHook.
148148
static VectorizationResult
149149
vectorizeLinalgYield(OpBuilder &builder, Operation *op,
150150
const BlockAndValueMapping &bvm, LinalgOp linalgOp,
151-
SmallVectorImpl<Value> &results) {
151+
SmallVectorImpl<Value> &newResults) {
152152
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
153153
if (!yieldOp)
154154
return VectorizationResult{VectorizationStatus::Failure, nullptr};
155155
for (auto outputs : llvm::enumerate(yieldOp.values())) {
156156
// TODO: Scan for an opportunity for reuse.
157157
// TODO: use a map.
158158
Value vectorValue = bvm.lookup(outputs.value());
159-
Value result = buildVectorWrite(builder, vectorValue,
160-
linalgOp.getOutput(outputs.index()));
161-
if (result)
162-
results.push_back(result);
159+
Value newResult = buildVectorWrite(builder, vectorValue,
160+
linalgOp.getOutput(outputs.index()));
161+
if (newResult)
162+
newResults.push_back(newResult);
163163
}
164164
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
165165
}
@@ -248,8 +248,8 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
248248
/// TODO: Reuse opportunities for RAR dependencies.
249249
/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
250250
/// 5. Iteratively call vectorizeOneOp on the region operations.
251-
static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
252-
OpBuilder &builder, LinalgOp linalgOp,
251+
LogicalResult vectorizeAsLinalgGeneric(
252+
OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
253253
ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
254254
// 1. Certain Linalg ops do not have a region but only a region builder.
255255
// If so, build the region so we can vectorize.
@@ -290,11 +290,10 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
290290
}
291291

292292
// 4. Register CustomVectorizationHook for yieldOp.
293-
SmallVector<Value> results;
294293
CustomVectorizationHook vectorizeYield =
295294
[&](Operation *op,
296295
const BlockAndValueMapping &bvm) -> VectorizationResult {
297-
return vectorizeLinalgYield(builder, op, bvm, linalgOp, results);
296+
return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults);
298297
};
299298
// Append the vectorizeYield hook.
300299
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@@ -305,7 +304,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
305304
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
306305
if (result.status == VectorizationStatus::Failure) {
307306
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
308-
return llvm::None;
307+
return failure();
309308
}
310309
if (result.status == VectorizationStatus::NewOp) {
311310
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
@@ -314,7 +313,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
314313
}
315314
}
316315

317-
return VectorizedLinalgOp{{results}};
316+
return success();
318317
}
319318

320319
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@@ -355,8 +354,8 @@ static bool isElementwise(Operation *op) {
355354
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
356355
}
357356

358-
static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
359-
LinalgOp linalgOp) {
357+
static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
358+
SmallVectorImpl<Value> &newResults) {
360359
assert(isaContractionOpInterface(linalgOp) &&
361360
"expected vectorizeContraction preconditions to be met");
362361
Location loc = linalgOp.getLoc();
@@ -383,7 +382,8 @@ static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
383382
linalgOp.indexing_maps(), linalgOp.iterator_types());
384383
return VectorizationResult{VectorizationStatus::NewOp, contract};
385384
};
386-
return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
385+
return vectorizeAsLinalgGeneric(builder, linalgOp, newResults,
386+
{vectorizeContraction});
387387
}
388388

389389
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
@@ -400,19 +400,20 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
400400
return success(isaContractionOpInterface(linalgOp));
401401
}
402402

403-
Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
404-
Operation *op) {
403+
LogicalResult
404+
mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op,
405+
SmallVectorImpl<Value> &newResults) {
405406
if (failed(vectorizeLinalgOpPrecondition(op)))
406-
return llvm::None;
407+
return failure();
407408

408409
edsc::ScopedContext scope(builder, op->getLoc());
409410
if (isElementwise(op)) {
410411
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
411412
<< "Vectorize linalg op as a generic: " << *op);
412-
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
413+
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op), newResults);
413414
}
414415

415-
return vectorizeContraction(builder, cast<LinalgOp>(op));
416+
return vectorizeContraction(builder, cast<LinalgOp>(op), newResults);
416417
}
417418

418419
//----------------------------------------------------------------------------//

0 commit comments

Comments
 (0)