Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Update representation of insert/extract_strided_slice #101850

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Aug 3, 2024

This commit updates the representation of both extract_strided_slice and insert_strided_slice to primitive arrays of int64_ts, rather than ArrayAttrs of IntegerAttrs. This prevents a lot of boilerplate conversions between IntegerAttr and int64_t.

This is done by adding a new StridedSliceAttr which matches the previous syntax and can be used for both operations.

It may also be possible to explore alternate slice syntax for the StridedSliceAttr in future.

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 3, 2024

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This commit updates the representation of both extract_strided_slice and insert_strided_slice to primitive arrays of int64_ts, rather than ArrayAttrs of IntegerAttrs. This prevents a lot of boilerplate conversions between IntegerAttr and int64_t.

Because previously the offsets, strides, and sizes were in the attribute dictionary (with no special syntax), simply replacing the attribute types with DenseI64ArrayAttr would be a syntax break.

So since a syntax break is mostly unavoidable this commit also tackles a long-standing TODO:

// TODO: Evolve to a range form syntax similar to:
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
  : vector<4x8x16xf32> to vector<2x4x16xf32>

This is done by introducing a new StridedSliceAttr attribute that can be used for both operations, with syntax based on the above example (see the attribute documentation VectorAttributes.td for a full syntax overview).

With this:

extract_strided_slice goes from:

%1 = vector.extract_strided_slice %0
     {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}
     : vector<4x8x16xf32> to vector<2x4x16xf32>

To:

%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
     : vector<4x8x16xf32> to vector<2x4x16xf32>

(matching the TODO)


And insert_strided_slice goes from:

%2 = vector.insert_strided_slice %0, %1
     {offsets = [0, 0, 2], strides = [1, 1]}
     : vector<2x4xf32> into vector<16x4x8xf32>

To:

%2 = vector.insert_strided_slice %0, %1[0][0:1][2:1]
     : vector<2x4xf32> into vector<16x4x8xf32>

(inspired by the TODO)


Almost all test changes were done automatically via auto-upgrade-insert-extract-slice.py, available at: https://gist.github.com/MacDue/ca84d3ec19cf83ae71aab2be8f09c3c5 (use at your own risk).

This PR is split into multiple commits to make the changes more understandable.

  • The first commit is code changes
  • The second commit is automatic test changes
  • The final commit is manual test changes

Patch is 354.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101850.diff

44 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td (+64)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+19-20)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-10)
  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+5-8)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-4)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+184-149)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp (+1-8)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+7-17)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+29-43)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+8-11)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+14-35)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir (+1-1)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir (+15-15)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir (+54-54)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+1-1)
  • (modified) mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir (+4-4)
  • (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir (+12-12)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+10-10)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+4-4)
  • (modified) mlir/test/Dialect/Arith/emulate-wide-int.mlir (+19-19)
  • (modified) mlir/test/Dialect/Arith/int-narrowing.mlir (+16-16)
  • (modified) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+83-83)
  • (modified) mlir/test/Dialect/ArmNeon/roundtrip.mlir (+2-2)
  • (modified) mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir (+10-10)
  • (modified) mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir (+19-19)
  • (modified) mlir/test/Dialect/Linalg/vectorize-convolution.mlir (+96-96)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+63-63)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+17-24)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+7-9)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir (+12-12)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+10-10)
  • (modified) mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/vector-scan-transforms.mlir (+20-20)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+4-4)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-unroll.mlir (+102-102)
  • (modified) mlir/test/Dialect/Vector/vector-transforms.mlir (+36-36)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+70-70)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/contraction.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/extract-strided-slice.mlir (+1-1)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/insert-strided-slice.mlir (+4-4)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/transpose.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
index 0f08f61d7b257..7fa20b950e7c6 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
@@ -16,6 +16,11 @@
 include "mlir/Dialect/Vector/IR/Vector.td"
 include "mlir/IR/EnumAttr.td"
 
+class Vector_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+    : AttrDef<Vector_Dialect, attrName, traits> {
+  let mnemonic = attrMnemonic;
+}
+
 // The "kind" of combining function for contractions and reductions.
 def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
 def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
@@ -82,4 +87,63 @@ def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctu
   let assemblyFormat = "`<` $value `>`";
 }
 
+def Vector_StridedSliceAttr : Vector_Attr<"StridedSlice", "strided_slice">
+{
+  let summary = "strided vector slice";
+
+  let description = [{
+    An attribute that represents a strided slice of a vector.
+
+    *Syntax:*
+
+    ```
+    offset = integer-literal
+    stride = integer-literal
+    size = integer-literal
+    offset-list = offset (`,` offset)*
+
+    // Without sizes (used for insert_strided_slice)
+    strided-slice-without-sizes = offset-list? (`[` offset `:` stride `]`)+
+
+    // With sizes (used for extract_strided_slice)
+    strided-slice-with-sizes = (`[` offset `:` size `:` stride `]`)+
+    ```
+
+    *Examples:*
+
+    Without sizes:
+
+    `[0:1][4:2]`
+
+    - The first dimension starts at offset 0 and is strided by 1
+    - The second dimension starts at offset 4 and is strided by 2
+
+    `[0, 1, 2][3:1][4:8]`
+
+    - The first three dimensions are indexed without striding (offsets 0, 1, 2)
+    - The fourth dimension starts at offset 3 and is strided by 1
+    - The fifth dimension starts at offset 4 and is strided by 8
+
+    With sizes (used for extract_strided_slice)
+
+    `[0:2:4][2:4:3]`
+
+    - The first dimension starts at offset 0, has size 2, and is strided by 4
+    - The second dimension starts at offset 2, has size 4, and is strided by 3
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$offsets,
+    OptionalArrayRefParameter<"int64_t">:$sizes,
+    ArrayRefParameter<"int64_t">:$strides
+  );
+
+  let builders = [AttrBuilder<(ins "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides), [{
+      return $_get($_ctxt, offsets, ArrayRef<int64_t>{}, strides);
+    }]>
+  ];
+
+  let hasCustomAssemblyFormat = 1;
+}
+
 #endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..45edb75c1989a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1040,8 +1040,8 @@ def Vector_InsertStridedSliceOp :
     PredOpTrait<"operand #0 and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
     AllTypesMatch<["dest", "res"]>]>,
