Skip to content

Commit 5deba69

Browse files
committed
[mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP)
This patch updates `vectorizeAsTensorUnpackOp` to support scalable vectorization by requiring user-specified vector sizes for both the _read_ and _write_ operations involved in `linalg.unpack`. Detailed rationale and an example are provided below. Conceptually, `linalg.unpack` consists of the following high-level steps: 1. _Read_ from the source tensor. 2. Transpose the value read in step (1). 3. _Write_ the value from step (2) into the destination tensor. Currently, when vectorizing with user-provided vector sizes, only the sizes for the _write_ operation (step 3) are required. Sizes for the _read_ operation (step 1) are inferred from static shapes and inner tile sizes. This logic breaks when the input shapes or tile sizes are dynamic (indeed, `vectorizeUnPackOpPrecondition` rejects such cases ATM and the vectorization fails). This patch addresses the issue by requiring explicit vector sizes for both the read and write sides, enabling scalable vectorization in such cases. Example: ```mlir func.func @unpack(%in: tensor<1x1x8x?xf32>, %out: tensor<8x?xf32>) -> tensor<8x?xf32> { %vs = vector.vscale %c8 = arith.constant 8 : index %tile_size = arith.muli %vs, %c8 : index %unpack = linalg.unpack %in inner_dims_pos = [0, 1] inner_tiles = [8, %tile_size] into %out : tensor<1x1x8x?xf32> -> tensor<8x?xf32> return %unpack : tensor<8x?xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %0 vector_sizes [1, 1, 8, [8], 8, [8]] : !transform.any_op // \ / \ / // read-sizes write-sizes transform.yield } } ``` Finally, this patch also extends `createReadOrMaskedRead` and `createWriteOrMaskedWrite` to take scalable flags.
1 parent 076cde5 commit 5deba69

File tree

4 files changed

+206
-70
lines changed

4 files changed

+206
-70
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ bool isLinearizableVector(VectorType type);
225225
///
226226
/// Note: all read offsets are set to 0.
227227
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
228-
ArrayRef<int64_t> inputVectorSizes, Value padValue,
228+
ArrayRef<int64_t> inputVectorSizes,
229+
ArrayRef<bool> inputScalableVecSizes,
230+
Value padValue,
229231
bool useInBoundsInsteadOfMasking = false);
230232

231233
/// Returns success if `inputVectorSizes` is a valid masking configuraion for

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 91 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,7 +1709,8 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
17091709
return write;
17101710

17111711
// Compute the mask and mask the write Op.
1712-
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1712+
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1713+
vecToStoreType.getScalableDims());
17131714

17141715
SmallVector<OpFoldResult> destSizes =
17151716
tensor::getMixedSizes(builder, loc, dest);
@@ -1801,8 +1802,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18011802
for (auto [idx, size] : enumerate(innerTiles))
18021803
inputShape[innerDimsPos[idx]] *= size;
18031804
auto maskedRead = vector::createReadOrMaskedRead(
1804-
rewriter, loc, packOp.getSource(), inputShape, padValue,
1805-
useInBoundsInsteadOfMasking);
1805+
rewriter, loc, packOp.getSource(), inputShape,
1806+
/*inputScalableVecSizes=*/{}, padValue, useInBoundsInsteadOfMasking);
18061807

18071808
// Create ShapeCastOp.
18081809
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1828,18 +1829,23 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18281829
return success();
18291830
}
18301831

1831-
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1832-
/// Vector::TransferReadOp - Reads a vector from the source tensor
1833-
/// vector::TransposeOp - Transpose the Source tensor
1834-
/// ShapeCastOp - Reshape the data based on the target.
1835-
/// vector::TransferWriteOp. - Write the result vector back to the destination
1836-
/// tensor.
1837-
/// If the vector sizes are not provided:
1832+
/// Vectorize `linalg.unpack %src into %dest` as:
1833+
/// // Reads a vector from the source tensor
1834+
/// %read = vector.transfer_read %src
1835+
/// // Transpose %read as specified in `outer_dims_perm` attribute
1836+
/// %tr = vector.transpose %read
1837+
/// // Reshape the data based on the target
1838+
/// %sc = vector.shape_cast %tr
1839+
/// // Write the result vector to the destination tensor.
1840+
/// vector.transfer_write %sc into %dest
1841+
///
1842+
/// If the vector sizes are not provided:
18381843
/// * the vector sizes are determined by the input operand and attributes,
18391844
/// * update the inBounds attribute instead of masking.
18401845
static LogicalResult
18411846
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18421847
ArrayRef<int64_t> inputVectorSizes,
1848+
ArrayRef<bool> inputScalableVecDims,
18431849
SmallVectorImpl<Value> &newResults) {
18441850

18451851
// TODO: Introduce a parent class that will handle the insertion point update.
@@ -1856,25 +1862,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18561862

18571863
auto destSize = unpackOp.getDestRank();
18581864

1859-
if (!inputVectorSizes.empty())
1860-
assert(inputVectorSizes.size() == destSize &&
1865+
if (!inputVectorSizes.empty()) {
1866+
assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
18611867
"Incorrect number of input vector sizes");
1868+
}
1869+
1870+
SmallVector<bool> readScalableVectorFlags;
1871+
SmallVector<bool> writeScalableVectorFlags;
1872+
SmallVector<int64_t> readVectorSizes;
1873+
SmallVector<int64_t> writeVectorSizes;
18621874

1863-
// vectorSizes is the shape of the vector that will be used to do final
1875+
// Split input-vector-sizes into vector sizes for the read and write
1876+
// operations.
1877+
if (!inputVectorSizes.empty()) {
1878+
readVectorSizes.append(inputVectorSizes.begin(),
1879+
inputVectorSizes.begin() + sourceShape.size());
1880+
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
1881+
inputVectorSizes.end());
1882+
}
1883+
if (!inputScalableVecDims.empty()) {
1884+
readScalableVectorFlags.append(inputScalableVecDims.begin(),
1885+
inputScalableVecDims.begin() +
1886+
sourceShape.size());
1887+
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
1888+
sourceShape.size(),
1889+
inputScalableVecDims.end());
1890+
} else {
1891+
readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
1892+
writeScalableVectorFlags = SmallVector<bool>(destSize, false);
1893+
}
1894+
1895+
// writeVectorSizes is the shape of the vector that will be used to do final
18641896
// write on the destination tensor. It is set like this: Let's say the
18651897
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
18661898
// Thus:
1867-
// 1. vectorSizes = sourceShape.take_front(N)
1868-
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1899+
// 1. writeVectorSizes = sourceShape.take_front(N)
1900+
// 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
18691901
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
18701902
// innerTiles attribute value.
1871-
SmallVector<int64_t> vectorSizes(inputVectorSizes);
1872-
if (vectorSizes.empty()) {
1873-
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1903+
// SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1904+
if (writeVectorSizes.empty()) {
1905+
if (ShapedType::isDynamicShape(sourceShape))
1906+
return failure();
1907+
1908+
llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
18741909
if (!outerDimsPerm.empty())
1875-
applyPermutationToVector(vectorSizes, outerDimsPerm);
1910+
applyPermutationToVector(writeVectorSizes, outerDimsPerm);
18761911
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1877-
vectorSizes[pos] *= innerTiles[i];
1912+
writeVectorSizes[pos] *= innerTiles[i];
18781913

18791914
useInBoundsInsteadOfMasking = true;
18801915
}
@@ -1898,17 +1933,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18981933
// After applying outer_dims_perm: [8, 16]
18991934
// After appending the rest of the sourceShape: [8, 16, 32, 16]
19001935

1901-
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1902-
1903-
for (auto [index, size] : enumerate(innerTiles)) {
1904-
readVectorSizes[innerDimPos[index]] =
1905-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1906-
}
1907-
if (!outerDimsPerm.empty()) {
1908-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1936+
if (readVectorSizes.empty()) {
1937+
// Compute read-vector-sizes based on the write-vector-sizes and inner tile
1938+
// sizes. Note, this will only work when all sizes are static.
1939+
readVectorSizes = writeVectorSizes;
1940+
for (auto [index, size] : enumerate(innerTiles)) {
1941+
readVectorSizes[innerDimPos[index]] =
1942+
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1943+
}
1944+
if (!outerDimsPerm.empty()) {
1945+
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1946+
}
1947+
readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
1948+
sourceShape.end());
19091949
}
1910-
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1911-
sourceShape.end());
19121950

19131951
ReifiedRankedShapedTypeDims reifiedRetShapes;
19141952
LogicalResult status =
@@ -1926,7 +1964,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19261964
// Read result, mask if necessary. If transferReadOp shape is not equal
19271965
// to shape of source, then a mask is necessary.
19281966
Value readResult = vector::createReadOrMaskedRead(
1929-
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1967+
rewriter, loc, unpackOp.getSource(), readVectorSizes,
1968+
readScalableVectorFlags, padValue,
19301969
/*useInBoundsInsteadOfMasking=*/false);
19311970

19321971
PackingMetadata packMetadata;
@@ -1946,15 +1985,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19461985
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
19471986
stripMineTensorType, packMetadata.reassociations);
19481987
mlir::VectorType vecCollapsedType =
1949-
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1988+
VectorType::get(collapsedType.getShape(), collapsedType.getElementType(),
1989+
writeScalableVectorFlags);
19501990
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
19511991
loc, vecCollapsedType, transposeOp->getResult(0));
19521992

1953-
// writeVectorSizes had to match the shapecast shape for dynamic sizes,
1993+
// writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
19541994
// otherwise the validator complains that the mask size is invalid.
1955-
SmallVector<int64_t> writeVectorSizes(
1995+
// FIXME: We should not override write-vector-sizes like this.
1996+
SmallVector<int64_t> writeVectorSizesFinal(
19561997
unpackOp.getDestType().hasStaticShape()
1957-
? vectorSizes
1998+
? writeVectorSizes
19581999
: shapeCastOp.getResultVectorType().getShape());
19592000
Operation *write = createWriteOrMaskedWrite(
19602001
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
@@ -1984,7 +2025,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19842025
(void)status; // prevent unused variable warning on non-assert builds
19852026
assert(succeeded(status) && "failed to reify result shapes");
19862027
auto maskedRead = vector::createReadOrMaskedRead(
1987-
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2028+
rewriter, loc, padOp.getSource(), inputVectorSizes,
2029+
/*inputScalableVecSizes=*/{}, padValue,
19882030
/*useInBoundsInsteadOfMasking=*/false);
19892031

19902032
// Create Xfer write Op
@@ -2069,6 +2111,9 @@ static LogicalResult
20692111
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
20702112
ArrayRef<int64_t> inputVectorSizes) {
20712113

2114+
// FIXME!!!
2115+
return success();
2116+
20722117
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
20732118
return !getConstantIntValue(res).has_value();
20742119
})) {
@@ -2319,6 +2364,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
23192364
LDBG("pad value is not constant: " << packOp << "\n");
23202365
return failure();
23212366
}
2367+
23222368
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
23232369
bool satisfyEmptyCond = true;
23242370
if (inputVectorSizes.empty()) {
@@ -2397,6 +2443,10 @@ vectorizeScalableVectorPrecondition(Operation *op,
23972443
if (numOfScalableDims == 0)
23982444
return success();
23992445

2446+
// FIXME!!!
2447+
return success();
2448+
2449+
// TODO: Check the following!
24002450
auto linalgOp = dyn_cast<LinalgOp>(op);
24012451

24022452
// Cond 1: There's been no need for scalable vectorisation of
@@ -2498,7 +2548,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
24982548
isa<linalg::MatmulTransposeAOp>(op) ||
24992549
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
25002550
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2501-
hasReductionIterator(linalgOp));
2551+
isa<linalg::UnPackOp>(op) || hasReductionIterator(linalgOp));
25022552
}
25032553

25042554
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2627,7 +2677,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
26272677
})
26282678
.Case<linalg::UnPackOp>([&](auto unpackOp) {
26292679
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2630-
inputVectorSizes, results);
2680+
inputVectorSizes,
2681+
inputScalableVecDims, results);
26312682
})
26322683
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
26332684
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -3017,7 +3068,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30173068
SmallVector<Value> readIndices(
30183069
vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
30193070
Value read = mlir::vector::createReadOrMaskedRead(
3020-
rewriter, loc, source, vecType.getShape(), padValue,
3071+
rewriter, loc, source, vecType.getShape(), /*inputScalableVecSizes=*/{},
3072+
padValue,
30213073
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
30223074

30233075
// Create write

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ bool vector::isLinearizableVector(VectorType type) {
319319
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
320320
Value source,
321321
ArrayRef<int64_t> inputVectorSizes,
322+
ArrayRef<bool> inputScalableVecSizes,
322323
Value padValue,
323324
bool useInBoundsInsteadOfMasking) {
324325
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
@@ -327,7 +328,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
327328
auto sourceShape = sourceShapedType.getShape();
328329
assert(sourceShape.size() == inputVectorSizes.size() &&
329330
"expected same ranks.");
330-
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
331+
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
332+
inputScalableVecSizes);
331333
assert(padValue.getType() == sourceShapedType.getElementType() &&
332334
"expected same pad element type to match source element type");
333335
int64_t readRank = inputVectorSizes.size();
@@ -354,7 +356,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
354356
SmallVector<OpFoldResult> mixedSourceDims =
355357
tensor::getMixedSizes(builder, loc, source);
356358

357-
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
359+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
360+
inputScalableVecSizes);
358361
Value mask =
359362
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
360363
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)