From ca143d6a228c0dc28482a7818b5e8eea72263e8d Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Sat, 12 Apr 2025 20:15:50 +0100 Subject: [PATCH 1/3] [mlir][vector] Prevent folding of OOB values in insert/extract Out of bound position values should not be folded in vector.extract and vector.insert operations, as only in bounds constants and -1 are valid. --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 16 +++++++-- mlir/test/Dialect/Vector/canonicalize.mlir | 38 ++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 5a3983699d5a3..0031608e2c9d5 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1996,6 +1996,12 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, std::vector staticPosition = op.getStaticPosition().vec(); OperandRange dynamicPosition = op.getDynamicPosition(); ArrayRef dynamicPositionAttr = adaptor.getDynamicPosition(); + ArrayRef vectorShape; + if constexpr (std::is_same_v) { + vectorShape = op.getSourceVectorType().getShape(); + } else if constexpr (std::is_same_v) { + vectorShape = op.getDestVectorType().getShape(); + } // If the dynamic operands is empty, it is returned directly. if (!dynamicPosition.size()) @@ -2012,9 +2018,13 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, Attribute positionAttr = dynamicPositionAttr[index]; Value position = dynamicPosition[index++]; if (auto attr = mlir::dyn_cast_if_present(positionAttr)) { - staticPosition[i] = attr.getInt(); - opChange = true; - continue; + int64_t value = attr.getInt(); + // Do not fold if the value is out of bounds. + if (value >= 0 && value < vectorShape[i]) { + staticPosition[i] = attr.getInt(); + opChange = true; + continue; + } } operands.push_back(position); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index b7db8ec834be7..b0f502a0b7c36 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3233,3 +3233,41 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3 %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32> return %res : vector<4x1xi32> } + +// ----- + +// Check that out of bounds indices are not folded for vector.insert + +// CHECK-LABEL: @fold_insert_oob +// CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> vector<4x1x2xi32> { +// CHECK: %[[OOB1:.*]] = arith.constant -2 : index +// CHECK: %[[OOB2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL:.*]] = arith.constant 1 : i32 +// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, %[[OOB1]], %[[OOB2]]] : i32 into vector<4x1x2xi32> +// CHECK: return %[[RES]] : vector<4x1x2xi32> +func.func @fold_insert_oob(%arg : vector<4x1x2xi32>) -> vector<4x1x2xi32> { + %0 = arith.constant 0 : index + %-2 = arith.constant -2 : index + %2 = arith.constant 2 : index + %1 = arith.constant 1 : i32 + %res = vector.insert %1, %arg[%0, %-2, %2] : i32 into vector<4x1x2xi32> + return %res : vector<4x1x2xi32> +} + +// ----- + +// Check that out of bounds indices are not folded for vector.extract + +// CHECK-LABEL: @fold_extract_oob +// CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> i32 { +// CHECK: %[[OOB1:.*]] = arith.constant -2 : index +// CHECK: %[[OOB2:.*]] = arith.constant 2 : index +// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, %[[OOB1]], %[[OOB2]]] : i32 from vector<4x1x2xi32> +// CHECK: return %[[RES]] : i32 +func.func @fold_extract_oob(%arg : vector<4x1x2xi32>) -> i32 { + %0 = arith.constant 0 : index + %-2 = arith.constant -2 : index + %2 = arith.constant 2 : index + %res = vector.extract %arg[%0, %-2, %2] : i32 from vector<4x1x2xi32> + return %res : i32 +} From 8d7718a717e7f9bfe0ffe5d8668f0cf72ef496f4 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 17 Apr 2025 05:45:16 +0100 Subject: [PATCH 2/3] Address comments --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +- mlir/test/Dialect/Vector/canonicalize.mlir | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 0031608e2c9d5..95f9d8a134de4 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1999,7 +1999,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, ArrayRef vectorShape; if constexpr (std::is_same_v) { vectorShape = op.getSourceVectorType().getShape(); - } else if constexpr (std::is_same_v) { + } else { vectorShape = op.getDestVectorType().getShape(); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index b0f502a0b7c36..ec2f823f4c701 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3236,7 +3236,7 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3 // ----- -// Check that out of bounds indices are not folded for vector.insert +// Check that out of bounds indices are not folded for vector.insert. // CHECK-LABEL: @fold_insert_oob // CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> vector<4x1x2xi32> { @@ -3246,17 +3246,17 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3 // CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, %[[OOB1]], %[[OOB2]]] : i32 into vector<4x1x2xi32> // CHECK: return %[[RES]] : vector<4x1x2xi32> func.func @fold_insert_oob(%arg : vector<4x1x2xi32>) -> vector<4x1x2xi32> { - %0 = arith.constant 0 : index - %-2 = arith.constant -2 : index - %2 = arith.constant 2 : index - %1 = arith.constant 1 : i32 - %res = vector.insert %1, %arg[%0, %-2, %2] : i32 into vector<4x1x2xi32> + %c0 = arith.constant 0 : index + %c-2 = arith.constant -2 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : i32 + %res = vector.insert %c1, %arg[%c0, %c-2, %c2] : i32 into vector<4x1x2xi32> return %res : vector<4x1x2xi32> } // ----- -// Check that out of bounds indices are not folded for vector.extract +// Check that out of bounds indices are not folded for vector.extract. // CHECK-LABEL: @fold_extract_oob // CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> i32 { @@ -3265,9 +3265,9 @@ func.func @fold_insert_oob(%arg : vector<4x1x2xi32>) -> vector<4x1x2xi32> { // CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, %[[OOB1]], %[[OOB2]]] : i32 from vector<4x1x2xi32> // CHECK: return %[[RES]] : i32 func.func @fold_extract_oob(%arg : vector<4x1x2xi32>) -> i32 { - %0 = arith.constant 0 : index - %-2 = arith.constant -2 : index - %2 = arith.constant 2 : index - %res = vector.extract %arg[%0, %-2, %2] : i32 from vector<4x1x2xi32> + %c0 = arith.constant 0 : index + %c-2 = arith.constant -2 : index + %c2 = arith.constant 2 : index + %res = vector.extract %arg[%c0, %c-2, %c2] : i32 from vector<4x1x2xi32> return %res : i32 } From 34ab7330bb7751c0016b543b49be7e77f855e6b3 Mon Sep 17 00:00:00 2001 From: Fehr Mathieu Date: Fri, 18 Apr 2025 05:09:09 +0200 Subject: [PATCH 3/3] Update mlir/lib/Dialect/Vector/IR/VectorOps.cpp Co-authored-by: Mehdi Amini --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 95f9d8a134de4..71077d4943aa5 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1997,11 +1997,10 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, OperandRange dynamicPosition = op.getDynamicPosition(); ArrayRef dynamicPositionAttr = adaptor.getDynamicPosition(); ArrayRef vectorShape; - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) vectorShape = op.getSourceVectorType().getShape(); - } else { + else vectorShape = op.getDestVectorType().getShape(); - } // If the dynamic operands is empty, it is returned directly. if (!dynamicPosition.size())