@@ -1857,6 +1857,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18571857 ArrayRef<int64_t > inputVectorSizes,
18581858 ArrayRef<bool > inputScalableVecDims,
18591859 SmallVectorImpl<Value> &newResults) {
1860+ if (!inputVectorSizes.empty ()) {
1861+ assert (inputVectorSizes.size () ==
1862+ unpackOp.getDestRank () + unpackOp.getSourceRank () &&
1863+ " Invalid number of input vector sizes!" );
1864+ assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
1865+ " Incompatible number of vector sizes and vector scalable flags!" );
1866+ }
18601867
18611868 // TODO: Introduce a parent class that will handle the insertion point update.
18621869 OpBuilder::InsertionGuard g (rewriter);
@@ -1872,44 +1879,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18721879
18731880 auto destSize = unpackOp.getDestRank ();
18741881
1875- if (!inputVectorSizes.empty ()) {
1876- assert (inputVectorSizes.size () == destSize + sourceShape.size () &&
1877- " Incorrect number of input vector sizes" );
1878- }
1879-
1880- SmallVector<bool > readScalableVectorFlags;
1881- SmallVector<bool > writeScalableVectorFlags;
1882+ // 1. Obtain vector sizes for the read and write operation.s
18821883 SmallVector<int64_t > readVectorSizes;
18831884 SmallVector<int64_t > writeVectorSizes;
1885+ SmallVector<bool > readScalableVectorFlags;
1886+ SmallVector<bool > writeScalableVectorFlags;
18841887
1885- // Split input-vector-sizes into vector sizes for the read and write
1886- // operations .
1888+ // CASE 1: Vector sizes are user-specified.
1889+ // 1.0 This is the trivial case, simply split the input vector sizes .
18871890 if (!inputVectorSizes.empty ()) {
18881891 readVectorSizes.append (inputVectorSizes.begin (),
18891892 inputVectorSizes.begin () + sourceShape.size ());
18901893 writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
18911894 inputVectorSizes.end ());
1892- }
1893- if (!inputScalableVecDims.empty ()) {
18941895 readScalableVectorFlags.append (inputScalableVecDims.begin (),
18951896 inputScalableVecDims.begin () +
18961897 sourceShape.size ());
18971898 writeScalableVectorFlags.append (inputScalableVecDims.begin () +
18981899 sourceShape.size (),
18991900 inputScalableVecDims.end ());
1900- } else {
1901- readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1902- writeScalableVectorFlags = SmallVector<bool >(destSize, false );
19031901 }
19041902
1905- // writeVectorSizes is the shape of the vector that will be used to do final
1906- // write on the destination tensor. It is set like this: Let's say the
1907- // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1908- // Thus:
1909- // 1. writeVectorSizes = sourceShape.take_front(N)
1910- // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
1911- // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1912- // innerTiles attribute value.
1903+ // CASE 2: Vector sizes have to be inferred.
1904+ //
1905+ // 1.1 Infer vector sizes for the write operation.
1906+ //
1907+ // Let:
1908+ // * rank(source tensor) = 'M'
1909+ // * rank(dest tensor) = 'N',
1910+ // and N <= M. The steps are:
1911+ // 1. writeVectorSizes = sourceShape.take_front(N)
1912+ // 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1913+ // by the corresponding values from the `inner_tiles` attribute value.
1914+ // 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1915+ //
1916+ // Note, this will only work when all sizes are static!
19131917 if (writeVectorSizes.empty ()) {
19141918 if (ShapedType::isDynamicShape (sourceShape))
19151919 return failure ();
@@ -1923,28 +1927,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19231927 useInBoundsInsteadOfMasking = true ;
19241928 }
19251929
1926- // readVectorSizes is the size of tensor used to read and apply mask. It is
1927- // set like this: Let's say the vectorSize (VS) array is size 'N' and
1928- // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1929- // size M-N
1930- // Thus:
1931- // - initially: readVectorSizes = vectorInputSizes
1932- // - Divide all the readMaskShape locations pointed by innerDimPos
1933- // by the innerTileSize attribute value.
1934- // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1935- // - Append the remaining shape from SS
1936- // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1937- // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1938- // 128] and outer_dims_perm is [1, 0] then read shape is:
1939- // ReadVectorSizes(initial): [512, 128]
1940- // Final Value(after innerDim Adjustment): [512/32, 128/16]
1941- // = [16, 8]
1942- // After applying outer_dims_perm: [8, 16]
1943- // After appending the rest of the sourceShape: [8, 16, 32, 16]
1944-
1930+ // 1.2 Infer vector sizes for the read operation.
1931+ //
1932+ // The steps are:
1933+ // 1. readVectorSizes = vectorInputSizes
1934+ // 2. Take readVectorSizes from 1. and divide all locations pointed by
1935+ // the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1936+ // 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1937+ // 4. Append the remaining sizes from the source tensor.
1938+ //
1939+ // Note, this will only work when all sizes are static!
19451940 if (readVectorSizes.empty ()) {
1946- // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1947- // sizes. Note, this will only work when all sizes are static.
19481941 readVectorSizes = writeVectorSizes;
19491942 for (auto [index, size] : enumerate(innerTiles)) {
19501943 readVectorSizes[innerDimPos[index]] =
0 commit comments