@@ -2106,24 +2106,45 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
21062106 return success ();
21072107}
21082108
2109- // / Need to check if the inner-tiles are static/constant.
2109+ // // This hook considers two cases:
2110+ // / (1) If the input-vector-sizes are empty, then the vector sizes will be
2111+ // / infered. This is only possible when all shapes are static.
2112+ // / (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2113+ // / carry out basic sanity-checking.
21102114static LogicalResult
21112115vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
21122116 ArrayRef<int64_t > inputVectorSizes) {
2117+ // If there are no input vector sizes and all shapes are static, there is
2118+ // nothing left to check.
2119+ if (inputVectorSizes.empty () && unpackOp.getDestType ().hasStaticShape () &&
2120+ unpackOp.getSourceType ().hasStaticShape ())
2121+ return success ();
21132122
2114- if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
2115- return !getConstantIntValue (res).has_value ();
2116- })) {
2117- LDBG (" Inner-tiles must be constant: " << unpackOp << " \n " );
2123+ // The input vector sizes must be equal to:
2124+ // * read-vector-rank + write-vector-rank
2125+ if (!inputVectorSizes.empty ()) {
2126+ if (inputVectorSizes.size () !=
2127+ unpackOp.getDestRank () + unpackOp.getSourceRank ()) {
2128+ LDBG (" Incorrect number of input vector sizes" );
2129+ return failure ();
2130+ }
2131+ }
2132+
2133+ // Check the vector sizes for the write operation.
2134+ if (failed (vector::isValidMaskedInputVector (
2135+ unpackOp.getDestType ().getShape (),
2136+ inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2137+ LDBG (" Incorrect number of input vector sizes" );
21182138 return failure ();
21192139 }
2120- ArrayRef< int64_t > resultShape = unpackOp. getDestType (). getShape ();
2121- bool satisfyEmptyCond = inputVectorSizes. empty () &&
2122- unpackOp. getDestType (). hasStaticShape () &&
2123- unpackOp.getSourceType ().hasStaticShape ();
2124- if (!satisfyEmptyCond &&
2125- failed ( vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
2140+
2141+ // Check the vector sizes for the read operation.
2142+ if ( failed ( vector::isValidMaskedInputVector (
2143+ unpackOp.getSourceType ().getShape (),
2144+ inputVectorSizes. take_front (unpackOp. getSourceRank ())))) {
2145+ LDBG ( " Incorrect number of input vector sizes " );
21262146 return failure ();
2147+ }
21272148
21282149 return success ();
21292150}
0 commit comments