diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index cf9f44f4cd8f0..1cac133ab0c2c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -18,7 +18,7 @@ void MIGraphXAllocator::CheckDevice() const { int current_device; auto hip_err = hipGetDevice(¤t_device); if (hip_err == hipSuccess) { - ORT_ENFORCE(current_device == Info().id); + ORT_ENFORCE(current_device == Info().device.Id()); } #endif } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 84c7fe3e4d4ab..923a39a7d2903 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -151,11 +151,14 @@ struct MIGraphX_Provider : Provider { const OrtSessionOptions& session_options, const OrtLogger& logger, std::unique_ptr& ep) override { + ORT_UNUSED_PARAMETER(num_devices); const ConfigOptions* config_options = &session_options.GetConfigOptions(); std::array configs_array = {&provider_options, config_options}; - const void* arg = reinterpret_cast(&configs_array); - auto ep_factory = CreateExecutionProviderFactory(&provider_options); + OrtMIGraphXProviderOptions migraphx_options; + UpdateProviderOptions(&migraphx_options, provider_options); + + auto ep_factory = CreateExecutionProviderFactory(&migraphx_options); ep = ep_factory->CreateProvider(session_options, logger); return Status::OK(); @@ -181,26 +184,47 @@ struct MigraphXEpFactory : OrtEpFactory { const char* ep_name, OrtHardwareDeviceType hw_type, const OrtLogger& default_logger_in) - : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, default_logger{default_logger_in} { + : ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; } // Returns the name for the EP. Each unique factory configuration must have a unique name. // Ex: a factory that supports NPU should have a different than a factory that supports GPU. - static const char* GetNameImpl(const OrtEpFactory* this_ptr) { + static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->ep_name.c_str(); } - static const char* GetVendorImpl(const OrtEpFactory* this_ptr) { + static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->vendor.c_str(); } + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id; + } + + static const char* GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->version.c_str(); + } + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. // An EP created with this factory is expected to be able to execute a model with *all* supported // hardware devices at once. A single instance of MigraphX EP is not currently setup to partition a model among @@ -212,14 +236,14 @@ struct MigraphXEpFactory : OrtEpFactory { size_t num_devices, OrtEpDevice** ep_devices, size_t max_ep_devices, - size_t* p_num_ep_devices) { + size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type) { - // factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) { + if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type && + factory->ort_api.HardwareDevice_VendorId(&device) == 0x1002) { OrtKeyValuePairs* ep_options = nullptr; factory->ort_api.CreateKeyValuePairs(&ep_options); ORT_API_RETURN_IF_ERROR( @@ -237,20 +261,59 @@ struct MigraphXEpFactory : OrtEpFactory { _In_ size_t /*num_devices*/, _In_ const OrtSessionOptions* /*session_options*/, _In_ const OrtLogger* /*logger*/, - _Out_ OrtEp** /*ep*/) { + _Out_ OrtEp** /*ep*/) noexcept { return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[MigraphX/AMDGPU EP] EP factory does not support this method."); } - static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) { + static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) noexcept { // no-op as we never create an EP here. } + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto* factory = static_cast(this_ptr); + + *allocator = nullptr; + return factory->ort_api.CreateStatus( + ORT_INVALID_ARGUMENT, + "CreateAllocator should not be called as we did not add OrtMemoryInfo to our OrtEpDevice."); + } + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, OrtAllocator* /*allocator*/) noexcept { + // should never be called as we don't implement CreateAllocator + } + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; // not implemented + return nullptr; + } + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; + } + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* /*memory_device*/, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** stream) noexcept { + auto* factory = static_cast(this_ptr); + + *stream = nullptr; + return factory->ort_api.CreateStatus( + ORT_INVALID_ARGUMENT, "CreateSyncStreamForDevice should not be called as IsStreamAware returned false."); + } + const OrtApi& ort_api; const OrtLogger& default_logger; const std::string ep_name; const std::string vendor{"AMD"}; + const std::string version{"1.0.0"}; // MigraphX EP version - const uint32_t vendor_id{0x1002}; + // Not using AMD vendor id 0x1002 so that OrderDevices in provider_policy_context.cc will default dml ep + const uint32_t vendor_id{0x9999}; const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice }; diff --git a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc index 5de7885a9452a..761ddf1975d15 100644 --- a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc +++ b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc @@ -188,6 +188,24 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) { ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true); } +static bool SessionHasEp(Ort::Session& session, const char* ep_name) { + // Access the underlying InferenceSession. + const OrtSession* ort_session = session; + const InferenceSession* s = reinterpret_cast(ort_session); + bool has_ep = false; + + for (const auto& provider : s->GetRegisteredProviderTypes()) { + if (provider == ep_name) { + has_ep = true; + break; + } + } + return has_ep; +} + +#if defined(WIN32) +// Tests autoEP feature to automatically select an EP that supports the GPU. +// Currently only works on Windows. TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) { PathString model_name = ORT_TSTR("migraphx_basic_test.onnx"); @@ -212,6 +230,7 @@ TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) { env.UnregisterExecutionProviderLibrary(kMIGraphXExecutionProvider); } +#endif } // namespace test } // namespace onnxruntime