Skip to content

Commit f5a4127

Browse files
committed
fixup! fixup! [mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP)
Fix pre-condition calculation
1 parent c4502e0 commit f5a4127

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,24 +2132,45 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
21322132
return success();
21332133
}
21342134

2135-
/// Need to check if the inner-tiles are static/constant.
2135+
//// This hook considers two cases:
2136+
/// (1) If the input-vector-sizes are empty, then the vector sizes will be
2137+
/// infered. This is only possible when all shapes are static.
2138+
/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2139+
/// carry out basic sanity-checking.
21362140
static LogicalResult
21372141
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21382142
ArrayRef<int64_t> inputVectorSizes) {
2143+
// If there are no input vector sizes and all shapes are static, there is
2144+
// nothing left to check.
2145+
if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2146+
unpackOp.getSourceType().hasStaticShape())
2147+
return success();
21392148

2140-
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
2141-
return !getConstantIntValue(res).has_value();
2142-
})) {
2143-
LDBG() << "Inner-tiles must be constant: " << unpackOp;
2149+
// The input vector sizes must be equal to:
2150+
// * read-vector-rank + write-vector-rank
2151+
if (!inputVectorSizes.empty()) {
2152+
if (inputVectorSizes.size() !=
2153+
unpackOp.getDestRank() + unpackOp.getSourceRank()) {
2154+
LDBG("Incorrect number of input vector sizes");
2155+
return failure();
2156+
}
2157+
}
2158+
2159+
// Check the vector sizes for the write operation.
2160+
if (failed(vector::isValidMaskedInputVector(
2161+
unpackOp.getDestType().getShape(),
2162+
inputVectorSizes.take_back(unpackOp.getDestRank())))) {
2163+
LDBG("Incorrect number of input vector sizes");
21442164
return failure();
21452165
}
2146-
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
2147-
bool satisfyEmptyCond = inputVectorSizes.empty() &&
2148-
unpackOp.getDestType().hasStaticShape() &&
2149-
unpackOp.getSourceType().hasStaticShape();
2150-
if (!satisfyEmptyCond &&
2151-
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
2166+
2167+
// Check the vector sizes for the read operation.
2168+
if (failed(vector::isValidMaskedInputVector(
2169+
unpackOp.getSourceType().getShape(),
2170+
inputVectorSizes.take_front(unpackOp.getSourceRank())))) {
2171+
LDBG("Incorrect number of input vector sizes");
21522172
return failure();
2173+
}
21532174

21542175
return success();
21552176
}

0 commit comments

Comments
 (0)