Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX import/export: TopK
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 18, 2018
1 parent 8dd2fb1 commit 94fb107
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 14 deletions.
37 changes: 37 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,3 +1655,40 @@ def convert_size(node, **kwargs):
and return the created node.
"""
return create_basic_op_node('Size', node, kwargs)


@mx_op.register("topk")
def convert_topk(node, **kwargs):
"""Map MXNet's size_array operator attributes to onnx's Size operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = int(attrs.get('axis', '-1'))
k = int(attrs.get('k', '1'))
ret_type = attrs.get('ret_typ')
outputs = [name+'_output0']

if ret_type and ret_type == 'both':
outputs.append(name + '_output1')
else:
raise NotImplementedError("ONNX expects both value and indices as output")

topk_node = onnx.helper.make_node(
"TopK",
input_nodes,
[outputs[0], 'cast_'+outputs[1]],
axis=axis,
k=k,
name=name
)

cast_node = onnx.helper.make_node(
"Cast",
['cast_'+outputs[1]],
[outputs[1]],
to=getattr(onnx.TensorProto, 'INT64'),
name=outputs[1]
)

return [topk_node, cast_node]
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# If converted node is NodeProto, add it in processed nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
node_name = converted_node.name if converted_node.name else converted_node.output[0]
node_name = converted_node.output[0]
if node_name in graph_outputs:
onnx_processed_outputs.append(
make_tensor_value_info(
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ._op_translations import reduce_sum_square, reduce_l1, reduce_l2, max_roi_pooling
from ._op_translations import log_softmax, softsign, lesser, greater, equal
from ._op_translations import logical_and, logical_or, logical_xor, logical_not
from ._op_translations import mean, depthtospace, spacetodepth
from ._op_translations import mean, depthtospace, spacetodepth, topk

# convert_map defines maps of ONNX operator names to converter functor(callable)
# defined in the op_translations module.
Expand Down Expand Up @@ -144,5 +144,6 @@
'HardSigmoid' : hardsigmoid,
'LpPool' : lp_pooling,
'DepthToSpace' : depthtospace,
'SpaceToDepth' : spacetodepth
'SpaceToDepth' : spacetodepth,
'TopK' : topk,
}
7 changes: 7 additions & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,10 @@ def spacetodepth(attrs, inputs, proto_obj):
new_attrs = translation_utils._fix_attribute_names(attrs, {'blocksize':'block_size'})

return "space_to_depth", new_attrs, inputs


def topk(attrs, inputs, proto_obj):
"""Returns the top k elements in an input array along the given axis."""
new_attrs = translation_utils._add_extra_attributes(attrs,
{'ret_typ': 'both'})
return 'topk', new_attrs, inputs
6 changes: 4 additions & 2 deletions tests/python-pytest/onnx/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,7 @@ def run(self, inputs, **kwargs):
args = dict(zip(data_names, data_forward))
exe = self.symbol.bind(ctx, args=args, aux_states=self.aux_params)
exe.forward(is_train=False)
result = exe.outputs[0].asnumpy()
return [result]
result = []
for output in exe.outputs:
result.append(output.asnumpy())
return result
52 changes: 43 additions & 9 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,19 @@ def forward_pass(sym, arg, aux, data_names, input_data):
# create module
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)
mod.set_params(arg_params=arg, aux_params=aux,
allow_missing=True, allow_extra=True)
if not arg and not aux:
mod.init_params()
else:
mod.set_params(arg_params=arg, aux_params=aux,
allow_missing=True, allow_extra=True)
# run inference
batch = namedtuple('Batch', ['data'])
mod.forward(batch([mx.nd.array(input_data)]), is_train=False)

return mod.get_outputs()[0].asnumpy()
result = []
for output in mod.get_outputs():
result.append(output.asnumpy())
return result


def test_models(model_name, input_shape, output_shape):
Expand Down Expand Up @@ -139,8 +145,8 @@ def test_models(model_name, input_shape, output_shape):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)

# verify the results
npt.assert_equal(result.shape, output_data.shape)
npt.assert_almost_equal(output_data, result, decimal=3)
npt.assert_equal(result[0].shape, output_data.shape)
npt.assert_almost_equal(output_data, result[0], decimal=3)
logging.info(model_name + " conversion successful")


Expand All @@ -157,7 +163,7 @@ def test_model_accuracy(model_name, input_shape):
expected_result= []
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
expected_result.append(result)
expected_result.append(result[0])

params = {}
params.update(arg_params)
Expand All @@ -179,7 +185,7 @@ def test_model_accuracy(model_name, input_shape):
actual_result = []
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
actual_result.append(result)
actual_result.append(result[0])

# verify the results
for expected, actual in zip(expected_result, actual_result):
Expand Down Expand Up @@ -239,7 +245,7 @@ def test_square():

numpy_op = np.square(input1)

npt.assert_almost_equal(result, numpy_op)
npt.assert_almost_equal(result[0], numpy_op)


def test_softmax():
Expand All @@ -261,7 +267,35 @@ def test_softmax():
result = forward_pass(sym, arg_params, aux_params, ['ipsym'], input1)

# Comparing result of forward pass before using onnx export, import
npt.assert_almost_equal(result, softmax_out)
npt.assert_almost_equal(result[0], softmax_out)


@with_seed()
def test_topk():
input1 = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=np.float32)
k = 3
dtype = 'int32'
inputs = [helper.make_tensor_value_info("input1", TensorProto.FLOAT, shape=np.shape(input1))]
sym = mx.sym.topk(mx.sym.Variable('input1'), k=k, ret_typ='both', dtype=dtype)
sym_output = forward_pass(sym, None, None, ['input1'], input1)

outputs = [helper.make_tensor_value_info("output1", TensorProto.FLOAT, shape=np.shape(sym_output[0])),
helper.make_tensor_value_info("output2", TensorProto.FLOAT, shape=np.shape(sym_output[1]))]

nodes = [helper.make_node("TopK", ["input1"], ["output1", "output2"], k=k)]

graph = helper.make_graph(nodes,
"topk_test",
inputs,
outputs)

spacetodepth_model = helper.make_model(graph)

bkd_rep = backend.prepare(spacetodepth_model)
output = bkd_rep.run([input1])

npt.assert_almost_equal(output, sym_output)


@with_seed()
def test_comparison_ops():
Expand Down

0 comments on commit 94fb107

Please sign in to comment.