Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ ORT_DEFINE_RELEASE(ValueInfo);
ORT_DEFINE_RELEASE(Node);
ORT_DEFINE_RELEASE(Graph);
ORT_DEFINE_RELEASE(Model);
ORT_DEFINE_RELEASE(KeyValuePairs)
ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi);

#undef ORT_DEFINE_RELEASE
Expand Down Expand Up @@ -675,6 +676,7 @@ struct AllocatedFree {

struct AllocatorWithDefaultOptions;
struct Env;
struct EpDevice;
struct Graph;
struct Model;
struct Node;
Expand Down Expand Up @@ -737,6 +739,77 @@ struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
};

namespace detail {
template <typename T>
struct KeyValuePairsImpl : Ort::detail::Base<T> {
using B = Ort::detail::Base<T>;
using B::B;

const char* GetValue(const char* key) const;

// get the pairs in unordered_map. needs to copy to std::string so the hash works as expected
std::unordered_map<std::string, std::string> GetKeyValuePairs() const;
// get the pairs in two vectors. entries will be 1:1 between keys and values. avoids copying to std::string
void GetKeyValuePairs(std::vector<const char*>& keys, std::vector<const char*>& values) const;
};
} // namespace detail

// Const object holder that does not own the underlying object
using ConstKeyValuePairs = detail::KeyValuePairsImpl<Ort::detail::Unowned<const OrtKeyValuePairs>>;

/** \brief Wrapper around ::OrtKeyValuePair */
struct KeyValuePairs : detail::KeyValuePairsImpl<OrtKeyValuePairs> {
explicit KeyValuePairs(std::nullptr_t) {} ///< No instance is created
/// Take ownership of a pointer created by C API
explicit KeyValuePairs(OrtKeyValuePairs* p) : KeyValuePairsImpl<OrtKeyValuePairs>{p} {}

explicit KeyValuePairs();
explicit KeyValuePairs(const std::unordered_map<std::string, std::string>& kv_pairs);

void Add(const char* key, const char* value);
void Remove(const char* key);

ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; }
};

namespace detail {
template <typename T>
struct HardwareDeviceImpl : Ort::detail::Base<T> {
using B = Ort::detail::Base<T>;
using B::B;

OrtHardwareDeviceType Type() const;
uint32_t VendorId() const;
uint32_t DeviceId() const;
const char* Vendor() const;
ConstKeyValuePairs Metadata() const;
};
} // namespace detail

/** \brief Wrapper around ::OrtHardwareDevice
* \remarks HardwareDevice is always read-only for API users.
*/
using ConstHardwareDevice = detail::HardwareDeviceImpl<Ort::detail::Unowned<const OrtHardwareDevice>>;

namespace detail {
template <typename T>
struct EpDeviceImpl : Ort::detail::Base<T> {
using B = Ort::detail::Base<T>;
using B::B;

const char* EpName() const;
const char* EpVendor() const;
ConstKeyValuePairs EpMetadata() const;
ConstKeyValuePairs EpOptions() const;
ConstHardwareDevice Device() const;
};
} // namespace detail

/** \brief Wrapper around ::OrtEpDevice
* \remarks EpDevice is always read-only for API users.
*/
using ConstEpDevice = detail::EpDeviceImpl<Ort::detail::Unowned<const OrtEpDevice>>;

