From 2fb4bcecfb16cb1ed44493f99cfc0d3963885156 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 5 May 2021 18:38:44 -0700 Subject: [PATCH 1/4] splitv2 --- .../_op_translations_opset12.py | 64 ++++++++++++++++++ .../_op_translations_opset13.py | 66 ++++++++++++++++++- tests/python-pytest/onnx/test_operators.py | 16 +++++ 3 files changed, 145 insertions(+), 1 deletion(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index 2f38faa05f2b..417dda33a3f5 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -4440,6 +4440,7 @@ def convert_sequence_reverse(node, **kwargs): return nodes + @mx_op.register("RNN") def convert_RNN(node, **kwargs): """Map MXNet's RNN operator attributes to onnx's operators @@ -4810,6 +4811,7 @@ def convert_RNN(node, **kwargs): raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode") return nodes + @mx_op.register('_rnn_param_concat') def convert_rnn_param_concat(node, **kwargs): """Map MXNet's _rnn_param_concat operator @@ -4852,3 +4854,65 @@ def convert_contrib_div_sqrt_dim(node, **kwargs): ] return nodes + + +@mx_op.register('_contrib_div_sqrt_dim') +def convert_contrib_div_sqrt_dim(node, **kwargs): + """Map MXNet's _contrib_div_sqrt_dim operator + """ + from onnx.helper import make_node + name, input_nodes, _ = get_inputs(node, kwargs) + input_dtypes = get_input_dtypes(node, kwargs) + + dtype = input_dtypes[0] + dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] + + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor([1], name+'_1_f', kwargs['initializer'], dtype=dtype) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Shape', [name+'_shape'], [name+'_dim']), + make_node('Sub', [name+'_dim', name+'_1'], [name+'_dim_m1']), + make_node('Slice', [name+'_shape', name+'_dim_m1', name+'_dim', name+'_0'], [name+'_c_']), + make_node('Cast', [name+'_c_'], [name+'_c'], to=dtype_t), + make_node('Sqrt', [name+'_c'], [name+'_c_sqrt']), + make_node('Div', [name+'_1_f', name+'_c_sqrt'], [name+'_1_over_c_sqrt']), + make_node('Mul', [input_nodes[0], name+'_1_over_c_sqrt'], [name]) + ] + + return nodes + + +@mx_op.register('_split_v2') +def convert_contrib_split_v2(node, **kwargs): + """Map MXNet's _split_v2 operator + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + axis = int(attrs.get('axis', 0)) + squeeze_axis = attrs.get('squeeze_axis', 'False') + sections = int(attrs.get('sections', 0)) + indices = convert_string_to_list(attrs.get('indices', '[]')) + if sections <= 0 and len(indices) == 0: + raise NotImplementedError('section or indices must be set') + if sections > 0: + output_nodes = [name+str(i) for i in range(sections)] + if squeeze_axis == 'False': + nodes = [ + make_node('Split', input_nodes, output_nodes, axis=axis), + ] + else: + output_nodes_ = [name+str(i)+'_' for i in range(sections)] + nodes = [ + make_node('Split', input_nodes, output_nodes_, axis=axis), + ] + for i in range(sections): + nodes += [ + make_node("Squeeze", [output_nodes_[i]], [output_nodes[i]], axes=[axis]), + ] + else: + raise NotImplementedError('indices is supported since ONNX 1.8.0 (opset13), please upgrade ONNX version') + + return nodes + diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index 4ac6dfdff21c..7ef4ecb37651 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -136,7 +136,6 @@ def get_inputs(node, kwargs): outputs_lookup = kwargs["outputs_lookup"] inputs = node["inputs"] attrs = node.get("attrs", {}) - input_nodes = [] for ip in inputs: input_node_name = outputs_lookup[ip[0]][ip[1]].name @@ -1732,3 +1731,68 @@ def convert_logsoftmax(node, **kwargs): ) return [node] + + +@mx_op.register('_split_v2', OPSET_VERSION) +def convert_contrib_split_v2(node, **kwargs): + """Map MXNet's _split_v2 operator + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + axis = int(attrs.get('axis', 0)) + squeeze_axis = attrs.get('squeeze_axis', 'False') + sections = int(attrs.get('sections', 0)) + indices = convert_string_to_list(attrs.get('indices', '[]')) + if sections <= 0 and len(indices) == 0: + raise NotImplementedError('section or indices must be set') + if sections > 0: + output_nodes = [name+str(i) for i in range(sections)] + if squeeze_axis == 'False': + nodes = [ + make_node('Split', input_nodes, output_nodes, axis=axis), + ] + else: + output_nodes_ = [name+str(i)+'_' for i in range(sections)] + create_tensor([axis], name+'_axis', kwargs['initializer']) + nodes = [ + make_node('Split', input_nodes, output_nodes_, axis=axis), + ] + for i in range(sections): + nodes += [ + make_node("Squeeze", [output_nodes_[i], name+'_axis'], [output_nodes[i]]), + ] + else: + # indices += [0] + indices.sort() + split = [] + for i in range(1, len(indices)): + if indices[i] >= indices[i-1]: + split.append(indices[i] - indices[i-1]) + + output_nodes = [name+str(i) for i in range(len(split)+1)] + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([axis], name+'_axis', kwargs['initializer']) + create_tensor([axis+1], name+'_axis+1', kwargs['initializer']) + create_tensor(split, name+'_split_', kwargs['initializer']) + create_tensor([sum(split)], name+'_sum', kwargs['initializer']) + nodes = [ + make_node('Shape', input_nodes, [name+'_shape']), + make_node('Slice', [name+'_shape', name+'_axis', name+'_axis+1', name+'_0'], [name+'_dim']), + make_node('Sub', [name+'_dim', name+'_sum'], [name+'_sub']), + make_node('Concat', [name+'_split_', name+'_sub'], [name+'_split'], axis=0), + ] + if squeeze_axis == 'False': + nodes += [ + make_node('Split', [input_nodes[0], name+'_split'], output_nodes, axis=axis), + ] + else: + output_nodes_ = [name+str(i)+'_' for i in range(len(split)+1)] + nodes += [ + make_node('Split', [input_nodes[0], name+'_split'], output_nodes_, axis=axis), + ] + for i in range(len(output_nodes)): + nodes += [ + make_node("Squeeze", [output_nodes_[i], name+'_axis'], [output_nodes[i]]), + ] + + return nodes \ No newline at end of file diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index a2971b8f139d..cdbebb0d3a93 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1858,3 +1858,19 @@ def rand_check(out): def rand_check_nd(out): return rand_check(out.asnumpy()) op_export_test('sample_multinomial', M, [x], tmp_path, mx_map=rand_check_nd, onnx_map=rand_check) + + +@pytest.mark.parametrize("dtype", ['float32', 'int32', 'int64']) +@pytest.mark.parametrize('params', [((2, 4, 6), (1, ), 0, True), + ((4, 5, 6), (2, 4), 1, False), + ((4, 5, 6, 7), (0, 2, 4), 2, False), + ((4, 5, 6, 7), 3, -2, False), + ((2, 6, 8), 8, -1, True)]) +def test_onnx_export_split_v2(tmp_path, dtype, params): + from onnx.defs import onnx_opset_version + if onnx_opset_version() < 13 and not isinstance(params[1], int): + # opset12 only supports sections. indices is supported since opset13 + return + M = def_model('split_v2', indices_or_sections=params[1], axis=params[2], squeeze_axis=params[3]) + x = mx.nd.random.uniform(0, 10, params[0]).astype(dtype) + op_export_test('split_v2', M, [x], tmp_path) From 1fbfd3ef7d1de44f445b09658d5b68d460e70974 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 5 May 2021 18:47:38 -0700 Subject: [PATCH 2/4] support large index --- .../onnx/mx2onnx/_op_translations/_op_translations_opset13.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index 7ef4ecb37651..826198a5a9db 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -1779,7 +1779,9 @@ def convert_contrib_split_v2(node, **kwargs): make_node('Shape', input_nodes, [name+'_shape']), make_node('Slice', [name+'_shape', name+'_axis', name+'_axis+1', name+'_0'], [name+'_dim']), make_node('Sub', [name+'_dim', name+'_sum'], [name+'_sub']), - make_node('Concat', [name+'_split_', name+'_sub'], [name+'_split'], axis=0), + make_node('Concat', [name+'_split_', name+'_sub'], [name+'_concat'], axis=0), + make_node('Less', [name+'_concat', name+'_0'], [name+'_less']), + make_node('Where', [name+'_less', name+'_0', name+'_concat'], [name+'_split']), ] if squeeze_axis == 'False': nodes += [ From c5bad07c79081a02b7e5d02cdadd9619b9049d6e Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Wed, 5 May 2021 19:11:42 -0700 Subject: [PATCH 3/4] fix sanity --- .../_op_translations_opset12.py | 29 ------------------- .../_op_translations_opset13.py | 15 +++++----- 2 files changed, 7 insertions(+), 37 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index 417dda33a3f5..a3c3c79809ed 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -4856,34 +4856,6 @@ def convert_contrib_div_sqrt_dim(node, **kwargs): return nodes -@mx_op.register('_contrib_div_sqrt_dim') -def convert_contrib_div_sqrt_dim(node, **kwargs): - """Map MXNet's _contrib_div_sqrt_dim operator - """ - from onnx.helper import make_node - name, input_nodes, _ = get_inputs(node, kwargs) - input_dtypes = get_input_dtypes(node, kwargs) - - dtype = input_dtypes[0] - dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype] - - create_tensor([0], name+'_0', kwargs['initializer']) - create_tensor([1], name+'_1', kwargs['initializer']) - create_tensor([1], name+'_1_f', kwargs['initializer'], dtype=dtype) - nodes = [ - make_node('Shape', [input_nodes[0]], [name+'_shape']), - make_node('Shape', [name+'_shape'], [name+'_dim']), - make_node('Sub', [name+'_dim', name+'_1'], [name+'_dim_m1']), - make_node('Slice', [name+'_shape', name+'_dim_m1', name+'_dim', name+'_0'], [name+'_c_']), - make_node('Cast', [name+'_c_'], [name+'_c'], to=dtype_t), - make_node('Sqrt', [name+'_c'], [name+'_c_sqrt']), - make_node('Div', [name+'_1_f', name+'_c_sqrt'], [name+'_1_over_c_sqrt']), - make_node('Mul', [input_nodes[0], name+'_1_over_c_sqrt'], [name]) - ] - - return nodes - - @mx_op.register('_split_v2') def convert_contrib_split_v2(node, **kwargs): """Map MXNet's _split_v2 operator @@ -4915,4 +4887,3 @@ def convert_contrib_split_v2(node, **kwargs): raise NotImplementedError('indices is supported since ONNX 1.8.0 (opset13), please upgrade ONNX version') return nodes - diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index 826198a5a9db..ac3ff599c714 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -1762,7 +1762,6 @@ def convert_contrib_split_v2(node, **kwargs): make_node("Squeeze", [output_nodes_[i], name+'_axis'], [output_nodes[i]]), ] else: - # indices += [0] indices.sort() split = [] for i in range(1, len(indices)): @@ -1776,12 +1775,12 @@ def convert_contrib_split_v2(node, **kwargs): create_tensor(split, name+'_split_', kwargs['initializer']) create_tensor([sum(split)], name+'_sum', kwargs['initializer']) nodes = [ - make_node('Shape', input_nodes, [name+'_shape']), - make_node('Slice', [name+'_shape', name+'_axis', name+'_axis+1', name+'_0'], [name+'_dim']), - make_node('Sub', [name+'_dim', name+'_sum'], [name+'_sub']), - make_node('Concat', [name+'_split_', name+'_sub'], [name+'_concat'], axis=0), - make_node('Less', [name+'_concat', name+'_0'], [name+'_less']), - make_node('Where', [name+'_less', name+'_0', name+'_concat'], [name+'_split']), + make_node('Shape', input_nodes, [name+'_shape']), + make_node('Slice', [name+'_shape', name+'_axis', name+'_axis+1', name+'_0'], [name+'_dim']), + make_node('Sub', [name+'_dim', name+'_sum'], [name+'_sub']), + make_node('Concat', [name+'_split_', name+'_sub'], [name+'_concat'], axis=0), + make_node('Less', [name+'_concat', name+'_0'], [name+'_less']), + make_node('Where', [name+'_less', name+'_0', name+'_concat'], [name+'_split']), ] if squeeze_axis == 'False': nodes += [ @@ -1797,4 +1796,4 @@ def convert_contrib_split_v2(node, **kwargs): make_node("Squeeze", [output_nodes_[i], name+'_axis'], [output_nodes[i]]), ] - return nodes \ No newline at end of file + return nodes From fb58449533e62c2bee5f5a423397fe4245350317 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Thu, 6 May 2021 10:54:22 -0700 Subject: [PATCH 4/4] fix sanity --- .../onnx/mx2onnx/_op_translations/_op_translations_opset13.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index ac3ff599c714..95bb27c7afcc 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -1791,9 +1791,9 @@ def convert_contrib_split_v2(node, **kwargs): nodes += [ make_node('Split', [input_nodes[0], name+'_split'], output_nodes_, axis=axis), ] - for i in range(len(output_nodes)): + for i, output_node in enumerate(output_nodes): nodes += [ - make_node("Squeeze", [output_nodes_[i], name+'_axis'], [output_nodes[i]]), + make_node("Squeeze", [output_nodes_[i], name+'_axis'], [output_node]), ] return nodes