Skip to content

Commit 953eaf0

Browse files
committed
[mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP)
This patch updates `vectorizeAsTensorUnpackOp` to support scalable vectorization by requiring user-specified vector sizes for both the _read_ and _write_ operations involved in `linalg.unpack`. Detailed rationale and an example are provided below. Conceptually, `linalg.unpack` consists of the following high-level steps: 1. _Read_ from the source tensor. 2. Transpose the value read in step (1). 3. _Write_ the value from step (2) into the destination tensor. Currently, when vectorizing with user-provided vector sizes, only the sizes for the _write_ operation (step 3) are required. Sizes for the _read_ operation (step 1) are inferred from static shapes and inner tile sizes. This logic breaks when the input shapes or tile sizes are dynamic (indeed, `vectorizeUnPackOpPrecondition` rejects such cases ATM and the vectorization fails). This patch addresses the issue by requiring explicit vector sizes for both the read and write sides, enabling scalable vectorization in such cases. Example: ```mlir func.func @unpack(%in: tensor<1x1x8x?xf32>, %out: tensor<8x?xf32>) -> tensor<8x?xf32> { %vs = vector.vscale %c8 = arith.constant 8 : index %tile_size = arith.muli %vs, %c8 : index %unpack = linalg.unpack %in inner_dims_pos = [0, 1] inner_tiles = [8, %tile_size] into %out : tensor<1x1x8x?xf32> -> tensor<8x?xf32> return %unpack : tensor<8x?xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %0 vector_sizes [1, 1, 8, [8], 8, [8]] : !transform.any_op // \ / \ / // read-sizes write-sizes transform.yield } } ``` Finally, this patch also extends `createReadOrMaskedRead` and `createWriteOrMaskedWrite` to take scalable flags.
1 parent be6bed4 commit 953eaf0

File tree

4 files changed

+209
-80
lines changed

4 files changed

+209
-80
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ bool isLinearizableVector(VectorType type);
228228
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
229229
ArrayRef<int64_t> inputVectorSizes, Value padValue,
230230
bool useInBoundsInsteadOfMasking = false,
231-
ArrayRef<bool> scalableDims = {});
231+
ArrayRef<bool> inputScalableVecDims = {});
232232

233233
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
234234
/// given `shape`, i.e., it meets:

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 89 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,7 +1806,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18061806
inputShape[innerDimsPos[idx]] *= size;
18071807
auto maskedRead = vector::createReadOrMaskedRead(
18081808
rewriter, loc, packOp.getSource(), inputShape, padValue,
1809-
useInBoundsInsteadOfMasking);
1809+
useInBoundsInsteadOfMasking,
1810+
/*inputScalableVecSizes=*/{});
18101811

18111812
// Create ShapeCastOp.
18121813
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1832,18 +1833,23 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18321833
return success();
18331834
}
18341835

