diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 7df3368ad4e0b..1bb7f219c9a45 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -179,7 +179,12 @@ class IExecutionProvider { /** Get the device id of current execution provider */ - virtual int GetDeviceId() const { return default_device_.Id(); }; + virtual int GetDeviceId() const { return default_device_.Id(); } + + /** + * Get the OrtDevice the execution provider was registered with. + */ + const OrtDevice& GetDevice() const { return default_device_; } /** Get execution provider's configuration options. diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e0542768aef2f..0dbf54a3ec99e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1984,13 +1984,15 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) { // For now, this function only checks for invalid combination of DML EP with other EPs. // TODO: extend this function to check for other invalid combinations of EPs. common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const { - // DML EP is only allowed with CPU EP + // DML EP is not allowed with other GPU or NPU EPs. + // historical reason for this is unknown. relaxing the limit that it must only be used with the CPU EP to support + // scenarios where alternative EPs are CPU based (e.g. openvino). bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr; if (has_dml_ep) { - const auto& ep_list = execution_providers_.GetIds(); - for (const auto& ep : ep_list) { - if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue; - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP."); + for (const auto& ep : execution_providers_) { + if (ep->Type() != kDmlExecutionProvider && ep->GetDevice().Type() != OrtDevice::CPU) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can only be used with CPU EPs."); + } } } return Status::OK();