@@ -1681,7 +1681,8 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
16811681 return write;
16821682
16831683 // Compute the mask and mask the write Op.
1684- auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type ());
1684+ auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type (),
1685+ vecToStoreType.getScalableDims ());
16851686
16861687 SmallVector<OpFoldResult> destSizes =
16871688 tensor::getMixedSizes (builder, loc, dest);
@@ -1773,8 +1774,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
17731774 for (auto [idx, size] : enumerate(innerTiles))
17741775 inputShape[innerDimsPos[idx]] *= size;
17751776 auto maskedRead = vector::createReadOrMaskedRead (
1776- rewriter, loc, packOp.getSource (), inputShape, padValue,
1777- useInBoundsInsteadOfMasking);
1777+ rewriter, loc, packOp.getSource (), inputShape,
1778+ /* inputScalableVecSizes= */ {}, padValue, useInBoundsInsteadOfMasking);
17781779
17791780 // Create ShapeCastOp.
17801781 SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1812,6 +1813,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18121813static LogicalResult
18131814vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18141815 ArrayRef<int64_t > inputVectorSizes,
1816+ ArrayRef<bool > inputScalableVecDims,
18151817 SmallVectorImpl<Value> &newResults) {
18161818
18171819 // TODO: Introduce a parent class that will handle the insertion point update.
@@ -1829,24 +1831,52 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18291831 auto destSize = unpackOp.getDestRank ();
18301832
18311833 if (!inputVectorSizes.empty ())
1832- assert (inputVectorSizes.size () == destSize &&
1834+ assert (inputVectorSizes.size () == destSize + sourceShape. size () &&
18331835 " Incorrect number of input vector sizes" );
18341836
1835- // vectorSizes is the shape of the vector that will be used to do final
1837+ SmallVector<bool > readScalableVectorFlags;
1838+ SmallVector<bool > writeScalableVectorFlags;
1839+ SmallVector<int64_t > readVectorSizes;
1840+ SmallVector<int64_t > writeVectorSizes;
1841+
1842+ // Split input-vector-sizes into vector sizes for the read and write
1843+ // operations.
1844+ if (!inputVectorSizes.empty ()) {
1845+ readVectorSizes.append (inputVectorSizes.begin (),
1846+ inputVectorSizes.begin () + sourceShape.size ());
1847+ writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1848+ inputVectorSizes.end ());
1849+ }
1850+ if (!inputScalableVecDims.empty ()) {
1851+ readScalableVectorFlags.append (inputScalableVecDims.begin (),
1852+ inputScalableVecDims.begin () +
1853+ sourceShape.size ());
1854+ writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1855+ sourceShape.size (),
1856+ inputScalableVecDims.end ());
1857+ } else {
1858+ readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1859+ writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1860+ }
1861+
1862+ // writeVectorSizes is the shape of the vector that will be used to do final
18361863 // write on the destination tensor. It is set like this: Let's say the
18371864 // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
18381865 // Thus:
1839- // 1. vectorSizes = sourceShape.take_front(N)
1840- // 2. if outer_dims_perms is present: do that permutation on vectorSizes .
1866+ // 1. writeVectorSizes = sourceShape.take_front(N)
1867+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes .
18411868 // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
18421869 // innerTiles attribute value.
1843- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1844- if (vectorSizes.empty ()) {
1845- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1870+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1871+ if (writeVectorSizes.empty ()) {
1872+ if (ShapedType::isDynamicShape (sourceShape))
1873+ return failure ();
1874+
1875+ llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
18461876 if (!outerDimsPerm.empty ())
1847- applyPermutationToVector (vectorSizes , outerDimsPerm);
1877+ applyPermutationToVector (writeVectorSizes , outerDimsPerm);
18481878 for (auto [i, pos] : llvm::enumerate (innerDimPos))
1849- vectorSizes [pos] *= innerTiles[i];
1879+ writeVectorSizes [pos] *= innerTiles[i];
18501880
18511881 useInBoundsInsteadOfMasking = true ;
18521882 }
@@ -1870,17 +1900,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18701900 // After applying outer_dims_perm: [8, 16]
18711901 // After appending the rest of the sourceShape: [8, 16, 32, 16]
18721902
1873- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1874-
1875- for (auto [index, size] : enumerate(innerTiles)) {
1876- readVectorSizes[innerDimPos[index]] =
1877- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1878- }
1879- if (!outerDimsPerm.empty ()) {
1880- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1903+ if (readVectorSizes.empty ()) {
1904+ // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1905+ // sizes. Note, this will only work when all sizes are static.
1906+ readVectorSizes = writeVectorSizes;
1907+ for (auto [index, size] : enumerate(innerTiles)) {
1908+ readVectorSizes[innerDimPos[index]] =
1909+ llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1910+ }
1911+ if (!outerDimsPerm.empty ()) {
1912+ applyPermutationToVector (readVectorSizes, outerDimsPerm);
1913+ }
1914+ readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1915+ sourceShape.end ());
18811916 }
1882- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1883- sourceShape.end ());
18841917
18851918 ReifiedRankedShapedTypeDims reifiedRetShapes;
18861919 LogicalResult status =
@@ -1898,7 +1931,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18981931 // Read result, mask if necessary. If transferReadOp shape is not equal
18991932 // to shape of source, then a mask is necessary.
19001933 Value readResult = vector::createReadOrMaskedRead (
1901- rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1934+ rewriter, loc, unpackOp.getSource (), readVectorSizes,
1935+ readScalableVectorFlags, padValue,
19021936 /* useInBoundsInsteadOfMasking=*/ false );
19031937
19041938 PackingMetadata packMetadata;
@@ -1918,15 +1952,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19181952 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
19191953 stripMineTensorType, packMetadata.reassociations );
19201954 mlir::VectorType vecCollapsedType =
1921- VectorType::get (collapsedType.getShape (), collapsedType.getElementType ());
1955+ VectorType::get (collapsedType.getShape (), collapsedType.getElementType (),
1956+ writeScalableVectorFlags);
19221957 vector::ShapeCastOp shapeCastOp = rewriter.create <vector::ShapeCastOp>(
19231958 loc, vecCollapsedType, transposeOp->getResult (0 ));
19241959
1925- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1960+ // writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
19261961 // otherwise the validator complains that the mask size is invalid.
1927- SmallVector<int64_t > writeVectorSizes (
1962+ // FIXME: We should not override write-vector-sizes like this.
1963+ SmallVector<int64_t > writeVectorSizesFinal (
19281964 unpackOp.getDestType ().hasStaticShape ()
1929- ? vectorSizes
1965+ ? writeVectorSizes
19301966 : shapeCastOp.getResultVectorType ().getShape ());
19311967 Operation *write = createWriteOrMaskedWrite (
19321968 rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
@@ -1956,7 +1992,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19561992 (void )status; // prevent unused variable warning on non-assert builds
19571993 assert (succeeded (status) && " failed to reify result shapes" );
19581994 auto maskedRead = vector::createReadOrMaskedRead (
1959- rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
1995+ rewriter, loc, padOp.getSource (), inputVectorSizes,
1996+ /* inputScalableVecSizes=*/ {}, padValue,
19601997 /* useInBoundsInsteadOfMasking=*/ false );
19611998
19621999 // Create Xfer write Op
@@ -2041,6 +2078,9 @@ static LogicalResult
20412078vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
20422079 ArrayRef<int64_t > inputVectorSizes) {
20432080
2081+ // FIXME!!!
2082+ return success ();
2083+
20442084 if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
20452085 return !getConstantIntValue (res).has_value ();
20462086 })) {
@@ -2291,6 +2331,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
22912331 LDBG (" pad value is not constant: " << packOp << " \n " );
22922332 return failure ();
22932333 }
2334+
22942335 ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
22952336 bool satisfyEmptyCond = true ;
22962337 if (inputVectorSizes.empty ()) {
@@ -2369,6 +2410,10 @@ vectorizeScalableVectorPrecondition(Operation *op,
23692410 if (numOfScalableDims == 0 )
23702411 return success ();
23712412
2413+ // FIXME!!!
2414+ return success ();
2415+
2416+ // TODO: Check the following!
23722417 auto linalgOp = dyn_cast<LinalgOp>(op);
23732418
23742419 // Cond 1: There's been no need for scalable vectorisation of
@@ -2469,7 +2514,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
24692514 return success (isElementwise (linalgOp) || isa<linalg::MatmulOp>(op) ||
24702515 isa<linalg::MatmulTransposeAOp>(op) ||
24712516 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2472- isa<linalg::MatvecOp>(op) || hasReductionIterator (linalgOp));
2517+ isa<linalg::MatvecOp>(op) || isa<linalg::UnPackOp>(op) ||
2518+ hasReductionIterator (linalgOp));
24732519}
24742520
24752521LogicalResult mlir::linalg::vectorizeOpPrecondition (
@@ -2598,7 +2644,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
25982644 })
25992645 .Case <linalg::UnPackOp>([&](auto unpackOp) {
26002646 return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2601- inputVectorSizes, results);
2647+ inputVectorSizes,
2648+ inputScalableVecDims, results);
26022649 })
26032650 .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
26042651 return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -2988,7 +3035,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
29883035 SmallVector<Value> readIndices (
29893036 vecType.getRank (), rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
29903037 Value read = mlir::vector::createReadOrMaskedRead (
2991- rewriter, loc, source, vecType.getShape (), padValue,
3038+ rewriter, loc, source, vecType.getShape (), /* inputScalableVecSizes=*/ {},
3039+ padValue,
29923040 /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
29933041
29943042 // Create write
0 commit comments