diff --git a/src/models/graph_builder.cpp b/src/models/graph_builder.cpp index 1ace418692..4e5083dabb 100644 --- a/src/models/graph_builder.cpp +++ b/src/models/graph_builder.cpp @@ -58,47 +58,36 @@ AttributeValue AttributeValue::Strings(const std::string& name, const std::vecto namespace { // Helper to create OrtOpAttr from AttributeValue using Model Editor API -OrtOpAttr* CreateOpAttr(const AttributeValue& attr) { - OrtOpAttr* op_attr = nullptr; - +std::unique_ptr CreateOpAttr(const AttributeValue& attr) { switch (attr.type) { case AttributeType::INT: - Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), &attr.int_value, 1, - OrtOpAttrType::ORT_OP_ATTR_INT, &op_attr)); - break; + return OrtOpAttr::Create(attr.name.c_str(), &attr.int_value, 1, OrtOpAttrType::ORT_OP_ATTR_INT); case AttributeType::FLOAT: - Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), &attr.float_value, 1, - OrtOpAttrType::ORT_OP_ATTR_FLOAT, &op_attr)); - break; + return OrtOpAttr::Create(attr.name.c_str(), &attr.float_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT); case AttributeType::STRING: - Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), attr.string_value.c_str(), - static_cast(attr.string_value.size()), - OrtOpAttrType::ORT_OP_ATTR_STRING, &op_attr)); - break; + return OrtOpAttr::Create(attr.name.c_str(), attr.string_value.c_str(), + static_cast(attr.string_value.size()), + OrtOpAttrType::ORT_OP_ATTR_STRING); case AttributeType::INTS: - Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), attr.ints_value.data(), - static_cast(attr.ints_value.size()), - OrtOpAttrType::ORT_OP_ATTR_INTS, &op_attr)); - break; + return OrtOpAttr::Create(attr.name.c_str(), attr.ints_value.data(), + static_cast(attr.ints_value.size()), + OrtOpAttrType::ORT_OP_ATTR_INTS); case AttributeType::FLOATS: - Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), attr.floats_value.data(), - static_cast(attr.floats_value.size()), - OrtOpAttrType::ORT_OP_ATTR_FLOATS, &op_attr)); - break; + return OrtOpAttr::Create(attr.name.c_str(), attr.floats_value.data(), + static_cast(attr.floats_value.size()), + OrtOpAttrType::ORT_OP_ATTR_FLOATS); case AttributeType::STRINGS: { std::vector string_ptrs; string_ptrs.reserve(attr.strings_value.size()); for (const auto& str : attr.strings_value) { string_ptrs.push_back(str.c_str()); } - Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), string_ptrs.data(), - static_cast(string_ptrs.size()), - OrtOpAttrType::ORT_OP_ATTR_STRINGS, &op_attr)); - break; + return OrtOpAttr::Create(attr.name.c_str(), string_ptrs.data(), + static_cast(string_ptrs.size()), + OrtOpAttrType::ORT_OP_ATTR_STRINGS); } } - - return op_attr; + throw std::runtime_error("CreateOpAttr: Unhandled attribute type"); } } // anonymous namespace @@ -106,119 +95,105 @@ OrtOpAttr* CreateOpAttr(const AttributeValue& attr) { namespace GraphBuilder { // Build complete ONNX model using the Model Editor API -OrtModel* Build(const ModelConfig& config) { - const auto& model_editor_api = Ort::GetModelEditorApi(); - - OrtGraph* graph = nullptr; - OrtModel* model = nullptr; - std::vector node_attributes; - - try { - // Create graph - Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); - - // Create input ValueInfos - std::vector graph_inputs; - for (const auto& input : config.inputs) { - OrtTensorTypeAndShapeInfo* tensor_info = nullptr; - Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info)); - Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info, input.elem_type)); - Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info, input.shape.data(), input.shape.size())); - - OrtTypeInfo* type_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info, &type_info)); - Ort::api->ReleaseTensorTypeAndShapeInfo(tensor_info); - - OrtValueInfo* value_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateValueInfo(input.name.c_str(), type_info, &value_info)); - Ort::api->ReleaseTypeInfo(type_info); - - graph_inputs.push_back(value_info); - } +std::unique_ptr Build(const ModelConfig& config) { + // Create graph using RAII wrapper + auto graph = OrtGraph::Create(); + + // Create input ValueInfos + std::vector> graph_inputs; + for (const auto& input : config.inputs) { + auto tensor_info = OrtTensorTypeAndShapeInfo::Create(); + tensor_info->SetElementType(input.elem_type); + tensor_info->SetDimensions(input.shape.data(), input.shape.size()); + + auto value_info = OrtValueInfo::Create(input.name.c_str(), tensor_info.get()); + graph_inputs.push_back(std::move(value_info)); + } - // Create output ValueInfos - std::vector graph_outputs; - for (const auto& output : config.outputs) { - OrtTensorTypeAndShapeInfo* tensor_info = nullptr; - Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info)); - Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info, output.elem_type)); - Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info, output.shape.data(), output.shape.size())); + // Create output ValueInfos + std::vector> graph_outputs; + for (const auto& output : config.outputs) { + auto tensor_info = OrtTensorTypeAndShapeInfo::Create(); + tensor_info->SetElementType(output.elem_type); + tensor_info->SetDimensions(output.shape.data(), output.shape.size()); - OrtTypeInfo* type_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info, &type_info)); - Ort::api->ReleaseTensorTypeAndShapeInfo(tensor_info); + auto value_info = OrtValueInfo::Create(output.name.c_str(), tensor_info.get()); + graph_outputs.push_back(std::move(value_info)); + } - OrtValueInfo* value_info = nullptr; - Ort::ThrowOnError(model_editor_api.CreateValueInfo(output.name.c_str(), type_info, &value_info)); - Ort::api->ReleaseTypeInfo(type_info); + // Set graph inputs and outputs (graph takes ownership of ValueInfos) + std::vector input_ptrs; + input_ptrs.reserve(graph_inputs.size()); + for (auto& vi : graph_inputs) { + input_ptrs.push_back(vi.get()); + } - graph_outputs.push_back(value_info); - } + std::vector output_ptrs; + output_ptrs.reserve(graph_outputs.size()); + for (auto& vi : graph_outputs) { + output_ptrs.push_back(vi.get()); + } - // Set graph inputs and outputs (graph takes ownership of ValueInfos) - Ort::ThrowOnError(model_editor_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size())); - Ort::ThrowOnError(model_editor_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size())); + graph->SetInputs(input_ptrs.data(), input_ptrs.size()); + // Release ownership since graph took it + for (auto& vi : graph_inputs) vi.release(); + graph->SetOutputs(output_ptrs.data(), output_ptrs.size()); + for (auto& vi : graph_outputs) vi.release(); - // Create node attributes - for (const auto& attr : config.attributes) { - node_attributes.push_back(CreateOpAttr(attr)); - } + // Create node attributes + std::vector> node_attributes; + for (const auto& attr : config.attributes) { + node_attributes.push_back(CreateOpAttr(attr)); + } - // Create input/output name vectors - std::vector input_names; - for (const auto& input : config.inputs) { - input_names.push_back(input.name.c_str()); - } + // Extract raw pointers for CreateNode (which stores references, not copies) + std::vector node_attr_ptrs; + node_attr_ptrs.reserve(node_attributes.size()); + for (auto& attr : node_attributes) { + node_attr_ptrs.push_back(attr.get()); + } - std::vector output_names; - for (const auto& output : config.outputs) { - output_names.push_back(output.name.c_str()); - } + // Create input/output name vectors + std::vector input_names; + for (const auto& input : config.inputs) { + input_names.push_back(input.name.c_str()); + } - // Create node - OrtNode* node = nullptr; - Ort::ThrowOnError(model_editor_api.CreateNode( - config.op_type.c_str(), - "", // empty domain = ONNX domain - (config.op_type + "_node").c_str(), - input_names.data(), - input_names.size(), - output_names.data(), - output_names.size(), - node_attributes.empty() ? nullptr : node_attributes.data(), - node_attributes.size(), - &node)); - - // Add node to graph (graph takes ownership of node) - Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node)); - // Release node attributes - CreateNode made its own copy - for (auto* attr : node_attributes) { - Ort::api->ReleaseOpAttr(attr); - } - node_attributes.clear(); - // Create model with opset - const char* domain_name = ""; - Ort::ThrowOnError(model_editor_api.CreateModel(&domain_name, &config.opset_version, 1, &model)); + std::vector output_names; + for (const auto& output : config.outputs) { + output_names.push_back(output.name.c_str()); + } - // Add graph to model (model takes ownership of graph) - Ort::ThrowOnError(model_editor_api.AddGraphToModel(model, graph)); - graph = nullptr; // model now owns graph + // Create node using RAII wrapper + auto node = OrtNode::Create( + config.op_type.c_str(), + "", // empty domain = ONNX domain + (config.op_type + "_node").c_str(), + input_names.data(), + input_names.size(), + output_names.data(), + output_names.size(), + node_attr_ptrs.empty() ? nullptr : node_attr_ptrs.data(), + node_attr_ptrs.size()); + + // Add node to graph (graph takes ownership of node) + graph->AddNode(node.get()); + for (auto& attr : node_attributes) { + attr.release(); + } + node.release(); - return model; + // Create model with opset using RAII wrapper + const char* domain_name = ""; + int opset = config.opset_version; + auto model = OrtModel::Create(&domain_name, &opset, 1); - } catch (...) { - // Clean up on error - for (auto* attr : node_attributes) { - Ort::api->ReleaseOpAttr(attr); - } - if (graph != nullptr) { - Ort::api->ReleaseGraph(graph); - } - if (model != nullptr) { - Ort::api->ReleaseModel(model); - } - throw; - } + // Add graph to model (model takes ownership of graph) + model->AddGraph(graph.get()); + // Release ownership - model now owns the entire graph structure + graph.release(); + + return model; } } // namespace GraphBuilder diff --git a/src/models/graph_builder.h b/src/models/graph_builder.h index 925d01017d..e8c5524e79 100644 --- a/src/models/graph_builder.h +++ b/src/models/graph_builder.h @@ -72,8 +72,7 @@ namespace GraphBuilder { // Build a complete ONNX model using the Model Editor API // Returns an OrtModel that can be used to create sessions -// Caller is responsible for calling Ort::api->ReleaseModel() when done -OrtModel* Build(const ModelConfig& config); +std::unique_ptr Build(const ModelConfig& config); } // namespace GraphBuilder diff --git a/src/models/graph_executor.cpp b/src/models/graph_executor.cpp index c742f3eb5b..f4a8e5738e 100644 --- a/src/models/graph_executor.cpp +++ b/src/models/graph_executor.cpp @@ -83,13 +83,10 @@ OrtSession* GetOrCreateSession( } // Build model using Model Editor API - OrtModel* model = GraphBuilder::Build(config); + auto model = GraphBuilder::Build(config); // Create session from model - auto session = CreateSession(model, ep_name, session_config_keys, session_config_values); - - // Release the model - session has its own copy - Ort::api->ReleaseModel(model); + auto session = CreateSession(model.get(), ep_name, session_config_keys, session_config_values); OrtSession* session_ptr = session.get(); cache.sessions_[key] = std::move(session); diff --git a/src/models/onnxruntime_api.h b/src/models/onnxruntime_api.h index e9b8424679..116ca18ee9 100644 --- a/src/models/onnxruntime_api.h +++ b/src/models/onnxruntime_api.h @@ -798,6 +798,11 @@ struct OrtMemoryInfo { * */ struct OrtTensorTypeAndShapeInfo { + static std::unique_ptr Create(); + + void SetElementType(ONNXTensorElementDataType type); + void SetDimensions(const int64_t* dim_values, size_t dim_count); + ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount @@ -1316,6 +1321,53 @@ struct OrtOpAttr { Ort::Abstract make_abstract; }; +/// +/// This struct provides life time management for OrtGraph used in Model Editor API +/// +struct OrtGraph { + static std::unique_ptr Create(); + + void SetInputs(OrtValueInfo** inputs, size_t input_count); + void SetOutputs(OrtValueInfo** outputs, size_t output_count); + void AddNode(OrtNode* node); + + static void operator delete(void* p) { Ort::api->ReleaseGraph(reinterpret_cast(p)); } + Ort::Abstract make_abstract; +}; + +/// +/// This struct provides life time management for OrtModel used in Model Editor API +/// +struct OrtModel { + static std::unique_ptr Create(const char** domain_names, const int* opset_versions, size_t num_domains); + + void AddGraph(OrtGraph* graph); + + static void operator delete(void* p) { Ort::api->ReleaseModel(reinterpret_cast(p)); } + Ort::Abstract make_abstract; +}; + +/// +/// This struct provides life time management for OrtValueInfo used in Model Editor API +/// +struct OrtValueInfo { + static std::unique_ptr Create(const char* name, const OrtTensorTypeAndShapeInfo* tensor_info); + static void operator delete(void* p) { Ort::api->ReleaseValueInfo(reinterpret_cast(p)); } + Ort::Abstract make_abstract; +}; + +/// +/// This struct provides life time management for OrtNode used in Model Editor API +/// +struct OrtNode { + static std::unique_ptr Create(const char* op_type, const char* domain, const char* name, + const char** input_names, size_t num_inputs, + const char** output_names, size_t num_outputs, + OrtOpAttr** attributes, size_t num_attributes); + static void operator delete(void* p) { Ort::api->ReleaseNode(reinterpret_cast(p)); } + Ort::Abstract make_abstract; +}; + /// /// This class wraps a raw pointer OrtKernelContext* that is being passed /// to the custom kernel Compute() method. Use it to safely access context diff --git a/src/models/onnxruntime_inline.h b/src/models/onnxruntime_inline.h index 80fed5d631..7d4609725d 100644 --- a/src/models/onnxruntime_inline.h +++ b/src/models/onnxruntime_inline.h @@ -1338,6 +1338,76 @@ inline std::unique_ptr OrtOpAttr::Create(const char* name, const void return std::unique_ptr{p}; } +inline std::unique_ptr OrtGraph::Create() { + OrtGraph* p; + Ort::ThrowOnError(Ort::GetModelEditorApi().CreateGraph(&p)); + return std::unique_ptr{p}; +} + +inline void OrtGraph::SetInputs(OrtValueInfo** inputs, size_t input_count) { + Ort::ThrowOnError(Ort::GetModelEditorApi().SetGraphInputs(this, inputs, input_count)); +} + +inline void OrtGraph::SetOutputs(OrtValueInfo** outputs, size_t output_count) { + Ort::ThrowOnError(Ort::GetModelEditorApi().SetGraphOutputs(this, outputs, output_count)); +} + +inline void OrtGraph::AddNode(OrtNode* node) { + Ort::ThrowOnError(Ort::GetModelEditorApi().AddNodeToGraph(this, node)); +} + +inline std::unique_ptr OrtModel::Create(const char** domain_names, const int* opset_versions, size_t num_domains) { + OrtModel* p; + Ort::ThrowOnError(Ort::GetModelEditorApi().CreateModel(domain_names, opset_versions, num_domains, &p)); + return std::unique_ptr{p}; +} + +inline void OrtModel::AddGraph(OrtGraph* graph) { + Ort::ThrowOnError(Ort::GetModelEditorApi().AddGraphToModel(this, graph)); +} + +inline std::unique_ptr OrtValueInfo::Create(const char* name, const OrtTensorTypeAndShapeInfo* tensor_info) { + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtTypeInfo* type_info; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info, &type_info)); + + OrtValueInfo* p; + auto status = model_editor_api.CreateValueInfo(name, type_info, &p); + Ort::api->ReleaseTypeInfo(type_info); + Ort::ThrowOnError(status); + + return std::unique_ptr{p}; +} + +inline std::unique_ptr OrtNode::Create(const char* op_type, const char* domain, const char* name, + const char** input_names, size_t num_inputs, + const char** output_names, size_t num_outputs, + OrtOpAttr** attributes, size_t num_attributes) { + OrtNode* p; + Ort::ThrowOnError(Ort::GetModelEditorApi().CreateNode( + op_type, domain, name, + input_names, num_inputs, + output_names, num_outputs, + attributes, num_attributes, + &p)); + return std::unique_ptr{p}; +} + +inline std::unique_ptr OrtTensorTypeAndShapeInfo::Create() { + OrtTensorTypeAndShapeInfo* p; + Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&p)); + return std::unique_ptr{p}; +} + +inline void OrtTensorTypeAndShapeInfo::SetElementType(ONNXTensorElementDataType type) { + Ort::ThrowOnError(Ort::api->SetTensorElementType(this, type)); +} + +inline void OrtTensorTypeAndShapeInfo::SetDimensions(const int64_t* dim_values, size_t dim_count) { + Ort::ThrowOnError(Ort::api->SetDimensions(this, dim_values, dim_count)); +} + inline std::unique_ptr OrtKernelInfo::Clone() const { OrtKernelInfo* p; Ort::ThrowOnError(Ort::api->CopyKernelInfo(this, &p));