@@ -1897,6 +1897,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18971897 ArrayRef<int64_t > inputVectorSizes,
18981898 ArrayRef<bool > inputScalableVecDims,
18991899 SmallVectorImpl<Value> &newResults) {
1900+ if (!inputVectorSizes.empty ()) {
1901+ assert (inputVectorSizes.size () ==
1902+ unpackOp.getDestRank () + unpackOp.getSourceRank () &&
1903+ " Invalid number of input vector sizes!" );
1904+ assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
1905+ " Incompatible number of vector sizes and vector scalable flags!" );
1906+ }
19001907
19011908 // TODO: Introduce a parent class that will handle the insertion point update.
19021909 OpBuilder::InsertionGuard g (rewriter);
@@ -1912,44 +1919,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19121919
19131920 auto destSize = unpackOp.getDestRank ();
19141921
1915- if (!inputVectorSizes.empty ()) {
1916- assert (inputVectorSizes.size () == destSize + sourceShape.size () &&
1917- " Incorrect number of input vector sizes" );
1918- }
1919-
1920- SmallVector<bool > readScalableVectorFlags;
1921- SmallVector<bool > writeScalableVectorFlags;
1922+ // 1. Obtain vector sizes for the read and write operation.s
19221923 SmallVector<int64_t > readVectorSizes;
19231924 SmallVector<int64_t > writeVectorSizes;
1925+ SmallVector<bool > readScalableVectorFlags;
1926+ SmallVector<bool > writeScalableVectorFlags;
19241927
1925- // Split input-vector-sizes into vector sizes for the read and write
1926- // operations .
1928+ // CASE 1: Vector sizes are user-specified.
1929+ // 1.0 This is the trivial case, simply split the input vector sizes .
19271930 if (!inputVectorSizes.empty ()) {
19281931 readVectorSizes.append (inputVectorSizes.begin (),
19291932 inputVectorSizes.begin () + sourceShape.size ());
19301933 writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
19311934 inputVectorSizes.end ());
1932- }
1933- if (!inputScalableVecDims.empty ()) {
19341935 readScalableVectorFlags.append (inputScalableVecDims.begin (),
19351936 inputScalableVecDims.begin () +
19361937 sourceShape.size ());
19371938 writeScalableVectorFlags.append (inputScalableVecDims.begin () +
19381939 sourceShape.size (),
19391940 inputScalableVecDims.end ());
1940- } else {
1941- readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1942- writeScalableVectorFlags = SmallVector<bool >(destSize, false );
19431941 }
19441942
1945- // writeVectorSizes is the shape of the vector that will be used to do final
1946- // write on the destination tensor. It is set like this: Let's say the
1947- // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1948- // Thus:
1949- // 1. writeVectorSizes = sourceShape.take_front(N)
1950- // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
1951- // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1952- // innerTiles attribute value.
1943+ // CASE 2: Vector sizes have to be inferred.
1944+ //
1945+ // 1.1 Infer vector sizes for the write operation.
1946+ //
1947+ // Let:
1948+ // * rank(source tensor) = 'M'
1949+ // * rank(dest tensor) = 'N',
1950+ // and N <= M. The steps are:
1951+ // 1. writeVectorSizes = sourceShape.take_front(N)
1952+ // 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1953+ // by the corresponding values from the `inner_tiles` attribute value.
1954+ // 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1955+ //
1956+ // Note, this will only work when all sizes are static!
19531957 if (writeVectorSizes.empty ()) {
19541958 if (ShapedType::isDynamicShape (sourceShape))
19551959 return failure ();
@@ -1963,28 +1967,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19631967 useInBoundsInsteadOfMasking = true ;
19641968 }
19651969
1966- // readVectorSizes is the size of tensor used to read and apply mask. It is
1967- // set like this: Let's say the vectorSize (VS) array is size 'N' and
1968- // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1969- // size M-N
1970- // Thus:
1971- // - initially: readVectorSizes = vectorInputSizes
1972- // - Divide all the readMaskShape locations pointed by innerDimPos
1973- // by the innerTileSize attribute value.
1974- // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1975- // - Append the remaining shape from SS
1976- // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1977- // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1978- // 128] and outer_dims_perm is [1, 0] then read shape is:
1979- // ReadVectorSizes(initial): [512, 128]
1980- // Final Value(after innerDim Adjustment): [512/32, 128/16]
1981- // = [16, 8]
1982- // After applying outer_dims_perm: [8, 16]
1983- // After appending the rest of the sourceShape: [8, 16, 32, 16]
1984-
1970+ // 1.2 Infer vector sizes for the read operation.
1971+ //
1972+ // The steps are:
1973+ // 1. readVectorSizes = vectorInputSizes
1974+ // 2. Take readVectorSizes from 1. and divide all locations pointed by
1975+ // the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1976+ // 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1977+ // 4. Append the remaining sizes from the source tensor.
1978+ //
1979+ // Note, this will only work when all sizes are static!
19851980 if (readVectorSizes.empty ()) {
1986- // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1987- // sizes. Note, this will only work when all sizes are static.
19881981 readVectorSizes = writeVectorSizes;
19891982 for (auto [index, size] : enumerate(innerTiles)) {
19901983 readVectorSizes[innerDimPos[index]] =
0 commit comments