Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
rewrite take
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei Chu committed Feb 8, 2021
1 parent 031dc5b commit 47cca6d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 9 deletions.
67 changes: 58 additions & 9 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2397,18 +2397,67 @@ def convert_topk(node, **kwargs):
def convert_take(node, **kwargs):
"""Map MXNet's Take operator attributes to onnx's Gather operator.
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = int(attrs.get('axis', 0))
mode = str(attrs.get('mode', 'clip'))

node = onnx.helper.make_node(
"Gather",
input_nodes,
[name],
axis=axis,
name=name,
)
return [node]
data = input_nodes[0]
indices = input_nodes[1]

nodes = [
make_node('Cast', [indices], [name+'_indices'], to=int(TensorProto.INT64)),
]

if mode == 'raise':
nodes += [
make_node('Gather', [data, name+'_indices'], [name], axis=axis, name=name)
]

return nodes

nodes += [
create_tensor([-1], name+'_-1', kwargs["initializer"]),
make_node('Shape', [data], [name+'_data_shape']),
]

# cornor case
if axis == -1:
nodes += [
make_node('Shape', [name+'_data_shape'], [name+'_data_dim']),
make_node('Add', [name+'_data_dim', name+'_-1'], [name+'_axis_max']),
make_node('Slice', [name+'_data_shape', name+'_axis_max', name+'_data_dim'], [name+'_slice0_out']),
]

else:
nodes += [
create_tensor([axis], name+'_axis', kwargs["initializer"]),
create_tensor([axis+1], name+'_axis+1', kwargs["initializer"]),
make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis+1'], [name+'_slice0_out']),
]

if mode == 'clip':
nodes += [
create_tensor([0], name+'_0', kwargs["initializer"]),
make_node('Add', [name+'_slice0_out', name+'_-1'], [name+'_max']),
make_node('Greater', [name+'_indices', name+'_max'], [name+'_max_mask']),
make_node('Where', [name+'_max_mask', name+'_max', name+'_indices'], [name+'_where0_out']),
make_node('Less', [name+'_indices', name+'_0'], [name+'_min_mask']),
make_node('Where', [name+'_min_mask', name+'_0', name+'_where0_out'], [name+'_where1_out']),
make_node('Gather', [data, name+'_where1_out'], [name], axis=axis, name=name)
]

elif mode == 'wrap':
nodes += [
make_node('Mod', [name+'_indices', name+'_slice0_out'], [name+'_mod0_out']),
make_node('Gather', [data, name+'_mod0_out'], [name], axis=axis, name=name)
]

else:
raise NotImplementedError("mode must be clip, wrap or raise.")

return nodes


@mx_op.register("LayerNorm")
Expand Down
12 changes: 12 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,3 +1112,15 @@ def test_onnx_export_argsort(tmp_path, dtype, axis, is_ascend, dtype_i):
kwargs['is_ascend'] = is_ascend
M = def_model('argsort', axis=axis, dtype=dtype_i, **kwargs)
op_export_test('argsort', M, [A], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2])
@pytest.mark.parametrize('mode', ['clip', 'wrap'])
def test_onnx_export_take(tmp_path, dtype, axis, mode):
x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype)
y = mx.random.randint(-100, 100, (6, 7)).astype(dtype)
M1 = def_model('take')
op_export_test('take1', M1, [x, y], tmp_path)
M2 = def_model('take', axis=axis, mode=mode)
op_export_test('take2', M2, [x, y], tmp_path)

0 comments on commit 47cca6d

Please sign in to comment.