1835-
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1836-
/// Vector::TransferReadOp - Reads a vector from the source tensor
1837-
/// vector::TransposeOp - Transpose the Source tensor
1838-
/// ShapeCastOp - Reshape the data based on the target.
1839-
/// vector::TransferWriteOp. - Write the result vector back to the destination
1840-
/// tensor.
1841-
/// If the vector sizes are not provided:
1836+
/// Vectorize `linalg.unpack %src into %dest` as:
1837+
/// // Reads a vector from the source tensor
1838+
/// %read = vector.transfer_read %src
1839+
/// // Transpose %read as specified in `outer_dims_perm` attribute
1840+
/// %tr = vector.transpose %read
1841+
/// // Reshape the data based on the target
1842+
/// %sc = vector.shape_cast %tr
1843+
/// // Write the result vector to the destination tensor.
1844+
/// vector.transfer_write %sc into %dest
1845+
///
1846+
/// If the vector sizes are not provided:
18421847
/// * the vector sizes are determined by the input operand and attributes,
18431848
/// * update the inBounds attribute instead of masking.
18441849
static LogicalResult
18451850
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18461851
ArrayRef<int64_t> inputVectorSizes,
1852+
ArrayRef<bool> inputScalableVecDims,
18471853
SmallVectorImpl<Value> &newResults) {
18481854

18491855
// TODO: Introduce a parent class that will handle the insertion point update.
@@ -1860,25 +1866,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18601866

18611867
auto destSize = unpackOp.getDestRank();
18621868

1863-
if (!inputVectorSizes.empty())
1864-
assert(inputVectorSizes.size() == destSize &&
1869+
if (!inputVectorSizes.empty()) {
1870+
assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
18651871
"Incorrect number of input vector sizes");
1872+
}
1873+
1874+
SmallVector<bool> readScalableVectorFlags;
1875+
SmallVector<bool> writeScalableVectorFlags;
1876+
SmallVector<int64_t> readVectorSizes;
1877+
SmallVector<int64_t> writeVectorSizes;
18661878

1867-
// vectorSizes is the shape of the vector that will be used to do final
1879+
// Split input-vector-sizes into vector sizes for the read and write
1880+
// operations.
1881+
if (!inputVectorSizes.empty()) {
1882+
readVectorSizes.append(inputVectorSizes.begin(),
1883+
inputVectorSizes.begin() + sourceShape.size());
1884+
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
1885+
inputVectorSizes.end());
1886+
}
1887+
if (!inputScalableVecDims.empty()) {
1888+
readScalableVectorFlags.append(inputScalableVecDims.begin(),
1889+
inputScalableVecDims.begin() +
1890+
sourceShape.size());
1891+
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
1892+
sourceShape.size(),
1893+
inputScalableVecDims.end());
1894+
} else {
1895+
readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
1896+
writeScalableVectorFlags = SmallVector<bool>(destSize, false);
1897+
}
1898+
1899+
// writeVectorSizes is the shape of the vector that will be used to do final
18681900
// write on the destination tensor. It is set like this: Let's say the
18691901
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
18701902
// Thus:
1871-
// 1. vectorSizes = sourceShape.take_front(N)
1872-
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1903+
// 1. writeVectorSizes = sourceShape.take_front(N)
1904+
// 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
18731905
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
18741906
// innerTiles attribute value.
1875-
SmallVector<int64_t> vectorSizes(inputVectorSizes);
1876-
if (vectorSizes.empty()) {
1877-
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1907+
// SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1908+
if (writeVectorSizes.empty()) {
1909+
if (ShapedType::isDynamicShape(sourceShape))
1910+
return failure();
1911+
1912+
llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
18781913
if (!outerDimsPerm.empty())
1879-
applyPermutationToVector(vectorSizes, outerDimsPerm);
1914+
applyPermutationToVector(writeVectorSizes, outerDimsPerm);
18801915
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1881-
vectorSizes[pos] *= innerTiles[i];
1916+
writeVectorSizes[pos] *= innerTiles[i];
18821917

18831918
useInBoundsInsteadOfMasking = true;
18841919
}
@@ -1902,17 +1937,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19021937
// After applying outer_dims_perm: [8, 16]
19031938
// After appending the rest of the sourceShape: [8, 16, 32, 16]
19041939

1905-
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1906-
1907-
for (auto [index, size] : enumerate(innerTiles)) {
1908-
readVectorSizes[innerDimPos[index]] =
1909-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1910-
}
1911-
if (!outerDimsPerm.empty()) {
1912-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1940+
if (readVectorSizes.empty()) {
1941+
// Compute read-vector-sizes based on the write-vector-sizes and inner tile
1942+
// sizes. Note, this will only work when all sizes are static.
1943+
readVectorSizes = writeVectorSizes;
1944+
for (auto [index, size] : enumerate(innerTiles)) {
1945+
readVectorSizes[innerDimPos[index]] =
1946+
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1947+
}
1948+
if (!outerDimsPerm.empty()) {
1949+
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1950+
}
1951+
readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
1952+
sourceShape.end());
19131953
}
1914-
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1915-
sourceShape.end());
19161954

