diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 580fbfbdba0b0..d3dbb8c065ec8 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -79,7 +79,10 @@ struct OrtVitisAIEpAPI { std::vector>* (*compile_onnx_model_vitisai_ep_with_error_handling)( const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options, void* status, vaip_core::error_report_func func); std::vector>* (*compile_onnx_model_vitisai_ep_v3)( - const std::filesystem::path& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options, void* status, vaip_core::error_report_func func); + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options, void* status, vaip_core::error_report_func func); + std::vector>* (*compile_onnx_model_vitisai_ep_v4)( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options, void* status, vaip_core::error_report_func func, const onnxruntime::logging::Logger& logger); + void (*vaip_execution_provider_deletor)(std::vector>*) noexcept = [](std::vector>* p) noexcept { delete p; }; uint32_t (*vaip_get_version)(); void (*create_ep_context_nodes)( const std::vector>& eps, @@ -126,7 +129,8 @@ struct OrtVitisAIEpAPI { auto status1 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_error_handling", (void**)&compile_onnx_model_vitisai_ep_with_error_handling); auto status2 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options); auto status3 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_v3", (void**)&compile_onnx_model_vitisai_ep_v3); - if ((!status1.IsOK()) && (!status2.IsOK()) && (!status3.IsOK())) { + auto status4 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_v4", (void**)&compile_onnx_model_vitisai_ep_v4); + if ((!status1.IsOK()) && (!status2.IsOK()) && (!status3.IsOK()) && (!status4.IsOK())) { ::onnxruntime::LogRuntimeError(0, status2, __FILE__, static_cast(__FUNCTION__), __LINE__); ORT_THROW(status2); } @@ -137,6 +141,15 @@ struct OrtVitisAIEpAPI { ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_on_run_start", (void**)&vitisai_ep_on_run_start)); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "vitisai_ep_set_ep_dynamic_options", (void**)&vitisai_ep_set_ep_dynamic_options)); std::ignore = env.GetSymbolFromLibrary(handle_, "deinitialize_onnxruntime_vitisai_ep", (void**)&deinitialize_onnxruntime_vitisai_ep); + { + typedef void* (*vaip_get_execution_provider_deletor_func_t)(); + vaip_get_execution_provider_deletor_func_t vaip_get_execution_provider_deletor = nullptr; + auto status = env.GetSymbolFromLibrary(handle_, "vaip_get_execution_provider_deletor", + (void**)&vaip_get_execution_provider_deletor); + if (status.IsOK()) { + vaip_execution_provider_deletor = reinterpret_cast(vaip_get_execution_provider_deletor()); + }; + } } void Clear() { if (handle_) { @@ -174,10 +187,19 @@ void change_status_with_error(void* status_ptr, int error_code, const char* erro vaip_core::DllSafe>> compile_onnx_model( const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options) { auto model_path = graph_viewer.ModelPath(); - if (s_library_vitisaiep.compile_onnx_model_vitisai_ep_v3) { + auto vaip_execution_provider_deletor = s_library_vitisaiep.vaip_execution_provider_deletor; + if (s_library_vitisaiep.compile_onnx_model_vitisai_ep_v4) { + Status status = Status::OK(); + auto status_ptr = reinterpret_cast(&status); + auto ret = vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_vitisai_ep_v4(model_path.u8string(), graph_viewer.GetGraph(), options, status_ptr, change_status_with_error, logger), vaip_execution_provider_deletor); + if (!status.IsOK()) { + ORT_THROW(status); + } + return ret; + } else if (s_library_vitisaiep.compile_onnx_model_vitisai_ep_v3) { Status status = Status::OK(); auto status_ptr = reinterpret_cast(&status); - auto ret = vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_vitisai_ep_v3(model_path, graph_viewer.GetGraph(), options, status_ptr, change_status_with_error)); + auto ret = vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_vitisai_ep_v3(model_path.u8string(), graph_viewer.GetGraph(), options, status_ptr, change_status_with_error), vaip_execution_provider_deletor); if (!status.IsOK()) { ORT_THROW(status); } @@ -185,13 +207,13 @@ vaip_core::DllSafe>> c } else if (s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling) { Status status = Status::OK(); auto status_ptr = reinterpret_cast(&status); - auto ret = vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling(model_path.u8string(), graph_viewer.GetGraph(), options, status_ptr, change_status_with_error)); + auto ret = vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_vitisai_ep_with_error_handling(model_path.u8string(), graph_viewer.GetGraph(), options, status_ptr, change_status_with_error), vaip_execution_provider_deletor); if (!status.IsOK()) { ORT_THROW(status); } return ret; } else { - return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path.u8string(), graph_viewer.GetGraph(), options)); + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path.u8string(), graph_viewer.GetGraph(), options), vaip_execution_provider_deletor); } } diff --git a/onnxruntime/core/providers/vitisai/include/vaip/dll_safe.h b/onnxruntime/core/providers/vitisai/include/vaip/dll_safe.h index 27bc3ab63187c..a18902c5404be 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/dll_safe.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/dll_safe.h @@ -17,7 +17,9 @@ class DllSafe { : value_{value}, deleter_{[](T* value) noexcept { std::default_delete()(value); }} {} - + explicit DllSafe(T* value, void (*deleter)(T*) noexcept) + : value_{value}, deleter_{deleter} { + } explicit DllSafe(T&& value) : DllSafe(new T(std::move(value))) {} explicit DllSafe(const T& value) : DllSafe(new T(value)) {}