@@ -833,28 +833,6 @@ def _get_abs_layer_name(node):
833833 params , num_layers )
834834 return sym
835835
836-
837- def _parse_import_prerequisites (graph ):
838- """ Calculate the named preconditions from TensorFlow `graph`.
839- Return prerequisites for parsing:
840- a. Set of operator names which don't have their mapping in TVM, i.e.
841- which are not supported
842- """
843- missing_operators = set ()
844- for node in graph .node :
845- if node .op == "Placeholder" :
846- pass
847- elif node .op == "Const" :
848- pass
849- else :
850- if any ([node .op in t for t in [_identity_list , _convert_map , _convert_map_rnn ]]):
851- pass
852- else :
853- missing_operators .add (node .op )
854-
855- return missing_operators
856-
857-
858836class GraphProto (object ):
859837 """ A helper class for handling nnvm graph copying from Tensorflow GraphDef.
860838 Definition:
@@ -863,12 +841,8 @@ class GraphProto(object):
863841 def __init__ (self ):
864842 self ._nodes = {}
865843 self ._params = {}
866- self ._renames = {}
867- self ._replacements = {}
868844 self ._output_shapes = {}
869- self ._num_input = 0
870845 self ._num_param = 0
871- self ._input_node = ''
872846 self ._num_rnn_layer = False
873847
874848 def from_tensorflow (self , graph ):
@@ -907,7 +881,7 @@ def from_tensorflow(self, graph):
907881 raise ImportError (
908882 "Unable to import tensorflow which is required {}" .format (e ))
909883
910- missing_operators = _parse_import_prerequisites (graph )
884+ missing_operators = self . _parse_import_prerequisites (graph )
911885
912886 if missing_operators :
913887 raise NotImplementedError ( \
@@ -917,58 +891,42 @@ def from_tensorflow(self, graph):
917891 for node in graph .node :
918892 # Tensorflow doesn't have seperate list for params extraction.
919893 # Operator name 'Const' is treated as a parameter to build NNVM params dict.
894+
920895 input_shapes = {}
896+
897+ attr = self ._parse_attr (node .attr )
898+
899+ #Variable converted to Const will not have only value attr
900+ if 'value' in attr and node .op == 'Const' :
901+ tensor_value = attr ['value' ]
902+ self ._output_shapes [node .name ] = \
903+ [tensor_util .TensorShapeProtoToList ( \
904+ tensor_value .tensor_shape )]
905+ elif '_output_shapes' in attr :
906+ self ._output_shapes [node .name ] = \
907+ [tensor_util .TensorShapeProtoToList (shape ) \
908+ for shape in attr ['_output_shapes' ]]
909+ else :
910+ raise NotImplementedError ( \
911+ "Please freeze the graph with add_shapes=True" )
912+
921913 if node .op == "Placeholder" :
922- self ._input_node = node .name
923- self . _num_input += 1
914+ self ._nodes [ node . name ] = _sym . Variable ( name = node .name ,
915+ shape = self . _output_shapes [ node . name ][ 0 ])
924916
925- try :
926- self ._output_shapes [node .name ] = \
927- [tensor_util .TensorShapeProtoToList (shape ) \
928- for shape in self ._parse_attr (node .attr )['_output_shapes' ]]
929- self ._nodes [node .name ] = _sym .Variable (name = node .name ,
930- shape = self ._output_shapes [node .name ][0 ])
931- input_shapes [self ._nodes [node .name ]] = self ._output_shapes [node .name ]
932- except KeyError :
933- raise NotImplementedError ( \
934- "Please freeze the graph with add_shapes=True" )
917+ #input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
935918 elif node .op == "Const" :
936- if self ._input_node == '' :
937- self ._input_node = node .name
938- self ._num_input += 1
939- self ._nodes [node .name ] = _sym .Variable (name = node .name )
940- else :
941- # Rest all nodes are Param nodes, lets parse
942- self ._num_param += 1
943- for key , value in node .attr .items ():
944- self ._parse_param (key , value , node .name )
945- if node .name not in self ._nodes :
946- raise NotImplementedError ( \
947- "Const {} couldn't be converted to Param." .format (node .name ))
948- attr = self ._parse_attr (node .attr )
949- #Variable converted to Const will not have only value attr
950- if 'value' in attr :
951- tensor_value = attr ['value' ]
952- self ._output_shapes [node .name ] = \
953- [tensor_util .TensorShapeProtoToList ( \
954- tensor_value .tensor_shape )]
955- elif '_output_shapes' in attr :
956- self ._output_shapes [node .name ] = \
957- [tensor_util .TensorShapeProtoToList (shape ) \
958- for shape in self ._parse_attr (node .attr )['_output_shapes' ]]
959- else :
919+ # All Const nodes are Param nodes, lets parse
920+ self ._num_param += 1
921+ for key , value in node .attr .items ():
922+ self ._parse_param (key , value , node .name )
923+ if node .name not in self ._nodes :
960924 raise NotImplementedError ( \
961- "Please freeze the graph with add_shapes=True" )
962- else :
925+ "Const {} couldn't be converted to Param." . format ( node . name ) )
926+
963927 attr = self ._parse_attr (node .attr )
964- try :
965- self ._output_shapes [node .name ] = \
966- [tensor_util .TensorShapeProtoToList (shape ) \
967- for shape in attr ['_output_shapes' ]]
968- except KeyError :
969- raise NotImplementedError ( \
970- "Please freeze the graph with add_shapes=True" )
971928
929+ else :
972930 # Pass the parsed shapes instead
973931 attr ["_output_shapes" ] = self ._output_shapes [node .name ]
974932
@@ -983,11 +941,12 @@ def from_tensorflow(self, graph):
983941 if ":" in node .input [0 ]:
984942 in_name , _ = node .input [0 ].split (':' )
985943 node .input [0 ] = in_name
944+
945+ # Fill shapes for all inputs in a list
986946 try :
987947 inputs = [self ._nodes [i ] for i in node .input ]
988948 for i in node .input :
989- if i not in self ._params :
990- input_shapes [self ._nodes [i ]] = self ._output_shapes [i ]
949+ input_shapes [self ._nodes [i ]] = self ._output_shapes [i ]
991950 attr ['_input_shapes' ] = input_shapes
992951 except KeyError :
993952 # TODO: Need to find clean way to handle '^CheckNumerics'
@@ -999,18 +958,40 @@ def from_tensorflow(self, graph):
999958 # Assuming only one output.
1000959 self ._nodes [node .name ] = op
1001960 node_output = op
961+
1002962 # Assume the final node is the output node
1003963 out = node_output
1004964
1005965 #Add the RNN outputs also with 'head' nodes of the nnvm graph
1006966 if self ._num_rnn_layer :
1007967 out_rnn = _sym .concatenate (* self ._out_rnn , axis = 0 )
1008968 out = [out , out_rnn ]
969+
1009970 if isinstance (out , list ):
1010971 out = _sym .Group (out )
1011972
1012973 return out , self ._params
1013974
975+ def _parse_import_prerequisites (self , graph ):
976+ """ Calculate the named preconditions from TensorFlow `graph`.
977+ Return prerequisites for parsing:
978+ a. Set of operator names which don't have their mapping in TVM, i.e.
979+ which are not supported
980+ """
981+ missing_operators = set ()
982+ for node in graph .node :
983+ if node .op == "Placeholder" :
984+ pass
985+ elif node .op == "Const" :
986+ pass
987+ else :
988+ if any ([node .op in t for t in [_identity_list , _convert_map , _convert_map_rnn ]]):
989+ pass
990+ else :
991+ missing_operators .add (node .op )
992+
993+ return missing_operators
994+
1014995 def _parse_param (self , key , value , name ):
1015996 try :
1016997 from tensorflow .python .framework import tensor_util
@@ -1020,6 +1001,12 @@ def _parse_param(self, key, value, name):
10201001
10211002 if key == 'value' :
10221003 np_array = tensor_util .MakeNdarray (value .tensor )
1004+
1005+ if np_array .dtype == np .dtype (object ):
1006+ # Object types are generally tensorflow DT_STRING (DecodeJpeg op). Just leave it as placeholder.
1007+ self ._nodes [name ] = _sym .Variable (name = name )
1008+ return
1009+
10231010 array_ndim = len (np_array .shape )
10241011 if array_ndim == 0 :
10251012 new_array = np .empty ([1 ], dtype = np_array .dtype )
0 commit comments