Skip to content

Commit e316f03

Browse files
srkreddy1238tqchen
authored andcommitted
[NNVM][TENSORFLOW] Cleanup redundant code. (#1551)
1 parent 42ec2e0 commit e316f03

File tree

1 file changed

+83
-157
lines changed

1 file changed

+83
-157
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 83 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -168,81 +168,7 @@ def _impl(inputs, attr, params):
168168
custom_check=_dimension_constraint())(inputs, attr)
169169
return _impl
170170

171-
def _conv():
172-
def _impl(inputs, attr, params):
173-
attr['data_format'] = attr['data_format'].decode("utf-8")
174-
175-
# Extract kernel shape from params
176-
conv_param_weights = params[inputs[1].list_output_names()[0]]
177-
178-
if attr['data_format'] == 'NHWC':
179-
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
180-
attr['channels'] = conv_param_weights.shape[3]
181-
if 'dilations' in attr:
182-
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
183-
elif attr['data_format'] == 'NCHW':
184-
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
185-
attr['channels'] = conv_param_weights.shape[1]
186-
if 'dilations' in attr:
187-
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
188-
else:
189-
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
190-
191-
# Fix strides
192-
attr['strides'] = (attr['strides'][1], attr['strides'][2])
193-
194-
# Fix padding
195-
input_shapes = attr['_input_shapes'][inputs[0]]
196-
attr['padding'] = attr['padding'].decode("utf-8")
197-
198-
if attr['padding'] == 'VALID':
199-
attr['padding'] = [0, 0]
200-
elif attr['padding'] == 'SAME':
201-
stride_h, stride_w = attr['strides']
202-
kernel_h, kernel_w = attr['kernel_shape']
203-
if attr['data_format'] == 'NHWC':
204-
in_h = input_shapes[0][1]
205-
in_w = input_shapes[0][2]
206-
else:
207-
in_h = input_shapes[0][2]
208-
in_w = input_shapes[0][3]
209-
210-
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
211-
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
212-
213-
if attr['data_format'] == 'NHWC':
214-
inputs[0] = _sym.pad(data=inputs[0],
215-
pad_width=((0, 0),
216-
(pad_v[0], pad_v[1]),
217-
(pad_h[0], pad_h[1]),
218-
(0, 0)))
219-
else:
220-
inputs[0] = _sym.pad(data=inputs[0],
221-
pad_width=((0, 0),
222-
(0, 0),
223-
(pad_v[0], pad_v[1]),
224-
(pad_h[0], pad_h[1])))
225-
226-
attr['padding'] = [0, 0]
227-
228-
else:
229-
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
230-
231-
if 'kernel_layout' not in attr:
232-
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
233-
234-
return AttrCvt(
235-
op_name=_dimension_picker('conv'),
236-
transforms={
237-
'kernel_shape': 'kernel_size',
238-
'data_format': 'layout',
239-
'dilations': ('dilation', (0, 0)),
240-
'group': ('groups', 1)},
241-
extras={'use_bias': len(inputs) == 3},
242-
custom_check=_dimension_constraint())(inputs, attr)
243-
return _impl
244-
245-
def _depthwise_conv():
171+
def _conv(opname):
246172
def _impl(inputs, attr, params):
247173
attr['data_format'] = attr['data_format'].decode("utf-8")
248174
input_shapes = attr['_input_shapes'][inputs[0]]
@@ -253,24 +179,33 @@ def _impl(inputs, attr, params):
253179
if attr['data_format'] == 'NHWC':
254180
kernel_h, kernel_w, _, depth_mult = conv_param_weights.shape
255181
attr['kernel_shape'] = (conv_param_weights.shape[0], conv_param_weights.shape[1])
256-
attr['channels'] = input_shapes[0][3] * depth_mult
182+
if opname == 'conv':
183+
attr['channels'] = conv_param_weights.shape[3]
184+
else:
185+
attr['channels'] = input_shapes[0][3] * depth_mult
186+
257187
if 'dilations' in attr:
258188
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
259189
elif attr['data_format'] == 'NCHW':
260190
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
261191
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
262-
attr['channels'] = input_shapes[0][1] * depth_mult
192+
if opname == 'conv':
193+
attr['channels'] = conv_param_weights.shape[1]
194+
else:
195+
attr['channels'] = input_shapes[0][1] * depth_mult
196+
263197
if 'dilations' in attr:
264198
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
265199
else:
266200
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))
267201

202+
203+
if opname == 'depthwise':
204+
attr['groups'] = attr['channels']
205+
268206
# Fix strides
269207
attr['strides'] = (attr['strides'][1], attr['strides'][2])
270208

