Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -854,17 +854,23 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
/// to work (these are checked by the vectorizer itself).
bool hasVectorizationImpl(Operation *);

/// Emit a suitable vector form for an operation. If provided,
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
/// must match the rank of the iteration space of the operation and the sizes
/// must be smaller or equal than their counterpart interation space sizes, if
/// static. `inputVectorShapes` also allows the vectorization of operations with
/// dynamic shapes.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false,
bool flatten1DDepthwiseConv = false);
/// Transformation information returned after vectorizing.
struct VectorizationResult {
/// Results of the vectorization transform to replace the original operation.
SmallVector<Value> replacements;
};
Comment on lines +857 to +861
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Not quite "transformation information" ;-)
  2. Do we need a vector of results? I don't want to sound "limiting" or call for a bigger refactor in this PR, but from what I recall, the vectorizer always gets an Op with one result and returns a vectorized result. Is this true? I just suspect that a single Value would be sufficient. This could be something for a follow-up PR.
  3. VectorizationResult -> VectorizedResult? This name is a bit too close to VectorizationHookResult and the distinction is not well documented. My suggestion is not that much better, but it "hints" it's just the actual vectorized Op (as opposed to "vectorized Op + status flag"). I will try to come back with sth better :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You make some good points. I was trying to follow the naming of other similar structs (SCFTilingResult, DropUnitDimsResult, ElementwiseOpFusionResult), but I think it is a little bit weird in this case, because we are only returning the results. I like staying consistent with the other transform names, though. Perhaps we can come up with a better name for VectorizationHookResult instead? I can try to think of something that is less similar to VectorizationResult.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE (2): I didn't know about this, but I looked into some of the lit tests, and there does seem to be an example of a linalg op with multiple results:

// Check vectorization can handle cases where outputs are a mix of reduced and non-reduced values.
func.func @mixed_parallel_reduced_results(%arg0 : tensor<2x4x8xf32>,
%arg1 : tensor<2x4xf32>, %arg2 : tensor<2x4x8xf32>, %arg3 : tensor<2x4xf32>) ->
(tensor<2x4x8xf32>, tensor<2x4xf32>) {
%0:2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : tensor<2x4x8xf32>, tensor<2x4xf32>)
outs(%arg2, %arg3 : tensor<2x4x8xf32>, tensor<2x4xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32):
%1 = arith.mulf %b0, %b1 : f32
%2 = arith.addf %1, %b3 : f32
linalg.yield %1, %2 : f32, f32
} -> (tensor<2x4x8xf32>, tensor<2x4xf32>)
return %0#0, %0#1 : tensor<2x4x8xf32>, tensor<2x4xf32>
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your replies!

I didn't know about this, but I looked into some of the lit tests, and there does seem to be an example of a linalg op with multiple results

Great find 🙏🏻

I was trying to follow the naming of other similar structs (SCFTilingResult, DropUnitDimsResult, ElementwiseOpFusionResult), but I think it is a little bit weird in this case, because we are only returning the results.

Aligning with other transformations makes sense, but, as you note, it doesn't work particular well in this case.. Why not:

  using vectorizedValues = SmallVector<Value>;

?
Also, the newly introduced/renamed VectorizationHookResult is used more widely than VectorizationResult - I would optimise for what's more commonly used and just keep the old name, VectorizationResult. However, it would be good to add a comment. Something along the lines of:

// Encapsulates vectorisation result for a single Op vectorized using a custom vectorization hook.

WDYT?

In general, my main ask is to make it clear (with comments) what the distinction between the two is. Naming in this case is particularly hard :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not:
using vectorizedValues = SmallVector<Value>;
?

I made it a struct because it leaves more room for extending the return type with less code churn in downstream projects, although I don't have a strong opinion on this.

Also, the newly introduced/renamed VectorizationHookResult is used more widely than VectorizationResult - I would optimise for what's more commonly used and just keep the old name

IMO, the naming of this new return type is more important to get right, because the VectorizationHookResult is just an internal implementation detail to the file. The new VectorizationResult is exposed as the return type of the public vectorize function, which means there will be more pain in downstream projects any time the type name is changed. A change to VectorizationHookResult is done with a single PR here in mlir, but changing VectorizationResult means every downstream project must integrate the change as well. Again, I don't have a particularly strong opinion here, but it just seems nice to avoid potential extra integration pain downstream.

That said, comments to distinguish the two are always great! I'll try to clarify it in the code comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are great points - thanks!

Let’s park the bikeshedding here (which I kicked off 😅). What you’re proposing makes sense to me!

/// Returns a `VectorizationResult` containing the results of the vectorized op,
/// or failure if the transformation fails. If provided, `inputVectorSizes` are
/// used to vectorize this operation. `inputVectorSizes` must match the rank of
/// the iteration space of the operation and the input vector sizes must be
/// greater than or equal to their counterpart iteration space sizes, if static.
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
/// shapes.
FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);

