Skip to content

Commit 7ee9cca

Browse files
sergei-mironovtqchen
authored andcommitted
[NNVM] Support argmax/argmin in tensorflow frontend (#1514)
1 parent 54ca149 commit 7ee9cca

File tree

2 files changed

+79
-8
lines changed

2 files changed

+79
-8
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,20 @@ def _impl(inputs, attr, *args):
9191
return AttrCvt(op_name="__pow_scalar__", extras={'scalar': -0.5})(inputs, attr)
9292
return _impl
9393

94+
def _argx(func, func_name):
95+
""" A common wrapper for argmin and argmax operations """
96+
def _impl(inputs, attr, params):
97+
try:
98+
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
99+
# support the case where it inputs from a scalar constant.
100+
axis_input_name = inputs[1].list_output_names()[0]
101+
axis_input_vlaue = params[axis_input_name].asnumpy()[0]
102+
except (IndexError, KeyError):
103+
raise TypeError( \
104+
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
105+
return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
106+
return _impl
107+
94108
def _elemwise(name):
95109
def _impl(inputs, attr, *args):
96110
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
@@ -664,6 +678,8 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
664678
# for 1 to N mapping(composed), use custom callable functions
665679
# for N to 1 mapping, currently not supported(?)
666680
_convert_map = {
681+
'ArgMax' : _argx(_sym.argmax, 'argmax'),
682+
'ArgMin' : _argx(_sym.argmin, 'argmin'),
667683
'AvgPool' : _pooling('avg_pool'),
668684
'BatchNormWithGlobalNormalization' : _batch_norm(),
669685
'BiasAdd' : _bias_add(),
@@ -879,6 +895,28 @@ def _get_abs_layer_name(node):
879895
params, num_layers)
880896
return sym
881897

898+
899+
def _parse_import_prerequisites(graph):
900+
""" Calculate the named preconditions from TensorFlow `graph`.
901+
Return prerequisites for parsing:
902+
a. Set of operator names which don't have their mapping in TVM, i.e.
903+
which are not supported
904+
"""
905+
missing_operators = set()
906+
for node in graph.node:
907+
if node.op == "Placeholder":
908+
pass
909+
elif node.op == "Const":
910+
pass
911+
else:
912+
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
913+
pass
914+
else:
915+
missing_operators.add(node.op)
916+
917+
return missing_operators
918+
919+
882920
class GraphProto(object):
883921
""" A helper class for handling nnvm graph copying from Tensorflow GraphDef.
884922
Definition:
@@ -901,7 +939,7 @@ def from_tensorflow(self, graph):
901939
Follow the tensorflow graph definition to parse and convert it to NNVM.
902940
Some of the assumptions listed below.
903941
904-
-> First Const or Placeholder node will be considered as graph input.
942+
-> First Placeholder or Const node will be considered as graph input.
905943
-> Rest all Const nodes are params.
906944
-> Last node is assumed as graph output.
907945
-> _output_shapes : Attribute should present in the tenserflow forzen graph.
@@ -910,6 +948,7 @@ def from_tensorflow(self, graph):
910948
-> CheckNumerics: No implementation as of now for this.
911949
Just copies input to output.
912950
951+
TODO: Change algorithm to stop treating first 'Const' in a special way.
913952
914953
Parameters
915954
----------
@@ -923,23 +962,25 @@ def from_tensorflow(self, graph):
923962
params : dict
924963
A dict of name: tvm.nd.array pairs, used as pretrained weights
925964
"""
926-
# Parse throught all nodes and start extracting
927-
# params aka Const nodes
928-
# input nodes : First const node
929-
# normal nodes : other normal nodes
930965

931966
try:
932967
from tensorflow.python.framework import tensor_util
933968
except ImportError as e:
934969
raise ImportError(
935970
"Unable to import tensorflow which is required {}".format(e))
936971

972+
missing_operators = _parse_import_prerequisites(graph)
973+
974+
if missing_operators:
975+
raise NotImplementedError( \
976+
"The following operators are not implemented: {}".format(missing_operators))
977+
978+
# Parse the nodes to re-create TF graph using Symbol API of NNVM
937979
for node in graph.node:
938980
# Tensorflow doesn't have seperate list for params extraction.
939981
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
940982
input_shapes = {}
941983
if node.op == "Placeholder":
942-
# Assuming only one input graph with type 'Placeholder'
943984
self._input_node = node.name
944985
self._num_input += 1
945986

@@ -954,7 +995,6 @@ def from_tensorflow(self, graph):
954995
raise NotImplementedError( \
955996
"Please freeze the graph with add_shapes=True")
956997
elif node.op == "Const":
957-
# Assuming first Const node as Graph Input node
958998
if self._input_node == '':
959999
self._input_node = node.name
9601000
self._num_input += 1
@@ -997,7 +1037,7 @@ def from_tensorflow(self, graph):
9971037
# Pass the node name too in attr
9981038
attr["_node_name"] = node.name
9991039

1000-
#ToDo: Some of the tensorflow operators maintain internaly maintain
1040+
#ToDo: Some of the tensorflow operators internaly maintain
10011041
#execution layers and its output name will the layer number along with
10021042
#graph node name.eg: Node name:- 'Model/RNN/cell_0/RnnCell', but the
10031043
#output name will be 'Model/RNN/cell_0/RnnCell:0'. In this case,

nnvm/tests/python/frontend/tensorflow/test_forward.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,37 @@ def test_forward_sigmoid():
404404

405405
_test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))
406406

407+
#######################################################################
408+
# Argmin/Argmax
409+
# -------------
410+
411+
def _test_argx(func, data, **kwargs):
412+
413+
with tf.Graph().as_default():
414+
inp = constant_op.constant(data, shape=data.shape, dtype=data.dtype, name="c0")
415+
416+
# pylint: disable=unused-variable
417+
out = func(inp, name="argx0", **kwargs)
418+
# pylint: enable=unused-variable
419+
420+
with tf.Session() as sess:
421+
graph_def = tf.graph_util.convert_variables_to_constants(
422+
sess=sess,
423+
input_graph_def=sess.graph.as_graph_def(add_shapes=True),
424+
output_node_names=["argx0"])
425+
426+
tf_output = run_tf_graph(sess, data, input_node="c0:0", output_node="argx0:0")
427+
tvm_output = run_tvm_graph(graph_def, data, "c0", tf_output.shape, output_dtype='int32')
428+
429+
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)
430+
431+
sess.close()
432+
433+
def test_argmin_argmax():
434+
for axis in [None,0,1,2]:
435+
data = np.random.uniform(size=(8,4,9)).astype('float32')
436+
_test_argx(tf.argmax, data=data, axis=axis)
437+
_test_argx(tf.argmin, data=data, axis=axis)
407438

408439
#######################################################################
409440
# Variable

0 commit comments

Comments
 (0)