Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add resiliency to onnx export code (#13426) (#13567)
Browse files Browse the repository at this point in the history
* Added resiliency to onnx export code

- With previous infer-shape implementation, if input shape was list instead of tuple or if extra non-existent parameters were provided, the code would still work. The fixes in this commit make sure that behavior is restored to prevent any compatibility issues with existing export code.

* Fixed name of net in unittest

* Fix pylint
  • Loading branch information
safrooze authored and nswamy committed Dec 7, 2018
1 parent d63fdfd commit 2d08816
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
5 changes: 3 additions & 2 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ def get_outputs(sym, params, in_shape, in_label):
# remove any input listed in params from sym.list_inputs() and bind them to the input shapes provided
# by user. Also remove in_label, which is the name of the label symbol that may have been used
# as the label for loss during training.
inputs = {n: s for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_shape)}
inputs = {n: tuple(s) for n, s in zip([n for n in sym.list_inputs() if n not in params and n != in_label],
in_shape)}
# Add params and their shape to list of inputs
inputs.update({n: v.shape for n, v in params.items()})
inputs.update({n: v.shape for n, v in params.items() if n in sym.list_inputs()})
# Provide input data as well as input params to infer_shape()
_, out_shapes, _ = sym.infer_shape(**inputs)

Expand Down
21 changes: 19 additions & 2 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,18 +286,19 @@ def _optional_group(symbols, group=False):
return symbols


def _check_onnx_export(net, group_outputs=False):
def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params={}):
net.initialize()
data = nd.random.uniform(0, 1, (1, 1024))
output = _force_list(net(data)) # initialize weights
net_sym = _optional_group(net(sym.Variable('data')), group_outputs)
net_params = {name:param._reduce() for name, param in net.collect_params().items()}
net_params.update(extra_params)
with tempfile.TemporaryDirectory() as tmpdirname:
onnx_file_path = os.path.join(tmpdirname, 'net.onnx')
export_path = onnx_mxnet.export_model(
sym=net_sym,
params=net_params,
input_shape=[data.shape],
input_shape=[shape_type(data.shape)],
onnx_file_path=onnx_file_path)
assert export_path == onnx_file_path
# Try importing the model to symbol
Expand Down Expand Up @@ -340,6 +341,22 @@ def hybrid_forward(self, F, x):
_check_onnx_export(net, group_outputs=True)


@with_seed()
def test_onnx_export_list_shape():
net = nn.HybridSequential(prefix='list_shape_net')
with net.name_scope():
net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
_check_onnx_export(net, shape_type=list)


@with_seed()
def test_onnx_export_extra_params():
net = nn.HybridSequential(prefix='extra_params_net')
with net.name_scope():
net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
_check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])})


if __name__ == '__main__':
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))
Expand Down

0 comments on commit 2d08816

Please sign in to comment.