Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1996,6 +1996,12 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
OperandRange dynamicPosition = op.getDynamicPosition();
ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
ArrayRef<int64_t> vectorShape;
if constexpr (std::is_same_v<OpType, ExtractOp>) {
vectorShape = op.getSourceVectorType().getShape();
} else if constexpr (std::is_same_v<OpType, InsertOp>) {
vectorShape = op.getDestVectorType().getShape();
}

// If the dynamic operands is empty, it is returned directly.
if (!dynamicPosition.size())
Expand All @@ -2012,9 +2018,13 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
Attribute positionAttr = dynamicPositionAttr[index];
Value position = dynamicPosition[index++];
if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
staticPosition[i] = attr.getInt();
opChange = true;
continue;
int64_t value = attr.getInt();
// Do not fold if the value is out of bounds.
if (value >= 0 && value < vectorShape[i]) {
staticPosition[i] = attr.getInt();
opChange = true;
continue;
}
}
operands.push_back(position);
}
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3233,3 +3233,41 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
return %res : vector<4x1xi32>
}

// -----

// Check that out of bounds indices are not folded for vector.insert

// CHECK-LABEL: @fold_insert_oob
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> vector<4x1x2xi32> {
// CHECK: %[[OOB1:.*]] = arith.constant -2 : index
// CHECK: %[[OOB2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, %[[OOB1]], %[[OOB2]]] : i32 into vector<4x1x2xi32>
// CHECK: return %[[RES]] : vector<4x1x2xi32>
func.func @fold_insert_oob(%arg : vector<4x1x2xi32>) -> vector<4x1x2xi32> {
%0 = arith.constant 0 : index
%-2 = arith.constant -2 : index
%2 = arith.constant 2 : index
%1 = arith.constant 1 : i32
%res = vector.insert %1, %arg[%0, %-2, %2] : i32 into vector<4x1x2xi32>
return %res : vector<4x1x2xi32>
}

// -----

// Check that out of bounds indices are not folded for vector.extract

// CHECK-LABEL: @fold_extract_oob
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> i32 {
// CHECK: %[[OOB1:.*]] = arith.constant -2 : index
// CHECK: %[[OOB2:.*]] = arith.constant 2 : index
// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, %[[OOB1]], %[[OOB2]]] : i32 from vector<4x1x2xi32>
// CHECK: return %[[RES]] : i32
func.func @fold_extract_oob(%arg : vector<4x1x2xi32>) -> i32 {
%0 = arith.constant 0 : index
%-2 = arith.constant -2 : index
%2 = arith.constant 2 : index
%res = vector.extract %arg[%0, %-2, %2] : i32 from vector<4x1x2xi32>
return %res : i32
}