diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 35f4ff451cdb..8fdc5944fb91 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2047,3 +2047,36 @@ def convert_broadcast_to(node, **kwargs): ) return [tensor_node, expand_node] + + +@mx_op.register("topk") +def convert_topk(node, **kwargs): + """Map MXNet's topk operator attributes to onnx's TopK operator + and return the created node. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + + axis = int(attrs.get('axis', '-1')) + k = int(attrs.get('k', '1')) + ret_type = attrs.get('ret_typ') + dtype = attrs.get('dtype') + outputs = [name + '_output0'] + + if ret_type and ret_type == 'both': + if dtype and dtype == 'int64': + outputs.append(name + '_output1') + else: + raise NotImplementedError("ONNX expects indices to be of type int64") + else: + raise NotImplementedError("ONNX expects both value and indices as output") + + topk_node = onnx.helper.make_node( + "TopK", + input_nodes, + outputs, + axis=axis, + k=k, + name=name + ) + + return [topk_node] diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py index cf95bfef09a3..bd48d26ef6ba 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py @@ -23,7 +23,7 @@ from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan from ._op_translations import softplus, shape, gather, lp_pooling, size from ._op_translations import ceil, floor, hardsigmoid, global_lppooling -from ._op_translations import concat, hardmax +from ._op_translations import concat, hardmax, topk from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm @@ -147,5 +147,6 @@ 'DepthToSpace' : depthtospace, 'SpaceToDepth' : spacetodepth, 'Hardmax' : hardmax, - 'LpNormalization' : lpnormalization + 'LpNormalization' : lpnormalization, + 'TopK' : topk } diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 734b438581a5..627181d6ae21 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -779,3 +779,11 @@ def lpnormalization(attrs, inputs, proto_obj): axis = int(attrs.get("axis", -1)) new_attrs.update(axis=axis) return 'norm', new_attrs, inputs + + +def topk(attrs, inputs, proto_obj): + """Returns the top k elements in an input array along the given axis.""" + new_attrs = translation_utils._add_extra_attributes(attrs, + {'ret_typ': 'both', + 'dtype': 'int64'}) + return 'topk', new_attrs, inputs diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index 89b60d15e84f..eb6b71896e28 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -78,7 +78,8 @@ 'test_max_', 'test_softplus', 'test_reduce_', - 'test_split_equal' + 'test_split_equal', + 'test_top_k' ], 'import': ['test_gather', 'test_softsign',