Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Add onnx export support for one_hot and random_uniform_like and unit tests for one_hot. #19952

Merged
merged 8 commits into from
Mar 2, 2021
41 changes: 41 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,6 +4035,47 @@ def convert_argsort(node, **kwargs):
return nodes


@mx_op.register('one_hot')
def convert_one_hot(node, **kwargs):
"""Map MXNet's one_hot operator attributes to onnx's OneHot operator
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

depth = int(attrs.get('depth'))
on_value = float(attrs.get('on_value', 1.))
off_value = float(attrs.get('off_value', 0.))
dtype = attrs.get('dtype', 'float32')

create_tensor([off_value, on_value], name+'_values', kwargs['initializer'], dtype=dtype)
create_tensor([depth], name+'_depth', kwargs['initializer'])
nodes = [
make_node('OneHot', [input_nodes[0], name+'_depth', name+'_values'], [name], name=name)
]

return nodes


@mx_op.register('_random_uniform_like')
def convert_random_uniform_like(node, **kwargs):
"""Map MXNet's random_uniform_like operator attributes to onnx's RandomUniformLike operator
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

low = float(attrs.get('low', 0.))
high = float(attrs.get('high', 1.))
dtype = attrs.get('dtype', 'float32')

nodes = [
make_node('RandomUniformLike', [input_nodes[0]], [name], name=name,
dtype=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)],
low=low, high=high)
]

return nodes


@mx_op.register('SequenceReverse')
def convert_sequence_reverse(node, **kwargs):
"""Map MXNet's SequenceReverse op
Expand Down
9 changes: 9 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,15 @@ def test_onnx_export_take_raise(tmp_path, dtype, axis):
op_export_test('take', M, [x, y], tmp_path)


# onnxruntime currently only supports float32 and int64
@pytest.mark.parametrize("dtype", ["float32", "int64"])
@pytest.mark.parametrize("depth", [1, 3, 5, 10])
def test_onnx_export_one_hot(tmp_path, dtype, depth):
M = def_model('one_hot', depth=depth, dtype=dtype)
x = mx.random.randint(0, 10, (depth * depth)).astype('int64')
josephevans marked this conversation as resolved.
Show resolved Hide resolved
op_export_test('one_hot', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
@pytest.mark.parametrize('params', [((6, 5, 4), [1, 2, 4, 5, 6]),
((7, 3, 5), [1, 7, 4]),
Expand Down