diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h index 971c78f4a86a7..00cdb13feb29b 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -53,6 +53,7 @@ namespace arith { class ConstantIntOp : public arith::ConstantOp { public: using arith::ConstantOp::ConstantOp; + static ::mlir::TypeID resolveTypeID() { return TypeID::get(); } /// Build a constant int op that produces an integer of the specified width. static void build(OpBuilder &builder, OperationState &result, int64_t value, @@ -74,6 +75,7 @@ class ConstantIntOp : public arith::ConstantOp { class ConstantFloatOp : public arith::ConstantOp { public: using arith::ConstantOp::ConstantOp; + static ::mlir::TypeID resolveTypeID() { return TypeID::get(); } /// Build a constant float op that produces a float of the specified type. static void build(OpBuilder &builder, OperationState &result, @@ -90,7 +92,7 @@ class ConstantFloatOp : public arith::ConstantOp { class ConstantIndexOp : public arith::ConstantOp { public: using arith::ConstantOp::ConstantOp; - + static ::mlir::TypeID resolveTypeID() { return TypeID::get(); } /// Build a constant int op that produces an index. static void build(OpBuilder &builder, OperationState &result, int64_t value); diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h index db27f2c6fc49b..128eacdbe6ab7 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -252,21 +252,21 @@ class TransformDialectExtension template void TransformDialect::addOperationIfNotRegistered() { - StringRef name = OpTy::getOperationName(); std::optional opName = - RegisteredOperationName::lookup(name, getContext()); + RegisteredOperationName::lookup(TypeID::get(), getContext()); if (!opName) { addOperations(); #ifndef NDEBUG + StringRef name = OpTy::getOperationName(); detail::checkImplementsTransformOpInterface(name, getContext()); #endif // NDEBUG return; } - if (opName->getTypeID() == TypeID::get()) + if (LLVM_LIKELY(opName->getTypeID() == TypeID::get())) return; - reportDuplicateOpRegistration(name); + reportDuplicateOpRegistration(OpTy::getOperationName()); } template diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 43b6d2b384169..3beade017d1ab 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -490,7 +490,7 @@ class OpBuilder : public Builder { template RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) { std::optional opName = - RegisteredOperationName::lookup(OpT::getOperationName(), ctx); + RegisteredOperationName::lookup(TypeID::get(), ctx); if (LLVM_UNLIKELY(!opName)) { llvm::report_fatal_error( "Building op `" + OpT::getOperationName() + diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index bd68c27445744..c177ae3594d11 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1729,8 +1729,7 @@ class Op : public OpState, public Traits... { template static void attachInterface(MLIRContext &context) { std::optional info = - RegisteredOperationName::lookup(ConcreteType::getOperationName(), - &context); + RegisteredOperationName::lookup(TypeID::get(), &context); if (!info) llvm::report_fatal_error( "Attempting to attach an interface to an unregistered operation " + diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index f2aa6cee84030..90e63ff8fcb38 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -676,6 +676,11 @@ class RegisteredOperationName : public OperationName { static std::optional lookup(StringRef name, MLIRContext *ctx); + /// Lookup the registered operation information for the given operation. + /// Returns std::nullopt if the operation isn't registered. + static std::optional lookup(TypeID typeID, + MLIRContext *ctx); + /// Register a new operation in a Dialect object. /// This constructor is used by Dialect objects when they register the list /// of operations they contain. diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index e1e6d14231d9f..214b354c5347e 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -183,7 +183,8 @@ class MLIRContextImpl { llvm::StringMap> operations; /// A vector of operation info specifically for registered operations. - llvm::StringMap registeredOperations; + llvm::DenseMap registeredOperations; + llvm::StringMap registeredOperationsByName; /// This is a sorted container of registered operations for a deterministic /// and efficient `getRegisteredOperations` implementation. @@ -780,8 +781,8 @@ OperationName::OperationName(StringRef name, MLIRContext *context) { // Check the registered info map first. In the overwhelmingly common case, // the entry will be in here and it also removes the need to acquire any // locks. - auto registeredIt = ctxImpl.registeredOperations.find(name); - if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) { + auto registeredIt = ctxImpl.registeredOperationsByName.find(name); + if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) { impl = registeredIt->second.impl; return; } @@ -909,10 +910,19 @@ OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) { //===----------------------------------------------------------------------===// std::optional -RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) { +RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) { auto &impl = ctx->getImpl(); - auto it = impl.registeredOperations.find(name); + auto it = impl.registeredOperations.find(typeID); if (it != impl.registeredOperations.end()) + return it->second; + return std::nullopt; +} + +std::optional +RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) { + auto &impl = ctx->getImpl(); + auto it = impl.registeredOperationsByName.find(name); + if (it != impl.registeredOperationsByName.end()) return it->getValue(); return std::nullopt; } @@ -945,11 +955,16 @@ void RegisteredOperationName::insert( // Update the registered info for this operation. auto emplaced = ctxImpl.registeredOperations.try_emplace( - name, RegisteredOperationName(impl)); + impl->getTypeID(), RegisteredOperationName(impl)); assert(emplaced.second && "operation name registration must be successful"); + auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace( + name, RegisteredOperationName(impl)); + (void)emplacedByName; + assert(emplacedByName.second && + "operation name registration must be successful"); // Add emplaced operation name to the sorted operations container. - RegisteredOperationName &value = emplaced.first->getValue(); + RegisteredOperationName &value = emplaced.first->second; ctxImpl.sortedRegisteredOperations.insert( llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value, [](auto &lhs, auto &rhs) {