Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Refactor onnx tests for object classification, add object detection tests #19802

Merged
merged 12 commits into from
Feb 2, 2021
Merged
2 changes: 1 addition & 1 deletion ci/docker/install/ubuntu_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ apt-get update || true
apt-get install -y libprotobuf-dev protobuf-compiler

echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX, tabulate and onnxruntime..."
pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.4.0 gluonnlp
pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.6.0 gluonnlp gluoncv
47 changes: 35 additions & 12 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,10 @@ def convert_slice_channel(node, **kwargs):
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

opset_version = kwargs['opset_version']
if opset_version < 11:
raise AttributeError('ONNX opset 11 or greater is required to export this operator')

num_outputs = int(attrs.get("num_outputs"))
axis = int(attrs.get("axis", 1))
squeeze_axis = int(attrs.get("squeeze_axis", 0))
Expand All @@ -1733,15 +1737,12 @@ def convert_slice_channel(node, **kwargs):
)
return [node]
elif squeeze_axis == 0 and num_outputs > 1:
in_shape = kwargs.get('in_shape')[0]
split = in_shape[axis] // num_outputs
node = onnx.helper.make_node(
"Split",
input_nodes,
[name+'_output'+str(i) for i in range(num_outputs)],
[name+str(i) for i in range(num_outputs)],
axis=axis,
split=[split for _ in range(num_outputs)],
name=name,
name=name
)
return [node]
else:
Expand Down Expand Up @@ -1973,7 +1974,15 @@ def convert_broadcast_equal(node, **kwargs):
"""Map MXNet's broadcast_equal operator attributes to onnx's Equal operator
and return the created node.
"""
return create_basic_op_node('Equal', node, kwargs)
from onnx.helper import make_node
name, input_nodes, _ = get_inputs(node, kwargs)
input_type = kwargs['in_type']

nodes = [
make_node("Equal", input_nodes, [name+"_equal"]),
make_node("Cast", [name+"_equal"], [name], name=name, to=int(input_type))
]
return nodes


@mx_op.register("broadcast_logical_and")
Expand Down Expand Up @@ -2682,10 +2691,15 @@ def convert_zeros_like(node, **kwargs):
"""Map MXNet's zeros_like operator attributes to onnx's ConstantOfShape operator.
"""
from onnx.helper import make_node, make_tensor
name, input_nodes, _ = get_inputs(node, kwargs)
name, input_nodes, attrs = get_inputs(node, kwargs)
dtype = attrs.get('dtype')
if dtype is not None:
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
else:
data_type = kwargs['in_type']

