-
Notifications
You must be signed in to change notification settings - Fork 314
AMD RyzenAI EP Support #1935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AMD RyzenAI EP Support #1935
Changes from 3 commits
b4f80f2
b36b96d
4002f84
0d5b950
bfbd8f4
bd1e012
025cdba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,184 @@ | ||
| #include "../generators.h" | ||
| #include "../search.h" | ||
| #include "interface.h" | ||
| #include <filesystem> | ||
| #include <mutex> | ||
| #include <span> | ||
|
|
||
| #if !defined(_WIN32) | ||
| #include <dlfcn.h> | ||
| #endif | ||
|
|
||
| namespace Generators { | ||
| namespace RyzenAI { | ||
|
|
||
| static constexpr auto ep_path_env_key_ = "RYZENAI_EP_PATH"; | ||
| static constexpr auto ep_name_ = "RyzenAILightExecutionProvider"; | ||
| #if defined(_WIN32) | ||
| static constexpr auto ep_filename_ = "onnxruntime_providers_ryzenai.dll"; | ||
| #else | ||
| static constexpr auto ep_filename_ = "onnxruntime_providers_ryzenai.so"; | ||
| #endif | ||
| static constexpr auto func_custom_ops_ = "RyzenAI_RegisterCustomOps"; | ||
| static constexpr auto func_shutdown_ = "RyzenAI_Shutdown"; | ||
|
|
||
| static Ort::Allocator* ort_allocator_{}; | ||
|
|
||
| struct Memory : DeviceBuffer { | ||
| Memory(size_t size) : owned_{true} { | ||
| size_in_bytes_ = size; | ||
| p_cpu_ = p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size_in_bytes_)); | ||
| } | ||
|
|
||
| Memory(void* p, size_t size) : owned_{false} { | ||
| size_in_bytes_ = size; | ||
| p_cpu_ = p_device_ = static_cast<uint8_t*>(p); | ||
| } | ||
|
|
||
| ~Memory() override { | ||
| if (owned_) | ||
| ort_allocator_->Free(p_device_); | ||
| } | ||
|
|
||
| const char* GetType() const override { return "RyzenAI"; } | ||
|
|
||
| void AllocateCpu() override {} | ||
| void CopyDeviceToCpu() override {} | ||
| void CopyCpuToDevice() override {} | ||
|
|
||
| void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override { | ||
| CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes); | ||
| } | ||
|
|
||
| void Zero() override { | ||
| memset(p_device_, 0, size_in_bytes_); | ||
| } | ||
|
baijumeswani marked this conversation as resolved.
|
||
|
|
||
| bool owned_; | ||
| }; | ||
|
|
||
| struct Interface : RyzenAIInterface { | ||
| Interface() { | ||
| // If already loaded then nothing to do | ||
| #if defined(_WIN32) | ||
| if (GetModuleHandleA(ep_filename_)) | ||
| return; | ||
| #else | ||
| if (dlopen(ep_filename_, RTLD_NOLOAD | RTLD_NOW)) | ||
| return; | ||
| #endif | ||
|
|
||
| std::error_code ec; | ||
|
|
||
| ep_path_ = GetEnv(ep_path_env_key_); | ||
|
|
||
| #if defined(_WIN32) | ||
| if (ep_path_.empty()) { | ||
| wchar_t buffer[MAX_PATH + 1] = {0}; | ||
| const auto len = sizeof(buffer) / sizeof(buffer[0]); | ||
|
|
||
| if (MEMORY_BASIC_INFORMATION mbi; VirtualQuery(Ort::api->RegisterExecutionProviderLibrary, &mbi, sizeof(mbi))) | ||
| if (HMODULE mod = (HMODULE)mbi.AllocationBase; GetModuleFileNameW(mod, buffer, len)) | ||
| if (const auto dir = std::filesystem::path{buffer}.remove_filename(); !dir.empty()) | ||
| if (auto path = dir / ep_filename_; std::filesystem::exists(path, ec)) | ||
| ep_path_ = std::move(path); | ||
| } | ||
| #endif // _WIN32 | ||
|
|
||
| if (ep_path_.empty()) | ||
| ep_path_ = std::filesystem::current_path(ec) / ep_filename_; | ||
|
|
||
| Ort::ThrowOnError(Ort::api->RegisterExecutionProviderLibrary(GetOrtGlobals()->env_.get(), ep_name_, ep_path_.native().c_str())); | ||
| } | ||
|
|
||
| ~Interface() { | ||
| // TODO: make it linux compatible | ||
| #if defined(_WIN32) | ||
| if (const auto mod = GetModuleHandleA(ep_filename_)) | ||
| if (const auto func = reinterpret_cast<void (*)()>(GetProcAddress(mod, func_shutdown_))) | ||
| func(); | ||
| #endif // _WIN32 | ||
| } | ||
|
|
||
| void SetupProvider(OrtSessionOptions& session_options, const ProviderOptions& provider_options) override { | ||
| std::vector<const OrtEpDevice*> supported_devices; | ||
|
|
||
| { | ||
| const OrtEpDevice* const* devices = nullptr; | ||
| size_t ndevices = 0; | ||
|
|
||
| Ort::ThrowOnError(Ort::api->GetEpDevices(&GetOrtEnv(), &devices, &ndevices)); | ||
|
|
||
| for (const auto& device : std::span{devices, devices + ndevices}) | ||
|
Check failure on line 112 in src/ryzenai/interface.cpp
|
||
| if (std::string_view{ep_name_} == Ort::api->EpDevice_EpName(device) && | ||
| OrtHardwareDeviceType_NPU == Ort::api->HardwareDevice_Type(Ort::api->EpDevice_Device(device))) | ||
| supported_devices.push_back(device); | ||
| } | ||
|
|
||
| if (supported_devices.empty()) | ||
| throw std::runtime_error{"No RyzenAI devices detected"}; | ||
|
|
||
| { | ||
| std::vector<const char*> ep_keys, ep_values; | ||
|
|
||
| for (auto& option : provider_options) { | ||
| ep_keys.emplace_back(option.first.c_str()); | ||
| ep_values.emplace_back(option.second.c_str()); | ||
| } | ||
|
|
||
| // this call merges provider_options into session_options | ||
| Ort::ThrowOnError(Ort::api->SessionOptionsAppendExecutionProvider_V2(&session_options, | ||
| &GetOrtEnv(), supported_devices.data(), supported_devices.size(), | ||
| ep_keys.data(), ep_values.data(), ep_keys.size())); | ||
| } | ||
|
|
||
| Ort::ThrowOnError(Ort::api->RegisterCustomOpsUsingFunction(&session_options, func_custom_ops_)); | ||
| } | ||
|
|
||
| DeviceType GetType() const override { return DeviceType::RyzenAI; } | ||
|
|
||
| void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override { | ||
| assert(!ort_allocator_); | ||
| ort_allocator_ = &allocator; | ||
| } | ||
|
|
||
| Ort::Allocator& GetAllocator() override { | ||
| return *ort_allocator_; | ||
| } | ||
|
|
||
| std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override { | ||
| return std::make_shared<Memory>(size); | ||
| } | ||
|
|
||
| std::shared_ptr<DeviceBuffer> WrapMemoryBase(void* p, size_t size) override { | ||
| return std::make_shared<Memory>(p, size); | ||
| } | ||
|
|
||
| std::unique_ptr<Search> CreateGreedy(const GeneratorParams& params) override { return std::make_unique<GreedySearch_Cpu>(params); } | ||
| std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); } | ||
|
|
||
| void Synchronize() override {} | ||
|
|
||
| private: | ||
| std::filesystem::path ep_path_; | ||
| }; | ||
|
|
||
| static std::unique_ptr<Interface> interface_; | ||
|
|
||
| } // namespace RyzenAI | ||
|
|
||
| void RyzenAIInterface::Shutdown() { | ||
|
baijumeswani marked this conversation as resolved.
|
||
| RyzenAI::interface_.reset(); | ||
| } | ||
|
|
||
| RyzenAIInterface* GetRyzenAIInterface() { | ||
| static std::once_flag once; | ||
|
|
||
| std::call_once(once, []() { | ||
| RyzenAI::interface_ = std::make_unique<RyzenAI::Interface>(); | ||
| }); | ||
|
|
||
| return RyzenAI::interface_.get(); | ||
| } | ||
|
|
||
| } // namespace Generators | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| #pragma once | ||
|
|
||
| namespace Generators { | ||
|
|
||
| struct RyzenAIInterface : DeviceInterface { | ||
| using ProviderOptions = std::vector<std::pair<std::string, std::string>>; | ||
|
|
||
| virtual void SetupProvider(OrtSessionOptions&, const ProviderOptions&) = 0; | ||
|
|
||
| static void Shutdown(); | ||
| }; | ||
|
|
||
| RyzenAIInterface* GetRyzenAIInterface(); | ||
|
|
||
| } // namespace Generators |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -92,6 +92,7 @@ enum struct DeviceType { | |
| QNN, | ||
| OpenVINO, | ||
| NvTensorRtRtx, | ||
| RyzenAI, | ||
| MAX | ||
| }; | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.