Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-895] ONNX import/export: TopK (#13627)
Browse files Browse the repository at this point in the history
* ONNX import: TopK

* ONNX export: TopK

* Add test for TopK

* Rebasing

* Fix a comment and rebase

* Re-trigger CI

* Re-trigger CI
  • Loading branch information
vandanavk authored and Roshrini committed Aug 29, 2019
1 parent 196d1f4 commit b7cca01
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 3 deletions.
33 changes: 33 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2047,3 +2047,36 @@ def convert_broadcast_to(node, **kwargs):
)

return [tensor_node, expand_node]


@mx_op.register("topk")
def convert_topk(node, **kwargs):
"""Map MXNet's topk operator attributes to onnx's TopK operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = int(attrs.get('axis', '-1'))
k = int(attrs.get('k', '1'))
ret_type = attrs.get('ret_typ')
dtype = attrs.get('dtype')
outputs = [name + '_output0']

if ret_type and ret_type == 'both':
if dtype and dtype == 'int64':
outputs.append(name + '_output1')
else:
raise NotImplementedError("ONNX expects indices to be of type int64")
else:
raise NotImplementedError("ONNX expects both value and indices as output")

topk_node = onnx.helper.make_node(
"TopK",
input_nodes,
outputs,
axis=axis,
k=k,
name=name
)

return [topk_node]
5 changes: 3 additions & 2 deletions python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
from ._op_translations import softplus, shape, gather, lp_pooling, size
from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
from ._op_translations import concat, hardmax
from ._op_translations import concat, hardmax, topk
from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected
from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm
from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm
Expand Down Expand Up @@ -147,5 +147,6 @@
'DepthToSpace' : depthtospace,
'SpaceToDepth' : spacetodepth,
'Hardmax' : hardmax,
'LpNormalization' : lpnormalization
'LpNormalization' : lpnormalization,
'TopK' : topk
}
8 changes: 8 additions & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,3 +779,11 @@ def lpnormalization(attrs, inputs, proto_obj):
axis = int(attrs.get("axis", -1))
new_attrs.update(axis=axis)
return 'norm', new_attrs, inputs


def topk(attrs, inputs, proto_obj):
"""Returns the top k elements in an input array along the given axis."""
new_attrs = translation_utils._add_extra_attributes(attrs,
{'ret_typ': 'both',
'dtype': 'int64'})
return 'topk', new_attrs, inputs
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@
'test_max_',
'test_softplus',
'test_reduce_',
'test_split_equal'
'test_split_equal',
'test_top_k'
],
'import': ['test_gather',
'test_softsign',
Expand Down

0 comments on commit b7cca01

Please sign in to comment.