From f7e6b5f89c39f2d9b4fad58234e82c53c9db6ad6 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 5 Feb 2021 10:27:23 -0800 Subject: [PATCH 1/4] rewrite take --- .../contrib/onnx/mx2onnx/_op_translations.py | 67 ++++++++++++++++--- tests/python-pytest/onnx/test_operators.py | 12 ++++ 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index c5e42a06a343..4cf512945949 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2393,18 +2393,67 @@ def convert_topk(node, **kwargs): def convert_take(node, **kwargs): """Map MXNet's Take operator attributes to onnx's Gather operator. """ + from onnx.helper import make_node + from onnx import TensorProto name, input_nodes, attrs = get_inputs(node, kwargs) - axis = int(attrs.get('axis', 0)) + mode = str(attrs.get('mode', 'clip')) - node = onnx.helper.make_node( - "Gather", - input_nodes, - [name], - axis=axis, - name=name, - ) - return [node] + data = input_nodes[0] + indices = input_nodes[1] + + nodes = [ + make_node('Cast', [indices], [name+'_indices'], to=int(TensorProto.INT64)), + ] + + if mode == 'raise': + nodes += [ + make_node('Gather', [data, name+'_indices'], [name], axis=axis, name=name) + ] + + return nodes + + nodes += [ + create_tensor([-1], name+'_-1', kwargs["initializer"]), + make_node('Shape', [data], [name+'_data_shape']), + ] + + # cornor case + if axis == -1: + nodes += [ + make_node('Shape', [name+'_data_shape'], [name+'_data_dim']), + make_node('Add', [name+'_data_dim', name+'_-1'], [name+'_axis_max']), + make_node('Slice', [name+'_data_shape', name+'_axis_max', name+'_data_dim'], [name+'_slice0_out']), + ] + + else: + nodes += [ + create_tensor([axis], name+'_axis', kwargs["initializer"]), + create_tensor([axis+1], name+'_axis+1', kwargs["initializer"]), + make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis+1'], [name+'_slice0_out']), + ] + + if mode == 'clip': + nodes += [ + create_tensor([0], name+'_0', kwargs["initializer"]), + make_node('Add', [name+'_slice0_out', name+'_-1'], [name+'_max']), + make_node('Greater', [name+'_indices', name+'_max'], [name+'_max_mask']), + make_node('Where', [name+'_max_mask', name+'_max', name+'_indices'], [name+'_where0_out']), + make_node('Less', [name+'_indices', name+'_0'], [name+'_min_mask']), + make_node('Where', [name+'_min_mask', name+'_0', name+'_where0_out'], [name+'_where1_out']), + make_node('Gather', [data, name+'_where1_out'], [name], axis=axis, name=name) + ] + + elif mode == 'wrap': + nodes += [ + make_node('Mod', [name+'_indices', name+'_slice0_out'], [name+'_mod0_out']), + make_node('Gather', [data, name+'_mod0_out'], [name], axis=axis, name=name) + ] + + else: + raise NotImplementedError("mode must be clip, wrap or raise.") + + return nodes @mx_op.register("LayerNorm") diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 90ec8f58b250..8a8966cdb206 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1139,3 +1139,15 @@ def test_onnx_export_tile(tmp_path, dtype, reps): x = mx.nd.random.normal(0, 100, (5, 6)).astype(dtype) M = def_model('tile', reps=reps) op_export_test('tile', M, [x], tmp_path) + + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2]) +@pytest.mark.parametrize('mode', ['clip', 'wrap']) +def test_onnx_export_take(tmp_path, dtype, axis, mode): + x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype) + y = mx.random.randint(-100, 100, (6, 7)).astype(dtype) + M1 = def_model('take') + op_export_test('take1', M1, [x, y], tmp_path) + M2 = def_model('take', axis=axis, mode=mode) + op_export_test('take2', M2, [x, y], tmp_path) From 084d95f28455e9eafe15a0e43f3d6bcf91bb9262 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 5 Feb 2021 12:08:17 -0800 Subject: [PATCH 2/4] fix typo --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 4cf512945949..fe81af93d750 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2418,7 +2418,7 @@ def convert_take(node, **kwargs): make_node('Shape', [data], [name+'_data_shape']), ] - # cornor case + # corner case if axis == -1: nodes += [ make_node('Shape', [name+'_data_shape'], [name+'_data_dim']), From cb7c90b2d5a24223c36500621933cf1eab02ed1b Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Mon, 8 Feb 2021 11:52:28 -0800 Subject: [PATCH 3/4] add test for raise --- tests/python-pytest/onnx/test_operators.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 8a8966cdb206..ab95ae038b43 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1151,3 +1151,12 @@ def test_onnx_export_take(tmp_path, dtype, axis, mode): op_export_test('take1', M1, [x, y], tmp_path) M2 = def_model('take', axis=axis, mode=mode) op_export_test('take2', M2, [x, y], tmp_path) + + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2]) +def test_onnx_export_take_raise(tmp_path, dtype, axis): + x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype) + y = mx.random.randint(0, 4, (6, 7)).astype(dtype) + M = def_model('take', axis=axis, mode='raise') + op_export_test('take', M, [x, y], tmp_path) \ No newline at end of file From 7cc4d867aa6417b000319f28135b1ec109cb5fdb Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 12 Feb 2021 10:44:12 -0800 Subject: [PATCH 4/4] fix test_raise --- tests/python-pytest/onnx/test_operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index ab95ae038b43..eb74630d7bca 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1157,6 +1157,6 @@ def test_onnx_export_take(tmp_path, dtype, axis, mode): @pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2]) def test_onnx_export_take_raise(tmp_path, dtype, axis): x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype) - y = mx.random.randint(0, 4, (6, 7)).astype(dtype) + y = mx.random.randint(0, 3, (6, 7)).astype(dtype) M = def_model('take', axis=axis, mode='raise') op_export_test('take', M, [x, y], tmp_path) \ No newline at end of file