diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index c5e42a06a343..fe81af93d750 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2393,18 +2393,67 @@ def convert_topk(node, **kwargs): def convert_take(node, **kwargs): """Map MXNet's Take operator attributes to onnx's Gather operator. """ + from onnx.helper import make_node + from onnx import TensorProto name, input_nodes, attrs = get_inputs(node, kwargs) - axis = int(attrs.get('axis', 0)) + mode = str(attrs.get('mode', 'clip')) - node = onnx.helper.make_node( - "Gather", - input_nodes, - [name], - axis=axis, - name=name, - ) - return [node] + data = input_nodes[0] + indices = input_nodes[1] + + nodes = [ + make_node('Cast', [indices], [name+'_indices'], to=int(TensorProto.INT64)), + ] + + if mode == 'raise': + nodes += [ + make_node('Gather', [data, name+'_indices'], [name], axis=axis, name=name) + ] + + return nodes + + nodes += [ + create_tensor([-1], name+'_-1', kwargs["initializer"]), + make_node('Shape', [data], [name+'_data_shape']), + ] + + # corner case + if axis == -1: + nodes += [ + make_node('Shape', [name+'_data_shape'], [name+'_data_dim']), + make_node('Add', [name+'_data_dim', name+'_-1'], [name+'_axis_max']), + make_node('Slice', [name+'_data_shape', name+'_axis_max', name+'_data_dim'], [name+'_slice0_out']), + ] + + else: + nodes += [ + create_tensor([axis], name+'_axis', kwargs["initializer"]), + create_tensor([axis+1], name+'_axis+1', kwargs["initializer"]), + make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis+1'], [name+'_slice0_out']), + ] + + if mode == 'clip': + nodes += [ + create_tensor([0], name+'_0', kwargs["initializer"]), + make_node('Add', [name+'_slice0_out', name+'_-1'], [name+'_max']), + make_node('Greater', [name+'_indices', name+'_max'], [name+'_max_mask']), + make_node('Where', [name+'_max_mask', name+'_max', name+'_indices'], [name+'_where0_out']), + make_node('Less', [name+'_indices', name+'_0'], [name+'_min_mask']), + make_node('Where', [name+'_min_mask', name+'_0', name+'_where0_out'], [name+'_where1_out']), + make_node('Gather', [data, name+'_where1_out'], [name], axis=axis, name=name) + ] + + elif mode == 'wrap': + nodes += [ + make_node('Mod', [name+'_indices', name+'_slice0_out'], [name+'_mod0_out']), + make_node('Gather', [data, name+'_mod0_out'], [name], axis=axis, name=name) + ] + + else: + raise NotImplementedError("mode must be clip, wrap or raise.") + + return nodes @mx_op.register("LayerNorm") diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 90ec8f58b250..eb74630d7bca 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1139,3 +1139,24 @@ def test_onnx_export_tile(tmp_path, dtype, reps): x = mx.nd.random.normal(0, 100, (5, 6)).astype(dtype) M = def_model('tile', reps=reps) op_export_test('tile', M, [x], tmp_path) + + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2]) +@pytest.mark.parametrize('mode', ['clip', 'wrap']) +def test_onnx_export_take(tmp_path, dtype, axis, mode): + x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype) + y = mx.random.randint(-100, 100, (6, 7)).astype(dtype) + M1 = def_model('take') + op_export_test('take1', M1, [x, y], tmp_path) + M2 = def_model('take', axis=axis, mode=mode) + op_export_test('take2', M2, [x, y], tmp_path) + + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2]) +def test_onnx_export_take_raise(tmp_path, dtype, axis): + x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype) + y = mx.random.randint(0, 3, (6, 7)).astype(dtype) + M = def_model('take', axis=axis, mode='raise') + op_export_test('take', M, [x, y], tmp_path) \ No newline at end of file