Skip to content

Conversation

@hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Feb 18, 2025

There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType:

  • dropEncoding(): Return a clone of this type without the encoding.
  • cloneWithEncoding(Attribute encoding): Return a clone of this type with the given new encoding and the same shape and element type as this type.

There are clone methods for shape and element type, but not for
encodings. The revision adds two clone method to RankedTensorType:
- dropEncoding(): Return a clone of this type without the encoding.
- cloneWithEncoding(Attribute encoding): Return a clone of this type
  with the given new encoding and the same shape and element type as
  this type.

Signed-off-by: hanhanW <[email protected]>
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Feb 18, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 18, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-ods

Author: Han-Chung Wang (hanhanW)

Changes

There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType:

  • dropEncoding(): Return a clone of this type without the encoding.
  • cloneWithEncoding(Attribute encoding): Return a clone of this type with the given new encoding and the same shape and element type as this type.

Full diff: https://github.com/llvm/llvm-project/pull/127709.diff

2 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+11)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+14)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e5a2ae81da0c9..af474b3e3ec47 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     RankedTensorType clone(::mlir::Type elementType) {
       return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
     }
+
+    /// Return a clone of this type without the encoding.
+    RankedTensorType dropEncoding() {
+      return RankedTensorType::get(getShape(), getElementType());
+    }
+
+    /// Return a clone of this type with the given new encoding and the same
+    /// shape and element type as this type.
+    RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
+      return RankedTensorType::get(getShape(), getElementType(), encoding);
+    }
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index c2900b5aaeeeb..bc4066ed210e8 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
   ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
   view = mlir::cast<TensorWithString>(viewCreated);
   EXPECT_EQ(view.getName(), "bob");
+
+  // Verify encoding clone methods.
+  EXPECT_EQ(unitEncodingRankedTensorType,
+            cast<RankedTensorType>(noEncodingRankedTensorType)
+                .cloneWithEncoding(unitAttr));
+  EXPECT_EQ(stringEncodingRankedTensorType,
+            cast<RankedTensorType>(noEncodingRankedTensorType)
+                .cloneWithEncoding(stringAttr));
+  EXPECT_EQ(
+      noEncodingRankedTensorType,
+      cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
+  EXPECT_EQ(
+      noEncodingRankedTensorType,
+      cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
 }
 
 } // namespace

@llvmbot
Copy link
Member

llvmbot commented Feb 18, 2025

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType:

  • dropEncoding(): Return a clone of this type without the encoding.
  • cloneWithEncoding(Attribute encoding): Return a clone of this type with the given new encoding and the same shape and element type as this type.

Full diff: https://github.com/llvm/llvm-project/pull/127709.diff

2 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+11)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+14)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e5a2ae81da0c9..af474b3e3ec47 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     RankedTensorType clone(::mlir::Type elementType) {
       return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
     }
+
+    /// Return a clone of this type without the encoding.
+    RankedTensorType dropEncoding() {
+      return RankedTensorType::get(getShape(), getElementType());
+    }
+
+    /// Return a clone of this type with the given new encoding and the same
+    /// shape and element type as this type.
+    RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
+      return RankedTensorType::get(getShape(), getElementType(), encoding);
+    }
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index c2900b5aaeeeb..bc4066ed210e8 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
   ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
   view = mlir::cast<TensorWithString>(viewCreated);
   EXPECT_EQ(view.getName(), "bob");
+
+  // Verify encoding clone methods.
+  EXPECT_EQ(unitEncodingRankedTensorType,
+            cast<RankedTensorType>(noEncodingRankedTensorType)
+                .cloneWithEncoding(unitAttr));
+  EXPECT_EQ(stringEncodingRankedTensorType,
+            cast<RankedTensorType>(noEncodingRankedTensorType)
+                .cloneWithEncoding(stringAttr));
+  EXPECT_EQ(
+      noEncodingRankedTensorType,
+      cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
+  EXPECT_EQ(
+      noEncodingRankedTensorType,
+      cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
 }
 
 } // namespace

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hanhanW hanhanW merged commit 28d7671 into llvm:main Feb 28, 2025
12 checks passed
@hanhanW hanhanW deleted the add-methods-to-ranked-tensor-type branch February 28, 2025 01:59
cheezeburglar pushed a commit to cheezeburglar/llvm-project that referenced this pull request Feb 28, 2025
…m#127709)

There are clone methods for shape and element type, but not for
encodings. The revision adds two clone method to RankedTensorType:
- dropEncoding(): Return a clone of this type without the encoding.
- cloneWithEncoding(Attribute encoding): Return a clone of this type
with the given new encoding and the same shape and element type as this
type.

Signed-off-by: hanhanW <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir:ods mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants