Skip to content

Commit

Permalink
add if in python_api
Browse files Browse the repository at this point in the history
  • Loading branch information
evolosen committed Oct 11, 2021
1 parent ef33e30 commit fca6e5a
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 57 deletions.
53 changes: 52 additions & 1 deletion runtime/bindings/python/src/compatibility/ngraph/opset8/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Factory functions for all ngraph ops."""
from functools import partial
from typing import Callable, Iterable, List, Optional, Set, Union
from typing import Callable, Iterable, List, Optional, Set, Union, Tuple

import numpy as np
from ngraph.impl import Node, Shape
Expand Down Expand Up @@ -367,3 +367,54 @@ def random_uniform(
"op_seed": op_seed,
}
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)

@nameable_op
def if_op(
condition: NodeInput,
inputs: List[Node],
bodies: Tuple(GraphBody, GraphBody)
input_desc: Tuple(List[TensorIteratorInvariantInputDesc], List[TensorIteratorInvariantInputDesc]),
output_desc: Tuple(List[TensorIteratorInvariantInputDesc], List[TensorIteratorInvariantInputDesc]),
name: Optional[str] = None,
) -> Node:
"""Perform recurrent execution of the network described in the body, iterating through the data.
@param trip_count: A scalar or 1D tensor with 1 element specifying
maximum number of iterations.
@param execution_condition: A scalar or 1D tensor with 1 element
specifying whether to execute the first iteration or not.
@param inputs: The provided to TensorIterator operator.
@param graph_body: The graph representing the body we execute.
@param slice_input_desc: The descriptors describing sliced inputs, that is nodes
representing tensors we iterate through, processing single
data slice in one iteration.
@param merged_input_desc: The descriptors describing merged inputs, that is nodes
representing variables with initial value at first iteration,
which may be changing through iterations.
@param invariant_input_desc: The descriptors describing invariant inputs, that is nodes
representing variable with persistent value through all
iterations.
@param body_output_desc: The descriptors describing body outputs from specified
iteration.
@param concat_output_desc: The descriptors describing specified output values through
all the iterations concatenated into one node.
@param body_condition_output_idx: Determines the purpose of the corresponding result in
the graph_body. This result will determine the dynamic
exit condition. If the value of this result is False,
then iterations stop.
@param current_iteration_input_idx: Determines the purpose of the corresponding parameter
in the graph_body. This parameter will be used as
an iteration counter. Optional.
@return: The new node which performs Loop.
"""
attributes = {
"then_body": bodies[0].serialize(),
"else_body": bodies[1].serialize(),
"then_inputs": [desc.serialize() for desc in input_desc[0]],
"else_inputs": [desc.serialize() for desc in input_desc[1]],
"then_outputs": [desc.serialize() for desc in output_desc[0]],
"else_outputs": [desc.serialize() for desc in output_desc[1]]
}
return _get_node_factory_opset8().create("If", as_nodes(condition, *inputs),
attributes)

Original file line number Diff line number Diff line change
Expand Up @@ -24,69 +24,96 @@ util::DictAttributeDeserializer::DictAttributeDeserializer(
void util::DictAttributeDeserializer::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
if (m_attributes.contains(name)) {
if (const auto& a = ngraph::as_type<
ngraph::AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>>(
ngraph::AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::InputDescription>>>>(
&adapter)) {
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> input_descs;
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
const auto& merged_input_desc = input_desc["merged_input_desc"].cast<py::list>();
const auto& slice_input_desc = input_desc["slice_input_desc"].cast<py::list>();
const auto& invariant_input_desc = input_desc["invariant_input_desc"].cast<py::list>();
for (py::handle h : slice_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto slice_in = std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
input_descs.push_back(slice_in);
}
std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::InputDescription>> input_descs;

for (py::handle h : merged_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto merged_in = std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["body_value_idx"].cast<int64_t>());
input_descs.push_back(merged_in);
}
if (name == "input_descriptions") {
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
const auto& merged_input_desc = input_desc["merged_input_desc"].cast<py::list>();
const auto& slice_input_desc = input_desc["slice_input_desc"].cast<py::list>();
const auto& invariant_input_desc = input_desc["invariant_input_desc"].cast<py::list>();
for (py::handle h : slice_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto slice_in = std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
input_descs.push_back(slice_in);
}

for (py::handle h : invariant_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto invariant_in = std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>());
input_descs.push_back(invariant_in);
for (py::handle h : merged_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto merged_in = std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["body_value_idx"].cast<int64_t>());
input_descs.push_back(merged_in);
}

for (py::handle h : invariant_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto invariant_in = std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>());
input_descs.push_back(invariant_in);
}
} else if (name == "then_inputs" || name == "else_inputs") {
const py::list& input_desc = m_attributes[name.c_str()].cast<py::list>();
for (py::handle h : input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto invariant_in = std::make_shared<ngraph::op::util::MultiSubGraphOp::InvariantInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>());
input_descs.push_back(invariant_in);
}
} else {
NGRAPH_CHECK(false, "Input descriptions is not supported with name ", name);
}
a->set(input_descs);
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>>(&adapter)) {
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> output_descs;
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
const auto& body_output_desc = output_desc["body_output_desc"].cast<py::list>();
const auto& concat_output_desc = output_desc["concat_output_desc"].cast<py::list>();
for (py::handle h : body_output_desc) {
const py::dict& desc = h.cast<py::dict>();
auto body_output = std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["iteration"].cast<int64_t>());
output_descs.push_back(body_output);
}
std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::OutputDescription>> output_descs;
if (name == "output_descriptions") {
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
const auto& body_output_desc = output_desc["body_output_desc"].cast<py::list>();
const auto& concat_output_desc = output_desc["concat_output_desc"].cast<py::list>();
for (py::handle h : body_output_desc) {
const py::dict& desc = h.cast<py::dict>();
auto body_output = std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["iteration"].cast<int64_t>());
output_descs.push_back(body_output);
}

for (py::handle h : concat_output_desc) {
const py::dict& desc = h.cast<py::dict>();
auto concat_output = std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
output_descs.push_back(concat_output);
for (py::handle h : concat_output_desc) {
const py::dict& desc = h.cast<py::dict>();
auto concat_output = std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
output_descs.push_back(concat_output);
}
} else if (name == "then_outputs" || name == "else_outputs") {
const py::list& output_desc = m_attributes[name.c_str()].cast<py::list>();
for (py::handle h : output_desc) {
const py::dict& desc = h.cast<py::dict>();
auto body_output = std::make_shared<ngraph::op::util::MultiSubGraphOp::BodyOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>());
output_descs.push_back(body_output);
}
} else {
NGRAPH_CHECK(false, "Output descriptions is not supported with name ", name);
}
a->set(output_descs);
} else if (const auto& a =
Expand Down Expand Up @@ -241,7 +268,7 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name,
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) {
if (m_attributes.contains(name)) {
if (name == "body") {
if (name == "body" || name == "then_body" || name == "else_body") {
const py::dict& body_attrs = m_attributes[name.c_str()].cast<py::dict>();
const auto& body_outputs = as_output_vector(body_attrs["results"].cast<ngraph::NodeVector>());
const auto& body_parameters = body_attrs["parameters"].cast<ngraph::ParameterVector>();
Expand Down

0 comments on commit fca6e5a

Please sign in to comment.