271-
# Fix groups
272-
attr['groups'] = attr['channels']
273-
274209
# Fix padding
275210
attr['padding'] = attr['padding'].decode("utf-8")
276211

@@ -308,7 +243,10 @@ def _impl(inputs, attr, params):
308243
raise TypeError("Unsupported padding type : {}".format(attr['padding']))
309244

310245
if 'kernel_layout' not in attr:
311-
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
246+
if opname == 'conv':
247+
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
248+
else:
249+
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
312250

313251
return AttrCvt(
314252
op_name=_dimension_picker('conv'),
@@ -687,7 +625,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
687625
'CheckNumerics' : _check_numerics(),
688626
'Concat' : _concat(),
689627
'ConcatV2' : _concatV2(),
690-
'Conv2D' : _conv(),
628+
'Conv2D' : _conv('conv'),
691629
'DecodeJpeg' : _decode_image(),
692630
'ExpandDims' : _expand_dims(),
693631
'Identity' : _identity(),
@@ -704,7 +642,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
704642
'Squeeze' : _squeeze(),
705643
'FusedBatchNorm' : _fused_batch_norm(),
706644
'Relu6' : _relu6(),
707-
'DepthwiseConv2dNative' : _depthwise_conv(),
645+
'DepthwiseConv2dNative' : _conv('depthwise'),
708646
'Shape' : _shape(),
709647
'Sigmoid' : AttrCvt('sigmoid'),
710648
'Fill' : _fill(),
@@ -895,28 +833,6 @@ def _get_abs_layer_name(node):
895833
params, num_layers)
896834
return sym
897835

898-
899-
def _parse_import_prerequisites(graph):
900-
""" Calculate the named preconditions from TensorFlow `graph`.
901-
Return prerequisites for parsing:
902-
a. Set of operator names which don't have their mapping in TVM, i.e.
903-
which are not supported
904-
"""
905-
missing_operators = set()
906-
for node in graph.node:
907-
if node.op == "Placeholder":
908-
pass
909-
elif node.op == "Const":
910-
pass
911-
else:
912-
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
913-
pass
914-
else:
915-
missing_operators.add(node.op)
916-
917-
return missing_operators
918-
919-
920836
class GraphProto(object):
921837
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
922838
Definition:
@@ -925,12 +841,8 @@ class GraphProto(object):
925841
def __init__(self):
926842
self._nodes = {}
927843
self._params = {}
928-
self._renames = {}
929-
self._replacements = {}
930844
self._output_shapes = {}
931-
self._num_input = 0
932845
self._num_param = 0
933-
self._input_node = ''
934846
self._num_rnn_layer = False
935847

936848
def from_tensorflow(self, graph):
@@ -969,7 +881,7 @@ def from_tensorflow(self, graph):
969881
raise ImportError(
970882
"Unable to import tensorflow which is required {}".format(e))
971883

972-
missing_operators = _parse_import_prerequisites(graph)
884+
missing_operators = self._parse_import_prerequisites(graph)
973885

