diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bf1dd6e20ce64..f22a4980865d0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6008,6 +6008,18 @@ struct OrtApi { */ ORT_API2_STATUS(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); + /** \brief Returns the execution provider type (name) that this node is assigned to run on. + * Returns NULL if the node has not been assigned to any execution provider yet. + * + * \param[in] node The OrtNode instance. + * \param[out] out Output execution provider type and can be NULL if node has not been assigned. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetEpType, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); + /// @} /// \name OrtRunOptions diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 8583fac30cfbf..c6328618e1575 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -276,6 +276,10 @@ const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { } } +const std::string& EpNode::GetEpType() const { + return node_.GetExecutionProviderType(); +} + // // EpValueInfo // diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 12fa082d3f354..e60bbc6d4dcda 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -208,6 +208,9 @@ struct EpNode : public OrtNode { // Helper that gets the node's attributes by name. const OrtOpAttr* GetAttribute(const std::string& name) const; + // Helper that gets the execution provider that this node is assigned to run on. + const std::string& GetEpType() const; + private: // Back pointer to containing graph. Useful when traversing through nested subgraphs. // Will be nullptr if the EpNode was created without an owning graph. diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 18b545483b38b..38f3c1de636af 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2967,6 +2967,23 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetGraph, _In_ const OrtNode* node, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Node_GetEpType, _In_ const OrtNode* node, + _Outptr_result_maybenull_ const char** out) { + API_IMPL_BEGIN + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); + } + + const EpNode* ep_node = EpNode::ToInternal(node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetEpType."); + } + + *out = ep_node->GetEpType().c_str(); + return nullptr; + API_IMPL_END +} + ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #ifdef ENABLE_TRAINING_APIS if (version >= 13 && version <= ORT_API_VERSION) @@ -3648,6 +3665,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumSubgraphs, &OrtApis::Node_GetSubgraphs, &OrtApis::Node_GetGraph, + &OrtApis::Node_GetEpType, &OrtApis::GetRunConfigEntry, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 75db44cb9e9ff..afc6b8d2b70a5 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -678,6 +678,7 @@ ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, _Out_writes_opt_(num_subgraphs) const char** attribute_names); ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); +ORT_API_STATUS_IMPL(Node_GetEpType, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index b498c40079f48..a5b46c74ecc21 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -328,6 +328,12 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); + const char* ep_type = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetEpType(fused_nodes[0], &ep_type)); + if (std::strncmp(ep_type, "example_ep", 11) != 0) { + return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on"); + } + // Associate the name of the fused node with our MulKernel. const char* fused_node_name = nullptr; RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name));