Skip to content

Commit

Permalink
Fix ElementwiseInputReshape transformation
Browse files Browse the repository at this point in the history
Reshape node always needs to be inserted
in order to preserve ShapeOf nodes (reshapability of a model) that can potentially be above
elementwise node.

Refactor EltwiseInputReshape_test and EltwiseInputNormalization_test since the logic of maintaining reshape for eltwise has been changed.

Signed-off-by: Roman Kazantsev <[email protected]>
  • Loading branch information
rkazants committed Oct 26, 2020
1 parent 166ab89 commit 4f0479f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 83 deletions.
56 changes: 36 additions & 20 deletions model-optimizer/extensions/middle/EltwiseInputNormalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
nodes_attributes = {
# Placeholder layers
'placeholder_1': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_2': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_3': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
Expand Down Expand Up @@ -131,16 +133,19 @@ def test_mega_hardcore(self):
#
# REFERENCE GRAPH AFTER TRANSFORMATION
#
# data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
# /\ /\ /\
# data2(1,1,64,1)---'--------------------------------'-------------------------------'
# /
# data4(64,1)-------, Reshape(1,1,64,1)
# \/ |
# data3(64,1)------`---->Eltwise3->data(64,1)---'
# data1(1,3,64,64)---------------------,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
# /\ /\ /\
# data2(64,1)-,- Reshape1(1,1,64,64)--'--------------------------------o-------------------------------'
# | |
# | Reshape(1,1,64,1)
# \/ |
# data3(64,1)----------->Eltwise3->data(64,1)--------------------------'
#
graph = build_graph(nodes_attributes,
[('placeholder_1_data', 'eltwise_1'),
[('placeholder_1', 'placeholder_1_data'),
('placeholder_2', 'placeholder_2_data'),
('placeholder_3', 'placeholder_3_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_1_data', 'eltwise_2'),
Expand All @@ -163,32 +168,43 @@ def test_mega_hardcore(self):
}, nodes_with_edges_only=True)

graph_ref = build_graph(nodes_attributes,
[('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'eltwise_1'),
[('placeholder_1', 'placeholder_1_data'),
('placeholder_2', 'placeholder_2_data'),
('placeholder_3', 'placeholder_3_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_1_data', 'eltwise_2'),
('placeholder_4_data', 'eltwise_3'),
('placeholder_2_data', 'eltwise_3'),
('placeholder_3_data', 'eltwise_3'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_3_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'eltwise_2'),
('eltwise_3_data', 'reshape_2'),
('reshape_2_const', 'reshape_2_const_data'),
('reshape_2_const_data', 'reshape_2'),
('reshape_2', 'reshape_2_data'),
('reshape_2_data', 'eltwise_2'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_2_data', 'eltwise_4'),
('placeholder_2_data', 'eltwise_4'),
('reshape_1_data', 'eltwise_4'),
('eltwise_4', 'eltwise_4_data'),
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'placeholder_2_data': {'shape': np.array([1, 1, 64, 1]),
'value': np.ones([1, 1, 64, 1])},
'placeholder_2_data': {'shape': np.array([64, 1]),
'value': np.ones([64, 1])},
'placeholder_3_data': {'shape': np.array([64, 1])},
'placeholder_4_data': {'shape': np.array([64, 1]), 'value': np.ones([64, 1])},
'reshape_1_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])},
'reshape_1_const_data': {'value': int64_array([1, 1, 64, 1]),
'shape': int64_array([4])},
'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},

'reshape_2_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])},
'reshape_2_const_data': {'value': int64_array([1, 1, 64, 1]),
'shape': int64_array([4])},
'reshape_2_data': {'shape': np.array([1, 1, 64, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_3_data': {'shape': np.array([64, 1])},
Expand Down
67 changes: 28 additions & 39 deletions model-optimizer/extensions/middle/EltwiseInputReshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np

from mo.front.common.layout import get_features_dim, shape_for_layout
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
Expand Down Expand Up @@ -73,50 +74,38 @@ def run_after(self):
return [MiddleStart]

def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_data_nodes():
# Get all requested shapes for current node
# This mapping will contain pairs like {shape:[list of consumers nodes]}
mapping = {}
for consumer in node.out_nodes():
edge_attrs = graph.get_edge_data(node.id, consumer.id)[0]
if 'new_shape' in edge_attrs:
if np.array_equal(edge_attrs['new_shape'], node.shape):
continue
new_shape = tuple([x for x in edge_attrs['new_shape']])
if not new_shape in mapping:
mapping.update({new_shape: [consumer]})
else:
mapping[new_shape].append(consumer)
for node in graph.get_op_nodes():
for out_port_idx in node.out_ports():
mapping = {}
output_port = node.out_port(out_port_idx)
for consumer_port in output_port.get_destinations():
edge_attrs = consumer_port.get_in_edge_attrs()
if 'new_shape' in edge_attrs:
if np.array_equal(edge_attrs['new_shape'], output_port.data.get_shape()):
continue
new_shape = tuple([x for x in edge_attrs['new_shape']])
if not new_shape in mapping:
mapping.update({new_shape: [consumer_port]})
else:
mapping[new_shape].append(consumer_port)

if node.has_valid('value'):
# Check that requested shape are the same
# In case if they are different, we duplicate them
for shape_key in mapping.keys():
shape = list(shape_key)
new_value = np.reshape(node.value, shape)
node_copy = Op.create_input_data_node(graph, node.id + '/copy', value=np.array(new_value))
for consumer in mapping[shape_key]:
edge_attrs = graph.get_edge_data(node.id, consumer.id)[0]
del edge_attrs['new_shape']

# Remove edge from previous data node and connect new data node with its consumer
graph.remove_edge(node.id, consumer.id)
graph.add_edge(node_copy.id, consumer.id, **edge_attrs)
else:
# Insert Reshape layer between data node and consumer
for shape_key in mapping.keys():
shape = list(shape_key)
new_shape = list(shape_key)
reshape_name = node.soft_get('name', node.id) + '/EltwiseReshape'
reshape = Reshape(graph, attrs={'name': reshape_name})
reshape_dim = Const(graph,
{'value': shape, 'name': reshape_name + '/Shape'}).create_node_with_data()
reshape_data = reshape.create_node_with_data(inputs=[node, reshape_dim])
reshape_node = create_op_with_const_inputs(graph, Reshape, {1: new_shape},
{'name': reshape_name})
reshape_node.in_port(0).connect(output_port)

# Iterate over consumers and reconnect them to Reshape layer output
for consumer in mapping[shape_key]:
edge_attrs = graph.get_edge_data(node.id, consumer.id)[0]
for consumer_port in mapping[shape_key]:
edge_attrs = consumer_port.get_in_edge_attrs()
del edge_attrs['new_shape']
consumer_port.connect(reshape_node.out_port(0))

# Reconnect edge from original data node to Reshape output datanode
graph.remove_edge(node.id, consumer.id)
graph.add_edge(reshape_data.id, consumer.id, **edge_attrs)
# Adjust shape and value for Reshape output
output_port_value = output_port.data.get_value()
if output_port_value is not None:
reshape_node.out_port(0).data.set_value(np.reshape(output_port_value, new_shape))
else:
reshape_node.out_port(0).data.set_shape(new_shape)
64 changes: 45 additions & 19 deletions model-optimizer/extensions/middle/EltwiseInputReshape_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,13 @@ def test2_not_constant(self):
self.assertTrue(flag, resp)

def test3_constant(self):
# ,--------------->consumer3 data-->consumer3
# data---(new_shape1)-->consumer1 => data-->consumer1
# `-(new_shape2)-->consumer2 data-->consumer2
# ,--------------->consumer3 ,------------>consumer3
# data---(new_shape1)-->consumer1 => data--->reshape1-->consumer1
# `-(new_shape2)-->consumer2 `->reshape2-->consumer2
#
graph = build_graph(nodes_attributes,
[('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}),
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}),
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 1, 3])}),
('placeholder_1_data', 'consumer_3'),
('consumer_1', 'concat'),
Expand All @@ -165,17 +166,32 @@ def test3_constant(self):
nodes_with_edges_only=True)

graph_ref = build_graph(nodes_attributes,
[('placeholder_1_data', 'consumer_1'),
('placeholder_2_data', 'consumer_2'),
('placeholder_3_data', 'consumer_3'),
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'consumer_1'),
('placeholder_1_data', 'reshape_2'),
('reshape_2_const', 'reshape_2_const_data'),
('reshape_2_const_data', 'reshape_2'),
('reshape_2', 'reshape_2_data'),
('reshape_2_data', 'consumer_2'),
('placeholder_1_data', 'consumer_3'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3, 1, 1]),
'value': np.ones([1, 3, 1, 1])},
'placeholder_2_data': {'shape': int64_array([1, 1, 3]), 'value': np.ones([1, 1, 3])},
'placeholder_3_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
'reshape_1_const': {'value': int64_array([1, 3, 1, 1]), 'shape': int64_array([4])},
'reshape_1_const_data': {'value': int64_array([1, 3, 1, 1]),
'shape': int64_array([4])},
'reshape_1_data': {'shape': int64_array([1, 3, 1, 1])},

'reshape_2_const': {'value': int64_array([1, 1, 3]), 'shape': int64_array([3])},
'reshape_2_const_data': {'value': int64_array([1, 1, 3]),
'shape': int64_array([3])},
'reshape_2_data': {'shape': int64_array([1, 1, 3])},
}, nodes_with_edges_only=True)

pattern = EltwiseInputReshape()
Expand All @@ -185,12 +201,13 @@ def test3_constant(self):
self.assertTrue(flag, resp)

def test4_constant(self):
# ,--------------->consumer3 ,-->consumer3
# data---(new_shape1)-->consumer1 => data-->consumer1
# `-(new_shape2)-->consumer2 `->consumer2
# ,-(new_shape)-->consumer3 ,-->consumer3
# data---(new_shape)-->consumer1 => data-->reshape---->consumer1
# `-(new_shape)-->consumer2 `-->consumer2
#
graph = build_graph(nodes_attributes,
[('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([3, 1, 1])}),
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([3, 1, 1])}),
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([3, 1, 1])}),
('placeholder_1_data', 'consumer_3', {'new_shape': int64_array([3, 1, 1])}),
('consumer_1', 'concat'),
Expand All @@ -201,14 +218,23 @@ def test4_constant(self):
nodes_with_edges_only=True)

graph_ref = build_graph(nodes_attributes,
[('placeholder_1_data', 'consumer_1'),
('placeholder_1_data', 'consumer_2'),
('placeholder_1_data', 'consumer_3'),
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'consumer_1'),
('reshape_1_data', 'consumer_2'),
('reshape_1_data', 'consumer_3'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([3, 1, 1]), 'value': np.ones([3, 1, 1])}
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
'reshape_1_const': {'value': int64_array([3, 1, 1]), 'shape': int64_array([3])},
'reshape_1_const_data': {'value': int64_array([3, 1, 1]),
'shape': int64_array([3])},
'reshape_1_data': {'shape': int64_array([3, 1, 1])},
}, nodes_with_edges_only=True)

pattern = EltwiseInputReshape()
Expand Down
Loading

0 comments on commit 4f0479f

Please sign in to comment.