-    Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
-               I64ArrayAttr:$strides)>,
+    Arguments<(ins AnyVector:$source, AnyVector:$dest,
+               Vector_StridedSliceAttr:$strided_slice)>,
     Results<(outs AnyVector:$res)> {
   let summary = "strided_slice operation";
   let description = [{
@@ -1059,14 +1059,13 @@ def Vector_InsertStridedSliceOp :
     Example:
 
     ```mlir
-    %2 = vector.insert_strided_slice %0, %1
-        {offsets = [0, 0, 2], strides = [1, 1]}:
-      vector<2x4xf32> into vector<16x4x8xf32>
+    %2 = vector.insert_strided_slice %0, %1[0][0:1][2:1]
+      : vector<2x4xf32> into vector<16x4x8xf32>
     ```
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest attr-dict `:` type($source) `into` type($dest)
+    $source `,` $dest `` $strided_slice attr-dict `:` type($source) `into` type($dest)
   }];
 
   let builders = [
@@ -1081,10 +1080,13 @@ def Vector_InsertStridedSliceOp :
       return ::llvm::cast<VectorType>(getDest().getType());
     }
     bool hasNonUnitStrides() {
-      return llvm::any_of(getStrides(), [](Attribute attr) {
-        return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+      return llvm::any_of(getStrides(), [](int64_t stride) {
+        return stride != 1;
       });
     }
+
+    ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+    ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
   }];
 
   let hasFolder = 1;
@@ -1298,8 +1300,7 @@ def Vector_ExtractStridedSliceOp :
   Vector_Op<"extract_strided_slice", [Pure,
     PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
-               I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
+    Arguments<(ins AnyVector:$vector, Vector_StridedSliceAttr:$strided_slice)>,
     Results<(outs AnyVector)> {
   let summary = "extract_strided_slice operation";
   let description = [{
@@ -1316,13 +1317,8 @@ def Vector_ExtractStridedSliceOp :
     Example:
 
     ```mlir
-    %1 = vector.extract_strided_slice %0
-        {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
-      vector<4x8x16xf32> to vector<2x4x16xf32>
-
-    // TODO: Evolve to a range form syntax similar to:
     %1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
-      vector<4x8x16xf32> to vector<2x4x16xf32>
+      : vector<4x8x16xf32> to vector<2x4x16xf32>
     ```
   }];
   let builders = [
@@ -1333,17 +1329,20 @@ def Vector_ExtractStridedSliceOp :
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getVector().getType());
     }
-    void getOffsets(SmallVectorImpl<int64_t> &results);
     bool hasNonUnitStrides() {
-      return llvm::any_of(getStrides(), [](Attribute attr) {
-        return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+      return llvm::any_of(getStrides(), [](int64_t stride) {
+        return stride != 1;
       });
     }
+
+    ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+    ArrayRef<int64_t> getSizes() { return getStridedSlice().getSizes(); }
+    ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
-  let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
+  let assemblyFormat = "$vector `` $strided_slice attr-dict `:` type($vector) `to` type(results)";
 }
 
 // TODO: Tighten semantics so that masks and inbounds can't be used
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 0150ff667e4ef..a2647e2b647c1 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -940,12 +940,6 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
   return success();
 }
 
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
-                                       SmallVectorImpl<int64_t> &results) {
-  for (auto attr : arrayAttr)
-    results.push_back(cast<IntegerAttr>(attr).getInt());
-}
-
 static LogicalResult
 convertExtractStridedSlice(RewriterBase &rewriter,
                            vector::ExtractStridedSliceOp op,
@@ -996,11 +990,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
   auto sourceVector = it->second;
 
   // offset and sizes at warp-level of onwership.
-  SmallVector<int64_t> offsets;
-  populateFromInt64AttrArray(op.getOffsets(), offsets);
+  ArrayRef<int64_t> offsets = op.getOffsets();
 
-  SmallVector<int64_t> sizes;
-  populateFromInt64AttrArray(op.getSizes(), sizes);
   ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
 
   // Compute offset in vector registers. Note that the mma.sync vector registers
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21b8858989839..4d4e5ebb4f428 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -46,9 +46,6 @@ static uint64_t getFirstIntValue(ValueRange values) {
 static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
   return cast<IntegerAttr>(attr[0]).getInt();
 }
-static uint64_t getFirstIntValue(ArrayAttr attr) {
-  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
-}
 static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
   auto attr = foldResults[0].dyn_cast<Attribute>();
   if (attr)
@@ -187,9 +184,9 @@ struct VectorExtractStridedSliceOpConvert final
     if (!dstType)
       return failure();
 
-    uint64_t offset = getFirstIntValue(extractOp.getOffsets());
-    uint64_t size = getFirstIntValue(extractOp.getSizes());
-    uint64_t stride = getFirstIntValue(extractOp.getStrides());
+    int64_t offset = extractOp.getOffsets().front();
+    int64_t size = extractOp.getSizes().front();
+    int64_t stride = extractOp.getStrides().front();
     if (stride != 1)
       return failure();
 
@@ -323,10 +320,10 @@ struct VectorInsertStridedSliceOpConvert final
     Value srcVector = adaptor.getOperands().front();
     Value dstVector = adaptor.getOperands().back();
 
-    uint64_t stride = getFirstIntValue(insertOp.getStrides());
+    uint64_t stride = insertOp.getStrides().front();
     if (stride != 1)
       return failure();
-    uint64_t offset = getFirstIntValue(insertOp.getOffsets());
+    uint64_t offset = insertOp.getOffsets().front();
 
     if (isa<spirv::ScalarType>(srcVector.getType())) {
       assert(!isa<spirv::ScalarType>(dstVector.getType()));
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index e2d42e961c576..941644e1116fc 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -550,11 +550,8 @@ struct ExtensionOverExtractStridedSlice final
     if (failed(ext))
       return failure();
 
-    VectorType origTy = op.getType();
-    VectorType extractTy =
-        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
     Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
-        op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
+        op.getLoc(), ext->getIn(), op.getOffsets(), op.getSizes(),
         op.getStrides());
     ext->recreateAndReplace(rewriter, op, newExtract);
     return success();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..dda6b916176fa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1340,13 +1340,6 @@ LogicalResult vector::ExtractOp::verify() {
   return success();
 }
 
-template <typename IntType>
-static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
-  return llvm::to_vector<4>(llvm::map_range(
-      arrayAttr.getAsRange<IntegerAttr>(),
-      [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
-}
-
 /// Fold the result of chains of ExtractOp in place by simply concatenating the
 /// positions.
 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
@@ -1770,8 +1763,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
     return Value();
 
   // Trim offsets for dimensions fully extracted.
-  auto sliceOffsets =
-      extractVector<int64_t>(extractStridedSliceOp.getOffsets());
+  SmallVector<int64_t> sliceOffsets(extractStridedSliceOp.getOffsets());
   while (!sliceOffsets.empty()) {
     size_t lastOffset = sliceOffsets.size() - 1;
     if (sliceOffsets.back() != 0 ||
@@ -1825,12 +1817,10 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
                              insertOp.getSourceVectorType().getRank();
     if (destinationRank > insertOp.getSourceVectorType().getRank())
       return Value();
-    auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
+    ArrayRef<int64_t> insertOffsets = insertOp.getOffsets();
     ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
 
-    if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
-          return llvm::cast<IntegerAttr>(attr).getInt() != 1;
-        }))
+    if (insertOp.hasNonUnitStrides())
       return Value();
     bool disjoint = false;
     SmallVector<int64_t, 4> offsetDiffs;
@@ -2899,6 +2889,95 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// StridedSliceAttr
+//===----------------------------------------------------------------------===//
+
+Attribute StridedSliceAttr::parse(AsmParser &parser, Type attrType) {
+  SmallVector<int64_t> offsets;
+  SmallVector<int64_t> sizes;
+  SmallVector<int64_t> strides;
+  bool parsedNonStridedOffsets = false;
+  while (succeeded(parser.parseOptionalLSquare())) {
+    int64_t offset = 0;
+    if (parser.parseInteger(offset))
+      return {};
+    if (parser.parseOptionalColon()) {
+      // Case 1: [Offset, ...]
+      if (!strides.empty() || parsedNonStridedOffsets) {
+        parser.emitError(parser.getCurrentLocation(),
+                         "expected slice stride or size");
+        return {};
+      }
+      offsets.push_back(offset);
+      if (succeeded(parser.parseOptionalComma())) {
+        if (parser.parseCommaSeparatedList(
+                AsmParser::Delimiter::None, [&]() -> ParseResult {
+                  if (parser.parseInteger(offset))
+                    return failure();
+                  offsets.push_back(offset);
+                  return success();
+                })) {
+          return {};
+        }
+      }
+      if (parser.parseRSquare())
+        return {};
+      parsedNonStridedOffsets = true;
+      continue;
+    }
+    int64_t sizeOrStide = 0;
+    if (parser.parseInteger(sizeOrStide)) {
+      parser.emitError(parser.getCurrentLocation(),
+                       "expected slice stride or size");
+      return {};
+    }
+    if (parser.parseOptionalColon()) {
+      // Case 2: [Offset:Stride]
+      if (!sizes.empty() || parser.parseRSquare()) {
+        parser.emitError(parser.getCurrentLocation(), "expected slice size");
+        return {};
+      }
+      offsets.push_back(offset);
+      strides.push_back(sizeOrStide);
+      continue;
+    }
+    // Case 3: [Offset:Size:Stride]
+    if (sizes.size() < strides.size()) {
+      parser.emitError(parser.getCurrentLocation(), "unexpected slice size");
+      return {};
+    }
+    int64_t stride = 0;
+    if (parser.parseInteger(stride) || parser.parseRSquare()) {
+      parser.emitError(parser.getCurrentLocation(), "expected slice stride");
+      return {};
+    }
+    offsets.push_back(offset);
+    sizes.push_back(sizeOrStide);
+    strides.push_back(stride);
+  }
+  return StridedSliceAttr::get(parser.getContext(), offsets, sizes, strides);
+}
+
+void StridedSliceAttr::print(AsmPrinter &printer) const {
+  ArrayRef<int64_t> offsets = getOffsets();
+  ArrayRef<int64_t> sizes = getSizes();
+  ArrayRef<int64_t> strides = getStrides();
+  int nonStridedOffsets = offsets.size() - strides.size();
+  if (nonStridedOffsets > 0) {
+    printer << '[';
+    llvm::interleaveComma(offsets.take_front(nonStridedOffsets), printer);
+    printer << ']';
+  }
+  for (int d = nonStridedOffsets, e = offsets.size(); d < e; ++d) {
+    int strideIdx = d - nonStridedOffsets;
+    printer << '[' << offsets[d] << ':';
+    if (!sizes.empty())
+      printer << sizes[strideIdx] << ':';
+    printer << strides[strideIdx] << ']';
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // InsertStridedSliceOp
 //===----------------------------------------------------------------------===//
@@ -2907,26 +2986,8 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                                  Value source, Value dest,
                                  ArrayRef<int64_t> offsets,
                                  ArrayRef<int64_t> strides) {
-  result.addOperands({source, dest});
-  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
-  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
-  result.addTypes(dest.getType());
-  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
-                      offsetsAttr);
-  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
-                      stridesAttr);
-}
-
-// TODO: Should be moved to Tablegen ConfinedAttr attributes.
-template <typename OpType>
-static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
-                                                        ArrayAttr arrayAttr,
-                                                        ArrayRef<int64_t> shape,
-                                                        StringRef attrName) {
-  if (arrayAttr.size() > shape.size())
-    return op.emitOpError("expected ")
-           << attrName << " attribute of rank no greater than vector rank";
-  return success();
+  build(builder, result, source, dest,
+        StridedSliceAttr::get(builder.getContext(), offsets, strides));
 }
 
 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
@@ -2934,16 +2995,15 @@ static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
 // Otherwise, the admissible interval is [min, max].
 template <typename OpType>
 static LogicalResult
-isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
-                                  int64_t max, StringRef attrName,
-                                  bool halfOpen = true) {
-  for (auto attr : arrayAttr) {
-    auto val = llvm::cast<IntegerAttr>(attr).getInt();
+isIntArrayConfinedToRange(OpType op, ArrayRef<int64_t> array, int64_t min,
+                          int64_t max, StringRef arrayName,
+                          bool halfOpen = true) {
+  for (int64_t val : array) {
     auto upper = max;
     if (!halfOpen)
       upper += 1;
     if (val < min || val >= upper)
-      return op.emitOpError("expected ") << attrName << " to be confined to ["
+      return op.emitOpError("expected ") << arrayName << " to be confined to ["
                                          << min << ", " << upper << ")";
   }
   return success();
@@ -2954,13 +3014,12 @@ isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
 // Otherwise, the admissible interval is [min, max].
 template <typename OpType>
 static LogicalResult
-isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
-                                  ArrayRef<int64_t> shape, StringRef attrName,
-                                  bool halfOpen = true, int64_t min = 0) {
-  for (auto [index, attrDimPair] :
-       llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
-    int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
-    int64_t max = std::get<1>(attrDimPair);
+isIntArrayConfinedToShape(OpType op, ArrayRef<int64_t> array,
+                          ArrayRef<int64_t> shape, StringRef attrName,
+                          bool halfOpen = true, int64_t min = 0) {
+  for (auto [index, dimPair] : llvm::enumerate(llvm::zip_first(array, shape))) {
+    int64_t val, max;
+    std::tie(val, max) = dimPair;
     if (!halfOpen)
       max += 1;
     if (val < min || val >= max)
@@ -2977,40 +3036,32 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
 // the admissible interval is [min, max].
 template <typename OpType>
-static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
-    OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
-    ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
+static LogicalResult isSumOfIntArrayConfinedToShape(
+    OpType op, ArrayRef<int64_t> array1, ArrayRef<int64_t> array2,
+    ArrayRef<int64_t> shape, StringRef arrayName1, StringRef arrayName2,
     bool halfOpen = true, int64_t min = 1) {
-  assert(arrayAttr1.size() <= shape.size());
-  assert(arrayAttr2.size() <= sh...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 3, 2024

@llvm/pr-subscribers-backend-amdgpu

Author: Benjamin Maxwell (MacDue)

Changes

This commit updates the representation of both extract_strided_slice and insert_strided_slice to primitive arrays of int64_ts, rather than ArrayAttrs of IntegerAttrs. This prevents a lot of boilerplate conversions between IntegerAttr and int64_t.

Because previously the offsets, strides, and sizes were in the attribute dictionary (with no special syntax), simply replacing the attribute types with DenseI64ArrayAttr would be a syntax break.

So since a syntax break is mostly unavoidable this commit also tackles a long-standing TODO:

// TODO: Evolve to a range form syntax similar to:
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
  : vector&lt;4x8x16xf32&gt; to vector&lt;2x4x16xf32&gt;

This is done by introducing a new StridedSliceAttr attribute that can be used for both operations, with syntax based on the above example (see the attribute documentation VectorAttributes.td for a full syntax overview).

With this:

extract_strided_slice goes from:

%1 = vector.extract_strided_slice %0
     {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}
     : vector&lt;4x8x16xf32&gt; to vector&lt;2x4x16xf32&gt;

To:

%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
     : vector&lt;4x8x16xf32&gt; to vector&lt;2x4x16xf32&gt;

(matching the TODO)


And insert_strided_slice goes from:

%2 = vector.insert_strided_slice %0, %1
     {offsets = [0, 0, 2], strides = [1, 1]}
     : vector&lt;2x4xf32&gt; into vector&lt;16x4x8xf32&gt;

To:

%2 = vector.insert_strided_slice %0, %1[0][0:1][2:1]
     : vector&lt;2x4xf32&gt; into vector&lt;16x4x8xf32&gt;

(inspired by the TODO)


Almost all test changes were done automatically via auto-upgrade-insert-extract-slice.py, available at: https://gist.github.com/MacDue/ca84d3ec19cf83ae71aab2be8f09c3c5 (use at your own risk).

This PR is split into multiple commits to make the changes more understandable.

  • The first commit is code changes
  • The second commit is automatic test changes
  • The final commit is manual test changes

Patch is 354.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101850.diff

44 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td (+64)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+19-20)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-10)
  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+5-8)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-4)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+184-149)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp (+1-8)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+7-17)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+29-43)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+8-11)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+14-35)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir (+1-1)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir (+15-15)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir (+54-54)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+1-1)
  • (modified) mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir (+4-4)
  • (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir (+12-12)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+10-10)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+4-4)
  • (modified) mlir/test/Dialect/Arith/emulate-wide-int.mlir (+19-19)
  • (modified) mlir/test/Dialect/Arith/int-narrowing.mlir (+16-16)
  • (modified) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+83-83)
  • (modified) mlir/test/Dialect/ArmNeon/roundtrip.mlir (+2-2)
  • (modified) mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir (+10-10)
  • (modified) mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir (+19-19)
  • (modified) mlir/test/Dialect/Linalg/vectorize-convolution.mlir (+96-96)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+63-63)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+17-24)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+7-9)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir (+12-12)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+10-10)
  • (modified) mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/vector-scan-transforms.mlir (+20-20)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+4-4)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-unroll.mlir (+102-102)
  • (modified) mlir/test/Dialect/Vector/vector-transforms.mlir (+36-36)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+70-70)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/contraction.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/extract-strided-slice.mlir (+1-1)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/insert-strided-slice.mlir (+4-4)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/transpose.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
