Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

fix for params with no dims in onnx #13413

Merged
merged 7 commits into from
Jan 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=invalid-name,too-many-locals,no-self-use
""" Support import export formats."""
from __future__ import absolute_import as _abs
import numpy as np
from .... import symbol
from .... import ndarray as nd
from ....base import string_types
Expand Down Expand Up @@ -87,7 +88,7 @@ def from_onnx(self, graph):
params : dict
A dict of name: nd.array pairs, used as pretrained weights
"""
#get input, output shapes
# get input, output shapes
self.model_metadata = self.get_graph_metadata(graph)
# parse network inputs, aka parameters
for init_tensor in graph.initializer:
Expand Down Expand Up @@ -196,7 +197,11 @@ def _parse_array(self, tensor_proto):
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. "
+ "Instructions to install - https://github.com/onnx/onnx")
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
if len(tuple(tensor_proto.dims)) > 0:
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
else:
# If onnx's params are scalar values without dims mentioned.
np_array = np.array([to_array(tensor_proto)])
return nd.array(np_array)

def _parse_attr(self, attr_proto):
Expand Down
11 changes: 11 additions & 0 deletions tests/python-pytest/onnx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
'https://s3.amazonaws.com/download.onnx/models/opset_8/inception_v2.tar.gz'
}

test_model_path = "https://s3.amazonaws.com/onnx-mxnet/test_model.onnx"

def get_test_files(name):
"""Extract tar file and returns model path and input, output data"""
Expand Down Expand Up @@ -152,6 +153,16 @@ def get_model_results(modelpath):

logging.info(model_name + " conversion successful")

def test_nodims_import(self):
# Download test model without dims mentioned in params
test_model = download(test_model_path, dirname=CURR_PATH.__str__())
input_data = np.array([0.2, 0.5])
nd_data = mx.nd.array(input_data).expand_dims(0)
sym, arg_params, aux_params = onnx_mxnet.import_model(test_model)
model_metadata = onnx_mxnet.get_model_metadata(test_model)
input_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]
output_data = forward_pass(sym, arg_params, aux_params, input_names, nd_data)
assert(output_data.shape == (1,1))

# test_case = ("model name", input shape, output shape)
test_cases = [
Expand Down