diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 0477815f329bf..a7acdd2c018ea 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1093,6 +1093,11 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, if (inputShape.getShape().empty()) return VectorMemoryAccessKind::ScalarBroadcast; + // 0a. Is the result a 0-D vector? If yes, there are no iteration dimensions + // so the tensor.extract is a single scalar load regardless of the index. + if (resType.getRank() == 0) + return VectorMemoryAccessKind::ScalarBroadcast; + // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false // otherwise. bool isOutput1DVector = @@ -1254,19 +1259,22 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs, /*padding=*/std::nullopt, permutationMap, inBounds); - // Mask this broadcasting xfer_read here rather than relying on the generic - // path (the generic path assumes identity masking map, which wouldn't be - // valid here). - SmallVector readMaskShape = {1}; - auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type()); - auto allTrue = vector::ConstantMaskOp::create( - rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue); - auto *maskedReadOp = - mlir::vector::maskOperation(rewriter, transferReadOp, allTrue); + Operation *readOrMaskedReadOp = transferReadOp; + if (dstRank > 0) { + // Mask this broadcasting xfer_read here rather than relying on the + // generic path (the generic path assumes identity masking map, which + // wouldn't be valid here). + SmallVector readMaskShape = {1}; + auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type()); + auto allTrue = vector::ConstantMaskOp::create( + rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue); + readOrMaskedReadOp = + mlir::vector::maskOperation(rewriter, transferReadOp, allTrue); + } LDBG() << "Vectorised as scalar broadcast load: " << extractOp; return VectorizationHookResult{VectorizationHookStatus::NewOp, - maskedReadOp}; + readOrMaskedReadOp}; } // 2b. Handle contiguous access. diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir index d7722eac2b91f..e04a3f1a83d35 100644 --- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir @@ -684,3 +684,39 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic_column( // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]], %[[C0]]], %[[PV]] : tensor<3x3x3xf32>, vector // CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector to vector<3x1x1xf32> // CHECK: vector.transfer_write %[[READ_BCAST]], %[[INIT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32> + +// ----- + +// Rank-0 linalg.generic with tensor.extract using a data-dependent index. +// The tensor.extract should be classified as ScalarBroadcast (not Gather), +// producing a vector.transfer_read of a 0-D vector. + +func.func @rank0_tensor_extract_data_dependent_index( + %src: tensor<2xi64>, + %idx_tensor: tensor) -> tensor { + + %init = tensor.empty() : tensor + %res = linalg.generic { + indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], + iterator_types = []} + ins(%idx_tensor : tensor) outs(%init : tensor) { + ^bb0(%in: i64, %out: i64): + %idx = arith.index_cast %in : i64 to index + %val = tensor.extract %src[%idx] : tensor<2xi64> + linalg.yield %val : i64 + } -> tensor + + return %res : tensor +} + +// CHECK-LABEL: func.func @rank0_tensor_extract_data_dependent_index( +// CHECK-SAME: %[[SRC:.*]]: tensor<2xi64>, +// CHECK-SAME: %[[IDX_TENSOR:.*]]: tensor) -> tensor { +// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor +// CHECK-DAG: %[[PAD:.*]] = ub.poison : i64 +// CHECK: %[[READ_IDX:.*]] = vector.transfer_read %[[IDX_TENSOR]][], %[[PAD]] : tensor, vector +// CHECK: %[[SCALAR_IDX:.*]] = vector.extract %[[READ_IDX]][] : i64 from vector +// CHECK: %[[INDEX:.*]] = arith.index_cast %[[SCALAR_IDX]] : i64 to index +// CHECK: %[[READ_VAL:.*]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %{{.*}} : tensor<2xi64>, vector +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ_VAL]], %[[INIT]][] : vector, tensor +// CHECK: return %[[WRITE]] : tensor