diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index d9238e41a28cc..ef1f0bf9f8d0e 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1210,6 +1210,12 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { } void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { + if (info_.custom_op_domain_list.empty()) { + common::Status status = CreateTensorRTCustomOpDomainList(info_); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + } + } custom_op_domain_list = info_.custom_op_domain_list; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 3bf6bc05a65df..24c391ee11b84 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -197,7 +197,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { Status ReplayGraph() override; private: - TensorrtExecutionProviderInfo info_; + mutable TensorrtExecutionProviderInfo info_; bool external_stream_ = false; cudaStream_t stream_ = nullptr; int max_partition_iterations_ = 1000; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index b5dbe1ac459b1..d7e13df000272 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -75,11 +75,6 @@ struct Tensorrt_Provider : Provider { info.device_id = device_id; info.has_trt_options = false; - common::Status status = CreateTensorRTCustomOpDomainList(info); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; - } - return std::make_shared(info); } @@ -121,11 +116,6 @@ struct Tensorrt_Provider : Provider { info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes; info.cuda_graph_enable = options.trt_cuda_graph_enable != 0; - common::Status status = CreateTensorRTCustomOpDomainList(info); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; - } - return std::make_shared(info); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b4d47652942b7..f49d45c36b9b6 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -613,9 +613,35 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - // Create Custom Op if EP requests it + // Register Custom Op if EP requests it std::vector custom_op_domains; - p_exec_provider->GetCustomOpDomainList(custom_op_domains); + std::vector candidate_custom_op_domains; + p_exec_provider->GetCustomOpDomainList(candidate_custom_op_domains); + + auto registry_kernels = kernel_registry_manager_.GetKernelRegistriesByProviderType(p_exec_provider->Type()); + + // Register the custom op domain only if it has not been registered before + if (registry_kernels.empty()) { + custom_op_domains = candidate_custom_op_domains; + } else { + for (auto candidate_custom_op_domain : candidate_custom_op_domains) { + for (auto registry_kernel : registry_kernels) { + const auto& kernel_map = registry_kernel->GetKernelCreateMap(); + bool need_register = true; + // If the kernel registry is the ep's custom op registry, we only need to check the first kernel, + // because all kernels in one kernel registry should have the same domain name. + for (auto iter = kernel_map.begin(); iter != kernel_map.end(); iter++) { + if (iter->second.kernel_def->Domain() == candidate_custom_op_domain->domain_) { + need_register = false; + break; + } + } + if (need_register) { + custom_op_domains.push_back(candidate_custom_op_domain); + } + } + } + } if (!custom_op_domains.empty()) { if (AddCustomOpDomains(custom_op_domains) != Status::OK()) { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d950223f2d108..d307f79c372ed 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1625,6 +1625,28 @@ ProviderOptions GetProviderInfo_Cuda(const OrtCUDAProviderOptionsV2* provider_op } // namespace onnxruntime +void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::string extra_plugin_lib_paths) { + auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; + }; + + std::vector custom_op_domains; + onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); + provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); + for (auto ptr : custom_op_domains) { + if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) { + options->custom_op_domains_.push_back(ptr); + } else { + LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; + } + } +} + ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) { API_IMPL_BEGIN auto factory = onnxruntime::DnnlProviderFactoryCreator::Create(use_arena); @@ -1646,13 +1668,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS options->provider_factories.push_back(factory); - std::vector custom_op_domains; std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths"); - onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); - provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); - for (auto ptr : custom_op_domains) { - options->custom_op_domains_.push_back(ptr); - } + AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); return nullptr; API_IMPL_END @@ -1679,12 +1696,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In options->provider_factories.push_back(factory); - std::vector custom_op_domains; - onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); - provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, ""); - for (auto ptr : custom_op_domains) { - options->custom_op_domains_.push_back(ptr); - } + AddTensorRTCustomOpDomainToSessionOption(options, ""); return nullptr; API_IMPL_END @@ -1788,13 +1800,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, options->provider_factories.push_back(factory); - std::vector custom_op_domains; std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths; - onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); - provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); - for (auto ptr : custom_op_domains) { - options->custom_op_domains_.push_back(ptr); - } + AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); return nullptr; API_IMPL_END diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 35e03bf9eacd5..a72f563601512 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -433,6 +433,15 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* #ifdef USE_TENSORRT void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) { if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) { + auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; + }; + std::string trt_extra_plugin_lib_paths = ""; const auto it = options.find("trt_extra_plugin_lib_paths"); if (it != options.end()) { @@ -441,7 +450,11 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti std::vector domain_list; tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths); for (auto ptr : domain_list) { - so.custom_op_domains_.push_back(ptr); + if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) { + so.custom_op_domains_.push_back(ptr); + } else { + LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; + } } } else { ORT_THROW("Please install TensorRT libraries as mentioned in the GPU requirements page, make sure they're in the PATH or LD_LIBRARY_PATH, and that your GPU is supported."); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 1d954fe4370ad..d8628c4288206 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -298,6 +298,20 @@ def test_set_providers_with_options(self): self.assertEqual(option["trt_engine_cache_path"], str(engine_cache_path)) self.assertEqual(option["trt_force_sequential_engine_build"], "1") + from onnxruntime.capi import _pybind_state as C + + session_options = C.get_default_session_options() + + # TRT plugins registered as custom op domain should only be added once in session option regardless of number of session creation + sess1 = onnxrt.InferenceSession( + get_name("mul_1.onnx"), session_options, providers=["TensorrtExecutionProvider"] + ) + sess2 = onnxrt.InferenceSession( + get_name("mul_1.onnx"), session_options, providers=["TensorrtExecutionProvider"] + ) + self.assertIn("TensorrtExecutionProvider", sess1.get_providers()) + self.assertIn("TensorrtExecutionProvider", sess2.get_providers()) + # We currently disable following test code since that not all test machines/GPUs have nvidia int8 capability """