# create tensor with shape of input
tensor_value = make_tensor(name+"_zero", kwargs['in_type'], [1], [0])
tensor_value = make_tensor(name+"_zero", data_type, [1], [0])
nodes = [
make_node("Shape", [input_nodes[0]], [name+"_shape"]),
make_node("ConstantOfShape", [name+"_shape"], [name], name=name, value=tensor_value)
Expand All @@ -2698,10 +2712,14 @@ def convert_ones_like(node, **kwargs):
"""Map MXNet's ones_like operator attributes to onnx's ConstantOfShape operator.
"""
from onnx.helper import make_node, make_tensor
name, input_nodes, _ = get_inputs(node, kwargs)

name, input_nodes, attrs = get_inputs(node, kwargs)
dtype = attrs.get('dtype')
if dtype is not None:
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
else:
data_type = kwargs['in_type']
# create tensor with shape of input
tensor_value = make_tensor(name+"_one", kwargs['in_type'], [1], [1])
tensor_value = make_tensor(name+"_one", data_type, [1], [1])
nodes = [
make_node("Shape", [input_nodes[0]], [name+"_shape"]),
make_node("ConstantOfShape", [name+"_shape"], [name], name=name, value=tensor_value)
Expand Down Expand Up @@ -2839,6 +2857,11 @@ def convert_arange(node, **kwargs):
step = attrs.get('step', 1.)
dtype = attrs.get('dtype', 'float32')
repeat = int(attrs.get('repeat', 1))

if stop == 'None':
stop = start
start = 0

if repeat != 1:
raise NotImplementedError("arange operator with repeat != 1 not yet implemented.")

Expand Down Expand Up @@ -3252,7 +3275,7 @@ def convert_broadcast_mod(node, **kwargs):
make_node('Where', [name+'_mask', input_nodes[1], name+'_zero'], [name+'_adjustment']),
make_node('Add', [name+'_mod', name+'_adjustment'], [name+'_adjusted']),
make_node('Equal', [input_nodes[1], name+'_zero'], [name+'_mask_div_0']),
make_node('Where', [name+'_mask_div_0', name+'_zero', name+'_adjusted'], [name])
make_node('Where', [name+'_mask_div_0', name+'_zero', name+'_adjusted'], [name], name=name)
]

return nodes
Expand Down
172 changes: 104 additions & 68 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mxnet as mx
import numpy as np
import gluoncv
import onnxruntime

from mxnet.test_utils import assert_almost_equal
Expand All @@ -27,99 +28,134 @@
import pytest
import shutil

# images that are tested and their accepted classes
test_images = [
['dog.jpg', [242,243]],
['apron.jpg', [411,578,638,639,689,775]],
['dolphin.jpg', [2,3,4,146,147,148,395]],
['hammerheadshark.jpg', [3,4]],
['lotus.jpg', [716,723,738,985]]
]

test_models = [

class GluonModel():
def __init__(self, model_name, input_shape, input_dtype, tmpdir):
self.model_name = model_name
self.input_shape = input_shape
self.input_dtype = input_dtype
self.modelpath = os.path.join(tmpdir, model_name)
self.ctx = mx.cpu(0)
self.get_model()
self.export()

def get_model(self):
self.model = mx.gluon.model_zoo.vision.get_model(self.model_name, pretrained=True, ctx=self.ctx, root=self.modelpath)
self.model.hybridize()

def export(self):
data = mx.nd.zeros(self.input_shape, dtype=self.input_dtype, ctx=self.ctx)
self.model.forward(data)
self.model.export(self.modelpath, 0)

def export_onnx(self):
onnx_file = self.modelpath + ".onnx"
mx.contrib.onnx.export_model(self.modelpath + "-symbol.json", self.modelpath + "-0000.params",
[self.input_shape], self.input_dtype, onnx_file)
return onnx_file

def predict(self, data):
return self.model(data)


def download_test_images(image_urls, tmpdir):
from urllib.parse import urlparse
paths = []
for url in image_urls:
filename = os.path.join(tmpdir, os.path.basename(urlparse(url).path))
mx.test_utils.download(url, fname=filename)
paths.append(filename)
return paths

@pytest.mark.parametrize('model', [
'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25',
'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25',
'resnet18_v1', 'resnet18_v2', 'resnet34_v1', 'resnet34_v2', 'resnet50_v1', 'resnet50_v2',
'resnet101_v1', 'resnet101_v2', 'resnet152_v1', 'resnet152_v2',
'squeezenet1.0', 'squeezenet1.1',
'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn'
]

@with_seed()
@pytest.mark.parametrize('model', test_models)
def test_cv_model_inference_onnxruntime(tmp_path, model):
def get_gluon_cv_model(model_name, tmp):
tmpfile = os.path.join(tmp, model_name)
ctx = mx.cpu(0)
net_fp32 = mx.gluon.model_zoo.vision.get_model(model_name, pretrained=True, ctx=ctx, root=tmp)
net_fp32.hybridize()
data = mx.nd.zeros((1,3,224,224), dtype='float32', ctx=ctx)
net_fp32.forward(data)
net_fp32.export(tmpfile, 0)
sym_file = tmpfile + '-symbol.json'
params_file = tmpfile + '-0000.params'
return sym_file, params_file

def export_model_to_onnx(sym_file, params_file):
input_shape = (1,3,224,224)
onnx_file = os.path.join(os.path.dirname(sym_file), "model.onnx")
converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, [input_shape],
np.float32, onnx_file)
return onnx_file

])
def test_obj_class_model_inference_onnxruntime(tmp_path, model):
def normalize_image(imgfile):
image = mx.image.imread(imgfile).asnumpy()
image_data = np.array(image).transpose(2, 0, 1)
img_data = image_data.astype('float32')
mean_vec = np.array([0.485, 0.456, 0.406])
stddev_vec = np.array([0.229, 0.224, 0.225])
norm_img_data = np.zeros(img_data.shape).astype('float32')
img_data = mx.image.imread(imgfile).transpose([2, 0, 1]).astype('float32')
mean_vec = mx.nd.array([0.485, 0.456, 0.406])
stddev_vec = mx.nd.array([0.229, 0.224, 0.225])
norm_img_data = mx.nd.zeros(img_data.shape).astype('float32')
for i in range(img_data.shape[0]):
norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
return norm_img_data.reshape(1, 3, 224, 224).astype('float32')

def get_prediction(model, image):
pass
try:
tmp_path = str(tmp_path)
M = GluonModel(model, (1,3,224,224), 'float32', tmp_path)
onnx_file = M.export_onnx()

# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
input_name = session.get_inputs()[0].name

def softmax(x):
x = x.reshape(-1)
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
test_image_urls = [
'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/dog.jpg',
'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/apron.jpg',
'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/dolphin.jpg',
'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/hammerheadshark.jpg',
'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/lotus.jpg'
]

def load_imgnet_labels(tmpdir):
tmpfile = os.path.join(tmpdir, 'image_net_labels.json')
mx.test_utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/image_net_labels.json',
fname=tmpfile)
return np.array(json.load(open(tmpfile, 'r')))
for img in download_test_images(test_image_urls, tmp_path):
img_data = normalize_image(img)
mx_result = M.predict(img_data)
onnx_result = session.run([], {input_name: img_data.asnumpy()})[0]
assert_almost_equal(mx_result, onnx_result)

def download_test_images(tmpdir):
global test_images
for f,_ in test_images:
mx.test_utils.download('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/onnx/images/'+f+'?raw=true',
fname=os.path.join(tmpdir, f))
return test_images
finally:
shutil.rmtree(tmp_path)


