diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bedeeb972c3a7..81599e3b811b1 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -902,6 +902,16 @@ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t n * * \nosubgrouping */ +/* + * Public enum for compiled model compatibility across EPs. + */ +typedef enum OrtCompiledModelCompatibility { + OrtCompiledModelCompatibility_EP_NOT_APPLICABLE = 0, + OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL, + OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION, + OrtCompiledModelCompatibility_EP_UNSUPPORTED, +} OrtCompiledModelCompatibility; + struct OrtApi { /// \name OrtStatus /// @{ @@ -6480,6 +6490,24 @@ struct OrtApi { * \since Version 1.23. */ ORT_API2_STATUS(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); + + /** \brief Validate a compiled model's compatibility information for one or more EP devices. + * + * \param[in] ep_devices The EP devices to validate against (e.g., from GetEpDevices). + * All devices must belong to the same execution provider. + * \param[in] num_ep_devices The number of EP devices provided. + * \param[in] compatibility_info The compatibility info string produced when the model was compiled. + * \param[out] out_status The resulting compatibility status for the EP devices. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 620cb5fcf13cc..975f6b453a88d 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -482,18 +482,6 @@ typedef enum OrtEpDataLayout { OrtEpDataLayout_Default = OrtEpDataLayout_NCHW, } OrtEpDataLayout; -/** - * \brief Enumeration describing the compatibility state of a compiled model relative to an execution provider. - * - * \since Version 1.23. - */ -typedef enum OrtCompiledModelCompatibility { - OrtCompiledModelCompatibility_EP_NOT_APPLICABLE = 0, - OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL, - OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION, - OrtCompiledModelCompatibility_EP_UNSUPPORTED, -} OrtCompiledModelCompatibility; - /** * \brief The OrtEp struct provides functions to implement for an execution provider. * \since Version 1.22. @@ -901,20 +889,28 @@ struct OrtEpFactory { */ ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); - /** \brief Validate the compatibility of a compiled model with the execution provider. + /** \brief Validate the compatibility of a compiled model with the execution provider factory for one or more devices. + * + * Given a compatibility info string produced during model compilation, the EP factory should determine whether the + * compiled model is compatible with the EP factory when targeting the provided hardware devices. All devices provided + * must belong to the same execution provider instance that this factory creates. * - * This function validates if a model produced with the supplied compatibility info string is supported by the underlying EP. - * The EP should check if a compiled model is compatible with the EP and set the model_compatibility parameter accordingly. + * The EP factory implementation should consider the set of devices (e.g., multi-adapter or multi-GPU scenarios) when + * evaluating compatibility and set `model_compatibility` accordingly. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] compatibility_info The compatibility information string that will be used - * \param[out] model_compatibility OrtCompiledModelCompatibility enum value describing the compatibility of the model with the EP. + * \param[in] devices Array of OrtHardwareDevice pointers that the EP would run on. All must map to this EP. + * \param[in] num_devices Number of entries in `devices`. + * \param[in] compatibility_info The compatibility information string produced when the model was compiled. + * \param[out] model_compatibility OrtCompiledModelCompatibility value describing the compatibility of the model with the EP. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(ValidateCompiledModelCompatibilityInfo, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, _In_ const char* compatibility_info, _Out_ OrtCompiledModelCompatibility* model_compatibility); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 1f491bc788870..ad0a1ad137f06 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3423,25 +3423,86 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, API_IMPL_END } +// Validate compiled model compatibility info for specific EP device(s) +ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status) { + API_IMPL_BEGIN + if (ep_devices == nullptr || num_ep_devices == 0 || compatibility_info == nullptr || out_status == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid argument provided to GetModelCompatibilityForEpDevices."); + } + + // Validate inputs and ensure all devices belong to the same EP/factory + const OrtEpFactory* first_factory = nullptr; + for (size_t i = 0; i < num_ep_devices; ++i) { + if (ep_devices[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_devices contains a null entry."); + } + const OrtEpFactory* f = ep_devices[i]->GetMutableFactory(); + if (i == 0) { + first_factory = f; + } else if (f != first_factory) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "All ep_devices must be from the same execution provider."); + } + } + + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + OrtStatus* ort_status = nullptr; + OrtEpFactory* factory = ep_devices[0]->GetMutableFactory(); + if (factory && factory->ValidateCompiledModelCompatibilityInfo) { + // collect hardware devices corresponding to the ep_devices + InlinedVector hardware_devices; + hardware_devices.reserve(num_ep_devices); + for (size_t i = 0; i < num_ep_devices; ++i) { + hardware_devices.push_back(ep_devices[i]->device); + } + ort_status = factory->ValidateCompiledModelCompatibilityInfo(factory, + hardware_devices.data(), + hardware_devices.size(), + compatibility_info, + &status); + } + if (ort_status != nullptr) { + return ToOrtStatus(ToStatusAndRelease(ort_status)); + } + + *out_status = status; + return nullptr; + API_IMPL_END +} + #else // defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "RegisterExecutionProviderLibrary is not supported in a minimal build."); API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "UnregisterExecutionProviderLibrary is not supported in a minimal build."); API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::GetEpDevices, _In_ const OrtEnv* /*env*/, _Outptr_ const OrtEpDevice* const** /*ep_devices*/, _Out_ size_t* /*num_ep_devices*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetEpDevices is not supported in a minimal build."); + API_IMPL_END +} + +// Minimal build stub for GetModelCompatibilityForEpDevices to satisfy symbol references from the API table +ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* /*ep_devices*/, + _In_ size_t /*num_ep_devices*/, + _In_ const char* /*compatibility_info*/, + _Out_ OrtCompiledModelCompatibility* /*out_status*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetModelCompatibilityForEpDevices is not supported in a minimal build."); API_IMPL_END } @@ -3453,7 +3514,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS _In_reads_(num_op_options) const char* const* /*ep_option_vals*/, size_t /*num_ep_options*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionOptionsAppendExecutionProvider_V2 is not supported in a minimal build."); API_IMPL_END } @@ -3466,7 +3527,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* _Out_writes_(num_values) const OrtEpDevice** /*inputs_ep_devices*/, _In_ size_t /*num_values*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionGetEpDeviceForInputs is not supported in a minimal build."); API_IMPL_END } @@ -3474,7 +3535,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice _In_opt_ const OrtKeyValuePairs* /*stream_options*/, _Outptr_ OrtSyncStream** /*ort_stream*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateSyncStreamForEpDevice is not supported in a minimal build."); API_IMPL_END } @@ -3493,7 +3554,7 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, _In_opt_ OrtSyncStream* /*stream*/, _In_ size_t /*num_tensors*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CopyTensors is not supported in a minimal build."); API_IMPL_END } @@ -4108,6 +4169,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::CopyTensors, &OrtApis::Graph_GetModelMetadata, + &OrtApis::GetModelCompatibilityForEpDevices, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index b3b0036c68247..e62149d04a16c 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -636,6 +636,13 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); ORT_API_STATUS_IMPL(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); + +// EP Compatibility Info APIs +ORT_API_STATUS_IMPL(GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status); ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 23e5e95af2903..093bfce462d32 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -80,9 +80,11 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); } - OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info, + OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _In_ const char* compatibility_info, _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { - return impl_->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + return impl_->ValidateCompiledModelCompatibilityInfo(devices, num_devices, compatibility_info, model_compatibility); } // Function ORT calls to release an EP instance. diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 6c55730d83979..f29154d19c53c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -62,8 +62,13 @@ class EpFactoryInternalImpl { return false; } - virtual OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info, - _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + virtual OrtStatus* ValidateCompiledModelCompatibilityInfo( + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + ORT_UNUSED_PARAMETER(devices); + ORT_UNUSED_PARAMETER(num_devices); ORT_UNUSED_PARAMETER(compatibility_info); // Default implementation: mark as not applicable *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 3bfca62a4d011..c8829423fbe26 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -668,8 +668,15 @@ Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std // Plugin EP did not provide an implementation of this function, so we call a default implementation. return Base::ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); } - // Delegate to the EP factory's validation method + // Delegate to the EP factory's validation method, passing hardware devices derived from our ep_devices_ + std::vector hardware_devices; + hardware_devices.reserve(ep_devices_.size()); + for (const auto* ep_device : ep_devices_) { + hardware_devices.push_back(ep_device->device); + } ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory_.ValidateCompiledModelCompatibilityInfo(&ep_factory_, + hardware_devices.data(), + hardware_devices.size(), compatibility_info.c_str(), &model_compatibility))); return Status::OK(); diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 29793b503c9d1..2cceb1d08d536 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -46,9 +46,12 @@ struct ForwardToFactoryImpl { } static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfo(OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + size_t num_devices, const char* compatibility_info, OrtCompiledModelCompatibility* model_compatibility) noexcept { - return static_cast(this_ptr)->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + return static_cast(this_ptr)->ValidateCompiledModelCompatibilityInfo(devices, num_devices, + compatibility_info, model_compatibility); } static OrtStatus* ORT_API_CALL CreateAllocator(_In_ OrtEpFactory* this_ptr, diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc index be97cf2620881..ee82d4683ab73 100644 --- a/onnxruntime/test/framework/ep_compatibility_test.cc +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -408,3 +408,94 @@ TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) { EXPECT_TRUE(has_config); EXPECT_EQ(config_value, "0"); } + +// ----------------------------- +// C API unit tests +// ----------------------------- + +namespace { + +// Helper to create an OrtEnv and fetch a CPU EP device pointer via the C API. +// Returns a pair of (env, cpu_device). Caller releases env via api->ReleaseEnv. +static std::pair CreateEnvAndGetCpuEpDevice(const OrtApi* api) { + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpCompatCapiTest", &env)); + EXPECT_NE(env, nullptr); + + const OrtEpDevice* const* devices = nullptr; + size_t num_devices = 0; + EXPECT_EQ(nullptr, api->GetEpDevices(env, &devices, &num_devices)); + EXPECT_GT(num_devices, 0u); + + const OrtEpDevice* cpu_device = nullptr; + for (size_t i = 0; i < num_devices; ++i) { + const char* name = api->EpDevice_EpName(devices[i]); + if (name && std::string(name) == "CPUExecutionProvider") { + cpu_device = devices[i]; + break; + } + } + + // Fallback: just pick the first device if CPU wasn't found (environment-dependent builds). + if (!cpu_device && num_devices > 0) { + cpu_device = devices[0]; + } + + EXPECT_NE(cpu_device, nullptr); + return {env, cpu_device}; +} + +} // namespace + +TEST(EpCompatibilityCapiTest, InvalidArguments) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtCompiledModelCompatibility out_status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + + // ep_devices == nullptr + OrtStatus* st = api->GetModelCompatibilityForEpDevices(nullptr, 0, "info", &out_status); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // Prepare a valid device + auto [env, device] = CreateEnvAndGetCpuEpDevice(api); + ASSERT_NE(env, nullptr); + ASSERT_NE(device, nullptr); + + // compatibility_info == nullptr + const OrtEpDevice* devices1[] = {device}; + st = api->GetModelCompatibilityForEpDevices(devices1, 1, nullptr, &out_status); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // out_status == nullptr + st = api->GetModelCompatibilityForEpDevices(devices1, 1, "some-info", nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(EpCompatibilityCapiTest, CpuEpReturnsNotApplicableIfNoValidation) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + auto [env, device] = CreateEnvAndGetCpuEpDevice(api); + ASSERT_NE(env, nullptr); + ASSERT_NE(device, nullptr); + + OrtCompiledModelCompatibility out_status = static_cast(-1); + const OrtEpDevice* devices2[] = {device}; + OrtStatus* st = api->GetModelCompatibilityForEpDevices(devices2, 1, "arbitrary-compat-string", &out_status); + ASSERT_EQ(st, nullptr) << (st ? api->GetErrorMessage(st) : ""); + + // For providers that don't implement validation, API should return EP_NOT_APPLICABLE. + EXPECT_EQ(out_status, OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +}