Skip to content

Commit 943d5f5

Browse files
committed
[NNVM][TENSORFLOW] Cleanup Const, Placeholder, _input_shapes.
1 parent 6d68a18 commit 943d5f5

File tree

1 file changed

+61
-74
lines changed

1 file changed

+61
-74
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 61 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -833,28 +833,6 @@ def _get_abs_layer_name(node):
833833
params, num_layers)
834834
return sym
835835

836-
837-
def _parse_import_prerequisites(graph):
838-
""" Calculate the named preconditions from TensorFlow `graph`.
839-
Return prerequisites for parsing:
840-
a. Set of operator names which don't have their mapping in TVM, i.e.
841-
which are not supported
842-
"""
843-
missing_operators = set()
844-
for node in graph.node:
845-
if node.op == "Placeholder":
846-
pass
847-
elif node.op == "Const":
848-
pass
849-
else:
850-
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
851-
pass
852-
else:
853-
missing_operators.add(node.op)
854-
855-
return missing_operators
856-
857-
858836
class GraphProto(object):
859837
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
860838
Definition:
@@ -863,12 +841,8 @@ class GraphProto(object):
863841
def __init__(self):
864842
self._nodes = {}
865843
self._params = {}
866-
self._renames = {}
867-
self._replacements = {}
868844
self._output_shapes = {}
869-
self._num_input = 0
870845
self._num_param = 0
871-
self._input_node = ''
872846
self._num_rnn_layer = False
873847

874848
def from_tensorflow(self, graph):
@@ -907,7 +881,7 @@ def from_tensorflow(self, graph):
907881
raise ImportError(
908882
"Unable to import tensorflow which is required {}".format(e))
909883

910-
missing_operators = _parse_import_prerequisites(graph)
884+
missing_operators = self._parse_import_prerequisites(graph)
911885

912886
if missing_operators:
913887
raise NotImplementedError( \
@@ -917,58 +891,42 @@ def from_tensorflow(self, graph):
917891
for node in graph.node:
918892
# Tensorflow doesn't have seperate list for params extraction.
919893
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
894+
920895
input_shapes = {}
896+
897+
attr = self._parse_attr(node.attr)
898+
899+
#Variable converted to Const will not have only value attr
900+
if 'value' in attr and node.op == 'Const':
901+
tensor_value = attr['value']
902+
self._output_shapes[node.name] = \
903+
[tensor_util.TensorShapeProtoToList( \
904+
tensor_value.tensor_shape)]
905+
elif '_output_shapes' in attr:
906+
self._output_shapes[node.name] = \
907+
[tensor_util.TensorShapeProtoToList(shape) \
908+
for shape in attr['_output_shapes']]
909+
else:
910+
raise NotImplementedError( \
911+
"Please freeze the graph with add_shapes=True")
912+
921913
if node.op == "Placeholder":
922-
self._input_node = node.name
923-
self._num_input += 1
914+
self._nodes[node.name] = _sym.Variable(name=node.name,
915+
shape=self._output_shapes[node.name][0])
924916

925-
try:
926-
self._output_shapes[node.name] = \
927-
[tensor_util.TensorShapeProtoToList(shape) \
928-
for shape in self._parse_attr(node.attr)['_output_shapes']]
929-
self._nodes[node.name] = _sym.Variable(name=node.name,
930-
shape=self._output_shapes[node.name][0])
931-
input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
932-
except KeyError:
933-
raise NotImplementedError( \
934-
"Please freeze the graph with add_shapes=True")
917+
#input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
935918
elif node.op == "Const":
936-
if self._input_node == '':
937-
self._input_node = node.name
938-
self._num_input += 1
939-
self._nodes[node.name] = _sym.Variable(name=node.name)
940-
else:
941-
# Rest all nodes are Param nodes, lets parse
942-
self._num_param += 1
943-
for key, value in node.attr.items():
944-
self._parse_param(key, value, node.name)
945-
if node.name not in self._nodes:
946-
raise NotImplementedError( \
947-
"Const {} couldn't be converted to Param.".format(node.name))
948-
attr = self._parse_attr(node.attr)
949-
#Variable converted to Const will not have only value attr
950-
if 'value' in attr:
951-
tensor_value = attr['value']
952-
self._output_shapes[node.name] = \
953-
[tensor_util.TensorShapeProtoToList( \
954-
tensor_value.tensor_shape)]
955-
elif '_output_shapes' in attr:
956-
self._output_shapes[node.name] = \
957-
[tensor_util.TensorShapeProtoToList(shape) \
958-
for shape in self._parse_attr(node.attr)['_output_shapes']]
959-
else:
919+
# All Const nodes are Param nodes, lets parse
920+
self._num_param += 1
921+
for key, value in node.attr.items():
922+
self._parse_param(key, value, node.name)
923+
if node.name not in self._nodes:
960924
raise NotImplementedError( \
961-
"Please freeze the graph with add_shapes=True")
962-
else:
925+
"Const {} couldn't be converted to Param.".format(node.name))
926+
963927
attr = self._parse_attr(node.attr)
964-
try:
965-
self._output_shapes[node.name] = \
966-
[tensor_util.TensorShapeProtoToList(shape) \
967-
for shape in attr['_output_shapes']]
968-
except KeyError:
969-
raise NotImplementedError( \
970-
"Please freeze the graph with add_shapes=True")
971928

