diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 28ce4439fdc7e..e2b2aff2011fe 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -203,415 +203,331 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, #define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ do { \ - OrtStatus* _status = (fn); \ - if (_status != nullptr) { \ - return Ort::Status{_status}; \ + Ort::Status _status{(fn)}; \ + if (!_status.IsOK()) { \ + return _status; \ } \ } while (0) #define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ - do { \ - Ort::Status _status = (fn); \ - if (!_status.IsOK()) { \ - return _status; \ - } \ - } while (0) + ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) -#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ - do { \ - if ((cond)) { \ - return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ - } \ +#define ORT_EP_UTILS_C_RETURN_IF(cond, msg) \ + do { \ + if ((cond)) { \ + return Ort::Status{msg, ORT_FAIL}; \ + } \ } while (0) namespace OrtEpUtils { -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, +static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto); +static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr ort_attr, onnx::AttributeProto& attr_proto); -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, +Ort::Status OrtGraphToProto(const OrtGraph& graph, onnx::GraphProto& graph_proto, HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // - // Set GraphProto metadata - // - const char* graph_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); - graph_proto.set_name(graph_name); - graph_proto.set_doc_string("Serialized from OrtGraph"); - - // - // Set GraphProto inputs and outputs - // - size_t num_graph_inputs = 0; - size_t num_graph_outputs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); - - std::vector graph_inputs(num_graph_inputs); - std::vector graph_outputs(num_graph_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); - - for (const OrtValueInfo* ort_value_info : graph_inputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - for (const OrtValueInfo* ort_value_info : graph_outputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - // - // Set GraphProto nodes, value_infos, and initializers. - // - - // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. - // A std::map maintains its elements in a stable ordering. - std::map value_infos; // For GraphProto.value_info - std::map initializer_value_infos; // For GraphProto.initializer - - // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. - // Optionally returns the OrtValueInfo name to the caller. - auto collect_value_info = [&ort_api, &value_infos, - &initializer_value_infos](const OrtValueInfo& ort_value_info, - /*out*/ const char** value_name_out = nullptr) -> Ort::Status { - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - - if (value_name_out != nullptr) { - *value_name_out = value_name; + try { + Ort::ConstGraph ort_graph{&graph}; + // + // Set GraphProto metadata + // + auto graph_name = ort_graph.GetName(); + graph_proto.set_name(graph_name); + graph_proto.set_doc_string("Serialized from OrtGraph"); + + // + // Set GraphProto inputs and outputs + // + std::vector graph_inputs = ort_graph.GetInputs(); + std::vector graph_outputs = ort_graph.GetOutputs(); + + for (const auto& ort_value_info : graph_inputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto)); } - if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { - return Ort::Status{nullptr}; // Already processed this OrtValueInfo. + for (const auto& ort_value_info : graph_outputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto)); } - bool is_required_graph_input = false; - bool is_optional_graph_input = false; - bool is_graph_output = false; - bool is_constant_initializer = false; - bool is_from_outer_scope = false; - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); - - // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. - // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. - // For values defined in an outer scope, just add the value info but not the initializer. - if (is_from_outer_scope) { - value_infos.emplace(value_name, &ort_value_info); - } else if (is_optional_graph_input) { - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (is_constant_initializer) { - value_infos.emplace(value_name, &ort_value_info); - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (!is_required_graph_input && !is_graph_output) { - value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. - } + // + // Set GraphProto nodes, value_infos, and initializers. + // + + // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. + // A std::map maintains its elements in a stable ordering. + std::map value_infos; // For GraphProto.value_info + std::map initializer_value_infos; // For GraphProto.initializer + + // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. + // Optionally returns the OrtValueInfo name to the caller. + auto collect_value_info = [&value_infos, + &initializer_value_infos](Ort::ConstValueInfo ort_value_info, + /*out*/ std::optional& value_name_out) { + auto value_name = ort_value_info.GetName(); + + if (value_name_out) { + *value_name_out = value_name; + } + + if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { + return; // Already processed this OrtValueInfo. + } - return Ort::Status{nullptr}; - }; - - size_t num_nodes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); - - std::vector nodes(num_nodes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); - - // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos - // that will be stored in GraphProto.value_info and GraphProto.initializer. - for (size_t i = 0; i < num_nodes; i++) { - const OrtNode* ort_node = nodes[i]; - onnx::NodeProto* node_proto = graph_proto.add_node(); - - const char* node_name = nullptr; - const char* node_domain = nullptr; - const char* node_op_type = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); - - node_proto->set_name(node_name); - node_proto->set_domain(node_domain); - node_proto->set_op_type(node_op_type); - - size_t num_inputs = 0; - size_t num_implicit_inputs = 0; - size_t num_outputs = 0; - size_t num_attrs = 0; - size_t num_subgraphs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); - - // Handle node attributes - if (num_attrs > 0) { - std::vector ort_attrs(num_attrs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); - - for (const OrtOpAttr* ort_attr : ort_attrs) { - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - - Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + bool is_required_graph_input = ort_value_info.IsRequiredGraphInput(); + bool is_optional_graph_input = ort_value_info.IsOptionalGraphInput(); + bool is_graph_output = ort_value_info.IsGraphOutput(); + bool is_constant_initializer = ort_value_info.IsConstantInitializer(); + bool is_from_outer_scope = ort_value_info.IsFromOuterScope(); + + // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. + // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. + // For values defined in an outer scope, just add the value info but not the initializer. + if (is_from_outer_scope) { + value_infos.emplace(value_name, ort_value_info); + } else if (is_optional_graph_input) { + initializer_value_infos.emplace(value_name, ort_value_info); + } else if (is_constant_initializer) { + value_infos.emplace(value_name, ort_value_info); + initializer_value_infos.emplace(value_name, ort_value_info); + } else if (!is_required_graph_input && !is_graph_output) { + value_infos.emplace(value_name, ort_value_info); // This is an internal OrtValueInfo. + } + }; + + std::vector nodes = ort_graph.GetNodes(); + // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos + // that will be stored in GraphProto.value_info and GraphProto.initializer. + for (const auto& ort_node : nodes) { + onnx::NodeProto* node_proto = graph_proto.add_node(); + + std::string node_name = ort_node.GetName(); + std::string node_domain = ort_node.GetDomain(); + std::string node_op_type = ort_node.GetOperatorType(); + + node_proto->set_name(node_name); + node_proto->set_domain(node_domain); + node_proto->set_op_type(node_op_type); + + // Handle node attributes + std::vector ort_attrs = ort_node.GetAttributes(); + for (const auto& attr : ort_attrs) { + OrtOpAttrType attr_type = attr.GetType(); if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. // Can use Node_GetSubgraphs to get subgraphs. continue; } - if (!attr_type_status.IsOK()) { - // Unsupported attribute type. - return attr_type_status; - } - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(attr, *attr_proto)); } - } - - // Handle node subgraphs - if (num_subgraphs > 0) { - std::vector ort_subgraphs(num_subgraphs); - std::vector subgraph_attr_names(num_subgraphs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), - subgraph_attr_names.data())); - - for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { - const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; - const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; + // Handle node subgraphs + std::vector ort_subgraphs = ort_node.GetSubgraphs(); + for (const auto& [subgraph_attr_name, ort_subgraph] : ort_subgraphs) { onnx::AttributeProto* attr_proto = node_proto->add_attribute(); onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); - attr_proto->set_name(subgraph_attr_name); attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); } - } - - // Handle node inputs - if (num_inputs > 0) { - std::vector ort_inputs(num_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); - for (const OrtValueInfo* ort_value_info : ort_inputs) { - if (ort_value_info == nullptr) { + // Handle node inputs + std::vector ort_inputs = ort_node.GetInputs(); + for (const auto& vi : ort_inputs) { + if (vi == nullptr) { // missing optional input. node_proto->add_input(""); continue; } - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_input(value_name); + std::optional value_name; + value_name.emplace(); + collect_value_info(vi, value_name); + node_proto->add_input(*value_name); } - } - - // Handle implicit inputs to this node. - if (num_implicit_inputs > 0) { - std::vector ort_implicit_inputs(num_implicit_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), - ort_implicit_inputs.size())); - for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { - assert(ort_value_info != nullptr); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); + // Handle implicit inputs to this node. + std::vector ort_implicit_inputs = ort_node.GetImplicitInputs(); + for (const auto& vi : ort_implicit_inputs) { + assert(vi != nullptr); + std::optional value_name; + collect_value_info(vi, value_name); } - } - - // Handle node outputs - if (num_outputs > 0) { - std::vector ort_outputs(num_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); - for (const OrtValueInfo* ort_value_info : ort_outputs) { - if (ort_value_info == nullptr) { + // Handle node outputs + std::vector ort_outputs = ort_node.GetOutputs(); + for (const auto& vi : ort_outputs) { + if (vi == nullptr) { // missing optional output. node_proto->add_output(""); continue; } - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_output(value_name); + std::optional value_name; + value_name.emplace(); + collect_value_info(vi, value_name); + node_proto->add_output(*value_name); } } - } - - // Add value_infos to GraphProto as ValueInfoProto objects. - for (const std::pair& entry : value_infos) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); - } - // Add initializers to GraphProto as TensorProto objects. - for (const std::pair& entry : initializer_value_infos) { - const OrtValueInfo* initializer_value_info = entry.second; - std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. - std::vector initializer_dims; - std::vector initializer_sym_dims; - ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, - initializer_elem_type, initializer_dims, - initializer_sym_dims)); - - onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); - tensor_proto->set_name(initializer_name); - tensor_proto->set_data_type(initializer_elem_type); - - auto* tensor_proto_dims = tensor_proto->mutable_dims(); - for (int64_t dim : initializer_dims) { - tensor_proto_dims->Add(dim); + // Add value_infos to GraphProto as ValueInfoProto objects. + for (const auto& [value_name, value_info] : value_infos) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(value_info, *value_info_proto)); } - const OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); + // Add initializers to GraphProto as TensorProto objects. + for (const auto& [initializer_name, initializer_value_info] : initializer_value_infos) { + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(initializer_value_info, /*get_sym_dims*/ false, + initializer_elem_type, initializer_dims, + initializer_sym_dims)); + + onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); + tensor_proto->set_name(initializer_name); + tensor_proto->set_data_type(initializer_elem_type); + + auto* tensor_proto_dims = tensor_proto->mutable_dims(); + for (int64_t dim : initializer_dims) { + tensor_proto_dims->Add(dim); + } - const void* data = nullptr; - size_t data_bytes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + Ort::ConstValue ort_value{nullptr}; + ORT_EP_UTILS_C_RETURN_IF_ERROR(initializer_value_info.GetInitializer(ort_value)); - std::string ext_location; - int64_t ext_offset = 0; - bool is_external = false; + assert(ort_value.IsTensor()); + const void* data = ort_value.GetTensorRawData(); + const size_t data_bytes = ort_value.GetTensorSizeInBytes(); - if (handle_initializer_data_func != nullptr) { - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, - is_external, ext_location, ext_offset)); - } + std::string ext_location; + int64_t ext_offset = 0; + bool is_external = false; + + if (handle_initializer_data_func != nullptr) { + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, + is_external, ext_location, ext_offset)); + } - if (is_external) { - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); - auto* ext_data_entries = tensor_proto->mutable_external_data(); - onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); - - location_entry->set_key("location"); - location_entry->set_value(ext_location); - offset_entry->set_key("offset"); - offset_entry->set_value(std::to_string(ext_offset)); - length_entry->set_key("length"); - length_entry->set_value(std::to_string(data_bytes)); - } else { - // User wants to store data inline the TensorProto's raw_data - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); - tensor_proto->set_raw_data(data, data_bytes); + if (is_external) { + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + auto* ext_data_entries = tensor_proto->mutable_external_data(); + onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); + + location_entry->set_key("location"); + location_entry->set_value(ext_location); + offset_entry->set_key("offset"); + offset_entry->set_value(std::to_string(ext_offset)); + length_entry->set_key("length"); + length_entry->set_value(std::to_string(data_bytes)); + } else { + // User wants to store data inline the TensorProto's raw_data + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); + tensor_proto->set_raw_data(data, data_bytes); + } } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_FAIL}; } return Ort::Status{nullptr}; } -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, +Ort::Status OrtGraphToProto(const OrtGraph& graph, onnx::ModelProto& model_proto, HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // Check that OrtGraph is a top-level graph (no parent node). - const OrtNode* parent_node = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); - ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); - - // Set model description. - model_proto.set_doc_string("Serialized from OrtGraph"); - model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); - - // Set ir version. - int64_t ir_version = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); - model_proto.set_ir_version(ir_version); - - // Set operator sets. - size_t num_operator_sets = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); - ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); - - std::vector domains(num_operator_sets, nullptr); - std::vector opset_versions(num_operator_sets); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), - num_operator_sets)); - - auto* operator_sets = model_proto.mutable_opset_import(); - - for (size_t i = 0; i < num_operator_sets; ++i) { - onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); - operator_set->set_domain(domains[i]); - operator_set->set_version(opset_versions[i]); - } + try { + // Check that OrtGraph is a top-level graph (no parent node). + Ort::ConstGraph ort_graph{&graph}; + Ort::ConstNode parent_node = ort_graph.GetParentNode(); + ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, "Cannot serialize nested OrtGraph into a ModelProto"); + + // Set model description. + model_proto.set_doc_string("Serialized from OrtGraph"); + model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); + + // Set ir version. + int64_t ir_version = ort_graph.GetOnnxIRVersion(); + model_proto.set_ir_version(ir_version); + + // Set operator sets. + std::vector op_sets = ort_graph.GetOperatorSets(); + ORT_EP_UTILS_C_RETURN_IF(op_sets.empty(), "OrtGraph should have at least one operator set."); + + auto* operator_sets = model_proto.mutable_opset_import(); + + for (const auto& op_set : op_sets) { + onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); + operator_set->set_domain(op_set.domain); + operator_set->set_version(op_set.version); + } - model_proto.clear_graph(); - onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + model_proto.clear_graph(); + onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_graph, *graph_proto, handle_initializer_data_func)); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); + } catch (const Ort::Exception& ex) { + return Ort::Status(ex); + } catch (const std::exception& ex) { + return Ort::Status(ex.what(), ORT_EP_FAIL); + } return Ort::Status{nullptr}; } -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, +static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims) { - const OrtApi& ort_api = Ort::GetApi(); - - const OrtTypeInfo* ort_type_info = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); - - ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); - ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); + try { + Ort::ConstTypeInfo ort_type_info = vi.TypeInfo(); + ONNXType ort_onnx_type = ort_type_info.GetONNXType(); + ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, "Expected OrtValueInfo to represent a Tensor"); - const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; - ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); - - size_t num_dims = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); + Ort::ConstTensorTypeAndShapeInfo ort_type_shape = ort_type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType ort_elem_type = ort_type_shape.GetElementType(); - std::vector ort_dims(num_dims, 0); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); + size_t num_dims = ort_type_shape.GetDimensionsCount(); + std::vector ort_dims = ort_type_shape.GetShape(); - elem_type = ort_elem_type; - dims = std::move(ort_dims); + elem_type = ort_elem_type; + dims = std::move(ort_dims); - if (get_symbolic_dims) { - std::vector ort_dim_syms(num_dims, nullptr); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), - ort_dim_syms.size())); + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ort_type_shape.GetSymbolicDimensions(ort_dim_syms.data(), ort_dim_syms.size()); - symbolic_dims.reserve(num_dims); - for (const char* sym_dim : ort_dim_syms) { - symbolic_dims.push_back(sym_dim); + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_EP_FAIL}; } - return Ort::Status{nullptr}; } // Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, +static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto) { - const OrtApi& ort_api = Ort::GetApi(); - std::vector ort_dims; std::vector ort_dim_syms; ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; @@ -620,9 +536,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, ort_elem_type, ort_dims, ort_dim_syms)); - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - value_info_proto.set_name(value_name); + value_info_proto.set_name(ort_value_info.GetName()); onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); type_proto_tensor->set_elem_type(ort_elem_type); @@ -652,213 +566,149 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { - const OrtApi& ort_api = Ort::GetApi(); - - const char* attr_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); - attr_proto.set_name(attr_name); +static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr attr, onnx::AttributeProto& attr_proto) { + try { + std::string attr_name = attr.GetName(); + attr_proto.set_name(attr_name); - size_t total_attr_bytes = 0; - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); + OrtOpAttrType attr_type = attr.GetType(); - switch (attr_type) { - case OrtOpAttrType::ORT_OP_ATTR_INT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); - - int64_t i_val = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); - attr_proto.set_i(i_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_INTS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector i_vals(total_attr_bytes / sizeof(int64_t)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* ints = attr_proto.mutable_ints(); - for (int64_t val : i_vals) { - ints->Add(val); + switch (attr_type) { + case OrtOpAttrType::ORT_OP_ATTR_INT: { + int64_t i_val = 0; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(i_val)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto.set_i(i_val); + break; } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); - - float f_val = 0.0f; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); - attr_proto.set_f(f_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector f_vals(total_attr_bytes / sizeof(float)); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* floats = attr_proto.mutable_floats(); - for (float val : f_vals) { - floats->Add(val); + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + std::vector i_vals; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(i_vals)); + auto* ints = attr_proto.mutable_ints(); + ints->Assign(i_vals.begin(), i_vals.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); + break; } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRING: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::string* str = attr_proto.mutable_s(); - - str->resize(total_attr_bytes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, - &total_attr_bytes)); - - str->resize(total_attr_bytes); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector chars(total_attr_bytes, '\0'); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* strs = attr_proto.mutable_strings(); - - // Strings are all in a single buffer, each separated with a '\0'. - // Extract each string and add it to the STRINGS attribute array. - char* at = chars.data(); - char* end = at + chars.size(); - - while (at < end) { - char* str_begin = at; - - while (*at && at < end) { - at++; - } - - strs->Add()->assign(str_begin, at - str_begin); - if (at < end) { - assert(*at == '\0'); - at++; // Skip '\0' to get to the beginning of the next string. - } + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + float f_val = 0.0f; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(f_val)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto.set_f(f_val); + break; } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + std::vector f_vals; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(f_vals)); + auto* floats = attr_proto.mutable_floats(); + floats->Assign(f_vals.begin(), f_vals.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + std::string* str = attr_proto.mutable_s(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(*str)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + std::vector result; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(result)); + auto* strs = attr_proto.mutable_strings(); + strs->Assign(result.begin(), result.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + + // TensorProto as an attribute value doesn't require a name. + + Ort::Value tensor; + ORT_EP_UTILS_C_RETURN_IF_ERROR(attr.GetTensorAttributeAsOrtValue(tensor)); + + // Get tensor type and shape info + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + + // Get tensor type + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + break; + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); - - onnx::TensorProto tensor_proto; - - // TensorProto as an attribute value doesn't require a name. - - OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value)); + auto shape = type_shape_info.GetShape(); - Ort::Value tensor(ort_value); + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } - // Get tensor type and shape info - Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + const void* data = tensor.GetTensorRawData(); + const size_t data_bytes = tensor.GetTensorSizeInBytes(); - // Get tensor type - ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + // Copy the Ortvalue to TensorProto as raw data + tensor_proto.set_raw_data(data, data_bytes); - size_t element_size = 0; - switch (element_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); - element_size = sizeof(float); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); - element_size = sizeof(uint8_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); - element_size = sizeof(int8_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); - element_size = sizeof(uint16_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); - element_size = sizeof(int16_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); - element_size = sizeof(int32_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); - element_size = sizeof(int64_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); - element_size = sizeof(bool); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); - element_size = sizeof(double); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); - element_size = sizeof(uint32_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); - element_size = sizeof(uint64_t); - break; - } - default: { - std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); - } + *(attr_proto.mutable_t()) = std::move(tensor_proto); + break; } - - auto shape = type_shape_info.GetShape(); - - for (auto& dim : shape) { - tensor_proto.add_dims(dim); + default: { + std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); } - - size_t element_count = type_shape_info.GetElementCount(); - size_t data_bytes = element_count * element_size; - const void* data = tensor.GetTensorData(); - - // Copy the Ortvalue to TensorProto as raw data - tensor_proto.set_raw_data(data, data_bytes); - - *(attr_proto.mutable_t()) = std::move(tensor_proto); - break; - } - default: { - std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_FAIL}; } return Ort::Status{nullptr}; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 13675ab447ab1..8de42151c99df 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -52,6 +52,7 @@ namespace Ort { * If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort() */ struct Exception : std::exception { + Exception(const std::string& string, OrtErrorCode code) : message_{string}, code_{code} {} Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} OrtErrorCode GetOrtErrorCode() const { return code_; } @@ -612,33 +613,35 @@ namespace detail { inline void OrtRelease(Ort##NAME* ptr) { API_GETTER().Release##NAME(ptr); } ORT_DEFINE_RELEASE(Allocator); -ORT_DEFINE_RELEASE(MemoryInfo); +ORT_DEFINE_RELEASE(ArenaCfg); ORT_DEFINE_RELEASE(CustomOpDomain); -ORT_DEFINE_RELEASE(ThreadingOptions); ORT_DEFINE_RELEASE(Env); -ORT_DEFINE_RELEASE(RunOptions); +ORT_DEFINE_RELEASE(ExternalInitializerInfo); +ORT_DEFINE_RELEASE(Graph); +ORT_DEFINE_RELEASE(IoBinding); +ORT_DEFINE_RELEASE(KernelInfo); +ORT_DEFINE_RELEASE(KeyValuePairs); ORT_DEFINE_RELEASE(LoraAdapter); +ORT_DEFINE_RELEASE(MemoryInfo); +ORT_DEFINE_RELEASE(MapTypeInfo); +ORT_DEFINE_RELEASE(Model); +ORT_DEFINE_RELEASE(ModelMetadata); +ORT_DEFINE_RELEASE(Node); +ORT_DEFINE_RELEASE(Op); +ORT_DEFINE_RELEASE(OpAttr); +ORT_DEFINE_RELEASE(PrepackedWeightsContainer); +ORT_DEFINE_RELEASE(RunOptions); ORT_DEFINE_RELEASE(Session); ORT_DEFINE_RELEASE(SessionOptions); -ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); ORT_DEFINE_RELEASE(SequenceTypeInfo); -ORT_DEFINE_RELEASE(MapTypeInfo); -ORT_DEFINE_RELEASE(TypeInfo); -ORT_DEFINE_RELEASE(Value); -ORT_DEFINE_RELEASE(ModelMetadata); -ORT_DEFINE_RELEASE(IoBinding); -ORT_DEFINE_RELEASE(ArenaCfg); ORT_DEFINE_RELEASE(Status); ORT_DEFINE_RELEASE(SyncStream); -ORT_DEFINE_RELEASE(OpAttr); -ORT_DEFINE_RELEASE(Op); -ORT_DEFINE_RELEASE(KernelInfo); +ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); +ORT_DEFINE_RELEASE(ThreadingOptions); +ORT_DEFINE_RELEASE(TypeInfo); +ORT_DEFINE_RELEASE(Value); ORT_DEFINE_RELEASE(ValueInfo); -ORT_DEFINE_RELEASE(Node); -ORT_DEFINE_RELEASE(Graph); -ORT_DEFINE_RELEASE(Model); -ORT_DEFINE_RELEASE(KeyValuePairs); -ORT_DEFINE_RELEASE(PrepackedWeightsContainer); + ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); @@ -764,6 +767,7 @@ struct AllocatedFree { struct AllocatorWithDefaultOptions; struct Env; struct EpDevice; +struct ExternalInitializerInfo; struct Graph; struct Model; struct Node; @@ -881,6 +885,36 @@ struct PrepackedWeightsContainer : detail::Base { PrepackedWeightsContainer(); }; +namespace detail { +template +struct ConstExternalInitializerInfoImpl : Base { + using B = Base; + using B::B; + + // Wraps OrtApi::ExternalInitializerInfo_GetFilePath + const std::basic_string GetFilePath() const; + // Wraps OrtApi::ExternalInitializerInfo_GetFileOffset + int64_t GetFileOffset() const; + // Wraps OrtApi::ExternalInitializerInfo_GetByteSize + size_t GetByteSize() const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstExternalInitializerInfo = + detail::ConstExternalInitializerInfoImpl>; + +struct ExternalInitializerInfo : detail::ConstExternalInitializerInfoImpl { + using Base = detail::ConstExternalInitializerInfoImpl; + using Base::Base; + + explicit ExternalInitializerInfo(std::nullptr_t) {} + explicit ExternalInitializerInfo(OrtExternalInitializerInfo* p) + : detail::ConstExternalInitializerInfoImpl{p} {} + + ConstExternalInitializerInfo GetConst() const { return ConstExternalInitializerInfo{this->p_}; } +}; + namespace detail { template struct KeyValuePairsImpl : Ort::detail::Base { @@ -2429,15 +2463,46 @@ struct ArenaCfg : detail::Base { // Custom OPs (only needed to implement custom OPs) // +namespace detail { +// Need to define a templated ConstOpAttr with const members +template +struct ConstOpAttrImpl : Base { + using B = detail::Base; + using B::B; + + // Wraps OrtApi::OpAttr_GetName + std::string GetName() const; + // Wraps OrtApi::OpAttr_GetType + OrtOpAttrType GetType() const; + + // Wraps OrtApi::ReadAttr for a single value + // This does not support Tensor Attribute + // Call GetTensorAttributeAsOrtValue() instead. + template + Status GetValue(R& out) const; + + // Wraps OrtApi::ReadAttr for an array of values + template + Status GetValueArray(std::vector& out) const; + // Wraps OrtApi::OpAttr_GetTensorAttributeAsOrtValue + Status GetTensorAttributeAsOrtValue(Value&) const; +}; +} // namespace detail + +using ConstOpAttr = detail::ConstOpAttrImpl>; + /// /// This struct provides life time management for custom op attribute /// -struct OpAttr : detail::Base { - using Base = detail::Base; +struct OpAttr : detail::ConstOpAttrImpl { + using Base = detail::ConstOpAttrImpl; using Base::Base; + OpAttr() = default; // Enable storing it in the container for resize() explicit OpAttr(std::nullptr_t) {} OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); + + ConstOpAttr GetConst() const { return ConstOpAttr{this->p_}; } }; /** @@ -2783,7 +2848,7 @@ struct ShapeInferContext { Strings GetAttrStrings(const char* attr_name); private: - const OrtOpAttr* GetAttrHdl(const char* attr_name) const; + ConstOpAttr GetAttrHdl(const char* attr_name) const; const OrtApi* ort_api_; OrtShapeInferContext* ctx_; std::vector input_shapes_; @@ -2934,48 +2999,114 @@ struct CustomOpBase : OrtCustomOp { int end_ver_ = MAX_CUSTOM_OP_END_VER; }; +// Forward declaration to resolve circular dependency +// on ConstNode +struct ValueInfoConsumerProducerInfo; + namespace detail { template -struct ValueInfoImpl : Ort::detail::Base { - using B = Ort::detail::Base; +struct ConstValueInfoImpl : Base { + using B = Base; using B::B; - std::string Name() const; + /// < A wrapper around OrtApi::GetValueInfoName + std::string GetName() const; + /// < A wrapper around OrtApi::GetValueInfoTypeInfo ConstTypeInfo TypeInfo() const; + ///< Wraps OrtApi::ValueInfo_GetProducerNode + ValueInfoConsumerProducerInfo GetProducerNode() const; + /// < A wrapper around OrtApi::ValueInfo_GetValueConsumers + std::vector GetConsumers() const; + /// < A wrapper around OrtApi::ValueInfo_GetInitializerValue + Status GetInitializer(ConstValue& value) const; + /// < A wrapper around OrtApi::ValueInfo_GetExternalInitializerInfo + Status GetExternalInitializerInfo(ExternalInitializerInfo& info) const; + /// < A wrapper around OrtApi::ValueInfo_IsRequiredGraphInput + bool IsRequiredGraphInput() const; + /// < A wrapper around OrtApi::ValueInfo_IsOptionalGraphInput + bool IsOptionalGraphInput() const; + /// < A wrapper around OrtApi::ValueInfo_IsGraphOutput + bool IsGraphOutput() const; + /// < A wrapper around OrtApi::ValueInfo_IsConstantInitializer + bool IsConstantInitializer() const; + /// < A wrapper around OrtApi::ValueInfo_IsFromOuterScope + bool IsFromOuterScope() const; }; } // namespace detail // Const object holder that does not own the underlying object -using ConstValueInfo = detail::ValueInfoImpl>; +using ConstValueInfo = detail::ConstValueInfoImpl>; /** \brief Wrapper around ::OrtValueInfo * */ -struct ValueInfo : detail::ValueInfoImpl { +struct ValueInfo : detail::ConstValueInfoImpl { + ValueInfo() = default; // Same thing as with nullptr explicit ValueInfo(std::nullptr_t) {} ///< No instance is created /// Take ownership of a pointer created by C API - explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl{p} {} + explicit ValueInfo(OrtValueInfo* p) : ConstValueInfoImpl{p} {} +#if !defined(ORT_MINIMAL_BUILD) // Create ValueInfo for a tensor explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info); - +#endif ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; } }; +// Forward declaration +struct AttrNameSubgraph; + namespace detail { +// Forward decl template -struct NodeImpl : Ort::detail::Base { - using B = Ort::detail::Base; +struct ConstGraphImpl; + +template +struct ConstNodeImpl : Base { + using B = Base; using B::B; + + // GetInputs() const; + // GetOutputs() const; + // GetImplicitInputs() const; + // GetAttributes() const; + // GetSubgraphs() const; + // > GetGraph() const; + // >; + /** \brief Wrapper around ::OrtNode * */ -struct Node : detail::NodeImpl { - explicit Node(std::nullptr_t) {} ///< No instance is created - explicit Node(OrtNode* p) : NodeImpl{p} {} ///< Take ownership of a pointer created by C API +struct Node : detail::ConstNodeImpl { + Node() = default; // Same thing as with nullptr + explicit Node(std::nullptr_t) {} ///< No instance is created + explicit Node(OrtNode* p) : ConstNodeImpl{p} {} ///< Take ownership of a pointer created by C API #if !defined(ORT_MINIMAL_BUILD) Node(const std::string& operator_name, const std::string& operator_domain, @@ -3002,22 +3133,78 @@ struct Node : detail::NodeImpl { #endif // !defined(ORT_MINIMAL_BUILD) }; +// Return struct for some of ValueInfo APIs. +// Must be declared after ConstNode is available. +struct ValueInfoConsumerProducerInfo { + ConstNode node; + // either producer output or consumer output index + // producer is unsigned only, output can be -1 + int64_t index; +}; + +// Represents a return value for Graph::GetOperatorSets() +struct OperatorSet { + std::string domain; + int64_t version; +}; + namespace detail { template -struct GraphImpl : Ort::detail::Base { - using B = Ort::detail::Base; +struct ConstGraphImpl : Base { + using B = Base; + using B::B; + + // GetModelPath() const; + // GetOperatorSets() const; + // GetInputs() const; + // GetOutputs() const; + // GetInitializers() const; + // GetNodes() const; + // & nodes) const; + // +struct GraphImpl : ConstGraphImpl { + using B = ConstGraphImpl; using B::B; #if !defined(ORT_MINIMAL_BUILD) + // & inputs); + // & outputs); + // >; + +// Return value for Node API +// Must be declared after ConstGraph +struct AttrNameSubgraph { + std::string attr_name; + ConstGraph sub_graph; +}; + /** \brief Wrapper around ::OrtGraph * */ @@ -3025,25 +3212,26 @@ struct Graph : detail::GraphImpl { explicit Graph(std::nullptr_t) {} ///< No instance is created explicit Graph(OrtGraph* p) : GraphImpl{p} {} ///< Take ownership of a pointer created by C API #if !defined(ORT_MINIMAL_BUILD) + // >; namespace detail { template -struct ModelImpl : Ort::detail::Base { +struct ModelImpl : detail::Base { using B = Ort::detail::Base; using B::B; #if !defined(ORT_MINIMAL_BUILD) + // >; +using UnownedModel = detail::ModelImpl>; /** \brief Wrapper around ::OrtModel * @@ -3055,10 +3243,9 @@ struct Model : detail::ModelImpl { explicit Model(OrtModel* p) : ModelImpl{p} {} ///< Take ownership of a pointer created by C API #if !defined(ORT_MINIMAL_BUILD) + //< Wraps GetModelEditorApi().CreateModel() explicit Model(const std::vector& opsets); #endif - - ConstModel GetConst() const { return ConstModel{this->p_}; } }; } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 05c86ae4e0c58..2f01332686a6c 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -571,6 +571,24 @@ inline PrepackedWeightsContainer::PrepackedWeightsContainer() { ThrowOnError(GetApi().CreatePrepackedWeightsContainer(&this->p_)); } +namespace detail { + +template +inline const std::basic_string ConstExternalInitializerInfoImpl::GetFilePath() const { + return GetApi().ExternalInitializerInfo_GetFilePath(this->p_); +} + +template +inline int64_t ConstExternalInitializerInfoImpl::GetFileOffset() const { + return GetApi().ExternalInitializerInfo_GetFileOffset(this->p_); +} + +template +inline size_t ConstExternalInitializerInfoImpl::GetByteSize() const { + return GetApi().ExternalInitializerInfo_GetByteSize(this->p_); +} +} // namespace detail + namespace detail { template inline const char* KeyValuePairsImpl::GetValue(const char* key) const { @@ -1759,7 +1777,7 @@ inline Session::Session(const Env& env, const void* model_data, size_t model_dat #if !defined(ORT_MINIMAL_BUILD) inline Session::Session(const Env& env, const Model& model, const SessionOptions& options) { - ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_)); + ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model, options, &this->p_)); } // static @@ -2475,6 +2493,171 @@ inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data)); } +namespace detail { + +template +constexpr OrtOpAttrType TypeToAttrType(); + +template <> +inline constexpr OrtOpAttrType TypeToAttrType() { + return OrtOpAttrType::ORT_OP_ATTR_INT; +} + +template <> +inline constexpr OrtOpAttrType TypeToAttrType() { + return OrtOpAttrType::ORT_OP_ATTR_FLOAT; +} + +template +inline constexpr OrtOpAttrType TypeToAttrsType(); + +template <> +inline constexpr OrtOpAttrType TypeToAttrsType() { + return OrtOpAttrType::ORT_OP_ATTR_INTS; +} + +template <> +inline constexpr OrtOpAttrType TypeToAttrsType() { + return OrtOpAttrType::ORT_OP_ATTR_FLOATS; +} + +inline Status CheckAttrType(const OrtOpAttr* attr, OrtOpAttrType requested_type) { + OrtOpAttrType type; + Ort::Status status(GetApi().OpAttr_GetType(attr, &type)); + if (!status.IsOK()) return status; + if (requested_type != type) { + std::string msg = "Attribute type mismatch: expected " + std::to_string(requested_type) + + ", but got " + std::to_string(type); + return Ort::Status(msg.c_str(), OrtErrorCode::ORT_INVALID_ARGUMENT); + } + return Ort::Status{}; +} + +inline size_t GetDataSize(const OrtOpAttr* attr, OrtOpAttrType attr_type) { + size_t result{}; + // Ignore the status here because we check the data type so the error should only be about + // the size + [[maybe_unused]] Status status{GetApi().ReadOpAttr(attr, attr_type, nullptr, 0, &result)}; + return result; +} + +template +Ort::Status GetNumericValue(const OrtOpAttr* attr, T& out) { + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + size_t size{}; + return Ort::Status{GetApi().ReadOpAttr(attr, TypeToAttrType(), &out, sizeof(out), &size)}; +} + +template +struct GetValueImpl { + static Status GetValue(const OrtOpAttr* attr, T& out) { + return GetNumericValue(attr, out); + } + static Status GetValues(const OrtOpAttr* attr, std::vector& out) { + // Api deficiency when it comes to value arrays. It is not possible + // to tell if the error is due to the type mismatch or the size + // so we check the type first, and then ignore the status of the size check + constexpr auto deduced_type = TypeToAttrsType(); + auto status = CheckAttrType(attr, deduced_type); + if (!status.IsOK()) return status; + auto size = GetDataSize(attr, deduced_type); + std::vector result; + if (size > 0) { + result.resize(size / sizeof(T)); + status = Status{GetApi().ReadOpAttr( + attr, deduced_type, result.data(), size, &size)}; + if (!status.IsOK()) return status; + } + out.swap(result); + return status; + } +}; + +// Create GetValueImpl specializations for std::string +template <> +struct GetValueImpl { + static Status GetValue(const OrtOpAttr* attr, std::string& out) { + // Api deficiency when it comes to value arrays. It is not possible + // to tell if the error is due to the type mismatch or the size + // so we check the type first, and then ignore the status of the size check + auto status = CheckAttrType(attr, OrtOpAttrType::ORT_OP_ATTR_STRING); + if (!status.IsOK()) return status; + auto size = GetDataSize(attr, OrtOpAttrType::ORT_OP_ATTR_STRING); + std::string result; + if (size > 0) { + result.resize(size); + status = Status{GetApi().ReadOpAttr( + attr, OrtOpAttrType::ORT_OP_ATTR_STRING, result.data(), size, &size)}; + if (!status.IsOK()) return status; + } + out.swap(result); + return status; + } + static Status GetValues(const OrtOpAttr* attr, std::vector& out) { + auto status = CheckAttrType(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS); + if (!status.IsOK()) return status; + + size_t total_buffer_size = GetDataSize(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS); + + // Create a temporary buffer to hold the string data + std::vector buffer(total_buffer_size); + status = Status{GetApi().ReadOpAttr(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS, buffer.data(), + total_buffer_size, &total_buffer_size)}; + if (!status.IsOK()) return status; + + std::vector result; + if (total_buffer_size > 0) { + const char* data = buffer.data(); + const char* end = data + total_buffer_size; + while (data < end) { + result.emplace_back(data); + data += result.back().size() + 1; // Move past the null terminator + } + } + out.swap(result); + return status; + } +}; + +template +template +inline Status ConstOpAttrImpl::GetValue(R& out) const { + return GetValueImpl::GetValue(this->p_, out); +} + +template +template +inline Status ConstOpAttrImpl::GetValueArray(std::vector& out) const { + return GetValueImpl::GetValues(this->p_, out); +} + +template +inline Status ConstOpAttrImpl::GetTensorAttributeAsOrtValue(Value& out) const { + OrtValue* tensor_value = nullptr; + auto status = Status(GetApi().OpAttr_GetTensorAttributeAsOrtValue(this->p_, &tensor_value)); + if (!status.IsOK()) return status; + out = Value{tensor_value}; + return status; +} + +template +inline std::string ConstOpAttrImpl::GetName() const { + const char* name = nullptr; + ThrowOnError(GetApi().OpAttr_GetName(this->p_, &name)); + if (name != nullptr) { + return name; + } + return {}; +} + +template +inline OrtOpAttrType ConstOpAttrImpl::GetType() const { + OrtOpAttrType type; + ThrowOnError(GetApi().OpAttr_GetType(this->p_, &type)); + return type; +} +} // namespace detail + inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) { Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_)); } @@ -2775,115 +2958,69 @@ inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shap } inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - int64_t i = {}; - size_t out = {}; - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out)); - return i; + auto attr = GetAttrHdl(attr_name); + int64_t value; + Status status = attr.GetValue(value); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting int attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); + } + return value; } inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - int64_t i = {}; - size_t out = {}; - // first call to get the bytes needed - // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. - // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). - // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out); - if (status) { - size_t num_i = out / sizeof(int64_t); - ShapeInferContext::Ints ints(num_i, 0); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out)); - return ints; - } else { - if (out == 0u) { - return {}; - } - return {i}; + auto attr = GetAttrHdl(attr_name); + ShapeInferContext::Ints result; + auto status = attr.GetValueArray(result); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting ints attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return result; } inline float ShapeInferContext::GetAttrFloat(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - float f = {}; - size_t out = {}; - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out)); - return f; + auto attr = GetAttrHdl(attr_name); + float value; + Status status = attr.GetValue(value); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting float attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); + } + return value; } inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - float f = {}; - size_t out = {}; - // first call to get the bytes needed - // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. - // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). - // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out); - if (status) { - size_t num_f = out / sizeof(float); - ShapeInferContext::Floats floats(num_f, 0); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out)); - return floats; - } else { - if (out == 0u) { - return {}; - } - return {f}; + auto attr = GetAttrHdl(attr_name); + ShapeInferContext::Floats result; + auto status = attr.GetValueArray(result); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting floats attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return result; } inline std::string ShapeInferContext::GetAttrString(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - char c = {}; - size_t out = {}; - // first call to get the bytes needed - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out); - if (status) { - std::vector chars(out, '\0'); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out)); - return std::string{chars.data(), out}; - } else { - return {c}; + auto attr = GetAttrHdl(attr_name); + std::string value; + Status status = attr.GetValue(value); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting string attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return value; } inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - char c = {}; - size_t out = {}; - // first call to get the bytes needed - // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. - // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). - // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out); - if (status) { - std::vector chars(out, '\0'); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out)); - ShapeInferContext::Strings strings; - char* char_st = chars.data(); - char* char_ed = char_st + out; - while (char_st < char_ed) { - strings.emplace_back(char_st); - while (*char_st != '\0') { - char_st++; - } - char_st++; - } - return strings; - } else { - if (out == 0u) { - return {}; - } - return {std::string{c}}; + auto attr = GetAttrHdl(attr_name); + ShapeInferContext::Strings result; + auto status = attr.GetValueArray(result); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting strings attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return result; } -inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const { +inline ConstOpAttr ShapeInferContext::GetAttrHdl(const char* attr_name) const { const OrtOpAttr* attr_hdl = {}; Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl)); - return attr_hdl; + return ConstOpAttr{attr_hdl}; } namespace detail { @@ -2897,6 +3034,136 @@ inline std::vector StringsToCharPtrs(const std::vector } } // namespace detail +namespace detail { +template +inline size_t ConstNodeImpl::GetId() const { + size_t id; + ThrowOnError(GetApi().Node_GetId(this->p_, &id)); + return id; +} + +template +inline std::string ConstNodeImpl::GetName() const { + const char* name; + ThrowOnError(GetApi().Node_GetName(this->p_, &name)); + return std::string(name); +} + +template +inline std::string ConstNodeImpl::GetOperatorType() const { + const char* type; + ThrowOnError(GetApi().Node_GetOperatorType(this->p_, &type)); + return std::string(type); +} + +template +inline std::string ConstNodeImpl::GetDomain() const { + const char* domain; + ThrowOnError(GetApi().Node_GetDomain(this->p_, &domain)); + return std::string(domain); +} + +template +inline int ConstNodeImpl::GetSinceVersion() const { + int since_version; + ThrowOnError(GetApi().Node_GetSinceVersion(this->p_, &since_version)); + return since_version; +} + +template +inline std::vector ConstNodeImpl::GetInputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Node_GetNumInputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Node_GetInputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstNodeImpl::GetOutputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Node_GetNumOutputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Node_GetOutputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstNodeImpl::GetImplicitInputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Node_GetNumImplicitInputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Node_GetImplicitInputs(this->p_, reinterpret_cast(result.data()), + num_vi)); + } + return result; +} + +template +inline std::vector ConstNodeImpl::GetAttributes() const { + static_assert(sizeof(const OrtOpAttr*) == sizeof(ConstOpAttr), "Must be the same size"); + size_t num_attrs; + ThrowOnError(GetApi().Node_GetNumAttributes(this->p_, &num_attrs)); + std::vector attrs; + if (num_attrs > 0) { + attrs.resize(num_attrs); + ThrowOnError(GetApi().Node_GetAttributes(this->p_, reinterpret_cast(attrs.data()), num_attrs)); + } + return attrs; +} + +template +inline Status ConstNodeImpl::GetAttributeByName(const std::string& name, ConstOpAttr& out) const { + const OrtOpAttr* attr = nullptr; + auto status = Status(GetApi().Node_GetAttributeByName(this->p_, name.c_str(), &attr)); + out = ConstOpAttr{attr}; + return status; +} + +template +inline std::vector ConstNodeImpl::GetSubgraphs() const { + size_t num_graphs; + ThrowOnError(GetApi().Node_GetNumSubgraphs(this->p_, &num_graphs)); + std::vector result; + if (num_graphs > 0) { + std::vector sub_graphs(num_graphs); + std::vector attr_names(num_graphs); + ThrowOnError(GetApi().Node_GetSubgraphs(this->p_, sub_graphs.data(), num_graphs, attr_names.data())); + result.reserve(num_graphs); + for (size_t i = 0; i < num_graphs; ++i) { + result.push_back({std::string(attr_names[i]), ConstGraph{sub_graphs[i]}}); + } + } + return result; +} + +template +inline ConstGraph ConstNodeImpl::GetGraph() const { + const OrtGraph* graph; + ThrowOnError(GetApi().Node_GetGraph(this->p_, &graph)); + return ConstGraph{graph}; +} + +template +inline std::string ConstNodeImpl::GetEpName() const { + const char* name; + ThrowOnError(GetApi().Node_GetEpName(this->p_, &name)); + return std::string(name); +} + +} // namespace detail + #if !defined(ORT_MINIMAL_BUILD) // static inline void Node::Init(const std::string& operator_name, const std::string& operator_domain, @@ -2938,97 +3205,294 @@ inline Node::Node(const std::string& operator_name, const std::string& operator_ std::vector empty_attributes; Init(operator_name, operator_domain, node_name, input_names, output_names, empty_attributes, p_); } +inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { + ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_)); +} +#endif // !defined(ORT_MINIMAL_BUILD) -inline Graph::Graph() { - ThrowOnError(GetModelEditorApi().CreateGraph(&p_)); +namespace detail { +template +inline std::string ConstValueInfoImpl::GetName() const { + const char* p = nullptr; + ThrowOnError(GetApi().GetValueInfoName(this->p_, &p)); + return std::string(p); } -inline Model::Model(const std::vector& opsets) { - std::vector domains; - std::vector versions; - domains.reserve(opsets.size()); - versions.reserve(opsets.size()); +template +inline ConstTypeInfo ConstValueInfoImpl::TypeInfo() const { + const OrtTypeInfo* type_info = nullptr; + ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info)); + return ConstTypeInfo{type_info}; +} - for (const auto& pair : opsets) { - domains.push_back(pair.first.c_str()); - versions.push_back(pair.second); +template +inline ValueInfoConsumerProducerInfo ConstValueInfoImpl::GetProducerNode() const { + ValueInfoConsumerProducerInfo info; + const OrtNode* producer; + size_t index; + ThrowOnError(GetApi().ValueInfo_GetValueProducer(this->p_, &producer, &index)); + info.node = ConstNode(producer); + info.index = static_cast(index); + return info; +} + +template +inline std::vector ConstValueInfoImpl::GetConsumers() const { + size_t num = 0; + ThrowOnError(GetApi().ValueInfo_GetValueNumConsumers(this->p_, &num)); + std::vector out; + if (num > 0) { + std::vector nodes(num); + std::vector indices(num); + ThrowOnError(GetApi().ValueInfo_GetValueConsumers(this->p_, nodes.data(), indices.data(), num)); + out.reserve(num); + for (size_t i = 0; i < num; ++i) { + out.push_back({ConstNode{nodes[i]}, indices[i]}); + } } + return out; +} - ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); +template +inline Status ConstValueInfoImpl::GetInitializer(ConstValue& value) const { + const OrtValue* out = nullptr; + auto status = Status(GetApi().ValueInfo_GetInitializerValue(this->p_, &out)); + if (!status.IsOK()) return status; + value = ConstValue{out}; + return status; } -inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { - ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_)); +template +inline Status ConstValueInfoImpl::GetExternalInitializerInfo(ExternalInitializerInfo& info) const { + OrtExternalInitializerInfo* out = nullptr; + auto status = Status(GetApi().ValueInfo_GetExternalInitializerInfo(this->p_, &out)); + if (!status.IsOK()) return status; + info = ExternalInitializerInfo{out}; + return status; } -#endif // !defined(ORT_MINIMAL_BUILD) -namespace detail { -template <> -inline std::string ValueInfoImpl::Name() const { - const char* name = nullptr; - ThrowOnError(GetApi().GetValueInfoName(this->p_, &name)); - return name; +template +inline bool ConstValueInfoImpl::IsRequiredGraphInput() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsRequiredGraphInput(this->p_, &out)); + return out; } -template <> -inline ConstTypeInfo ValueInfoImpl::TypeInfo() const { - const OrtTypeInfo* type_info = nullptr; - ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info)); - return ConstTypeInfo{type_info}; +template +inline bool ConstValueInfoImpl::IsOptionalGraphInput() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsOptionalGraphInput(this->p_, &out)); + return out; +} + +template +inline bool ConstValueInfoImpl::IsGraphOutput() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsGraphOutput(this->p_, &out)); + return out; +} + +template +inline bool ConstValueInfoImpl::IsConstantInitializer() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsConstantInitializer(this->p_, &out)); + return out; +} + +template +inline bool ConstValueInfoImpl::IsFromOuterScope() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsFromOuterScope(this->p_, &out)); + return out; +} + +template +inline ModelMetadata ConstGraphImpl::GetModelMetadata() const { + OrtModelMetadata* out; + ThrowOnError(GetApi().Graph_GetModelMetadata(this->p_, &out)); + return ModelMetadata{out}; +} + +template +inline std::string ConstGraphImpl::GetName() const { + const char* name; + ThrowOnError(GetApi().Graph_GetName(this->p_, &name)); + return std::string(name); +} + +template +inline std::basic_string ConstGraphImpl::GetModelPath() const { + const ORTCHAR_T* path; + ThrowOnError(GetApi().Graph_GetModelPath(this->p_, &path)); + return std::basic_string(path); +} + +template +inline int64_t ConstGraphImpl::GetOnnxIRVersion() const { + int64_t version; + ThrowOnError(GetApi().Graph_GetOnnxIRVersion(this->p_, &version)); + return version; +} + +template +inline std::vector ConstGraphImpl::GetOperatorSets() const { + size_t num_opsets; + ThrowOnError(GetApi().Graph_GetNumOperatorSets(this->p_, &num_opsets)); + std::vector result; + if (num_opsets > 0) { + std::vector domains; + std::vector versions; + domains.resize(num_opsets); + versions.resize(num_opsets); + ThrowOnError(GetApi().Graph_GetOperatorSets(this->p_, domains.data(), versions.data(), num_opsets)); + result.reserve(num_opsets); + for (size_t i = 0; i < num_opsets; ++i) { + result.push_back({domains[i], versions[i]}); + } + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetInputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Graph_GetNumInputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Graph_GetInputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetOutputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Graph_GetNumOutputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Graph_GetOutputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetInitializers() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Graph_GetNumInitializers(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Graph_GetInitializers(this->p_, reinterpret_cast(result.data()), + num_vi)); + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetNodes() const { + static_assert(sizeof(const OrtNode*) == sizeof(ConstNode)); + size_t num_nodes; + ThrowOnError(GetApi().Graph_GetNumNodes(this->p_, &num_nodes)); + std::vector result; + if (num_nodes > 0) { + result.resize(num_nodes); + ThrowOnError(GetApi().Graph_GetNodes(this->p_, reinterpret_cast(result.data()), num_nodes)); + } + return result; +} + +template +inline ConstNode ConstGraphImpl::GetParentNode() const { + const OrtNode* parent; + ThrowOnError(GetApi().Graph_GetParentNode(this->p_, &parent)); + return ConstNode{parent}; +} + +template +inline Graph ConstGraphImpl::GetGraphView(const std::vector& nodes) const { + OrtGraph* graph_viewer; + std::vector inputs_ptrs; + inputs_ptrs.reserve(nodes.size()); + std::transform(nodes.begin(), nodes.end(), std::back_inserter(inputs_ptrs), + [](ConstNode n) -> const OrtNode* { return n; }); + ThrowOnError(GetApi().Graph_GetGraphView(this->p_, inputs_ptrs.data(), + nodes.size(), &graph_viewer)); + return Graph{graph_viewer}; } #if !defined(ORT_MINIMAL_BUILD) -template <> -inline void GraphImpl::SetInputs(std::vector& inputs) { +template +inline void GraphImpl::SetInputs(std::vector& inputs) { std::vector inputs_ptrs; inputs_ptrs.reserve(inputs.size()); std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs), [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); - ThrowOnError(GetModelEditorApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size())); + ThrowOnError(GetModelEditorApi().SetGraphInputs(this->p_, inputs_ptrs.data(), inputs_ptrs.size())); // Graph now owns the inputs std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); }); } -template <> -inline void GraphImpl::SetOutputs(std::vector& outputs) { +template +inline void GraphImpl::SetOutputs(std::vector& outputs) { std::vector outputs_ptrs; outputs_ptrs.reserve(outputs.size()); std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs), [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); - ThrowOnError(GetModelEditorApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size())); + ThrowOnError(GetModelEditorApi().SetGraphOutputs(this->p_, outputs_ptrs.data(), outputs_ptrs.size())); // Graph now owns the outputs std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); }); } -template <> -inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { +template +inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { // Graph takes ownership of `initializer` - ThrowOnError(GetModelEditorApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external)); + // On error the ownership is not transferred. + ThrowOnError(GetModelEditorApi().AddInitializerToGraph(this->p_, name.c_str(), initializer, data_is_external)); + initializer.release(); } -template <> -inline void GraphImpl::AddNode(Node& node) { +template +inline void GraphImpl::AddNode(Node& node) { // Graph takes ownership of `node` - ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release())); + ThrowOnError(GetModelEditorApi().AddNodeToGraph(this->p_, node.release())); } template -inline ModelMetadata GraphImpl::GetModelMetadata() const { - OrtModelMetadata* out; - ThrowOnError(GetApi().Graph_GetModelMetadata(this->p_, &out)); - return ModelMetadata{out}; -} - -template <> -inline void ModelImpl::AddGraph(Graph& graph) { +inline void ModelImpl::AddGraph(Graph& graph) { // Model takes ownership of `graph` - ThrowOnError(GetModelEditorApi().AddGraphToModel(p_, graph.release())); + ThrowOnError(GetModelEditorApi().AddGraphToModel(this->p_, graph.release())); } #endif // !defined(ORT_MINIMAL_BUILD) } // namespace detail + +#if !defined(ORT_MINIMAL_BUILD) +inline Graph::Graph() { + ThrowOnError(GetModelEditorApi().CreateGraph(&p_)); +} + +inline Model::Model(const std::vector& opsets) { + std::vector domains; + std::vector versions; + domains.reserve(opsets.size()); + versions.reserve(opsets.size()); + + for (const auto& pair : opsets) { + domains.push_back(pair.first.c_str()); + versions.push_back(pair.second); + } + + ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); +} +#endif + } // namespace Ort diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index 287eba05a0595..e4265713d2d0a 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -33,95 +33,88 @@ struct MulKernel { return iter != float_initializers.end() ? &iter->second : nullptr; } - OrtStatus* GetInputDataAndShape(OrtKernelContext* kernel_context, size_t index, - /*out*/ gsl::span& data, - /*out*/ std::vector& shape) const { - const OrtValue* input = nullptr; - RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, index, &input)); - - OrtTensorTypeAndShapeInfo* type_shape = nullptr; - DeferOrtRelease release_type(&type_shape, ort_api.ReleaseTensorTypeAndShapeInfo); - - RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input, &type_shape)); - - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); - RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 inputs"); - - size_t num_elems = 0; - RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(type_shape, &num_elems)); - - size_t num_dims = 0; - RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape, &num_dims)); - - shape.resize(num_dims, 0); - RETURN_IF_ERROR(ort_api.GetDimensions(type_shape, shape.data(), shape.size())); - - const void* raw_data = nullptr; - RETURN_IF_ERROR(ort_api.GetTensorData(input, &raw_data)); - - const float* float_data = static_cast(raw_data); + void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const { + Ort::ConstValue input = kernel_context.GetInput(index); + auto type_shape = input.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); + + const float* float_data = input.GetTensorData(); + size_t num_elems = type_shape.GetElementCount(); data = gsl::span(float_data, num_elems); - return nullptr; + shape = type_shape.GetShape(); } - OrtStatus* Compute(OrtKernelContext* kernel_context) { + OrtStatus* Compute(OrtKernelContext* kernel_ctx) { RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); - gsl::span input0; - gsl::span input1; - std::vector shape0; - std::vector shape1; - - size_t num_inputs = 0; - RETURN_IF_ERROR(ort_api.KernelContext_GetInputCount(kernel_context, &num_inputs)); - - if (num_inputs == 2) { - // Both inputs are non-constant. Get them from ORT's KernelContext. - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input0, shape0)); - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 1, input1, shape1)); - } else if (num_inputs == 1) { - // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. - // Get the constant input from the initializers saved by the EP. - // Refer to "NodeFusionOptions_DropConstantInitializers()". - - if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input1, shape1)); + Ort::KernelContext kernel_context(kernel_ctx); + try { + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + + size_t num_inputs = kernel_context.GetInputCount(); + + if (num_inputs == 2) { + // Both inputs are non-constant. Get them from ORT's KernelContext. + GetInputDataAndShape(kernel_context, 0, input0, shape0); + GetInputDataAndShape(kernel_context, 1, input1, shape1); + } else if (num_inputs == 1) { + // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. + // Get the constant input from the initializers saved by the EP. + // Refer to "NodeFusionOptions_DropConstantInitializers()". + + if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input1, shape1); + input0 = gsl::span(const_input0->data); + shape0 = const_input0->shape; + } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input0, shape0); + input1 = gsl::span(const_input1->data); + shape1 = const_input1->shape; + } + } else { + // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) + // are disabled. + const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); + const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); + RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, + "Expected 2 initializer inputs to be saved by EP"); + input0 = gsl::span(const_input0->data); - shape0 = const_input0->shape; - } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input0, shape0)); input1 = gsl::span(const_input1->data); + shape0 = const_input0->shape; shape1 = const_input1->shape; } - } else { - // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) - // are disabled. - const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); - const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); - RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, - "Expected 2 initializer inputs to be saved by EP"); - - input0 = gsl::span(const_input0->data); - input1 = gsl::span(const_input1->data); - shape0 = const_input0->shape; - shape1 = const_input1->shape; - } - RETURN_IF(shape0 != shape1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting. + if (shape0 != shape1) { + throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); + } - size_t num_outputs = 0; - RETURN_IF_ERROR(ort_api.KernelContext_GetOutputCount(kernel_context, &num_outputs)); - RETURN_IF(num_outputs != 1, ort_api, "Expected 1 output for MulKernel"); + size_t num_outputs = kernel_context.GetOutputCount(); + if (num_outputs != 1) { + throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); + } - OrtValue* output = nullptr; - float* output_data = nullptr; - RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_context, 0, shape0.data(), shape0.size(), &output)); - RETURN_IF_ERROR(ort_api.GetTensorMutableData(output, reinterpret_cast(&output_data))); + auto output = kernel_context.GetOutput(0, shape0); + float* output_data = output.GetTensorMutableData(); - for (size_t i = 0; i < input0.size(); ++i) { - output_data[i] = input0[i] * input1[i]; + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; @@ -183,178 +176,175 @@ const char* ORT_API_CALL ExampleEp ::GetNameImpl(const OrtEp* this_ptr) noexcept return ep->name_.c_str(); } -OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) { - size_t num_initializers = 0; - RETURN_IF_ERROR(ort_api.Graph_GetNumInitializers(graph, &num_initializers)); - - std::vector initializers(num_initializers); - RETURN_IF_ERROR(ort_api.Graph_GetInitializers(graph, initializers.data(), initializers.size())); - - for (const OrtValueInfo* initializer : initializers) { - bool is_constant = false; - RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(initializer, &is_constant)); - - if (is_constant) { - const char* name = nullptr; - const OrtValue* value = nullptr; - OrtTensorTypeAndShapeInfo* type_shape = nullptr; - DeferOrtRelease release_type(&type_shape, ort_api.ReleaseTensorTypeAndShapeInfo); - size_t num_elems = 0; +OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* ort_graph) { + Ort::ConstGraph graph{ort_graph}; - RETURN_IF_ERROR(ort_api.GetValueInfoName(initializer, &name)); - RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer, &value)); - RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(value, &type_shape)); - RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(type_shape, &num_elems)); + try { + std::vector initializers = graph.GetInitializers(); - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); - RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 initializers"); + for (const auto& initializer : initializers) { + const bool is_constant = initializer.IsConstantInitializer(); - size_t num_dims = 0; - RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape, &num_dims)); + if (is_constant) { + auto name = initializer.GetName(); + Ort::ConstValue value; + auto status = initializer.GetInitializer(value); + if (!status.IsOK()) + return status.release(); - std::vector dims(num_dims, 0); - RETURN_IF_ERROR(ort_api.GetDimensions(type_shape, dims.data(), dims.size())); + auto type_shape = value.GetTensorTypeAndShapeInfo(); + const size_t num_elems = type_shape.GetElementCount(); + const ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + return Ort::Status("Expected float32 initializers", ORT_INVALID_ARGUMENT).release(); - const float* data = nullptr; - RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(value), (void**)&data)); + std::vector dims = type_shape.GetShape(); + const float* data = value.GetTensorData(); - FloatInitializer ep_initializer = {std::move(dims), std::vector(data, data + num_elems)}; - float_initializers_.emplace(name, std::move(ep_initializer)); + FloatInitializer ep_initializer = {std::move(dims), std::vector(data, data + num_elems)}; + float_initializers_.emplace(std::move(name), std::move(ep_initializer)); + } } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; } /*static*/ -OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, +OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, OrtEpGraphSupportInfo* graph_support_info) noexcept { - ExampleEp* ep = static_cast(this_ptr); + try { + ExampleEp* ep = static_cast(this_ptr); - size_t num_nodes = 0; - RETURN_IF_ERROR(ep->ort_api.Graph_GetNumNodes(graph, &num_nodes)); + Ort::ConstGraph graph{ort_graph}; + std::vector nodes = graph.GetNodes(); + if (nodes.empty()) { + return nullptr; // No nodes to process + } - if (num_nodes == 0) { - return nullptr; // No nodes to process - } + std::vector supported_nodes; - std::vector nodes(num_nodes); - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); - - std::vector supported_nodes; - - for (const OrtNode* node : nodes) { - const char* op_type = nullptr; - RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); - - if (std::strncmp(op_type, "Mul", 4) == 0) { - // Check that Mul has inputs/output of type float - size_t num_inputs = 0; - size_t num_outputs = 0; - RETURN_IF_ERROR(ep->ort_api.Node_GetNumInputs(node, &num_inputs)); - RETURN_IF_ERROR(ep->ort_api.Node_GetNumOutputs(node, &num_outputs)); - RETURN_IF(num_inputs != 2 || num_outputs != 1, ep->ort_api, "Mul should have 2 inputs and 1 output"); - - std::vector inputs(num_inputs); - std::vector outputs(num_outputs); - RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(node, inputs.data(), inputs.size())); - RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); - - std::array is_float = {false, false, false}; - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, inputs[0], is_float[0])); - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, inputs[1], is_float[1])); - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, outputs[0], is_float[2])); - if (!is_float[0] || !is_float[1] || !is_float[2]) { - continue; // Input or output is not of type float - } + for (const auto& node : nodes) { + auto op_type = node.GetOperatorType(); - supported_nodes.push_back(node); // Only support a single Mul for now. - break; - } - } + if (op_type != "Mul") { + // Check that Mul has inputs/output of type float + std::vector inputs = node.GetInputs(); + std::vector outputs = node.GetOutputs(); + + RETURN_IF(inputs.size() != 2 || outputs.size() != 1, ep->ort_api, "Mul should have 2 inputs and 1 output"); - // Create (optional) fusion options for the supported nodes to fuse. - OrtNodeFusionOptions node_fusion_options = {}; - node_fusion_options.ort_version_supported = ORT_API_VERSION; + std::array is_float = {false, false, false}; + IsFloatTensor(inputs[0], is_float[0]); + IsFloatTensor(inputs[1], is_float[1]); + IsFloatTensor(outputs[0], is_float[2]); + if (!is_float[0] || !is_float[1] || !is_float[2]) { + continue; // Input or output is not of type float + } - // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers - // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. - // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use - // during inference. - node_fusion_options.drop_constant_initializers = true; - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), - supported_nodes.size(), &node_fusion_options)); + supported_nodes.push_back(node); // Only support a single Mul for now. + break; + } + } + + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } return nullptr; } /*static*/ -OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, +OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** ort_graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { - ExampleEp* ep = static_cast(this_ptr); - const OrtApi& ort_api = ep->ort_api; - - if (count != 1) { - return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single graph"); - } - - // In GetCapability(), this EP specified that it doesn't need ORT to provide constant initializers during inference. - // So, this EP saves constant initializers so that they're available during inference, but an actual EP - // implementation could transfer the weights to device memory. - ep->SaveConstantInitializers(graphs[0]); - - size_t num_nodes = 0; - RETURN_IF_ERROR(ep->ort_api.Graph_GetNumNodes(graphs[0], &num_nodes)); - - std::vector nodes(num_nodes); - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graphs[0], nodes.data(), nodes.size())); - - if (num_nodes != 1) { - return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); - } + try { + if (count != 1) { + Ort::Status status("Expected to compile a single graph", ORT_EP_FAIL); + return status.release(); + } - const char* node_op_type = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetOperatorType(nodes[0], &node_op_type)); + ExampleEp* ep = static_cast(this_ptr); - if (std::strncmp(node_op_type, "Mul", 4) != 0) { - return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); - } + Ort::ConstGraph graph{ort_graphs[0]}; - // Now we know we're compiling a single Mul node. Create a computation kernel. - std::array node_inputs = {}; - std::array node_input_names = {}; + // In GetCapability(), this EP specified that it doesn't need ORT to provide constant initializers during inference. + // So, this EP saves constant initializers so that they're available during inference, but an actual EP + // implementation could transfer the weights to device memory. + ep->SaveConstantInitializers(graph); - RETURN_IF_ERROR(ort_api.Node_GetInputs(nodes[0], node_inputs.data(), node_inputs.size())); - RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); - RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); + std::vector nodes = graph.GetNodes(); + if (nodes.size() != 1) { + Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL); + return status.release(); + } - const char* ep_name = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetEpName(fused_nodes[0], &ep_name)); - if (std::strncmp(ep_name, "example_ep", 11) != 0) { - return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on"); - } + auto node_op_type = nodes[0].GetOperatorType(); + if (node_op_type != "Mul") { + Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL); + return status.release(); + } - // Associate the name of the fused node with our MulKernel. - const char* fused_node_name = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); + // Now we know we're compiling a single Mul node. Create a computation kernel. + std::vector node_inputs = nodes[0].GetInputs(); + std::array node_input_names; + node_input_names[0] = node_inputs[0].GetName(); + node_input_names[1] = node_inputs[1].GetName(); + + Ort::ConstNode fused_node{fused_nodes[0]}; + auto ep_name = fused_node.GetEpName(); + if (ep_name != "example_ep") { + Ort::Status status("The fused node is expected to assigned to this EP to run on", ORT_EP_FAIL); + return status.release(); + } - ep->kernels_.emplace(std::string(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, + // Associate the name of the fused node with our MulKernel. + auto fused_node_name = fused_node.GetName(); + ep->kernels_.emplace(std::move(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, ep->float_initializers_, node_input_names[0], node_input_names[1])); - // Update the OrtNodeComputeInfo associated with the graph. - auto node_compute_info = std::make_unique(*ep); - node_compute_infos[0] = node_compute_info.release(); + // Update the OrtNodeComputeInfo associated with the graph. + auto node_compute_info = std::make_unique(*ep); + node_compute_infos[0] = node_compute_info.release(); - // Create EpContext nodes for the fused nodes we compiled. - if (ep->config_.enable_ep_context) { - assert(ep_context_nodes != nullptr); - RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), - gsl::span(ep_context_nodes, count))); + // Create EpContext nodes for the fused nodes we compiled. + if (ep->config_.enable_ep_context) { + assert(ep_context_nodes != nullptr); + RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), + gsl::span(ep_context_nodes, count))); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; @@ -375,69 +365,74 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, // cannot currently run the EPContext model. OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes) { - assert(fused_nodes.size() == ep_context_nodes.size()); + try { + assert(fused_nodes.size() == ep_context_nodes.size()); - // Helper to collect input or output names from an array of OrtValueInfo instances. - auto collect_input_output_names = [&](gsl::span value_infos, - std::vector& result) -> OrtStatus* { - size_t num_values = value_infos.size(); - std::vector value_names(num_values); + // Helper to collect input or output names from an array of OrtValueInfo instances. + auto collect_input_output_names = [&](gsl::span value_infos, + std::vector& result) { + std::vector value_names; + value_names.reserve(value_infos.size()); - for (size_t i = 0; i < num_values; ++i) { - const OrtValueInfo* value_info = value_infos[i]; - RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &value_names[i])); - } + for (const auto vi : value_infos) { + value_names.push_back(vi.GetName()); + } - result = std::move(value_names); - return nullptr; - }; - - // Create an "EPContext" node for every fused node. - for (size_t i = 0; i < fused_nodes.size(); ++i) { - const OrtNode* fused_node = fused_nodes[i]; - const char* fused_node_name = nullptr; - - RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &fused_node_name)); - - size_t num_fused_node_inputs = 0; - size_t num_fused_node_outputs = 0; - RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_fused_node_inputs)); - RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(fused_node, &num_fused_node_outputs)); - - std::vector fused_node_inputs(num_fused_node_inputs); - std::vector fused_node_outputs(num_fused_node_outputs); - RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, fused_node_inputs.data(), fused_node_inputs.size())); - RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, fused_node_outputs.data(), fused_node_outputs.size())); - - std::vector input_names; - std::vector output_names; - - RETURN_IF_ERROR(collect_input_output_names(fused_node_inputs, /*out*/ input_names)); - RETURN_IF_ERROR(collect_input_output_names(fused_node_outputs, /*out*/ output_names)); - - int64_t is_main_context = (i == 0); - int64_t embed_mode = 1; - - // Create node attributes. The CreateNode() function copies the attributes, so we have to release them. - std::array attributes = {}; - DeferOrtRelease defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); - - std::string ep_ctx = "binary_data"; - RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", ep_ctx.c_str(), static_cast(ep_ctx.length()), - ORT_OP_ATTR_STRING, &attributes[0])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT, &attributes[1])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[2])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING, &attributes[3])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("partition_name", fused_node_name, static_cast(strlen(fused_node_name)), - ORT_OP_ATTR_STRING, &attributes[4])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("source", this->name_.c_str(), static_cast(this->name_.length()), - ORT_OP_ATTR_STRING, &attributes[5])); - - RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, - input_names.data(), input_names.size(), - output_names.data(), output_names.size(), - attributes.data(), attributes.size(), - &ep_context_nodes[i])); + result = std::move(value_names); + }; + + // Create an "EPContext" node for every fused node. + for (size_t i = 0; i < fused_nodes.size(); ++i) { + Ort::ConstNode fused_node{fused_nodes[i]}; + auto fused_node_name = fused_node.GetName(); + + std::vector fused_node_inputs = fused_node.GetInputs(); + std::vector fused_node_outputs = fused_node.GetOutputs(); + + std::vector input_names; + std::vector output_names; + + collect_input_output_names(fused_node_inputs, /*out*/ input_names); + collect_input_output_names(fused_node_outputs, /*out*/ output_names); + + int64_t is_main_context = (i == 0); + int64_t embed_mode = 1; + + // Create node attributes. The CreateNode() function copies the attributes. + std::array attributes = {}; + std::string ep_ctx = "binary_data"; + attributes[0] = Ort::OpAttr("ep_cache_context", ep_ctx.data(), static_cast(ep_ctx.size()), + ORT_OP_ATTR_STRING); + + attributes[1] = Ort::OpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT); + attributes[2] = Ort::OpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT); + attributes[3] = Ort::OpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING); + attributes[4] = Ort::OpAttr("partition_name", fused_node_name.data(), static_cast(fused_node_name.size()), + ORT_OP_ATTR_STRING); + + attributes[5] = Ort::OpAttr("source", this->name_.data(), static_cast(this->name_.size()), + ORT_OP_ATTR_STRING); + + std::vector c_input_names; + std::transform(input_names.begin(), input_names.end(), std::back_inserter(c_input_names), + [](const std::string& s) { return s.c_str(); }); + std::vector c_output_names; + std::transform(output_names.begin(), output_names.end(), std::back_inserter(c_output_names), + [](const std::string& s) { return s.c_str(); }); + + OrtOpAttr** op_attrs = reinterpret_cast(attributes.data()); + RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name.c_str(), + c_input_names.data(), c_input_names.size(), + c_output_names.data(), c_output_names.size(), + op_attrs, attributes.size(), + &ep_context_nodes[i])); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index fa6eb24c5cc04..279925a7ec3e1 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -54,7 +54,7 @@ class ExampleEp : public OrtEp, public ApiPtrs { OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); - OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph); + OrtStatus* SaveConstantInitializers(const OrtGraph* graph); ExampleEpFactory& factory_; std::string name_; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep.cc index b6f982a422b6a..c14bdc1b52093 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep.cc @@ -1,6 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + #include "ep_factory.h" // To make symbols visible on macOS/iOS @@ -21,6 +25,9 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const const OrtEpApi* ep_api = ort_api->GetEpApi(); const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); + // Manual init for the C++ API + Ort::InitApi(ort_api); + // Factory could use registration_name or define its own EP name. std::unique_ptr factory = std::make_unique(registration_name, ApiPtrs{*ort_api, *ep_api, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc index 549551931c647..263b4d208bd91 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc @@ -5,48 +5,33 @@ #include -OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessionOptions& session_options, +OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& /* ort_api */, const OrtSessionOptions& session_options, const char* config_key, const std::string& default_val, /*out*/ std::string& config_val) { - int has_config = 0; - RETURN_IF_ERROR(ort_api.HasSessionConfigEntry(&session_options, config_key, &has_config)); - - if (has_config != 1) { - config_val = default_val; - return nullptr; + try { + Ort::ConstSessionOptions sess_opt{&session_options}; + config_val = sess_opt.GetConfigEntryOrDefault(config_key, default_val); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); } - size_t size = 0; - RETURN_IF_ERROR(ort_api.GetSessionConfigEntry(&session_options, config_key, nullptr, &size)); - - config_val.resize(size); - RETURN_IF_ERROR(ort_api.GetSessionConfigEntry(&session_options, config_key, config_val.data(), &size)); - config_val.resize(size - 1); // remove the terminating '\0' - return nullptr; } -OrtStatus* IsFloatTensor(const OrtApi& ort_api, const OrtValueInfo* value_info, bool& result) { +void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result) { result = false; - const OrtTypeInfo* type_info = nullptr; - RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(value_info, &type_info)); - - ONNXType onnx_type = ONNX_TYPE_UNKNOWN; - RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(type_info, &onnx_type)); + auto type_info = value_info.TypeInfo(); + ONNXType onnx_type = type_info.GetONNXType(); if (onnx_type != ONNX_TYPE_TENSOR) { - return nullptr; + return; } - const OrtTensorTypeAndShapeInfo* type_shape = nullptr; - RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(type_info, &type_shape)); - - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); + auto type_shape = type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - return nullptr; + return; } - result = true; - return nullptr; } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h index 99ebee9ff64de..e8c086d38a7cb 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h @@ -107,4 +107,4 @@ OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessio /*out*/ std::string& config_val); // Returns true (via output parameter) if the given OrtValueInfo represents a float tensor. -OrtStatus* IsFloatTensor(const OrtApi& ort_api, const OrtValueInfo* value_info, bool& result); +void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 513097aaf7ade..7e6d157799d86 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -100,30 +100,20 @@ TEST(EpGraphTest, GetAttributeByName) { // 'auto_pad' and 'group'. The other optional attributes (i.e. dilations, kernel_shape, pads, strides) do not // have statically computable default values, so will not be filled in by Graph::Resolve(). const OrtGraph& ort_graph = test_graph->GetOrtGraph(); - const OrtApi& ort_api = Ort::GetApi(); - - size_t num_nodes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); - ASSERT_EQ(num_nodes, 1); + Ort::ConstGraph graph{&ort_graph}; - std::vector nodes(num_nodes); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + auto nodes = graph.GetNodes(); + ASSERT_EQ(nodes.size(), 1); - const OrtNode* conv_node = nodes[0]; - const char* op_type = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetOperatorType(conv_node, &op_type)); - ASSERT_STREQ(op_type, "Conv"); + auto conv_node = nodes[0]; + auto op_type = conv_node.GetOperatorType(); + ASSERT_EQ(op_type, "Conv"); - size_t num_attrs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(conv_node, &num_attrs)); - ASSERT_EQ(num_attrs, 2); + auto attrs = conv_node.GetAttributes(); + ASSERT_EQ(attrs.size(), 2); - std::vector attrs(num_attrs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(conv_node, attrs.data(), attrs.size())); - for (const OrtOpAttr* attr : attrs) { - const char* attr_name_cstr = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(attr, &attr_name_cstr)); - std::string_view attr_name = attr_name_cstr; + for (const auto& attr : attrs) { + auto attr_name = attr.GetName(); ASSERT_TRUE(attr_name == "auto_pad" || attr_name == "group"); // Only 'auto_pad' and 'group' have been set } @@ -131,9 +121,8 @@ TEST(EpGraphTest, GetAttributeByName) { // Test 1: Get optional attribute that is not set (e.g., dilations). Should not get an error. // { - const OrtOpAttr* attr = nullptr; - Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "dilations", &attr)}; - ASSERT_TRUE(status.IsOK()); + Ort::ConstOpAttr attr; + auto status = conv_node.GetAttributeByName("dilations", attr); ASSERT_EQ(attr, nullptr); } @@ -141,8 +130,8 @@ TEST(EpGraphTest, GetAttributeByName) { // Test 2: Get attribute that does not exist in operator schema. Should get a ORT_NOT_FOUND error. // { - const OrtOpAttr* attr = nullptr; - Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "_does_not_exist_", &attr)}; + Ort::ConstOpAttr attr; + Ort::Status status = conv_node.GetAttributeByName("_does_not_exist_", attr); ASSERT_FALSE(status.IsOK()); ASSERT_EQ(status.GetErrorCode(), ORT_NOT_FOUND); ASSERT_EQ(attr, nullptr); @@ -152,23 +141,14 @@ TEST(EpGraphTest, GetAttributeByName) { // Test 3: Get attribute that is known to be set. // { - const OrtOpAttr* attr = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(conv_node, "auto_pad", &attr)); + Ort::ConstOpAttr attr; + ASSERT_ORTSTATUS_OK(conv_node.GetAttributeByName("auto_pad", attr)); ASSERT_NE(attr, nullptr); - OrtOpAttrType attr_type = ORT_OP_ATTR_UNDEFINED; - ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetType(attr, &attr_type)); - ASSERT_EQ(attr_type, ORT_OP_ATTR_STRING); - + OrtOpAttrType type = attr.GetType(); + ASSERT_EQ(ORT_OP_ATTR_STRING, type); std::string auto_pad_val; - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - size_t total_attr_bytes = 0; - Ort::Status status2{ort_api.ReadOpAttr(attr, attr_type, nullptr, 0, &total_attr_bytes)}; - auto_pad_val.resize(total_attr_bytes); - - ASSERT_ORTSTATUS_OK(ort_api.ReadOpAttr(attr, attr_type, auto_pad_val.data(), total_attr_bytes, - &total_attr_bytes)); + ASSERT_ORTSTATUS_OK(attr.GetValue(auto_pad_val)); ASSERT_EQ(auto_pad_val, "NOTSET"); } } @@ -229,14 +209,10 @@ TEST(EpGraphTest, SerializeToProto_InputModelHasExternalIni) { std::string ext_ini_file_path = "conv_qdq_ext_ini_serialized.bin"; std::filesystem::remove(ext_ini_file_path); std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); - auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* /* value_info */, const void* data, size_t bytes, bool& is_external, std::string& location, int64_t& offset) -> Ort::Status { - // OrtValueInfo* could be used to query initializer's name, type, shape, - // node consumers, etc. - (void)value_info; - if (bytes <= 127) { is_external = false; // Keep small initializers stored inside the TensorProto. return Ort::Status{nullptr}; @@ -442,13 +418,13 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { } for (size_t i = 0; i < api_num_initializers; ++i) { - const OrtValue* ort_value = nullptr; - const void* ort_value_data = nullptr; - const char* value_name = nullptr; + std::string value_name; + Ort::ConstValue ort_value; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_initializers[i], &value_name)); - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_initializers[i], &ort_value)); - ASSERT_ORTSTATUS_OK(ort_api.GetTensorData(ort_value, &ort_value_data)); + Ort::ConstValueInfo vi(api_initializers[i]); + value_name = vi.GetName(); + ASSERT_ORTSTATUS_OK(vi.GetInitializer(ort_value)); + const void* ort_value_data = ort_value.GetTensorRawData(); auto iter = tensor_proto_map.find(value_name); ASSERT_NE(iter, tensor_proto_map.end()); @@ -723,25 +699,21 @@ static void CheckValueInfoConsumers(const GraphViewer& graph_viewer, const OrtVa static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, const ONNX_NAMESPACE::TensorProto* tensor_proto, const GraphViewer& graph_viewer) { - const OrtApi& ort_api = Ort::GetApi(); - - const char* api_initializer_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_value_info, &api_initializer_name)); - ASSERT_NE(api_initializer_name, nullptr); + Ort::ConstValueInfo vi(api_value_info); + std::string api_initializer_name = vi.GetName(); // Check external initializer info (if any). - OrtExternalInitializerInfo* api_ext_info = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetExternalInitializerInfo(api_value_info, &api_ext_info)); - DeferOrtRelease defer_release_info(&api_ext_info, ort_api.ReleaseExternalInitializerInfo); + Ort::ExternalInitializerInfo api_ext_info{nullptr}; + auto external_status = vi.GetExternalInitializerInfo(api_ext_info); std::unique_ptr ext_info = nullptr; bool has_ext_info = graph_viewer.GetGraph().GetExternalInitializerInfo(api_initializer_name, ext_info, true); if (has_ext_info) { ASSERT_NE(api_ext_info, nullptr); - const ORTCHAR_T* api_ext_file_path = ort_api.ExternalInitializerInfo_GetFilePath(api_ext_info); - int64_t api_ext_file_offset = ort_api.ExternalInitializerInfo_GetFileOffset(api_ext_info); - size_t api_ext_byte_size = ort_api.ExternalInitializerInfo_GetByteSize(api_ext_info); + const std::basic_string api_ext_file_path = api_ext_info.GetFilePath(); + int64_t api_ext_file_offset = api_ext_info.GetFileOffset(); + size_t api_ext_byte_size = api_ext_info.GetByteSize(); ASSERT_EQ(PathString(api_ext_file_path), ext_info->GetRelPath()); ASSERT_EQ(api_ext_file_offset, static_cast(ext_info->GetOffset())); @@ -751,61 +723,49 @@ static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, ASSERT_FALSE(utils::HasExternalDataInFile(*tensor_proto)); } - const OrtValue* api_initializer_value = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_value_info, &api_initializer_value)); + Ort::ConstValue api_initializer_value; + ASSERT_ORTSTATUS_OK(vi.GetInitializer(api_initializer_value)); ASSERT_NE(api_initializer_value, nullptr); // Check initializer type. const ONNX_NAMESPACE::TypeProto type_proto = utils::TypeProtoFromTensorProto(*tensor_proto); auto type_info = OrtTypeInfo::FromTypeProto(type_proto); - const OrtTypeInfo* api_type_info = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoTypeInfo(api_value_info, &api_type_info)); + Ort::ConstTypeInfo api_type_info = vi.TypeInfo(); CheckTypeInfo(api_type_info, type_info.get()); } -static void CheckInitializerValueInfosCApi(gsl::span initializer_value_infos, +static void CheckInitializerValueInfosCApi(gsl::span initializer_value_infos, const InitializedTensorSet& initializer_tensor_protos, const GraphViewer& graph_viewer) { - const OrtApi& ort_api = Ort::GetApi(); - for (size_t i = 0; i < initializer_value_infos.size(); i++) { - const OrtValueInfo* api_value_info = initializer_value_infos[i]; - - const char* api_initializer_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_value_info, &api_initializer_name)); - ASSERT_NE(api_initializer_name, nullptr); + Ort::ConstValueInfo vi(initializer_value_infos[i]); + std::string api_initializer_name = vi.GetName(); auto tensor_proto_iter = initializer_tensor_protos.find(api_initializer_name); ASSERT_NE(tensor_proto_iter, initializer_tensor_protos.end()); const ONNX_NAMESPACE::TensorProto* tensor_proto = tensor_proto_iter->second; ASSERT_NE(tensor_proto, nullptr); - - CheckInitializerValueInfo(api_value_info, tensor_proto, graph_viewer); + CheckInitializerValueInfo(vi, tensor_proto, graph_viewer); } } // Checks that the OrtValueInfos obtained from the public C API are "equivalent" to the NodeArgs // in the original graph. -static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span value_infos, +static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span value_infos, gsl::span node_args) { ASSERT_EQ(value_infos.size(), node_args.size()); - const OrtApi& ort_api = Ort::GetApi(); const auto& graph_viewer_inputs = graph_viewer.GetInputsIncludingInitializers(); const auto& graph_viewer_outputs = graph_viewer.GetOutputs(); for (size_t i = 0; i < value_infos.size(); i++) { const NodeArg* node_arg = node_args[i]; - const OrtValueInfo* value_info = value_infos[i]; + Ort::ConstValueInfo vi(value_infos[i]); if (node_arg->Exists()) { const auto& value_name = node_arg->Name(); - - ASSERT_NE(value_info, nullptr); - - const char* api_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(value_info, &api_name)); + std::string api_name = vi.GetName(); ASSERT_EQ(std::string(api_name), value_name); bool is_graph_input = std::any_of(graph_viewer_inputs.begin(), graph_viewer_inputs.end(), @@ -825,64 +785,52 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::spanName()); - bool api_is_outer_scope = false; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsFromOuterScope(value_info, &api_is_outer_scope)); + bool api_is_outer_scope = vi.IsFromOuterScope(); ASSERT_EQ(api_is_outer_scope, is_outer_scope); - bool api_is_const_initializer = false; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsConstantInitializer(value_info, &api_is_const_initializer)); + bool api_is_const_initializer = vi.IsConstantInitializer(); ASSERT_EQ(api_is_const_initializer, is_const_initializer); if (is_const_initializer || api_is_opt_graph_input) { - CheckInitializerValueInfo(value_info, initializer, graph_viewer); + CheckInitializerValueInfo(vi, initializer, graph_viewer); } else { auto node_arg_type_info = OrtTypeInfo::FromTypeProto(*node_arg->TypeAsProto()); - const OrtTypeInfo* api_type_info = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoTypeInfo(value_info, &api_type_info)); + Ort::ConstTypeInfo api_type_info = vi.TypeInfo(); CheckTypeInfo(api_type_info, node_arg_type_info.get()); } - CheckValueInfoProducer(graph_viewer, value_info, node_arg); - CheckValueInfoConsumers(graph_viewer, value_info, node_arg); + CheckValueInfoProducer(graph_viewer, vi, node_arg); + CheckValueInfoConsumers(graph_viewer, vi, node_arg); } else { - ASSERT_EQ(value_info, nullptr); // A missing optional input has a null OrtValueInfo. + ASSERT_EQ(vi, nullptr); // A missing optional input has a null OrtValueInfo. } } } // Checks the Graph_GetSubgraph C API static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) { - const OrtApi& ort_api = Ort::GetApi(); - + Ort::ConstGraph ort_graph{&api_graph}; // Get all the nodes - size_t num_nodes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&api_graph, &num_nodes)); - - std::vector nodes(num_nodes); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); + std::vector nodes = ort_graph.GetNodes(); // Select a half of nodes to create a OrtGraph size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); - std::vector selected_nodes(num_selected_nodes); + std::vector selected_nodes(num_selected_nodes); for (size_t i = 0; i < num_selected_nodes; i++) { selected_nodes[i] = nodes[i]; } - OrtGraph* sub_graph; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); + Ort::Graph sub_graph = ort_graph.GetGraphView(selected_nodes); // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. @@ -892,31 +840,25 @@ static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) { GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - const char* graph_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); + auto graph_name = ort_graph.GetName(); std::string name = graph_name; name += "_half.onnx"; // Dump the graph for debugging // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); // model_proto->SerializeToOstream(&dump); - - ort_api.ReleaseGraph(sub_graph); } // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { - const OrtApi& ort_api = Ort::GetApi(); - + auto ort_cxx_graph = Ort::ConstGraph(&api_graph); // Check the path to model. const std::filesystem::path& model_path = graph_viewer.ModelPath(); - const ORTCHAR_T* api_model_path = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetModelPath(&api_graph, &api_model_path)); + const auto api_model_path = ort_cxx_graph.GetModelPath(); ASSERT_EQ(PathString(api_model_path), PathString(model_path.c_str())); // Check the model metadata Ort::AllocatorWithDefaultOptions default_allocator; - auto ort_cxx_graph = Ort::ConstGraph(&api_graph); auto ort_cxx_model_metadat = ort_cxx_graph.GetModelMetadata(); auto& model = graph_viewer.GetGraph().GetModel(); ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetProducerNameAllocated(default_allocator).get(), model.ProducerName().c_str()), 0); @@ -933,42 +875,30 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Check graph inputs. const auto& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers(); - size_t api_num_graph_inputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumInputs(&api_graph, &api_num_graph_inputs)); - ASSERT_EQ(api_num_graph_inputs, graph_input_node_args.size()); + std::vector api_graph_inputs = ort_cxx_graph.GetInputs(); + ASSERT_EQ(api_graph_inputs.size(), graph_input_node_args.size()); - std::vector api_graph_inputs(api_num_graph_inputs); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInputs(&api_graph, api_graph_inputs.data(), api_graph_inputs.size())); CheckValueInfosCApi(graph_viewer, api_graph_inputs, graph_input_node_args); // Check graph outputs. const auto& graph_output_node_args = graph_viewer.GetOutputs(); - size_t api_num_graph_outputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumOutputs(&api_graph, &api_num_graph_outputs)); - ASSERT_EQ(api_num_graph_outputs, graph_output_node_args.size()); + std::vector api_graph_outputs = ort_cxx_graph.GetOutputs(); + ASSERT_EQ(api_graph_outputs.size(), graph_output_node_args.size()); - std::vector api_graph_outputs(api_num_graph_outputs); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetOutputs(&api_graph, api_graph_outputs.data(), api_graph_outputs.size())); CheckValueInfosCApi(graph_viewer, api_graph_outputs, graph_output_node_args); // Check graph initializers const auto& graph_initializers = graph_viewer.GetAllInitializedTensors(); - size_t api_num_initializers = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumInitializers(&api_graph, &api_num_initializers)); - ASSERT_EQ(api_num_initializers, graph_initializers.size()); - - std::vector api_initializers(api_num_initializers); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInitializers(&api_graph, api_initializers.data(), api_initializers.size())); + std::vector api_initializers = ort_cxx_graph.GetInitializers(); + ASSERT_EQ(api_initializers.size(), graph_initializers.size()); CheckInitializerValueInfosCApi(api_initializers, graph_initializers, graph_viewer); // Check if it has a parent node. const Node* parent_node = graph_viewer.ParentNode(); const bool has_parent_node = parent_node != nullptr; - const OrtNode* api_parent_node = nullptr; - - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetParentNode(&api_graph, &api_parent_node)); + Ort::ConstNode api_parent_node = ort_cxx_graph.GetParentNode(); const bool api_has_parent_node = api_parent_node != nullptr; ASSERT_EQ(api_has_parent_node, has_parent_node); @@ -977,79 +907,56 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } // Check all nodes. - size_t api_num_nodes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&api_graph, &api_num_nodes)); - ASSERT_EQ(api_num_nodes, graph_viewer.NumberOfNodes()); - - std::vector api_nodes(api_num_nodes); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, api_nodes.data(), api_nodes.size())); + std::vector api_nodes = ort_cxx_graph.GetNodes(); + ASSERT_EQ(api_nodes.size(), graph_viewer.NumberOfNodes()); std::vector node_indices = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); - for (size_t node_idx = 0; node_idx < api_num_nodes; node_idx++) { + for (size_t node_idx = 0; node_idx < api_nodes.size(); node_idx++) { // Check basic node properties. const Node* node = graph_viewer.GetNode(node_indices[node_idx]); - const OrtNode* api_node = api_nodes[node_idx]; + Ort::ConstNode api_node = api_nodes[node_idx]; CheckNode(node, api_node); - int api_since_version = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSinceVersion(api_node, &api_since_version)); + const int api_since_version = api_node.GetSinceVersion(); ASSERT_EQ(api_since_version, node->SinceVersion()); // Check node inputs const auto input_node_args = node->InputDefs(); - size_t api_node_num_inputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumInputs(api_node, &api_node_num_inputs)); - ASSERT_EQ(api_node_num_inputs, input_node_args.size()); - - std::vector api_node_inputs(api_node_num_inputs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetInputs(api_node, api_node_inputs.data(), api_node_inputs.size())); + std::vector api_node_inputs = api_node.GetInputs(); + ASSERT_EQ(api_node_inputs.size(), input_node_args.size()); CheckValueInfosCApi(graph_viewer, api_node_inputs, input_node_args); // Check node outputs const auto output_node_args = node->OutputDefs(); - size_t api_node_num_outputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumOutputs(api_node, &api_node_num_outputs)); - ASSERT_EQ(api_node_num_outputs, output_node_args.size()); - - std::vector api_node_outputs(api_node_num_outputs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetOutputs(api_node, api_node_outputs.data(), api_node_outputs.size())); + std::vector api_node_outputs = api_node.GetOutputs(); + ASSERT_EQ(api_node_outputs.size(), output_node_args.size()); CheckValueInfosCApi(graph_viewer, api_node_outputs, output_node_args); // Check node attributes const auto& node_attrs = node->GetAttributes(); if (!node_attrs.empty()) { - size_t api_num_node_attributes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(api_node, &api_num_node_attributes)); - - std::vector api_node_attributes(api_num_node_attributes); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(api_node, api_node_attributes.data(), api_node_attributes.size())); + std::vector api_node_attributes = api_node.GetAttributes(); size_t attr_idx = 0; for (const auto& node_attr : node_attrs) { - const OrtOpAttr* api_node_attr = api_node_attributes[attr_idx]; + auto api_node_attr = api_node_attributes[attr_idx]; ASSERT_NE(api_node_attr, nullptr); - api_node_attr = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(api_node, node_attr.first.c_str(), &api_node_attr)); + auto status = api_node.GetAttributeByName(node_attr.first, api_node_attr); + ASSERT_TRUE(status.IsOK()); ASSERT_NE(api_node_attr, nullptr); - const char* api_node_attr_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(api_node_attr, &api_node_attr_name)); - ASSERT_STREQ(api_node_attr_name, node_attr.first.c_str()); - - OrtOpAttrType api_node_attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + auto api_node_attr_name = api_node_attr.GetName(); + ASSERT_EQ(api_node_attr_name, node_attr.first); + // XXX: Investigate why not // It's possible that the type is defined in ONNX::AttributeProto_AttributeType but not in OrtOpAttrType, since the two are not in a 1:1 mapping. // In such cases, OpAttr_GetType will return a non-null status, and we simply skip the check here. // TODO: Once we add support for ORT_OP_ATTR_TENSOR, we should be able to just fail if OpAttr_GetType // returns an error. - OrtStatusPtr status = ort_api.OpAttr_GetType(api_node_attr, &api_node_attr_type); - if (status != nullptr) { - Ort::GetApi().ReleaseStatus(status); - continue; - } + OrtOpAttrType api_node_attr_type = api_node_attr.GetType(); ONNX_NAMESPACE::AttributeProto_AttributeType node_attr_type = node_attr.second.type(); switch (node_attr_type) { @@ -1091,7 +998,7 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } default: // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail. - ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit.")); + FAIL() << "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit."; } attr_idx++; } @@ -1105,41 +1012,19 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Check node's implicit inputs to its subgraph nodes. const auto implicit_input_node_args = node->ImplicitInputDefs(); - size_t api_num_node_implicit_inputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumImplicitInputs(api_node, &api_num_node_implicit_inputs)); - ASSERT_EQ(api_num_node_implicit_inputs, implicit_input_node_args.size()); - - std::vector api_node_implicit_inputs(api_num_node_implicit_inputs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetImplicitInputs(api_node, api_node_implicit_inputs.data(), - api_node_implicit_inputs.size())); - + std::vector api_node_implicit_inputs = api_node.GetImplicitInputs(); + ASSERT_EQ(api_node_implicit_inputs.size(), implicit_input_node_args.size()); CheckValueInfosCApi(graph_viewer, api_node_implicit_inputs, implicit_input_node_args); // Recursively check subgraphs. - size_t api_num_node_subgraphs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumSubgraphs(api_node, &api_num_node_subgraphs)); - ASSERT_EQ(api_num_node_subgraphs, node_subgraphs_map.size()); - - std::vector api_node_subgraphs(api_num_node_subgraphs); - std::vector api_subgraph_attr_names(api_num_node_subgraphs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size(), - api_subgraph_attr_names.data())); - - for (const auto& [attr_name, subgraph] : node_subgraphs_map) { - // find index of this subgraph. - size_t api_subgraph_idx = api_num_node_subgraphs; - for (size_t subgraph_idx = 0; subgraph_idx < api_num_node_subgraphs; subgraph_idx++) { - if (api_subgraph_attr_names[subgraph_idx] == attr_name) { - api_subgraph_idx = subgraph_idx; - break; - } - } - ASSERT_NE(api_subgraph_idx, api_num_node_subgraphs); - - // Recursively check the subgraph - auto subgraph_viewer = std::make_unique(*subgraph); - const OrtGraph* api_subgraph = api_node_subgraphs[api_subgraph_idx]; - CheckGraphCApi(*subgraph_viewer, *api_subgraph); + std::vector api_node_subgraphs = api_node.GetSubgraphs(); + ASSERT_EQ(api_node_subgraphs.size(), node_subgraphs_map.size()); + + for (const auto& name_subgraph : api_node_subgraphs) { + auto hit = node_subgraphs_map.find(name_subgraph.attr_name); + ASSERT_NE(node_subgraphs_map.end(), hit); + auto subgraph_viewer = std::make_unique(*hit->second); + CheckGraphCApi(*subgraph_viewer, *name_subgraph.sub_graph); } } } diff --git a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc index 63652d8835e77..2e2bce97f0cb9 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc @@ -56,19 +56,19 @@ static Ort::Status GetNodeInputEdgeCount(const OrtNode* node, size_t& num_input_ // Sum the number of inputs with a producer node. num_input_edges = 0; - for (const OrtValueInfo* input : inputs) { + for (const OrtValueInfo* ort_input : inputs) { + Ort::ConstValueInfo input{ort_input}; if (input == nullptr) continue; // Skip missing optional input - const OrtNode* producer_node = nullptr; - RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueProducer(input, &producer_node, /*output_index*/ nullptr)); - num_input_edges += static_cast(producer_node != nullptr); + auto producer_info = input.GetProducerNode(); + num_input_edges += static_cast(producer_info.node != nullptr); } return Ort::Status{nullptr}; } // Get all output nodes that consume an output from the given node. -static Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result) { +static Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result) { const OrtApi& ort_api = Ort::GetApi(); size_t num_outputs = 0; @@ -77,23 +77,17 @@ static Ort::Status GetOutputNodes(const OrtNode* node, std::vector outputs(num_outputs); RETURN_IF_API_ERROR(ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); - std::vector output_nodes; + std::vector output_nodes; output_nodes.reserve(num_outputs); // May have more than `num_outputs` // Gather the OrtNode consumers of every output. - for (const OrtValueInfo* output : outputs) { + for (const OrtValueInfo* ort_output : outputs) { + Ort::ConstValueInfo output{ort_output}; if (output == nullptr) continue; // Skip missing optional output - size_t num_consumers = 0; - RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueNumConsumers(output, &num_consumers)); - - std::vector node_consumers(num_consumers, nullptr); - std::vector input_indices(num_consumers, 0); - RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueConsumers(output, node_consumers.data(), - input_indices.data(), num_consumers)); - - for (const OrtNode* consumer : node_consumers) { - output_nodes.push_back(consumer); + auto consumers_info = output.GetConsumers(); + for (const auto& consumer : consumers_info) { + output_nodes.push_back(consumer.node); } } @@ -108,77 +102,85 @@ static Ort::Status KahnsTopologicalSort(const OrtGraph& graph, const std::function& comp) { const OrtApi& ort_api = Ort::GetApi(); - // Get all nodes - size_t num_nodes = 0; - RETURN_IF_API_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); + try { + // Get all nodes + size_t num_nodes = 0; + RETURN_IF_API_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); - if (num_nodes == 0) { - return Ort::Status{nullptr}; // Nothing to sort. - } + if (num_nodes == 0) { + return Ort::Status{nullptr}; // Nothing to sort. + } - std::vector nodes(num_nodes); - RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); + std::vector nodes(num_nodes); + RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); - // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. - size_t max_node_id = 0; - for (const OrtNode* node : nodes) { - size_t node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); - max_node_id = std::max(max_node_id, node_id); - } + // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. + size_t max_node_id = 0; + for (const OrtNode* node : nodes) { + size_t node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); + max_node_id = std::max(max_node_id, node_id); + } - std::vector in_degree(max_node_id + 1, 0); - std::vector topo_order; - VisitorPriorityQueue to_visit(comp); + std::vector in_degree(max_node_id + 1, 0); + std::vector topo_order; + VisitorPriorityQueue to_visit(comp); - topo_order.reserve(num_nodes); + topo_order.reserve(num_nodes); - // Initialize in_degree and initial nodes to visit first. - for (const OrtNode* node : nodes) { - size_t input_edge_count = 0; - RETURN_IF_API_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); + // Initialize in_degree and initial nodes to visit first. + for (const OrtNode* node : nodes) { + size_t input_edge_count = 0; + RETURN_IF_API_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); - size_t node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); + size_t node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); - in_degree[node_id] = input_edge_count; - if (input_edge_count == 0) { - to_visit.push(node); + in_degree[node_id] = input_edge_count; + if (input_edge_count == 0) { + to_visit.push(node); + } } - } - while (!to_visit.empty()) { - const OrtNode* current_node = to_visit.top(); - to_visit.pop(); + while (!to_visit.empty()) { + const OrtNode* current_node = to_visit.top(); + to_visit.pop(); - if (!current_node) continue; + if (!current_node) continue; - if (enter) { - enter(current_node); - } + if (enter) { + enter(current_node); + } - std::vector output_nodes; - RETURN_IF_API_ERROR(GetOutputNodes(current_node, output_nodes)); + std::vector output_nodes; + RETURN_IF_API_ERROR(GetOutputNodes(current_node, output_nodes)); - for (const OrtNode* output_node : output_nodes) { - size_t output_node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(output_node, &output_node_id)); + for (const auto& output_node : output_nodes) { + size_t output_node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(output_node, &output_node_id)); - auto& node_in_degree = in_degree[output_node_id]; - node_in_degree--; + auto& node_in_degree = in_degree[output_node_id]; + node_in_degree--; - if (node_in_degree == 0) { - to_visit.push(output_node); + if (node_in_degree == 0) { + to_visit.push(output_node); + } } - } - size_t current_node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(current_node, ¤t_node_id)); - topo_order.push_back(current_node_id); - } + size_t current_node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(current_node, ¤t_node_id)); + topo_order.push_back(current_node_id); + } - if (num_nodes != topo_order.size()) { - return Ort::Status("Some nodes are not included in the topological sort: graph has a cycle", ORT_FAIL); + if (num_nodes != topo_order.size()) { + return Ort::Status("Some nodes are not included in the topological sort: graph has a cycle", ORT_FAIL); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status; + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status; } return Ort::Status{nullptr}; diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index 0fe747cdd84e5..cffa0efc39d45 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -420,7 +420,7 @@ TEST(ModelEditorAPITest, BasicModelEdit_CxxApi) { // typically this isn't needed. we replace this input but need to read info from it later on in the test // validation so we save the info locally to keep it accessible. - auto orig_input_name = graph_inputs[0].Name(); + auto orig_input_name = graph_inputs[0].GetName(); auto input_shape = graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape(); const std::string new_input_name = "Int64Input"; @@ -589,7 +589,7 @@ TEST(ModelEditorAPITest, InvalidModelEdit) { Node node("Cast", domain, "NewInputNode", {new_input_name}, // the existing node will now consume the output from the Cast instead of a graph input - {graph_inputs[0].Name()}, + {graph_inputs[0].GetName()}, attributes); graph.AddNode(node);