Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Onnx multi output (#13390)
Browse files Browse the repository at this point in the history
* Fix ONNX export to support multi-output graphs

* Add ONNX unit-test

* Added multi-output shape inference.

- Removed unnecessary forward_pass() call
- Modified infer_output_shape to return multiple shapes for multiple outputs as well as output names.

* Fixed pylint
  • Loading branch information
safrooze authored and ThomasDelteil committed Nov 26, 2018
1 parent 7b1e7a5 commit cb627bc
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 85 deletions.
128 changes: 43 additions & 85 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,8 @@
from __future__ import unicode_literals
import logging
import json
import numpy as np

from .... import context
from .... import ndarray as nd
from .... import io
from .... import module as mod


class MXNetGraph(object):
Expand Down Expand Up @@ -95,60 +91,6 @@ def convert_layer(node, **kwargs):
convert_func = MXNetGraph.registry_[op]
return convert_func(node, **kwargs)

@staticmethod
def forward_pass(inputs, sym, arg_params, aux_params, output_label):
"""Do a forward pass based on the sym and params to get the shape
of the output using dummy data
Parameters
----------
inputs : json string
sym : :class:`~mxnet.symbol.Symbol`
MXNet symbol object
arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
Returns
-------
shape : Shape
Output shape
"""
# if label is not provided, MXNet adds label "softmax_label" by default
# while running load_checkpoint which is not actually a graph input. So ignoring it here
data_names = [graph_input for graph_input in sym.list_inputs()
if graph_input not in arg_params and graph_input not in aux_params
and graph_input != output_label]

data_shapes = []
# Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs.
for idx, input_name in enumerate(data_names):
data_shapes.append((input_name, inputs[idx].shape))

# create module, passing cpu context
ctx = context.cpu()
test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None)
test_mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)

# initializing parameters for calculating result of each individual node
if arg_params is None and aux_params is None:
test_mod.init_params()
else:
test_mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True)

data_forward = []
for idx, input_name in enumerate(data_names):
val = inputs[idx]
data_forward.append(nd.array(val))

test_mod.forward(io.DataBatch(data_forward))
result = test_mod.get_outputs()[0].asnumpy()

return result.shape