929+
else:
972930
# Pass the parsed shapes instead
973931
attr["_output_shapes"] = self._output_shapes[node.name]
974932

@@ -983,11 +941,12 @@ def from_tensorflow(self, graph):
983941
if ":" in node.input[0]:
984942
in_name, _ = node.input[0].split(':')
985943
node.input[0] = in_name
944+
945+
# Fill shapes for all inputs in a list
986946
try:
987947
inputs = [self._nodes[i] for i in node.input]
988948
for i in node.input:
989-
if i not in self._params:
990-
input_shapes[self._nodes[i]] = self._output_shapes[i]
949+
input_shapes[self._nodes[i]] = self._output_shapes[i]
991950
attr['_input_shapes'] = input_shapes
992951
except KeyError:
993952
# TODO: Need to find clean way to handle '^CheckNumerics'
@@ -999,18 +958,40 @@ def from_tensorflow(self, graph):
999958
# Assuming only one output.
1000959
self._nodes[node.name] = op
1001960
node_output = op
961+
1002962
# Assume the final node is the output node
1003963
out = node_output
1004964

1005965
#Add the RNN outputs also with 'head' nodes of the nnvm graph
1006966
if self._num_rnn_layer:
1007967
out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
1008968
out = [out, out_rnn]
969+
1009970
if isinstance(out, list):
1010971
out = _sym.Group(out)
1011972

1012973
return out, self._params
1013974

975+
def _parse_import_prerequisites(self, graph):
976+
""" Calculate the named preconditions from TensorFlow `graph`.
977+
Return prerequisites for parsing:
978+
a. Set of operator names which don't have their mapping in TVM, i.e.
979+
which are not supported
980+
"""
981+
missing_operators = set()
982+
for node in graph.node:
983+
if node.op == "Placeholder":
984+
pass
985+
elif node.op == "Const":
986+
pass
987+
else:
988+
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
989+
pass
990+
else:
991+
missing_operators.add(node.op)
992+
993+
return missing_operators
994+
1014995
def _parse_param(self, key, value, name):
1015996
try:
1016997
from tensorflow.python.framework import tensor_util
@@ -1020,6 +1001,12 @@ def _parse_param(self, key, value, name):
10201001

10211002
if key == 'value':
10221003
np_array = tensor_util.MakeNdarray(value.tensor)
1004+
1005+
if np_array.dtype == np.dtype(object):
1006+
# Object types are generally tensorflow DT_STRING (DecodeJpeg op). Just leave it as placeholder.
1007+
self._nodes[name] = _sym.Variable(name=name)
1008+
return
1009+
10231010
array_ndim = len(np_array.shape)
10241011
if array_ndim == 0:
10251012
new_array = np.empty([1], dtype=np_array.dtype)

0 commit comments

Comments
 (0)