Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX import/export: TopK
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 12, 2018
1 parent 449e17d commit d9066bb
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 7 deletions.
37 changes: 37 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,3 +1655,40 @@ def convert_size(node, **kwargs):
and return the created node.
"""
return create_basic_op_node('Size', node, kwargs)


@mx_op.register("topk")
def convert_topk(node, **kwargs):
"""Map MXNet's size_array operator attributes to onnx's Size operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = int(attrs.get('axis', '-1'))
k = int(attrs.get('k', '1'))
ret_type = attrs.get('ret_typ')
outputs = [name+'_output0']

if ret_type and ret_type == 'both':
outputs.append(name + '_output1')
else:
raise NotImplementedError("ONNX expects both value and indices as output")

topk_node = onnx.helper.make_node(
"TopK",
input_nodes,
[outputs[0], 'cast_'+outputs[1]],
axis=axis,
k=k,
name=name
)

cast_node = onnx.helper.make_node(
"Cast",
['cast_'+outputs[1]],
[outputs[1]],
to=getattr(onnx.TensorProto, 'INT64'),
name=outputs[1]
)

return [topk_node, cast_node]
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# If converted node is NodeProto, add it in processed nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
node_name = converted_node.name if converted_node.name else converted_node.output[0]
node_name = converted_node.output[0]
if node_name in graph_outputs:
onnx_processed_outputs.append(
make_tensor_value_info(
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ._op_translations import reduce_sum_square, reduce_l1, reduce_l2, max_roi_pooling
from ._op_translations import log_softmax, softsign, lesser, greater, equal
from ._op_translations import logical_and, logical_or, logical_xor, logical_not
from ._op_translations import mean, depthtospace, spacetodepth
from ._op_translations import mean, depthtospace, spacetodepth, topk

# convert_map defines maps of ONNX operator names to converter functor(callable)
# defined in the op_translations module.
Expand Down Expand Up @@ -144,5 +144,6 @@
'HardSigmoid' : hardsigmoid,
'LpPool' : lp_pooling,
'DepthToSpace' : depthtospace,
'SpaceToDepth' : spacetodepth
'SpaceToDepth' : spacetodepth,
'TopK' : topk,
}
7 changes: 7 additions & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,10 @@ def spacetodepth(attrs, inputs, proto_obj):
new_attrs = translation_utils._fix_attribute_names(attrs, {'blocksize':'block_size'})

return "space_to_depth", new_attrs, inputs


def topk(attrs, inputs, proto_obj):
"""Returns the top k elements in an input array along the given axis."""
new_attrs = translation_utils._add_extra_attributes(attrs,
{'ret_typ': 'both'})
return 'topk', new_attrs, inputs
6 changes: 4 additions & 2 deletions tests/python-pytest/onnx/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,7 @@ def run(self, inputs, **kwargs):
args = dict(zip(data_names, data_forward))
exe = self.symbol.bind(ctx, args=args, aux_states=self.aux_params)
exe.forward(is_train=False)
result = exe.outputs[0].asnumpy()
return [result]
result = []
for output in exe.outputs:
result.append(output.asnumpy())
return result
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/export/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@
'test_hardsigmoid',
'test_instancenorm',
'test_shape',
'test_size'
'test_size',
'test_top_k'
]

BASIC_MODEL_TESTS = [
Expand Down
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/import/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@
'test_operator_params',
'test_operator_permute2',
'test_depthtospace',
'test_size'
'test_size',
'test_top_k'
]

BASIC_MODEL_TESTS = [
Expand Down

0 comments on commit d9066bb

Please sign in to comment.