@@ -2132,24 +2132,45 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
21322132 return success ();
21332133}
21342134
2135- // / Need to check if the inner-tiles are static/constant.
2135+ // // This hook considers two cases:
2136+ // / (1) If the input-vector-sizes are empty, then the vector sizes will be
2137+ // / infered. This is only possible when all shapes are static.
2138+ // / (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2139+ // / carry out basic sanity-checking.
21362140static LogicalResult
21372141vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
21382142 ArrayRef<int64_t > inputVectorSizes) {
2143+ // If there are no input vector sizes and all shapes are static, there is
2144+ // nothing left to check.
2145+ if (inputVectorSizes.empty () && unpackOp.getDestType ().hasStaticShape () &&
2146+ unpackOp.getSourceType ().hasStaticShape ())
2147+ return success ();
21392148
2140- if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
2141- return !getConstantIntValue (res).has_value ();
2142- })) {
2143- LDBG () << " Inner-tiles must be constant: " << unpackOp;
2149+ // The input vector sizes must be equal to:
2150+ // * read-vector-rank + write-vector-rank
2151+ if (!inputVectorSizes.empty ()) {
2152+ if (inputVectorSizes.size () !=
2153+ unpackOp.getDestRank () + unpackOp.getSourceRank ()) {
2154+ LDBG (" Incorrect number of input vector sizes" );
2155+ return failure ();
2156+ }
2157+ }
2158+
2159+ // Check the vector sizes for the write operation.
2160+ if (failed (vector::isValidMaskedInputVector (
2161+ unpackOp.getDestType ().getShape (),
2162+ inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2163+ LDBG (" Incorrect number of input vector sizes" );
21442164 return failure ();
21452165 }
2146- ArrayRef< int64_t > resultShape = unpackOp. getDestType (). getShape ();
2147- bool satisfyEmptyCond = inputVectorSizes. empty () &&
2148- unpackOp. getDestType (). hasStaticShape () &&
2149- unpackOp.getSourceType ().hasStaticShape ();
2150- if (!satisfyEmptyCond &&
2151- failed ( vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
2166+
2167+ // Check the vector sizes for the read operation.
2168+ if ( failed ( vector::isValidMaskedInputVector (
2169+ unpackOp.getSourceType ().getShape (),
2170+ inputVectorSizes. take_front (unpackOp. getSourceRank ())))) {
2171+ LDBG ( " Incorrect number of input vector sizes " );
21522172 return failure ();
2173+ }
21532174
21542175 return success ();
21552176}
0 commit comments