Skip to content

Commit

Permalink
Fix code after review #2
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Kazantsev <[email protected]>
  • Loading branch information
rkazants committed Nov 9, 2020
1 parent 72abca8 commit 8c8e8e9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
15 changes: 12 additions & 3 deletions model-optimizer/extensions/middle/EltwiseInputReshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def find_and_replace_pattern(self, graph: Graph):
class EltwiseInputReshape(MiddleReplacementPattern):
# This pass should be called directly from pipeline before layout change and other permutations
enabled = False
force_shape_inference = True

def find_and_replace_pattern(self, graph: Graph):
# Generate a map for producers of eltwise nodes with non-normalized shapes
Expand Down Expand Up @@ -97,11 +96,21 @@ def find_and_replace_pattern(self, graph: Graph):
for unsqueeze_dims in mapping[producer_port].keys():
unsqueeze_name = producer_node.soft_get('name', producer_node.id) + '/EltwiseReshape'
unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(list(unsqueeze_dims))},
{'name': unsqueeze_name,
'override_output_shape': True})
{'name': unsqueeze_name})

unsqueeze_node.in_port(0).connect(producer_port)

# Insert Reshape with determined output shape between the current producer and eltwise node
for consumer_port in mapping[producer_port][unsqueeze_dims]:
consumer_port.connect(unsqueeze_node.out_port(0))

# The shape and value adjustments must be explicitly done within the transformation
# since the transformation is called from Fusing transformation that excludes
# automatic call of shape inference pass
producer_port_value = producer_port.data.get_value()
producer_port_shape = producer_port.data.get_shape()
new_shape = np.insert(producer_port_shape, np.zeros_like(unsqueeze_dims), 1)
if producer_port_value is not None:
unsqueeze_node.out_port(0).data.set_value(np.reshape(producer_port_value, new_shape))
else:
unsqueeze_node.out_port(0).data.set_shape(new_shape)
9 changes: 0 additions & 9 deletions model-optimizer/extensions/middle/EltwiseInputReshape_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from extensions.middle.EltwiseInputReshape import EltwiseInputReshape
from mo.front.common.partial_infer.utils import int64_array
from mo.middle.passes.eliminate import shape_inference
from mo.middle.passes.eliminate_test import build_graph
from mo.utils.ir_engine.compare_graphs import compare_graphs

Expand Down Expand Up @@ -86,7 +85,6 @@ def test1_not_constant(self):
'placeholder_3_data': {'shape': np.array([64, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
}, nodes_with_edges_only=True)
shape_inference(graph)

graph_ref = build_graph(nodes_attributes,
[
Expand Down Expand Up @@ -120,7 +118,6 @@ def test1_not_constant(self):

pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
shape_inference(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True)
self.assertTrue(flag, resp)
Expand Down Expand Up @@ -216,7 +213,6 @@ def test_mega_hardcore(self):

pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
shape_inference(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_4', check_op_attrs=True)
self.assertTrue(flag, resp)
Expand Down Expand Up @@ -278,7 +274,6 @@ def test2_not_constant(self):

pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
shape_inference(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)
Expand Down Expand Up @@ -334,7 +329,6 @@ def test3_not_constant(self):

pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
shape_inference(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)
Expand Down Expand Up @@ -397,7 +391,6 @@ def test4_constant(self):

pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
shape_inference(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)
Expand Down Expand Up @@ -451,7 +444,6 @@ def test5_constant(self):

pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
shape_inference(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)
Expand Down Expand Up @@ -495,7 +487,6 @@ def test6_not_constant(self):

pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
shape_inference(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)

0 comments on commit 8c8e8e9

Please sign in to comment.