@staticmethod
def split_params(sym, params):
"""Helper function to split params dictionary into args and aux params
Expand Down Expand Up @@ -177,15 +119,40 @@ def split_params(sym, params):
aux_params.update({aux: nd.array(params[aux])})
return arg_params, aux_params


@staticmethod
def infer_output_shape(sym, params, in_shape, output_label):
"""Infer output shape by doing a forward pass using dummy inputs """
# create dummy input
inputs = [np.random.randn(*input_shape) for input_shape in in_shape]
arg, aux = MXNetGraph.split_params(sym, params)
return MXNetGraph.forward_pass(inputs, sym, arg, aux, output_label)
def get_outputs(sym, params, in_shape, in_label):
""" Infer output shapes and return dictionary of output name to shape
:param :class:`~mxnet.symbol.Symbol` sym: symbol to perform infer shape on
:param dic of (str, nd.NDArray) params:
:param list of tuple(int, ...) in_shape: list of all input shapes
:param in_label: name of label typically used in loss that may be left in graph. This name is
removed from list of inputs required by symbol
:return: dictionary of output name to shape
:rtype: dict of (str, tuple(int, ...))
"""
# remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided
# by user. Also remove in_label, which is the name of the label symbol that may have been used
# as the label for loss during training.
inputs = {n: s for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_shape)}
# Add params and their shape to list of inputs
inputs.update({n: v.shape for n, v in params.items()})
# Provide input data as well as input params to infer_shape()
_, out_shapes, _ = sym.infer_shape(**inputs)

out_names = list()
for name in sym.list_outputs():
if name.endswith('_output'):
out_names.append(name[:-len('_output')])
else:
logging.warning("output '%s' does not end with '_output'", name)
out_names.append(name)

assert len(out_shapes) == len(out_names)
# bind output shapes with output names
graph_outputs = {n: s for n, s in zip(out_names, out_shapes)}

return graph_outputs

@staticmethod
def convert_weights_to_numpy(weights_dict):
Expand Down Expand Up @@ -228,9 +195,6 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# Deriving the output_label name.
output_label = sym.get_internals()[len(sym.get_internals()) - 1].name + "_label"

# Determine output shape
output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label)

weights = MXNetGraph.convert_weights_to_numpy(params)

mx_graph = json.loads(sym.tojson())["nodes"]
Expand All @@ -242,6 +206,9 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
onnx_processed_outputs = []
index_lookup = []

# Determine output shape
graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label)

graph_input_idx = 0
for idx, node in enumerate(mx_graph):
op = node["op"]
Expand Down Expand Up @@ -294,24 +261,15 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# If converted node is NodeProto, add it in processed nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
if idx == (len(mx_graph) - 1):
# If converted node doesnt have name, use it from output field
if not converted_node.name:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.output[0],
elem_type=in_type,
shape=output_shape
)
)
else:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.name,
elem_type=in_type,
shape=output_shape
)
node_name = converted_node.name if converted_node.name else converted_node.output[0]
if node_name in graph_outputs:
onnx_processed_outputs.append(
make_tensor_value_info(
name=node_name,
elem_type=in_type,
shape=graph_outputs[node_name]
)
)
if verbose:
logging.info("Output node is: %s", converted_node.name)
elif isinstance(converted_node, TensorProto):
Expand Down
76 changes: 76 additions & 0 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
import unittest
import logging
import tarfile
import tempfile
from collections import namedtuple
import numpy as np
import numpy.testing as npt
from onnx import numpy_helper, helper
from onnx import TensorProto
from mxnet import nd, sym
from mxnet.gluon import nn
from mxnet.test_utils import download
from mxnet.contrib import onnx as onnx_mxnet
import mxnet as mx
Expand Down Expand Up @@ -238,6 +241,79 @@ def test_square():

npt.assert_almost_equal(result, numpy_op)


def _assert_sym_equal(lhs, rhs):
assert lhs.list_inputs() == rhs.list_inputs() # input names must be identical
assert len(lhs.list_outputs()) == len(rhs.list_outputs()) # number of outputs must be identical


def _force_list(output):
if isinstance(output, nd.NDArray):
return [output]
return list(output)


def _optional_group(symbols, group=False):
if group:
return sym.Group(symbols)
else:
return symbols


def _check_onnx_export(net, group_outputs=False):
net.initialize()
data = nd.random.uniform(0, 1, (1, 1024))
output = _force_list(net(data)) # initialize weights
net_sym = _optional_group(net(sym.Variable('data')), group_outputs)
net_params = {name:param._reduce() for name, param in net.collect_params().items()}
with tempfile.TemporaryDirectory() as tmpdirname:
onnx_file_path = os.path.join(tmpdirname, 'net.onnx')
export_path = onnx_mxnet.export_model(
sym=net_sym,
params=net_params,
input_shape=[data.shape],
onnx_file_path=onnx_file_path)
assert export_path == onnx_file_path
# Try importing the model to symbol
_assert_sym_equal(net_sym, onnx_mxnet.import_model(export_path)[0])

# Try importing the model to gluon
imported_net = onnx_mxnet.import_to_gluon(export_path, ctx=None)
_assert_sym_equal(net_sym, _optional_group(imported_net(sym.Variable('data')), group_outputs))

# Confirm network outputs are the same
imported_net_output = _force_list(imported_net(data))
for out, imp_out in zip(output, imported_net_output):
mx.test_utils.assert_almost_equal(out.asnumpy(), imp_out.asnumpy())


@with_seed()
def test_onnx_export_single_output():
net = nn.HybridSequential(prefix='single_output_net')
with net.name_scope():
net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
_check_onnx_export(net)


@with_seed()
def test_onnx_export_multi_output():
class MultiOutputBlock(nn.HybridBlock):
def __init__(self):
super(MultiOutputBlock, self).__init__()
with self.name_scope():
self.net = nn.HybridSequential()
for i in range(10):
self.net.add(nn.Dense(100 + i * 10, activation='relu'))

def hybrid_forward(self, F, x):
out = tuple(block(x) for block in self.net._children.values())
return out

net = MultiOutputBlock()
assert len(sym.Group(net(sym.Variable('data'))).list_outputs()) == 10
_check_onnx_export(net, group_outputs=True)


if __name__ == '__main__':
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))
Expand Down

0 comments on commit cb627bc

Please sign in to comment.