Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,8 @@ class Node {
@remarks The graph containing this node must be resolved, otherwise nullptr will be returned. */
const ONNX_NAMESPACE::OpSchema* Op() const noexcept { return op_; }

Status InstantiateFunctionBody();

Status GetInstantiateFunctionBody(std::unique_ptr<Function>& output) const;
/** Create a copy of the called op's FunctionProto if it has one. Returns true if successful. */
bool TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& func_proto) const;

bool CanBeInlined() const;

Expand Down Expand Up @@ -1289,9 +1288,9 @@ class Graph {
*/
Graph(Graph& parent_graph, const Node& parent_node, ONNX_NAMESPACE::GraphProto& subgraph_proto);

Graph(const Model& owning_model,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
ONNX_NAMESPACE::GraphProto& subgraph_proto,
Graph(const Model& owning_model,
IOnnxRuntimeOpSchemaCollectionPtr schema_registry,
ONNX_NAMESPACE::GraphProto& subgraph_proto,
const std::unordered_map<std::string, int>& domain_version_map,
const logging::Logger& logger,
bool strict_shape_type_inference);
Expand Down Expand Up @@ -1571,7 +1570,7 @@ class Graph {
IOnnxRuntimeOpSchemaCollectionPtr schema_registry_;

//Currently to make the ORT in-memory graph work, we have to create a temporary op schema
//for the fused kernel. I really don't like it. but for short-term solution, let's host
//for the fused kernel. I really don't like it. but for short-term solution, let's host
//those schemas here.
InlinedVector<std::unique_ptr<ONNX_NAMESPACE::OpSchema>> fused_schemas_containers_;
#endif // !defined(ORT_MINIMAL_BUILD)
Expand Down
405 changes: 138 additions & 267 deletions onnxruntime/core/graph/function_utils.cc

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions onnxruntime/core/graph/function_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <string>

#include "core/common/common.h"
#include "onnx/onnx_pb.h"
#include "core/graph/graph.h"
Expand Down Expand Up @@ -38,7 +40,7 @@ std::unique_ptr<ONNX_NAMESPACE::OpSchema> CreateSchema(const std::string& functi
const logging::Logger& logger,
bool allow_released_opsets_only);

/** Get the unique id for function. This is used as a key to find the
/** Get the unique id for function. This is used as a key to find the
* relevant model local function from it's container.
* @param function_domain Domain for the function.
* @param function_name Name of the function. Name should match the OpType of the node which references the function.
Expand All @@ -47,10 +49,10 @@ inline std::string GetFunctionIdentifier(std::string_view function_domain, std::
return function_domain.data() + std::string(":") + function_name.data();
}

Status Instantiate(onnxruntime::Graph& graph,
const onnxruntime::NodeIndex node_index,
const ONNX_NAMESPACE::FunctionProto& onnx_func_proto,
std::unique_ptr<Function>& output);
void Specialize(ONNX_NAMESPACE::FunctionProto& called_function, const ONNX_NAMESPACE::NodeProto calling_node,
const onnxruntime::NodeAttributes& attr_map, std::string unique_prefix);

void Specialize(ONNX_NAMESPACE::FunctionProto& called_function, Node& calling_node, std::string unique_prefix);

}

Expand Down
213 changes: 104 additions & 109 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,12 @@ bool Node::CanBeInlined() const {
return func_body_ || func_template_ || op_ && (op_->HasFunction() || op_->HasContextDependentFunction());
}

Status Node::GetInstantiateFunctionBody(std::unique_ptr<Function>& output) const {
// Initialize function body
bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_proto) const {
if (func_template_) {
return function_utils::Instantiate(*graph_, index_, *func_template_->onnx_func_proto_, output);
} else if (op_ && (op_->HasFunction() || op_->HasContextDependentFunction())) {
// This node has a schema defined function proto. If it is a context dependent function
// then build it otherwise fetch the FunctionProto from schema.
ONNX_NAMESPACE::FunctionProto onnx_function_proto;
onnx_function_proto = *func_template_->onnx_func_proto_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy probably is fine as there is no initializers in function proto. but do we really need a copy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I thought about it earlier. But decided this was fine for current use, because we end up changing the names/attributes when inlining. So, we end up with a copy anyway. (The current implementation reduces the number of intermediate-representations used during inlining from two to one. Reducing the one to zero is possible, but would complicate the code, so settled on this.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see.

return true;
} else if (op_) {
// Check if this node has a schema defined function proto.
if (op_->HasContextDependentFunction()) {
NodeProto node_proto;
ToProto(node_proto);
Expand All @@ -582,27 +580,13 @@ Status Node::GetInstantiateFunctionBody(std::unique_ptr<Function>& output) const
input_types.emplace_back();
}
ONNX_NAMESPACE::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types);
if (!op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto)) {
// I don't know why but the existing behavior is ignore the failure here.
// keep the same.
return Status::OK();
}
} else {
return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto);
} else if (op_->HasFunction()) {
onnx_function_proto = *(op_->GetFunction());
return true;
}
return function_utils::Instantiate(*graph_, index_, onnx_function_proto, output);
} else {
return Status::OK();
}
}

Status Node::InstantiateFunctionBody() {
if (nullptr != func_body_) {
// already instantiated.
return Status::OK();
}

return GetInstantiateFunctionBody(func_body_);
return false;
}

void Node::SetFunctionTemplate(const FunctionTemplate& func_template) {
Expand Down Expand Up @@ -3989,113 +3973,124 @@ Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph,
return fused_node;
}

Status Graph::InlineFunction(Node& node) {
// Remove the function node, add the nodes in function's subgraph into the
// main graph.
if (!node.GetFunctionBody())
ORT_RETURN_IF_ERROR(node.InstantiateFunctionBody());
const Graph& subgraph = node.GetFunctionBody()->Body();
auto output_edges = node.GetRelationships().output_edges;
Status Graph::InlineFunction(Node& callnode) {
const auto& model_path = ModelPath();
auto output_edges = callnode.GetRelationships().output_edges;
for (const auto& output_edge : output_edges) {
RemoveEdge(node.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex());
RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex());
}

// Map of function input outputs to nodes input/outputs
std::unordered_map<std::string, NodeArg*> remap_input_output;
// Set of node input output names as these names need to be preserved during inlining
std::unordered_set<std::string> func_input_output_names;
// create a uniq_identifier to append to every node name and intermediate input\outputs
// to make sure there are no unintended duplicates
std::stringstream ss;
ss << "_" << static_cast<const void*>(&callnode) << "_";
auto uniq_identifier = ss.str();

// Replace a (function-call) node by an inlined graph.
if (!callnode.GetFunctionBody()) {
// This is the normal use-case: inlining a FunctionProto (representing
// a model-local function or a schema-defined function).
FunctionProto inlined_fp;
ORT_ENFORCE(callnode.TryGetFunctionProto(inlined_fp), "Node has no function body and cannot be inlined.");
function_utils::Specialize(inlined_fp, callnode, uniq_identifier);

auto to_node_arg = [this](const std::string& name) {
return &this->GetOrCreateNodeArg(name, nullptr);
};

for (const auto& inlined_node : inlined_fp.node()) {
if (inlined_node.op_type() == kConstant) {
// Copy constant nodes _value to name_to_initial_tensor_
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(inlined_node, model_path, *tensor, inlined_node.output(0)));
name_to_initial_tensor_[tensor->name()] = tensor;
} else {
InlinedVector<onnxruntime::NodeArg*> inputs;
InlinedVector<onnxruntime::NodeArg*> outputs;

for (const auto& tensor_name : inlined_node.input())
inputs.push_back(to_node_arg(tensor_name));

for (size_t i = 0; i < subgraph.GetInputsIncludingInitializers().size(); ++i) {
auto* input = subgraph.GetInputsIncludingInitializers()[i];
if (input->Name() != node.MutableInputDefs()[i]->Name()) {
remap_input_output[input->Name()] = node.MutableInputDefs()[i];
for (const auto& tensor_name : inlined_node.output())
outputs.push_back(to_node_arg(tensor_name));

onnxruntime::NodeAttributes new_attr_map;
new_attr_map.reserve(inlined_node.attribute_size());
for (const auto& node_attr : inlined_node.attribute()) {
onnx::AttributeProto attr_copy = node_attr;
new_attr_map[node_attr.name()] = std::move(attr_copy);
}
AddNode(inlined_node.name(), inlined_node.op_type(),
inlined_node.doc_string(), inputs, outputs, &new_attr_map, inlined_node.domain());
}
}
func_input_output_names.insert(input->Name());
}

for (size_t i = 0; i < subgraph.GetOutputs().size(); ++i) {
auto* output = subgraph.GetOutputs()[i];
if (output->Name() != node.MutableOutputDefs()[i]->Name()) {
remap_input_output[output->Name()] = node.MutableOutputDefs()[i];
} else {
// Uncommon scenario. Inlining a node representing a fused sub-graph.
// TODO: Unclear that this feature is needed. Can this be removed?
const Graph& subgraph = callnode.GetFunctionBody()->Body();

// Map of function input outputs to nodes input/outputs
std::unordered_map<std::string, NodeArg*> remap_input_output;
// Set of node input output names as these names need to be preserved during inlining
std::unordered_set<std::string> func_input_output_names;

for (size_t i = 0; i < subgraph.GetInputsIncludingInitializers().size(); ++i) {
auto* input = subgraph.GetInputsIncludingInitializers()[i];
if (input->Name() != callnode.MutableInputDefs()[i]->Name()) {
remap_input_output[input->Name()] = callnode.MutableInputDefs()[i];
}
func_input_output_names.insert(input->Name());
}
func_input_output_names.insert(output->Name());
}

// create a uniq_identifier to append to every node name and intermediate input\outputs
// to make sure there are no unintended duplicates
std::stringstream ss;
ss << static_cast<const void*>(&node);
auto uniq_identifier = ss.str();
for (size_t i = 0; i < subgraph.GetOutputs().size(); ++i) {
auto* output = subgraph.GetOutputs()[i];
if (output->Name() != callnode.MutableOutputDefs()[i]->Name()) {
remap_input_output[output->Name()] = callnode.MutableOutputDefs()[i];
}
func_input_output_names.insert(output->Name());
}

const auto& model_path = ModelPath();
for (auto& init : subgraph.name_to_initial_tensor_) {
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
*tensor = *init.second;
tensor->set_name(tensor->name() + uniq_identifier);
name_to_initial_tensor_[tensor->name()] = tensor;
}
for (const auto& subgraph_node : subgraph.Nodes()) {
if (subgraph_node.OpType() == kConstant) {
// Copy constant nodes _value to name_to_initial_tensor_
ONNX_NAMESPACE::NodeProto subgraph_node_proto{};
subgraph_node.ToProto(subgraph_node_proto);
for (auto& init : subgraph.name_to_initial_tensor_) {
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, model_path, *tensor, subgraph_node_proto.output(0) + uniq_identifier));
*tensor = *init.second;
tensor->set_name(tensor->name() + uniq_identifier);
name_to_initial_tensor_[tensor->name()] = tensor;
} else {
std::vector<NodeArg*> inputs, outputs;
for (auto* input : subgraph_node.InputDefs()) {
if (input->Name().empty()) {
// This is a missing (optional) input. No need to rename.
}
for (const auto& subgraph_node : subgraph.Nodes()) {
if (subgraph_node.OpType() == kConstant) {
// Copy constant nodes _value to name_to_initial_tensor_
ONNX_NAMESPACE::NodeProto subgraph_node_proto{};
subgraph_node.ToProto(subgraph_node_proto);
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, model_path, *tensor, subgraph_node_proto.output(0)));
name_to_initial_tensor_[tensor->name()] = tensor;
} else {
std::vector<NodeArg*> inputs, outputs;
for (auto* input : subgraph_node.InputDefs()) {
auto& n_input = GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
inputs.push_back(&n_input);
} else if (func_input_output_names.find(input->Name()) != func_input_output_names.end()) {
auto it = remap_input_output.find(input->Name());
if (it != remap_input_output.end()) {
// This is a function input/output and needs to be remapped to node input for correctness
inputs.push_back(it->second);
} else {
// This is a function input/output so preserve the existing name
auto& n_input = GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
inputs.push_back(&n_input);
}
} else {
// This is an intermediate input. Add a unique identifier as suffix to make sure
// there is no name collision with names in parent graph
auto& n_input = GetOrCreateNodeArg(input->Name() + uniq_identifier, input->TypeAsProto());
inputs.push_back(&n_input);
}
}
for (auto* output : subgraph_node.OutputDefs()) {
if (output->Name().empty()) {
// Create empty arg (no renaming) for missing optional-outputs
for (auto* output : subgraph_node.OutputDefs()) {
auto& n_output = GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
outputs.push_back(&n_output);
} else if (func_input_output_names.find(output->Name()) != func_input_output_names.end()) {
auto it = remap_input_output.find(output->Name());
if (it != remap_input_output.end()) {
outputs.push_back(it->second);
} else {
auto& n_output = GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
outputs.push_back(&n_output);
}
} else {
auto& n_output = GetOrCreateNodeArg(output->Name() + uniq_identifier, output->TypeAsProto());
outputs.push_back(&n_output);
}
}

AddNode(subgraph_node.Name() + uniq_identifier, subgraph_node.OpType(), subgraph_node.Description(),
inputs,
outputs,
&subgraph_node.GetAttributes(),
subgraph_node.Domain());
AddNode(subgraph_node.Name() + uniq_identifier, subgraph_node.OpType(), subgraph_node.Description(),
inputs,
outputs,
&subgraph_node.GetAttributes(),
subgraph_node.Domain());
}
}
}

RemoveNode(node.Index());
RemoveNode(callnode.Index());

// std::cout << "Graph after inlining\n\n" << *this << std::endl << std::flush;

ORT_RETURN_IF_ERROR(this->Resolve());

return Status::OK();
}

Expand Down
Loading