diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 5c9c1a0ae163f..9a0bcb53c9ad7 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -1011,6 +1011,8 @@ struct ProviderHost { virtual void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) = 0; // We pass OrtValue by reference here (as opposed to the original Graph function) to avoid header inclusion virtual Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& value) = 0; + virtual bool Graph__GetOrtValueInitializer(const Graph* p, const std::string& tensor_name, OrtValue& value, + bool check_outer_scope) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, NodeAttributes&& attributes, const std::string& domain) = 0; virtual Node& Graph__AddNode(Graph* p, const Node& other) = 0; @@ -1074,6 +1076,8 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::TensorProto* GraphViewer__GetConstantInitializer(const GraphViewer* p, const std::string& name, bool check_outer_scope) const = 0; + virtual bool GraphViewer__GetOrtValueInitializer(const GraphViewer* p, const std::string& tensor_name, + OrtValue& value) = 0; virtual const Node* GraphViewer__ParentNode(const GraphViewer* p) = 0; virtual int GraphViewer__NumberOfNodes(const GraphViewer* p) noexcept = 0; virtual int GraphViewer__MaxNodeIndex(const GraphViewer* p) noexcept = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 23fbead1e9707..19b4636c3766d 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1041,6 +1041,10 @@ struct Graph final { Status AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& ort_value) { return g_host->Graph__AddInitializedOrtValue(this, tensor, ort_value); } + bool GetOrtValueInitializer(const std::string& tensor_name, OrtValue& ort_value, + bool check_outer_scope = false) const { + return g_host->Graph__GetOrtValueInitializer(this, tensor_name, ort_value, check_outer_scope); + } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, const NodeAttributes* attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, attributes, domain); } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, NodeAttributes&& attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, std::move(attributes), domain); } Node& AddNode(const Node& other) { return g_host->Graph__AddNode(this, other); } @@ -1124,6 +1128,9 @@ class GraphViewer final { bool check_outer_scope = true) const { return g_host->GraphViewer__GetConstantInitializer(this, name, check_outer_scope); } + bool GetOrtValueInitializer(const std::string& tensor_name, OrtValue& ort_value) const { + return g_host->GraphViewer__GetOrtValueInitializer(this, tensor_name, ort_value); + } const Node* ParentNode() const { return g_host->GraphViewer__ParentNode(this); } int NumberOfNodes() const noexcept { return g_host->GraphViewer__NumberOfNodes(this); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index ee59ff2ab4932..41cf8be1d1412 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1258,6 +1258,10 @@ struct ProviderHostImpl : ProviderHost { void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) override { p->AddInitializedTensor(tensor); } Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& value) override { return p->AddInitializedOrtValue(tensor, value); } + bool Graph__GetOrtValueInitializer(const Graph* p, const std::string& tensor_name, OrtValue& value, + bool check_outer_scope) override { + return p->GetOrtValueInitializer(tensor_name, value, check_outer_scope); + } Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) override { return p->AddNode(name, op_type, description, input_args, output_args, attributes, domain); } @@ -1356,6 +1360,10 @@ struct ProviderHostImpl : ProviderHost { bool check_outer_scope) const override { return p->GetConstantInitializer(name, check_outer_scope); } + bool GraphViewer__GetOrtValueInitializer(const GraphViewer* p, const std::string& tensor_name, + OrtValue& value) override { + return p->GetOrtValueInitializer(tensor_name, value); + } const Node* GraphViewer__ParentNode(const GraphViewer* p) override { return p->ParentNode(); } int GraphViewer__NumberOfNodes(const GraphViewer* p) noexcept override { return p->NumberOfNodes(); } int GraphViewer__MaxNodeIndex(const GraphViewer* p) noexcept override { return p->MaxNodeIndex(); }