@@ -880,6 +880,28 @@ def _get_abs_layer_name(node):
880880 params , num_layers )
881881 return sym
882882
883+
884+ def _parse_import_prerequisites (graph ):
885+ """ Calculate the named preconditions from TensorFlow `graph`.
886+ Return prerequisites for parsing:
887+ a. Set of operator names which don't have their mapping in TVM, i.e.
888+ which are not supported
889+ """
890+ missing_operators = set ()
891+ for node in graph .node :
892+ if node .op == "Placeholder" :
893+ pass
894+ elif node .op == "Const" :
895+ pass
896+ else :
897+ if any ([node .op in t for t in [_identity_list , _convert_map , _convert_map_rnn ]]):
898+ pass
899+ else :
900+ missing_operators .add (node .op )
901+
902+ return missing_operators
903+
904+
883905class GraphProto (object ):
884906 """ A helper class for handling nnvm graph copying from Tensorflow GraphDef.
885907 Definition:
@@ -902,7 +924,7 @@ def from_tensorflow(self, graph):
902924 Follow the tensorflow graph definition to parse and convert it to NNVM.
903925 Some of the assumptions listed below.
904926
905- -> First Const or Placeholder node will be considered as graph input.
927+ -> First Placeholder or Const node will be considered as graph input.
906928 -> Rest all Const nodes are params.
907929 -> Last node is assumed as graph output.
908930 -> _output_shapes : Attribute should present in the tenserflow forzen graph.
@@ -911,6 +933,7 @@ def from_tensorflow(self, graph):
911933 -> CheckNumerics: No implementation as of now for this.
912934 Just copies input to output.
913935
936+ TODO: Change algorithm to stop treating first 'Const' in a special way.
914937
915938 Parameters
916939 ----------
@@ -924,23 +947,25 @@ def from_tensorflow(self, graph):
924947 params : dict
925948 A dict of name: tvm.nd.array pairs, used as pretrained weights
926949 """
927- # Parse throught all nodes and start extracting
928- # params aka Const nodes
929- # input nodes : First const node
930- # normal nodes : other normal nodes
931950
932951 try :
933952 from tensorflow .python .framework import tensor_util
934953 except ImportError as e :
935954 raise ImportError (
936955 "Unable to import tensorflow which is required {}" .format (e ))
937956
957+ missing_operators = _parse_import_prerequisites (graph )
958+
959+ if missing_operators :
960+ raise NotImplementedError ( \
961+ "The following operators are not implemented: {}" .format (missing_operators ))
962+
963+ # Parse the nodes to re-create TF graph using Symbol API of NNVM
938964 for node in graph .node :
939965 # Tensorflow doesn't have seperate list for params extraction.
940966 # Operator name 'Const' is treated as a parameter to build NNVM params dict.
941967 input_shapes = {}
942968 if node .op == "Placeholder" :
943- # Assuming only one input graph with type 'Placeholder'
944969 self ._input_node = node .name
945970 self ._num_input += 1
946971
@@ -955,7 +980,6 @@ def from_tensorflow(self, graph):
955980 raise NotImplementedError ( \
956981 "Please freeze the graph with add_shapes=True" )
957982 elif node .op == "Const" :
958- # Assuming first Const node as Graph Input node
959983 if self ._input_node == '' :
960984 self ._input_node = node .name
961985 self ._num_input += 1
@@ -998,7 +1022,7 @@ def from_tensorflow(self, graph):
9981022 # Pass the node name too in attr
9991023 attr ["_node_name" ] = node .name
10001024
1001- #ToDo: Some of the tensorflow operators maintain internaly maintain
1025+ #ToDo: Some of the tensorflow operators internaly maintain
10021026 #execution layers and its output name will the layer number along with
10031027 #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
10041028 #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
0 commit comments