index 0f08f61d7b257..7fa20b950e7c6 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
@@ -16,6 +16,11 @@
 include "mlir/Dialect/Vector/IR/Vector.td"
 include "mlir/IR/EnumAttr.td"
 
+class Vector_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+    : AttrDef<Vector_Dialect, attrName, traits> {
+  let mnemonic = attrMnemonic;
+}
+
 // The "kind" of combining function for contractions and reductions.
 def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
 def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
@@ -82,4 +87,63 @@ def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctu
   let assemblyFormat = "`<` $value `>`";
 }
 
+def Vector_StridedSliceAttr : Vector_Attr<"StridedSlice", "strided_slice">
+{
+  let summary = "strided vector slice";
+
+  let description = [{
+    An attribute that represents a strided slice of a vector.
+
+    *Syntax:*
+
+    ```
+    offset = integer-literal
+    stride = integer-literal
+    size = integer-literal
+    offset-list = offset (`,` offset)*
+
+    // Without sizes (used for insert_strided_slice)
+    strided-slice-without-sizes = offset-list? (`[` offset `:` stride `]`)+
+
+    // With sizes (used for extract_strided_slice)
+    strided-slice-with-sizes = (`[` offset `:` size `:` stride `]`)+
+    ```
+
+    *Examples:*
+
+    Without sizes:
+
+    `[0:1][4:2]`
+
+    - The first dimension starts at offset 0 and is strided by 1
+    - The second dimension starts at offset 4 and is strided by 2
+
+    `[0, 1, 2][3:1][4:8]`
+
+    - The first three dimensions are indexed without striding (offsets 0, 1, 2)
+    - The fourth dimension starts at offset 3 and is strided by 1
+    - The fifth dimension starts at offset 4 and is strided by 8
+
+    With sizes (used for extract_strided_slice)
+
+    `[0:2:4][2:4:3]`
+
+    - The first dimension starts at offset 0, has size 2, and is strided by 4
+    - The second dimension starts at offset 2, has size 4, and is strided by 3
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$offsets,
+    OptionalArrayRefParameter<"int64_t">:$sizes,
+    ArrayRefParameter<"int64_t">:$strides
+  );
+
+  let builders = [AttrBuilder<(ins "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides), [{
+      return $_get($_ctxt, offsets, ArrayRef<int64_t>{}, strides);
+    }]>
+  ];
+
+  let hasCustomAssemblyFormat = 1;
+}
+
 #endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..45edb75c1989a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1040,8 +1040,8 @@ def Vector_InsertStridedSliceOp :
     PredOpTrait<"operand #0 and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
     AllTypesMatch<["dest", "res"]>]>,
