@@ -1900,6 +1900,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19001900 ArrayRef<int64_t > inputVectorSizes,
19011901 ArrayRef<bool > inputScalableVecDims,
19021902 SmallVectorImpl<Value> &newResults) {
1903+ if (!inputVectorSizes.empty ()) {
1904+ assert (inputVectorSizes.size () ==
1905+ unpackOp.getDestRank () + unpackOp.getSourceRank () &&
1906+ " Invalid number of input vector sizes!" );
1907+ assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
1908+ " Incompatible number of vector sizes and vector scalable flags!" );
1909+ }
19031910
19041911 // TODO: Introduce a parent class that will handle the insertion point update.
19051912 OpBuilder::InsertionGuard g (rewriter);
@@ -1915,44 +1922,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19151922
19161923 auto destSize = unpackOp.getDestRank ();
19171924
1918- if (!inputVectorSizes.empty ()) {
1919- assert (inputVectorSizes.size () == destSize + sourceShape.size () &&
1920- " Incorrect number of input vector sizes" );
1921- }
1922-
1923- SmallVector<bool > readScalableVectorFlags;
1924- SmallVector<bool > writeScalableVectorFlags;
1925+ // 1. Obtain vector sizes for the read and write operation.s
19251926 SmallVector<int64_t > readVectorSizes;
19261927 SmallVector<int64_t > writeVectorSizes;
1928+ SmallVector<bool > readScalableVectorFlags;
1929+ SmallVector<bool > writeScalableVectorFlags;
19271930
1928- // Split input-vector-sizes into vector sizes for the read and write
1929- // operations .
1931+ // CASE 1: Vector sizes are user-specified.
1932+ // 1.0 This is the trivial case, simply split the input vector sizes .
19301933 if (!inputVectorSizes.empty ()) {
19311934 readVectorSizes.append (inputVectorSizes.begin (),
19321935 inputVectorSizes.begin () + sourceShape.size ());
19331936 writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
19341937 inputVectorSizes.end ());
1935- }
1936- if (!inputScalableVecDims.empty ()) {
19371938 readScalableVectorFlags.append (inputScalableVecDims.begin (),
19381939 inputScalableVecDims.begin () +
19391940 sourceShape.size ());
19401941 writeScalableVectorFlags.append (inputScalableVecDims.begin () +
19411942 sourceShape.size (),
19421943 inputScalableVecDims.end ());
1943- } else {
1944- readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1945- writeScalableVectorFlags = SmallVector<bool >(destSize, false );
19461944 }
19471945
1948- // writeVectorSizes is the shape of the vector that will be used to do final
1949- // write on the destination tensor. It is set like this: Let's say the
1950- // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1951- // Thus:
1952- // 1. writeVectorSizes = sourceShape.take_front(N)
1953- // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
1954- // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1955- // innerTiles attribute value.
1946+ // CASE 2: Vector sizes have to be inferred.
1947+ //
1948+ // 1.1 Infer vector sizes for the write operation.
1949+ //
1950+ // Let:
1951+ // * rank(source tensor) = 'M'
1952+ // * rank(dest tensor) = 'N',
1953+ // and N <= M. The steps are:
1954+ // 1. writeVectorSizes = sourceShape.take_front(N)
1955+ // 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1956+ // by the corresponding values from the `inner_tiles` attribute value.
1957+ // 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1958+ //
1959+ // Note, this will only work when all sizes are static!
19561960 if (writeVectorSizes.empty ()) {
19571961 if (ShapedType::isDynamicShape (sourceShape))
19581962 return failure ();
@@ -1966,28 +1970,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19661970 useInBoundsInsteadOfMasking = true ;
19671971 }
19681972
1969- // readVectorSizes is the size of tensor used to read and apply mask. It is
1970- // set like this: Let's say the vectorSize (VS) array is size 'N' and
1971- // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1972- // size M-N
1973- // Thus:
1974- // - initially: readVectorSizes = vectorInputSizes
1975- // - Divide all the readMaskShape locations pointed by innerDimPos
1976- // by the innerTileSize attribute value.
1977- // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1978- // - Append the remaining shape from SS
1979- // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1980- // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1981- // 128] and outer_dims_perm is [1, 0] then read shape is:
1982- // ReadVectorSizes(initial): [512, 128]
1983- // Final Value(after innerDim Adjustment): [512/32, 128/16]
1984- // = [16, 8]
1985- // After applying outer_dims_perm: [8, 16]
1986- // After appending the rest of the sourceShape: [8, 16, 32, 16]
1987-
1973+ // 1.2 Infer vector sizes for the read operation.
1974+ //
1975+ // The steps are:
1976+ // 1. readVectorSizes = vectorInputSizes
1977+ // 2. Take readVectorSizes from 1. and divide all locations pointed by
1978+ // the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1979+ // 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1980+ // 4. Append the remaining sizes from the source tensor.
1981+ //
1982+ // Note, this will only work when all sizes are static!
19881983 if (readVectorSizes.empty ()) {
1989- // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1990- // sizes. Note, this will only work when all sizes are static.
19911984 readVectorSizes = writeVectorSizes;
19921985 for (auto [index, size] : enumerate(innerTiles)) {
19931986 readVectorSizes[innerDimPos[index]] =
0 commit comments