Skip to content

Commit

Permalink
ONNX export: Support equal length splits (apache#14121)
Browse files Browse the repository at this point in the history
* ONNX export: Support equal length splits

* Fix lint error

* Add comment about checking for multiple outputs
  • Loading branch information
vandanavk authored and vdantu committed Mar 31, 2019
1 parent 9a0cac6 commit 71e4f9d
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
6 changes: 4 additions & 2 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,12 +1537,14 @@ def convert_slice_channel(node, **kwargs):
)
return [node]
elif squeeze_axis == 0 and num_outputs > 1:
in_shape = kwargs.get('in_shape')[0]
split = in_shape[axis] // num_outputs
node = onnx.helper.make_node(
"Split",
input_nodes,
[name],
[name+'_output'+str(i) for i in range(num_outputs)],
axis=axis,
split=[num_outputs],
split=[split for _ in range(num_outputs)],
name=name,
)
return [node]
Expand Down
23 changes: 13 additions & 10 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,17 +262,20 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# If converted node is NodeProto, add it in processed nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
node_name = converted_node.name if converted_node.name else converted_node.output[0]
if node_name in graph_outputs:
onnx_processed_outputs.append(
make_tensor_value_info(
name=node_name,
elem_type=in_type,
shape=graph_outputs[node_name]
# some operators have multiple outputs,
# therefore, check all output node names
node_names = list(converted_node.output)
for nodename in node_names:
if nodename in graph_outputs:
onnx_processed_outputs.append(
make_tensor_value_info(
name=nodename,
elem_type=in_type,
shape=graph_outputs[nodename]
)
)
)
if verbose:
logging.info("Output node is: %s", converted_node.name)
if verbose:
logging.info("Output node is: %s", nodename)
elif isinstance(converted_node, TensorProto):
raise ValueError("Did not expect TensorProto")
else:
Expand Down
12 changes: 7 additions & 5 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,15 @@ def split(attrs, inputs, proto_obj):
if not split_list:
num_outputs = len(proto_obj.model_metadata.get('output_tensor_data'))
else:
raise NotImplementedError("Operator {} in MXNet does not support variable splits."
"Tracking the issue to support variable split here: "
"https://github.com/apache/incubator-mxnet/issues/11594"
.format('split'))
if len(set(split_list)) == 1:
num_outputs = len(split_list)
else:
raise NotImplementedError("Operator {} in MXNet does not support variable splits."
"Tracking the issue to support variable split here: "
"https://github.com/apache/incubator-mxnet/issues/11594"
.format('split'))

new_attrs['num_outputs'] = num_outputs

return 'split', new_attrs, inputs

def _slice(attrs, inputs, proto_obj):
Expand Down
4 changes: 2 additions & 2 deletions tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@
'test_elu',
'test_max_',
'test_softplus',
'test_reduce_'
'test_reduce_',
'test_split_equal'
],
'import': ['test_gather',
'test_softsign',
Expand All @@ -88,7 +89,6 @@
'test_averagepool_2d_precomputed_strides',
'test_averagepool_2d_strides',
'test_averagepool_3d',
'test_split_equal',
'test_hardmax'
],
'export': ['test_random_uniform',
Expand Down

0 comments on commit 71e4f9d

Please sign in to comment.