Skip to content

Commit

Permalink
Perform full normalization based on shapes of all inputs to eltwise
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Kazantsev <[email protected]>
  • Loading branch information
rkazants committed Nov 13, 2020
1 parent 1eaa548 commit 5bb7b96
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 12 deletions.
2 changes: 1 addition & 1 deletion model-optimizer/extensions/front/mxnet/MXRepeatReplacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def mxrepeat_decomposition(node: Node):

axis = get_canonical_axis_index_node(input_rank, node.axis)
unsqueeze_axis = create_op_node_with_second_input(
graph, Add, int64_array(1), {'name': name + '/Unsqueeze/Axis'}, input_node=axis)
graph, Add, int64_array([1]), {'name': name + '/Unsqueeze/Axis'}, input_node=axis)

unsqueeze = Unsqueeze(graph, {'name': name + '/Unsqueeze'}).create_node()
unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0))
Expand Down
43 changes: 34 additions & 9 deletions model-optimizer/extensions/middle/EltwiseInputReshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ def find_and_replace_pattern(self, graph: Graph):


class EltwiseInputReshape(MiddleReplacementPattern):
# TODO: implement the full normalization of eltwise input shapes executed within the single call
# This pass should be called directly from pipeline before layout change and other permutations
enabled = False

def find_and_replace_pattern(self, graph: Graph):
def find_and_replace_pattern(self, graph: Graph, is_first_pass=False):
# Generate a map for producers of eltwise nodes with non-normalized shapes
# and in this map every producer has another map that reflects normalized shape
# to a list of eltwise consumers
Expand All @@ -81,14 +82,21 @@ def find_and_replace_pattern(self, graph: Graph):
consumer_port = eltwise_node.in_port(in_port_idx)
producer_port = consumer_port.get_source()
producer_shape = producer_port.data.get_shape()
if len(producer_shape) != len(eltwise_shape):

producer_data_node = eltwise_node.in_node(in_port_idx)
edge_attrs = graph.get_edge_data(producer_data_node.id, eltwise_node.id)[0]
if is_first_pass and 'unsqueeze_dims' in edge_attrs and len(edge_attrs['unsqueeze_dims']) > 0:
unsqueeze_dims = tuple([x for x in edge_attrs['unsqueeze_dims']])
elif not is_first_pass and len(producer_shape) != len(eltwise_shape):
unsqueeze_dims = tuple(np.arange(len(eltwise_shape) - len(producer_shape), dtype=np.int64))
if not producer_port in mapping:
mapping.update({producer_port: {unsqueeze_dims: [consumer_port]}})
elif not unsqueeze_dims in mapping[producer_port]:
mapping[producer_port].update({unsqueeze_dims: [consumer_port]})
else:
mapping[producer_port][unsqueeze_dims].append(consumer_port)
else:
continue
if not producer_port in mapping:
mapping.update({producer_port: {unsqueeze_dims: [consumer_port]}})
elif not unsqueeze_dims in mapping[producer_port]:
mapping[producer_port].update({unsqueeze_dims: [consumer_port]})
else:
mapping[producer_port][unsqueeze_dims].append(consumer_port)

# Walk through each produced in the map and insert Reshape nodes between a producer and eltwise nodes
for producer_port in mapping.keys():
Expand All @@ -109,8 +117,25 @@ def find_and_replace_pattern(self, graph: Graph):
# automatic call of shape inference pass
producer_port_value = producer_port.data.get_value()
producer_port_shape = producer_port.data.get_shape()
new_shape = np.insert(producer_port_shape, np.zeros_like(unsqueeze_dims), 1)
new_shape = producer_port_shape
for unsqueeze_dim in unsqueeze_dims:
new_shape = np.insert(new_shape, unsqueeze_dim, 1)
if producer_port_value is not None:
unsqueeze_node.out_port(0).data.set_value(np.reshape(producer_port_value, new_shape))
else:
unsqueeze_node.out_port(0).data.set_shape(new_shape)


class EltwiseInputReshapeFirstPass(MiddleReplacementPattern):
# TODO: The full normalization must be performed within the single call of EltwiseInputReshape
# and eltwise partial_infer must not sets edges with unsqueeze_dims attribute
# Now this pass executes a query from eltwise partial_infer for input shapes normalization
enabled = True
force_clean_up = True

def run_after(self):
from extensions.middle.pass_separator import MiddleStart
return [MiddleStart]

def find_and_replace_pattern(self, graph: Graph):
EltwiseInputReshape().find_and_replace_pattern(graph, is_first_pass=True)
10 changes: 8 additions & 2 deletions model-optimizer/mo/front/common/partial_infer/eltwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def eltwise_infer(node, op=None, **kwargs):
if len(shape) != max_dims and len(shape) > 0 and axis is not None:
new_shape = shape

# Compute unsqueeze_dims
num_unsqueeze_dims = max_dims - axis - len(shape)
unsqueeze_dims = int64_array([])
if num_unsqueeze_dims > 0:
unsqueeze_dims = np.arange(len(shape), len(shape) + num_unsqueeze_dims, dtype=np.int64)

# Extend shape with 1's
for cnt in range(axis + len(shape), max_dims):
new_shape = np.append(new_shape, 1)
Expand All @@ -58,8 +64,8 @@ def eltwise_infer(node, op=None, **kwargs):
edge_attrs = node.graph.get_edge_data(inputs[id].id, node.id)[0]

nx.set_edge_attributes(G=node.graph,
values={(inputs[id].id, node.id, 0): new_shape},
name='new_shape')
values={(inputs[id].id, node.id, 0): unsqueeze_dims},
name='unsqueeze_dims')

# Reshape value to correctly calculate output shape
if values[id] is not None:
Expand Down

0 comments on commit 5bb7b96

Please sign in to comment.