Skip to content

Commit

Permalink
Fix the NHWC->NCHW transformation for dynamic weights (#2848)
Browse files Browse the repository at this point in the history
* Fix the NHWC->NCHW transformation when weights and data comes from same input

* Simplify code
  • Loading branch information
mvafin authored Nov 6, 2020
1 parent 15d7919 commit 9f0b26e
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 27 deletions.
105 changes: 78 additions & 27 deletions model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from typing import Set

from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \
mark_as_correct_data_layout
mark_as_correct_data_layout, mark_output_as_in_correct_layout, mark_input_as_in_correct_layout
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
from extensions.middle.pass_separator import PostMiddleStart
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.graph.perm_inputs import PermuteInputs
from mo.graph.port import Port
from mo.middle.replacement import MiddleReplacementPattern

Expand All @@ -34,8 +36,7 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
1. Prevents from adding Transpose operations before and after "reinterp_shape" like operations which change rank of
the input and output tensors of this layout agnostic op.
2. Disable attributes permutation for all intermediate ops between these "reinterp_shape" nodes.
For now the transformation is triggered for MatMul operation only getting input as 4D or 5D tensors.
3. Marks nodes along the weight path of convolutions as in correct layout to not permute them from NHWC to NCHW
"""
enabled = True
graph_condition = [lambda graph: graph.graph['layout'] == 'NHWC']
Expand Down Expand Up @@ -69,7 +70,7 @@ def bfs(self, start_nodes: list, visited: set, condition: callable = None, forwa
:param start_nodes: Nodes to start search from
:param visited: set of already visited nodes where traversing should not happen
:param condition: function getting a Node as input and returning whether the node should be included into the
resukt or not. If the value is None then the node is added unconditionally.
result or not. If the value is None then the node is added unconditionally.
:param forward: boolean flag specifying the traverse direction
:return: the list of Nodes visited
"""
Expand Down Expand Up @@ -127,8 +128,16 @@ def find_and_replace_pattern(self, graph: Graph):
mark_as_correct_data_layout(visited_node)
visited_node['nchw_layout'] = True

for node in self.get_ports_and_nodes_on_weights(graph)[1]:
mark_as_correct_data_layout(node)
_, nodes_weigths, nodes_in_weights = self.get_ports_and_nodes_on_weights(graph)
for node in nodes_weigths:
if node in nodes_in_weights:
for ind, port in node.in_ports().items():
if ind not in nodes_in_weights[node]:
mark_input_as_in_correct_layout(node, ind)
for ind, port in node.out_ports().items():
mark_output_as_in_correct_layout(node, ind)
else:
mark_as_correct_data_layout(node)
node['nchw_layout'] = True
if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
node.out_node()['nchw_layout'] = True
Expand All @@ -140,8 +149,39 @@ def find_and_replace_pattern(self, graph: Graph):
node.out_node()['nchw_layout'] = True

@staticmethod
def walk_up_from_in_ports_to_out_ports(in_ports: Set[Port], out_ports: Set[Port],
visited_ports: Set[Port] = None, visited_nodes: Set[Node] = None):
def get_weighted_layer_type_to_in_weights_port():
get_weights_port_index = lambda node: node.weights_index if node.has_valid('weights_index') else 1
weighted_layer_type_to_in_weights_port = {
'Convolution': get_weights_port_index,
'DeformableConvolution': get_weights_port_index,
'Deconvolution': get_weights_port_index,
'BinaryConvolution': get_weights_port_index,
}
return weighted_layer_type_to_in_weights_port

@staticmethod
def insert_permute_inputs_before_dynamic_weights_subgraph(dynamic_subgraphs: Set[Node] = None):
"""
The function inserts permutations on input nodes in the weights subgraph
:param dynamic_subgraphs: Set of Nodes belonging to weight path subgraphs
:return: the list of Nodes which are inputs to weight path subgraphs
"""
dynamic_in_nodes = dict()
for node in dynamic_subgraphs:
node_type = node.soft_get('type')
if node_type not in ['Const', 'Parameter', 'ShapeOf']:
idx_lst = list()
for idx in [idx for idx, port in node.in_ports().items() if
not port.disconnected() and port.get_source().node not in dynamic_subgraphs]:
PermuteInputs().set_input_permutation(node.in_node(idx), node, 'input:{}'.format(idx),
'transpose_nchw_to_nhwc')
idx_lst.append(idx)
if len(idx_lst):
dynamic_in_nodes[node] = idx_lst
return dynamic_in_nodes

@staticmethod
def walk_up_from_in_ports_to_out_ports(in_ports: Set[Port], out_ports: Set[Port], port_condition=None):
""""
Returns all intermediate ports and nodes of such a sub-graph:
Expand All @@ -153,14 +193,14 @@ def walk_up_from_in_ports_to_out_ports(in_ports: Set[Port], out_ports: Set[Port]
\/ \/
in_ports
"""
if visited_ports is None:
visited_ports = set()
if visited_nodes is None:
visited_nodes = set()
visited_ports = set()
visited_nodes = set()

deque_of_in_ports = deque(in_ports)
while len(deque_of_in_ports):
in_port = deque_of_in_ports.popleft()
if in_port.get_source() is None:
continue
source_node = in_port.get_source().node
if in_port in visited_ports: # do not check visited_nodes as search is based on ports
continue
Expand All @@ -169,40 +209,51 @@ def walk_up_from_in_ports_to_out_ports(in_ports: Set[Port], out_ports: Set[Port]
if not len(in_port.get_source().node.in_ports()): # for Constants and Parameters to be visited
visited_nodes.add(in_port.get_source().node)
continue
deque_of_in_ports.extend([port for port in source_node.in_ports().values() if not port.disconnected()])
for idx, port in source_node.in_ports().items():
if not port.disconnected() and (not port_condition or port_condition(source_node, idx)):
deque_of_in_ports.append(port)
visited_nodes.add(source_node)
return visited_ports, visited_nodes

@staticmethod
def is_not_weight_port(node: Node, idx: int):
w_types_to_in_port_dict = MarkSubGraphsWithCorrectLayout.get_weighted_layer_type_to_in_weights_port()
node_type = node.soft_get('type')
return node_type in w_types_to_in_port_dict.keys() and idx != w_types_to_in_port_dict[node_type](node)

@staticmethod
def get_ports_and_nodes_on_weights(graph):
get_weights_port_index = lambda node: node.weights_index if node.has_valid('weights_index') else 1
weighted_layer_type_to_in_weights_port = {
'Convolution': get_weights_port_index,
'DeformableConvolution': get_weights_port_index,
'Deconvolution': get_weights_port_index,
'BinaryConvolution': get_weights_port_index,
}
nodes = graph.get_op_nodes()
weighted_types = list(weighted_layer_type_to_in_weights_port.keys())

# collect all input ports with weights
weight_ports = set()
result_ports = set()
start_ports = set()
w_types_to_in_port_dict = MarkSubGraphsWithCorrectLayout.get_weighted_layer_type_to_in_weights_port()
for node in nodes:
node_type = node.soft_get('type', 'unknown')
if node_type not in weighted_types:
if node_type in ['Const', 'Parameter', 'ShapeOf']:
if node_type not in w_types_to_in_port_dict.keys():
if node_type in ['Const', 'Parameter', 'ShapeOf', 'ExtractImagePatches']:
start_ports.add(node.out_port(0))
continue
weight_port_idx = weighted_layer_type_to_in_weights_port[node_type](node)
weight_port_idx = w_types_to_in_port_dict[node_type](node)
assert node.is_in_port_connected(weight_port_idx), \
'Unexpected port configuration of {} node with name=`{}`'.format(node_type,
node.soft_get('name', node.id))
weight_ports.add(node.in_port(weight_port_idx))
for result in graph.get_op_nodes(type='Result'):
result_ports.update(result.in_ports().values())

# collect all sub-graphs that start with Constant/Parameter/ShapeOf and end at in_port as weights
ports, nodes = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(weight_ports, start_ports)
return ports, nodes
# collect all sub-graphs that start with Constant/Parameter/ShapeOf/ExtractImagePatches and end at in_port as
# weights
ports_w, nodes_w = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(weight_ports, start_ports)
# collect all sub-graphs that start with Constant/Parameter/ShapeOf/ExtractImagePatches, end at Result nodes and
# not contains branches that end as weights
ports_d, nodes_d = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(
result_ports, start_ports, MarkSubGraphsWithCorrectLayout.is_not_weight_port)
nodes_dif = nodes_w.difference(nodes_d)
nodes_in_w = MarkSubGraphsWithCorrectLayout.insert_permute_inputs_before_dynamic_weights_subgraph(nodes_dif)
return ports_w.difference(ports_d), nodes_dif, nodes_in_w

@staticmethod
def get_ports_and_nodes_on_shape_subgraphs(graph):
Expand Down
19 changes: 19 additions & 0 deletions model-optimizer/mo/graph/perm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,23 @@ def transpose(op_node: Node, port_info: str, input_port: int):
op_node.in_port(input_port).get_connection().insert_node(transpose)


def transpose_nchw_to_nhwc(op_node: Node, port_info: str, input_port: int):
graph = op_node.graph
permutation_data_node = get_node_with_permutation(op_node, port_info)
rank = len(permutation_data_node.shape)
assert rank >= 4, 'Rank must be 4D or higher for HCHW to HHWC permutation on node {}.'.format(op_node.id)

perm = list(range(rank))
perm.insert(1, perm.pop())
perm = int64_array(perm)

transpose_name = op_node.soft_get('name', op_node.id) + '/Transpose'
from mo.front.tf.graph_utils import create_op_with_const_inputs # avoiding recursive imports
transpose = create_op_with_const_inputs(
graph, Transpose, {1: perm}, {'name': transpose_name, 'override_output_shape': True})
op_node.in_port(input_port).get_connection().insert_node(transpose)


class PermuteInputs:
common_inv_permutation = lambda node, port_info, input_port: axis(node, port_info, input_port)

Expand All @@ -179,6 +196,8 @@ class PermuteInputs:
'order': lambda node, port_info, input_port: order(node, port_info, input_port),
'shape': lambda node, port_info, input_port: shape(node, port_info, input_port),
'transpose': lambda node, port_info, input_port: transpose(node, port_info, input_port),
'transpose_nchw_to_nhwc': lambda node, port_info, input_port: transpose_nchw_to_nhwc(node, port_info,
input_port),
}

def set_input_permutation(self, node1: Node, node2: Node, port_info: str, permutation_rule: str):
Expand Down

0 comments on commit 9f0b26e

Please sign in to comment.