diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 22266454cccd..46767980f4c3 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -186,13 +186,13 @@ def create_tensor(tensor_list, tensor_name, initializer, dtype='int64'): dims = np.shape(tensor_np) tensor_node = onnx.helper.make_tensor_value_info(tensor_name, data_type, dims) if dtype == np.float16: - tensor_list = tensor_np.view(dtype=np.uint16).flatten().tolist() + tensor_np = tensor_np.view(dtype=np.uint16) initializer.append( onnx.helper.make_tensor( name=tensor_name, data_type=data_type, dims=dims, - vals=tensor_list, + vals=tensor_np.flatten().tolist(), raw=False ) ) @@ -3306,3 +3306,45 @@ def convert_gather_nd(node, **kwargs): ] return nodes + + + +@mx_op.register('slice_like') +def convert_slice_like(node, **kwargs): + """Map MXNet's slice_like operator to onnx Slice operator.""" + from onnx.helper import make_node, make_tensor + from onnx import TensorProto + + name, input_nodes, attrs = get_inputs(node, kwargs) + + axes = convert_string_to_list(attrs.get('axes', 'None')) + zero = make_tensor(name+'_zero', TensorProto.INT64, [1], [0]) + + nodes = [] + if axes == [None]: + nodes += [ + make_node('Shape', [input_nodes[1]], [name+'_shape_1']), + make_node('Shape', [name+'_shape_1'], [name+'_dim_1']), + make_node('ConstantOfShape', [name+'_dim_1'], [name+'_starts'], value=zero), + make_node('Slice', [input_nodes[0], name+'_starts', name+'_shape_1'], [name]) + ] + else: + axes = [[i] for i in axes] + nodes += [ + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor(axes, name+'_axes_', kwargs['initializer']), + make_node('Shape', [input_nodes[0]], [name+'_shape_0']), + make_node('Shape', [input_nodes[1]], [name+'_shape_1']), + make_node('Shape', [name+'_shape_0'], [name+'_dim_0']), + make_node('Less', [name+'_axes_', name+'_0'], [name+'_less']), + make_node('Cast', [name+'_less'], [name+'_mask'], to=int(TensorProto.INT64)), + make_node('Mul', [name+'_mask', name+'_dim_0'], [name+'_mul']), + make_node('Add', [name+'_axes_', name+'_mul'], [name+'_axes']), + make_node('ConstantOfShape', [name+'_dim_0'], [name+'_starts'], value=zero), + make_node('GatherND', [name+'_shape_1', name+'_axes'], [name+'_gather']), + make_node('ScatterND', [name+'_shape_0', name+'_axes', name+'_gather'], + [name+'_ends']), + make_node('Slice', [input_nodes[0], name+'_starts', name+'_ends'], [name]) + ] + + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 2c363a96ba04..21d46703f5e8 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -539,3 +539,21 @@ def test_onnx_export_gather_nd(tmp_path, dtype): M2 = def_model('gather_nd') op_export_test('gather_nd2', M2, [x2, y2], tmp_path) + +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) +@pytest.mark.parametrize('axes', [None, (0, 1, 2), (-2, -3), (-2, 0)]) +def test_onnx_export_slice_like(tmp_path, dtype, axes): + x = mx.nd.random.uniform(0, 1, (4, 5, 6, 7)).astype(dtype) + if axes is None: + M = def_model('slice_like') + y = mx.nd.zeros((2, 3, 4, 5), dtype=dtype) + op_export_test('slice_like', M, [x, y], tmp_path) + else: + M = def_model('slice_like', axes=axes) + y1 = mx.nd.zeros((2, 3, 4), dtype=dtype) + y2 = mx.nd.zeros((2, 3, 4, 5), dtype=dtype) + y3 = mx.nd.zeros((2, 3, 4, 5, 6), dtype=dtype) + op_export_test('slice_like_1', M, [x, y1], tmp_path) + op_export_test('slice_like_2', M, [x, y2], tmp_path) + op_export_test('slice_like_3', M, [x, y3], tmp_path) +