diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 27f19a1e177a..636f55adb863 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -91,6 +91,20 @@ def _impl(inputs, attr, *args): return AttrCvt(op_name="__pow_scalar__", extras={'scalar': -0.5})(inputs, attr) return _impl +def _argx(func, func_name): + """ A common wrapper for argmin and argmax operations """ + def _impl(inputs, attr, params): + try: + # In Tensorflow, `axis` argument is a Tensor, not attribute. We + # support the case where it inputs from a scalar constant. + axis_input_name = inputs[1].list_output_names()[0] + axis_input_vlaue = params[axis_input_name].asnumpy()[0] + except (IndexError, KeyError): + raise TypeError( \ + "Unsupported argument for `{}` : `axis` should be a constant".format(func_name)) + return func(inputs[0], axis=axis_input_vlaue, keepdims=False) + return _impl + def _elemwise(name): def _impl(inputs, attr, *args): assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs)) @@ -650,6 +664,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params): # for 1 to N mapping(composed), use custom callable functions # for N to 1 mapping, currently not supported(?) _convert_map = { + 'ArgMax' : _argx(_sym.argmax, 'argmax'), + 'ArgMin' : _argx(_sym.argmin, 'argmin'), 'AvgPool' : _pooling('avg_pool'), 'BatchNormWithGlobalNormalization' : _batch_norm(), 'BiasAdd' : _bias_add(), @@ -864,6 +880,28 @@ def _get_abs_layer_name(node): params, num_layers) return sym + +def _parse_import_prerequisites(graph): + """ Calculate the named preconditions from TensorFlow `graph`. + Return prerequisites for parsing: + a. Set of operator names which don't have their mapping in TVM, i.e. + which are not supported + """ + missing_operators = set() + for node in graph.node: + if node.op == "Placeholder": + pass + elif node.op == "Const": + pass + else: + if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): + pass + else: + missing_operators.add(node.op) + + return missing_operators + + class GraphProto(object): """ A helper class for handling nnvm graph copying from Tensorflow GraphDef. Definition: @@ -886,7 +924,7 @@ def from_tensorflow(self, graph): Follow the tensorflow graph definition to parse and convert it to NNVM. Some of the assumptions listed below. - -> First Const or Placeholder node will be considered as graph input. + -> First Placeholder or Const node will be considered as graph input. -> Rest all Const nodes are params. -> Last node is assumed as graph output. -> _output_shapes : Attribute should present in the tenserflow forzen graph. @@ -895,6 +933,7 @@ def from_tensorflow(self, graph): -> CheckNumerics: No implementation as of now for this. Just copies input to output. + TODO: Change algorithm to stop treating first 'Const' in a special way. Parameters ---------- @@ -908,10 +947,6 @@ def from_tensorflow(self, graph): params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ - # Parse throught all nodes and start extracting - # params aka Const nodes - # input nodes : First const node - # normal nodes : other normal nodes try: from tensorflow.python.framework import tensor_util @@ -919,12 +954,18 @@ def from_tensorflow(self, graph): raise ImportError( "Unable to import tensorflow which is required {}".format(e)) + missing_operators = _parse_import_prerequisites(graph) + + if missing_operators: + raise NotImplementedError( \ + "The following operators are not implemented: {}".format(missing_operators)) + + # Parse the nodes to re-create TF graph using Symbol API of NNVM for node in graph.node: # Tensorflow doesn't have seperate list for params extraction. # Operator name 'Const' is treated as a parameter to build NNVM params dict. input_shapes = {} if node.op == "Placeholder": - # Assuming only one input graph with type 'Placeholder' self._input_node = node.name self._num_input += 1 @@ -939,7 +980,6 @@ def from_tensorflow(self, graph): raise NotImplementedError( \ "Please freeze the graph with add_shapes=True") elif node.op == "Const": - # Assuming first Const node as Graph Input node if self._input_node == '': self._input_node = node.name self._num_input += 1 @@ -982,7 +1022,7 @@ def from_tensorflow(self, graph): # Pass the node name too in attr attr["_node_name"] = node.name - #ToDo: Some of the tensorflow operators maintain internaly maintain + #ToDo: Some of the tensorflow operators internaly maintain #execution layers and its output name will the layer number along with #graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the #output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case, diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 045d154d9d8b..1a54ba52dc99 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -404,6 +404,37 @@ def test_forward_sigmoid(): _test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32')) +####################################################################### +# Argmin/Argmax +# ------------- + +def _test_argx(func, data, **kwargs): + + with tf.Graph().as_default(): + inp = constant_op.constant(data, shape=data.shape, dtype=data.dtype, name="c0") + + # pylint: disable=unused-variable + out = func(inp, name="argx0", **kwargs) + # pylint: enable=unused-variable + + with tf.Session() as sess: + graph_def = tf.graph_util.convert_variables_to_constants( + sess=sess, + input_graph_def=sess.graph.as_graph_def(add_shapes=True), + output_node_names=["argx0"]) + + tf_output = run_tf_graph(sess, data, input_node="c0:0", output_node="argx0:0") + tvm_output = run_tvm_graph(graph_def, data, "c0", tf_output.shape, output_dtype='int32') + + np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5) + + sess.close() + +def test_argmin_argmax(): + for axis in [None,0,1,2]: + data = np.random.uniform(size=(8,4,9)).astype('float32') + _test_argx(tf.argmax, data=data, axis=axis) + _test_argx(tf.argmin, data=data, axis=axis) ####################################################################### # Variable