diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 311e589547b89b..5463a7bd8f4c84 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -432,6 +432,9 @@ struct UnrolledOuterProductGenerator return failure(); int reductionSize = lhsType.getDimSize(reductionDim); + assert(reductionSize > 0 && + "Reduction dim must be a known static size to allow unrolling"); + // Incremental support for masking. if (mask && !maybeMask.has_value()) return failure(); @@ -997,7 +1000,7 @@ FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, }); if (lhsType.getScalableDims()[lhsIndex]) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex + diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex << ") is not supported yet"; }); dimSize = lhsType.getDimSize(lhsIndex); @@ -1005,7 +1008,7 @@ FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, iterIndex = iMap[1].getDimPosition(rhsIndex); if (rhsType.getScalableDims()[rhsIndex]) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex + diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex << ") is not supported yet"; }); dimSize = rhsType.getDimSize(rhsIndex);