Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX support for _split_v2 (#20250)
Browse files Browse the repository at this point in the history
* splitv2

* support large index

* fix sanity

* fix sanity

Co-authored-by: Wei Chu <[email protected]>
  • Loading branch information
waytrue17 and Wei Chu authored May 6, 2021
1 parent 923b608 commit e329e84
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -4440,6 +4440,7 @@ def convert_sequence_reverse(node, **kwargs):

return nodes


@mx_op.register("RNN")
def convert_RNN(node, **kwargs):
"""Map MXNet's RNN operator attributes to onnx's operators
Expand Down Expand Up @@ -4810,6 +4811,7 @@ def convert_RNN(node, **kwargs):
raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode")
return nodes


@mx_op.register('_rnn_param_concat')
def convert_rnn_param_concat(node, **kwargs):
"""Map MXNet's _rnn_param_concat operator
Expand Down Expand Up @@ -4852,3 +4854,36 @@ def convert_contrib_div_sqrt_dim(node, **kwargs):
]

return nodes


@mx_op.register('_split_v2')
def convert_contrib_split_v2(node, **kwargs):
"""Map MXNet's _split_v2 operator
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)
axis = int(attrs.get('axis', 0))
squeeze_axis = attrs.get('squeeze_axis', 'False')
sections = int(attrs.get('sections', 0))
indices = convert_string_to_list(attrs.get('indices', '[]'))
if sections <= 0 and len(indices) == 0:
raise NotImplementedError('section or indices must be set')
if sections > 0:
output_nodes = [name+str(i) for i in range(sections)]
if squeeze_axis == 'False':
nodes = [
make_node('Split', input_nodes, output_nodes, axis=axis),
]
else:
output_nodes_ = [name+str(i)+'_' for i in range(sections)]
nodes = [
make_node('Split', input_nodes, output_nodes_, axis=axis),
]
for i in range(sections):
nodes += [
make_node("Squeeze", [output_nodes_[i]], [output_nodes[i]], axes=[axis]),
]
else:
raise NotImplementedError('indices is supported since ONNX 1.8.0 (opset13), please upgrade ONNX version')

return nodes
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def get_inputs(node, kwargs):
outputs_lookup = kwargs["outputs_lookup"]
inputs = node["inputs"]
attrs = node.get("attrs", {})

input_nodes = []
for ip in inputs:
input_node_name = outputs_lookup[ip[0]][ip[1]].name
Expand Down Expand Up @@ -1732,3 +1731,69 @@ def convert_logsoftmax(node, **kwargs):
)

return [node]


@mx_op.register('_split_v2', OPSET_VERSION)
def convert_contrib_split_v2(node, **kwargs):
"""Map MXNet's _split_v2 operator
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)
axis = int(attrs.get('axis', 0))
squeeze_axis = attrs.get('squeeze_axis', 'False')
sections = int(attrs.get('sections', 0))
indices = convert_string_to_list(attrs.get('indices', '[]'))
if sections <= 0 and len(indices) == 0:
raise NotImplementedError('section or indices must be set')
if sections > 0:
output_nodes = [name+str(i) for i in range(sections)]
if squeeze_axis == 'False':
nodes = [
make_node('Split', input_nodes, output_nodes, axis=axis),
]
else:
output_nodes_ = [name+str(i)+'_' for i in range(sections)]
create_tensor([axis], name+'_axis', kwargs['initializer'])
nodes = [
make_node('Split', input_nodes, output_nodes_, axis=axis),
]
for i in range(sections):
nodes += [
make_node("Squeeze", [output_nodes_[i], name+'_axis'], [output_nodes[i]]),
]
else:
indices.sort()
split = []
for i in range(1, len(indices)):
if indices[i] >= indices[i-1]:
split.append(indices[i] - indices[i-1])

output_nodes = [name+str(i) for i in range(len(split)+1)]
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([axis], name+'_axis', kwargs['initializer'])
create_tensor([axis+1], name+'_axis+1', kwargs['initializer'])
create_tensor(split, name+'_split_', kwargs['initializer'])
create_tensor([sum(split)], name+'_sum', kwargs['initializer'])
nodes = [
make_node('Shape', input_nodes, [name+'_shape']),
make_node('Slice', [name+'_shape', name+'_axis', name+'_axis+1', name+'_0'], [name+'_dim']),
make_node('Sub', [name+'_dim', name+'_sum'], [name+'_sub']),
make_node('Concat', [name+'_split_', name+'_sub'], [name+'_concat'], axis=0),
make_node('Less', [name+'_concat', name+'_0'], [name+'_less']),
make_node('Where', [name+'_less', name+'_0', name+'_concat'], [name+'_split']),
]
if squeeze_axis == 'False':
nodes += [
make_node('Split', [input_nodes[0], name+'_split'], output_nodes, axis=axis),
]
else:
output_nodes_ = [name+str(i)+'_' for i in range(len(split)+1)]
nodes += [
make_node('Split', [input_nodes[0], name+'_split'], output_nodes_, axis=axis),
]
for i, output_node in enumerate(output_nodes):
nodes += [
make_node("Squeeze", [output_nodes_[i], name+'_axis'], [output_node]),
]

return nodes
16 changes: 16 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,3 +1858,19 @@ def rand_check(out):
def rand_check_nd(out):
return rand_check(out.asnumpy())
op_export_test('sample_multinomial', M, [x], tmp_path, mx_map=rand_check_nd, onnx_map=rand_check)


@pytest.mark.parametrize("dtype", ['float32', 'int32', 'int64'])
@pytest.mark.parametrize('params', [((2, 4, 6), (1, ), 0, True),
((4, 5, 6), (2, 4), 1, False),
((4, 5, 6, 7), (0, 2, 4), 2, False),
((4, 5, 6, 7), 3, -2, False),
((2, 6, 8), 8, -1, True)])
def test_onnx_export_split_v2(tmp_path, dtype, params):
from onnx.defs import onnx_opset_version
if onnx_opset_version() < 13 and not isinstance(params[1], int):
# opset12 only supports sections. indices is supported since opset13
return
M = def_model('split_v2', indices_or_sections=params[1], axis=params[2], squeeze_axis=params[3])
x = mx.nd.random.uniform(0, 10, params[0]).astype(dtype)
op_export_test('split_v2', M, [x], tmp_path)

0 comments on commit e329e84

Please sign in to comment.