@@ -139,27 +139,27 @@ using CustomVectorizationHook = std::function<VectorizationResult(
139139 Operation *, const BlockAndValueMapping &)>;
140140
141141// / Helper function to vectorize the terminator of a `linalgOp`. New result
142- // / vector values are appended to `results`.
143- // / Return VectorizationStatus::NoReplace to signal the vectorization algorithm
144- // / that it should not try to map produced operations: this is the purpose of
145- // / the `results` argument to capture such values and make them available for
146- // / RAUW to the vectorization algorithm.
147- // / This function is meant to be used as a CustomVectorizationHook.
142+ // / vector values are appended to `newResults`. Return
143+ // / VectorizationStatus::NoReplace to signal the vectorization algorithm that it
144+ // / should not try to map produced operations and instead return the results
145+ // / using the `newResults` vector making them available to the
146+ // / vectorization algorithm for RAUW. This function is meant to be used as a
147+ // / CustomVectorizationHook.
148148static VectorizationResult
149149vectorizeLinalgYield (OpBuilder &builder, Operation *op,
150150 const BlockAndValueMapping &bvm, LinalgOp linalgOp,
151- SmallVectorImpl<Value> &results ) {
151+ SmallVectorImpl<Value> &newResults ) {
152152 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
153153 if (!yieldOp)
154154 return VectorizationResult{VectorizationStatus::Failure, nullptr };
155155 for (auto outputs : llvm::enumerate (yieldOp.values ())) {
156156 // TODO: Scan for an opportunity for reuse.
157157 // TODO: use a map.
158158 Value vectorValue = bvm.lookup (outputs.value ());
159- Value result = buildVectorWrite (builder, vectorValue,
160- linalgOp.getOutput (outputs.index ()));
161- if (result )
162- results .push_back (result );
159+ Value newResult = buildVectorWrite (builder, vectorValue,
160+ linalgOp.getOutput (outputs.index ()));
161+ if (newResult )
162+ newResults .push_back (newResult );
163163 }
164164 return VectorizationResult{VectorizationStatus::NoReplace, nullptr };
165165}
@@ -248,8 +248,8 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
248248// / TODO: Reuse opportunities for RAR dependencies.
249249// / 4. Register CustomVectorizationHook for YieldOp to capture the results.
250250// / 5. Iteratively call vectorizeOneOp on the region operations.
251- static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric (
252- OpBuilder &builder, LinalgOp linalgOp,
251+ LogicalResult vectorizeAsLinalgGeneric (
252+ OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
253253 ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
254254 // 1. Certain Linalg ops do not have a region but only a region builder.
255255 // If so, build the region so we can vectorize.
@@ -290,11 +290,10 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
290290 }
291291
292292 // 4. Register CustomVectorizationHook for yieldOp.
293- SmallVector<Value> results;
294293 CustomVectorizationHook vectorizeYield =
295294 [&](Operation *op,
296295 const BlockAndValueMapping &bvm) -> VectorizationResult {
297- return vectorizeLinalgYield (builder, op, bvm, linalgOp, results );
296+ return vectorizeLinalgYield (builder, op, bvm, linalgOp, newResults );
298297 };
299298 // Append the vectorizeYield hook.
300299 auto hooks = llvm::to_vector<4 >(customVectorizationHooks);
@@ -305,7 +304,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
305304 VectorizationResult result = vectorizeOneOp (builder, &op, bvm, hooks);
306305 if (result.status == VectorizationStatus::Failure) {
307306 LLVM_DEBUG (dbgs () << " \n [" DEBUG_TYPE " ]: failed to vectorize: " << op);
308- return llvm::None ;
307+ return failure () ;
309308 }
310309 if (result.status == VectorizationStatus::NewOp) {
311310 LLVM_DEBUG (dbgs () << " \n [" DEBUG_TYPE " ]: new vector op: "
@@ -314,7 +313,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
314313 }
315314 }
316315
317- return VectorizedLinalgOp{{results}} ;
316+ return success () ;
318317}
319318
320319// / Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@@ -355,8 +354,8 @@ static bool isElementwise(Operation *op) {
355354 return hasOnlyScalarElementwiseOp (linalgOp->getRegion (0 ));
356355}
357356
358- static Optional<VectorizedLinalgOp> vectorizeContraction (OpBuilder &builder,
359- LinalgOp linalgOp ) {
357+ static LogicalResult vectorizeContraction (OpBuilder &builder, LinalgOp linalgOp ,
358+ SmallVectorImpl<Value> &newResults ) {
360359 assert (isaContractionOpInterface (linalgOp) &&
361360 " expected vectorizeContraction preconditions to be met" );
362361 Location loc = linalgOp.getLoc ();
@@ -383,7 +382,8 @@ static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
383382 linalgOp.indexing_maps (), linalgOp.iterator_types ());
384383 return VectorizationResult{VectorizationStatus::NewOp, contract};
385384 };
386- return vectorizeAsLinalgGeneric (builder, linalgOp, {vectorizeContraction});
385+ return vectorizeAsLinalgGeneric (builder, linalgOp, newResults,
386+ {vectorizeContraction});
387387}
388388
389389LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition (Operation *op) {
@@ -400,19 +400,20 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
400400 return success (isaContractionOpInterface (linalgOp));
401401}
402402
403- Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp (OpBuilder &builder,
404- Operation *op) {
403+ LogicalResult
404+ mlir::linalg::vectorizeLinalgOp (OpBuilder &builder, Operation *op,
405+ SmallVectorImpl<Value> &newResults) {
405406 if (failed (vectorizeLinalgOpPrecondition (op)))
406- return llvm::None ;
407+ return failure () ;
407408
408409 edsc::ScopedContext scope (builder, op->getLoc ());
409410 if (isElementwise (op)) {
410411 LLVM_DEBUG (dbgs () << " \n [" DEBUG_TYPE " ]: "
411412 << " Vectorize linalg op as a generic: " << *op);
412- return vectorizeAsLinalgGeneric (builder, cast<LinalgOp>(op));
413+ return vectorizeAsLinalgGeneric (builder, cast<LinalgOp>(op), newResults );
413414 }
414415
415- return vectorizeContraction (builder, cast<LinalgOp>(op));
416+ return vectorizeContraction (builder, cast<LinalgOp>(op), newResults );
416417}
417418
418419// ----------------------------------------------------------------------------//
0 commit comments