@@ -1709,7 +1709,8 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
17091709 return write;
17101710
17111711 // Compute the mask and mask the write Op.
1712- auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type ());
1712+ auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type (),
1713+ vecToStoreType.getScalableDims ());
17131714
17141715 SmallVector<OpFoldResult> destSizes =
17151716 tensor::getMixedSizes (builder, loc, dest);
@@ -1801,8 +1802,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18011802 for (auto [idx, size] : enumerate(innerTiles))
18021803 inputShape[innerDimsPos[idx]] *= size;
18031804 auto maskedRead = vector::createReadOrMaskedRead (
1804- rewriter, loc, packOp.getSource (), inputShape, padValue,
1805- useInBoundsInsteadOfMasking);
1805+ rewriter, loc, packOp.getSource (), inputShape,
1806+ /* inputScalableVecSizes= */ {}, padValue, useInBoundsInsteadOfMasking);
18061807
18071808 // Create ShapeCastOp.
18081809 SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1828,18 +1829,23 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18281829 return success ();
18291830}
18301831
1831- // / Vectorize a `linalg::UnPackOp` to these 4 Ops:
1832- // / Vector::TransferReadOp - Reads a vector from the source tensor
1833- // / vector::TransposeOp - Transpose the Source tensor
1834- // / ShapeCastOp - Reshape the data based on the target.
1835- // / vector::TransferWriteOp. - Write the result vector back to the destination
1836- // / tensor.
1837- // / If the vector sizes are not provided:
1832+ // / Vectorize `linalg.unpack %src into %dest` as:
1833+ // / // Reads a vector from the source tensor
1834+ // / %read = vector.transfer_read %src
1835+ // / // Transpose %read as specified in `outer_dims_perm` attribute
1836+ // / %tr = vector.transpose %read
1837+ // / // Reshape the data based on the target
1838+ // / %sc = vector.shape_cast %tr
1839+ // / // Write the result vector to the destination tensor.
1840+ // / vector.transfer_write %sc into %dest
1841+ // /
1842+ // / If the vector sizes are not provided:
18381843// / * the vector sizes are determined by the input operand and attributes,
18391844// / * update the inBounds attribute instead of masking.
18401845static LogicalResult
18411846vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18421847 ArrayRef<int64_t > inputVectorSizes,
1848+ ArrayRef<bool > inputScalableVecDims,
18431849 SmallVectorImpl<Value> &newResults) {
18441850
18451851 // TODO: Introduce a parent class that will handle the insertion point update.
@@ -1856,25 +1862,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18561862
18571863 auto destSize = unpackOp.getDestRank ();
18581864
1859- if (!inputVectorSizes.empty ())
1860- assert (inputVectorSizes.size () == destSize &&
1865+ if (!inputVectorSizes.empty ()) {
1866+ assert (inputVectorSizes.size () == destSize + sourceShape. size () &&
18611867 " Incorrect number of input vector sizes" );
1868+ }
1869+
1870+ SmallVector<bool > readScalableVectorFlags;
1871+ SmallVector<bool > writeScalableVectorFlags;
1872+ SmallVector<int64_t > readVectorSizes;
1873+ SmallVector<int64_t > writeVectorSizes;
18621874
1863- // vectorSizes is the shape of the vector that will be used to do final
1875+ // Split input-vector-sizes into vector sizes for the read and write
1876+ // operations.
1877+ if (!inputVectorSizes.empty ()) {
1878+ readVectorSizes.append (inputVectorSizes.begin (),
1879+ inputVectorSizes.begin () + sourceShape.size ());
1880+ writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1881+ inputVectorSizes.end ());
1882+ }
1883+ if (!inputScalableVecDims.empty ()) {
1884+ readScalableVectorFlags.append (inputScalableVecDims.begin (),
1885+ inputScalableVecDims.begin () +
1886+ sourceShape.size ());
1887+ writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1888+ sourceShape.size (),
1889+ inputScalableVecDims.end ());
1890+ } else {
1891+ readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1892+ writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1893+ }
1894+
1895+ // writeVectorSizes is the shape of the vector that will be used to do final
18641896 // write on the destination tensor. It is set like this: Let's say the
18651897 // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
18661898 // Thus:
1867- // 1. vectorSizes = sourceShape.take_front(N)
1868- // 2. if outer_dims_perms is present: do that permutation on vectorSizes .
1899+ // 1. writeVectorSizes = sourceShape.take_front(N)
1900+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes .
18691901 // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
18701902 // innerTiles attribute value.
1871- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1872- if (vectorSizes.empty ()) {
1873- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1903+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1904+ if (writeVectorSizes.empty ()) {
1905+ if (ShapedType::isDynamicShape (sourceShape))
1906+ return failure ();
1907+
1908+ llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
18741909 if (!outerDimsPerm.empty ())
1875- applyPermutationToVector (vectorSizes , outerDimsPerm);
1910+ applyPermutationToVector (writeVectorSizes , outerDimsPerm);
18761911 for (auto [i, pos] : llvm::enumerate (innerDimPos))
1877- vectorSizes [pos] *= innerTiles[i];
1912+ writeVectorSizes [pos] *= innerTiles[i];
18781913
18791914 useInBoundsInsteadOfMasking = true ;
18801915 }
@@ -1898,17 +1933,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18981933 // After applying outer_dims_perm: [8, 16]
18991934 // After appending the rest of the sourceShape: [8, 16, 32, 16]
19001935
1901- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1902-
1903- for (auto [index, size] : enumerate(innerTiles)) {
1904- readVectorSizes[innerDimPos[index]] =
1905- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1906- }
1907- if (!outerDimsPerm.empty ()) {
1908- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1936+ if (readVectorSizes.empty ()) {
1937+ // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1938+ // sizes. Note, this will only work when all sizes are static.
1939+ readVectorSizes = writeVectorSizes;
1940+ for (auto [index, size] : enumerate(innerTiles)) {
1941+ readVectorSizes[innerDimPos[index]] =
1942+ llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1943+ }
1944+ if (!outerDimsPerm.empty ()) {
1945+ applyPermutationToVector (readVectorSizes, outerDimsPerm);
1946+ }
1947+ readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1948+ sourceShape.end ());
19091949 }
1910- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1911- sourceShape.end ());
19121950
19131951 ReifiedRankedShapedTypeDims reifiedRetShapes;
19141952 LogicalResult status =
@@ -1926,7 +1964,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19261964 // Read result, mask if necessary. If transferReadOp shape is not equal
19271965 // to shape of source, then a mask is necessary.
19281966 Value readResult = vector::createReadOrMaskedRead (
1929- rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1967+ rewriter, loc, unpackOp.getSource (), readVectorSizes,
1968+ readScalableVectorFlags, padValue,
19301969 /* useInBoundsInsteadOfMasking=*/ false );
19311970
19321971 PackingMetadata packMetadata;
@@ -1946,15 +1985,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19461985 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
19471986 stripMineTensorType, packMetadata.reassociations );
19481987 mlir::VectorType vecCollapsedType =
1949- VectorType::get (collapsedType.getShape (), collapsedType.getElementType ());
1988+ VectorType::get (collapsedType.getShape (), collapsedType.getElementType (),
1989+ writeScalableVectorFlags);
19501990 vector::ShapeCastOp shapeCastOp = rewriter.create <vector::ShapeCastOp>(
19511991 loc, vecCollapsedType, transposeOp->getResult (0 ));
19521992
1953- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1993+ // writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
19541994 // otherwise the validator complains that the mask size is invalid.
1955- SmallVector<int64_t > writeVectorSizes (
1995+ // FIXME: We should not override write-vector-sizes like this.
1996+ SmallVector<int64_t > writeVectorSizesFinal (
19561997 unpackOp.getDestType ().hasStaticShape ()
1957- ? vectorSizes
1998+ ? writeVectorSizes
19581999 : shapeCastOp.getResultVectorType ().getShape ());
19592000 Operation *write = createWriteOrMaskedWrite (
19602001 rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
@@ -1984,7 +2025,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19842025 (void )status; // prevent unused variable warning on non-assert builds
19852026 assert (succeeded (status) && " failed to reify result shapes" );
19862027 auto maskedRead = vector::createReadOrMaskedRead (
1987- rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
2028+ rewriter, loc, padOp.getSource (), inputVectorSizes,
2029+ /* inputScalableVecSizes=*/ {}, padValue,
19882030 /* useInBoundsInsteadOfMasking=*/ false );
19892031
19902032 // Create Xfer write Op
@@ -2069,6 +2111,9 @@ static LogicalResult
20692111vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
20702112 ArrayRef<int64_t > inputVectorSizes) {
20712113
2114+ // FIXME!!!
2115+ return success ();
2116+
20722117 if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
20732118 return !getConstantIntValue (res).has_value ();
20742119 })) {
@@ -2319,6 +2364,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
23192364 LDBG (" pad value is not constant: " << packOp << " \n " );
23202365 return failure ();
23212366 }
2367+
23222368 ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
23232369 bool satisfyEmptyCond = true ;
23242370 if (inputVectorSizes.empty ()) {
@@ -2397,6 +2443,10 @@ vectorizeScalableVectorPrecondition(Operation *op,
23972443 if (numOfScalableDims == 0 )
23982444 return success ();
23992445
2446+ // FIXME!!!
2447+ return success ();
2448+
2449+ // TODO: Check the following!
24002450 auto linalgOp = dyn_cast<LinalgOp>(op);
24012451
24022452 // Cond 1: There's been no need for scalable vectorisation of
@@ -2498,7 +2548,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
24982548 isa<linalg::MatmulTransposeAOp>(op) ||
24992549 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
25002550 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2501- hasReductionIterator (linalgOp));
2551+ isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
25022552}
25032553
25042554LogicalResult mlir::linalg::vectorizeOpPrecondition (
@@ -2627,7 +2677,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
26272677 })
26282678 .Case <linalg::UnPackOp>([&](auto unpackOp) {
26292679 return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2630- inputVectorSizes, results);
2680+ inputVectorSizes,
2681+ inputScalableVecDims, results);
26312682 })
26322683 .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
26332684 return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -3017,7 +3068,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30173068 SmallVector<Value> readIndices (
30183069 vecType.getRank (), rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
30193070 Value read = mlir::vector::createReadOrMaskedRead (
3020- rewriter, loc, source, vecType.getShape (), padValue,
3071+ rewriter, loc, source, vecType.getShape (), /* inputScalableVecSizes=*/ {},
3072+ padValue,
30213073 /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
30223074
30233075 // Create write
0 commit comments