From 441cbc0551d0c52b9209fea437eb2ca276354a85 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Tue, 12 Feb 2019 13:21:29 -0800 Subject: [PATCH] ONNX export: Support equal length splits (#14121) * ONNX export: Support equal length splits * Fix lint error * Add comment about checking for multiple outputs --- .../contrib/onnx/mx2onnx/_op_translations.py | 6 +++-- .../mxnet/contrib/onnx/mx2onnx/export_onnx.py | 23 +++++++++++-------- .../contrib/onnx/onnx2mx/_op_translations.py | 12 ++++++---- tests/python-pytest/onnx/test_cases.py | 4 ++-- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index e077824e0226..f9d170d81c13 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1537,12 +1537,14 @@ def convert_slice_channel(node, **kwargs): ) return [node] elif squeeze_axis == 0 and num_outputs > 1: + in_shape = kwargs.get('in_shape')[0] + split = in_shape[axis] // num_outputs node = onnx.helper.make_node( "Split", input_nodes, - [name], + [name+'_output'+str(i) for i in range(num_outputs)], axis=axis, - split=[num_outputs], + split=[split for _ in range(num_outputs)], name=name, ) return [node] diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index d0d4501d89f4..a7b11fc902db 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -262,17 +262,20 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) # If converted node is NodeProto, add it in processed nodes list elif isinstance(converted_node, NodeProto): onnx_processed_nodes.append(converted_node) - node_name = converted_node.name if converted_node.name else converted_node.output[0] - if node_name in graph_outputs: - onnx_processed_outputs.append( - make_tensor_value_info( - name=node_name, - elem_type=in_type, - shape=graph_outputs[node_name] + # some operators have multiple outputs, + # therefore, check all output node names + node_names = list(converted_node.output) + for nodename in node_names: + if nodename in graph_outputs: + onnx_processed_outputs.append( + make_tensor_value_info( + name=nodename, + elem_type=in_type, + shape=graph_outputs[nodename] + ) ) - ) - if verbose: - logging.info("Output node is: %s", converted_node.name) + if verbose: + logging.info("Output node is: %s", nodename) elif isinstance(converted_node, TensorProto): raise ValueError("Did not expect TensorProto") else: diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index dc00feee815b..a7cef7674496 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -484,13 +484,15 @@ def split(attrs, inputs, proto_obj): if not split_list: num_outputs = len(proto_obj.model_metadata.get('output_tensor_data')) else: - raise NotImplementedError("Operator {} in MXNet does not support variable splits." - "Tracking the issue to support variable split here: " - "https://github.com/apache/incubator-mxnet/issues/11594" - .format('split')) + if len(set(split_list)) == 1: + num_outputs = len(split_list) + else: + raise NotImplementedError("Operator {} in MXNet does not support variable splits." + "Tracking the issue to support variable split here: " + "https://github.com/apache/incubator-mxnet/issues/11594" + .format('split')) new_attrs['num_outputs'] = num_outputs - return 'split', new_attrs, inputs def _slice(attrs, inputs, proto_obj): diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index b20db23aa1fd..89b60d15e84f 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -77,7 +77,8 @@ 'test_elu', 'test_max_', 'test_softplus', - 'test_reduce_' + 'test_reduce_', + 'test_split_equal' ], 'import': ['test_gather', 'test_softsign', @@ -88,7 +89,6 @@ 'test_averagepool_2d_precomputed_strides', 'test_averagepool_2d_strides', 'test_averagepool_3d', - 'test_split_equal', 'test_hardmax' ], 'export': ['test_random_uniform',