diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 25535408f4528..7f9c470ffec30 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -401,6 +401,38 @@ enum class SliceVerificationResult { SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType); +//===----------------------------------------------------------------------===// +// Convenience wrappers for VectorType +// +// These are provided to allow idiomatic code like: +// * isa(type) +//===----------------------------------------------------------------------===// +/// A vector type containing at least one scalable dimension. +class ScalableVectorType : public VectorType { +public: + using VectorType::VectorType; + + static bool classof(Type type) { + auto vecTy = llvm::dyn_cast(type); + if (!vecTy) + return false; + return vecTy.isScalable(); + } +}; + +/// A vector type with no scalable dimensions. +class FixedVectorType : public VectorType { +public: + using VectorType::VectorType; + + static bool classof(Type type) { + auto vecTy = llvm::dyn_cast(type); + if (!vecTy) + return false; + return !vecTy.isScalable(); + } +}; + //===----------------------------------------------------------------------===// // Deferred Method Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 7db095d0ae5af..b9f8c1ed19470 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td" // Explicitly disallow 0-D vectors for now until we have good enough coverage. def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">, CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>; -def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">, - CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">, - CPred<"!::llvm::cast($_self).isScalable()">]>; +def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">, + CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>; // Temporary vector type clone that allows gradual transition to 0-D vectors. // TODO: Remove this when all ops support 0-D vectors. def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">; // Whether a type is a fixed-length VectorType. -def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && - !::llvm::cast($_self).isScalable()}]>; +def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>; // Whether a type is a scalable VectorType. def IsVectorTypeWithAnyDimScalablePred - : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && - ::llvm::cast($_self).isScalable()}]>; + : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>; // Whether a type is a scalable VectorType, with a single trailing scalable dimension. // Examples: diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h index c209f869a579d..1f1d0f7a30669 100644 --- a/mlir/include/mlir/IR/VectorTypes.h +++ b/mlir/include/mlir/IR/VectorTypes.h @@ -10,42 +10,3 @@ // * isa(type) // //===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_VECTORTYPES_H -#define MLIR_IR_VECTORTYPES_H - -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Types.h" - -namespace mlir { -namespace vector { - -/// A vector type containing at least one scalable dimension. -class ScalableVectorType : public VectorType { -public: - using VectorType::VectorType; - - static bool classof(Type type) { - auto vecTy = llvm::dyn_cast(type); - if (!vecTy) - return false; - return vecTy.isScalable(); - } -}; - -/// A vector type with no scalable dimensions. -class FixedVectorType : public VectorType { -public: - using VectorType::VectorType; - static bool classof(Type type) { - auto vecTy = llvm::dyn_cast(type); - if (!vecTy) - return false; - return !vecTy.isScalable(); - } -}; - -} // namespace vector -} // namespace mlir - -#endif // MLIR_IR_VECTORTYPES_H diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index fe7646140db7e..5f445231b80fd 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -21,7 +21,6 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/VectorTypes.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/APFloat.h" @@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() { // Note, we could relax this for vectors with 1 scalable dim, e.g.: // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32> // However, this would most likely require updating the lowerings to LLVM. - if (isa(type) && - !isa(getValue())) + if (isa(type) && !isa(getValue())) return emitOpError( "intializing scalable vectors with elements attribute is not supported" " unless it's a vector splat");