Skip to content

Commit

Permalink
Merge EltwiseInputNormalization and EltwiseInputReshape transformations
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Kazantsev <[email protected]>
  • Loading branch information
rkazants committed Oct 29, 2020
1 parent fa891f6 commit 187fd1c
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 369 deletions.
1 change: 0 additions & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ extensions/middle/DeleteControlFlowEdges.py
extensions/middle/DeleteNotExecutable.py
extensions/middle/DilatedConvolution.py
extensions/middle/EltwiseChecker.py
extensions/middle/EltwiseInputNormalization.py
extensions/middle/EltwiseInputReshape.py
extensions/middle/FakeSplitOutputs.py
extensions/middle/FusedBatchNormNonConstant.py
Expand Down
48 changes: 0 additions & 48 deletions model-optimizer/extensions/middle/EltwiseInputNormalization.py

This file was deleted.

218 changes: 0 additions & 218 deletions model-optimizer/extensions/middle/EltwiseInputNormalization_test.py

This file was deleted.

25 changes: 12 additions & 13 deletions model-optimizer/extensions/middle/EltwiseInputReshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.op import Op
from mo.ops.reshape import Reshape


Expand Down Expand Up @@ -66,24 +65,26 @@ def find_and_replace_pattern(self, graph: Graph):


class EltwiseInputReshape(MiddleReplacementPattern):
enabled = True
force_clean_up = True

def run_after(self):
from extensions.middle.pass_separator import MiddleStart
return [MiddleStart]
# This pass should be called directly from pipeline before layout change and other permutations
enabled = False

def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes():
for out_port_idx in node.out_ports():
mapping = {}
output_port = node.out_port(out_port_idx)
output_port_shape = output_port.data.get_shape()
for consumer_port in output_port.get_destinations():
edge_attrs = consumer_port.get_in_edge_attrs()
if 'new_shape' in edge_attrs:
if np.array_equal(edge_attrs['new_shape'], output_port.data.get_shape()):
if consumer_port.node.has_valid('is_eltwise') and consumer_port.node['is_eltwise'] == True:
consumer_output_shape = consumer_port.node.out_port(0).data.get_shape()
if np.array_equal(consumer_output_shape, output_port_shape):
continue
new_shape = tuple([x for x in edge_attrs['new_shape']])
# Set edge attribute new_shape for further transformation pass
new_shape = output_port_shape
if len(output_port_shape) != len(consumer_output_shape):
for x in range(len(consumer_output_shape) - len(output_port_shape)):
new_shape = np.insert(new_shape, 0, 1)
new_shape = tuple([x for x in new_shape])
if not new_shape in mapping:
mapping.update({new_shape: [consumer_port]})
else:
Expand All @@ -99,8 +100,6 @@ def find_and_replace_pattern(self, graph: Graph):

# Iterate over consumers and reconnect them to Reshape layer output
for consumer_port in mapping[shape_key]:
edge_attrs = consumer_port.get_in_edge_attrs()
del edge_attrs['new_shape']
consumer_port.connect(reshape_node.out_port(0))

# Adjust shape and value for Reshape output
Expand Down
Loading

0 comments on commit 187fd1c

Please sign in to comment.