@@ -91,6 +91,20 @@ def _impl(inputs, attr, *args):
9191 return AttrCvt (op_name = "__pow_scalar__" , extras = {'scalar' : - 0.5 })(inputs , attr )
9292 return _impl
9393
94+ def _argx (func , func_name ):
95+ """ A common wrapper for argmin and argmax operations """
96+ def _impl (inputs , attr , params ):
97+ try :
98+ # In Tensorflow, `axis` argument is a Tensor, not attribute. We
99+ # support the case where it inputs from a scalar constant.
100+ axis_input_name = inputs [1 ].list_output_names ()[0 ]
101+ axis_input_vlaue = params [axis_input_name ].asnumpy ()[0 ]
102+ except (IndexError , KeyError ):
103+ raise TypeError ( \
104+ "Unsupported argument for `{}` : `axis` should be a constant" .format (func_name ))
105+ return func (inputs [0 ], axis = axis_input_vlaue , keepdims = False )
106+ return _impl
107+
94108def _elemwise (name ):
95109 def _impl (inputs , attr , * args ):
96110 assert len (inputs ) == 2 , "Math op take 2 inputs, {} given" .format (len (inputs ))
@@ -664,6 +678,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
664678# for 1 to N mapping(composed), use custom callable functions
665679# for N to 1 mapping, currently not supported(?)
666680_convert_map = {
681+ 'ArgMax' : _argx (_sym .argmax , 'argmax' ),
682+ 'ArgMin' : _argx (_sym .argmin , 'argmin' ),
667683 'AvgPool' : _pooling ('avg_pool' ),
668684 'BatchNormWithGlobalNormalization' : _batch_norm (),
669685 'BiasAdd' : _bias_add (),
@@ -879,6 +895,28 @@ def _get_abs_layer_name(node):
879895 params , num_layers )
880896 return sym
881897
898+
899+ def _parse_import_prerequisites (graph ):
900+ """ Calculate the named preconditions from TensorFlow `graph`.
901+ Return prerequisites for parsing:
902+ a. Set of operator names which don't have their mapping in TVM, i.e.
903+ which are not supported
904+ """
905+ missing_operators = set ()
906+ for node in graph .node :
907+ if node .op == "Placeholder" :
908+ pass
909+ elif node .op == "Const" :
910+ pass
911+ else :
912+ if any ([node .op in t for t in [_identity_list , _convert_map , _convert_map_rnn ]]):
913+ pass
914+ else :
915+ missing_operators .add (node .op )
916+
917+ return missing_operators
918+
919+
882920class GraphProto (object ):
883921 """ A helper class for handling nnvm graph copying from Tensorflow GraphDef.
884922 Definition:
@@ -901,7 +939,7 @@ def from_tensorflow(self, graph):
901939 Follow the tensorflow graph definition to parse and convert it to NNVM.
902940 Some of the assumptions listed below.
903941
904- -> First Const or Placeholder node will be considered as graph input.
942+ -> First Placeholder or Const node will be considered as graph input.
905943 -> Rest all Const nodes are params.
906944 -> Last node is assumed as graph output.
907945 -> _output_shapes : Attribute should present in the tenserflow forzen graph.
@@ -910,6 +948,7 @@ def from_tensorflow(self, graph):
910948 -> CheckNumerics: No implementation as of now for this.
911949 Just copies input to output.
912950
951+ TODO: Change algorithm to stop treating first 'Const' in a special way.
913952
914953 Parameters
915954 ----------
@@ -923,23 +962,25 @@ def from_tensorflow(self, graph):
923962 params : dict
924963 A dict of name: tvm.nd.array pairs, used as pretrained weights
925964 """
926- # Parse throught all nodes and start extracting
927- # params aka Const nodes
928- # input nodes : First const node
929- # normal nodes : other normal nodes
930965
931966 try :
932967 from tensorflow .python .framework import tensor_util
933968 except ImportError as e :
934969 raise ImportError (
935970 "Unable to import tensorflow which is required {}" .format (e ))
936971
972+ missing_operators = _parse_import_prerequisites (graph )
973+
974+ if missing_operators :
975+ raise NotImplementedError ( \
976+ "The following operators are not implemented: {}" .format (missing_operators ))
977+
978+ # Parse the nodes to re-create TF graph using Symbol API of NNVM
937979 for node in graph .node :
938980 # Tensorflow doesn't have seperate list for params extraction.
939981 # Operator name 'Const' is treated as a parameter to build NNVM params dict.
940982 input_shapes = {}
941983 if node .op == "Placeholder" :
942- # Assuming only one input graph with type 'Placeholder'
943984 self ._input_node = node .name
944985 self ._num_input += 1
945986
@@ -954,7 +995,6 @@ def from_tensorflow(self, graph):
954995 raise NotImplementedError ( \
955996 "Please freeze the graph with add_shapes=True" )
956997 elif node .op == "Const" :
957- # Assuming first Const node as Graph Input node
958998 if self ._input_node == '' :
959999 self ._input_node = node .name
9601000 self ._num_input += 1
@@ -997,7 +1037,7 @@ def from_tensorflow(self, graph):
9971037 # Pass the node name too in attr
9981038 attr ["_node_name" ] = node .name
9991039
1000- #ToDo: Some of the tensorflow operators maintain internaly maintain
1040+ #ToDo: Some of the tensorflow operators internaly maintain
10011041 #execution layers and its output name will the layer number along with
10021042 #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
10031043 #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,
0 commit comments