diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 1a6962717c680..f2cf1ac9fad8b 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -28,6 +28,7 @@ function(get_c_cxx_api_headers HEADERS_VAR) "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_env_config_keys.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" ) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index f7776b1f77383..02d923b9cbc10 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -597,6 +597,7 @@ set (onnxruntime_shared_lib_test_SRC if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_builder_api.cc) + list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_env_creation.cc) endif() if(onnxruntime_RUN_ONNX_TESTS) diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 3a423a64b9047..e5c325f599435 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "core/common/common.h" @@ -20,6 +21,7 @@ #include "core/platform/threadpool.h" #include "core/session/abi_devices.h" +#include "core/session/abi_key_value_pairs.h" #include "core/session/plugin_ep/ep_library.h" #include "core/session/onnxruntime_c_api.h" @@ -51,11 +53,13 @@ class Environment { @param tp_options optional set of parameters controlling the number of intra and inter op threads for the global threadpools. @param create_global_thread_pools determine if this function will create the global threadpools or not. + @param config_entries Application-specified configuration entries. */ static Status Create(std::unique_ptr logging_manager, std::unique_ptr& environment, const OrtThreadingOptions* tp_options = nullptr, - bool create_global_thread_pools = false); + bool create_global_thread_pools = false, + const OrtKeyValuePairs* config_entries = nullptr); /** * Set the global threading options for the environment, if no global thread pools have been created yet. @@ -170,6 +174,17 @@ class Environment { // return a shared allocator from a plugin EP or custom allocator added with RegisterAllocator Status GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocator*& allocator); + /// + /// Returns a copy of the configuration entries set by the application on environment creation. + /// + /// Primarily used by EP libraries to retrieve environment-level configurations, but could be used + /// more generally to specify global settings. + /// + /// Refer to OrtApi::CreateEnvWithOptions(). + /// + /// + OrtKeyValuePairs GetConfigEntries() const; + ~Environment(); private: @@ -177,7 +192,8 @@ class Environment { Status Initialize(std::unique_ptr logging_manager, const OrtThreadingOptions* tp_options = nullptr, - bool create_global_thread_pools = false); + bool create_global_thread_pools = false, + const OrtKeyValuePairs* config_entries = nullptr); Status RegisterAllocatorImpl(AllocatorPtr allocator); Status UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool error_if_not_found = true); @@ -186,6 +202,13 @@ class Environment { const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator, bool replace_existing); + // Inserts (or assigns) a config entry into `config_entries_`. Locks `config_entries_mutex_`. + void InsertOrAssignConfigEntry(std::string key, std::string value); + + // Removes a config entry from `config_entries_`. Does nothing if the key does not exist. + // Locks `config_entries_mutex_`. + void RemoveConfigEntry(const std::string& key); + std::unique_ptr logging_manager_; std::unique_ptr intra_op_thread_pool_; std::unique_ptr inter_op_thread_pool_; @@ -254,6 +277,20 @@ class Environment { DataTransferManager data_transfer_mgr_; // plugin EP IDataTransfer instances #endif // !defined(ORT_MINIMAL_BUILD) + + // Application-specified environment configuration entries + // The environment may add or remove an entry on EP library registration and unregistration, respectively. + OrtKeyValuePairs config_entries_; + mutable std::shared_mutex config_entries_mutex_; // Should be locked when accessing config_entries_ + + // Tracks the number of registered EP libraries that can create virtual devices. + // It is incremented when an EP library is registered with a name that ends in ".virtual". + // It is decremented when that EP library is unregistered. + // If it reaches 0, the config entry "allow_virtual_devices" is removed. + // + // This starts at 1 if user created an OrtEnv with the config "allow_virtual_devices" set to "1" + // to prevent removal of the config entry in that case. + size_t num_allow_virtual_device_uses_{}; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5acac571f3f3b..410b63147a8fe 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -965,14 +965,6 @@ typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options */ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status); -/** \brief The C API - * - * All C API functions are defined inside this structure as pointers to functions. - * Call OrtApiBase::GetApi to get a pointer to it - * - * \nosubgrouping - */ - /** \addtogroup Global * @{ */ @@ -1056,6 +1048,101 @@ typedef enum OrtCompiledModelCompatibility { OrtCompiledModelCompatibility_EP_UNSUPPORTED, } OrtCompiledModelCompatibility; +/** \brief Configuration options for creating an OrtEnv. + * + * \note The version field must be set to ORT_API_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.24. + */ +typedef struct OrtEnvCreationOptions { + uint32_t version; ///< Must be set to ORT_API_VERSION + + /** \brief The logging severity level for the environment. Must be set to a value from OrtLoggingLevel. + * + * \note Logging messages which are less severe than the `logging_severity_level` are not emitted. + * + * \note Serves as the default logging severity level for session creation and runs. + * Use ::SetSessionLogSeverityLevel() to set a logging severity level for the creation of specific session. + * Use ::RunOptionsSetRunLogSeverityLevel() to set a logging severity level for a specific session run. + * + * \since Version 1.24. + */ + int32_t logging_severity_level; + + /** \brief The log identifier. Must be set to a valid UTF-8 null-terminated string. + * + * \note This string identifier is copied by ORT. + * + * \since Version 1.24. + */ + const char* log_id; + + /** \brief Optional custom logging function. May be set to NULL. + * + * \note The OrtEnvCreationOptions::custom_logging_param is provided as the first argument to this logging function. + * This allows passing custom state into the logging function. + * + * \note This function is only called when a message's severity meets or exceeds the set logging severity level. + * + * \since Version 1.24. + */ + OrtLoggingFunction custom_logging_function; + + /** \brief Optional state to pass as the first argument to OrtEnvCreationOptions::custom_logger_function. + * May be set to NULL. + * + * \since Version 1.24. + */ + void* custom_logging_param; + + /** \brief Optional threading options for creating an environment with global thread pools shared across sessions. + * May be set to NULL. + * + * \note The OrtThreadingOptions instance is copied by ORT. + * + * \note Use OrtApi::CreateThreadingOptions() to create an instance of OrtThreadingOptions. + * + * \note Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use its own + * thread pools. + * + * \since Version 1.24. + */ + const OrtThreadingOptions* threading_options; + + /** \brief Optional environment configuration entries represented as string key-value pairs. May be set to NULL. + * + * \note The OrtKeyValuePairs instance is copied by ORT. + * + * \note Refer to onnxruntime_env_config_keys.h for common config entry keys and their supported values. + * + * \note An application provides environment-level configuration options for execution provider libraries by + * using keys with the prefix 'ep_factory..'. Ex: the key 'ep_factory.my_ep.some_ep_key' represents + * a key named 'some_ep_key' that is meant to be consumed by an execution provider named 'my_ep'. Refer to + * the specific execution provider's documentation for valid keys and values. + * + * \note An application may separately set session-level configuration options for execution providers via other APIs + * such as SessionOptionsAppendExecutionProvider_V2, which store configuration entries within OrtSessionOptions. + * If an environment-level configuration conflicts with a session-level configuration, then + * precedence is determined by the execution provider library itself. + * + * \since Version 1.24. + */ + const OrtKeyValuePairs* config_entries; + + // + // End of fields available in ORT 1.24 + // + +} OrtEnvCreationOptions; + +/** \brief The C API + * + * All C API functions are defined inside this structure as pointers to functions. + * Call OrtApiBase::GetApi to get a pointer to it + * + * \nosubgrouping + */ struct OrtApi { /// \name OrtStatus /// @{ @@ -6912,6 +6999,20 @@ struct OrtApi { ORT_CLASS_RELEASE(DeviceEpIncompatibilityDetails); /// @} + + /** \brief Create an OrtEnv instance with the given options. + * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. + * + * \param[in] options The OrtEnvCreationOptions instance that contains creation options. + * \param[out] out Output parameter set to the new OrtEnv instance. Must be freed with OrtApi::ReleaseEnv. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24 + */ + ORT_API2_STATUS(CreateEnvWithOptions, _In_ const OrtEnvCreationOptions* options, _Outptr_ OrtEnv** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d24594f590619..901f7f10f3754 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1185,6 +1185,9 @@ struct Env : detail::Base { Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); + /// \brief Wraps OrtApi::CreateEnvWithOptions + explicit Env(const OrtEnvCreationOptions* options); + /// \brief C Interop Helper explicit Env(OrtEnv* p) : Base{p} {} @@ -3431,5 +3434,8 @@ struct SharedPrePackedWeightCacheImpl : Ort::detail::Base { */ using UnownedSharedPrePackedWeightCache = detail::SharedPrePackedWeightCacheImpl>; + +///< Wraps OrtEpApi::GetEnvConfigEntries() +Ort::KeyValuePairs GetEnvConfigEntries(); } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 18299d2e49343..267838e41887e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -784,6 +784,15 @@ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction loggin } } +inline Env::Env(const OrtEnvCreationOptions* options) { + ThrowOnError(GetApi().CreateEnvWithOptions(options, &p_)); + if (strcmp(options->log_id, "onnxruntime-node") == 0) { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS)); + } else { + ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS)); + } +} + inline Env& Env::EnableTelemetryEvents() { ThrowOnError(GetApi().EnableTelemetryEvents(p_)); return *this; @@ -3779,4 +3788,11 @@ inline Status SharedPrePackedWeightCacheImpl::StoreWeightData(void** buffer_d num_buffers)}; } } // namespace detail + +inline Ort::KeyValuePairs GetEnvConfigEntries() { + OrtKeyValuePairs* entries = nullptr; + Ort::ThrowOnError(GetEpApi().GetEnvConfigEntries(&entries)); + + return Ort::KeyValuePairs{entries}; +} } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_env_config_keys.h b/include/onnxruntime/core/session/onnxruntime_env_config_keys.h new file mode 100644 index 0000000000000..b603765ad3428 --- /dev/null +++ b/include/onnxruntime/core/session/onnxruntime_env_config_keys.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// This file contains well-known keys for OrtEnv configuration entries, which may be used to configure EPs or +// other global settings. +// Refer to OrtEnvCreationOptions::config_entries and OrtApi::CreateEnvWithOptions. +// This file does NOT specify all available keys as EPs may accept custom entries with the prefix "ep..". + +// Key for a boolean option that, when enabled, allows EP factories to create virtual OrtHardwareDevice +// instances via OrtEpApi::CreateHardwareDevice(). +// +// This config entry is automatically set to "1" by ORT if an application registers an EP library with a registration +// name that ends in the suffix ".virtual". See OrtApi::RegisterExecutionProviderLibrary(). +// +// Note: A virtual OrtHardwareDevice does not represent actual hardware on the device, and is identified via the +// metadata entry "is_virtual" with a value of "1". +// +// Allowed values: +// - "0": Default. Creation of virtual devices is not allowed. +// This is the assumed default value if this key is not present in the environment's configuration entries. +// - "1": Creation of virtual devices is allowed. +static const char* const kOrtEnvAllowVirtualDevices = "allow_virtual_devices"; diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 6bb454cd47623..b64e13531c260 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1429,6 +1429,23 @@ struct OrtEpApi { _Outptr_ OrtKernelImpl** kernel_out); ORT_CLASS_RELEASE(KernelImpl); + + /** \brief Gets a new OrtKeyValuePairs instance containing a copy of all configuration entries set on the environment. + * + * \note An application provides environment-level configuration options for execution provider libraries by + * using keys with the prefix 'ep_factory..'. Ex: the key 'ep_factory.my_ep.some_ep_key' represents + * a key named 'some_ep_key' that is meant to be consumed by an execution provider named 'my_ep'. Refer to + * the specific execution provider's documentation for valid keys and values. + * + * \note Refer to onnxruntime_env_config_keys.h for common configuration entry keys and their supported values. + * + * \param[out] out Output parameter set to the OrtKeyValuePairs instance containing all configuration entries. + * Must be released via OrtApi::ReleaseKeyValuePairs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(GetEnvConfigEntries, _Outptr_ OrtKeyValuePairs** config_entries); }; /** @@ -1982,35 +1999,6 @@ struct OrtEpFactory { _In_opt_ const OrtKeyValuePairs* stream_options, _Outptr_ OrtSyncStreamImpl** stream); - /** \brief Set environment options on this EP factory. - * - * Environment options can be set by ORT after calling the library's 'CreateEpFactories' function to - * create EP factories. - * - * Supported options: - * "allow_virtual_devices": Allows EP factory to specify OrtEpDevice instances that use custom - * virtual OrtHardwareDevices, which can be created via OrtEpApi::CreateHardwareDevice(). - * - * A virtual OrtHardwareDevice does not represent actual hardware on the device, and is identified - * via the metadata entry "is_virtual" with a value of "1". - * Refer to onnxruntime_ep_device_ep_metadata_keys.h for well-known OrtHardwareDevice metadata keys. - * - * Allowed values: - * -# "0": Default. Creation of virtual devices is not allowed. - * -# "1": Creation of virtual devices is allowed. - * - * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] options The configuration options. - * - * \note Implementation of this function is optional. - * An EP factory should only implement this if it needs to handle any environment options. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.24. - */ - ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); - /** \brief Check for known incompatibility reasons between a hardware device and this execution provider. * * This function allows an execution provider to check if a specific hardware device is compatible diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index cd8a799115ce6..523ff8eaf13b8 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -16,6 +16,7 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" #include "core/session/inference_session.h" +#include "core/session/onnxruntime_env_config_keys.h" #include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/plugin_ep/ep_library_plugin.h" @@ -120,9 +121,11 @@ std::unordered_set::const_iterator FindExistingAllocator(const st Status Environment::Create(std::unique_ptr logging_manager, std::unique_ptr& environment, const OrtThreadingOptions* tp_options, - bool create_global_thread_pools) { + bool create_global_thread_pools, + const OrtKeyValuePairs* config_entries) { environment = std::make_unique(); - auto status = environment->Initialize(std::move(logging_manager), tp_options, create_global_thread_pools); + auto status = environment->Initialize(std::move(logging_manager), tp_options, create_global_thread_pools, + config_entries); return status; } @@ -242,11 +245,23 @@ Status Environment::CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, co Status Environment::Initialize(std::unique_ptr logging_manager, const OrtThreadingOptions* tp_options, - bool create_global_thread_pools) { + bool create_global_thread_pools, + const OrtKeyValuePairs* config_entries) { auto status = Status::OK(); logging_manager_ = std::move(logging_manager); + if (config_entries != nullptr) { + config_entries_ = *config_entries; + + const auto& config_map = config_entries_.Entries(); + + if (auto iter = config_map.find(kOrtEnvAllowVirtualDevices); + iter != config_map.end() && iter->second == "1") { + num_allow_virtual_device_uses_ = 1; + } + } + // create thread pools if (create_global_thread_pools) { ORT_RETURN_IF_ERROR(SetGlobalThreadingOptions(*tp_options)); @@ -474,6 +489,21 @@ Status Environment::GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocat return Status::OK(); } +OrtKeyValuePairs Environment::GetConfigEntries() const { + std::shared_lock lock{config_entries_mutex_}; + return config_entries_; // copy +} + +void Environment::InsertOrAssignConfigEntry(std::string key, std::string value) { + std::lock_guard lock{config_entries_mutex_}; + config_entries_.Add(std::move(key), std::move(value)); +} + +void Environment::RemoveConfigEntry(const std::string& key) { + std::lock_guard lock{config_entries_mutex_}; + config_entries_.Remove(key.c_str()); +} + #if !defined(ORT_MINIMAL_BUILD) // @@ -496,6 +526,14 @@ Status CreateDataTransferForFactory(OrtEpFactory& ep_factory, return Status::OK(); } + +bool AreVirtualDevicesAllowed(std::string_view lib_registration_name) { + constexpr std::string_view suffix{".virtual"}; + + return lib_registration_name.size() >= suffix.size() && + lib_registration_name.compare(lib_registration_name.size() - suffix.size(), + suffix.size(), suffix) == 0; +} } // namespace Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, @@ -576,6 +614,19 @@ Status Environment::RegisterExecutionProviderLibrary(const std::string& registra std::vector internal_factories = {}; std::unique_ptr ep_library; + // An application can allow EP libraries to create virtual devices by using an EP library registration name that + // ends in the suffix ".virtual". If so, ORT automatically sets the config key "allow_virtual_devices" to "1" + // in the environment. We track the number of libraries that use virtual devices to be able to remove + // "allow_virtual_devices" from the config entries when the last library is unregistered. In practice, + // we expect only one such library to be registered for cross-compilation. + if (AreVirtualDevicesAllowed(registration_name)) { + if (num_allow_virtual_device_uses_ == 0) { + InsertOrAssignConfigEntry(kOrtEnvAllowVirtualDevices, "1"); + } + + num_allow_virtual_device_uses_ += 1; + } + // This will create an EpLibraryPlugin or an EpLibraryProviderBridge depending on what the library supports. ORT_RETURN_IF_ERROR(LoadPluginOrProviderBridge(registration_name, lib_path, ep_library, internal_factories)); @@ -583,20 +634,31 @@ Status Environment::RegisterExecutionProviderLibrary(const std::string& registra return RegisterExecutionProviderLibrary(registration_name, std::move(ep_library), internal_factories); } -Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_name) { +Status Environment::UnregisterExecutionProviderLibrary(const std::string& registration_name) { std::lock_guard lock{mutex_}; - if (ep_libraries_.count(ep_name) == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution provider library: ", ep_name, " was not registered."); + if (ep_libraries_.count(registration_name) == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution provider library: ", registration_name, + " was not registered."); } auto status = Status::OK(); ORT_TRY { - auto ep_info = std::move(ep_libraries_[ep_name]); + auto ep_info = std::move(ep_libraries_[registration_name]); + + // Clean up environment config entry that may have been added to enable virtual devices. + if (AreVirtualDevicesAllowed(registration_name)) { + num_allow_virtual_device_uses_ -= 1; + + if (num_allow_virtual_device_uses_ == 0) { + RemoveConfigEntry(kOrtEnvAllowVirtualDevices); + } + } + // remove from map and global list of OrtEpDevice* before unloading so we don't get a leftover entry if // something goes wrong in any of the following steps.. - ep_libraries_.erase(ep_name); + ep_libraries_.erase(registration_name); for (auto* data_transfer : ep_info->data_transfers) { ORT_RETURN_IF_ERROR(data_transfer_mgr_.UnregisterDataTransfer(data_transfer)); @@ -629,8 +691,8 @@ Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_nam } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to unregister EP library: ", ep_name, " with error: ", - ex.what()); + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to unregister EP library: ", registration_name, + " with error: ", ex.what()); }); } @@ -708,29 +770,6 @@ const std::vector& GetSortedHardwareDevices() { static const auto sorted_devices = SortDevicesByType(); return sorted_devices; } - -bool AreVirtualDevicesAllowed(std::string_view lib_registration_name) { - constexpr std::string_view suffix{".virtual"}; - - return lib_registration_name.size() >= suffix.size() && - lib_registration_name.compare(lib_registration_name.size() - suffix.size(), - suffix.size(), suffix) == 0; -} - -Status SetEpFactoryEnvironmentOptions(OrtEpFactory& factory, std::string_view lib_registration_name) { - // OrtEpFactory::SetEnvironmentOptions was added in ORT 1.24 - if (factory.ort_version_supported < 24 || factory.SetEnvironmentOptions == nullptr) { - return Status::OK(); - } - - // We only set one option now but this can be generalized if necessary. - OrtKeyValuePairs options; - options.Add("allow_virtual_devices", AreVirtualDevicesAllowed(lib_registration_name) ? "1" : "0"); - - ORT_RETURN_IF_ERROR(ToStatusAndRelease(factory.SetEnvironmentOptions(&factory, &options))); - - return Status::OK(); -} } // namespace const std::vector& Environment::GetSortedOrtHardwareDevices() const { @@ -851,8 +890,6 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u auto& factory = *factory_ptr; - ORT_RETURN_IF_ERROR(SetEpFactoryEnvironmentOptions(factory, instance.library->RegistrationName())); - std::array ep_devices{nullptr}; size_t num_ep_devices = 0; ORT_RETURN_IF_ERROR(ToStatusAndRelease( diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 8cca7f2872c44..afb17f867fc00 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -196,7 +196,9 @@ ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLogger, OrtLoggingFunction loggi API_IMPL_BEGIN OrtEnv::LoggingManagerConstructionInfo lm_info{logging_function, logger_param, logging_level, logid}; Status status; - *out = OrtEnv::GetInstance(lm_info, status); + OrtEnvPtr ort_env = OrtEnv::GetOrCreateInstance(lm_info, status); + + *out = ort_env.release(); return ToOrtStatus(status); API_IMPL_END } @@ -206,7 +208,9 @@ ORT_API_STATUS_IMPL(OrtApis::CreateEnv, OrtLoggingLevel logging_level, API_IMPL_BEGIN OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, logging_level, logid}; Status status; - *out = OrtEnv::GetInstance(lm_info, status); + OrtEnvPtr ort_env = OrtEnv::GetOrCreateInstance(lm_info, status); + + *out = ort_env.release(); return ToOrtStatus(status); API_IMPL_END } @@ -216,7 +220,9 @@ ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithGlobalThreadPools, OrtLoggingLevel log API_IMPL_BEGIN OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, logging_level, logid}; Status status; - *out = OrtEnv::GetInstance(lm_info, status, tp_options); + OrtEnvPtr ort_env = OrtEnv::GetOrCreateInstance(lm_info, status, tp_options); + + *out = ort_env.release(); return ToOrtStatus(status); API_IMPL_END } @@ -227,7 +233,56 @@ ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLoggerAndGlobalThreadPools, OrtL API_IMPL_BEGIN OrtEnv::LoggingManagerConstructionInfo lm_info{logging_function, logger_param, logging_level, logid}; Status status; - *out = OrtEnv::GetInstance(lm_info, status, tp_options); + OrtEnvPtr ort_env = OrtEnv::GetOrCreateInstance(lm_info, status, tp_options); + + *out = ort_env.release(); + return ToOrtStatus(status); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithOptions, _In_ const OrtEnvCreationOptions* options, _Outptr_ OrtEnv** out) { + API_IMPL_BEGIN + if (options == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "CreateEnvWithOptions requires a valid (non-null) OrtEnvCreationOptions argument"); + } + + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "CreateEnvWithOptions requires a valid (non-null) output parameter into which to " + "store the new OrtEnv instance"); + } + + // Both this API function and OrtEnvCreationOptions were added in ORT 1.24, so check that the user + // filled out the version correctly. + if (options->version < 24) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "CreateEnvWithOptions requires a OrtEnvCreationOptions argument with the version set " + "equal to ORT_API_VERSION"); + } + + if (options->logging_severity_level < OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE || + options->logging_severity_level > OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "CreateEnvWithOptions requires a OrtEnvCreationOptions argument " + "with a valid logging severity level value from the OrtLoggingLevel enumeration"); + } + + if (options->log_id == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "CreateEnvWithOptions requires a OrtEnvCreationOptions argument " + "with a valid (non-null) log identifier string"); + } + + OrtLoggingLevel logging_severity_level = static_cast(options->logging_severity_level); + OrtEnv::LoggingManagerConstructionInfo lm_info(options->custom_logging_function, + options->custom_logging_param, + logging_severity_level, + options->log_id); + Status status; + OrtEnvPtr ort_env = OrtEnv::GetOrCreateInstance(lm_info, status, options->threading_options, options->config_entries); + + *out = ort_env.release(); return ToOrtStatus(status); API_IMPL_END } @@ -4291,6 +4346,7 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::GetInteropApi, &OrtApis::SessionGetEpDeviceForOutputs, + &OrtApis::GetNumHardwareDevices, &OrtApis::GetHardwareDevices, &OrtApis::GetHardwareDeviceEpIncompatibilityDetails, @@ -4298,6 +4354,8 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::DeviceEpIncompatibilityDetails_GetNotes, &OrtApis::DeviceEpIncompatibilityDetails_GetErrorCode, &OrtApis::ReleaseDeviceEpIncompatibilityDetails, + + &OrtApis::CreateEnvWithOptions, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index a93d853592dea..a38ee0c1eab11 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -782,4 +782,6 @@ ORT_API_STATUS_IMPL(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, _In_ size_t num_outputs); +// OrtEnv +ORT_API_STATUS_IMPL(CreateEnvWithOptions, _In_ const OrtEnvCreationOptions* options, _Outptr_ OrtEnv** out); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index f07f3a1530b7d..12757a3a662a6 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -52,9 +52,11 @@ OrtEnv::~OrtEnv() { #endif } -OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, - onnxruntime::common::Status& status, - const OrtThreadingOptions* tp_options) { +/*static*/ +OrtEnvPtr OrtEnv::GetOrCreateInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_info, + onnxruntime::common::Status& status, + const OrtThreadingOptions* tp_options, + const OrtKeyValuePairs* config_entries) { std::lock_guard lock(m_); if (!p_instance_) { std::unique_ptr lmgr; @@ -76,14 +78,13 @@ OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_inf LoggingManager::InstanceType::Default, &name); + const bool create_global_thread_pools = tp_options != nullptr; std::unique_ptr env; - if (!tp_options) { - status = onnxruntime::Environment::Create(std::move(lmgr), env); - } else { - status = onnxruntime::Environment::Create(std::move(lmgr), env, tp_options, true); - } + status = onnxruntime::Environment::Create(std::move(lmgr), env, tp_options, + create_global_thread_pools, config_entries); + if (!status.IsOK()) { - return nullptr; + return OrtEnvPtr(nullptr, OrtEnv::Release); } // Use 'new' to allocate OrtEnv, as it will be managed by p_instance_ // and deleted in ReleaseEnv or leaked if g_is_process_shutting_down is true. @@ -91,9 +92,10 @@ OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_inf } ++ref_count_; - return p_instance_; + return OrtEnvPtr(p_instance_, OrtEnv::Release); } +/*static*/ void OrtEnv::Release(OrtEnv* env_ptr) { if (!env_ptr) { return; // nothing to release @@ -131,6 +133,17 @@ void OrtEnv::Release(OrtEnv* env_ptr) { delete instance_to_delete; } +/*static*/ +OrtEnvPtr OrtEnv::TryGetInstance() { + std::lock_guard lock(m_); + + if (p_instance_) { + ++ref_count_; + } + + return OrtEnvPtr(p_instance_, OrtEnv::Release); +} + onnxruntime::logging::LoggingManager* OrtEnv::GetLoggingManager() const { return value_->GetLoggingManager(); } diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 94c8e0a6ea2e8..2f6c270e9c5e7 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -14,6 +14,9 @@ namespace onnxruntime { class Environment; } +// Managed pointer type for OrtEnv that calls OrtEnv::Release as its deleter. +using OrtEnvPtr = std::unique_ptr; + struct OrtEnv { public: struct LoggingManagerConstructionInfo { @@ -31,9 +34,24 @@ struct OrtEnv { const char* logid{}; }; - static OrtEnv* GetInstance(const LoggingManagerConstructionInfo& lm_info, - onnxruntime::common::Status& status, - const OrtThreadingOptions* tp_options = nullptr); + /// + /// Gets or creates the global OrtEnv instance. Arguments are ignored if the instance has already been created. + /// + /// Configuration for the logging manager. + /// Output parameter that indicates if an error occurred during environment creation. + /// Optional threading options. + /// Optional configuration entries. + /// The OrtEnv instance. + static OrtEnvPtr GetOrCreateInstance(const LoggingManagerConstructionInfo& lm_info, + onnxruntime::common::Status& status, + const OrtThreadingOptions* tp_options = nullptr, + const OrtKeyValuePairs* config_entries = nullptr); + + /// + /// Gets the global OrtEnv instance. Returns nullptr if the instance has not yet been created. + /// + /// The OrtEnv instance or nullptr. + static OrtEnvPtr TryGetInstance(); static void Release(OrtEnv* env_ptr); @@ -58,7 +76,7 @@ struct OrtEnv { // Using a smart pointer like std::unique_ptr would complicate this specific // shutdown scenario, as it would attempt to deallocate the memory even if // Release() hasn't been called or if a leak is desired. - // Management is handled by GetInstance() and Release(), with ref_count_ + // Management is handled by GetOrCreateInstance(), TryGetInstance(), and Release(), with ref_count_ // tracking active users. It is set to nullptr when the last reference is released // (and not shutting down). static OrtEnv* p_instance_; diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index 21e6ae1525838..c1d93a7620e74 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -21,8 +21,10 @@ #include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" #include "core/session/abi_ep_types.h" +#include "core/session/environment.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/ort_apis.h" +#include "core/session/ort_env.h" #include "core/session/plugin_ep/ep_kernel_registration.h" #include "core/session/plugin_ep/ep_control_flow_kernel_impls.h" #include "core/session/utils.h" @@ -784,6 +786,26 @@ ORT_API(void, ReleaseKernelImpl, _Frees_ptr_opt_ OrtKernelImpl* kernel_impl) { } } +ORT_API_STATUS_IMPL(GetEnvConfigEntries, _Outptr_ OrtKeyValuePairs** config_entries) { + API_IMPL_BEGIN + OrtEnvPtr ort_env = OrtEnv::TryGetInstance(); + + if (ort_env == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "OrtEnv instance does not exist"); + } + + if (config_entries == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "GetEnvConfigEntries requires a valid (non-null) output parameter into which to store " + "the new OrtKeyValuePairs instance"); + } + + auto entries_unique_ptr = std::make_unique(ort_env->GetEnvironment().GetConfigEntries()); + *config_entries = entries_unique_ptr.release(); + return nullptr; + API_IMPL_END +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -845,6 +867,7 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::CreateLoopKernel, &OrtExecutionProviderApi::CreateScanKernel, &OrtExecutionProviderApi::ReleaseKernelImpl, + &OrtExecutionProviderApi::GetEnvConfigEntries, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index 853342c2c3a53..23bbe23026f9a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -116,4 +116,7 @@ ORT_API_STATUS_IMPL(CreateLoopKernel, _In_ const OrtKernelInfo* kernel_info, _In ORT_API_STATUS_IMPL(CreateScanKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtScanKernelHelper* helper, _Outptr_ OrtKernelImpl** kernel_out); ORT_API(void, ReleaseKernelImpl, _Frees_ptr_opt_ OrtKernelImpl* kernel_impl); + +// Env config entries +ORT_API_STATUS_IMPL(GetEnvConfigEntries, _Outptr_ OrtKeyValuePairs** config_entries); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index a13645e293844..fe36c9ea0cdd1 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -32,7 +32,6 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer; OrtEpFactory::IsStreamAware = Forward::IsStreamAware; OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; - OrtEpFactory::SetEnvironmentOptions = Forward::SetEnvironmentOptions; OrtEpFactory::CreateExternalResourceImporterForDevice = Forward::CreateExternalResourceImporterForDevice; OrtEpFactory::GetHardwareDeviceIncompatibilityDetails = Forward::GetHardwareDeviceIncompatibilityDetails; } diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 6f4a37f44fb44..ae09c763bbcbf 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -87,10 +87,6 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->ValidateCompiledModelCompatibilityInfo(devices, num_devices, compatibility_info, model_compatibility); } - OrtStatus* SetEnvironmentOptions(_In_ const OrtKeyValuePairs* options) noexcept { - return impl_->SetEnvironmentOptions(options); - } - OrtStatus* CreateExternalResourceImporterForDevice(_In_ const OrtEpDevice* ep_device, _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { return impl_->CreateExternalResourceImporterForDevice(ep_device, importer); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 7f42cdda33a96..f562ee73f2aaa 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -83,11 +83,6 @@ class EpFactoryInternalImpl { "CreateSyncStreamForDevice is not implemented for this EP factory."); } - virtual OrtStatus* SetEnvironmentOptions(const OrtKeyValuePairs* /*options*/) noexcept { - // Default implementation does not handle any options. - return nullptr; - } - virtual OrtStatus* CreateExternalResourceImporterForDevice( _In_ const OrtEpDevice* /*ep_device*/, _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 27c453b500017..ce9d06da75cb3 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -82,11 +82,6 @@ struct ForwardToFactoryImpl { return static_cast(this_ptr)->CreateSyncStreamForDevice(memory_device, stream_options, stream); } - static OrtStatus* ORT_API_CALL SetEnvironmentOptions(_In_ OrtEpFactory* this_ptr, - _In_ const OrtKeyValuePairs* options) noexcept { - return static_cast(this_ptr)->SetEnvironmentOptions(options); - } - static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDevice( _In_ OrtEpFactory* this_ptr, _In_ const OrtEpDevice* ep_device, diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index 5bf16439c0917..e1c883f960dde 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -35,7 +35,7 @@ static Status CreateOrtEnv() { Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON); OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, "Default"}; Status status; - ort_env = OrtEnv::GetInstance(lm_info, status, use_global_tp ? &global_tp_options : nullptr); + ort_env = OrtEnv::GetOrCreateInstance(lm_info, status, use_global_tp ? &global_tp_options : nullptr).release(); if (!status.IsOK()) return status; // Keep the ort_env alive, don't free it. It's ok to leak the memory. #if !defined(__APPLE__) && !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc index d841e70187f70..19d4df64cf2e8 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.cc @@ -12,19 +12,19 @@ EpFactoryVirtualGpu::EpFactoryVirtualGpu(const OrtApi& ort_api, const OrtEpApi& ep_api, const OrtModelEditorApi& model_editor_api, + bool allow_virtual_devices, const OrtLogger& /*default_logger*/) : OrtEpFactory{}, ort_api_(ort_api), ep_api_(ep_api), model_editor_api_(model_editor_api), - allow_virtual_devices_{false} { + allow_virtual_devices_{allow_virtual_devices} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; - SetEnvironmentOptions = SetEnvironmentOptionsImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -69,19 +69,6 @@ const char* ORT_API_CALL EpFactoryVirtualGpu::GetVersionImpl(const OrtEpFactory* return factory->ep_version_.c_str(); } -/*static*/ -OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::SetEnvironmentOptionsImpl(OrtEpFactory* this_ptr, - const OrtKeyValuePairs* options) noexcept { - auto* factory = static_cast(this_ptr); - const char* value = factory->ort_api_.GetKeyValue(options, "allow_virtual_devices"); - - if (value != nullptr) { - factory->allow_virtual_devices_ = strcmp(value, "1") == 0; - } - - return nullptr; -} - /*static*/ OrtStatus* ORT_API_CALL EpFactoryVirtualGpu::GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* /*devices*/, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h index 1d708d9b40963..fa2542e2ef8fc 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_factory.h @@ -16,7 +16,7 @@ class EpFactoryVirtualGpu : public OrtEpFactory { public: EpFactoryVirtualGpu(const OrtApi& ort_api, const OrtEpApi& ep_api, const OrtModelEditorApi& model_editor_api, - const OrtLogger& default_logger); + bool allow_virtual_devices, const OrtLogger& default_logger); ~EpFactoryVirtualGpu(); const OrtApi& GetOrtApi() const { return ort_api_; } @@ -66,9 +66,6 @@ class EpFactoryVirtualGpu : public OrtEpFactory { const OrtKeyValuePairs* stream_options, OrtSyncStreamImpl** stream) noexcept; - static OrtStatus* ORT_API_CALL SetEnvironmentOptionsImpl(OrtEpFactory* this_ptr, - const OrtKeyValuePairs* options) noexcept; - const OrtApi& ort_api_; const OrtEpApi& ep_api_; const OrtModelEditorApi& model_editor_api_; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc index 1e438e156828d..07c97697267c5 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_virt_gpu/ep_lib_entry.cc @@ -3,10 +3,9 @@ #include -#define ORT_API_MANUAL_INIT -#include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT +#include "core/session/onnxruntime_env_config_keys.h" +#include "../plugin_ep_utils.h" #include "ep_factory.h" // To make symbols visible on macOS/iOS @@ -23,6 +22,7 @@ extern "C" { EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, const OrtLogger* default_logger, OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + EXCEPTION_TO_RETURNED_STATUS_BEGIN const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); const OrtEpApi* ep_api = ort_api->GetEpApi(); const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); @@ -30,18 +30,30 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* /*registration_name*/, co // Manual init for the C++ API Ort::InitApi(ort_api); - std::unique_ptr factory = std::make_unique(*ort_api, *ep_api, *model_editor_api, - *default_logger); - if (max_factories < 1) { return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, "Not enough space to return EP factory. Need at least one."); } + Ort::KeyValuePairs env_configs = Ort::GetEnvConfigEntries(); + + // Extract a config that determines whether creating virtual hardware devices is allowed. + // An application can allow an EP library to create virtual devices in two ways: + // 1. Use an EP library registration name that ends in the suffix ".virtual". If so, ORT will automatically + // set the config key "allow_virtual_devices" to "1" in the environment. + // 2. Directly set the config key "allow_virtual_devices" to "1" when creating the + // OrtEnv via OrtApi::CreateEnvWithOptions(). + const char* config_value = env_configs.GetValue(kOrtEnvAllowVirtualDevices); + const bool allow_virtual_devices = config_value != nullptr && strcmp(config_value, "1") == 0; + + std::unique_ptr factory = std::make_unique(*ort_api, *ep_api, *model_editor_api, + allow_virtual_devices, *default_logger); + factories[0] = factory.release(); *num_factories = 1; return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END } EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { diff --git a/onnxruntime/test/autoep/library/plugin_ep_utils.h b/onnxruntime/test/autoep/library/plugin_ep_utils.h index f7b8dc4d2be0d..7cd60b5afd9d4 100644 --- a/onnxruntime/test/autoep/library/plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/plugin_ep_utils.h @@ -103,6 +103,15 @@ struct FloatInitializer { std::vector data; }; +// Returns a lower case version of the input string. +inline std::string GetLowercaseString(std::string str) { + // https://en.cppreference.com/w/cpp/string/byte/tolower + std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return str; +} + // Returns an entry in the session option configurations, or a default value if not present. inline OrtStatus* GetSessionConfigEntryOrDefault(const OrtSessionOptions& session_options, const char* config_key, const std::string& default_val, diff --git a/onnxruntime/test/autoep/test_registration.cc b/onnxruntime/test/autoep/test_registration.cc index 7415c5e138874..7b6679ffaf462 100644 --- a/onnxruntime/test/autoep/test_registration.cc +++ b/onnxruntime/test/autoep/test_registration.cc @@ -7,12 +7,15 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" +#include "core/session/onnxruntime_env_config_keys.h" #include "test/autoep/test_autoep_utils.h" #include "test/util/include/api_asserts.h" #include "test/util/include/asserts.h" extern std::unique_ptr ort_env; +extern "C" void ortenv_setup(); +extern "C" void ortenv_teardown(); namespace onnxruntime { namespace test { @@ -94,8 +97,8 @@ TEST(OrtEpLibrary, LoadUnloadPluginVirtGpuLibraryCxxApi) { const std::string& registration_name = "example_plugin_ep_virt_gpu"; const std::string& ep_name = Utils::example_ep_virt_gpu_info.ep_name; - auto get_plugin_ep_devices = [&]() -> std::vector { - std::vector all_ep_devices = ort_env->GetEpDevices(); + auto get_plugin_ep_devices = [&](Ort::Env& env) -> std::vector { + std::vector all_ep_devices = env.GetEpDevices(); std::vector ep_devices; std::copy_if(all_ep_devices.begin(), all_ep_devices.end(), std::back_inserter(ep_devices), @@ -123,7 +126,7 @@ TEST(OrtEpLibrary, LoadUnloadPluginVirtGpuLibraryCxxApi) { ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); // Find ep devices for this EP. Should not get any. - std::vector ep_devices = get_plugin_ep_devices(); + std::vector ep_devices = get_plugin_ep_devices(*ort_env); ASSERT_EQ(ep_devices.size(), 0); ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); @@ -138,7 +141,7 @@ TEST(OrtEpLibrary, LoadUnloadPluginVirtGpuLibraryCxxApi) { ort_env->RegisterExecutionProviderLibrary(registration_name_for_virtual_devices.c_str(), library_path.c_str()); // Find ep devices for this EP. Should get a virtual gpu. - std::vector ep_devices = get_plugin_ep_devices(); + std::vector ep_devices = get_plugin_ep_devices(*ort_env); ASSERT_EQ(ep_devices.size(), 1); auto virt_gpu_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), @@ -166,6 +169,43 @@ TEST(OrtEpLibrary, LoadUnloadPluginVirtGpuLibraryCxxApi) { ort_env->UnregisterExecutionProviderLibrary(registration_name_for_virtual_devices.c_str()); } + + // Test using OrtApi::CreateEnvWithOptions to explicitly set a config that enables virtual devices. + // The EP should return a OrtEpDevice for a virtual GPU. + + ortenv_teardown(); // Release current OrtEnv as we need to recreate it. + + auto run_test = [&]() -> void { + // Create OrtEnv with config entry to enable virtual devices. + Ort::KeyValuePairs env_configs; + env_configs.Add(kOrtEnvAllowVirtualDevices, "1"); + + OrtEnvCreationOptions env_options{}; + env_options.version = ORT_API_VERSION; + env_options.logging_severity_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO; + env_options.log_id = "LoadUnloadPluginVirtGpuLibraryCxxApi"; + env_options.config_entries = env_configs.GetConst(); + + Ort::Env tmp_env(&env_options); + + // Register EP library. It should be able to extract the env config entry that enables virtual devices. + tmp_env.RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + // Find ep devices for this EP. Should get a virtual gpu. + std::vector ep_devices = get_plugin_ep_devices(tmp_env); + ASSERT_EQ(ep_devices.size(), 1); + + auto virt_gpu_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), + [](Ort::ConstEpDevice& ep_device) { + return ep_device.Device().Type() == OrtHardwareDeviceType_GPU; + }); + + ASSERT_TRUE(is_hw_device_virtual(virt_gpu_ep_device->Device())); + tmp_env.UnregisterExecutionProviderLibrary(registration_name.c_str()); + }; + + EXPECT_NO_FATAL_FAILURE(run_test()); + ortenv_setup(); // Restore OrtEnv } } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/shared_lib/test_env_creation.cc b/onnxruntime/test/shared_lib/test_env_creation.cc new file mode 100644 index 0000000000000..a7b5087fe373e --- /dev/null +++ b/onnxruntime/test/shared_lib/test_env_creation.cc @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/util/include/api_asserts.h" + +extern std::unique_ptr ort_env; +extern "C" void ortenv_setup(); +extern "C" void ortenv_teardown(); + +TEST(EnvCreation, CreateEnvWithOptions) { + const OrtApi& ort_api = Ort::GetApi(); + + // Basic error checking when user passes an invalid version for OrtEnvCreationOptions + { + OrtEnv* test_env = nullptr; + OrtEnvCreationOptions options{}; + options.version = 0; // Invalid! + options.logging_severity_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; + options.log_id = "test logger"; + + Ort::Status status{ort_api.CreateEnvWithOptions(&options, &test_env)}; + + ASSERT_EQ(status.GetErrorCode(), ORT_INVALID_ARGUMENT); + ASSERT_THAT(status.GetErrorMessage(), testing::HasSubstr("version set equal to ORT_API_VERSION")); + } + + // Basic error checking when user passes an invalid log identifier to the API function + { + OrtEnv* test_env = nullptr; + OrtEnvCreationOptions options{}; + options.version = ORT_API_VERSION; + options.logging_severity_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; + options.log_id = nullptr; // Invalid! + + Ort::Status status{ort_api.CreateEnvWithOptions(&options, &test_env)}; + + ASSERT_EQ(status.GetErrorCode(), ORT_INVALID_ARGUMENT); + ASSERT_THAT(status.GetErrorMessage(), testing::HasSubstr("valid (non-null) log identifier string")); + } + + // Basic error checking when user passes an invalid logging severity level + { + OrtEnv* test_env = nullptr; + OrtEnvCreationOptions options{}; + options.version = ORT_API_VERSION; + options.logging_severity_level = 100; // Invalid! + options.log_id = "EnvCreation.CreateEnvWithOptions"; + + Ort::Status status{ort_api.CreateEnvWithOptions(&options, &test_env)}; + + ASSERT_EQ(status.GetErrorCode(), ORT_INVALID_ARGUMENT); + ASSERT_THAT(status.GetErrorMessage(), testing::HasSubstr("valid logging severity level value from " + "the OrtLoggingLevel enumeration")); + } + + // Create an OrtEnv with configuration entries. Use the CXX API. + + ortenv_teardown(); // Release current OrtEnv as we need to recreate it. + + auto run_test = [&]() -> void { + // Create OrtEnv with some dummy config entry. + Ort::KeyValuePairs env_configs; + env_configs.Add("some_key", "some_val"); + + OrtEnvCreationOptions options{}; + options.version = ORT_API_VERSION; + options.logging_severity_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE; + options.log_id = "EnvCreation.CreateEnvWithOptions_2"; + options.config_entries = env_configs.GetConst(); + + Ort::Env tmp_env(&options); + + // Use EP API to retrieve environment configs and check contents + Ort::KeyValuePairs env_configs_2 = Ort::GetEnvConfigEntries(); + + auto configs_expected = env_configs.GetKeyValuePairs(); + auto configs_actual = env_configs_2.GetKeyValuePairs(); + ASSERT_EQ(configs_actual, configs_expected); + }; + + EXPECT_NO_FATAL_FAILURE(run_test()); + ortenv_setup(); // Restore OrtEnv +} diff --git a/orttraining/orttraining/python/orttraining_pybind_common.h b/orttraining/orttraining/python/orttraining_pybind_common.h index 6304fc4ab11ad..50922e638b457 100644 --- a/orttraining/orttraining/python/orttraining_pybind_common.h +++ b/orttraining/orttraining/python/orttraining_pybind_common.h @@ -20,7 +20,7 @@ using ExecutionProviderLibInfoMap = std::unordered_map ort_env); + ORTTrainingPythonEnv(OrtEnvPtr ort_env); const OrtEnv& GetORTEnv() const; OrtEnv& GetORTEnv(); @@ -46,7 +46,7 @@ class ORTTrainingPythonEnv { std::string GetExecutionProviderMapKey(const std::string& provider_type, size_t hash); - std::unique_ptr ort_env_; + OrtEnvPtr ort_env_; // NOTE: the EPs in the following map probably depends on dynamic EP DLLs that are going to be unloaded by OrtEnv's destructor if we delete OrtEnv ExecutionProviderMap execution_provider_instances_map_; std::vector available_training_eps_; diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 3d611a0881fdf..67e6e90726d38 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -109,7 +109,7 @@ bool GetProviderInstanceHash(const std::string& type, return false; } -ORTTrainingPythonEnv::ORTTrainingPythonEnv(std::unique_ptr ort_env) : ort_env_(std::move(ort_env)) { +ORTTrainingPythonEnv::ORTTrainingPythonEnv(OrtEnvPtr ort_env) : ort_env_(std::move(ort_env)) { const auto& builtinEPs = GetAvailableExecutionProviderNames(); available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end()); } @@ -173,7 +173,7 @@ static Status CreateOrtEnv() { Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON); OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, "Default"}; Status status; - std::unique_ptr ort_env(OrtEnv::GetInstance(lm_info, status, use_global_tp ? &global_tp_options : nullptr)); + OrtEnvPtr ort_env = OrtEnv::GetOrCreateInstance(lm_info, status, use_global_tp ? &global_tp_options : nullptr); if (!status.IsOK()) return status; #if !defined(__APPLE__) && !defined(ORT_MINIMAL_BUILD) if (!InitProvidersSharedLibrary()) {