@@ -902,7 +902,7 @@ def from_tensorflow(self, graph):
902902 Follow the tensorflow graph definition to parse and convert it to NNVM.
903903 Some of the assumptions listed below.
904904
905- -> First Const or Placeholder node will be considered as graph input.
905+ -> First Placeholder or Const node will be considered as graph input.
906906 -> Rest all Const nodes are params.
907907 -> Last node is assumed as graph output.
908908 -> _output_shapes : Attribute should present in the tenserflow forzen graph.
@@ -924,17 +924,39 @@ def from_tensorflow(self, graph):
924924 params : dict
925925 A dict of name: tvm.nd.array pairs, used as pretrained weights
926926 """
927- # Parse throught all nodes and start extracting
928- # params aka Const nodes
929- # input nodes : First const node
930- # normal nodes : other normal nodes
931927
932928 try :
933929 from tensorflow .python .framework import tensor_util
934930 except ImportError as e :
935931 raise ImportError (
936932 "Unable to import tensorflow which is required {}" .format (e ))
937933
934+ # Parse the nodes briefly to determine the preconditions:
935+ # a. Placeholders
936+ # b. Missing operations to report in advance
937+ has_placeholders = False
938+ missing_operators = set ()
939+ for node in graph .node :
940+ if node .op == "Placeholder" :
941+ has_placeholders = True
942+ elif node .op == "Const" :
943+ pass
944+ else :
945+ if node .op in _identity_list :
946+ pass
947+ elif node .op in _convert_map :
948+ pass
949+ else :
950+ missing_operators .add (node .op )
951+
952+ if missing_operators :
953+ raise NotImplementedError ( \
954+ "The following operators are not implemented: {}" .format (missing_operators ))
955+
956+ # Parse the nodes to re-create TF graph using Symbol API of NNVM
957+ # * Input node will be produced by first placeholder node, or by first
958+ # Const node if there are no placeholders
959+ # * Other nodes will be mapped using `_convert_map` and other translation facilities
938960 for node in graph .node :
939961 # Tensorflow doesn't have seperate list for params extraction.
940962 # Operator name 'Const' is treated as a parameter to build NNVM params dict.
@@ -956,7 +978,7 @@ def from_tensorflow(self, graph):
956978 "Please freeze the graph with add_shapes=True" )
957979 elif node .op == "Const" :
958980 # Assuming first Const node as Graph Input node
959- if self ._input_node == '' :
981+ if self ._input_node == '' and not has_placeholders :
960982 self ._input_node = node .name
961983 self ._num_input += 1
962984 self ._nodes [node .name ] = _sym .Variable (name = node .name )
0 commit comments