diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 5866dd3e83624..4905df2a71867 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -611,3 +611,53 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return &the_global_api; } } + +struct ExternalEpLibaray { + ExternalEpLibaray(const std::string& libray_name) : libray_name_{libray_name} { + Ensure(); + } + onnxruntime::Provider* (*get_provider_api)(); + void (*create_ep_factories)(void*, const OrtApiBase*, void*, OrtEpFactory**, size_t, size_t*); + void (*set_session_option)(OrtSessionOptions*); + + void Ensure() { + if (handle_) + return; + auto& env = Provider_GetHost()->Env__Default(); + auto library_filename = PathString(LIBRARY_PREFIX) + PathString(libray_name_.begin(), libray_name_.end()) + LIBRARY_EXTENSION; + auto full_path = env.GetRuntimePath() + library_filename; + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "GetProvider", (void**)&get_provider_api)); + } + + void Clear() { + if (handle_) { + auto& env = Provider_GetHost()->Env__Default(); + auto status = env.UnloadDynamicLibrary(handle_); + vai_assert(status.IsOK(), status.ErrorMessage()); + handle_ = nullptr; + } + } + + private: + std::string libray_name_; + void* handle_{}; +}; +static std::unordered_map> g_external_ep_libaries; + +std::unique_ptr +CreateExecutionProviderFromAnotherEp(const std::string& lib, const OrtSessionOptions& session_options, + std::unordered_map& provider_options) { + auto it = g_external_ep_libaries.find(lib); + if (it == g_external_ep_libaries.end()) { + it = g_external_ep_libaries.emplace(lib, std::make_unique(lib)).first; + } + auto ep_lib = it->second.get(); + auto get_provider_func = ep_lib->get_provider_api; + auto provider = get_provider_func(); + std::unique_ptr ret; + provider->Initialize(); + std::ignore = provider->CreateIExecutionProvider(nullptr, nullptr, 0, const_cast(provider_options), session_options, *((OrtLogger*)nullptr), ret); + + return ret; +} \ No newline at end of file diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 7791ea430054a..567f2cb4b39e3 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -6,10 +6,12 @@ #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/provider_options.h" +#include "core/framework/execution_provider.h" #include "vaip/my_ort.h" #include "vaip/dll_safe.h" #include "vaip/custom_op.h" #include +#include void initialize_vitisai_ep(); void deinitialize_vitisai_ep(); vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); @@ -40,3 +42,6 @@ using EventInfo = std::tuple< void profiler_collect( std::vector& api_events, std::vector& kernel_events); +std::unique_ptr +CreateExecutionProviderFromAnotherEp(const std::string& lib, const OrtSessionOptions& session_options, + std::unordered_map& provider_options); diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 50f924e468ed0..e1a3ca43e162e 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -7,7 +7,6 @@ #include #include #include - #include "vaip/global_api.h" #include "./vitisai_execution_provider.h" #include "core/framework/execution_provider.h" @@ -57,6 +56,10 @@ std::unique_ptr VitisAIProviderFactory::CreateProvider(const } } + auto it = provider_options.find("external_ep_libray"); + if (it != provider_options.end()) { + return CreateExecutionProviderFromAnotherEp(it->second, session_options, provider_options); + } auto ep_instance = std::make_unique(provider_options); ep_instance->SetLogger(reinterpret_cast(&session_logger)); return ep_instance;