974886
if missing_operators:
975887
raise NotImplementedError( \
@@ -979,58 +891,42 @@ def from_tensorflow(self, graph):
979891
for node in graph.node:
980892
# Tensorflow doesn't have seperate list for params extraction.
981893
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
894+
982895
input_shapes = {}
896+
897+
attr = self._parse_attr(node.attr)
898+
899+
#Variable converted to Const will not have only value attr
900+
if 'value' in attr and node.op == 'Const':
901+
tensor_value = attr['value']
902+
self._output_shapes[node.name] = \
903+
[tensor_util.TensorShapeProtoToList( \
904+
tensor_value.tensor_shape)]
905+
elif '_output_shapes' in attr:
906+
self._output_shapes[node.name] = \
907+
[tensor_util.TensorShapeProtoToList(shape) \
908+
for shape in attr['_output_shapes']]
909+
else:
910+
raise NotImplementedError( \
911+
"Please freeze the graph with add_shapes=True")
912+
983913
if node.op == "Placeholder":
984-
self._input_node = node.name
985-
self._num_input += 1
914+
self._nodes[node.name] = _sym.Variable(name=node.name,
915+
shape=self._output_shapes[node.name][0])
986916

987-
try:
988-
self._output_shapes[node.name] = \
989-
[tensor_util.TensorShapeProtoToList(shape) \
990-
for shape in self._parse_attr(node.attr)['_output_shapes']]
991-
self._nodes[node.name] = _sym.Variable(name=node.name,
992-
shape=self._output_shapes[node.name][0])
993-
input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
994-
except KeyError:
995-
raise NotImplementedError( \
996-
"Please freeze the graph with add_shapes=True")
917+
#input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
997918
elif node.op == "Const":
998-
if self._input_node == '':
999-
self._input_node = node.name
1000-
self._num_input += 1
1001-
self._nodes[node.name] = _sym.Variable(name=node.name)
1002-
else:
1003-
# Rest all nodes are Param nodes, lets parse
1004-
self._num_param += 1
1005-
for key, value in node.attr.items():
1006-
self._parse_param(key, value, node.name)
1007-
if node.name not in self._nodes:
1008-
raise NotImplementedError( \
1009-
"Const {} couldn't be converted to Param.".format(node.name))
1010-
attr = self._parse_attr(node.attr)
1011-
#Variable converted to Const will not have only value attr
1012-
if 'value' in attr:
1013-
tensor_value = attr['value']
1014-
self._output_shapes[node.name] = \
1015-
[tensor_util.TensorShapeProtoToList( \
1016-
tensor_value.tensor_shape)]
1017-
elif '_output_shapes' in attr:
1018-
self._output_shapes[node.name] = \
1019-
[tensor_util.TensorShapeProtoToList(shape) \
1020-
for shape in self._parse_attr(node.attr)['_output_shapes']]
1021-
else:
919+
# All Const nodes are Param nodes, lets parse
920+
self._num_param += 1
921+
for key, value in node.attr.items():
922+
self._parse_param(key, value, node.name)
923+
if node.name not in self._nodes:
1022924
raise NotImplementedError( \
1023-
"Please freeze the graph with add_shapes=True")
1024-
else:
925+
"Const {} couldn't be converted to Param.".format(node.name))
926+
1025927
attr = self._parse_attr(node.attr)
1026-
try:
1027-
self._output_shapes[node.name] = \
1028-
[tensor_util.TensorShapeProtoToList(shape) \
1029-
for shape in attr['_output_shapes']]
1030-
except KeyError:
1031-
raise NotImplementedError( \
1032-
"Please freeze the graph with add_shapes=True")
1033928

929+
else:
1034930
# Pass the parsed shapes instead
1035931
attr["_output_shapes"] = self._output_shapes[node.name]
1036932

@@ -1045,11 +941,12 @@ def from_tensorflow(self, graph):
1045941
if ":" in node.input[0]:
1046942
in_name, _ = node.input[0].split(':')
1047943
node.input[0] = in_name
944+
945+
# Fill shapes for all inputs in a list
1048946
try:
1049947
inputs = [self._nodes[i] for i in node.input]
1050948
for i in node.input:
1051-
if i not in self._params:
1052-
input_shapes[self._nodes[i]] = self._output_shapes[i]
949+
input_shapes[self._nodes[i]] = self._output_shapes[i]
1053950
attr['_input_shapes'] = input_shapes
1054951
except KeyError:
1055952
# TODO: Need to find clean way to handle '^CheckNumerics'
@@ -1061,18 +958,40 @@ def from_tensorflow(self, graph):
1061958
# Assuming only one output.
1062959
self._nodes[node.name] = op
1063960
node_output = op
961+
1064962
# Assume the final node is the output node
1065963
out = node_output
1066964

1067965
#Add the RNN outputs also with 'head' nodes of the nnvm graph
1068966
if self._num_rnn_layer:
1069967
out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
1070968
out = [out, out_rnn]
969+
1071970
if isinstance(out, list):
1072971
out = _sym.Group(out)
1073972

1074973
return out, self._params
1075974

975+
def _parse_import_prerequisites(self, graph):
976+
""" Calculate the named preconditions from TensorFlow `graph`.
977+
Return prerequisites for parsing:
978+
a. Set of operator names which don't have their mapping in TVM, i.e.
979+
which are not supported
980+
"""
981+
missing_operators = set()
982+
for node in graph.node:
983+
if node.op == "Placeholder":
984+
pass
985+
elif node.op == "Const":
986+
pass
987+
else:
988+
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
989+
pass
990+
else:
991+
missing_operators.add(node.op)
992+
993+
return missing_operators
994+
1076995
def _parse_param(self, key, value, name):
1077996
try:
1078997
from tensorflow.python.framework import tensor_util
@@ -1082,6 +1001,13 @@ def _parse_param(self, key, value, name):
10821001

10831002
if key == 'value':
10841003
np_array = tensor_util.MakeNdarray(value.tensor)
1004+
1005+
if np_array.dtype == np.dtype(object):
1006+
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
1007+
# Just leave it as placeholder.
1008+
self._nodes[name] = _sym.Variable(name=name)
1009+
return
1010+
10851011
array_ndim = len(np_array.shape)
10861012
if array_ndim == 0:
10871013
new_array = np.empty([1], dtype=np_array.dtype)

0 commit comments

Comments
 (0)