Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/migraphx/migraphx_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void MIGraphXAllocator::CheckDevice() const {
int current_device;
auto hip_err = hipGetDevice(&current_device);
if (hip_err == hipSuccess) {
ORT_ENFORCE(current_device == Info().id);
ORT_ENFORCE(current_device == Info().device.Id());
}
#endif
}
Expand Down
85 changes: 74 additions & 11 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,14 @@
const OrtSessionOptions& session_options,
const OrtLogger& logger,
std::unique_ptr<IExecutionProvider>& ep) override {
ORT_UNUSED_PARAMETER(num_devices);
const ConfigOptions* config_options = &session_options.GetConfigOptions();

std::array<const void*, 2> configs_array = {&provider_options, config_options};
const void* arg = reinterpret_cast<const void*>(&configs_array);
auto ep_factory = CreateExecutionProviderFactory(&provider_options);
OrtMIGraphXProviderOptions migraphx_options;
UpdateProviderOptions(&migraphx_options, provider_options);

auto ep_factory = CreateExecutionProviderFactory(&migraphx_options);
ep = ep_factory->CreateProvider(session_options, logger);

return Status::OK();
Expand All @@ -181,26 +184,47 @@
const char* ep_name,
OrtHardwareDeviceType hw_type,
const OrtLogger& default_logger_in)
: ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, default_logger{default_logger_in} {
: ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} {
ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
GetVendorId = GetVendorIdImpl;
GetVersion = GetVersionImpl;

GetSupportedDevices = GetSupportedDevicesImpl;
CreateEp = CreateEpImpl;
ReleaseEp = ReleaseEpImpl;

CreateAllocator = CreateAllocatorImpl;
ReleaseAllocator = ReleaseAllocatorImpl;
CreateDataTransfer = CreateDataTransferImpl;

IsStreamAware = IsStreamAwareImpl;
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl;
}

// Returns the name for the EP. Each unique factory configuration must have a unique name.
// Ex: a factory that supports NPU should have a different than a factory that supports GPU.
static const char* GetNameImpl(const OrtEpFactory* this_ptr) {
static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->ep_name.c_str();
}

static const char* GetVendorImpl(const OrtEpFactory* this_ptr) {
static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->vendor.c_str();
}

static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->vendor_id;
}

static const char* GetVersionImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const MigraphXEpFactory*>(this_ptr);
return factory->version.c_str();
}

// Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports.
// An EP created with this factory is expected to be able to execute a model with *all* supported
// hardware devices at once. A single instance of MigraphX EP is not currently setup to partition a model among
Expand All @@ -212,14 +236,14 @@
size_t num_devices,
OrtEpDevice** ep_devices,
size_t max_ep_devices,
size_t* p_num_ep_devices) {
size_t* p_num_ep_devices) noexcept {
size_t& num_ep_devices = *p_num_ep_devices;
auto* factory = static_cast<MigraphXEpFactory*>(this_ptr);

for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
const OrtHardwareDevice& device = *devices[i];
if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type) {
// factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) {
if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type &&
factory->ort_api.HardwareDevice_VendorId(&device) == 0x1002) {
OrtKeyValuePairs* ep_options = nullptr;
factory->ort_api.CreateKeyValuePairs(&ep_options);
ORT_API_RETURN_IF_ERROR(
Expand All @@ -237,20 +261,59 @@
_In_ size_t /*num_devices*/,
_In_ const OrtSessionOptions* /*session_options*/,
_In_ const OrtLogger* /*logger*/,
_Out_ OrtEp** /*ep*/) {
_Out_ OrtEp** /*ep*/) noexcept {
return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[MigraphX/AMDGPU EP] EP factory does not support this method.");
}

static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) {
static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) noexcept {
// no-op as we never create an EP here.
}

static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr,
const OrtMemoryInfo* /*memory_info*/,
const OrtKeyValuePairs* /*allocator_options*/,
OrtAllocator** allocator) noexcept {
auto* factory = static_cast<MigraphXEpFactory*>(this_ptr);

*allocator = nullptr;
return factory->ort_api.CreateStatus(
ORT_INVALID_ARGUMENT,
"CreateAllocator should not be called as we did not add OrtMemoryInfo to our OrtEpDevice.");
}

static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, OrtAllocator* /*allocator*/) noexcept {
// should never be called as we don't implement CreateAllocator
}

static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/,
OrtDataTransferImpl** data_transfer) noexcept {
*data_transfer = nullptr; // not implemented
return nullptr;
}

static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
return false;
}

static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr,
const OrtMemoryDevice* /*memory_device*/,
const OrtKeyValuePairs* /*stream_options*/,
OrtSyncStreamImpl** stream) noexcept {
auto* factory = static_cast<MigraphXEpFactory*>(this_ptr);

*stream = nullptr;
return factory->ort_api.CreateStatus(
ORT_INVALID_ARGUMENT, "CreateSyncStreamForDevice should not be called as IsStreamAware returned false.");
}

const OrtApi& ort_api;
const OrtLogger& default_logger;
const std::string ep_name;
const std::string vendor{"AMD"};
const std::string version{"1.0.0"}; // MigraphX EP version

Check warning on line 313 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:313: Add #include <string> for string [build/include_what_you_use] [4]

const uint32_t vendor_id{0x1002};
// Not using AMD vendor id 0x1002 so that OrderDevices in provider_policy_context.cc will default dml ep
const uint32_t vendor_id{0x9999};
const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice
};

Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/test/providers/migraphx/migraphx_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) {
ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true);
}

static bool SessionHasEp(Ort::Session& session, const char* ep_name) {
// Access the underlying InferenceSession.
const OrtSession* ort_session = session;
const InferenceSession* s = reinterpret_cast<const InferenceSession*>(ort_session);
bool has_ep = false;

for (const auto& provider : s->GetRegisteredProviderTypes()) {
if (provider == ep_name) {
has_ep = true;
break;
}
}
return has_ep;
}

#if defined(WIN32)
// Tests autoEP feature to automatically select an EP that supports the GPU.
// Currently only works on Windows.
TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) {
PathString model_name = ORT_TSTR("migraphx_basic_test.onnx");

Expand All @@ -212,6 +230,7 @@ TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) {

env.UnregisterExecutionProviderLibrary(kMIGraphXExecutionProvider);
}
#endif

} // namespace test
} // namespace onnxruntime
Loading