@@ -1872,19 +1872,12 @@ static VectorType getCollapsedVecType(VectorType type,
18721872 return VectorType::get (newShape, type.getElementType (), newScalableFlags);
18731873}
18741874
1875- // / Vectorize `linalg.unpack` into :
1875+ // / Vectorize `linalg.unpack` as :
18761876// / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
18771877// /
1878- // / The input-vector-sizes specify both the read and the write vector
1879- // / sizes and are passed as one array covering both operations, i.e.:
1880- // /
1881- // / input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1882- // / \ / \ /
1883- // / read-sizes write-sizes
1884- // /
1885- // / (for brefity, in the diagram,
1886- // / * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1887- // / )
1878+ // / The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1879+ // / for the xfer_read operation). This is sufficient to infer the other vector
1880+ // / sizes required here.
18881881// /
18891882// / If the vector sizes are not provided:
18901883// / * the vector sizes are determined by the operands,
@@ -1907,8 +1900,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19071900 ArrayRef<bool > inputScalableVecDims,
19081901 SmallVectorImpl<Value> &newResults) {
19091902 if (!inputVectorSizes.empty ()) {
1910- assert (inputVectorSizes.size () ==
1911- unpackOp.getDestRank () + unpackOp.getSourceRank () &&
1903+ assert (inputVectorSizes.size () == unpackOp.getSourceRank () &&
19121904 " Invalid number of input vector sizes!" );
19131905 assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
19141906 " Incompatible number of vector sizes and vector scalable flags!" );
@@ -1928,22 +1920,15 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19281920
19291921 // 1. Obtain vector sizes for the read and write operations.
19301922 SmallVector<int64_t > readVectorSizes;
1931- SmallVector<int64_t > writeVectorSizes;
19321923 SmallVector<bool > readScalableVectorFlags;
1933- SmallVector<bool > writeScalableVectorFlags;
19341924
19351925 if (!inputVectorSizes.empty ()) {
19361926 // CASE 1.1: Vector sizes are user-specified.
19371927 readVectorSizes.assign (inputVectorSizes.begin (),
19381928 inputVectorSizes.begin () + sourceShape.size ());
1939- writeVectorSizes.assign (inputVectorSizes.begin () + sourceShape.size (),
1940- inputVectorSizes.end ());
19411929 readScalableVectorFlags.assign (inputScalableVecDims.begin (),
19421930 inputScalableVecDims.begin () +
19431931 sourceShape.size ());
1944- writeScalableVectorFlags.assign (inputScalableVecDims.begin () +
1945- sourceShape.size (),
1946- inputScalableVecDims.end ());
19471932 } else {
19481933 // CASE 1.2: Vector sizes are inferred from the static input tensor
19491934 // shapes.
@@ -1952,7 +1937,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19521937 return failure ();
19531938
19541939 readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
1955- writeVectorSizes.assign (destShape.begin (), destShape.end ());
19561940 useInBoundsInsteadOfMasking = true ;
19571941 }
19581942
@@ -2102,31 +2086,21 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21022086 unpackOp.getSourceType ().hasStaticShape ())
21032087 return success ();
21042088
2105- // The input vector sizes must be equal to:
2106- // * read-vector-rank + write-vector-rank
2089+ // The number of input vector sizes must be equal to:
2090+ // * read-vector-rank
21072091 if (!inputVectorSizes.empty () &&
2108- (inputVectorSizes.size () !=
2109- unpackOp.getDestRank () + unpackOp.getSourceRank ())) {
2092+ (inputVectorSizes.size () != unpackOp.getSourceRank ())) {
21102093 LDBG () << " Incorrect number of input vector sizes" ;
21112094 return failure ();
21122095 }
21132096
21142097 // Check the vector sizes for the read operation.
21152098 if (failed (vector::isValidMaskedInputVector (
2116- unpackOp.getSourceType ().getShape (),
2117- inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2099+ unpackOp.getSourceType ().getShape (), inputVectorSizes))) {
21182100 LDBG () << " Invalid vector sizes for the read operation" ;
21192101 return failure ();
21202102 }
21212103
2122- // Check the vector sizes for the write operation.
2123- if (failed (vector::isValidMaskedInputVector (
2124- unpackOp.getDestType ().getShape (),
2125- inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2126- LDBG () << " Invalid vector sizes for the write operation" ;
2127- return failure ();
2128- }
2129-
21302104 return success ();
21312105}
21322106
0 commit comments