diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6eb15280a4aa4..5dcf62f25d221 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6469,6 +6469,17 @@ struct OrtApi { _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); + + /** \brief Get ::OrtModelMetadata from an ::OrtGraph + * + * \param[in] graph The OrtGraph instance. + * \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d1b08f127fa2a..067ef048aebb3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2834,6 +2834,7 @@ struct GraphImpl : Ort::detail::Base { void SetOutputs(std::vector& outputs); void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value void AddNode(Node& node); // Graph takes ownership of Node + ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::Graph_GetModelMetadata #endif // !defined(ORT_MINIMAL_BUILD) }; } // namespace detail @@ -2848,6 +2849,7 @@ struct Graph : detail::GraphImpl { Graph(); #endif }; +using ConstGraph = detail::GraphImpl>; namespace detail { template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 705f17c5d6f43..539f43098b19a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2798,6 +2798,13 @@ inline void GraphImpl::AddNode(Node& node) { ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release())); } +template +inline ModelMetadata GraphImpl::GetModelMetadata() const { + OrtModelMetadata* out; + ThrowOnError(GetApi().Graph_GetModelMetadata(this->p_, &out)); + return ModelMetadata{out}; +} + template <> inline void ModelImpl::AddGraph(Graph& graph) { // Model takes ownership of `graph` diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 504b102e782fd..b99c22edb36c8 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -10,6 +10,7 @@ #include "core/framework/tensor_external_data_info.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" #define DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(external_type, internal_type, internal_api) \ external_type* ToExternal() { return static_cast(this); } \ @@ -301,6 +302,11 @@ struct OrtGraph { /// The graph's name. virtual const std::string& GetName() const = 0; + /// + /// Returns the model's metadata. + /// + /// The model metadata. + virtual std::unique_ptr GetModelMetadata() const = 0; /// /// Returns the model's path, which is empty if unknown. /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index eb7fb6937c29e..759a2998ace3a 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -20,6 +20,7 @@ #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/graph_viewer.h" #include "core/graph/graph.h" +#include "core/graph/model.h" namespace onnxruntime { @@ -769,6 +770,25 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); } +std::unique_ptr EpGraph::GetModelMetadata() const { +#if !defined(ORT_MINIMAL_BUILD) + const auto& model = graph_viewer_.GetGraph().GetModel(); + auto model_metadata = std::make_unique(); + + model_metadata->producer_name = model.ProducerName(); + model_metadata->producer_version = model.ProducerVersion(); + model_metadata->description = model.DocString(); + model_metadata->graph_description = model.GraphDocString(); + model_metadata->domain = model.Domain(); + model_metadata->version = model.ModelVersion(); + model_metadata->custom_metadata_map = model.MetaData(); + model_metadata->graph_name = model.MainGraph().Name(); + return model_metadata; +#else + return nullptr; +#endif +} + const ORTCHAR_T* EpGraph::GetModelPath() const { return graph_viewer_.ModelPath().c_str(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index be78d77360cb8..7f22e265129f7 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -298,6 +298,9 @@ struct EpGraph : public OrtGraph { // Returns the graph's name. const std::string& GetName() const override; + // Returns the graph's metadata + std::unique_ptr GetModelMetadata() const override; + // Returns the model path. const ORTCHAR_T* GetModelPath() const override; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index d3795d911b22f..e7ffcbc7e4c90 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -13,6 +13,7 @@ #include "core/framework/ort_value.h" #include "core/graph/abi_graph_types.h" #include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" namespace onnxruntime { @@ -184,6 +185,9 @@ struct ModelEditorGraph : public OrtGraph { const std::string& GetName() const override { return name; } + std::unique_ptr GetModelMetadata() const override { + return std::make_unique(model_metadata); + } const ORTCHAR_T* GetModelPath() const override { return model_path.c_str(); } int64_t GetOnnxIRVersion() const override { @@ -241,6 +245,7 @@ struct ModelEditorGraph : public OrtGraph { std::vector> nodes; std::string name = "ModelEditorGraph"; std::filesystem::path model_path; + ModelMetadata model_metadata; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 88d84e95b406c..3a5bf196117d8 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2626,6 +2626,16 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out) { + API_IMPL_BEGIN + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); + } + *out = reinterpret_cast(graph->GetModelMetadata().release()); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path) { API_IMPL_BEGIN if (model_path == nullptr) { @@ -4095,6 +4105,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ReleaseSyncStream, &OrtApis::CopyTensors, + + &OrtApis::Graph_GetModelMetadata, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 3eee174ff81f4..c9cb32732acfe 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -635,6 +635,7 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); +ORT_API_STATUS_IMPL(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 188edad572182..513097aaf7ade 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -914,7 +914,22 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ const ORTCHAR_T* api_model_path = nullptr; ASSERT_ORTSTATUS_OK(ort_api.Graph_GetModelPath(&api_graph, &api_model_path)); ASSERT_EQ(PathString(api_model_path), PathString(model_path.c_str())); - + // Check the model metadata + Ort::AllocatorWithDefaultOptions default_allocator; + auto ort_cxx_graph = Ort::ConstGraph(&api_graph); + auto ort_cxx_model_metadat = ort_cxx_graph.GetModelMetadata(); + auto& model = graph_viewer.GetGraph().GetModel(); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetProducerNameAllocated(default_allocator).get(), model.ProducerName().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetGraphNameAllocated(default_allocator).get(), model.MainGraph().Name().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetDomainAllocated(default_allocator).get(), model.Domain().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetDescriptionAllocated(default_allocator).get(), model.DocString().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetGraphDescriptionAllocated(default_allocator).get(), model.GraphDocString().c_str()), 0); + ASSERT_EQ(ort_cxx_model_metadat.GetVersion(), model.ModelVersion()); + auto model_meta_data = model.MetaData(); + for (auto& [k, v] : model_meta_data) { + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.LookupCustomMetadataMapAllocated(k.c_str(), default_allocator).get(), v.c_str()), 0) + << " key=" << k << "; value=" << v; + } // Check graph inputs. const auto& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers();