Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 89 additions & 26 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,19 @@ struct Exception : std::exception {
throw Ort::Exception(string, code)
#endif

// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
// it's in a template so that we can define a global variable in a header and make
// it transparent to the users of the API.
template <typename T>
struct Global {
static const OrtApi* api_;
};

// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
template <typename T>
#ifdef ORT_API_MANUAL_INIT
const OrtApi* Global<T>::api_{};
inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }

// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
// required by C++ APIs.
// If the macro ORT_API_MANUAL_INIT is defined, no static initialization
// will be performed. Instead, users must call InitApi() before using the
// ORT C++ APIs..
//
// InitApi() sets the global API object using the default initialization
// logic. Users call this to initialize the ORT C++ APIs at a time that
// makes sense in their program.
inline void InitApi() noexcept;

// InitApi(const OrtApi*) is used by custom operator libraries that are not
// linked to onnxruntime. It sets the global API object, which is required
// by the ORT C++ APIs.
//
// Example mycustomop.cc:
//
Expand All @@ -107,22 +104,88 @@ inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(OR
// // ...
// }
//
inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
#else
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
// Please define ORT_API_MANUAL_INIT if it conerns you.
#pragma warning(disable : 26426)
inline void InitApi(const OrtApi* api) noexcept;
#endif
const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)

namespace detail {
// This is used internally by the C++ API. This class holds the global
// variable that points to the OrtApi.
struct Global {
static const OrtApi* Api(const OrtApi* newValue = nullptr) noexcept {
// This block-level static will be initialized once when this function is
// first executed, delaying the call to DefaultInit() until it is first needed.
//
// When ORT_API_MANUAL_INIT is not defined, DefaultInit() calls
// OrtGetApiBase()->GetApi(), which may result in a shared library being
// loaded.
//
// Using a block-level static instead of a class-level static helps
// avoid issues with static initialization order and dynamic libraries
// loading other dynamic libraries.
//
// This makes it safe to include the C++ API headers in a shared library
// that is delay loaded or delay loads its dependencies.
//
// This DOES NOT make it safe to _use_ arbitrary ORT C++ APIs when
// initializing static members, however.
static const OrtApi* api = DefaultInit();

if (newValue) {
api = newValue;
}

return api;
}

private:
// Has different definitions based on ORT_API_MANUAL_INIT
static const OrtApi* DefaultInit() noexcept;

#ifdef ORT_API_MANUAL_INIT
// Public APIs to set the OrtApi* to use.
friend void ::Ort::InitApi() noexcept;
friend void ::Ort::InitApi(const OrtApi*) noexcept;
#endif
};
} // namespace detail

#ifdef ORT_API_MANUAL_INIT

// See comments on declaration above for usage.
inline void InitApi(const OrtApi* api) noexcept { detail::Global::Api(api); }
inline void InitApi() noexcept { InitApi(OrtGetApiBase()->GetApi(ORT_API_VERSION)); }

#ifdef _MSC_VER
// If you get a linker error about a mismatch here, you are trying to
// link two compilation units that have different definitions for
// ORT_API_MANUAL_INIT together. All compilation units must agree on the
// definition of ORT_API_MANUAL_INIT.
#pragma detect_mismatch("ORT_API_MANUAL_INIT", "enabled")
#endif

inline const OrtApi* detail::Global::DefaultInit() noexcept {
// When ORT_API_MANUAL_INIT is defined, there's no default init that can
// be done.
return nullptr;
}

#else // ORT_API_MANUAL_INIT

#ifdef _MSC_VER
// If you get a linker error about a mismatch here, you are trying to link
// two compilation units that have different definitions for
// ORT_API_MANUAL_INIT together. All compilation units must agree on the
// definition of ORT_API_MANUAL_INIT.
#pragma detect_mismatch("ORT_API_MANUAL_INIT", "disabled")
#endif

inline const OrtApi* detail::Global::DefaultInit() noexcept {
return OrtGetApiBase()->GetApi(ORT_API_VERSION);
}
#endif // ORT_API_MANUAL_INIT

/// This returns a reference to the ORT C API.
inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
inline const OrtApi& GetApi() noexcept { return *detail::Global::Api(); }

/// <summary>
/// This function returns the onnxruntime version string
Expand Down
2 changes: 1 addition & 1 deletion js/node/src/inference_session_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
// create ONNX runtime env
Ort::InitApi();
ORT_NAPI_THROW_ERROR_IF(
Ort::Global<void>::api_ == nullptr, env,
&Ort::GetApi() == nullptr, env,
"Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version "
"ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library).");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ std::once_flag init;
} // namespace

void InitProviderOrtApi() {
std::call_once(init, []() { Ort::Global<void>::api_ = Provider_GetHost()->OrtGetApiBase()->GetApi(ORT_API_VERSION); });
std::call_once(init, []() { Ort::InitApi(Provider_GetHost()->OrtGetApiBase()->GetApi(ORT_API_VERSION)); });
}

} // namespace onnxruntime
} // namespace onnxruntime
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ int vitisai_ep_set_ep_dynamic_options(
struct MyCustomOpKernel : OpKernel {
MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) {
op_kernel_ =
op_.CreateKernel(&op_, Ort::Global<void>::api_, reinterpret_cast<const OrtKernelInfo*>(&info));
op_.CreateKernel(&op_, &Ort::GetApi(), reinterpret_cast<const OrtKernelInfo*>(&info));
}

~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); }
Expand Down Expand Up @@ -332,8 +332,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
InitProviderOrtApi();
set_version_info(the_global_api);
the_global_api.host_ = Provider_GetHost();
assert(Ort::Global<void>::api_ != nullptr);
the_global_api.ort_api_ = Ort::Global<void>::api_;
assert(&Ort::GetApi() != nullptr);
the_global_api.ort_api_ = &Ort::GetApi();
the_global_api.model_load = [](const std::string& filename) -> Model* {
auto model_proto = ONNX_NAMESPACE::ModelProto::Create();
auto& logger = logging::LoggingManager::DefaultLogger();
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/test/autoep/library/ep_arena.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ limitations under the License.
#include <mutex>
#include <set>

#define ORT_API_MANUAL_INIT
#include "onnxruntime_cxx_api.h"
#undef ORT_API_MANUAL_INIT

#include "ep_allocator.h"
#include "example_plugin_ep_utils.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) {
}

OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
Ort::Global<void>::api_ = api->GetApi(ORT_API_VERSION);
Ort::InitApi(api->GetApi(ORT_API_VERSION));
OrtStatus* result = nullptr;

ORT_TRY {
Expand Down
Loading