/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
Expand Down
20 changes: 13 additions & 7 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3823,9 +3823,14 @@ struct VectorizationPattern : public RewritePattern {
if (!linalg::hasVectorizationImpl(op))
return rewriter.notifyMatchFailure(op,
"Unsupported Op, cannot vectorize");
return vectorize(rewriter, op, /*inputVectorSizes=*/{},
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
flatten1DDepthwiseConv);
FailureOr<VectorizationResult> vectorResults =
vectorize(rewriter, op, /*inputVectorSizes=*/{},
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
flatten1DDepthwiseConv);
if (failed(vectorResults))
return failure();
rewriter.replaceOp(op, vectorResults->replacements);
return success();
}

private:
Expand Down Expand Up @@ -3914,13 +3919,14 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}

if (failed(linalg::vectorize(rewriter, target, vectorSizes,
getScalableSizes(),
getVectorizeNdExtract().value_or(false)))) {
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
getVectorizeNdExtract().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
}
rewriter.replaceOp(target, vectorResults->replacements);
}

return DiagnosedSilenceableFailure::success();
Expand Down
116 changes: 56 additions & 60 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,10 @@ enum class Conv1DOpOrder {
Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
};

/// Helper data structure to represent the result of vectorization.
/// In certain specific cases, like terminators, we do not want to propagate/
enum VectorizationStatus {
/// Helper data structure to represent the result of vectorization for a single
/// operation. In certain specific cases, like terminators, we do not want to
/// propagate.
enum VectorizationHookStatus {
/// Op failed to vectorize.
Failure = 0,
/// Op vectorized and custom function took care of replacement logic
Expand All @@ -564,9 +565,12 @@ enum VectorizationStatus {
// TODO: support values if Op vectorized to Many-Ops whose results we need to
// aggregate for replacement.
};
struct VectorizationResult {
/// VectorizationHookResult contains the vectorized op returned from a
/// CustomVectorizationHook. This is an internal implementation detail of
/// linalg vectorization, not to be confused with VectorizationResult.
struct VectorizationHookResult {
/// Return status from vectorizing the current op.
enum VectorizationStatus status = VectorizationStatus::Failure;
enum VectorizationHookStatus status = VectorizationHookStatus::Failure;
/// New vectorized operation to replace the current op.
/// Replacement behavior is specified by `status`.
Operation *newOp;
Expand Down Expand Up @@ -728,22 +732,22 @@ using CustomVectorizationPrecondition =
// assuming all its vectorized operands are already in the IRMapping.
// Return nullptr if the Operation cannot be vectorized.
using CustomVectorizationHook =
std::function<VectorizationResult(Operation *, const IRMapping &)>;
std::function<VectorizationHookResult(Operation *, const IRMapping &)>;

/// Helper function to vectorize the terminator of a `linalgOp`. New result
/// vector values are appended to `newResults`. Return
/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
/// should not try to map produced operations and instead return the results
/// using the `newResults` vector making them available to the vectorization
/// algorithm for RAUW. This function is meant to be used as a
/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
/// that it should not try to map produced operations and instead return the
/// results using the `newResults` vector making them available to the
/// vectorization algorithm for RAUW. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationResult
static VectorizationHookResult
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
const IRMapping &bvm, VectorizationState &state,
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
if (!yieldOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
// TODO: Scan for an opportunity for reuse.
// TODO: use a map.
Expand All @@ -755,20 +759,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
newResults.push_back(newResult);
}

return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
return VectorizationHookResult{VectorizationHookStatus::NoReplace, nullptr};
}

/// Helper function to vectorize the index operations of a `linalgOp`. Return
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
VectorizationState &state,
Operation *op,
LinalgOp linalgOp) {
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
VectorizationState &state,
Operation *op,
LinalgOp linalgOp) {
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
if (!indexOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
auto loc = indexOp.getLoc();
// Compute the static loop sizes of the index op.
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
Expand All @@ -782,7 +786,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
// dimension of the iteration space since the vectorization algorithm in this
// case can handle the broadcast.
if (dim == targetShape.size() - 1)
return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
return VectorizationHookResult{VectorizationHookStatus::NewOp, indexSteps};
// Otherwise permute the targetShape to move the index dimension last,
// broadcast the one-dimensional index vector to the permuted shape, and
// finally transpose the broadcasted index vector to undo the permutation.
Expand All @@ -800,7 +804,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
std::swap(transposition.back(), transposition[dim]);
auto transposeOp =
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp};
}

/// Helper function to check if the tensor.extract can be vectorized by the
Expand Down Expand Up @@ -1098,15 +1102,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
}

/// Helper function to vectorize the tensor.extract operations. Returns
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationResult
static VectorizationHookResult
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
auto loc = extractOp.getLoc();

// Compute the static loop sizes of the extract op.
Expand Down Expand Up @@ -1138,7 +1142,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);

