diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 2ce0a6bc1def..2521cf5adb37 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -4035,6 +4035,47 @@ def convert_argsort(node, **kwargs): return nodes +@mx_op.register('one_hot') +def convert_one_hot(node, **kwargs): + """Map MXNet's one_hot operator attributes to onnx's OneHot operator + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + depth = int(attrs.get('depth')) + on_value = float(attrs.get('on_value', 1.)) + off_value = float(attrs.get('off_value', 0.)) + dtype = attrs.get('dtype', 'float32') + + create_tensor([off_value, on_value], name+'_values', kwargs['initializer'], dtype=np.dtype(dtype)) + create_tensor([depth], name+'_depth', kwargs['initializer']) + nodes = [ + make_node('OneHot', [input_nodes[0], name+'_depth', name+'_values'], [name], name=name) + ] + + return nodes + + +@mx_op.register('_random_uniform_like') +def convert_random_uniform_like(node, **kwargs): + """Map MXNet's random_uniform_like operator attributes to onnx's RandomUniformLike operator + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + low = float(attrs.get('low', 0.)) + high = float(attrs.get('high', 1.)) + dtype = attrs.get('dtype', 'float32') + + nodes = [ + make_node('RandomUniformLike', [input_nodes[0]], [name], name=name, + dtype=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], + low=low, high=high) + ] + + return nodes + + @mx_op.register('SequenceReverse') def convert_sequence_reverse(node, **kwargs): """Map MXNet's SequenceReverse op diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 9931628118b8..f73672b59cc7 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1183,6 +1183,16 @@ def test_onnx_export_take_raise(tmp_path, dtype, axis): op_export_test('take', M, [x, y], tmp_path) +# onnxruntime currently does not support int32 +@pytest.mark.parametrize("dtype", ["float16", "float32", "int64"]) +@pytest.mark.parametrize("depth", [1, 3, 5, 10]) +@pytest.mark.parametrize("shape", [(1,1), (1,5), (5,5), (3,4,5)]) +def test_onnx_export_one_hot(tmp_path, dtype, depth, shape): + M = def_model('one_hot', depth=depth, dtype=dtype) + x = mx.random.randint(0, 10, shape).astype('int64') + op_export_test('one_hot', M, [x], tmp_path) + + @pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) @pytest.mark.parametrize('params', [((6, 5, 4), [1, 2, 4, 5, 6]), ((7, 3, 5), [1, 7, 4]),