diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 58dec6091f27f6..d4294b4dd9fd4e 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -30,6 +30,16 @@ def ArmSVE_Dialect : Dialect { }]; } +//===----------------------------------------------------------------------===// +// ArmSVE type definitions +//===----------------------------------------------------------------------===// + +def SVBool : ScalableVectorOfRankAndLengthAndType< + [1], [16], [I1]>; + +def SVEPredicate : ScalableVectorOfRankAndLengthAndType< + [1], [16, 8, 4, 2, 1], [I1]>; + //===----------------------------------------------------------------------===// // ArmSVE op definitions //===----------------------------------------------------------------------===// @@ -302,4 +312,18 @@ def ScalableMaskedDivFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fdiv">, Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; +def ConvertFromSvboolIntrOp : + ArmSVE_IntrOp<"convert.from.svbool", + [TypeIs<"res", SVEPredicate>], + /*overloadedOperands=*/[], + /*overloadedResults=*/[0]>, + Arguments<(ins SVBool:$svbool)>; + +def ConvertToSvboolIntrOp : + ArmSVE_IntrOp<"convert.to.svbool", + [TypeIs<"res", SVBool>], + /*overloadedOperands=*/[0], + /*overloadedResults=*/[]>, + Arguments<(ins SVEPredicate:$mask)>; + #endif // ARMSVE_OPS diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir index 999df8079e0727..172a2f7d12d440 100644 --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -272,3 +272,47 @@ llvm.func @get_vector_scale() -> i64 { %0 = "llvm.intr.vscale"() : () -> i64 llvm.return %0 : i64 } + +// CHECK-LABEL: @arm_sve_convert_from_svbool( +// CHECK-SAME: %[[SVBOOL:[0-9]+]]) +llvm.func @arm_sve_convert_from_svbool(%nxv16i1 : vector<[16]xi1>) { + // CHECK: %[[RES0:.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %[[SVBOOL]]) + %res0 = "arm_sve.intr.convert.from.svbool"(%nxv16i1) + : (vector<[16]xi1>) -> vector<[8]xi1> + // CHECK: %[[RES1:.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %[[SVBOOL]]) + %res1 = "arm_sve.intr.convert.from.svbool"(%nxv16i1) + : (vector<[16]xi1>) -> vector<[4]xi1> + // CHECK: %[[RES2:.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv2i1( %[[SVBOOL]]) + %res2 = "arm_sve.intr.convert.from.svbool"(%nxv16i1) + : (vector<[16]xi1>) -> vector<[2]xi1> + // CHECK: %[[RES3:.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv1i1( %[[SVBOOL]]) + %res3 = "arm_sve.intr.convert.from.svbool"(%nxv16i1) + : (vector<[16]xi1>) -> vector<[1]xi1> + llvm.return +} + +// CHECK-LABEL: arm_sve_convert_to_svbool( +// CHECK-SAME: %[[P8:[0-9]+]], +// CHECK-SAME: %[[P4:[0-9]+]], +// CHECK-SAME: %[[P2:[0-9]+]], +// CHECK-SAME: %[[P1:[0-9]+]]) +llvm.func @arm_sve_convert_to_svbool( + %nxv8i1 : vector<[8]xi1>, + %nxv4i1 : vector<[4]xi1>, + %nxv2i1 : vector<[2]xi1>, + %nxv1i1 : vector<[1]xi1> +) { + // CHECK-NEXT: %[[RES0:.*]] = call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %[[P8]]) + %res0 = "arm_sve.intr.convert.to.svbool"(%nxv8i1) + : (vector<[8]xi1>) -> vector<[16]xi1> + // CHECK-NEXT: %[[RES1:.*]] = call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %[[P4]]) + %res1 = "arm_sve.intr.convert.to.svbool"(%nxv4i1) + : (vector<[4]xi1>) -> vector<[16]xi1> + // CHECK-NEXT: %[[RES2:.*]] = call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( %[[P2]]) + %res2 = "arm_sve.intr.convert.to.svbool"(%nxv2i1) + : (vector<[2]xi1>) -> vector<[16]xi1> + // CHECK-NEXT: %[[RES3:.*]] = call @llvm.aarch64.sve.convert.to.svbool.nxv1i1( %[[P1]]) + %res3 = "arm_sve.intr.convert.to.svbool"(%nxv1i1) + : (vector<[1]xi1>) -> vector<[16]xi1> + llvm.return +}