Skip to content

Commit aa72198

Browse files
committed
Address the remaining comments from HanHan
1 parent b073854 commit aa72198

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,22 +1932,21 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19321932
SmallVector<bool> readScalableVectorFlags;
19331933
SmallVector<bool> writeScalableVectorFlags;
19341934

1935-
// CASE 1.1: Vector sizes are user-specified.
19361935
if (!inputVectorSizes.empty()) {
1937-
readVectorSizes.append(inputVectorSizes.begin(),
1936+
// CASE 1.1: Vector sizes are user-specified.
1937+
readVectorSizes.assign(inputVectorSizes.begin(),
19381938
inputVectorSizes.begin() + sourceShape.size());
1939-
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
1939+
writeVectorSizes.assign(inputVectorSizes.begin() + sourceShape.size(),
19401940
inputVectorSizes.end());
1941-
readScalableVectorFlags.append(inputScalableVecDims.begin(),
1941+
readScalableVectorFlags.assign(inputScalableVecDims.begin(),
19421942
inputScalableVecDims.begin() +
19431943
sourceShape.size());
1944-
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
1944+
writeScalableVectorFlags.assign(inputScalableVecDims.begin() +
19451945
sourceShape.size(),
19461946
inputScalableVecDims.end());
1947-
}
1948-
1949-
// CASE 1. 2: Vector sizes have to be inferred.
1950-
if (writeVectorSizes.empty()) {
1947+
} else {
1948+
// CASE 1.2: Vector sizes are inferred from the static input tensor
1949+
// shapes.
19511950
if (ShapedType::isDynamicShape(destShape) ||
19521951
ShapedType::isDynamicShape(sourceShape))
19531952
return failure();
@@ -2105,12 +2104,11 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21052104

21062105
// The input vector sizes must be equal to:
21072106
// * read-vector-rank + write-vector-rank
2108-
if (!inputVectorSizes.empty()) {
2109-
if (inputVectorSizes.size() !=
2110-
unpackOp.getDestRank() + unpackOp.getSourceRank()) {
2111-
LDBG() << "Incorrect number of input vector sizes";
2112-
return failure();
2113-
}
2107+
if (!inputVectorSizes.empty() &&
2108+
(inputVectorSizes.size() !=
2109+
unpackOp.getDestRank() + unpackOp.getSourceRank())) {
2110+
LDBG() << "Incorrect number of input vector sizes";
2111+
return failure();
21142112
}
21152113

21162114
// Check the vector sizes for the read operation.

0 commit comments

Comments
 (0)