@@ -1850,6 +1850,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18501850 ArrayRef<int64_t > inputVectorSizes,
18511851 ArrayRef<bool > inputScalableVecDims,
18521852 SmallVectorImpl<Value> &newResults) {
1853+ if (!inputVectorSizes.empty ()) {
1854+ assert (inputVectorSizes.size () ==
1855+ unpackOp.getDestRank () + unpackOp.getSourceRank () &&
1856+ " Invalid number of input vector sizes!" );
1857+ assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
1858+ " Incompatible number of vector sizes and vector scalable flags!" );
1859+ }
18531860
18541861 // TODO: Introduce a parent class that will handle the insertion point update.
18551862 OpBuilder::InsertionGuard g (rewriter);
@@ -1865,44 +1872,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18651872
18661873 auto destSize = unpackOp.getDestRank ();
18671874
1868- if (!inputVectorSizes.empty ()) {
1869- assert (inputVectorSizes.size () == destSize + sourceShape.size () &&
1870- " Incorrect number of input vector sizes" );
1871- }
1872-
1873- SmallVector<bool > readScalableVectorFlags;
1874- SmallVector<bool > writeScalableVectorFlags;
1875+ // 1. Obtain vector sizes for the read and write operation.s
18751876 SmallVector<int64_t > readVectorSizes;
18761877 SmallVector<int64_t > writeVectorSizes;
1878+ SmallVector<bool > readScalableVectorFlags;
1879+ SmallVector<bool > writeScalableVectorFlags;
18771880
1878- // Split input-vector-sizes into vector sizes for the read and write
1879- // operations .
1881+ // CASE 1: Vector sizes are user-specified.
1882+ // 1.0 This is the trivial case, simply split the input vector sizes .
18801883 if (!inputVectorSizes.empty ()) {
18811884 readVectorSizes.append (inputVectorSizes.begin (),
18821885 inputVectorSizes.begin () + sourceShape.size ());
18831886 writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
18841887 inputVectorSizes.end ());
1885- }
1886- if (!inputScalableVecDims.empty ()) {
18871888 readScalableVectorFlags.append (inputScalableVecDims.begin (),
18881889 inputScalableVecDims.begin () +
18891890 sourceShape.size ());
18901891 writeScalableVectorFlags.append (inputScalableVecDims.begin () +
18911892 sourceShape.size (),
18921893 inputScalableVecDims.end ());
1893- } else {
1894- readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1895- writeScalableVectorFlags = SmallVector<bool >(destSize, false );
18961894 }
18971895
1898- // writeVectorSizes is the shape of the vector that will be used to do final
1899- // write on the destination tensor. It is set like this: Let's say the
1900- // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1901- // Thus:
1902- // 1. writeVectorSizes = sourceShape.take_front(N)
1903- // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
1904- // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1905- // innerTiles attribute value.
1896+ // CASE 2: Vector sizes have to be inferred.
1897+ //
1898+ // 1.1 Infer vector sizes for the write operation.
1899+ //
1900+ // Let:
1901+ // * rank(source tensor) = 'M'
1902+ // * rank(dest tensor) = 'N',
1903+ // and N <= M. The steps are:
1904+ // 1. writeVectorSizes = sourceShape.take_front(N)
1905+ // 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1906+ // by the corresponding values from the `inner_tiles` attribute value.
1907+ // 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1908+ //
1909+ // Note, this will only work when all sizes are static!
19061910 if (writeVectorSizes.empty ()) {
19071911 if (ShapedType::isDynamicShape (sourceShape))
19081912 return failure ();
@@ -1916,28 +1920,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19161920 useInBoundsInsteadOfMasking = true ;
19171921 }
19181922
1919- // readVectorSizes is the size of tensor used to read and apply mask. It is
1920- // set like this: Let's say the vectorSize (VS) array is size 'N' and
1921- // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1922- // size M-N
1923- // Thus:
1924- // - initially: readVectorSizes = vectorInputSizes
1925- // - Divide all the readMaskShape locations pointed by innerDimPos
1926- // by the innerTileSize attribute value.
1927- // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1928- // - Append the remaining shape from SS
1929- // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1930- // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1931- // 128] and outer_dims_perm is [1, 0] then read shape is:
1932- // ReadVectorSizes(initial): [512, 128]
1933- // Final Value(after innerDim Adjustment): [512/32, 128/16]
1934- // = [16, 8]
1935- // After applying outer_dims_perm: [8, 16]
1936- // After appending the rest of the sourceShape: [8, 16, 32, 16]
1937-
1923+ // 1.2 Infer vector sizes for the read operation.
1924+ //
1925+ // The steps are:
1926+ // 1. readVectorSizes = vectorInputSizes
1927+ // 2. Take readVectorSizes from 1. and divide all locations pointed by
1928+ // the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1929+ // 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1930+ // 4. Append the remaining sizes from the source tensor.
1931+ //
1932+ // Note, this will only work when all sizes are static!
19381933 if (readVectorSizes.empty ()) {
1939- // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1940- // sizes. Note, this will only work when all sizes are static.
19411934 readVectorSizes = writeVectorSizes;
19421935 for (auto [index, size] : enumerate(innerTiles)) {
19431936 readVectorSizes[innerDimPos[index]] =
0 commit comments