Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX support for slice_like (#19782)
Browse files Browse the repository at this point in the history
* slice_like

* Update _op_translations.py

* Update _op_translations.py
  • Loading branch information
Zha0q1 committed Jan 27, 2021
1 parent bd9d80c commit 7de30c2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
46 changes: 44 additions & 2 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,13 @@ def create_tensor(tensor_list, tensor_name, initializer, dtype='int64'):
dims = np.shape(tensor_np)
tensor_node = onnx.helper.make_tensor_value_info(tensor_name, data_type, dims)
if dtype == np.float16:
tensor_list = tensor_np.view(dtype=np.uint16).flatten().tolist()
tensor_np = tensor_np.view(dtype=np.uint16)
initializer.append(
onnx.helper.make_tensor(
name=tensor_name,
data_type=data_type,
dims=dims,
vals=tensor_list,
vals=tensor_np.flatten().tolist(),
raw=False
)
)
Expand Down Expand Up @@ -3306,3 +3306,45 @@ def convert_gather_nd(node, **kwargs):
]

return nodes



@mx_op.register('slice_like')
def convert_slice_like(node, **kwargs):
"""Map MXNet's slice_like operator to onnx Slice operator."""
from onnx.helper import make_node, make_tensor
from onnx import TensorProto

name, input_nodes, attrs = get_inputs(node, kwargs)

axes = convert_string_to_list(attrs.get('axes', 'None'))
zero = make_tensor(name+'_zero', TensorProto.INT64, [1], [0])

nodes = []
if axes == [None]:
nodes += [
make_node('Shape', [input_nodes[1]], [name+'_shape_1']),
make_node('Shape', [name+'_shape_1'], [name+'_dim_1']),
make_node('ConstantOfShape', [name+'_dim_1'], [name+'_starts'], value=zero),
make_node('Slice', [input_nodes[0], name+'_starts', name+'_shape_1'], [name])
]
else:
axes = [[i] for i in axes]
nodes += [
create_tensor([0], name+'_0', kwargs['initializer']),
create_tensor(axes, name+'_axes_', kwargs['initializer']),
make_node('Shape', [input_nodes[0]], [name+'_shape_0']),
make_node('Shape', [input_nodes[1]], [name+'_shape_1']),
make_node('Shape', [name+'_shape_0'], [name+'_dim_0']),
make_node('Less', [name+'_axes_', name+'_0'], [name+'_less']),
make_node('Cast', [name+'_less'], [name+'_mask'], to=int(TensorProto.INT64)),
make_node('Mul', [name+'_mask', name+'_dim_0'], [name+'_mul']),
make_node('Add', [name+'_axes_', name+'_mul'], [name+'_axes']),
make_node('ConstantOfShape', [name+'_dim_0'], [name+'_starts'], value=zero),
make_node('GatherND', [name+'_shape_1', name+'_axes'], [name+'_gather']),
make_node('ScatterND', [name+'_shape_0', name+'_axes', name+'_gather'],
[name+'_ends']),
make_node('Slice', [input_nodes[0], name+'_starts', name+'_ends'], [name])
]

return nodes
18 changes: 18 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,21 @@ def test_onnx_export_gather_nd(tmp_path, dtype):
M2 = def_model('gather_nd')
op_export_test('gather_nd2', M2, [x2, y2], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize('axes', [None, (0, 1, 2), (-2, -3), (-2, 0)])
def test_onnx_export_slice_like(tmp_path, dtype, axes):
x = mx.nd.random.uniform(0, 1, (4, 5, 6, 7)).astype(dtype)
if axes is None:
M = def_model('slice_like')
y = mx.nd.zeros((2, 3, 4, 5), dtype=dtype)
op_export_test('slice_like', M, [x, y], tmp_path)
else:
M = def_model('slice_like', axes=axes)
y1 = mx.nd.zeros((2, 3, 4), dtype=dtype)
y2 = mx.nd.zeros((2, 3, 4, 5), dtype=dtype)
y3 = mx.nd.zeros((2, 3, 4, 5, 6), dtype=dtype)
op_export_test('slice_like_1', M, [x, y1], tmp_path)
op_export_test('slice_like_2', M, [x, y2], tmp_path)
op_export_test('slice_like_3', M, [x, y3], tmp_path)

0 comments on commit 7de30c2

Please sign in to comment.