Skip to content

Commit

Permalink
[mlir][ArmSVE] Add convert_to/from_svbool ops (#68586)
Browse files Browse the repository at this point in the history
This adds slightly higher-level ops for converting masks between svbool
and SVE predicate types. The main reason to use these over the
intrinsics is these ops support vectors of masks (via unrolling).

E.g.

```
// Convert a svbool mask to a mask of SVE predicates:
%svbool = vector.load %memref[%c0, %c0]
                       : memref<2x?xi1>, vector<2x[16]xi1>
%mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1>
// => Results in vector<2x[8]xi1>
```
Or:
```
// Convert a mask of SVE predicates to a svbool mask:
%mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
%svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>
```

Depends on #68418
  • Loading branch information
MacDue authored Oct 12, 2023
1 parent 1c12dcc commit b833bcb
Show file tree
Hide file tree
Showing 9 changed files with 448 additions and 5 deletions.
84 changes: 84 additions & 0 deletions mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def ArmSVE_Dialect : Dialect {
This dialect contains the definitions necessary to target specific Arm SVE
scalable vector operations.
}];

let dependentDialects = ["vector::VectorDialect"];
}

//===----------------------------------------------------------------------===//
Expand All @@ -40,6 +42,13 @@ def SVBool : ScalableVectorOfRankAndLengthAndType<
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
[1], [16, 8, 4, 2, 1], [I1]>;

// Generalizations of SVBool and SVEPredicate to ranks >= 1.
// These are masks with a single trailing scalable dimension.
def SVBoolMask : VectorWithTrailingDimScalableOfSizeAndType<
[16], [I1]>;
def SVEPredicateMask : VectorWithTrailingDimScalableOfSizeAndType<
[16, 8, 4, 2, 1], [I1]>;

//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -236,6 +245,81 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}

class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
"VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;

def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
[Pure, SvboolTypeConstraint<"result", "source">]>
{
let summary = "Convert a svbool type to a SVE predicate type";
let description = [{
Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
`vector<2x3x[16]xi1>`) to SVE predicate types. Note: Only the trailing
dimension can be scalable.

Example 1: Convert a 1-D svbool mask to a SVE predicate.
```mlir
%source = vector.load %memref[%c0] : memref<?xi1>, vector<[16]xi1>
%result = arm_sve.convert_from_svbool %source : vector<[4]xi1>
```

Example 2: Convert a 2-D svbool mask to a mask of SVE predicates.
```mlir
%source = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1>
%result = arm_sve.convert_from_svbool %source : vector<2x[8]xi1>
```

---

A `svbool` is the smallest SVE predicate type that has a in-memory
representation (and maps to a full predicate register). In MLIR `svbool` is
represented as `vector<[16]xi1>`. Smaller SVE predicate types
(`vector<[1|2|4|8]xi1>`) must be stored as a `svbool` then converted back to
the original predicate type after loading.
}];
let arguments = (ins SVBoolMask:$source);
let results = (outs SVEPredicateMask:$result);
let assemblyFormat = "$source attr-dict `:` type($result)";
}

def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
[Pure, SvboolTypeConstraint<"source", "result">]>
{
let summary = "Convert a SVE predicate type to a svbool type";
let description = [{
Converts SVE predicate types (or vectors of predicate types, e.g.
`vector<4x[4]xi1>`) to svbool types. Note: Only the trailing dimension can
be scalable.

Example 1: Convert a 1-D SVE predicate to a svbool mask.
```mlir
%source = vector.create_mask %dim_size : vector<[4]xi1>
%result = arm_sve.convert_to_svbool %source : vector<[4]xi1>
// => Results in vector<[16]xi1>
```

Example 2: Convert a 2-D mask of SVE predicates to a svbool mask.
```mlir
%source = vector.create_mask %c2, %dim_size : vector<2x[2]xi1>
%result = arm_sve.convert_to_svbool %source : vector<2x[2]xi1>
// => Results in vector<2x[16]xi1>
```

---

A `svbool` is the smallest SVE predicate type that has a in-memory
representation (and maps to a full predicate register). In MLIR `svbool` is
represented as `vector<[16]xi1>`. Smaller SVE predicate types
(`vector<[1|2|4|8]xi1>`) must be converted to a `svbool` before they can be
stored.
}];
let arguments = (ins SVEPredicateMask:$source);
let results = (outs SVBoolMask:$result);
let assemblyFormat = "$source attr-dict `:` type($source)";
}

def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
[Commutative]>;

Expand Down
74 changes: 74 additions & 0 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
::llvm::cast<VectorType>($_self).isScalable()}]>;

// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
// Examples:
// Valid:
// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32>
// Invalid
// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32>
def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[
CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">
]>;

// Whether a type is a VectorType and all dimensions are scalable.
def allDimsScalableVectorTypePred : And<[
IsVectorTypePred,
Expand Down Expand Up @@ -404,6 +417,15 @@ class ScalableVectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsScalableVectorTypePred,
"scalable vector", "::mlir::VectorType">;

// Any vector with a single trailing scalable dimension, with an element type in
// the `allowedTypes` list.
//
// Note: This Similar to ScalableVectorOf, with the extra requirement that only
// the trailing dim is scalable.
class VectorWithTrailingDimScalableOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypeWithOnlyTrailingDimScalablePred,
"trailing scalable vector", "::mlir::VectorType">;

// Whether the number of elements of a vector is from the given
// `allowedRanks` list
class IsVectorOfRankPred<list<int> allowedRanks> :
Expand Down Expand Up @@ -481,6 +503,40 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
== }]
# allowedlength>)>]>;

