Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,16 @@ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t n
*
* \nosubgrouping
*/
/*
* Public enum for compiled model compatibility across EPs.
*/
typedef enum OrtCompiledModelCompatibility {
OrtCompiledModelCompatibility_EP_NOT_APPLICABLE = 0,
OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL,
OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION,
OrtCompiledModelCompatibility_EP_UNSUPPORTED,
} OrtCompiledModelCompatibility;

struct OrtApi {
/// \name OrtStatus
/// @{
Expand Down Expand Up @@ -6480,6 +6490,24 @@ struct OrtApi {
* \since Version 1.23.
*/
ORT_API2_STATUS(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out);

/** \brief Validate a compiled model's compatibility information for one or more EP devices.
*
* \param[in] ep_devices The EP devices to validate against (e.g., from GetEpDevices).
* All devices must belong to the same execution provider.
* \param[in] num_ep_devices The number of EP devices provided.
* \param[in] compatibility_info The compatibility info string produced when the model was compiled.
* \param[out] out_status The resulting compatibility status for the EP devices.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(GetModelCompatibilityForEpDevices,
_In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices,
_In_ size_t num_ep_devices,
_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* out_status);
};

/*
Expand Down
30 changes: 13 additions & 17 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,18 +482,6 @@ typedef enum OrtEpDataLayout {
OrtEpDataLayout_Default = OrtEpDataLayout_NCHW,
} OrtEpDataLayout;

/**
* \brief Enumeration describing the compatibility state of a compiled model relative to an execution provider.
*
* \since Version 1.23.
*/
typedef enum OrtCompiledModelCompatibility {
OrtCompiledModelCompatibility_EP_NOT_APPLICABLE = 0,
OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL,
OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION,
OrtCompiledModelCompatibility_EP_UNSUPPORTED,
} OrtCompiledModelCompatibility;

/**
* \brief The OrtEp struct provides functions to implement for an execution provider.
* \since Version 1.22.
Expand Down Expand Up @@ -901,20 +889,28 @@ struct OrtEpFactory {
*/
ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr);

/** \brief Validate the compatibility of a compiled model with the execution provider.
/** \brief Validate the compatibility of a compiled model with the execution provider factory for one or more devices.
*
* Given a compatibility info string produced during model compilation, the EP factory should determine whether the
* compiled model is compatible with the EP factory when targeting the provided hardware devices. All devices provided
* must belong to the same execution provider instance that this factory creates.
*
* This function validates if a model produced with the supplied compatibility info string is supported by the underlying EP.
* The EP should check if a compiled model is compatible with the EP and set the model_compatibility parameter accordingly.
* The EP factory implementation should consider the set of devices (e.g., multi-adapter or multi-GPU scenarios) when
* evaluating compatibility and set `model_compatibility` accordingly.
*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[in] compatibility_info The compatibility information string that will be used
* \param[out] model_compatibility OrtCompiledModelCompatibility enum value describing the compatibility of the model with the EP.
* \param[in] devices Array of OrtHardwareDevice pointers that the EP would run on. All must map to this EP.
* \param[in] num_devices Number of entries in `devices`.
* \param[in] compatibility_info The compatibility information string produced when the model was compiled.
* \param[out] model_compatibility OrtCompiledModelCompatibility value describing the compatibility of the model with the EP.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(ValidateCompiledModelCompatibilityInfo, _In_ OrtEpFactory* this_ptr,
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
_In_ size_t num_devices,
_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* model_compatibility);

Expand Down
76 changes: 69 additions & 7 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3423,25 +3423,86 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env,
API_IMPL_END
}

// Validate compiled model compatibility info for specific EP device(s)
ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices,
_In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices,
_In_ size_t num_ep_devices,
_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* out_status) {
API_IMPL_BEGIN
if (ep_devices == nullptr || num_ep_devices == 0 || compatibility_info == nullptr || out_status == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid argument provided to GetModelCompatibilityForEpDevices.");
}

// Validate inputs and ensure all devices belong to the same EP/factory
const OrtEpFactory* first_factory = nullptr;
for (size_t i = 0; i < num_ep_devices; ++i) {
if (ep_devices[i] == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_devices contains a null entry.");
}
const OrtEpFactory* f = ep_devices[i]->GetMutableFactory();
if (i == 0) {
first_factory = f;
} else if (f != first_factory) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "All ep_devices must be from the same execution provider.");
}
}

OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
OrtStatus* ort_status = nullptr;
OrtEpFactory* factory = ep_devices[0]->GetMutableFactory();
if (factory && factory->ValidateCompiledModelCompatibilityInfo) {
// collect hardware devices corresponding to the ep_devices
InlinedVector<const OrtHardwareDevice*> hardware_devices;
hardware_devices.reserve(num_ep_devices);
for (size_t i = 0; i < num_ep_devices; ++i) {
hardware_devices.push_back(ep_devices[i]->device);
}
ort_status = factory->ValidateCompiledModelCompatibilityInfo(factory,
hardware_devices.data(),
hardware_devices.size(),
compatibility_info,
&status);
}
if (ort_status != nullptr) {
return ToOrtStatus(ToStatusAndRelease(ort_status));
}

*out_status = status;
return nullptr;
API_IMPL_END
}

#else // defined(ORT_MINIMAL_BUILD)
ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/,
const ORTCHAR_T* /*path*/) {
API_IMPL_BEGIN
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build.");
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "RegisterExecutionProviderLibrary is not supported in a minimal build.");
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/,
_In_ const char* /*registration_name*/) {
API_IMPL_BEGIN
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build.");
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "UnregisterExecutionProviderLibrary is not supported in a minimal build.");
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::GetEpDevices, _In_ const OrtEnv* /*env*/,
_Outptr_ const OrtEpDevice* const** /*ep_devices*/, _Out_ size_t* /*num_ep_devices*/) {
API_IMPL_BEGIN
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build.");
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetEpDevices is not supported in a minimal build.");
API_IMPL_END
}

// Minimal build stub for GetModelCompatibilityForEpDevices to satisfy symbol references from the API table
ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices,
_In_reads_(num_ep_devices) const OrtEpDevice* const* /*ep_devices*/,
_In_ size_t /*num_ep_devices*/,
_In_ const char* /*compatibility_info*/,
_Out_ OrtCompiledModelCompatibility* /*out_status*/) {
API_IMPL_BEGIN
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetModelCompatibilityForEpDevices is not supported in a minimal build.");
API_IMPL_END
}

Expand All @@ -3453,7 +3514,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS
_In_reads_(num_op_options) const char* const* /*ep_option_vals*/,
size_t /*num_ep_options*/) {
API_IMPL_BEGIN
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build.");
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionOptionsAppendExecutionProvider_V2 is not supported in a minimal build.");
API_IMPL_END
}

Expand All @@ -3466,15 +3527,15 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession*
_Out_writes_(num_values) const OrtEpDevice** /*inputs_ep_devices*/,
_In_ size_t /*num_values*/) {
API_IMPL_BEGIN
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build.");
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionGetEpDeviceForInputs is not supported in a minimal build.");
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* /*ep_device*/,
_In_opt_ const OrtKeyValuePairs* /*stream_options*/,
_Outptr_ OrtSyncStream** /*ort_stream*/) {
API_IMPL_BEGIN
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build.");
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateSyncStreamForEpDevice is not supported in a minimal build.");
API_IMPL_END
}

Expand All @@ -3493,7 +3554,7 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/,
_In_opt_ OrtSyncStream* /*stream*/,
_In_ size_t /*num_tensors*/) {
API_IMPL_BEGIN
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build.");
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CopyTensors is not supported in a minimal build.");
API_IMPL_END
}

Expand Down Expand Up @@ -4108,6 +4169,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::CopyTensors,

&OrtApis::Graph_GetModelMetadata,
&OrtApis::GetModelCompatibilityForEpDevices,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,13 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i
// OrtGraph
ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name);
ORT_API_STATUS_IMPL(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out);

// EP Compatibility Info APIs
ORT_API_STATUS_IMPL(GetModelCompatibilityForEpDevices,
_In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices,
_In_ size_t num_ep_devices,
_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* out_status);
ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path);
ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version);
ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets);
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/session/plugin_ep/ep_factory_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ class EpFactoryInternal : public OrtEpFactory {
return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream);
}

OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info,
OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
_In_ size_t num_devices,
_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept {
return impl_->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility);
return impl_->ValidateCompiledModelCompatibilityInfo(devices, num_devices, compatibility_info, model_compatibility);
}

// Function ORT calls to release an EP instance.
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@ class EpFactoryInternalImpl {
return false;
}

virtual OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept {
virtual OrtStatus* ValidateCompiledModelCompatibilityInfo(
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
_In_ size_t num_devices,
_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept {
ORT_UNUSED_PARAMETER(devices);
ORT_UNUSED_PARAMETER(num_devices);
ORT_UNUSED_PARAMETER(compatibility_info);
// Default implementation: mark as not applicable
*model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -668,8 +668,15 @@ Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std
// Plugin EP did not provide an implementation of this function, so we call a default implementation.
return Base::ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility);
}
// Delegate to the EP factory's validation method
// Delegate to the EP factory's validation method, passing hardware devices derived from our ep_devices_
std::vector<const OrtHardwareDevice*> hardware_devices;
hardware_devices.reserve(ep_devices_.size());
for (const auto* ep_device : ep_devices_) {
hardware_devices.push_back(ep_device->device);
}
ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory_.ValidateCompiledModelCompatibilityInfo(&ep_factory_,
hardware_devices.data(),
hardware_devices.size(),
compatibility_info.c_str(),
&model_compatibility)));
return Status::OK();
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ struct ForwardToFactoryImpl {
}

