@@ -1806,7 +1806,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18061806 inputShape[innerDimsPos[idx]] *= size;
18071807 auto maskedRead = vector::createReadOrMaskedRead (
18081808 rewriter, loc, packOp.getSource (), inputShape, padValue,
1809- useInBoundsInsteadOfMasking);
1809+ useInBoundsInsteadOfMasking,
1810+ /* inputScalableVecSizes=*/ {});
18101811
18111812 // Create ShapeCastOp.
18121813 SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1832,18 +1833,23 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18321833 return success ();
18331834}
18341835
1835- // / Vectorize a `linalg::UnPackOp` to these 4 Ops:
1836- // / Vector::TransferReadOp - Reads a vector from the source tensor
1837- // / vector::TransposeOp - Transpose the Source tensor
1838- // / ShapeCastOp - Reshape the data based on the target.
1839- // / vector::TransferWriteOp. - Write the result vector back to the destination
1840- // / tensor.
1841- // / If the vector sizes are not provided:
1836+ // / Vectorize `linalg.unpack %src into %dest` as:
1837+ // / // Reads a vector from the source tensor
1838+ // / %read = vector.transfer_read %src
1839+ // / // Transpose %read as specified in `outer_dims_perm` attribute
1840+ // / %tr = vector.transpose %read
1841+ // / // Reshape the data based on the target
1842+ // / %sc = vector.shape_cast %tr
1843+ // / // Write the result vector to the destination tensor.
1844+ // / vector.transfer_write %sc into %dest
1845+ // /
1846+ // / If the vector sizes are not provided:
18421847// / * the vector sizes are determined by the input operand and attributes,
18431848// / * update the inBounds attribute instead of masking.
18441849static LogicalResult
18451850vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18461851 ArrayRef<int64_t > inputVectorSizes,
1852+ ArrayRef<bool > inputScalableVecDims,
18471853 SmallVectorImpl<Value> &newResults) {
18481854
18491855 // TODO: Introduce a parent class that will handle the insertion point update.
@@ -1860,25 +1866,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18601866
18611867 auto destSize = unpackOp.getDestRank ();
18621868
1863- if (!inputVectorSizes.empty ())
1864- assert (inputVectorSizes.size () == destSize &&
1869+ if (!inputVectorSizes.empty ()) {
1870+ assert (inputVectorSizes.size () == destSize + sourceShape. size () &&
18651871 " Incorrect number of input vector sizes" );
1872+ }
1873+
1874+ SmallVector<bool > readScalableVectorFlags;
1875+ SmallVector<bool > writeScalableVectorFlags;
1876+ SmallVector<int64_t > readVectorSizes;
1877+ SmallVector<int64_t > writeVectorSizes;
18661878
1867- // vectorSizes is the shape of the vector that will be used to do final
1879+ // Split input-vector-sizes into vector sizes for the read and write
1880+ // operations.
1881+ if (!inputVectorSizes.empty ()) {
1882+ readVectorSizes.append (inputVectorSizes.begin (),
1883+ inputVectorSizes.begin () + sourceShape.size ());
1884+ writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1885+ inputVectorSizes.end ());
1886+ }
1887+ if (!inputScalableVecDims.empty ()) {
1888+ readScalableVectorFlags.append (inputScalableVecDims.begin (),
1889+ inputScalableVecDims.begin () +
1890+ sourceShape.size ());
1891+ writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1892+ sourceShape.size (),
1893+ inputScalableVecDims.end ());
1894+ } else {
1895+ readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1896+ writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1897+ }
1898+
1899+ // writeVectorSizes is the shape of the vector that will be used to do final
18681900 // write on the destination tensor. It is set like this: Let's say the
18691901 // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
18701902 // Thus:
1871- // 1. vectorSizes = sourceShape.take_front(N)
1872- // 2. if outer_dims_perms is present: do that permutation on vectorSizes .
1903+ // 1. writeVectorSizes = sourceShape.take_front(N)
1904+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes .
18731905 // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
18741906 // innerTiles attribute value.
1875- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1876- if (vectorSizes.empty ()) {
1877- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1907+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1908+ if (writeVectorSizes.empty ()) {
1909+ if (ShapedType::isDynamicShape (sourceShape))
1910+ return failure ();
1911+
1912+ llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
18781913 if (!outerDimsPerm.empty ())
1879- applyPermutationToVector (vectorSizes , outerDimsPerm);
1914+ applyPermutationToVector (writeVectorSizes , outerDimsPerm);
18801915 for (auto [i, pos] : llvm::enumerate (innerDimPos))
1881- vectorSizes [pos] *= innerTiles[i];
1916+ writeVectorSizes [pos] *= innerTiles[i];
18821917
18831918 useInBoundsInsteadOfMasking = true ;
18841919 }
@@ -1902,17 +1937,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19021937 // After applying outer_dims_perm: [8, 16]
19031938 // After appending the rest of the sourceShape: [8, 16, 32, 16]
19041939
1905- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1906-
1907- for (auto [index, size] : enumerate(innerTiles)) {
1908- readVectorSizes[innerDimPos[index]] =
1909- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1910- }
1911- if (!outerDimsPerm.empty ()) {
1912- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1940+ if (readVectorSizes.empty ()) {
1941+ // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1942+ // sizes. Note, this will only work when all sizes are static.
1943+ readVectorSizes = writeVectorSizes;
1944+ for (auto [index, size] : enumerate(innerTiles)) {
1945+ readVectorSizes[innerDimPos[index]] =
1946+ llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1947+ }
1948+ if (!outerDimsPerm.empty ()) {
1949+ applyPermutationToVector (readVectorSizes, outerDimsPerm);
1950+ }
1951+ readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1952+ sourceShape.end ());
19131953 }
1914- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1915- sourceShape.end ());
19161954
19171955 ReifiedRankedShapedTypeDims reifiedRetShapes;
19181956 LogicalResult status =
@@ -1931,7 +1969,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19311969 // to shape of source, then a mask is necessary.
19321970 Value readResult = vector::createReadOrMaskedRead (
19331971 rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1934- /* useInBoundsInsteadOfMasking=*/ false );
1972+ /* useInBoundsInsteadOfMasking=*/ false , readScalableVectorFlags );
19351973
19361974 PackingMetadata packMetadata;
19371975 SmallVector<int64_t > lastDimToInsertPosPerm =
@@ -1950,15 +1988,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19501988 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
19511989 stripMineTensorType, packMetadata.reassociations );
19521990 mlir::VectorType vecCollapsedType =
1953- VectorType::get (collapsedType.getShape (), collapsedType.getElementType ());
1991+ VectorType::get (collapsedType.getShape (), collapsedType.getElementType (),
1992+ writeScalableVectorFlags);
19541993 vector::ShapeCastOp shapeCastOp = rewriter.create <vector::ShapeCastOp>(
19551994 loc, vecCollapsedType, transposeOp->getResult (0 ));
19561995
1957- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1996+ // writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
19581997 // otherwise the validator complains that the mask size is invalid.
1959- SmallVector<int64_t > writeVectorSizes (
1998+ // FIXME: We should not override write-vector-sizes like this.
1999+ SmallVector<int64_t > writeVectorSizesFinal (
19602000 unpackOp.getDestType ().hasStaticShape ()
1961- ? vectorSizes
2001+ ? writeVectorSizes
19622002 : shapeCastOp.getResultVectorType ().getShape ());
19632003 Operation *write = createWriteOrMaskedWrite (
19642004 rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
@@ -1989,7 +2029,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19892029 assert (succeeded (status) && " failed to reify result shapes" );
19902030 auto maskedRead = vector::createReadOrMaskedRead (
19912031 rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
1992- /* useInBoundsInsteadOfMasking=*/ false );
2032+ /* useInBoundsInsteadOfMasking=*/ false , /* inputScalableVecSizes= */ {} );
19932033
19942034 // Create Xfer write Op
19952035 Value dest = rewriter.create <tensor::EmptyOp>(
@@ -2073,6 +2113,9 @@ static LogicalResult
20732113vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
20742114 ArrayRef<int64_t > inputVectorSizes) {
20752115
2116+ // FIXME!!!
2117+ return success ();
2118+
20762119 if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
20772120 return !getConstantIntValue (res).has_value ();
20782121 })) {
@@ -2409,6 +2452,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24092452 LDBG (" pad value is not constant: " << packOp << " \n " );
24102453 return failure ();
24112454 }
2455+
24122456 ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
24132457 bool satisfyEmptyCond = true ;
24142458 if (inputVectorSizes.empty ()) {
@@ -2487,12 +2531,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
24872531 if (numOfScalableDims == 0 )
24882532 return success ();
24892533
2534+ // TODO: Check the following!
24902535 auto linalgOp = dyn_cast<LinalgOp>(op);
24912536
2492- // Cond 1: There's been no need for scalable vectorisation of
2493- // non-linalg Ops so far
2494- if (!linalgOp)
2495- return failure ();
2537+ // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2538+ // exception of UnpackOp for which there is a dedicated hook.
2539+ if (!linalgOp) {
2540+ return isa<linalg::UnPackOp>(op) ? success () : failure ();
2541+ }
24962542
24972543 // Cond 2: There's been no need for more than 2 scalable dims so far
24982544 if (numOfScalableDims > 2 )
@@ -2588,7 +2634,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25882634 isa<linalg::MatmulTransposeAOp>(op) ||
25892635 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
25902636 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2591- hasReductionIterator (linalgOp));
2637+ isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
25922638}
25932639
25942640LogicalResult mlir::linalg::vectorizeOpPrecondition (
@@ -2723,7 +2769,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27232769 })
27242770 .Case <linalg::UnPackOp>([&](auto unpackOp) {
27252771 return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2726- inputVectorSizes, results);
2772+ inputVectorSizes,
2773+ inputScalableVecDims, results);
27272774 })
27282775 .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
27292776 return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -3114,7 +3161,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31143161 vecType.getRank (), rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
31153162 Value read = mlir::vector::createReadOrMaskedRead (
31163163 rewriter, loc, source, vecType.getShape (), padValue,
3117- /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
3164+ /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty (),
3165+ /* inputScalableVecSizes=*/ {});
31183166
31193167 // Create write
31203168 auto writeIndices =
0 commit comments