Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 103 additions & 128 deletions src/models/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,167 +58,142 @@ AttributeValue AttributeValue::Strings(const std::string& name, const std::vecto
namespace {

// Helper to create OrtOpAttr from AttributeValue using Model Editor API
OrtOpAttr* CreateOpAttr(const AttributeValue& attr) {
OrtOpAttr* op_attr = nullptr;

std::unique_ptr<OrtOpAttr> CreateOpAttr(const AttributeValue& attr) {
switch (attr.type) {
case AttributeType::INT:
Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), &attr.int_value, 1,
OrtOpAttrType::ORT_OP_ATTR_INT, &op_attr));
break;
return OrtOpAttr::Create(attr.name.c_str(), &attr.int_value, 1, OrtOpAttrType::ORT_OP_ATTR_INT);
case AttributeType::FLOAT:
Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), &attr.float_value, 1,
OrtOpAttrType::ORT_OP_ATTR_FLOAT, &op_attr));
break;
return OrtOpAttr::Create(attr.name.c_str(), &attr.float_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT);
case AttributeType::STRING:
Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), attr.string_value.c_str(),
static_cast<int>(attr.string_value.size()),
OrtOpAttrType::ORT_OP_ATTR_STRING, &op_attr));
break;
return OrtOpAttr::Create(attr.name.c_str(), attr.string_value.c_str(),
static_cast<int>(attr.string_value.size()),
OrtOpAttrType::ORT_OP_ATTR_STRING);
case AttributeType::INTS:
Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), attr.ints_value.data(),
static_cast<int>(attr.ints_value.size()),
OrtOpAttrType::ORT_OP_ATTR_INTS, &op_attr));
break;
return OrtOpAttr::Create(attr.name.c_str(), attr.ints_value.data(),
static_cast<int>(attr.ints_value.size()),
OrtOpAttrType::ORT_OP_ATTR_INTS);
case AttributeType::FLOATS:
Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), attr.floats_value.data(),
static_cast<int>(attr.floats_value.size()),
OrtOpAttrType::ORT_OP_ATTR_FLOATS, &op_attr));
break;
return OrtOpAttr::Create(attr.name.c_str(), attr.floats_value.data(),
static_cast<int>(attr.floats_value.size()),
OrtOpAttrType::ORT_OP_ATTR_FLOATS);
case AttributeType::STRINGS: {
std::vector<const char*> string_ptrs;
string_ptrs.reserve(attr.strings_value.size());
for (const auto& str : attr.strings_value) {
string_ptrs.push_back(str.c_str());
}
Ort::ThrowOnError(Ort::api->CreateOpAttr(attr.name.c_str(), string_ptrs.data(),
static_cast<int>(string_ptrs.size()),
OrtOpAttrType::ORT_OP_ATTR_STRINGS, &op_attr));
break;
return OrtOpAttr::Create(attr.name.c_str(), string_ptrs.data(),
static_cast<int>(string_ptrs.size()),
OrtOpAttrType::ORT_OP_ATTR_STRINGS);
}
}

return op_attr;
throw std::runtime_error("CreateOpAttr: Unhandled attribute type");
}

} // anonymous namespace

namespace GraphBuilder {

// Build complete ONNX model using the Model Editor API
OrtModel* Build(const ModelConfig& config) {
const auto& model_editor_api = Ort::GetModelEditorApi();

OrtGraph* graph = nullptr;
OrtModel* model = nullptr;
std::vector<OrtOpAttr*> node_attributes;

try {
// Create graph
Ort::ThrowOnError(model_editor_api.CreateGraph(&graph));

// Create input ValueInfos
std::vector<OrtValueInfo*> graph_inputs;
for (const auto& input : config.inputs) {
OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info));
Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info, input.elem_type));
Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info, input.shape.data(), input.shape.size()));

OrtTypeInfo* type_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info, &type_info));
Ort::api->ReleaseTensorTypeAndShapeInfo(tensor_info);

OrtValueInfo* value_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateValueInfo(input.name.c_str(), type_info, &value_info));
Ort::api->ReleaseTypeInfo(type_info);

