diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 15c467b21c81e..772762f7bd54b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -524,6 +524,40 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, if (!mask) { LDBG() << "No mask required"; + if (assumeDynamicDimsMatchVecSizes) { + llvm::TypeSwitch(opToMask) + .Case( + [&](auto xferOp) { + // For vector.transfer_read and vector.transfer_write, there is + // also the `in-bounds` attribute that has to be set explicitly + // to true. Otherwise, "out-of-bounds" access will be assumed + // and masks will be generated while lowering these. + LDBG() << "Assuming dynamic dimensions match vector sizes and " + "setting their in-bounds to true!"; + SmallVector inBoundsMap = xferOp.getInBoundsValues(); + ShapedType xferType = xferOp.getShapedType(); + AffineMap permMap = xferOp.getPermutationMap(); + // Only set the in-bounds values to true for dynamic dims. + // Different mechanisms will set these accordingly for the + // static dims. + for (unsigned i = 0; i < xferOp.getTransferRank(); i++) { + auto dimExpr = dyn_cast(permMap.getResult(i)); + // Skip broadcast dimensions. + if (!dimExpr) + continue; + unsigned pos = dimExpr.getPosition(); + if (xferType.isDynamicDim(pos)) + inBoundsMap[i] = true; + } + rewriter.modifyOpInPlace(xferOp, [&]() { + xferOp.setInBoundsAttr( + rewriter.getBoolArrayAttr(inBoundsMap)); + }); + }) + .Default([](Operation *op) { + // No-op if the operation is not an xfer read or write. + }); + } return opToMask; } diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 62bf1f55c9af2..11bea8d92432c 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -918,12 +918,17 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1 // CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>, // CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) { // CHECK-NOT: mask -// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32> -// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> -// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]] +// CHECK-SAME: memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]] +// `in-bounds` are set to true for dynamic dims with assume, static sizes will be inferred elsewhere. +// CHECK-SAME: in_bounds = [false, false, false, false, true, false]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]] +// CHECK-SAME: in_bounds = [false, false, false, true]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> // CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32> // CHECK: %[[RED:.*]] = vector.multi_reduction , %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> -// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> +// CHECK: vector.transfer_write %[[RED]], %[[C_IN]] +// CHECK-SAME: in_bounds = [false, false, false, true]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { @@ -1011,12 +1016,17 @@ func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: mem // CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>, // CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) { // CHECK-NOT: mask -// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32> -// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> -// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]] +// CHECK-SAME: memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32> +// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]] +// `in-bounds` are set to true for dynamic dims with assume, static sizes will be inferred elsewhere. +// CHECK-SAME: in_bounds = [false, false, false, false, false, true, false]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> +// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]] +// CHECK-SAME: in_bounds = [false, false, false, false, true]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> // CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32> // CHECK: %[[RED:.*]] = vector.multi_reduction , %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> -// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> +// CHECK: vector.transfer_write %[[RED]], %[[C_IN]] +// CHECK-SAME: in_bounds = [false, false, false, false, true]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {