Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX fix embedding and slice (#19695)
Browse files Browse the repository at this point in the history
* fix slice

* Update _op_translations.py
  • Loading branch information
Zha0q1 authored Dec 19, 2020
1 parent d06d705 commit d538eb3
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 22 deletions.
61 changes: 42 additions & 19 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def convert_string_to_list(string_val):
val = val.replace("L", "")
val = val.replace("[", "")
val = val.replace("]", "")
if val not in ("", "None"):
if val == "None":
result_list.append(None)
elif val != "":
result_list.append(int(val))

return result_list
Expand Down Expand Up @@ -2516,17 +2518,17 @@ def convert_sequencemask(node, **kwargs):
def convert_embedding(node, **kwargs):
"""Map MXNet's Embedding operator attributes to onnx's
Gather operator."""
from onnx.helper import make_node
from onnx import TensorProto

name, input_nodes, attrs = get_inputs(node, kwargs)
axis = int(attrs.get('axis', 0))
node = onnx.helper.make_node(
"Gather",
[input_nodes[1], input_nodes[0]],
[name],
axis=axis,
name=name
)
return [node]

nodes = [
make_node('Cast', [input_nodes[0]], [name+'_indices_casted'], to=int(TensorProto.INT64)),
make_node('Gather', [input_nodes[1], name+'_indices_casted'], [name], axis=axis, name=name)
]
return nodes

@mx_op.register("stack")
def convert_stack(node, **kwargs):
Expand Down Expand Up @@ -2558,19 +2560,40 @@ def convert_stack(node, **kwargs):
@mx_op.register("slice")
def convert_slice(node, **kwargs):
"""Map MXNet's slice operator to onnx Slice operator."""
from onnx.helper import make_node

name, input_nodes, attrs = get_inputs(node, kwargs)
starts = convert_string_to_list(attrs.get("begin"))
ends = convert_string_to_list(attrs.get("end"))
steps = attrs.get("step", [])

starts = convert_string_to_list(attrs.get('begin'))
ends = convert_string_to_list(attrs.get('end'))
steps = convert_string_to_list(attrs.get('step', '[]'))

assert len(starts) == len(ends)
if len(steps) == 0 or (len(steps) == 1 and steps[0] is None):
steps = [1 for x in starts]
else:
assert len(steps) == len(starts)
steps = [1 if x is None else x for x in steps]
for i, s in enumerate(steps):
if s < 0:
raise NotImplementedError('slice operator does not support negative steps yet')
if starts[i] is None:
starts[i] = 0
if ends[i] is None:
ends[i] = 2**63-1

nodes = [
create_const_node(name+"_begin", np.array(starts), kwargs),
create_const_node(name+"_end", np.array(ends), kwargs)
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs),
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs),
create_const_scalar_node(name+'_len_s', np.int64(len(starts)), kwargs),
make_node('Range', [name+'_0_s', name+'_len_s', name+'_1_s'], [name+'_axes']),
create_tensor(starts, name+'_starts', kwargs['initializer']),
create_tensor(ends, name+'_ends', kwargs['initializer']),
create_tensor(steps, name+'_steps', kwargs['initializer']),
make_node("Slice", [input_nodes[0], name+'_starts', name+'_ends', name+'_axes',
name+'_steps'], [name], name=name)
]
inputs = [input_nodes[0], name+"_begin", name+"_end"]
if len(steps) > 0:
nodes.append(create_const_node(name+"_steps", np.array(steps, dtype='int64'), kwargs))
inputs.append(name+"_steps")
nodes.append(onnx.helper.make_node("Slice", inputs, [name], name=name))

return nodes


Expand Down
10 changes: 7 additions & 3 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ def test_onnx_export_abs(tmp_path):
op_export_test('abs', M, [x], tmp_path)


def test_onnx_export_slice(tmp_path):
M = def_model('slice', begin=(0,1), end=(2,4))
x = mx.nd.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]], dtype='float32')
@pytest.mark.parametrize('dtype', ['float32', 'float64', 'float16', 'int32', 'int64'])
@pytest.mark.parametrize('params', [[(0, 1), (2,3), (1, 1)],
[(None, 1), (2, None), None],
[(0, 0, 0), (None, 4, 5), (None, 1, 2)]])
def test_onnx_export_slice(tmp_path, dtype, params):
M = def_model('slice', begin=params[0], end=params[1], step=params[2])
x = mx.nd.arange(start=0, stop=60, dtype=dtype).reshape((3, 4, 5))
op_export_test('slice', M, [x], tmp_path)


Expand Down

0 comments on commit d538eb3

Please sign in to comment.