graph_inputs.push_back(value_info);
}
std::unique_ptr<OrtModel> Build(const ModelConfig& config) {
// Create graph using RAII wrapper
auto graph = OrtGraph::Create();

// Create input ValueInfos
std::vector<std::unique_ptr<OrtValueInfo>> graph_inputs;
for (const auto& input : config.inputs) {
auto tensor_info = OrtTensorTypeAndShapeInfo::Create();
tensor_info->SetElementType(input.elem_type);
tensor_info->SetDimensions(input.shape.data(), input.shape.size());

auto value_info = OrtValueInfo::Create(input.name.c_str(), tensor_info.get());
graph_inputs.push_back(std::move(value_info));
}

// Create output ValueInfos
std::vector<OrtValueInfo*> graph_outputs;
for (const auto& output : config.outputs) {
OrtTensorTypeAndShapeInfo* tensor_info = nullptr;
Ort::ThrowOnError(Ort::api->CreateTensorTypeAndShapeInfo(&tensor_info));
Ort::ThrowOnError(Ort::api->SetTensorElementType(tensor_info, output.elem_type));
Ort::ThrowOnError(Ort::api->SetDimensions(tensor_info, output.shape.data(), output.shape.size()));
// Create output ValueInfos
std::vector<std::unique_ptr<OrtValueInfo>> graph_outputs;
for (const auto& output : config.outputs) {
auto tensor_info = OrtTensorTypeAndShapeInfo::Create();
tensor_info->SetElementType(output.elem_type);
tensor_info->SetDimensions(output.shape.data(), output.shape.size());

OrtTypeInfo* type_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_info, &type_info));
Ort::api->ReleaseTensorTypeAndShapeInfo(tensor_info);
auto value_info = OrtValueInfo::Create(output.name.c_str(), tensor_info.get());
graph_outputs.push_back(std::move(value_info));
}

OrtValueInfo* value_info = nullptr;
Ort::ThrowOnError(model_editor_api.CreateValueInfo(output.name.c_str(), type_info, &value_info));
Ort::api->ReleaseTypeInfo(type_info);
// Set graph inputs and outputs (graph takes ownership of ValueInfos)
std::vector<OrtValueInfo*> input_ptrs;
input_ptrs.reserve(graph_inputs.size());
for (auto& vi : graph_inputs) {
input_ptrs.push_back(vi.get());
}

graph_outputs.push_back(value_info);
}
std::vector<OrtValueInfo*> output_ptrs;
output_ptrs.reserve(graph_outputs.size());
for (auto& vi : graph_outputs) {
output_ptrs.push_back(vi.get());
}

// Set graph inputs and outputs (graph takes ownership of ValueInfos)
Ort::ThrowOnError(model_editor_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size()));
Ort::ThrowOnError(model_editor_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size()));
graph->SetInputs(input_ptrs.data(), input_ptrs.size());
// Release ownership since graph took it
for (auto& vi : graph_inputs) vi.release();
Comment thread
baijumeswani marked this conversation as resolved.
graph->SetOutputs(output_ptrs.data(), output_ptrs.size());
for (auto& vi : graph_outputs) vi.release();

// Create node attributes
for (const auto& attr : config.attributes) {
node_attributes.push_back(CreateOpAttr(attr));
}
// Create node attributes
std::vector<std::unique_ptr<OrtOpAttr>> node_attributes;
for (const auto& attr : config.attributes) {
node_attributes.push_back(CreateOpAttr(attr));
}
Comment thread
qjia7 marked this conversation as resolved.
Comment thread
qjia7 marked this conversation as resolved.

// Create input/output name vectors
std::vector<const char*> input_names;
for (const auto& input : config.inputs) {
input_names.push_back(input.name.c_str());
}
// Extract raw pointers for CreateNode (which stores references, not copies)
std::vector<OrtOpAttr*> node_attr_ptrs;
node_attr_ptrs.reserve(node_attributes.size());
for (auto& attr : node_attributes) {
node_attr_ptrs.push_back(attr.get());
}

std::vector<const char*> output_names;
for (const auto& output : config.outputs) {
output_names.push_back(output.name.c_str());
}
// Create input/output name vectors
std::vector<const char*> input_names;
for (const auto& input : config.inputs) {
input_names.push_back(input.name.c_str());
}

// Create node
OrtNode* node = nullptr;
Ort::ThrowOnError(model_editor_api.CreateNode(
config.op_type.c_str(),
"", // empty domain = ONNX domain
(config.op_type + "_node").c_str(),
input_names.data(),
input_names.size(),
output_names.data(),
output_names.size(),
node_attributes.empty() ? nullptr : node_attributes.data(),
node_attributes.size(),
&node));

// Add node to graph (graph takes ownership of node)
Ort::ThrowOnError(model_editor_api.AddNodeToGraph(graph, node));
// Release node attributes - CreateNode made its own copy
for (auto* attr : node_attributes) {
Ort::api->ReleaseOpAttr(attr);
}
node_attributes.clear();
// Create model with opset
const char* domain_name = "";
Ort::ThrowOnError(model_editor_api.CreateModel(&domain_name, &config.opset_version, 1, &model));
std::vector<const char*> output_names;
for (const auto& output : config.outputs) {
output_names.push_back(output.name.c_str());
}

// Add graph to model (model takes ownership of graph)
Ort::ThrowOnError(model_editor_api.AddGraphToModel(model, graph));
graph = nullptr; // model now owns graph
// Create node using RAII wrapper
auto node = OrtNode::Create(
config.op_type.c_str(),
"", // empty domain = ONNX domain
(config.op_type + "_node").c_str(),
input_names.data(),
input_names.size(),
output_names.data(),
output_names.size(),
node_attr_ptrs.empty() ? nullptr : node_attr_ptrs.data(),
node_attr_ptrs.size());

// Add node to graph (graph takes ownership of node)
graph->AddNode(node.get());
for (auto& attr : node_attributes) {
attr.release();
}
node.release();

return model;
// Create model with opset using RAII wrapper
const char* domain_name = "";
int opset = config.opset_version;
auto model = OrtModel::Create(&domain_name, &opset, 1);

} catch (...) {
// Clean up on error
for (auto* attr : node_attributes) {
Ort::api->ReleaseOpAttr(attr);
}
if (graph != nullptr) {
Ort::api->ReleaseGraph(graph);
}
if (model != nullptr) {
Ort::api->ReleaseModel(model);
}
throw;
}
// Add graph to model (model takes ownership of graph)
model->AddGraph(graph.get());
// Release ownership - model now owns the entire graph structure
graph.release();

return model;
}

} // namespace GraphBuilder
Expand Down
3 changes: 1 addition & 2 deletions src/models/graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ namespace GraphBuilder {

// Build a complete ONNX model using the Model Editor API
// Returns an OrtModel that can be used to create sessions
// Caller is responsible for calling Ort::api->ReleaseModel() when done
OrtModel* Build(const ModelConfig& config);
std::unique_ptr<OrtModel> Build(const ModelConfig& config);

} // namespace GraphBuilder

Expand Down
7 changes: 2 additions & 5 deletions src/models/graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,10 @@ OrtSession* GetOrCreateSession(
}

// Build model using Model Editor API
OrtModel* model = GraphBuilder::Build(config);
auto model = GraphBuilder::Build(config);

// Create session from model
auto session = CreateSession(model, ep_name, session_config_keys, session_config_values);

// Release the model - session has its own copy
Ort::api->ReleaseModel(model);
auto session = CreateSession(model.get(), ep_name, session_config_keys, session_config_values);

OrtSession* session_ptr = session.get();
cache.sessions_[key] = std::move(session);
Expand Down
52 changes: 52 additions & 0 deletions src/models/onnxruntime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,11 @@ struct OrtMemoryInfo {
*
*/
struct OrtTensorTypeAndShapeInfo {
static std::unique_ptr<OrtTensorTypeAndShapeInfo> Create();

void SetElementType(ONNXTensorElementDataType type);
void SetDimensions(const int64_t* dim_values, size_t dim_count);

ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount

Expand Down Expand Up @@ -1316,6 +1321,53 @@ struct OrtOpAttr {
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtGraph used in Model Editor API
/// </summary>
struct OrtGraph {
static std::unique_ptr<OrtGraph> Create();

void SetInputs(OrtValueInfo** inputs, size_t input_count);
void SetOutputs(OrtValueInfo** outputs, size_t output_count);
void AddNode(OrtNode* node);

static void operator delete(void* p) { Ort::api->ReleaseGraph(reinterpret_cast<OrtGraph*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtModel used in Model Editor API
/// </summary>
struct OrtModel {
static std::unique_ptr<OrtModel> Create(const char** domain_names, const int* opset_versions, size_t num_domains);

void AddGraph(OrtGraph* graph);

static void operator delete(void* p) { Ort::api->ReleaseModel(reinterpret_cast<OrtModel*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtValueInfo used in Model Editor API
/// </summary>
struct OrtValueInfo {
static std::unique_ptr<OrtValueInfo> Create(const char* name, const OrtTensorTypeAndShapeInfo* tensor_info);
static void operator delete(void* p) { Ort::api->ReleaseValueInfo(reinterpret_cast<OrtValueInfo*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This struct provides life time management for OrtNode used in Model Editor API
/// </summary>
struct OrtNode {
static std::unique_ptr<OrtNode> Create(const char* op_type, const char* domain, const char* name,
const char** input_names, size_t num_inputs,
const char** output_names, size_t num_outputs,
OrtOpAttr** attributes, size_t num_attributes);
static void operator delete(void* p) { Ort::api->ReleaseNode(reinterpret_cast<OrtNode*>(p)); }
Ort::Abstract make_abstract;
};

/// <summary>
/// This class wraps a raw pointer OrtKernelContext* that is being passed
/// to the custom kernel Compute() method. Use it to safely access context
Expand Down
Loading
Loading