Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSVE] Add convert.from/to.svbool intrinsics #68418

Merged
merged 1 commit into from
Oct 10, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Oct 6, 2023

These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not support this for any type smaller
than an svbool, which is vector<[16]xi1>).

Depends on #68399

@MacDue
Copy link
Member Author

MacDue commented Oct 6, 2023

cc @c-rhodes, @banach-space

These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not suppor this for any type smaller
than an svbool, which is vector<[16]xi1>).
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 9, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-sve

@llvm/pr-subscribers-mlir-llvm

Changes

These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not support this for any type smaller
than an svbool, which is vector<[16]xi1>).

Depends on #68399


Full diff: https://github.com/llvm/llvm-project/pull/68418.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+24)
  • (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+44)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 58dec6091f27f6e..d4294b4dd9fd4e8 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -30,6 +30,16 @@ def ArmSVE_Dialect : Dialect {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ArmSVE type definitions
+//===----------------------------------------------------------------------===//
+
+def SVBool : ScalableVectorOfRankAndLengthAndType<
+  [1], [16], [I1]>;
+
+def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
+  [1], [16, 8, 4, 2, 1], [I1]>;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -302,4 +312,18 @@ def ScalableMaskedDivFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
   Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
+def ConvertFromSvboolIntrOp :
+  ArmSVE_IntrOp<"convert.from.svbool",
+    [TypeIs<"res", SVEPredicate>],
+    /*overloadedOperands=*/[],
+    /*overloadedResults=*/[0]>,
+  Arguments<(ins SVBool:$svbool)>;
+
+def ConvertToSvboolIntrOp :
+  ArmSVE_IntrOp<"convert.to.svbool",
+    [TypeIs<"res", SVBool>],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[]>,
+    Arguments<(ins SVEPredicate:$mask)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 999df8079e0727a..172a2f7d12d440e 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -272,3 +272,47 @@ llvm.func @get_vector_scale() -> i64 {
   %0 = "llvm.intr.vscale"() : () -> i64
   llvm.return %0 : i64
 }
+
+// CHECK-LABEL: @arm_sve_convert_from_svbool(
+// CHECK-SAME:                               <vscale x 16 x i1> %[[SVBOOL:[0-9]+]])
+llvm.func @arm_sve_convert_from_svbool(%nxv16i1 : vector<[16]xi1>) {
+  // CHECK: %[[RES0:.*]] = call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res0 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[8]xi1>
+  // CHECK: %[[RES1:.*]] = call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res1 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[4]xi1>
+  // CHECK: %[[RES2:.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res2 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[2]xi1>
+  // CHECK: %[[RES3:.*]] = call <vscale x 1 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv1i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res3 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[1]xi1>
+  llvm.return
+}
+
+// CHECK-LABEL: arm_sve_convert_to_svbool(
+// CHECK-SAME:                            <vscale x 8 x i1> %[[P8:[0-9]+]],
+// CHECK-SAME:                            <vscale x 4 x i1> %[[P4:[0-9]+]],
+// CHECK-SAME:                            <vscale x 2 x i1> %[[P2:[0-9]+]],
+// CHECK-SAME:                            <vscale x 1 x i1> %[[P1:[0-9]+]])
+llvm.func @arm_sve_convert_to_svbool(
+                                       %nxv8i1  : vector<[8]xi1>,
+                                       %nxv4i1  : vector<[4]xi1>,
+                                       %nxv2i1  : vector<[2]xi1>,
+                                       %nxv1i1  : vector<[1]xi1>
+) {
+  // CHECK-NEXT: %[[RES0:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %[[P8]])
+  %res0 = "arm_sve.intr.convert.to.svbool"(%nxv8i1)
+    : (vector<[8]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES1:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %[[P4]])
+  %res1 = "arm_sve.intr.convert.to.svbool"(%nxv4i1)
+    : (vector<[4]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES2:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %[[P2]])
+  %res2 = "arm_sve.intr.convert.to.svbool"(%nxv2i1)
+    : (vector<[2]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES3:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv1i1(<vscale x 1 x i1> %[[P1]])
+  %res3 = "arm_sve.intr.convert.to.svbool"(%nxv1i1)
+    : (vector<[1]xi1>) -> vector<[16]xi1>
+  llvm.return
+}

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, ta!

@MacDue MacDue merged commit 3d70ba6 into llvm:main Oct 10, 2023
7 checks passed
MacDue added a commit that referenced this pull request Oct 12, 2023
This adds slightly higher-level ops for converting masks between svbool
and SVE predicate types. The main reason to use these over the
intrinsics is these ops support vectors of masks (via unrolling).

E.g.

```
// Convert a svbool mask to a mask of SVE predicates:
%svbool = vector.load %memref[%c0, %c0]
                       : memref<2x?xi1>, vector<2x[16]xi1>
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
// => Results in vector<2x[8]xi1>
```
Or:
```
// Convert a mask of SVE predicates to a svbool mask:
%mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>
```

Depends on #68418
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants