[mlir][vector] Simplify createReadOrMaskedRead#163736
[mlir][vector] Simplify createReadOrMaskedRead#163736banach-space merged 3 commits intollvm:mainfrom
Conversation
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/163736.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index a57aadcdcc5b0..2e6fab30e5120 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -219,18 +219,16 @@ bool isLinearizableVector(VectorType type);
/// Creates a TransferReadOp from `source`.
///
-/// The shape of the vector to read is specified via `inputVectorSizes`. If the
-/// shape of the output vector differs from the shape of the value being read,
-/// masking is used to avoid out-of-bounds accesses. Set
+/// If the shape of vector to read differs from the shape of the value being
+/// read, masking is used to avoid out-of-bounds accesses. Set
/// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
/// instead of explicit masks.
///
/// Note: all read offsets are set to 0.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
- ArrayRef<int64_t> inputVectorSizes,
+ VectorType &vecToReadTy,
std::optional<Value> padValue = std::nullopt,
- bool useInBoundsInsteadOfMasking = false,
- ArrayRef<bool> inputScalableVecDims = {});
+ bool useInBoundsInsteadOfMasking = false);
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c15e6330fd8c3..32246e3a11433 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1886,9 +1886,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
// Create masked TransferReadOp.
auto maskedRead = vector::createReadOrMaskedRead(
- rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
- useInBoundsInsteadOfMasking,
- /*inputScalableVecSizes=*/{});
+ rewriter, loc, packOp.getSource(), readVecType, padValue,
+ useInBoundsInsteadOfMasking);
// Create ShapeCastOp.
auto expandedVecType =
@@ -1975,9 +1974,12 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
}
// -- Generate the read operation --
+ VectorType readVecType =
+ VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
+ readScalableVectorFlags);
Value readResult = vector::createReadOrMaskedRead(
- rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt,
- useInBoundsInsteadOfMasking, readScalableVectorFlags);
+ rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
+ useInBoundsInsteadOfMasking);
// -- Generate the transpose operation --
PackingMetadata packMetadata;
@@ -2023,9 +2025,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");
+ auto readType = VectorType::get(inputVectorSizes, padValue.getType());
auto maskedRead = vector::createReadOrMaskedRead(
- rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
- /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
+ rewriter, loc, padOp.getSource(), readType, padValue,
+ /*useInBoundsInsteadOfMasking=*/false);
// Create Xfer write Op
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2220,9 +2223,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
Value read = mlir::vector::createReadOrMaskedRead(
- rewriter, loc, opOperand.get(), readType.getShape(),
+ rewriter, loc, opOperand.get(), readType,
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
- /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
+ /*useInBoundsInsteadOfMasking=*/false);
vecOperands.push_back(read);
}
@@ -3163,9 +3166,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
SmallVector<Value> readIndices(
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
Value read = mlir::vector::createReadOrMaskedRead(
- rewriter, loc, source, vecType.getShape(), padValue,
- /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
- /*inputScalableVecSizes=*/{});
+ rewriter, loc, source, vecType, padValue,
+ /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
// Create write
auto writeIndices =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 025ee9a04a1de..d73ccd1c41b66 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -317,51 +317,51 @@ bool vector::isLinearizableVector(VectorType type) {
}
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
- Value source,
- ArrayRef<int64_t> inputVectorSizes,
+ Value source, VectorType &vecToReadTy,
std::optional<Value> padValue,
- bool useInBoundsInsteadOfMasking,
- ArrayRef<bool> inputScalableVecDims) {
- assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
+ bool useInBoundsInsteadOfMasking) {
+ assert(!llvm::is_contained(vecToReadTy.getScalableDims(),
+ ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
- assert(sourceShape.size() == inputVectorSizes.size() &&
+
+ int64_t vecToReadRank = vecToReadTy.getRank();
+ auto vecToReadShape = vecToReadTy.getShape();
+
+ assert(sourceShape.size() == static_cast<size_t>(vecToReadRank) &&
"expected same ranks.");
- auto vectorType =
- VectorType::get(inputVectorSizes, sourceShapedType.getElementType(),
- inputScalableVecDims);
assert((!padValue.has_value() ||
padValue.value().getType() == sourceShapedType.getElementType()) &&
"expected same pad element type to match source element type");
- int64_t readRank = inputVectorSizes.size();
+
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
- SmallVector<bool> inBoundsVal(readRank, true);
+ SmallVector<bool> inBoundsVal(vecToReadRank, true);
if (useInBoundsInsteadOfMasking) {
// Update the inBounds attribute.
// FIXME: This computation is too weak - it ignores the read indices.
- for (unsigned i = 0; i < readRank; i++)
- inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
+ for (unsigned i = 0; i < vecToReadRank; i++)
+ inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) &&
ShapedType::isStatic(sourceShape[i]);
}
auto transferReadOp = vector::TransferReadOp::create(
builder, loc,
- /*vectorType=*/vectorType,
+ /*vectorType=*/vecToReadTy,
/*source=*/source,
- /*indices=*/SmallVector<Value>(readRank, zero),
+ /*indices=*/SmallVector<Value>(vecToReadRank, zero),
/*padding=*/padValue,
/*inBounds=*/inBoundsVal);
- if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
+ if (llvm::equal(vecToReadTy.getShape(), sourceShape) ||
+ useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
isa<MemRefType>(source.getType())
? memref::getMixedSizes(builder, loc, source)
: tensor::getMixedSizes(builder, loc, source);
- auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
- inputScalableVecDims);
+ auto maskType = vecToReadTy.clone(builder.getI1Type());
Value mask =
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
|
a26b730 to
bea41f5
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
0ad9ece to
98a3fac
Compare
rengolin
left a comment
There was a problem hiding this comment.
Cleaner code is always good. 😃
This looks good to me, but it's a change of API, so best to leave a bit longer to make sure users have time to react.
|
@arun-thmn FYI |
dcaballe
left a comment
There was a problem hiding this comment.
LGTM! Add [NFC] before landing?
Simplify `createReadOrMaskedRead` to only require _one_ argument to specify the vector type to read (passed as `VectorType`) instead of passing vector-sizes and scalable-flags independently (i.e. _two_ arguments).
424bf12 to
a24d857
Compare
| /// Note: all read offsets are set to 0. | ||
| Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, | ||
| ArrayRef<int64_t> inputVectorSizes, | ||
| const VectorType &vecToReadTy, |
There was a problem hiding this comment.
Does it start enforcing users to create VectorType when they don't have to? E.g., I searched the use in IREE, and we have one use. The type is ShapedType; this API change forces the use to create VectorType when they dont have to.
(I don't have a strong opinion, and I can fix IREE side. I mainly want to provide a data point that users will have to create VectorType when they don't care about scalable flags.)
There was a problem hiding this comment.
Thanks for the feedback!
I mainly want to provide a data point that users will have to create VectorType when they don't care about scalable flags
My thinking here was that createReadOrMaskedRead always creates an instance of VectorType, so:
- When users do require
VectorType, we make sure that we create it only once and then sizes are easier to track (i.e. "Where are these sizes coming from?") - When users do not require
VectorType(like in your case), my PR creates extra burden on the users, but the number of instances ofVectorTypedoes not change.
To make it easier for you, how about adding an overload:
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes,
std::optional<Value> padValue = std::nullopt,
bool useInBoundsInsteadOfMasking = false,
ArrayRef<bool> inputScalableVecDims = {}) {
VectorType readVecType =
VectorType::get(inputVectorSizes, source.getElementType(),
readScalableVectorFlags);
createReadOrMaskedRead(builder, loc, source, readVecType, padValue, useInBoundsInsteadOfMasking);
}
WDYT?
Simplify
createReadOrMaskedReadto only require one argument tospecify the vector type to read (passed as
VectorType) instead ofpassing vector-sizes and scalable-flags independently (i.e. two
arguments).