static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfo(OrtEpFactory* this_ptr,
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
size_t num_devices,
const char* compatibility_info,
OrtCompiledModelCompatibility* model_compatibility) noexcept {
return static_cast<TFactory*>(this_ptr)->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility);
return static_cast<TFactory*>(this_ptr)->ValidateCompiledModelCompatibilityInfo(devices, num_devices,
compatibility_info, model_compatibility);
}

static OrtStatus* ORT_API_CALL CreateAllocator(_In_ OrtEpFactory* this_ptr,
Expand Down
91 changes: 91 additions & 0 deletions onnxruntime/test/framework/ep_compatibility_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,94 @@ TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) {
EXPECT_TRUE(has_config);
EXPECT_EQ(config_value, "0");
}

// -----------------------------
// C API unit tests
// -----------------------------

namespace {

// Helper to create an OrtEnv and fetch a CPU EP device pointer via the C API.
// Returns a pair of (env, cpu_device). Caller releases env via api->ReleaseEnv.
static std::pair<OrtEnv*, const OrtEpDevice*> CreateEnvAndGetCpuEpDevice(const OrtApi* api) {
OrtEnv* env = nullptr;
EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpCompatCapiTest", &env));
EXPECT_NE(env, nullptr);

const OrtEpDevice* const* devices = nullptr;
size_t num_devices = 0;
EXPECT_EQ(nullptr, api->GetEpDevices(env, &devices, &num_devices));
EXPECT_GT(num_devices, 0u);

const OrtEpDevice* cpu_device = nullptr;
for (size_t i = 0; i < num_devices; ++i) {
const char* name = api->EpDevice_EpName(devices[i]);
if (name && std::string(name) == "CPUExecutionProvider") {
cpu_device = devices[i];
break;
}
}

// Fallback: just pick the first device if CPU wasn't found (environment-dependent builds).
if (!cpu_device && num_devices > 0) {
cpu_device = devices[0];
}

EXPECT_NE(cpu_device, nullptr);
return {env, cpu_device};
}

} // namespace

TEST(EpCompatibilityCapiTest, InvalidArguments) {
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
ASSERT_NE(api, nullptr);

OrtCompiledModelCompatibility out_status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;

// ep_devices == nullptr
OrtStatus* st = api->GetModelCompatibilityForEpDevices(nullptr, 0, "info", &out_status);
ASSERT_NE(st, nullptr);
EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT);
api->ReleaseStatus(st);

// Prepare a valid device
auto [env, device] = CreateEnvAndGetCpuEpDevice(api);
ASSERT_NE(env, nullptr);
ASSERT_NE(device, nullptr);

// compatibility_info == nullptr
const OrtEpDevice* devices1[] = {device};
st = api->GetModelCompatibilityForEpDevices(devices1, 1, nullptr, &out_status);
ASSERT_NE(st, nullptr);
EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT);
api->ReleaseStatus(st);

// out_status == nullptr
st = api->GetModelCompatibilityForEpDevices(devices1, 1, "some-info", nullptr);
ASSERT_NE(st, nullptr);
EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT);
api->ReleaseStatus(st);

api->ReleaseEnv(env);
}

TEST(EpCompatibilityCapiTest, CpuEpReturnsNotApplicableIfNoValidation) {
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
ASSERT_NE(api, nullptr);

auto [env, device] = CreateEnvAndGetCpuEpDevice(api);
ASSERT_NE(env, nullptr);
ASSERT_NE(device, nullptr);

OrtCompiledModelCompatibility out_status = static_cast<OrtCompiledModelCompatibility>(-1);
const OrtEpDevice* devices2[] = {device};
OrtStatus* st = api->GetModelCompatibilityForEpDevices(devices2, 1, "arbitrary-compat-string", &out_status);
ASSERT_EQ(st, nullptr) << (st ? api->GetErrorMessage(st) : "");

// For providers that don't implement validation, API should return EP_NOT_APPLICABLE.
EXPECT_EQ(out_status, OrtCompiledModelCompatibility_EP_NOT_APPLICABLE);
api->ReleaseStatus(st);

api->ReleaseEnv(env);
}
Loading