// Normalizes an index so the indices in both directions have the same value.
// For example, when indexing forwards index 2 is the third element. When
// indexing in reverse the third element is -3. This helper would map both of
// these to the "normalized" index of 3. This makes the bounds checking in
// IsNthDimSizeIsOneOfPred simpler (see first CPred).
class NormalizeIndex<int value> {
int ret = !if(!lt(value, 0),
!sub(0, value) /* -value if negative */,
!add(value, 1) /* value + 1 if positive*/);
}

// Whether the n-th dim of the shape is contained within `allowedSizes`.
// Negative values for `n` index in reverse.
//
// Examples:
// IsNthDimSizeIsOneOfPred<0, {2, 3, 4}>
// - Accepts any shape where the first dim is 2, 3, or 4.
// * This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc
// IsNthDimSizeIsOneOfPred<-1, {16}>
// - Accepts any shape where the last dim is 16.
// * This means shapes like 2x16, 16, 1x2x3x4x16, etc
// IsNthDimSizeIsOneOfPred<-2, {10, 5}>
// - Accepts any shape where the second to last dim is 10 or 5.
// * This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc
class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
: And<[
CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex<n>.ret>,
CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
# "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
# !if(!lt(n, 0),
"::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
"" # n)
# "))">]>;

// Whether the shape of a vector matches the given `shape` list.
class IsVectorOfShape<list<int> shape>
: CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
Expand Down Expand Up @@ -546,6 +602,24 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;

// Any ShapedType where the size of the n-th dim is contained in `allowedSizes`.
// Negative values for `n` index in reverse.
class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
IsNthDimSizeIsOneOfPred<n, allowedSizes>,
" with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
"::mlir::ShapedType">;

// Any scalable vector with a single trailing scalable dimensions, where the
// size of the trailing dimension is in `allowedTrailingSizes` list, and the
// type is in the `allowedTypes` list.
class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
list<Type> allowedTypes> : AllOfType<
[VectorWithTrailingDimScalableOf<allowedTypes>,
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
VectorWithTrailingDimScalableOf<allowedTypes>.summary #
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
"::mlir::VectorType">;

def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVEDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
MLIRVectorDialect
MLIRSideEffectInterfaces
)
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRArmSVETransforms
LINK_LIBS PUBLIC
MLIRArmSVEDialect
MLIRFuncDialect
MLIRVectorDialect
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
Expand Down
85 changes: 82 additions & 3 deletions mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"

Expand Down Expand Up @@ -66,6 +68,77 @@ using ScalableMaskedDivFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
ScalableMaskedDivFIntrOp>;

namespace {

/// Unrolls a conversion to/from equivalent vector types, to allow using a
/// conversion intrinsic that only supports 1-D vector types.
///
/// Example:
/// ```
/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
/// ```
/// is rewritten into:
/// ```
/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
/// : (vector<[4]xi1>) -> vector<[16]xi1>
/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
/// : (vector<[4]xi1>) -> vector<[16]xi1>
/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
/// ```
template <typename Op, typename IntrOp>
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Op convertOp, typename Op::Adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = convertOp.getLoc();

auto source = convertOp.getSource();
VectorType sourceType = source.getType();
VectorType resultType = convertOp.getResult().getType();

Value result = rewriter.create<arith::ConstantOp>(
loc, resultType, rewriter.getZeroAttr(resultType));

// We want to iterate over the input vector in steps of the trailing
// dimension. So this creates tile shape where all leading dimensions are 1,
// and the trailing dimension step is the size of the dimension.
SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
tileShape.back() = sourceType.getShape().back();

// Iterate over all scalable mask/predicate slices of the source vector.
for (SmallVector<int64_t> index :
StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
auto extractOrInsertPosition = ArrayRef(index).drop_back();
auto sourceVector = rewriter.create<vector::ExtractOp>(
loc, source, extractOrInsertPosition);
auto convertedType =
VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
.setDim(0, resultType.getShape().back());
auto convertedVector =
rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
extractOrInsertPosition);
}

rewriter.replaceOp(convertOp, result);
return success();
}
};

using ConvertToSvboolOpLowering =
SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;

using ConvertFromSvboolOpLowering =
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;

} // namespace

/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Expand All @@ -88,7 +161,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ScalableMaskedMulFOpLowering,
ScalableMaskedSDivIOpLowering,
ScalableMaskedUDivIOpLowering,
ScalableMaskedDivFOpLowering>(converter);
ScalableMaskedDivFOpLowering,
ConvertToSvboolOpLowering,
ConvertFromSvboolOpLowering>(converter);
// clang-format on
}

Expand All @@ -107,7 +182,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedMulFIntrOp,
ScalableMaskedSDivIIntrOp,
ScalableMaskedUDivIIntrOp,
ScalableMaskedDivFIntrOp>();
ScalableMaskedDivFIntrOp,
ConvertToSvboolIntrOp,
ConvertFromSvboolIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
Expand All @@ -120,6 +197,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedMulFOp,
ScalableMaskedSDivIOp,
ScalableMaskedUDivIOp,
ScalableMaskedDivFOp>();
ScalableMaskedDivFOp,
ConvertToSvboolOp,
ConvertFromSvboolOp>();
// clang-format on
}
Loading

0 comments on commit b833bcb

Please sign in to comment.