-    Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
-               I64ArrayAttr:$strides)>,
+    Arguments<(ins AnyVector:$source, AnyVector:$dest,
+               Vector_StridedSliceAttr:$strided_slice)>,
     Results<(outs AnyVector:$res)> {
   let summary = "strided_slice operation";
   let description = [{
@@ -1059,14 +1059,13 @@ def Vector_InsertStridedSliceOp :
     Example:
 
     ```mlir
-    %2 = vector.insert_strided_slice %0, %1
-        {offsets = [0, 0, 2], strides = [1, 1]}:
-      vector<2x4xf32> into vector<16x4x8xf32>
+    %2 = vector.insert_strided_slice %0, %1[0][0:1][2:1]
+      : vector<2x4xf32> into vector<16x4x8xf32>
     ```
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest attr-dict `:` type($source) `into` type($dest)
+    $source `,` $dest `` $strided_slice attr-dict `:` type($source) `into` type($dest)
   }];
 
   let builders = [
@@ -1081,10 +1080,13 @@ def Vector_InsertStridedSliceOp :
       return ::llvm::cast<VectorType>(getDest().getType());
     }
     bool hasNonUnitStrides() {
-      return llvm::any_of(getStrides(), [](Attribute attr) {
-        return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+      return llvm::any_of(getStrides(), [](int64_t stride) {
+        return stride != 1;
       });
     }
+
+    ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+    ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
   }];
 
   let hasFolder = 1;
@@ -1298,8 +1300,7 @@ def Vector_ExtractStridedSliceOp :
   Vector_Op<"extract_strided_slice", [Pure,
     PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
-               I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
+    Arguments<(ins AnyVector:$vector, Vector_StridedSliceAttr:$strided_slice)>,
     Results<(outs AnyVector)> {
   let summary = "extract_strided_slice operation";
   let description = [{
@@ -1316,13 +1317,8 @@ def Vector_ExtractStridedSliceOp :
     Example:
 
     ```mlir
-    %1 = vector.extract_strided_slice %0
-        {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
-      vector<4x8x16xf32> to vector<2x4x16xf32>
-
-    // TODO: Evolve to a range form syntax similar to:
     %1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
-      vector<4x8x16xf32> to vector<2x4x16xf32>
+      : vector<4x8x16xf32> to vector<2x4x16xf32>
     ```
   }];
   let builders = [
@@ -1333,17 +1329,20 @@ def Vector_ExtractStridedSliceOp :
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getVector().getType());
     }
-    void getOffsets(SmallVectorImpl<int64_t> &results);
     bool hasNonUnitStrides() {
-      return llvm::any_of(getStrides(), [](Attribute attr) {
-        return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+      return llvm::any_of(getStrides(), [](int64_t stride) {
+        return stride != 1;
       });
     }
+
+    ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+    ArrayRef<int64_t> getSizes() { return getStridedSlice().getSizes(); }
+    ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
-  let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
+  let assemblyFormat = "$vector `` $strided_slice attr-dict `:` type($vector) `to` type(results)";
 }
 
 // TODO: Tighten semantics so that masks and inbounds can't be used
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 0150ff667e4ef..a2647e2b647c1 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -940,12 +940,6 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
   return success();
 }
 
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
-                                       SmallVectorImpl<int64_t> &results) {
-  for (auto attr : arrayAttr)
-    results.push_back(cast<IntegerAttr>(attr).getInt());
-}
-
 static LogicalResult
 convertExtractStridedSlice(RewriterBase &rewriter,
                            vector::ExtractStridedSliceOp op,
@@ -996,11 +990,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
   auto sourceVector = it->second;
 
   // offset and sizes at warp-level of onwership.
-  SmallVector<int64_t> offsets;
-  populateFromInt64AttrArray(op.getOffsets(), offsets);
+  ArrayRef<int64_t> offsets = op.getOffsets();
 
-  SmallVector<int64_t> sizes;
-  populateFromInt64AttrArray(op.getSizes(), sizes);
   ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
 
   // Compute offset in vector registers. Note that the mma.sync vector registers
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21b8858989839..4d4e5ebb4f428 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -46,9 +46,6 @@ static uint64_t getFirstIntValue(ValueRange values) {
 static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
   return cast<IntegerAttr>(attr[0]).getInt();
 }
-static uint64_t getFirstIntValue(ArrayAttr attr) {
-  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
-}
 static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
   auto attr = foldResults[0].dyn_cast<Attribute>();
   if (attr)
@@ -187,9 +184,9 @@ struct VectorExtractStridedSliceOpConvert final
     if (!dstType)
       return failure();
 
-    uint64_t offset = getFirstIntValue(extractOp.getOffsets());
-    uint64_t size = getFirstIntValue(extractOp.getSizes());
-    uint64_t stride = getFirstIntValue(extractOp.getStrides());
+    int64_t offset = extractOp.getOffsets().front();
+    int64_t size = extractOp.getSizes().front();
+    int64_t stride = extractOp.getStrides().front();
     if (stride != 1)
       return failure();
 
@@ -323,10 +320,10 @@ struct VectorInsertStridedSliceOpConvert final
     Value srcVector = adaptor.getOperands().front();
     Value dstVector = adaptor.getOperands().back();
 
-    uint64_t stride = getFirstIntValue(insertOp.getStrides());
+    uint64_t stride = insertOp.getStrides().front();
     if (stride != 1)
       return failure();
-    uint64_t offset = getFirstIntValue(insertOp.getOffsets());
+    uint64_t offset = insertOp.getOffsets().front();
 
     if (isa<spirv::ScalarType>(srcVector.getType())) {
       assert(!isa<spirv::ScalarType>(dstVector.getType()));
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index e2d42e961c576..941644e1116fc 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -550,11 +550,8 @@ struct ExtensionOverExtractStridedSlice final
     if (failed(ext))
       return failure();
 
-    VectorType origTy = op.getType();
-    VectorType extractTy =
-        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
     Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
-        op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
+        op.getLoc(), ext->getIn(), op.getOffsets(), op.getSizes(),
         op.getStrides());
     ext->recreateAndReplace(rewriter, op, newExtract);
     return success();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..dda6b916176fa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1340,13 +1340,6 @@ LogicalResult vector::ExtractOp::verify() {
   return success();
 }
 
-template <typename IntType>
-static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
-  return llvm::to_vector<4>(llvm::map_range(
-      arrayAttr.getAsRange<IntegerAttr>(),
-      [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
-}
-
 /// Fold the result of chains of ExtractOp in place by simply concatenating the
 /// positions.
 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
@@ -1770,8 +1763,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
     return Value();
 
   // Trim offsets for dimensions fully extracted.
-  auto sliceOffsets =
-      extractVector<int64_t>(extractStridedSliceOp.getOffsets());
+  SmallVector<int64_t> sliceOffsets(extractStridedSliceOp.getOffsets());
   while (!sliceOffsets.empty()) {
     size_t lastOffset = sliceOffsets.size() - 1;
     if (sliceOffsets.back() != 0 ||
@@ -1825,12 +1817,10 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
                              insertOp.getSourceVectorType().getRank();
     if (destinationRank > insertOp.getSourceVectorType().getRank())
       return Value();
-    auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
+    ArrayRef<int64_t> insertOffsets = insertOp.getOffsets();
     ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
 
-    if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
-          return llvm::cast<IntegerAttr>(attr).getInt() != 1;
-        }))
+    if (insertOp.hasNonUnitStrides())
       return Value();
     bool disjoint = false;
     SmallVector<int64_t, 4> offsetDiffs;
@@ -2899,6 +2889,95 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// StridedSliceAttr
+//===----------------------------------------------------------------------===//
+
+Attribute StridedSliceAttr::parse(AsmParser &parser, Type attrType) {
+  SmallVector<int64_t> offsets;
+  SmallVector<int64_t> sizes;
+  SmallVector<int64_t> strides;
+  bool parsedNonStridedOffsets = false;
+  while (succeeded(parser.parseOptionalLSquare())) {
+    int64_t offset = 0;
+    if (parser.parseInteger(offset))
+      return {};
+    if (parser.parseOptionalColon()) {
+      // Case 1: [Offset, ...]
+      if (!strides.empty() || parsedNonStridedOffsets) {
+        parser.emitError(parser.getCurrentLocation(),
+                         "expected slice stride or size");
+        return {};
+      }
+      offsets.push_back(offset);
+      if (succeeded(parser.parseOptionalComma())) {
+        if (parser.parseCommaSeparatedList(
+                AsmParser::Delimiter::None, [&]() -> ParseResult {
+                  if (parser.parseInteger(offset))
+                    return failure();
+                  offsets.push_back(offset);
+                  return success();
+                })) {
+          return {};
+        }
+      }
+      if (parser.parseRSquare())
+        return {};
+      parsedNonStridedOffsets = true;
+      continue;
+    }
+    int64_t sizeOrStide = 0;
+    if (parser.parseInteger(sizeOrStide)) {
+      parser.emitError(parser.getCurrentLocation(),
+                       "expected slice stride or size");
+      return {};
+    }
+    if (parser.parseOptionalColon()) {
+      // Case 2: [Offset:Stride]
+      if (!sizes.empty() || parser.parseRSquare()) {
+        parser.emitError(parser.getCurrentLocation(), "expected slice size");
+        return {};
+      }
+      offsets.push_back(offset);
+      strides.push_back(sizeOrStide);
+      continue;
+    }
+    // Case 3: [Offset:Size:Stride]
+    if (sizes.size() < strides.size()) {
+      parser.emitError(parser.getCurrentLocation(), "unexpected slice size");
+      return {};
+    }
+    int64_t stride = 0;
+    if (parser.parseInteger(stride) || parser.parseRSquare()) {
+      parser.emitError(parser.getCurrentLocation(), "expected slice stride");
+      return {};
+    }
+    offsets.push_back(offset);
+    sizes.push_back(sizeOrStide);
+    strides.push_back(stride);
+  }
+  return StridedSliceAttr::get(parser.getContext(), offsets, sizes, strides);
+}
+
+void StridedSliceAttr::print(AsmPrinter &printer) const {
+  ArrayRef<int64_t> offsets = getOffsets();
+  ArrayRef<int64_t> sizes = getSizes();
+  ArrayRef<int64_t> strides = getStrides();
+  int nonStridedOffsets = offsets.size() - strides.size();
+  if (nonStridedOffsets > 0) {
+    printer << '[';
+    llvm::interleaveComma(offsets.take_front(nonStridedOffsets), printer);
+    printer << ']';
+  }
+  for (int d = nonStridedOffsets, e = offsets.size(); d < e; ++d) {
+    int strideIdx = d - nonStridedOffsets;
+    printer << '[' << offsets[d] << ':';
+    if (!sizes.empty())
+      printer << sizes[strideIdx] << ':';
+    printer << strides[strideIdx] << ']';
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // InsertStridedSliceOp
 //===----------------------------------------------------------------------===//
@@ -2907,26 +2986,8 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                                  Value source, Value dest,
                                  ArrayRef<int64_t> offsets,
                                  ArrayRef<int64_t> strides) {
-  result.addOperands({source, dest});
-  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
-  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
-  result.addTypes(dest.getType());
-  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
-                      offsetsAttr);
-  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
-                      stridesAttr);
-}
-
-// TODO: Should be moved to Tablegen ConfinedAttr attributes.
-template <typename OpType>
-static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
-                                                        ArrayAttr arrayAttr,
-                                                        ArrayRef<int64_t> shape,
-                                                        StringRef attrName) {
-  if (arrayAttr.size() > shape.size())
-    return op.emitOpError("expected ")
-           << attrName << " attribute of rank no greater than vector rank";
-  return success();
+  build(builder, result, source, dest,
+        StridedSliceAttr::get(builder.getContext(), offsets, strides));
 }
 
 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
@@ -2934,16 +2995,15 @@ static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
 // Otherwise, the admissible interval is [min, max].
 template <typename OpType>
 static LogicalResult
-isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
-                                  int64_t max, StringRef attrName,
-                                  bool halfOpen = true) {
-  for (auto attr : arrayAttr) {
-    auto val = llvm::cast<IntegerAttr>(attr).getInt();
+isIntArrayConfinedToRange(OpType op, ArrayRef<int64_t> array, int64_t min,
+                          int64_t max, StringRef arrayName,
+                          bool halfOpen = true) {
+  for (int64_t val : array) {
     auto upper = max;
     if (!halfOpen)
       upper += 1;
     if (val < min || val >= upper)
-      return op.emitOpError("expected ") << attrName << " to be confined to ["
+      return op.emitOpError("expected ") << arrayName << " to be confined to ["
                                          << min << ", " << upper << ")";
   }
   return success();
@@ -2954,13 +3014,12 @@ isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
 // Otherwise, the admissible interval is [min, max].
 template <typename OpType>
 static LogicalResult
-isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
-                                  ArrayRef<int64_t> shape, StringRef attrName,
-                                  bool halfOpen = true, int64_t min = 0) {
-  for (auto [index, attrDimPair] :
-       llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
-    int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
-    int64_t max = std::get<1>(attrDimPair);
+isIntArrayConfinedToShape(OpType op, ArrayRef<int64_t> array,
+                          ArrayRef<int64_t> shape, StringRef attrName,
+                          bool halfOpen = true, int64_t min = 0) {
+  for (auto [index, dimPair] : llvm::enumerate(llvm::zip_first(array, shape))) {
+    int64_t val, max;
+    std::tie(val, max) = dimPair;
     if (!halfOpen)
       max += 1;
     if (val < min || val >= max)
@@ -2977,40 +3036,32 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
 // the admissible interval is [min, max].
 template <typename OpType>
-static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
-    OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
-    ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
+static LogicalResult isSumOfIntArrayConfinedToShape(
+    OpType op, ArrayRef<int64_t> array1, ArrayRef<int64_t> array2,
+    ArrayRef<int64_t> shape, StringRef arrayName1, StringRef arrayName2,
     bool halfOpen = true, int64_t min = 1) {
-  assert(arrayAttr1.size() <= shape.size());
-  assert(arrayAttr2.size() <= sh...
[truncated]

@MacDue
Copy link
Member Author

MacDue commented Aug 3, 2024

This is the most ambitious change based on #100997.

Note: This PR is not urgent! It's just something I think would be nice to have 🙂

Copy link

github-actions bot commented Aug 3, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td Outdated Show resolved Hide resolved
printer << sizes[strideIdx] << ':';
printer << strides[strideIdx] << ']';
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be an assumption about the size of all the arrays, but I don't see the invariant enforced in a verifier?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add a verifier for the attribute, as the constraints are enforced within the ops (right now). I think some of those constraints could be factored out to the StridedSliceAttr though 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going to leave this for a future patch now it's not needed for the printer/parser (which just uses the current syntax).

This commit updates the representation of both extract_strided_slice and
insert_strided_slice to primitive arrays of int64_ts, rather than
ArrayAttrs of IntegerAttrs. This prevents a lot of boilerplate
conversions between IntegerAttr and int64_t.

This is done by adding a new `StridedSliceAttr` which matches the
previous syntax and can be used for both operations.

It may also be possible to explore alternate slice syntax for the
`StridedSliceAttr` in future.
@MacDue MacDue changed the title [mlir][vector] Update syntax and representation of insert/extract_strided_slice [mlir][vector] Update representation of insert/extract_strided_slice Aug 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants