@@ -1812,7 +1812,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18121812 inputShape[innerDimsPos[idx]] *= size;
18131813 auto maskedRead = vector::createReadOrMaskedRead (
18141814 rewriter, loc, packOp.getSource (), inputShape, padValue,
1815- useInBoundsInsteadOfMasking);
1815+ useInBoundsInsteadOfMasking,
1816+ /* inputScalableVecSizes=*/ {});
18161817
18171818 // Create ShapeCastOp.
18181819 SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1838,18 +1839,23 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18381839 return success ();
18391840}
18401841
1841- // / Vectorize a `linalg::UnPackOp` to these 4 Ops:
1842- // / Vector::TransferReadOp - Reads a vector from the source tensor
1843- // / vector::TransposeOp - Transpose the Source tensor
1844- // / ShapeCastOp - Reshape the data based on the target.
1845- // / vector::TransferWriteOp. - Write the result vector back to the destination
1846- // / tensor.
1847- // / If the vector sizes are not provided:
1842+ // / Vectorize `linalg.unpack %src into %dest` as:
1843+ // / // Reads a vector from the source tensor
1844+ // / %read = vector.transfer_read %src
1845+ // / // Transpose %read as specified in `outer_dims_perm` attribute
1846+ // / %tr = vector.transpose %read
1847+ // / // Reshape the data based on the target
1848+ // / %sc = vector.shape_cast %tr
1849+ // / // Write the result vector to the destination tensor.
1850+ // / vector.transfer_write %sc into %dest
1851+ // /
1852+ // / If the vector sizes are not provided:
18481853// / * the vector sizes are determined by the input operand and attributes,
18491854// / * update the inBounds attribute instead of masking.
18501855static LogicalResult
18511856vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18521857 ArrayRef<int64_t > inputVectorSizes,
1858+ ArrayRef<bool > inputScalableVecDims,
18531859 SmallVectorImpl<Value> &newResults) {
18541860
18551861 // TODO: Introduce a parent class that will handle the insertion point update.
@@ -1866,25 +1872,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18661872
18671873 auto destSize = unpackOp.getDestRank ();
18681874
1869- if (!inputVectorSizes.empty ())
1870- assert (inputVectorSizes.size () == destSize &&
1875+ if (!inputVectorSizes.empty ()) {
1876+ assert (inputVectorSizes.size () == destSize + sourceShape. size () &&
18711877 " Incorrect number of input vector sizes" );
1878+ }
1879+
1880+ SmallVector<bool > readScalableVectorFlags;
1881+ SmallVector<bool > writeScalableVectorFlags;
1882+ SmallVector<int64_t > readVectorSizes;
1883+ SmallVector<int64_t > writeVectorSizes;
18721884
1873- // vectorSizes is the shape of the vector that will be used to do final
1885+ // Split input-vector-sizes into vector sizes for the read and write
1886+ // operations.
1887+ if (!inputVectorSizes.empty ()) {
1888+ readVectorSizes.append (inputVectorSizes.begin (),
1889+ inputVectorSizes.begin () + sourceShape.size ());
1890+ writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1891+ inputVectorSizes.end ());
1892+ }
1893+ if (!inputScalableVecDims.empty ()) {
1894+ readScalableVectorFlags.append (inputScalableVecDims.begin (),
1895+ inputScalableVecDims.begin () +
1896+ sourceShape.size ());
1897+ writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1898+ sourceShape.size (),
1899+ inputScalableVecDims.end ());
1900+ } else {
1901+ readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1902+ writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1903+ }
1904+
1905+ // writeVectorSizes is the shape of the vector that will be used to do final
18741906 // write on the destination tensor. It is set like this: Let's say the
18751907 // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
18761908 // Thus:
1877- // 1. vectorSizes = sourceShape.take_front(N)
1878- // 2. if outer_dims_perms is present: do that permutation on vectorSizes .
1909+ // 1. writeVectorSizes = sourceShape.take_front(N)
1910+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes .
18791911 // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
18801912 // innerTiles attribute value.
1881- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1882- if (vectorSizes.empty ()) {
1883- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1913+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1914+ if (writeVectorSizes.empty ()) {
1915+ if (ShapedType::isDynamicShape (sourceShape))
1916+ return failure ();
1917+
1918+ llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
18841919 if (!outerDimsPerm.empty ())
1885- applyPermutationToVector (vectorSizes , outerDimsPerm);
1920+ applyPermutationToVector (writeVectorSizes , outerDimsPerm);
18861921 for (auto [i, pos] : llvm::enumerate (innerDimPos))
1887- vectorSizes [pos] *= innerTiles[i];
1922+ writeVectorSizes [pos] *= innerTiles[i];
18881923
18891924 useInBoundsInsteadOfMasking = true ;
18901925 }
@@ -1908,17 +1943,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19081943 // After applying outer_dims_perm: [8, 16]
19091944 // After appending the rest of the sourceShape: [8, 16, 32, 16]
19101945
1911- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1912-
1913- for (auto [index, size] : enumerate(innerTiles)) {
1914- readVectorSizes[innerDimPos[index]] =
1915- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1916- }
1917- if (!outerDimsPerm.empty ()) {
1918- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1946+ if (readVectorSizes.empty ()) {
1947+ // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1948+ // sizes. Note, this will only work when all sizes are static.
1949+ readVectorSizes = writeVectorSizes;
1950+ for (auto [index, size] : enumerate(innerTiles)) {
1951+ readVectorSizes[innerDimPos[index]] =
1952+ llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1953+ }
1954+ if (!outerDimsPerm.empty ()) {
1955+ applyPermutationToVector (readVectorSizes, outerDimsPerm);
1956+ }
1957+ readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1958+ sourceShape.end ());
19191959 }
1920- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1921- sourceShape.end ());
19221960
19231961 Location loc = unpackOp->getLoc ();
19241962
@@ -1930,7 +1968,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19301968 // to shape of source, then a mask is necessary.
19311969 Value readResult = vector::createReadOrMaskedRead (
19321970 rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1933- /* useInBoundsInsteadOfMasking=*/ false );
1971+ /* useInBoundsInsteadOfMasking=*/ false , readScalableVectorFlags );
19341972
19351973 PackingMetadata packMetadata;
19361974 SmallVector<int64_t > lastDimToInsertPosPerm =
@@ -1949,15 +1987,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19491987 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
19501988 stripMineTensorType, packMetadata.reassociations );
19511989 mlir::VectorType vecCollapsedType =
1952- VectorType::get (collapsedType.getShape (), collapsedType.getElementType ());
1990+ VectorType::get (collapsedType.getShape (), collapsedType.getElementType (),
1991+ writeScalableVectorFlags);
19531992 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create (
19541993 rewriter, loc, vecCollapsedType, transposeOp->getResult (0 ));
19551994
1956- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1995+ // writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
19571996 // otherwise the validator complains that the mask size is invalid.
1958- SmallVector<int64_t > writeVectorSizes (
1997+ // FIXME: We should not override write-vector-sizes like this.
1998+ SmallVector<int64_t > writeVectorSizesFinal (
19591999 unpackOp.getDestType ().hasStaticShape ()
1960- ? vectorSizes
2000+ ? writeVectorSizes
19612001 : shapeCastOp.getResultVectorType ().getShape ());
19622002 Operation *write = createWriteOrMaskedWrite (
19632003 rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
@@ -1988,7 +2028,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19882028 assert (succeeded (status) && " failed to reify result shapes" );
19892029 auto maskedRead = vector::createReadOrMaskedRead (
19902030 rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
1991- /* useInBoundsInsteadOfMasking=*/ false );
2031+ /* useInBoundsInsteadOfMasking=*/ false , /* inputScalableVecSizes= */ {} );
19922032
19932033 // Create Xfer write Op
19942034 Value dest = tensor::EmptyOp::create (rewriter, loc, reifiedReturnShapes[0 ],
@@ -2072,6 +2112,9 @@ static LogicalResult
20722112vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
20732113 ArrayRef<int64_t > inputVectorSizes) {
20742114
2115+ // FIXME!!!
2116+ return success ();
2117+
20752118 if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
20762119 return !getConstantIntValue (res).has_value ();
20772120 })) {
@@ -2408,6 +2451,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24082451 LDBG (" pad value is not constant: " << packOp << " \n " );
24092452 return failure ();
24102453 }
2454+
24112455 ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
24122456 bool satisfyEmptyCond = true ;
24132457 if (inputVectorSizes.empty ()) {
@@ -2486,12 +2530,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
24862530 if (numOfScalableDims == 0 )
24872531 return success ();
24882532
2533+ // TODO: Check the following!
24892534 auto linalgOp = dyn_cast<LinalgOp>(op);
24902535
2491- // Cond 1: There's been no need for scalable vectorisation of
2492- // non-linalg Ops so far
2493- if (!linalgOp)
2494- return failure ();
2536+ // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2537+ // exception of UnpackOp for which there is a dedicated hook.
2538+ if (!linalgOp) {
2539+ return isa<linalg::UnPackOp>(op) ? success () : failure ();
2540+ }
24952541
24962542 // Cond 2: There's been no need for more than 2 scalable dims so far
24972543 if (numOfScalableDims > 2 )
@@ -2587,7 +2633,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25872633 isa<linalg::MatmulTransposeAOp>(op) ||
25882634 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
25892635 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2590- hasReductionIterator (linalgOp));
2636+ isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
25912637}
25922638
25932639LogicalResult mlir::linalg::vectorizeOpPrecondition (
@@ -2722,7 +2768,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27222768 })
27232769 .Case <linalg::UnPackOp>([&](auto unpackOp) {
27242770 return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2725- inputVectorSizes, results);
2771+ inputVectorSizes,
2772+ inputScalableVecDims, results);
27262773 })
27272774 .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
27282775 return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -3114,7 +3161,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31143161 vecType.getRank (), arith::ConstantIndexOp::create (rewriter, loc, 0 ));
31153162 Value read = mlir::vector::createReadOrMaskedRead (
31163163 rewriter, loc, source, vecType.getShape (), padValue,
3117- /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
3164+ /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty (),
3165+ /* inputScalableVecSizes=*/ {});
31183166
31193167 // Create write
31203168 auto writeIndices =
0 commit comments