/** \brief The Env (Environment)
*
* The Env holds the logging state used by all other objects.
Expand Down Expand Up @@ -768,7 +841,14 @@ struct Env : detail::Base<OrtEnv> {

Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator

Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2
Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info,
const std::unordered_map<std::string, std::string>& options,
const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2

Env& RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string<ORTCHAR_T>& path); ///< Wraps OrtApi::RegisterExecutionProviderLibrary
Env& UnregisterExecutionProviderLibrary(const char* registration_name); ///< Wraps OrtApi::UnregisterExecutionProviderLibrary

std::vector<ConstEpDevice> GetEpDevices() const;
};

/** \brief Custom Op Domain
Expand Down Expand Up @@ -919,7 +999,7 @@ struct ConstSessionOptionsImpl : Base<T> {

std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def) const;
};

template <typename T>
Expand Down Expand Up @@ -981,6 +1061,11 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
const std::unordered_map<std::string, std::string>& provider_options = {});

SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector<ConstEpDevice>& ep_devices,
const KeyValuePairs& ep_options);
SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector<ConstEpDevice>& ep_devices,
const std::unordered_map<std::string, std::string>& ep_options);

SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
Expand Down
191 changes: 190 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,120 @@
return *this;
}

namespace detail {
template <typename T>
inline const char* KeyValuePairsImpl<T>::GetValue(const char* key) const {
return GetApi().GetKeyValue(this->p_, key);
}

template <typename T>
inline std::unordered_map<std::string, std::string> KeyValuePairsImpl<T>::GetKeyValuePairs() const {
std::unordered_map<std::string, std::string> out;

size_t num_pairs = 0;
const char* const* keys = nullptr;
const char* const* values = nullptr;
GetApi().GetKeyValuePairs(this->p_, &keys, &values, &num_pairs);
if (num_pairs > 0) {
out.reserve(num_pairs);
for (size_t i = 0; i < num_pairs; ++i) {
out.emplace(keys[i], values[i]);
}
}

return out;
}

template <typename T>
inline void KeyValuePairsImpl<T>::GetKeyValuePairs(std::vector<const char*>& keys,
std::vector<const char*>& values) const {
keys.clear();
values.clear();

size_t num_pairs = 0;
const char* const* keys_ptr = nullptr;
const char* const* values_ptr = nullptr;
GetApi().GetKeyValuePairs(this->p_, &keys_ptr, &values_ptr, &num_pairs);
if (num_pairs > 0) {
keys.resize(num_pairs);
values.resize(num_pairs);
std::copy(keys_ptr, keys_ptr + num_pairs, keys.begin());
std::copy(values_ptr, values_ptr + num_pairs, values.begin());
}
}
} // namespace detail

inline KeyValuePairs::KeyValuePairs() {
GetApi().CreateKeyValuePairs(&p_);
}

inline KeyValuePairs::KeyValuePairs(const std::unordered_map<std::string, std::string>& kv_pairs) {
GetApi().CreateKeyValuePairs(&p_);
for (const auto& kv : kv_pairs) {
GetApi().AddKeyValuePair(this->p_, kv.first.c_str(), kv.second.c_str());
}
}

inline void KeyValuePairs::Add(const char* key, const char* value) {
GetApi().AddKeyValuePair(this->p_, key, value);
}

inline void KeyValuePairs::Remove(const char* key) {
GetApi().RemoveKeyValuePair(this->p_, key);
}

namespace detail {
template <typename T>
inline OrtHardwareDeviceType HardwareDeviceImpl<T>::Type() const {
return GetApi().HardwareDevice_Type(this->p_);
}

template <typename T>
inline uint32_t HardwareDeviceImpl<T>::VendorId() const {
return GetApi().HardwareDevice_VendorId(this->p_);
}

template <typename T>
inline uint32_t HardwareDeviceImpl<T>::DeviceId() const {
return GetApi().HardwareDevice_DeviceId(this->p_);
}

template <typename T>
inline const char* HardwareDeviceImpl<T>::Vendor() const {
return GetApi().HardwareDevice_Vendor(this->p_);
}

template <typename T>
inline ConstKeyValuePairs HardwareDeviceImpl<T>::Metadata() const {
return ConstKeyValuePairs{GetApi().HardwareDevice_Metadata(this->p_)};
}

template <typename T>
inline const char* EpDeviceImpl<T>::EpName() const {
return GetApi().EpDevice_EpName(this->p_);
}

template <typename T>
inline const char* EpDeviceImpl<T>::EpVendor() const {
return GetApi().EpDevice_EpVendor(this->p_);
}

template <typename T>
inline ConstKeyValuePairs EpDeviceImpl<T>::EpMetadata() const {
return ConstKeyValuePairs(GetApi().EpDevice_EpMetadata(this->p_));
}

template <typename T>
inline ConstKeyValuePairs EpDeviceImpl<T>::EpOptions() const {
return ConstKeyValuePairs(GetApi().EpDevice_EpOptions(this->p_));
}

template <typename T>
inline ConstHardwareDevice EpDeviceImpl<T>::Device() const {
return ConstHardwareDevice(GetApi().EpDevice_Device(this->p_));
}
} // namespace detail

inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
if (strcmp(logid, "onnxruntime-node") == 0) {
Expand Down Expand Up @@ -551,6 +665,33 @@
return *this;
}

inline Env& Env::RegisterExecutionProviderLibrary(const char* registration_name,
const std::basic_string<ORTCHAR_T>& path) {
ThrowOnError(GetApi().RegisterExecutionProviderLibrary(p_, registration_name, path.c_str()));
return *this;
}

inline Env& Env::UnregisterExecutionProviderLibrary(const char* registration_name) {
ThrowOnError(GetApi().UnregisterExecutionProviderLibrary(p_, registration_name));
return *this;
}

inline std::vector<ConstEpDevice> Env::GetEpDevices() const {
size_t num_devices = 0;
const OrtEpDevice* const* device_ptrs = nullptr;
ThrowOnError(GetApi().GetEpDevices(p_, &device_ptrs, &num_devices));

std::vector<ConstEpDevice> devices;
if (num_devices > 0) {
devices.reserve(num_devices);
for (size_t i = 0; i < num_devices; ++i) {
devices.emplace_back(device_ptrs[i]);
}
}

return devices;
}

inline CustomOpDomain::CustomOpDomain(const char* domain) {
ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
}
Expand Down Expand Up @@ -717,7 +858,8 @@
}

template <typename T>
inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key,
const std::string& def) const {
if (!this->HasConfigEntry(config_key)) {
return def;
}
Expand Down Expand Up @@ -955,6 +1097,53 @@
return *this;
}

namespace {

Check warning on line 1100 in include/onnxruntime/core/session/onnxruntime_cxx_inline.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use unnamed namespaces in header files. See https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces for more information. [build/namespaces_headers] [4] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline.h:1100: Do not use unnamed namespaces in header files. See https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces for more information. [build/namespaces_headers] [4]
template <typename T>
void SessionOptionsAppendEP(detail::SessionOptionsImpl<T>& session_options,
Env& env, const std::vector<ConstEpDevice>& ep_devices,
const std::vector<const char*>& ep_options_keys,
const std::vector<const char*>& ep_options_values) {
std::vector<const OrtEpDevice*> ep_devices_ptrs;
ep_devices_ptrs.reserve(ep_devices.size());
for (const auto& ep_device : ep_devices) {
ep_devices_ptrs.push_back(ep_device);
}

ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_V2(
session_options, env, ep_devices_ptrs.data(), ep_devices_ptrs.size(),
ep_options_keys.data(), ep_options_values.data(), ep_options_keys.size()));
}
} // namespace

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_V2(
Env& env, const std::vector<ConstEpDevice>& ep_devices, const KeyValuePairs& ep_options) {
std::vector<const char*> ep_options_keys, ep_options_values;
ep_options.GetKeyValuePairs(ep_options_keys, ep_options_values);

SessionOptionsAppendEP(*this, env, ep_devices, ep_options_keys, ep_options_values);

return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_V2(
Env& env, const std::vector<ConstEpDevice>& ep_devices,
const std::unordered_map<std::string, std::string>& ep_options) {
std::vector<const char*> ep_options_keys, ep_options_values;
ep_options_keys.reserve(ep_options.size());
ep_options_values.reserve(ep_options.size());

for (const auto& [key, value] : ep_options) {
ep_options_keys.push_back(key.c_str());
ep_options_values.push_back(value.c_str());
}

SessionOptionsAppendEP(*this, env, ep_devices, ep_options_keys, ep_options_values);

return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
Expand Down
13 changes: 12 additions & 1 deletion onnxruntime/core/session/abi_key_value_pairs.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@ struct OrtKeyValuePairs {
Sync();
}
void Add(const char* key, const char* value) {
return Add(std::string(key), std::string(value));
// ignore if either are nullptr.
if (key && value) {
Add(std::string(key), std::string(value));
}
}

void Add(const std::string& key, const std::string& value) {
if (key.empty()) { // ignore empty keys
return;
}

auto iter_inserted = entries.insert({key, value});
bool inserted = iter_inserted.second;
if (inserted) {
Expand All @@ -37,6 +44,10 @@ struct OrtKeyValuePairs {

// we don't expect this to be common. reconsider using std::vector if it turns out to be.
void Remove(const char* key) {
if (key == nullptr) {
return;
}

auto iter = entries.find(key);
if (iter != entries.end()) {
auto key_iter = std::find(keys.begin(), keys.end(), iter->first.c_str());
Expand Down
Loading
Loading