tmp_path = str(tmp_path)
try:
#labels = load_imgnet_labels(tmp_path)
test_images = download_test_images(tmp_path)
sym_file, params_file = get_gluon_cv_model(model, tmp_path)
onnx_file = export_model_to_onnx(sym_file, params_file)

class GluonCVModel(GluonModel):
def __init__(self, *args, **kwargs):
super(GluonCVModel, self).__init__(*args, **kwargs)
def get_model(self):
self.model = gluoncv.model_zoo.get_model(self.model_name, pretrained=True, ctx=self.ctx)
self.model.hybridize()

@pytest.mark.parametrize('model', [
'center_net_resnet18_v1b_voc',
'center_net_resnet50_v1b_voc',
'center_net_resnet101_v1b_voc',
'center_net_resnet18_v1b_coco',
'center_net_resnet50_v1b_coco',
'center_net_resnet101_v1b_coco'
])
def test_obj_detection_model_inference_onnxruntime(tmp_path, model):
def normalize_image(imgfile):
x, _ = gluoncv.data.transforms.presets.center_net.load_test(imgfile, short=512)
return x

try:
tmp_path = str(tmp_path)
M = GluonCVModel(model, (1,3,512,683), 'float32', tmp_path)
onnx_file = M.export_onnx()
# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
input_name = session.get_inputs()[0].name

for img, accepted_ids in test_images:
img_data = normalize_image(os.path.join(tmp_path,img))
raw_result = session.run([], {input_name: img_data})
res = softmax(np.array(raw_result)).tolist()
class_idx = np.argmax(res)
assert(class_idx in accepted_ids)
test_image_urls = ['https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/dog.jpg']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we consider adding more test images?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we will soon in upcoming PRs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we should have tests which avoid depending on extrenal sources

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In upcoming PRs for these onnxruntime tests, let's migrate to use either an s3 bucket for test data or check them into the repo. the first is ideal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also create a second repo just for resources. I would actually prefer that because this would allow everyone to fully replicate the setup even it the s3 bucket goes away


for img in download_test_images(test_image_urls, tmp_path):
img_data = normalize_image(os.path.join(tmp_path, img))
mx_class_ids, mx_scores, mx_boxes = M.predict(img_data)
onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()})
assert_almost_equal(mx_class_ids, onnx_class_ids)
assert_almost_equal(mx_scores, onnx_scores)
assert_almost_equal(mx_boxes, onnx_boxes)

finally:
shutil.rmtree(tmp_path)
Expand Down
29 changes: 24 additions & 5 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,13 @@ def test_onnx_export_arange_like(tmp_path, dtype, axis, start, step, test_data):
op_export_test('arange_like', M, [x], tmp_path)


@pytest.mark.parametrize("stop", [2, 50, 5000])
@pytest.mark.parametrize("step", [0.25, 0.5, 1, 5])
@pytest.mark.parametrize("start", [0., 1.])
@pytest.mark.parametrize("params", [[0, 2, 1], [0, 50, 0.25], [-100, 100, 0.5], [5, None, 1], [-5, None, -1]])
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64"])
def test_onnx_export_arange(tmp_path, dtype, start, stop, step):
def test_onnx_export_arange(tmp_path, dtype, params):
start, stop, step = params[0], params[1], params[2]
if "int" in dtype:
start = int(start)
stop = int(stop)
stop = int(stop) if stop != None else None
step = int(step)
if step == 0:
step = 1
Expand Down Expand Up @@ -311,6 +310,14 @@ def test_onnx_export_broadcast_add(tmp_path, dtype):
op_export_test('broadcast_add', M, [x, y], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
def test_onnx_export_broadcast_equal(tmp_path, dtype):
M = def_model('broadcast_equal')
x = mx.nd.zeros((4,5,6), dtype=dtype)
y = mx.nd.ones((4,5,6), dtype=dtype)
op_export_test('broadcast_equal', M, [x, y], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize('axis', [0, 1, 2, -1])
def test_onnx_export_stack(tmp_path, dtype, axis):
Expand Down Expand Up @@ -613,6 +620,18 @@ def test_onnx_export_slice_like(tmp_path, dtype, axes):
op_export_test('slice_like_3', M, [x, y3], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
@pytest.mark.parametrize('axis', [None, 0, 2])
josephevans marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize('num_outputs', [2, 5])
def test_onnx_export_slice_channel(tmp_path, dtype, axis, num_outputs):
x = mx.nd.zeros((10,20,30,40), dtype=dtype)
if axis is None:
M = def_model('SliceChannel', num_outputs=num_outputs)
else:
M = def_model('SliceChannel', axis=axis, num_outputs=num_outputs)
op_export_test('SliceChannel', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
@pytest.mark.parametrize('lhs_axes', [[1, 3], [3, 1], [-2, -4], [-4, -2]])
@pytest.mark.parametrize('rhs_axes', [[1, 3], [3, 1], [-2, -4], [-4, -2]])
Expand Down