19171955
ReifiedRankedShapedTypeDims reifiedRetShapes;
19181956
LogicalResult status =
@@ -1931,7 +1969,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19311969
// to shape of source, then a mask is necessary.
19321970
Value readResult = vector::createReadOrMaskedRead(
19331971
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1934-
/*useInBoundsInsteadOfMasking=*/false);
1972+
/*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
19351973

19361974
PackingMetadata packMetadata;
19371975
SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -1950,15 +1988,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19501988
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
19511989
stripMineTensorType, packMetadata.reassociations);
19521990
mlir::VectorType vecCollapsedType =
1953-
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1991+
VectorType::get(collapsedType.getShape(), collapsedType.getElementType(),
1992+
writeScalableVectorFlags);
19541993
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
19551994
loc, vecCollapsedType, transposeOp->getResult(0));
19561995

1957-
// writeVectorSizes had to match the shapecast shape for dynamic sizes,
1996+
// writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
19581997
// otherwise the validator complains that the mask size is invalid.
1959-
SmallVector<int64_t> writeVectorSizes(
1998+
// FIXME: We should not override write-vector-sizes like this.
1999+
SmallVector<int64_t> writeVectorSizesFinal(
19602000
unpackOp.getDestType().hasStaticShape()
1961-
? vectorSizes
2001+
? writeVectorSizes
19622002
: shapeCastOp.getResultVectorType().getShape());
19632003
Operation *write = createWriteOrMaskedWrite(
19642004
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
@@ -1989,7 +2029,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19892029
assert(succeeded(status) && "failed to reify result shapes");
19902030
auto maskedRead = vector::createReadOrMaskedRead(
19912031
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1992-
/*useInBoundsInsteadOfMasking=*/false);
2032+
/*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
19932033

19942034
// Create Xfer write Op
19952035
Value dest = rewriter.create<tensor::EmptyOp>(
@@ -2073,6 +2113,9 @@ static LogicalResult
20732113
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
20742114
ArrayRef<int64_t> inputVectorSizes) {
20752115

2116+
// FIXME!!!
2117+
return success();
2118+
20762119
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
20772120
return !getConstantIntValue(res).has_value();
20782121
})) {
@@ -2409,6 +2452,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24092452
LDBG("pad value is not constant: " << packOp << "\n");
24102453
return failure();
24112454
}
2455+
24122456
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
24132457
bool satisfyEmptyCond = true;
24142458
if (inputVectorSizes.empty()) {
@@ -2487,12 +2531,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
24872531
if (numOfScalableDims == 0)
24882532
return success();
24892533

2534+
// TODO: Check the following!
24902535
auto linalgOp = dyn_cast<LinalgOp>(op);
24912536

2492-
// Cond 1: There's been no need for scalable vectorisation of
2493-
// non-linalg Ops so far
2494-
if (!linalgOp)
2495-
return failure();
2537+
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2538+
// exception of UnpackOp for which there is a dedicated hook.
2539+
if (!linalgOp) {
2540+
return isa<linalg::UnPackOp>(op) ? success() : failure();
2541+
}
24962542

24972543
// Cond 2: There's been no need for more than 2 scalable dims so far
24982544
if (numOfScalableDims > 2)
@@ -2588,7 +2634,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25882634
isa<linalg::MatmulTransposeAOp>(op) ||
25892635
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
25902636
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2591-
hasReductionIterator(linalgOp));
2637+
isa<linalg::UnPackOp>(op) || hasReductionIterator(linalgOp));
25922638
}
25932639

25942640
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2723,7 +2769,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27232769
})
27242770
.Case<linalg::UnPackOp>([&](auto unpackOp) {
27252771
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2726-
inputVectorSizes, results);
2772+
inputVectorSizes,
2773+
inputScalableVecDims, results);
27272774
})
27282775
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
27292776
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -3114,7 +3161,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31143161
vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
31153162
Value read = mlir::vector::createReadOrMaskedRead(
31163163
rewriter, loc, source, vecType.getShape(), padValue,
3117-
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3164+
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3165+
/*inputScalableVecSizes=*/{});
31183166

31193167
// Create write
31203168
auto writeIndices =

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -281,14 +281,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
281281
// Attempt to unroll until targetRank or the first scalable dimension (which
282282
// cannot be unrolled).
283283
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
284-
auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
285-
auto it = llvm::find(scalableDimsToUnroll, true);
286-
auto firstScalableDim = it - scalableDimsToUnroll.begin();
284+
auto inputScalableVecDimsToUnroll =
285+
vType.getScalableDims().drop_back(targetRank);
286+
auto it = llvm::find(inputScalableVecDimsToUnroll, true);
287+
auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin();
287288
if (firstScalableDim == 0)
288289
return {};
289290
// All scalable dimensions should be removed now.
290-
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
291-
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
291+
inputScalableVecDimsToUnroll =
292+
inputScalableVecDimsToUnroll.slice(0, firstScalableDim);
293+
assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) &&
292294
"unexpected leading scalable dimension");
293295
// Create an unroll iterator for leading dimensions.
294296
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
@@ -321,15 +323,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
321323
ArrayRef<int64_t> inputVectorSizes,
322324
Value padValue,
323325
bool useInBoundsInsteadOfMasking,
324-
ArrayRef<bool> scalableDims) {
326+
ArrayRef<bool> inputScalableVecDims) {
325327
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
326328
"invalid input vector sizes");
327329
auto sourceShapedType = cast<ShapedType>(source.getType());
328330
auto sourceShape = sourceShapedType.getShape();
329331
assert(sourceShape.size() == inputVectorSizes.size() &&
330332
"expected same ranks.");
331-
auto vectorType =
332-
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
333+
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
334+
inputScalableVecDims);
333335
assert(padValue.getType() == sourceShapedType.getElementType() &&
334336
"expected same pad element type to match source element type");
335337
int64_t readRank = inputVectorSizes.size();
@@ -358,8 +360,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
358360
? memref::getMixedSizes(builder, loc, source)
359361
: tensor::getMixedSizes(builder, loc, source);
360362

361-
auto maskType =
362-
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
363+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
364+
inputScalableVecDims);
363365
Value mask =
364366
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
365367
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)