diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 60f115ca50da4..14f12c906f11a 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -42,6 +42,35 @@ using namespace onnxruntime; #define LIBRARY_EXTENSION ".so" #endif +/// @brief Gets the path of directory containing the dynamic library that contains the address. +/// @param address An address of a function or variable in the dynamic library. +/// @return The path of the directory containing the dynamic library, or an empty string if the path cannot be determined. +static onnxruntime::PathString GetDynamicLibraryLocationByAddress(const void* address) { +#ifdef _WIN32 + HMODULE moduleHandle; + if (!::GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast(address), &moduleHandle)) { + return {}; + } + std::wstring buffer; + for (std::uint32_t size{70}; size < 4096; size *= 2) { + buffer.resize(size, L'\0'); + const std::uint32_t requiredSize = ::GetModuleFileNameW(moduleHandle, buffer.data(), size); + if (requiredSize == 0) { + break; + } + if (requiredSize == size) { + continue; + } + buffer.resize(requiredSize); + return {std::move(buffer)}; + } +#else + std::ignore = address; +#endif + return {}; +} + vaip_core::OrtApiForVaip* create_org_api_hook(); struct OrtVitisAIEpAPI { void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector& ret_domain); @@ -74,8 +103,20 @@ struct OrtVitisAIEpAPI { // this dll is already linked to the executable, normally a test program handle_ = reinterpret_cast(GetModuleHandle(TEXT("onnxruntime_vitisai_ep.dll"))); if (!handle_) { + // First try loading with full path + auto library_filename = PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); auto full_path = env.GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); - ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); + if (std::filesystem::exists(full_path)) { + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); + } else { + // Identify the path of the current dynamic library, and expect that onnxruntime_vitisai_ep is in the same directory. + PathString current_path = GetDynamicLibraryLocationByAddress(reinterpret_cast(create_org_api_hook)); + if (!current_path.empty()) { + const std::filesystem::path parent_path = std::filesystem::path{std::move(current_path)}.parent_path(); + PathString module_relative_full_path = PathString(parent_path / library_filename); + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(module_relative_full_path, true, &handle_)); + } + } } #else auto full_path = env.GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); diff --git a/onnxruntime/core/providers/vitisai/symbols.def b/onnxruntime/core/providers/vitisai/symbols.def index 4ec2f7914c208..3afed01da1966 100644 --- a/onnxruntime/core/providers/vitisai/symbols.def +++ b/onnxruntime/core/providers/vitisai/symbols.def @@ -1,2 +1,4 @@ EXPORTS GetProvider + CreateEpFactories + ReleaseEpFactory diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 6849bcfc21f88..1ef63588a1685 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -57,9 +57,6 @@ std::unique_ptr VitisAIProviderFactory::CreateProvider(const } } - // Store pointer to session options as done in SessionOptionsAppendExecutionProvider_VitisAI - provider_options["session_options"] = std::to_string((uintptr_t)(void*)&session_options); - auto ep_instance = std::make_unique(provider_options); ep_instance->SetLogger(reinterpret_cast(&session_logger)); return ep_instance; @@ -89,8 +86,101 @@ struct VitisAI_Provider : Provider { void Initialize() override { initialize_vitisai_ep(); } // Called right before unloading the shared library void Shutdown() override { deinitialize_vitisai_ep(); } + + Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t /*num_devices*/, + ProviderOptions& provider_options, + const OrtSessionOptions& session_options, + const OrtLogger& logger, + std::unique_ptr& ep) override { + auto ep_factory = CreateExecutionProviderFactory(&provider_options); + ep = ep_factory->CreateProvider(session_options, logger); + return Status::OK(); + } } g_provider; +struct VitisAIEpFactory : OrtEpFactory { + VitisAIEpFactory(const OrtApi& ort_api_in) + : ort_api{ort_api_in} { + ort_version_supported = ORT_API_VERSION; + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + } + + static const char* GetNameImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return ep_name; + } + + static const char* GetVendorImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return vendor; + } + + static uint32_t GetVendorIdImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return hardware_vendor_id; + } + + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return ORT_VERSION; + } + + static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + VitisAIEpFactory* factory = static_cast(ep_factory); + + for (size_t i = 0; i < num_devices; ++i) { + const OrtHardwareDevice* hardware_device = devices[i]; + const std::uint32_t vendor_id = factory->ort_api.HardwareDevice_VendorId(hardware_device); + const OrtHardwareDeviceType device_type = factory->ort_api.HardwareDevice_Type(hardware_device); + + if ((vendor_id != VitisAIEpFactory::hardware_vendor_id) || + (device_type != OrtHardwareDeviceType_NPU)) { + continue; + } + + if (num_ep_devices == max_ep_devices) { + return factory->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Not enough space to return EP devices."); + } + + auto status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, hardware_device, nullptr, nullptr, + &ep_devices[num_ep_devices++]); + if (status != nullptr) { + return status; + } + } + return nullptr; + } + + static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/, + _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) noexcept { + return CreateStatus(ORT_INVALID_ARGUMENT, "VitisAI EP factory does not support this method."); + } + + static void ReleaseEpImpl(OrtEpFactory*, OrtEp*) noexcept { + // no-op as we never create an EP here. + } + + const OrtApi& ort_api; + static constexpr const char* const ep_name{kVitisAIExecutionProvider}; + static constexpr std::uint32_t hardware_vendor_id{0x1022}; + static constexpr const char* const vendor{"AMD"}; +}; + } // namespace onnxruntime extern "C" { @@ -98,4 +188,21 @@ extern "C" { ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } + +OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + factories[0] = std::make_unique(*ort_api).release(); + *num_factories = 1; + return nullptr; +} + +OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} }