Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/Arith/IR/Arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ namespace arith {
class ConstantIntOp : public arith::ConstantOp {
public:
using arith::ConstantOp::ConstantOp;
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }

/// Build a constant int op that produces an integer of the specified width.
static void build(OpBuilder &builder, OperationState &result, int64_t value,
Expand All @@ -74,6 +75,7 @@ class ConstantIntOp : public arith::ConstantOp {
class ConstantFloatOp : public arith::ConstantOp {
public:
using arith::ConstantOp::ConstantOp;
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }

/// Build a constant float op that produces a float of the specified type.
static void build(OpBuilder &builder, OperationState &result,
Expand All @@ -90,7 +92,7 @@ class ConstantFloatOp : public arith::ConstantOp {
class ConstantIndexOp : public arith::ConstantOp {
public:
using arith::ConstantOp::ConstantOp;

static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant int op that produces an index.
static void build(OpBuilder &builder, OperationState &result, int64_t value);

Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,21 +252,21 @@ class TransformDialectExtension

template <typename OpTy>
void TransformDialect::addOperationIfNotRegistered() {
StringRef name = OpTy::getOperationName();
std::optional<RegisteredOperationName> opName =
RegisteredOperationName::lookup(name, getContext());
RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext());
if (!opName) {
addOperations<OpTy>();
#ifndef NDEBUG
StringRef name = OpTy::getOperationName();
detail::checkImplementsTransformOpInterface(name, getContext());
#endif // NDEBUG
return;
}

if (opName->getTypeID() == TypeID::get<OpTy>())
if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
return;

reportDuplicateOpRegistration(name);
reportDuplicateOpRegistration(OpTy::getOperationName());
}

template <typename Type>
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ class OpBuilder : public Builder {
template <typename OpT>
RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
std::optional<RegisteredOperationName> opName =
RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
RegisteredOperationName::lookup(TypeID::get<OpT>(), ctx);
if (LLVM_UNLIKELY(!opName)) {
llvm::report_fatal_error(
"Building op `" + OpT::getOperationName() +
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -1729,8 +1729,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
template <typename... Models>
static void attachInterface(MLIRContext &context) {
std::optional<RegisteredOperationName> info =
RegisteredOperationName::lookup(ConcreteType::getOperationName(),
&context);
RegisteredOperationName::lookup(TypeID::get<ConcreteType>(), &context);
if (!info)
llvm::report_fatal_error(
"Attempting to attach an interface to an unregistered operation " +
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/OperationSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,11 @@ class RegisteredOperationName : public OperationName {
static std::optional<RegisteredOperationName> lookup(StringRef name,
MLIRContext *ctx);

/// Lookup the registered operation information for the given operation.
/// Returns std::nullopt if the operation isn't registered.
static std::optional<RegisteredOperationName> lookup(TypeID typeID,
MLIRContext *ctx);

/// Register a new operation in a Dialect object.
/// This constructor is used by Dialect objects when they register the list
/// of operations they contain.
Expand Down
29 changes: 22 additions & 7 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ class MLIRContextImpl {
llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;

/// A vector of operation info specifically for registered operations.
llvm::StringMap<RegisteredOperationName> registeredOperations;
llvm::DenseMap<TypeID, RegisteredOperationName> registeredOperations;
llvm::StringMap<RegisteredOperationName> registeredOperationsByName;

/// This is a sorted container of registered operations for a deterministic
/// and efficient `getRegisteredOperations` implementation.
Expand Down Expand Up @@ -780,8 +781,8 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
// Check the registered info map first. In the overwhelmingly common case,
// the entry will be in here and it also removes the need to acquire any
// locks.
auto registeredIt = ctxImpl.registeredOperations.find(name);
if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
auto registeredIt = ctxImpl.registeredOperationsByName.find(name);
if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) {
impl = registeredIt->second.impl;
return;
}
Expand Down Expand Up @@ -909,10 +910,19 @@ OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) {
//===----------------------------------------------------------------------===//

std::optional<RegisteredOperationName>
RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) {
auto &impl = ctx->getImpl();
auto it = impl.registeredOperations.find(name);
auto it = impl.registeredOperations.find(typeID);
if (it != impl.registeredOperations.end())
return it->second;
return std::nullopt;
}

std::optional<RegisteredOperationName>
RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
auto &impl = ctx->getImpl();
auto it = impl.registeredOperationsByName.find(name);
if (it != impl.registeredOperationsByName.end())
return it->getValue();
return std::nullopt;
}
Expand Down Expand Up @@ -945,11 +955,16 @@ void RegisteredOperationName::insert(

// Update the registered info for this operation.
auto emplaced = ctxImpl.registeredOperations.try_emplace(
name, RegisteredOperationName(impl));
impl->getTypeID(), RegisteredOperationName(impl));
assert(emplaced.second && "operation name registration must be successful");
auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace(
name, RegisteredOperationName(impl));
(void)emplacedByName;
assert(emplacedByName.second &&
"operation name registration must be successful");

// Add emplaced operation name to the sorted operations container.
RegisteredOperationName &value = emplaced.first->getValue();
RegisteredOperationName &value = emplaced.first->second;
ctxImpl.sortedRegisteredOperations.insert(
llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
[](auto &lhs, auto &rhs) {
Expand Down