diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a2937b6e82a27..a9deb2dd3e341 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -559,6 +559,7 @@ ORT_DEFINE_RELEASE(ValueInfo); ORT_DEFINE_RELEASE(Node); ORT_DEFINE_RELEASE(Graph); ORT_DEFINE_RELEASE(Model); +ORT_DEFINE_RELEASE(KeyValuePairs) ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); #undef ORT_DEFINE_RELEASE @@ -675,6 +676,7 @@ struct AllocatedFree { struct AllocatorWithDefaultOptions; struct Env; +struct EpDevice; struct Graph; struct Model; struct Node; @@ -737,6 +739,77 @@ struct ThreadingOptions : detail::Base { ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); }; +namespace detail { +template +struct KeyValuePairsImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + const char* GetValue(const char* key) const; + + // get the pairs in unordered_map. needs to copy to std::string so the hash works as expected + std::unordered_map GetKeyValuePairs() const; + // get the pairs in two vectors. entries will be 1:1 between keys and values. avoids copying to std::string + void GetKeyValuePairs(std::vector& keys, std::vector& values) const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstKeyValuePairs = detail::KeyValuePairsImpl>; + +/** \brief Wrapper around ::OrtKeyValuePair */ +struct KeyValuePairs : detail::KeyValuePairsImpl { + explicit KeyValuePairs(std::nullptr_t) {} ///< No instance is created + /// Take ownership of a pointer created by C API + explicit KeyValuePairs(OrtKeyValuePairs* p) : KeyValuePairsImpl{p} {} + + explicit KeyValuePairs(); + explicit KeyValuePairs(const std::unordered_map& kv_pairs); + + void Add(const char* key, const char* value); + void Remove(const char* key); + + ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; } +}; + +namespace detail { +template +struct HardwareDeviceImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + OrtHardwareDeviceType Type() const; + uint32_t VendorId() const; + uint32_t DeviceId() const; + const char* Vendor() const; + ConstKeyValuePairs Metadata() const; +}; +} // namespace detail + +/** \brief Wrapper around ::OrtHardwareDevice + * \remarks HardwareDevice is always read-only for API users. + */ +using ConstHardwareDevice = detail::HardwareDeviceImpl>; + +namespace detail { +template +struct EpDeviceImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + const char* EpName() const; + const char* EpVendor() const; + ConstKeyValuePairs EpMetadata() const; + ConstKeyValuePairs EpOptions() const; + ConstHardwareDevice Device() const; +}; +} // namespace detail + +/** \brief Wrapper around ::OrtEpDevice + * \remarks EpDevice is always read-only for API users. + */ +using ConstEpDevice = detail::EpDeviceImpl>; + /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. @@ -768,7 +841,14 @@ struct Env : detail::Base { Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator - Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 + Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, + const std::unordered_map& options, + const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 + + Env& RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string& path); ///< Wraps OrtApi::RegisterExecutionProviderLibrary + Env& UnregisterExecutionProviderLibrary(const char* registration_name); ///< Wraps OrtApi::UnregisterExecutionProviderLibrary + + std::vector GetEpDevices() const; }; /** \brief Custom Op Domain @@ -919,7 +999,7 @@ struct ConstSessionOptionsImpl : Base { std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry - std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def); + std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def) const; }; template @@ -981,6 +1061,11 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); + SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector& ep_devices, + const KeyValuePairs& ep_options); + SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector& ep_devices, + const std::unordered_map& ep_options); + SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index e41ef005349ac..57b4f1b3ead66 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -479,6 +479,120 @@ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustom return *this; } +namespace detail { +template +inline const char* KeyValuePairsImpl::GetValue(const char* key) const { + return GetApi().GetKeyValue(this->p_, key); +} + +template +inline std::unordered_map KeyValuePairsImpl::GetKeyValuePairs() const { + std::unordered_map out; + + size_t num_pairs = 0; + const char* const* keys = nullptr; + const char* const* values = nullptr; + GetApi().GetKeyValuePairs(this->p_, &keys, &values, &num_pairs); + if (num_pairs > 0) { + out.reserve(num_pairs); + for (size_t i = 0; i < num_pairs; ++i) { + out.emplace(keys[i], values[i]); + } + } + + return out; +} + +template +inline void KeyValuePairsImpl::GetKeyValuePairs(std::vector& keys, + std::vector& values) const { + keys.clear(); + values.clear(); + + size_t num_pairs = 0; + const char* const* keys_ptr = nullptr; + const char* const* values_ptr = nullptr; + GetApi().GetKeyValuePairs(this->p_, &keys_ptr, &values_ptr, &num_pairs); + if (num_pairs > 0) { + keys.resize(num_pairs); + values.resize(num_pairs); + std::copy(keys_ptr, keys_ptr + num_pairs, keys.begin()); + std::copy(values_ptr, values_ptr + num_pairs, values.begin()); + } +} +} // namespace detail + +inline KeyValuePairs::KeyValuePairs() { + GetApi().CreateKeyValuePairs(&p_); +} + +inline KeyValuePairs::KeyValuePairs(const std::unordered_map& kv_pairs) { + GetApi().CreateKeyValuePairs(&p_); + for (const auto& kv : kv_pairs) { + GetApi().AddKeyValuePair(this->p_, kv.first.c_str(), kv.second.c_str()); + } +} + +inline void KeyValuePairs::Add(const char* key, const char* value) { + GetApi().AddKeyValuePair(this->p_, key, value); +} + +inline void KeyValuePairs::Remove(const char* key) { + GetApi().RemoveKeyValuePair(this->p_, key); +} + +namespace detail { +template +inline OrtHardwareDeviceType HardwareDeviceImpl::Type() const { + return GetApi().HardwareDevice_Type(this->p_); +} + +template +inline uint32_t HardwareDeviceImpl::VendorId() const { + return GetApi().HardwareDevice_VendorId(this->p_); +} + +template +inline uint32_t HardwareDeviceImpl::DeviceId() const { + return GetApi().HardwareDevice_DeviceId(this->p_); +} + +template +inline const char* HardwareDeviceImpl::Vendor() const { + return GetApi().HardwareDevice_Vendor(this->p_); +} + +template +inline ConstKeyValuePairs HardwareDeviceImpl::Metadata() const { + return ConstKeyValuePairs{GetApi().HardwareDevice_Metadata(this->p_)}; +} + +template +inline const char* EpDeviceImpl::EpName() const { + return GetApi().EpDevice_EpName(this->p_); +} + +template +inline const char* EpDeviceImpl::EpVendor() const { + return GetApi().EpDevice_EpVendor(this->p_); +} + +template +inline ConstKeyValuePairs EpDeviceImpl::EpMetadata() const { + return ConstKeyValuePairs(GetApi().EpDevice_EpMetadata(this->p_)); +} + +template +inline ConstKeyValuePairs EpDeviceImpl::EpOptions() const { + return ConstKeyValuePairs(GetApi().EpDevice_EpOptions(this->p_)); +} + +template +inline ConstHardwareDevice EpDeviceImpl::Device() const { + return ConstHardwareDevice(GetApi().EpDevice_Device(this->p_)); +} +} // namespace detail + inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) { ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_)); if (strcmp(logid, "onnxruntime-node") == 0) { @@ -551,6 +665,33 @@ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, return *this; } +inline Env& Env::RegisterExecutionProviderLibrary(const char* registration_name, + const std::basic_string& path) { + ThrowOnError(GetApi().RegisterExecutionProviderLibrary(p_, registration_name, path.c_str())); + return *this; +} + +inline Env& Env::UnregisterExecutionProviderLibrary(const char* registration_name) { + ThrowOnError(GetApi().UnregisterExecutionProviderLibrary(p_, registration_name)); + return *this; +} + +inline std::vector Env::GetEpDevices() const { + size_t num_devices = 0; + const OrtEpDevice* const* device_ptrs = nullptr; + ThrowOnError(GetApi().GetEpDevices(p_, &device_ptrs, &num_devices)); + + std::vector devices; + if (num_devices > 0) { + devices.reserve(num_devices); + for (size_t i = 0; i < num_devices; ++i) { + devices.emplace_back(device_ptrs[i]); + } + } + + return devices; +} + inline CustomOpDomain::CustomOpDomain(const char* domain) { ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); } @@ -717,7 +858,8 @@ inline bool ConstSessionOptionsImpl::HasConfigEntry(const char* config_key) c } template -inline std::string ConstSessionOptionsImpl::GetConfigEntryOrDefault(const char* config_key, const std::string& def) { +inline std::string ConstSessionOptionsImpl::GetConfigEntryOrDefault(const char* config_key, + const std::string& def) const { if (!this->HasConfigEntry(config_key)) { return def; } @@ -955,6 +1097,53 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider( return *this; } +namespace { +template +void SessionOptionsAppendEP(detail::SessionOptionsImpl& session_options, + Env& env, const std::vector& ep_devices, + const std::vector& ep_options_keys, + const std::vector& ep_options_values) { + std::vector ep_devices_ptrs; + ep_devices_ptrs.reserve(ep_devices.size()); + for (const auto& ep_device : ep_devices) { + ep_devices_ptrs.push_back(ep_device); + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_V2( + session_options, env, ep_devices_ptrs.data(), ep_devices_ptrs.size(), + ep_options_keys.data(), ep_options_values.data(), ep_options_keys.size())); +} +} // namespace + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_V2( + Env& env, const std::vector& ep_devices, const KeyValuePairs& ep_options) { + std::vector ep_options_keys, ep_options_values; + ep_options.GetKeyValuePairs(ep_options_keys, ep_options_values); + + SessionOptionsAppendEP(*this, env, ep_devices, ep_options_keys, ep_options_values); + + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_V2( + Env& env, const std::vector& ep_devices, + const std::unordered_map& ep_options) { + std::vector ep_options_keys, ep_options_values; + ep_options_keys.reserve(ep_options.size()); + ep_options_values.reserve(ep_options.size()); + + for (const auto& [key, value] : ep_options) { + ep_options_keys.push_back(key.c_str()); + ep_options_values.push_back(value.c_str()); + } + + SessionOptionsAppendEP(*this, env, ep_devices, ep_options_keys, ep_options_values); + + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn)); diff --git a/onnxruntime/core/session/abi_key_value_pairs.h b/onnxruntime/core/session/abi_key_value_pairs.h index 28de183fde405..3242be817881a 100644 --- a/onnxruntime/core/session/abi_key_value_pairs.h +++ b/onnxruntime/core/session/abi_key_value_pairs.h @@ -19,10 +19,17 @@ struct OrtKeyValuePairs { Sync(); } void Add(const char* key, const char* value) { - return Add(std::string(key), std::string(value)); + // ignore if either are nullptr. + if (key && value) { + Add(std::string(key), std::string(value)); + } } void Add(const std::string& key, const std::string& value) { + if (key.empty()) { // ignore empty keys + return; + } + auto iter_inserted = entries.insert({key, value}); bool inserted = iter_inserted.second; if (inserted) { @@ -37,6 +44,10 @@ struct OrtKeyValuePairs { // we don't expect this to be common. reconsider using std::vector if it turns out to be. void Remove(const char* key) { + if (key == nullptr) { + return; + } + auto iter = entries.find(key); if (iter != entries.end()) { auto key_iter = std::find(keys.begin(), keys.end(), iter->first.c_str()); diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index b5d9c81f250c2..a7fa1bccf3210 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -58,7 +58,7 @@ template & model_uri, const std::string& ep_to_select, std::optional library_path, - const OrtKeyValuePairs& provider_options, + const Ort::KeyValuePairs& ep_options, const std::vector& inputs, const char* output_name, const std::vector& expected_dims_y, @@ -75,13 +75,15 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod if (auto_select) { // manually specify EP to select for now - ASSERT_ORTSTATUS_OK(Ort::GetApi().AddSessionConfigEntry(session_options, "test.ep_to_select", - ep_to_select.c_str())); + session_options.AddConfigEntry("test.ep_to_select", ep_to_select.c_str()); + // add the provider options to the session options with the required prefix const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_to_select.c_str()); - for (const auto& [key, value] : provider_options.entries) { + std::vector keys, values; + ep_options.GetKeyValuePairs(keys, values); + for (size_t i = 0, end = keys.size(); i < end; ++i) { // add the default value with prefix - session_options.AddConfigEntry((option_prefix + key).c_str(), value.c_str()); + session_options.AddConfigEntry((option_prefix + keys[i]).c_str(), values[i]); } } else { std::vector devices; @@ -92,9 +94,17 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod DefaultDeviceSelection(ep_to_select, devices); } - ASSERT_ORTSTATUS_OK(Ort::GetApi().SessionOptionsAppendExecutionProvider_V2( - session_options, env, devices.data(), devices.size(), - provider_options.keys.data(), provider_options.values.data(), provider_options.entries.size())); + // C API. Test the C++ API because if it works the C API must also work. + // ASSERT_ORTSTATUS_OK(Ort::GetApi().SessionOptionsAppendExecutionProvider_V2( + // session_options, env, devices.data(), devices.size(), + // provider_options.keys.data(), provider_options.values.data(), provider_options.entries.size())); + std::vector ep_devices; + ep_devices.reserve(devices.size()); + for (const auto* device : devices) { + ep_devices.emplace_back(device); + } + + session_options.AppendExecutionProvider_V2(*ort_env, ep_devices, ep_options); } // if session creation passes, model loads fine @@ -115,7 +125,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod namespace { void RunBasicTest(const std::string& ep_name, std::optional library_path, - const OrtKeyValuePairs& provider_options = {}, + const Ort::KeyValuePairs& provider_options = Ort::KeyValuePairs{}, const std::function&)>& select_devices = nullptr) { const auto run_test = [&](bool auto_select) { std::vector> inputs(1); @@ -149,7 +159,7 @@ TEST(AutoEpSelection, CpuEP) { #if defined(USE_CUDA) TEST(AutoEpSelection, CudaEP) { - OrtKeyValuePairs provider_options; + Ort::KeyValuePairs provider_options; provider_options.Add("prefer_nhwc", "1"); RunBasicTest(kCudaExecutionProvider, "onnxruntime_providers_cuda", provider_options); } @@ -157,7 +167,7 @@ TEST(AutoEpSelection, CudaEP) { #if defined(USE_DML) TEST(AutoEpSelection, DmlEP) { - OrtKeyValuePairs provider_options; + Ort::KeyValuePairs provider_options; provider_options.Add("disable_metacommands", "true"); // checking options are passed through const auto select_devices = [&](std::vector& devices) { @@ -204,16 +214,71 @@ TEST(AutoEpSelection, WebGpuEP) { } #endif -TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { +// tests for AutoEP selection related things in the API that aren't covered by the other tests. +TEST(AutoEpSelection, MiscApiTests) { + const OrtApi* c_api = &Ort::GetApi(); + + // nullptr and empty input to OrtKeyValuePairs + { + OrtKeyValuePairs* kvps = nullptr; + c_api->CreateKeyValuePairs(&kvps); + c_api->AddKeyValuePair(kvps, "key1", nullptr); // should be ignored + c_api->AddKeyValuePair(kvps, nullptr, "value1"); // should be ignored + c_api->RemoveKeyValuePair(kvps, nullptr); // should be ignored + + c_api->AddKeyValuePair(kvps, "", "value2"); // empty key should be ignored + ASSERT_EQ(c_api->GetKeyValue(kvps, ""), nullptr); + + c_api->AddKeyValuePair(kvps, "key2", ""); // empty value is allowed + ASSERT_EQ(c_api->GetKeyValue(kvps, "key2"), std::string("")); + } + + // construct KVP from std::unordered_map + { + std::unordered_map kvps; + kvps["key1"] = "value1"; + kvps["key2"] = "value2"; + Ort::KeyValuePairs ort_kvps(kvps); + ASSERT_EQ(ort_kvps.GetValue("key1"), std::string("value1")); + ASSERT_EQ(ort_kvps.GetValue("key2"), std::string("value2")); + } + + std::vector ep_devices = ort_env->GetEpDevices(); + + // explicit EP selection with Ort::KeyValuePairs for options + { + Ort::SessionOptions session_options; + Ort::KeyValuePairs ep_options; + ep_options.Add("option1", "true"); + session_options.AppendExecutionProvider_V2(*ort_env, {ep_devices[0]}, ep_options); + } + + // explicit EP selection with for options + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + ep_options["option1"] = "true"; + session_options.AppendExecutionProvider_V2(*ort_env, {ep_devices[0]}, ep_options); + } +} + +namespace { +struct ExamplePluginInfo { + const std::filesystem::path library_path = #if _WIN32 - std::filesystem::path library_path = "example_plugin_ep.dll"; + "example_plugin_ep.dll"; #else - std::filesystem::path library_path = "libexample_plugin_ep.so"; + "libexample_plugin_ep.so"; #endif - const std::string registration_name = "example_ep"; +}; - Ort::SessionOptions session_options; +static const ExamplePluginInfo example_plugin_info; +} // namespace + +TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { + const std::filesystem::path& library_path = example_plugin_info.library_path; + const std::string& registration_name = example_plugin_info.registration_name; OrtEnv* c_api_env = *ort_env; const OrtApi* c_api = &Ort::GetApi(); @@ -238,6 +303,48 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(c_api_env, registration_name.c_str())); } + +TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { + const std::filesystem::path& library_path = example_plugin_info.library_path; + const std::string& registration_name = example_plugin_info.registration_name; + + // this should load the library and create OrtEpDevice + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + std::vector ep_devices = ort_env->GetEpDevices(); + + // should be one device for the example EP + auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), + [®istration_name](Ort::ConstEpDevice& device) { + // the example uses the registration name for the EP name + // but that is not a requirement and the two can differ. + return device.EpName() == registration_name; + }); + ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by the test library."; + + // test all the C++ getters. expected values are from \onnxruntime\test\autoep\library\example_plugin_ep.cc + ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso"); + + auto metadata = test_ep_device->EpMetadata(); + ASSERT_STREQ(metadata.GetValue("version"), "0.1"); + + auto options = test_ep_device->EpOptions(); + ASSERT_STREQ(options.GetValue("run_really_fast"), "true"); + + // the CPU device info will vary by machine so check for the lowest common denominator values + Ort::ConstHardwareDevice device = test_ep_device->Device(); + ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU); + ASSERT_GE(device.VendorId(), 0); + ASSERT_GE(device.DeviceId(), 0); + ASSERT_NE(device.Vendor(), nullptr); + Ort::ConstKeyValuePairs device_metadata = device.Metadata(); + std::unordered_map metadata_entries = device_metadata.GetKeyValuePairs(); + ASSERT_GT(metadata_entries.size(), 0); // should have at least SPDRP_HARDWAREID on Windows + + // and this should unload it without throwing + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); +} + } // namespace test } // namespace onnxruntime