LDBG("Vectorised as gather load: " << extractOp << "\n");
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp};
}

// 2. Handle:
Expand Down Expand Up @@ -1202,7 +1206,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);

LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
return VectorizationHookResult{VectorizationHookStatus::NewOp,
maskedReadOp};
}

// 2b. Handle contiguous access.
Expand All @@ -1228,7 +1233,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
inBounds);

LDBG("Vectorised as contiguous load: " << extractOp);
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
return VectorizationHookResult{VectorizationHookStatus::NewOp,
transferReadOp};
}

/// Emit reduction operations if the shapes of the value to reduce is different
Expand Down Expand Up @@ -1268,9 +1274,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
/// This function assumes all operands of `op` have been vectorized and are in
/// the `bvm` mapping. As a consequence, this function is meant to be called on
/// a topologically-sorted list of ops.
/// This function does not update `bvm` but returns a VectorizationStatus that
/// instructs the caller what `bvm` update needs to occur.
static VectorizationResult
/// This function does not update `bvm` but returns a VectorizationHookStatus
/// that instructs the caller what `bvm` update needs to occur.
static VectorizationHookResult
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
Expand All @@ -1279,8 +1285,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
// 1. Try to apply any CustomVectorizationHook.
if (!customVectorizationHooks.empty()) {
for (auto &customFunc : customVectorizationHooks) {
VectorizationResult result = customFunc(op, bvm);
if (result.status == VectorizationStatus::Failure)
VectorizationHookResult result = customFunc(op, bvm);
if (result.status == VectorizationHookStatus::Failure)
continue;
return result;
}
Expand All @@ -1289,11 +1295,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
// Clone so that the constant is not confined to the linalgOp block .
if (isa<arith::ConstantOp, func::ConstantOp>(op))
return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
return VectorizationHookResult{VectorizationHookStatus::NewOp,
rewriter.clone(*op)};

// 3. Only ElementwiseMappable are allowed in the generic vectorization.
if (!OpTrait::hasElementwiseMappableTraits(op))
return VectorizationResult{VectorizationStatus::Failure, nullptr};
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};

// 4 . Check if the operation is a reduction.
SmallVector<std::pair<Value, Value>> reductionOperands;
Expand All @@ -1316,7 +1323,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
reductionOperands[0].second, bvm);
if (reduceOp)
return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
return VectorizationHookResult{VectorizationHookStatus::NewOp, reduceOp};
}

// 5. Generic vectorization path for ElementwiseMappable ops.
Expand Down Expand Up @@ -1356,8 +1363,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
: resultType);
}
// d. Build and return the new op.
return VectorizationResult{
VectorizationStatus::NewOp,
return VectorizationHookResult{
VectorizationHookStatus::NewOp,
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
resultTypes, op->getAttrs())};
}
Expand Down Expand Up @@ -1461,34 +1468,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
SmallVector<CustomVectorizationHook> hooks;
// 4a. Register CustomVectorizationHook for yieldOp.
CustomVectorizationHook vectorizeYield =
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
};
hooks.push_back(vectorizeYield);

// 4b. Register CustomVectorizationHook for indexOp.
CustomVectorizationHook vectorizeIndex =
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
};
hooks.push_back(vectorizeIndex);

// 4c. Register CustomVectorizationHook for extractOp.
CustomVectorizationHook vectorizeExtract =
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
};
hooks.push_back(vectorizeExtract);

// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block->getOperations()) {
VectorizationResult result =
VectorizationHookResult result =
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
if (result.status == VectorizationHookStatus::Failure) {
LDBG("failed to vectorize: " << op << "\n");
return failure();
}
if (result.status == VectorizationStatus::NewOp) {
if (result.status == VectorizationHookStatus::NewOp) {
Operation *maybeMaskedOp =
state.maskOperation(rewriter, result.newOp, linalgOp);
LDBG("New vector op: " << *maybeMaskedOp << "\n");
Expand Down Expand Up @@ -2525,17 +2532,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
tensor::InsertSliceOp>(op);
}

/// Emit a suitable vector form for an operation. If provided,
/// `inputVectorSizes` are used to vectorize this operation.
/// `inputVectorSizes` must match the rank of the iteration space of the
/// operation and the input vector sizes must be greater than or equal to
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
/// also allows the vectorization of operations with dynamic shapes.
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
FailureOr<VectorizationResult>
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
Expand Down Expand Up @@ -2617,12 +2618,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return failure();
}

if (!results.empty())
rewriter.replaceOp(op, results);
else
rewriter.eraseOp(op);

return success();
return VectorizationResult{results};
}

LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
Expand Down