Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,38 @@ enum class SliceVerificationResult {
SliceVerificationResult isRankReducedType(ShapedType originalType,
ShapedType candidateReducedType);

//===----------------------------------------------------------------------===//
// Convenience wrappers for VectorType
//
// These are provided to allow idiomatic code like:
// * isa<vector::ScalableVectorType>(type)
//===----------------------------------------------------------------------===//
/// A vector type containing at least one scalable dimension.
class ScalableVectorType : public VectorType {
public:
using VectorType::VectorType;

static bool classof(Type type) {
auto vecTy = llvm::dyn_cast<VectorType>(type);
if (!vecTy)
return false;
return vecTy.isScalable();
}
};

/// A vector type with no scalable dimensions.
class FixedVectorType : public VectorType {
public:
using VectorType::VectorType;

static bool classof(Type type) {
auto vecTy = llvm::dyn_cast<VectorType>(type);
if (!vecTy)
return false;
return !vecTy.isScalable();
}
};

//===----------------------------------------------------------------------===//
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 4 additions & 7 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;

// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;

// Whether a type is a fixed-length VectorType.
def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
!::llvm::cast<VectorType>($_self).isScalable()}]>;
def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;

// Whether a type is a scalable VectorType.
def IsVectorTypeWithAnyDimScalablePred
: CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
::llvm::cast<VectorType>($_self).isScalable()}]>;
: CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;

// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
// Examples:
Expand Down
39 changes: 0 additions & 39 deletions mlir/include/mlir/IR/VectorTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,3 @@
// * isa<vector::ScalableVectorType>(type)
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_VECTORTYPES_H
#define MLIR_IR_VECTORTYPES_H

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"

namespace mlir {
namespace vector {

/// A vector type containing at least one scalable dimension.
class ScalableVectorType : public VectorType {
public:
using VectorType::VectorType;

static bool classof(Type type) {
auto vecTy = llvm::dyn_cast<VectorType>(type);
if (!vecTy)
return false;
return vecTy.isScalable();
}
};

/// A vector type with no scalable dimensions.
class FixedVectorType : public VectorType {
public:
using VectorType::VectorType;
static bool classof(Type type) {
auto vecTy = llvm::dyn_cast<VectorType>(type);
if (!vecTy)
return false;
return !vecTy.isScalable();
}
};

} // namespace vector
} // namespace mlir

#endif // MLIR_IR_VECTORTYPES_H
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/VectorTypes.h"
#include "mlir/Support/LogicalResult.h"

#include "llvm/ADT/APFloat.h"
Expand Down Expand Up @@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
// Note, we could relax this for vectors with 1 scalable dim, e.g.:
// * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
// However, this would most likely require updating the lowerings to LLVM.
if (isa<vector::ScalableVectorType>(type) &&
!isa<SplatElementsAttr>(getValue()))
if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
return emitOpError(
"intializing scalable vectors with elements attribute is not supported"
" unless it's a vector splat");
Expand Down
Loading