diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td index 0f08f61d7b2575..e974e5bd046a98 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td @@ -16,6 +16,11 @@ include "mlir/Dialect/Vector/IR/Vector.td" include "mlir/IR/EnumAttr.td" +class Vector_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + // The "kind" of combining function for contractions and reductions. def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">; def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">; @@ -82,4 +87,42 @@ def Vector_PrintPunctuation : EnumAttr +{ + let summary = "strided vector slice"; + + let description = [{ + An attribute that represents a strided slice of a vector. + + *Examples:* + + Without sizes: + + `{offsets = [0, 0, 2], strides = [1, 1]}` + + With sizes (used for extract_strided_slice): + + `{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}` + + TODO? Come up with a range syntax (similar to Python slices). + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$offsets, + OptionalArrayRefParameter<"int64_t">:$sizes, + ArrayRefParameter<"int64_t">:$strides + ); + + let builders = [AttrBuilder<(ins "ArrayRef":$offsets, "ArrayRef":$strides), [{ + return $_get($_ctxt, offsets, ArrayRef{}, strides); + }]> + ]; + + let assemblyFormat = [{ + `{` `offsets` `=` `[` $offsets `]` `,` + (`sizes` `=` `[` $sizes^ `]` `,`)? + `strides` `=` `[` $strides `]` `}` + }]; +} + #endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index cd19d356a6739d..5f9b4f6b29b0f8 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1040,8 +1040,8 @@ def Vector_InsertStridedSliceOp : PredOpTrait<"operand #0 and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, AllTypesMatch<["dest", "res"]>]>, - Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets, - I64ArrayAttr:$strides)>, + Arguments<(ins AnyVector:$source, AnyVector:$dest, + Vector_StridedSliceAttr:$strided_slice)>, Results<(outs AnyVector:$res)> { let summary = "strided_slice operation"; let description = [{ @@ -1060,13 +1060,13 @@ def Vector_InsertStridedSliceOp : ```mlir %2 = vector.insert_strided_slice %0, %1 - {offsets = [0, 0, 2], strides = [1, 1]}: - vector<2x4xf32> into vector<16x4x8xf32> + {offsets = [0, 0, 2], strides = [1, 1]} + : vector<2x4xf32> into vector<16x4x8xf32> ``` }]; let assemblyFormat = [{ - $source `,` $dest attr-dict `:` type($source) `into` type($dest) + $source `,` $dest $strided_slice attr-dict `:` type($source) `into` type($dest) }]; let builders = [ @@ -1081,10 +1081,13 @@ def Vector_InsertStridedSliceOp : return ::llvm::cast(getDest().getType()); } bool hasNonUnitStrides() { - return llvm::any_of(getStrides(), [](Attribute attr) { - return ::llvm::cast(attr).getInt() != 1; + return llvm::any_of(getStrides(), [](int64_t stride) { + return stride != 1; }); } + + ArrayRef getOffsets() { return getStridedSlice().getOffsets(); } + ArrayRef getStrides() { return getStridedSlice().getStrides(); } }]; let hasFolder = 1; @@ -1182,8 +1185,7 @@ def Vector_ExtractStridedSliceOp : Vector_Op<"extract_strided_slice", [Pure, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, - Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets, - I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>, + Arguments<(ins AnyVector:$vector, Vector_StridedSliceAttr:$strided_slice)>, Results<(outs AnyVector)> { let summary = "extract_strided_slice operation"; let description = [{ @@ -1201,12 +1203,8 @@ def Vector_ExtractStridedSliceOp : ```mlir %1 = vector.extract_strided_slice %0 - {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}: - vector<4x8x16xf32> to vector<2x4x16xf32> - - // TODO: Evolve to a range form syntax similar to: - %1 = vector.extract_strided_slice %0[0:2:1][2:4:1] - vector<4x8x16xf32> to vector<2x4x16xf32> + {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]} + : vector<4x8x16xf32> to vector<2x4x16xf32> ``` }]; let builders = [ @@ -1217,17 +1215,20 @@ def Vector_ExtractStridedSliceOp : VectorType getSourceVectorType() { return ::llvm::cast(getVector().getType()); } - void getOffsets(SmallVectorImpl &results); bool hasNonUnitStrides() { - return llvm::any_of(getStrides(), [](Attribute attr) { - return ::llvm::cast(attr).getInt() != 1; + return llvm::any_of(getStrides(), [](int64_t stride) { + return stride != 1; }); } + + ArrayRef getOffsets() { return getStridedSlice().getOffsets(); } + ArrayRef getSizes() { return getStridedSlice().getSizes(); } + ArrayRef getStrides() { return getStridedSlice().getStrides(); } }]; let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; - let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)"; + let assemblyFormat = "$vector $strided_slice attr-dict `:` type($vector) `to` type(results)"; } // TODO: Tighten semantics so that masks and inbounds can't be used diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index e059d31ca5842f..682414c63c06a7 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -940,12 +940,6 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, return success(); } -static void populateFromInt64AttrArray(ArrayAttr arrayAttr, - SmallVectorImpl &results) { - for (auto attr : arrayAttr) - results.push_back(cast(attr).getInt()); -} - static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, @@ -996,11 +990,8 @@ convertExtractStridedSlice(RewriterBase &rewriter, auto sourceVector = it->second; // offset and sizes at warp-level of onwership. - SmallVector offsets; - populateFromInt64AttrArray(op.getOffsets(), offsets); + ArrayRef offsets = op.getOffsets(); - SmallVector sizes; - populateFromInt64AttrArray(op.getSizes(), sizes); ArrayRef warpVectorShape = op.getSourceVectorType().getShape(); // Compute offset in vector registers. Note that the mma.sync vector registers diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 21b8858989839b..4d4e5ebb4f4282 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -46,9 +46,6 @@ static uint64_t getFirstIntValue(ValueRange values) { static uint64_t getFirstIntValue(ArrayRef attr) { return cast(attr[0]).getInt(); } -static uint64_t getFirstIntValue(ArrayAttr attr) { - return (*attr.getAsValueRange().begin()).getZExtValue(); -} static uint64_t getFirstIntValue(ArrayRef foldResults) { auto attr = foldResults[0].dyn_cast(); if (attr) @@ -187,9 +184,9 @@ struct VectorExtractStridedSliceOpConvert final if (!dstType) return failure(); - uint64_t offset = getFirstIntValue(extractOp.getOffsets()); - uint64_t size = getFirstIntValue(extractOp.getSizes()); - uint64_t stride = getFirstIntValue(extractOp.getStrides()); + int64_t offset = extractOp.getOffsets().front(); + int64_t size = extractOp.getSizes().front(); + int64_t stride = extractOp.getStrides().front(); if (stride != 1) return failure(); @@ -323,10 +320,10 @@ struct VectorInsertStridedSliceOpConvert final Value srcVector = adaptor.getOperands().front(); Value dstVector = adaptor.getOperands().back(); - uint64_t stride = getFirstIntValue(insertOp.getStrides()); + uint64_t stride = insertOp.getStrides().front(); if (stride != 1) return failure(); - uint64_t offset = getFirstIntValue(insertOp.getOffsets()); + uint64_t offset = insertOp.getOffsets().front(); if (isa(srcVector.getType())) { assert(!isa(dstVector.getType())); diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index 70fd9bc0a1e68f..2b0e6445dfda1c 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -549,11 +549,8 @@ struct ExtensionOverExtractStridedSlice final if (failed(ext)) return failure(); - VectorType origTy = op.getType(); - VectorType extractTy = - origTy.cloneWith(origTy.getShape(), ext->getInElementType()); Value newExtract = rewriter.create( - op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(), + op.getLoc(), ext->getIn(), op.getOffsets(), op.getSizes(), op.getStrides()); ext->recreateAndReplace(rewriter, op, newExtract); return success(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2a3b9f2091ab39..3a0d30098c369a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1340,13 +1340,6 @@ LogicalResult vector::ExtractOp::verify() { return success(); } -template -static SmallVector extractVector(ArrayAttr arrayAttr) { - return llvm::to_vector<4>(llvm::map_range( - arrayAttr.getAsRange(), - [](IntegerAttr attr) { return static_cast(attr.getInt()); })); -} - /// Fold the result of chains of ExtractOp in place by simply concatenating the /// positions. static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { @@ -1770,8 +1763,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) { return Value(); // Trim offsets for dimensions fully extracted. - auto sliceOffsets = - extractVector(extractStridedSliceOp.getOffsets()); + SmallVector sliceOffsets(extractStridedSliceOp.getOffsets()); while (!sliceOffsets.empty()) { size_t lastOffset = sliceOffsets.size() - 1; if (sliceOffsets.back() != 0 || @@ -1825,12 +1817,10 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) { insertOp.getSourceVectorType().getRank(); if (destinationRank > insertOp.getSourceVectorType().getRank()) return Value(); - auto insertOffsets = extractVector(insertOp.getOffsets()); + ArrayRef insertOffsets = insertOp.getOffsets(); ArrayRef extractOffsets = extractOp.getStaticPosition(); - if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { - return llvm::cast(attr).getInt() != 1; - })) + if (insertOp.hasNonUnitStrides()) return Value(); bool disjoint = false; SmallVector offsetDiffs; @@ -2195,12 +2185,6 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(foldExtractFromFromElements); } -static void populateFromInt64AttrArray(ArrayAttr arrayAttr, - SmallVectorImpl &results) { - for (auto attr : arrayAttr) - results.push_back(llvm::cast(attr).getInt()); -} - //===----------------------------------------------------------------------===// // FmaOp //===----------------------------------------------------------------------===// @@ -2907,26 +2891,8 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, ArrayRef offsets, ArrayRef strides) { - result.addOperands({source, dest}); - auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); - auto stridesAttr = getVectorSubscriptAttr(builder, strides); - result.addTypes(dest.getType()); - result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name), - offsetsAttr); - result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name), - stridesAttr); -} - -// TODO: Should be moved to Tablegen ConfinedAttr attributes. -template -static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, - ArrayAttr arrayAttr, - ArrayRef shape, - StringRef attrName) { - if (arrayAttr.size() > shape.size()) - return op.emitOpError("expected ") - << attrName << " attribute of rank no greater than vector rank"; - return success(); + build(builder, result, source, dest, + StridedSliceAttr::get(builder.getContext(), offsets, strides)); } // Returns true if all integers in `arrayAttr` are in the half-open [min, max} @@ -2934,16 +2900,15 @@ static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, // Otherwise, the admissible interval is [min, max]. template static LogicalResult -isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, - int64_t max, StringRef attrName, - bool halfOpen = true) { - for (auto attr : arrayAttr) { - auto val = llvm::cast(attr).getInt(); +isIntArrayConfinedToRange(OpType op, ArrayRef array, int64_t min, + int64_t max, StringRef arrayName, + bool halfOpen = true) { + for (int64_t val : array) { auto upper = max; if (!halfOpen) upper += 1; if (val < min || val >= upper) - return op.emitOpError("expected ") << attrName << " to be confined to [" + return op.emitOpError("expected ") << arrayName << " to be confined to [" << min << ", " << upper << ")"; } return success(); @@ -2954,13 +2919,12 @@ isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, // Otherwise, the admissible interval is [min, max]. template static LogicalResult -isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, - ArrayRef shape, StringRef attrName, - bool halfOpen = true, int64_t min = 0) { - for (auto [index, attrDimPair] : - llvm::enumerate(llvm::zip_first(arrayAttr, shape))) { - int64_t val = llvm::cast(std::get<0>(attrDimPair)).getInt(); - int64_t max = std::get<1>(attrDimPair); +isIntArrayConfinedToShape(OpType op, ArrayRef array, + ArrayRef shape, StringRef attrName, + bool halfOpen = true, int64_t min = 0) { + for (auto [index, dimPair] : llvm::enumerate(llvm::zip_first(array, shape))) { + int64_t val, max; + std::tie(val, max) = dimPair; if (!halfOpen) max += 1; if (val < min || val >= max) @@ -2977,40 +2941,32 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, // If `halfOpen` is true then the admissible interval is [min, max). Otherwise, // the admissible interval is [min, max]. template -static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( - OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, - ArrayRef shape, StringRef attrName1, StringRef attrName2, +static LogicalResult isSumOfIntArrayConfinedToShape( + OpType op, ArrayRef array1, ArrayRef array2, + ArrayRef shape, StringRef arrayName1, StringRef arrayName2, bool halfOpen = true, int64_t min = 1) { - assert(arrayAttr1.size() <= shape.size()); - assert(arrayAttr2.size() <= shape.size()); - for (auto [index, it] : - llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) { - auto val1 = llvm::cast(std::get<0>(it)).getInt(); - auto val2 = llvm::cast(std::get<1>(it)).getInt(); - int64_t max = std::get<2>(it); + assert(array1.size() <= shape.size()); + assert(array2.size() <= shape.size()); + for (auto [index, it] : llvm::enumerate(llvm::zip(array1, array2, shape))) { + int64_t val1, val2, max; + std::tie(val1, val2, max) = it; if (!halfOpen) max += 1; if (val1 + val2 < 0 || val1 + val2 >= max) return op.emitOpError("expected sum(") - << attrName1 << ", " << attrName2 << ") dimension " << index + << arrayName1 << ", " << arrayName2 << ") dimension " << index << " to be confined to [" << min << ", " << max << ")"; } return success(); } -static ArrayAttr makeI64ArrayAttr(ArrayRef values, - MLIRContext *context) { - auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { - return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); - }); - return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); -} - LogicalResult InsertStridedSliceOp::verify() { auto sourceVectorType = getSourceVectorType(); auto destVectorType = getDestVectorType(); - auto offsets = getOffsetsAttr(); - auto strides = getStridesAttr(); + auto offsets = getOffsets(); + auto strides = getStrides(); + if (!getStridedSlice().getSizes().empty()) + return emitOpError("slice sizes not supported"); if (offsets.size() != static_cast(destVectorType.getRank())) return emitOpError( "expected offsets of same size as destination vector rank"); @@ -3025,18 +2981,14 @@ LogicalResult InsertStridedSliceOp::verify() { SmallVector sourceShapeAsDestShape( destShape.size() - sourceShape.size(), 0); sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end()); - auto offName = InsertStridedSliceOp::getOffsetsAttrName(); - auto stridesName = InsertStridedSliceOp::getStridesAttrName(); - if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape, - offName)) || - failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1, - /*max=*/1, stridesName, - /*halfOpen=*/false)) || - failed(isSumOfIntegerArrayAttrConfinedToShape( - *this, offsets, - makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape, - offName, "source vector shape", - /*halfOpen=*/false, /*min=*/1))) + if (failed(isIntArrayConfinedToShape(*this, offsets, destShape, "offsets")) || + failed(isIntArrayConfinedToRange(*this, strides, /*min=*/1, + /*max=*/1, "strides", + /*halfOpen=*/false)) || + failed(isSumOfIntArrayConfinedToShape(*this, offsets, + sourceShapeAsDestShape, destShape, + "offsets", "source vector shape", + /*halfOpen=*/false, /*min=*/1))) return failure(); unsigned rankDiff = destShape.size() - sourceShape.size(); @@ -3161,7 +3113,7 @@ class InsertStridedSliceConstantFolder final VectorType sliceVecTy = sourceValue.getType(); ArrayRef sliceShape = sliceVecTy.getShape(); int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank(); - SmallVector offsets = getI64SubArray(op.getOffsets()); + ArrayRef offsets = op.getOffsets(); SmallVector destStrides = computeStrides(destTy.getShape()); // Calcualte the destination element indices by enumerating all slice @@ -3336,14 +3288,15 @@ Type OuterProductOp::getExpectedMaskType() { // 2. Add sizes from 'vectorType' for remaining dims. // Scalable flags are inherited from 'vectorType'. static Type inferStridedSliceOpResultType(VectorType vectorType, - ArrayAttr offsets, ArrayAttr sizes, - ArrayAttr strides) { + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); SmallVector shape; shape.reserve(vectorType.getRank()); unsigned idx = 0; for (unsigned e = offsets.size(); idx < e; ++idx) - shape.push_back(llvm::cast(sizes[idx]).getInt()); + shape.push_back(sizes[idx]); for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) shape.push_back(vectorType.getShape()[idx]); @@ -3356,51 +3309,48 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, ArrayRef sizes, ArrayRef strides) { result.addOperands(source); - auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); - auto sizesAttr = getVectorSubscriptAttr(builder, sizes); - auto stridesAttr = getVectorSubscriptAttr(builder, strides); - result.addTypes( - inferStridedSliceOpResultType(llvm::cast(source.getType()), - offsetsAttr, sizesAttr, stridesAttr)); - result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name), - offsetsAttr); - result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name), - sizesAttr); - result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name), - stridesAttr); + auto stridedSliceAttr = + StridedSliceAttr::get(builder.getContext(), offsets, sizes, strides); + result.addTypes(inferStridedSliceOpResultType( + llvm::cast(source.getType()), offsets, sizes, strides)); + result.addAttribute( + ExtractStridedSliceOp::getStridedSliceAttrName(result.name), + stridedSliceAttr); } LogicalResult ExtractStridedSliceOp::verify() { auto type = getSourceVectorType(); - auto offsets = getOffsetsAttr(); - auto sizes = getSizesAttr(); - auto strides = getStridesAttr(); + auto offsets = getOffsets(); + auto sizes = getSizes(); + auto strides = getStrides(); if (offsets.size() != sizes.size() || offsets.size() != strides.size()) return emitOpError( "expected offsets, sizes and strides attributes of same size"); auto shape = type.getShape(); - auto offName = getOffsetsAttrName(); - auto sizesName = getSizesAttrName(); - auto stridesName = getStridesAttrName(); - if (failed( - isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || - failed( - isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || - failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, - stridesName)) || - failed( - isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || - failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, - /*halfOpen=*/false, - /*min=*/1)) || - failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1, - /*max=*/1, stridesName, - /*halfOpen=*/false)) || - failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, - shape, offName, sizesName, - /*halfOpen=*/false))) + auto isIntArraySmallerThanShape = [&](ArrayRef array, + StringRef arrayName) -> LogicalResult { + if (array.size() > shape.size()) + return emitOpError("expected ") + << arrayName << " to have rank no greater than vector rank"; + return success(); + }; + + if (failed(isIntArraySmallerThanShape(offsets, "offsets")) || + failed(isIntArraySmallerThanShape(sizes, "sizes")) || + failed(isIntArraySmallerThanShape(strides, "strides")) || + failed(isIntArrayConfinedToShape(*this, offsets, shape, "offsets")) || + failed(isIntArrayConfinedToShape(*this, sizes, shape, "sizes", + /*halfOpen=*/false, + /*min=*/1)) || + failed(isIntArrayConfinedToRange(*this, strides, /*min=*/1, + /*max=*/1, "strides", + /*halfOpen=*/false)) || + failed(isSumOfIntArrayConfinedToShape(*this, offsets, sizes, shape, + "offsets", "sizes", + /*halfOpen=*/false))) { return failure(); + } auto resultType = inferStridedSliceOpResultType(getSourceVectorType(), offsets, sizes, strides); @@ -3410,7 +3360,7 @@ LogicalResult ExtractStridedSliceOp::verify() { for (unsigned idx = 0; idx < sizes.size(); ++idx) { if (type.getScalableDims()[idx]) { auto inputDim = type.getShape()[idx]; - auto inputSize = llvm::cast(sizes[idx]).getInt(); + auto inputSize = sizes[idx]; if (inputDim != inputSize) return emitOpError("expected size at idx=") << idx @@ -3428,20 +3378,16 @@ LogicalResult ExtractStridedSliceOp::verify() { // extracted vector is a subset of one of the vector inserted. static LogicalResult foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { - // Helper to extract integer out of ArrayAttr. - auto getElement = [](ArrayAttr array, int idx) { - return llvm::cast(array[idx]).getInt(); - }; - ArrayAttr extractOffsets = op.getOffsets(); - ArrayAttr extractStrides = op.getStrides(); - ArrayAttr extractSizes = op.getSizes(); + ArrayRef extractOffsets = op.getOffsets(); + ArrayRef extractStrides = op.getStrides(); + ArrayRef extractSizes = op.getSizes(); auto insertOp = op.getVector().getDefiningOp(); while (insertOp) { if (op.getSourceVectorType().getRank() != insertOp.getSourceVectorType().getRank()) return failure(); - ArrayAttr insertOffsets = insertOp.getOffsets(); - ArrayAttr insertStrides = insertOp.getStrides(); + ArrayRef insertOffsets = insertOp.getOffsets(); + ArrayRef insertStrides = insertOp.getStrides(); // If the rank of extract is greater than the rank of insert, we are likely // extracting a partial chunk of the vector inserted. if (extractOffsets.size() > insertOffsets.size()) @@ -3450,12 +3396,12 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { bool disjoint = false; SmallVector offsetDiffs; for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { - if (getElement(extractStrides, dim) != getElement(insertStrides, dim)) + if (extractStrides[dim] != insertStrides[dim]) return failure(); - int64_t start = getElement(insertOffsets, dim); + int64_t start = insertOffsets[dim]; int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim); - int64_t offset = getElement(extractOffsets, dim); - int64_t size = getElement(extractSizes, dim); + int64_t offset = extractOffsets[dim]; + int64_t size = extractSizes[dim]; // Check if the start of the extract offset is in the interval inserted. if (start <= offset && offset < end) { // If the extract interval overlaps but is not fully included we may @@ -3473,7 +3419,9 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { op.setOperand(insertOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); - op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs)); + auto stridedSliceAttr = StridedSliceAttr::get( + op.getContext(), offsetDiffs, op.getSizes(), op.getStrides()); + op.setStridedSliceAttr(stridedSliceAttr); return success(); } // If the chunk extracted is disjoint from the chunk inserted, keep looking @@ -3496,11 +3444,6 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { return getResult(); return {}; } - -void ExtractStridedSliceOp::getOffsets(SmallVectorImpl &results) { - populateFromInt64AttrArray(getOffsets(), results); -} - namespace { // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to @@ -3524,11 +3467,8 @@ class StridedSliceConstantMaskFolder final // Gather constant mask dimension sizes. ArrayRef maskDimSizes = constantMaskOp.getMaskDimSizes(); // Gather strided slice offsets and sizes. - SmallVector sliceOffsets; - populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(), - sliceOffsets); - SmallVector sliceSizes; - populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes); + ArrayRef sliceOffsets = extractStridedSliceOp.getOffsets(); + ArrayRef sliceSizes = extractStridedSliceOp.getSizes(); // Compute slice of vector mask region. SmallVector sliceMaskDimSizes; @@ -3620,10 +3560,10 @@ class StridedSliceNonSplatConstantFolder final // Expand offsets and sizes to match the vector rank. SmallVector offsets(sliceRank, 0); - copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin()); + copy(extractStridedSliceOp.getOffsets(), offsets.begin()); SmallVector sizes(sourceShape); - copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin()); + copy(extractStridedSliceOp.getSizes(), sizes.begin()); // Calculate the slice elements by enumerating all slice positions and // linearizing them. The enumeration order is lexicographic which yields a @@ -3686,10 +3626,9 @@ class StridedSliceBroadcast final bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1); if (!lowerDimMatch && !isScalarSrc) { source = rewriter.create( - op->getLoc(), source, - getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff), - getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff), - getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff)); + op->getLoc(), source, op.getOffsets().drop_front(rankDiff), + op.getSizes().drop_front(rankDiff), + op.getStrides().drop_front(rankDiff)); } rewriter.replaceOpWithNewOp(op, op.getType(), source); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index a1f67bd0e9ed35..9d18fe74178c2f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -130,23 +130,16 @@ struct ScanToArithOps : public OpRewritePattern { VectorType initialValueType = scanOp.getInitialValueType(); int64_t initialValueRank = initialValueType.getRank(); - SmallVector reductionShape(destShape); - reductionShape[reductionDim] = 1; - VectorType reductionType = VectorType::get(reductionShape, elType); SmallVector offsets(destRank, 0); SmallVector strides(destRank, 1); SmallVector sizes(destShape); sizes[reductionDim] = 1; - ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); - ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); Value lastOutput, lastInput; for (int i = 0; i < destShape[reductionDim]; i++) { offsets[reductionDim] = i; - ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); Value input = rewriter.create( - loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, - scanStrides); + loc, scanOp.getSource(), offsets, sizes, strides); Value output; if (i == 0) { if (inclusive) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 42ac717b44c4b9..0b8a2ab6b2fa0b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -71,11 +71,6 @@ struct CastAwayExtractStridedSliceLeadingOneDim int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); VectorType oldDstType = extractOp.getType(); - VectorType newDstType = - VectorType::get(oldDstType.getShape().drop_front(dropCount), - oldDstType.getElementType(), - oldDstType.getScalableDims().drop_front(dropCount)); - Location loc = extractOp.getLoc(); Value newSrcVector = rewriter.create( @@ -83,15 +78,12 @@ struct CastAwayExtractStridedSliceLeadingOneDim // The offsets/sizes/strides attribute can have a less number of elements // than the input vector's rank: it is meant for the leading dimensions. - auto newOffsets = rewriter.getArrayAttr( - extractOp.getOffsets().getValue().drop_front(dropCount)); - auto newSizes = rewriter.getArrayAttr( - extractOp.getSizes().getValue().drop_front(dropCount)); - auto newStrides = rewriter.getArrayAttr( - extractOp.getStrides().getValue().drop_front(dropCount)); + auto newOffsets = extractOp.getOffsets().drop_front(dropCount); + auto newSizes = extractOp.getSizes().drop_front(dropCount); + auto newStrides = extractOp.getStrides().drop_front(dropCount); auto newExtractOp = rewriter.create( - loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); + loc, newSrcVector, newOffsets, newSizes, newStrides); rewriter.replaceOpWithNewOp(extractOp, oldDstType, newExtractOp); @@ -126,13 +118,11 @@ struct CastAwayInsertStridedSliceLeadingOneDim Value newDstVector = rewriter.create( loc, insertOp.getDest(), splatZero(dstDropCount)); - auto newOffsets = rewriter.getArrayAttr( - insertOp.getOffsets().getValue().take_back(newDstType.getRank())); - auto newStrides = rewriter.getArrayAttr( - insertOp.getStrides().getValue().take_back(newSrcType.getRank())); + auto newOffsets = insertOp.getOffsets().take_back(newDstType.getRank()); + auto newStrides = insertOp.getStrides().take_back(newSrcType.getRank()); auto newInsertOp = rewriter.create( - loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); + loc, newSrcVector, newDstVector, newOffsets, newStrides); rewriter.replaceOpWithNewOp(insertOp, oldDstType, newInsertOp); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index ec2ef3fc7501c2..4de58ed7526a9d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -63,7 +63,7 @@ class DecomposeDifferentRankInsertStridedSlice auto srcType = op.getSourceVectorType(); auto dstType = op.getDestVectorType(); - if (op.getOffsets().getValue().empty()) + if (op.getOffsets().empty()) return failure(); auto loc = op.getLoc(); @@ -76,21 +76,17 @@ class DecomposeDifferentRankInsertStridedSlice // Extract / insert the subvector of matching rank and InsertStridedSlice // on it. Value extracted = rewriter.create( - loc, op.getDest(), - getI64SubArray(op.getOffsets(), /*dropFront=*/0, - /*dropBack=*/rankRest)); + loc, op.getDest(), op.getOffsets().drop_back(rankRest)); // A different pattern will kick in for InsertStridedSlice with matching // ranks. auto stridedSliceInnerOp = rewriter.create( - loc, op.getSource(), extracted, - getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff), - getI64SubArray(op.getStrides(), /*dropFront=*/0)); - - rewriter.replaceOpWithNewOp( - op, stridedSliceInnerOp.getResult(), op.getDest(), - getI64SubArray(op.getOffsets(), /*dropFront=*/0, - /*dropBack=*/rankRest)); + loc, op.getSource(), extracted, op.getOffsets().drop_front(rankDiff), + op.getStrides()); + + rewriter.replaceOpWithNewOp(op, stridedSliceInnerOp.getResult(), + op.getDest(), + op.getOffsets().drop_back(rankRest)); return success(); } }; @@ -119,7 +115,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle auto srcType = op.getSourceVectorType(); auto dstType = op.getDestVectorType(); - if (op.getOffsets().getValue().empty()) + if (op.getOffsets().empty()) return failure(); int64_t srcRank = srcType.getRank(); @@ -133,11 +129,9 @@ class ConvertSameRankInsertStridedSliceIntoShuffle return success(); } - int64_t offset = - cast(op.getOffsets().getValue().front()).getInt(); + int64_t offset = op.getOffsets().front(); int64_t size = srcType.getShape().front(); - int64_t stride = - cast(op.getStrides().getValue().front()).getInt(); + int64_t stride = op.getStrides().front(); auto loc = op.getLoc(); Value res = op.getDest(); @@ -181,9 +175,8 @@ class ConvertSameRankInsertStridedSliceIntoShuffle // 3. Reduce the problem to lowering a new InsertStridedSlice op with // smaller rank. extractedSource = rewriter.create( - loc, extractedSource, extractedDest, - getI64SubArray(op.getOffsets(), /* dropFront=*/1), - getI64SubArray(op.getStrides(), /* dropFront=*/1)); + loc, extractedSource, extractedDest, op.getOffsets().drop_front(1), + op.getStrides().drop_front(1)); } // 4. Insert the extractedSource into the res vector. res = insertOne(rewriter, loc, extractedSource, res, off); @@ -205,18 +198,16 @@ class Convert1DExtractStridedSliceIntoShuffle PatternRewriter &rewriter) const override { auto dstType = op.getType(); - assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); + assert(!op.getOffsets().empty() && "Unexpected empty offsets"); - int64_t offset = - cast(op.getOffsets().getValue().front()).getInt(); - int64_t size = cast(op.getSizes().getValue().front()).getInt(); - int64_t stride = - cast(op.getStrides().getValue().front()).getInt(); + int64_t offset = op.getOffsets().front(); + int64_t size = op.getSizes().front(); + int64_t stride = op.getStrides().front(); assert(dstType.getElementType().isSignlessIntOrIndexOrFloat()); // Single offset can be more efficiently shuffled. - if (op.getOffsets().getValue().size() != 1) + if (op.getOffsets().size() != 1) return failure(); SmallVector offsets; @@ -248,14 +239,12 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final return failure(); // Only handle 1-D cases. - if (op.getOffsets().getValue().size() != 1) + if (op.getOffsets().size() != 1) return failure(); - int64_t offset = - cast(op.getOffsets().getValue().front()).getInt(); - int64_t size = cast(op.getSizes().getValue().front()).getInt(); - int64_t stride = - cast(op.getStrides().getValue().front()).getInt(); + int64_t offset = op.getOffsets().front(); + int64_t size = op.getSizes().front(); + int64_t stride = op.getStrides().front(); Location loc = op.getLoc(); SmallVector elements; @@ -294,13 +283,11 @@ class DecomposeNDExtractStridedSlice PatternRewriter &rewriter) const override { auto dstType = op.getType(); - assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); + assert(!op.getOffsets().empty() && "Unexpected empty offsets"); - int64_t offset = - cast(op.getOffsets().getValue().front()).getInt(); - int64_t size = cast(op.getSizes().getValue().front()).getInt(); - int64_t stride = - cast(op.getStrides().getValue().front()).getInt(); + int64_t offset = op.getOffsets().front(); + int64_t size = op.getSizes().front(); + int64_t stride = op.getStrides().front(); auto loc = op.getLoc(); auto elemType = dstType.getElementType(); @@ -308,7 +295,7 @@ class DecomposeNDExtractStridedSlice // Single offset can be more efficiently shuffled. It's handled in // Convert1DExtractStridedSliceIntoShuffle. - if (op.getOffsets().getValue().size() == 1) + if (op.getOffsets().size() == 1) return failure(); // Extract/insert on a lower ranked extract strided slice op. @@ -319,9 +306,8 @@ class DecomposeNDExtractStridedSlice off += stride, ++idx) { Value one = extractOne(rewriter, loc, op.getVector(), off); Value extracted = rewriter.create( - loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1), - getI64SubArray(op.getSizes(), /* dropFront=*/1), - getI64SubArray(op.getStrides(), /* dropFront=*/1)); + loc, one, op.getOffsets().drop_front(), op.getSizes().drop_front(), + op.getStrides().drop_front()); res = insertOne(rewriter, loc, extracted, res, idx); } rewriter.replaceOp(op, res); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 868397f2daaae4..dbcdc6ea8f31ac 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -160,10 +160,10 @@ struct LinearizeVectorExtractStridedSlice final return rewriter.notifyMatchFailure( extractOp, "Can't flatten since targetBitWidth <= OpSize"); - ArrayAttr offsets = extractOp.getOffsets(); - ArrayAttr sizes = extractOp.getSizes(); - ArrayAttr strides = extractOp.getStrides(); - if (!isConstantIntValue(strides[0], 1)) + ArrayRef offsets = extractOp.getOffsets(); + ArrayRef sizes = extractOp.getSizes(); + ArrayRef strides = extractOp.getStrides(); + if (strides[0] != 1) return rewriter.notifyMatchFailure( extractOp, "Strided slice with stride != 1 is not supported."); Value srcVector = adaptor.getVector(); @@ -185,8 +185,8 @@ struct LinearizeVectorExtractStridedSlice final } // Get total number of extracted slices. int64_t nExtractedSlices = 1; - for (Attribute size : sizes) { - nExtractedSlices *= cast(size).getInt(); + for (int64_t size : sizes) { + nExtractedSlices *= size; } // Compute the strides of the source vector considering first k dimensions. llvm::SmallVector sourceStrides(kD, extractGranularitySize); @@ -202,8 +202,7 @@ struct LinearizeVectorExtractStridedSlice final llvm::SmallVector extractedStrides(kD, 1); // Compute extractedStrides. for (int i = kD - 2; i >= 0; --i) { - extractedStrides[i] = - extractedStrides[i + 1] * cast(sizes[i + 1]).getInt(); + extractedStrides[i] = extractedStrides[i + 1] * sizes[i + 1]; } // Iterate over all extracted slices from 0 to nExtractedSlices - 1 // and compute the multi-dimensional index and the corresponding linearized @@ -220,9 +219,7 @@ struct LinearizeVectorExtractStridedSlice final // i.e. shift the multiDimIndex by the offsets. int64_t linearizedIndex = 0; for (int64_t j = 0; j < kD; ++j) { - linearizedIndex += - (cast(offsets[j]).getInt() + multiDimIndex[j]) * - sourceStrides[j]; + linearizedIndex += (offsets[j] + multiDimIndex[j]) * sourceStrides[j]; } // Fill the indices array form linearizedIndex to linearizedIndex + // extractGranularitySize. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 6777e589795c8e..75820162dd9d54 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -548,13 +548,6 @@ struct ReorderElementwiseOpsOnTranspose final } }; -// Returns the values in `arrayAttr` as an integer vector. -static SmallVector getIntValueVector(ArrayAttr arrayAttr) { - return llvm::to_vector<4>( - llvm::map_range(arrayAttr.getAsRange(), - [](IntegerAttr attr) { return attr.getInt(); })); -} - // Shuffles vector.bitcast op after vector.extract op. // // This transforms IR like: @@ -661,8 +654,7 @@ struct BubbleDownBitCastForStridedSliceExtract return failure(); // Only accept all one strides for now. - if (llvm::any_of(extractOp.getStrides().getAsValueRange(), - [](const APInt &val) { return !val.isOne(); })) + if (extractOp.hasNonUnitStrides()) return failure(); unsigned rank = extractOp.getSourceVectorType().getRank(); @@ -673,34 +665,24 @@ struct BubbleDownBitCastForStridedSliceExtract // are selecting the full range for the last bitcasted dimension; other // dimensions aren't affected. Otherwise, we need to scale down the last // dimension's offset given we are extracting from less elements now. - ArrayAttr newOffsets = extractOp.getOffsets(); + SmallVector newOffsets(extractOp.getOffsets()); if (newOffsets.size() == rank) { - SmallVector offsets = getIntValueVector(newOffsets); - if (offsets.back() % expandRatio != 0) + if (newOffsets.back() % expandRatio != 0) return failure(); - offsets.back() = offsets.back() / expandRatio; - newOffsets = rewriter.getI64ArrayAttr(offsets); + newOffsets.back() = newOffsets.back() / expandRatio; } // Similarly for sizes. - ArrayAttr newSizes = extractOp.getSizes(); + SmallVector newSizes(extractOp.getSizes()); if (newSizes.size() == rank) { - SmallVector sizes = getIntValueVector(newSizes); - if (sizes.back() % expandRatio != 0) + if (newSizes.back() % expandRatio != 0) return failure(); - sizes.back() = sizes.back() / expandRatio; - newSizes = rewriter.getI64ArrayAttr(sizes); + newSizes.back() = newSizes.back() / expandRatio; } - SmallVector dims = - llvm::to_vector<4>(cast(extractOp.getType()).getShape()); - dims.back() = dims.back() / expandRatio; - VectorType newExtractType = - VectorType::get(dims, castSrcType.getElementType()); - auto newExtractOp = rewriter.create( - extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, - newSizes, extractOp.getStrides()); + extractOp.getLoc(), castOp.getSource(), newOffsets, newSizes, + extractOp.getStrides()); rewriter.replaceOpWithNewOp( extractOp, extractOp.getType(), newExtractOp); @@ -818,8 +800,7 @@ struct BubbleUpBitCastForStridedSliceInsert return failure(); // Only accept all one strides for now. - if (llvm::any_of(insertOp.getStrides().getAsValueRange(), - [](const APInt &val) { return !val.isOne(); })) + if (insertOp.hasNonUnitStrides()) return failure(); unsigned rank = insertOp.getSourceVectorType().getRank(); @@ -836,13 +817,11 @@ struct BubbleUpBitCastForStridedSliceInsert if (insertOp.getSourceVectorType().getNumElements() % numElements != 0) return failure(); - ArrayAttr newOffsets = insertOp.getOffsets(); + SmallVector newOffsets(insertOp.getOffsets()); assert(newOffsets.size() == rank); - SmallVector offsets = getIntValueVector(newOffsets); - if (offsets.back() % shrinkRatio != 0) + if (newOffsets.back() % shrinkRatio != 0) return failure(); - offsets.back() = offsets.back() / shrinkRatio; - newOffsets = rewriter.getI64ArrayAttr(offsets); + newOffsets.back() = newOffsets.back() / shrinkRatio; SmallVector srcDims = llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); @@ -863,7 +842,7 @@ struct BubbleUpBitCastForStridedSliceInsert bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); rewriter.replaceOpWithNewOp( - bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, + bitcastOp, newCastSrcOp, newCastDstOp, newOffsets, insertOp.getStrides()); return success(); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 10ba895a1b3a4d..3d1862a9a889b2 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -683,7 +683,7 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) { // ----- func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) { - // expected-error@+1 {{expected offsets attribute of rank no greater than vector rank}} + // expected-error@+1 {{op expected offsets to have rank no greater than vector rank}} %1 = vector.extract_strided_slice %arg0 {offsets = [2, 2, 2, 2], sizes = [2, 2, 2, 2], strides = [1, 1, 1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32> } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 916e3e5fd2529d..93a7836b0192e8 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -172,18 +172,18 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ARG:.*]] {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32> // BW-0: return %[[RES]] : vector<2x2xf32> - %0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]} + %0 = vector.extract_strided_slice %arg0 { offsets = [0, 4], sizes = [2, 2], strides = [1, 1] } : vector<4x8xf32> to vector<2x2xf32> return %0 : vector<2x2xf32> } // ALL-LABEL: func.func @test_extract_strided_slice_1_scalable( // ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> { -func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { +func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> { // ALL-NOT: vector.shuffle // ALL-NOT: vector.shape_cast // ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32> - %0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32> + %0 = vector.extract_strided_slice %arg0 { offsets = [1, 0], sizes = [2, 8], strides = [1, 1] } : vector<4x[8]xf32> to vector<2x[8]xf32> // ALL: return %[[RES]] : vector<2x[8]xf32> return %0 : vector<2x[8]xf32> } @@ -206,7 +206,7 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4 // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ORIG_ARG]] {offsets = [1, 2], sizes = [1, 4], strides = [1, 1]} : vector<2x8x2xf32> to vector<1x4x2xf32> // BW-0: return %[[RES]] : vector<1x4x2xf32> - %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] } + %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], sizes = [1, 4], strides = [1, 1] } : vector<2x8x2xf32> to vector<1x4x2xf32> return %0 : vector<1x4x2xf32> }