diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index b02d970f9c2d..14c674f56f2d 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -53,12 +53,8 @@ from __future__ import unicode_literals import logging import json -import numpy as np -from .... import context from .... import ndarray as nd -from .... import io -from .... import module as mod class MXNetGraph(object): @@ -95,60 +91,6 @@ def convert_layer(node, **kwargs): convert_func = MXNetGraph.registry_[op] return convert_func(node, **kwargs) - @staticmethod - def forward_pass(inputs, sym, arg_params, aux_params, output_label): - """Do a forward pass based on the sym and params to get the shape - of the output using dummy data - - Parameters - ---------- - inputs : json string - - sym : :class:`~mxnet.symbol.Symbol` - MXNet symbol object - arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` - Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format - aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` - Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format - - Returns - ------- - shape : Shape - Output shape - """ - # if label is not provided, MXNet adds label "softmax_label" by default - # while running load_checkpoint which is not actually a graph input. So ignoring it here - data_names = [graph_input for graph_input in sym.list_inputs() - if graph_input not in arg_params and graph_input not in aux_params - and graph_input != output_label] - - data_shapes = [] - # Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs. - for idx, input_name in enumerate(data_names): - data_shapes.append((input_name, inputs[idx].shape)) - - # create module, passing cpu context - ctx = context.cpu() - test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None) - test_mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None) - - # initializing parameters for calculating result of each individual node - if arg_params is None and aux_params is None: - test_mod.init_params() - else: - test_mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True) - - data_forward = [] - for idx, input_name in enumerate(data_names): - val = inputs[idx] - data_forward.append(nd.array(val)) - - test_mod.forward(io.DataBatch(data_forward)) - result = test_mod.get_outputs()[0].asnumpy() - - return result.shape - - @staticmethod def split_params(sym, params): """Helper function to split params dictionary into args and aux params @@ -177,15 +119,40 @@ def split_params(sym, params): aux_params.update({aux: nd.array(params[aux])}) return arg_params, aux_params - @staticmethod - def infer_output_shape(sym, params, in_shape, output_label): - """Infer output shape by doing a forward pass using dummy inputs """ - # create dummy input - inputs = [np.random.randn(*input_shape) for input_shape in in_shape] - arg, aux = MXNetGraph.split_params(sym, params) - return MXNetGraph.forward_pass(inputs, sym, arg, aux, output_label) + def get_outputs(sym, params, in_shape, in_label): + """ Infer output shapes and return dictionary of output name to shape + + :param :class:`~mxnet.symbol.Symbol` sym: symbol to perform infer shape on + :param dic of (str, nd.NDArray) params: + :param list of tuple(int, ...) in_shape: list of all input shapes + :param in_label: name of label typically used in loss that may be left in graph. This name is + removed from list of inputs required by symbol + :return: dictionary of output name to shape + :rtype: dict of (str, tuple(int, ...)) + """ + # remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided + # by user. Also remove in_label, which is the name of the label symbol that may have been used + # as the label for loss during training. + inputs = {n: s for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_shape)} + # Add params and their shape to list of inputs + inputs.update({n: v.shape for n, v in params.items()}) + # Provide input data as well as input params to infer_shape() + _, out_shapes, _ = sym.infer_shape(**inputs) + + out_names = list() + for name in sym.list_outputs(): + if name.endswith('_output'): + out_names.append(name[:-len('_output')]) + else: + logging.warning("output '%s' does not end with '_output'", name) + out_names.append(name) + assert len(out_shapes) == len(out_names) + # bind output shapes with output names + graph_outputs = {n: s for n, s in zip(out_names, out_shapes)} + + return graph_outputs @staticmethod def convert_weights_to_numpy(weights_dict): @@ -228,9 +195,6 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) # Deriving the output_label name. output_label = sym.get_internals()[len(sym.get_internals()) - 1].name + "_label" - # Determine output shape - output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label) - weights = MXNetGraph.convert_weights_to_numpy(params) mx_graph = json.loads(sym.tojson())["nodes"] @@ -242,6 +206,9 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) onnx_processed_outputs = [] index_lookup = [] + # Determine output shape + graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label) + graph_input_idx = 0 for idx, node in enumerate(mx_graph): op = node["op"] @@ -294,24 +261,15 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) # If converted node is NodeProto, add it in processed nodes list elif isinstance(converted_node, NodeProto): onnx_processed_nodes.append(converted_node) - if idx == (len(mx_graph) - 1): - # If converted node doesnt have name, use it from output field - if not converted_node.name: - onnx_processed_outputs.append( - make_tensor_value_info( - name=converted_node.output[0], - elem_type=in_type, - shape=output_shape - ) - ) - else: - onnx_processed_outputs.append( - make_tensor_value_info( - name=converted_node.name, - elem_type=in_type, - shape=output_shape - ) + node_name = converted_node.name if converted_node.name else converted_node.output[0] + if node_name in graph_outputs: + onnx_processed_outputs.append( + make_tensor_value_info( + name=node_name, + elem_type=in_type, + shape=graph_outputs[node_name] ) + ) if verbose: logging.info("Output node is: %s", converted_node.name) elif isinstance(converted_node, TensorProto): diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index 9f91369d667e..bbff7833fe20 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -28,11 +28,14 @@ import unittest import logging import tarfile +import tempfile from collections import namedtuple import numpy as np import numpy.testing as npt from onnx import numpy_helper, helper from onnx import TensorProto +from mxnet import nd, sym +from mxnet.gluon import nn from mxnet.test_utils import download from mxnet.contrib import onnx as onnx_mxnet import mxnet as mx @@ -238,6 +241,79 @@ def test_square(): npt.assert_almost_equal(result, numpy_op) + +def _assert_sym_equal(lhs, rhs): + assert lhs.list_inputs() == rhs.list_inputs() # input names must be identical + assert len(lhs.list_outputs()) == len(rhs.list_outputs()) # number of outputs must be identical + + +def _force_list(output): + if isinstance(output, nd.NDArray): + return [output] + return list(output) + + +def _optional_group(symbols, group=False): + if group: + return sym.Group(symbols) + else: + return symbols + + +def _check_onnx_export(net, group_outputs=False): + net.initialize() + data = nd.random.uniform(0, 1, (1, 1024)) + output = _force_list(net(data)) # initialize weights + net_sym = _optional_group(net(sym.Variable('data')), group_outputs) + net_params = {name:param._reduce() for name, param in net.collect_params().items()} + with tempfile.TemporaryDirectory() as tmpdirname: + onnx_file_path = os.path.join(tmpdirname, 'net.onnx') + export_path = onnx_mxnet.export_model( + sym=net_sym, + params=net_params, + input_shape=[data.shape], + onnx_file_path=onnx_file_path) + assert export_path == onnx_file_path + # Try importing the model to symbol + _assert_sym_equal(net_sym, onnx_mxnet.import_model(export_path)[0]) + + # Try importing the model to gluon + imported_net = onnx_mxnet.import_to_gluon(export_path, ctx=None) + _assert_sym_equal(net_sym, _optional_group(imported_net(sym.Variable('data')), group_outputs)) + + # Confirm network outputs are the same + imported_net_output = _force_list(imported_net(data)) + for out, imp_out in zip(output, imported_net_output): + mx.test_utils.assert_almost_equal(out.asnumpy(), imp_out.asnumpy()) + + +@with_seed() +def test_onnx_export_single_output(): + net = nn.HybridSequential(prefix='single_output_net') + with net.name_scope(): + net.add(nn.Dense(100, activation='relu'), nn.Dense(10)) + _check_onnx_export(net) + + +@with_seed() +def test_onnx_export_multi_output(): + class MultiOutputBlock(nn.HybridBlock): + def __init__(self): + super(MultiOutputBlock, self).__init__() + with self.name_scope(): + self.net = nn.HybridSequential() + for i in range(10): + self.net.add(nn.Dense(100 + i * 10, activation='relu')) + + def hybrid_forward(self, F, x): + out = tuple(block(x) for block in self.net._children.values()) + return out + + net = MultiOutputBlock() + assert len(sym.Group(net(sym.Variable('data'))).list_outputs()) == 10 + _check_onnx_export(net, group_outputs=True) + + if __name__ == '__main__': test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000)) test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))