From 35f129f2dfbf0a095834e8becef1fcfe133e8e97 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 27 Jan 2021 19:42:22 +0000 Subject: [PATCH 01/12] Fix onnx export of arange when stop is none. Refactor unit test to cover case. --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 5 +++++ tests/python-pytest/onnx/test_operators.py | 9 ++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 9a6f290cdf80..822760736762 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2839,6 +2839,11 @@ def convert_arange(node, **kwargs): step = attrs.get('step', 1.) dtype = attrs.get('dtype', 'float32') repeat = int(attrs.get('repeat', 1)) + + if stop == 'None': + stop = start + start = 0 + if repeat != 1: raise NotImplementedError("arange operator with repeat != 1 not yet implemented.") diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 74e850be995e..81e733896fc6 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -118,14 +118,13 @@ def test_onnx_export_arange_like(tmp_path, dtype, axis, start, step, test_data): op_export_test('arange_like', M, [x], tmp_path) -@pytest.mark.parametrize("stop", [2, 50, 5000]) -@pytest.mark.parametrize("step", [0.25, 0.5, 1, 5]) -@pytest.mark.parametrize("start", [0., 1.]) +@pytest.mark.parametrize("params", [[0, 2, 1], [0, 50, 0.25], [-100, 100, 0.5], [5, None, 1], [-5, None, -1]]) @pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64"]) -def test_onnx_export_arange(tmp_path, dtype, start, stop, step): +def test_onnx_export_arange(tmp_path, dtype, params): + start, stop, step = params[0], params[1], params[2] if "int" in dtype: start = int(start) - stop = int(stop) + stop = int(stop) if stop != None else None step = int(step) if step == 0: step = 1 From 9a95b577de3a7f92235c9013386cec9670071779 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Wed, 27 Jan 2021 21:23:54 +0000 Subject: [PATCH 02/12] Use correct outname names for SliceChannel. --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 822760736762..e231e3bd07f5 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1738,7 +1738,7 @@ def convert_slice_channel(node, **kwargs): node = onnx.helper.make_node( "Split", input_nodes, - [name+'_output'+str(i) for i in range(num_outputs)], + [name+str(i) for i in range(num_outputs)], axis=axis, split=[split for _ in range(num_outputs)], name=name, From c6db38b631c4c843204287856225c8aa6dbf2ff9 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Thu, 28 Jan 2021 23:19:38 +0000 Subject: [PATCH 03/12] Cast output of broadcast_equal to input type, like MXNet does. --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index e231e3bd07f5..bc4ba5748dd0 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1973,7 +1973,15 @@ def convert_broadcast_equal(node, **kwargs): """Map MXNet's broadcast_equal operator attributes to onnx's Equal operator and return the created node. """ - return create_basic_op_node('Equal', node, kwargs) + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + input_type = kwargs['in_type'] + + nodes = [ + make_node("Equal", input_nodes, [name+"_equal"]), + make_node("Cast", [name+"_equal"], [name], name=name, to=int(input_type)) + ] + return nodes @mx_op.register("broadcast_logical_and") From c3d7584e55d10cd06341a4cbccf56680c633f4e4 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Thu, 28 Jan 2021 23:22:54 +0000 Subject: [PATCH 04/12] Add unit test for broadcast_equal onnx export. --- tests/python-pytest/onnx/test_operators.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 81e733896fc6..a410a7d0b716 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -310,6 +310,14 @@ def test_onnx_export_broadcast_add(tmp_path, dtype): op_export_test('broadcast_add', M, [x, y], tmp_path) +@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64']) +def test_onnx_export_broadcast_equal(tmp_path, dtype): + M = def_model('broadcast_equal') + x = mx.nd.zeros((4,5,6), dtype=dtype) + y = mx.nd.ones((4,5,6), dtype=dtype) + op_export_test('broadcast_equal', M, [x, y], tmp_path) + + @pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64']) @pytest.mark.parametrize('axis', [0, 1, 2, -1]) def test_onnx_export_stack(tmp_path, dtype, axis): From b613890ed48fd1c710027b48446d9f77ea98027a Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Fri, 29 Jan 2021 00:44:05 +0000 Subject: [PATCH 05/12] Require onnx opset >= 11 for SliceChannel which doesn't require the splits attribute for onnx Split operator. Add operator unit test for SliceChannel. --- .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 9 +++++---- tests/python-pytest/onnx/test_operators.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index bc4ba5748dd0..388fe8d2c2d8 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1719,6 +1719,10 @@ def convert_slice_channel(node, **kwargs): """ name, input_nodes, attrs = get_inputs(node, kwargs) + opset_version = kwargs['opset_version'] + if opset_version < 11: + raise AttributeError('ONNX opset 11 or greater is required to export this operator') + num_outputs = int(attrs.get("num_outputs")) axis = int(attrs.get("axis", 1)) squeeze_axis = int(attrs.get("squeeze_axis", 0)) @@ -1733,15 +1737,12 @@ def convert_slice_channel(node, **kwargs): ) return [node] elif squeeze_axis == 0 and num_outputs > 1: - in_shape = kwargs.get('in_shape')[0] - split = in_shape[axis] // num_outputs node = onnx.helper.make_node( "Split", input_nodes, [name+str(i) for i in range(num_outputs)], axis=axis, - split=[split for _ in range(num_outputs)], - name=name, + name=name ) return [node] else: diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index a410a7d0b716..878ab32fbea8 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -620,6 +620,18 @@ def test_onnx_export_slice_like(tmp_path, dtype, axes): op_export_test('slice_like_3', M, [x, y3], tmp_path) +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64']) +@pytest.mark.parametrize('axis', [None, 0, 2]) +@pytest.mark.parametrize('num_outputs', [2, 5]) +def test_onnx_export_slice_channel(tmp_path, dtype, axis, num_outputs): + x = mx.nd.zeros((10,20,30,40), dtype=dtype) + if axis is None: + M = def_model('SliceChannel', num_outputs=num_outputs) + else: + M = def_model('SliceChannel', axis=axis, num_outputs=num_outputs) + op_export_test('SliceChannel', M, [x], tmp_path) + + @pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) @pytest.mark.parametrize('lhs_axes', [[1, 3], [3, 1], [-2, -4], [-4, -2]]) @pytest.mark.parametrize('rhs_axes', [[1, 3], [3, 1], [-2, -4], [-4, -2]]) From 499fe2c1bd55ac306009dfad90b3d4ce57e64aae Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Fri, 29 Jan 2021 01:55:04 +0000 Subject: [PATCH 06/12] Support passing dtype to zeros_like and ones_like. --- .../contrib/onnx/mx2onnx/_op_translations.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 388fe8d2c2d8..f1f4f675117b 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2691,10 +2691,15 @@ def convert_zeros_like(node, **kwargs): """Map MXNet's zeros_like operator attributes to onnx's ConstantOfShape operator. """ from onnx.helper import make_node, make_tensor - name, input_nodes, _ = get_inputs(node, kwargs) + name, input_nodes, attrs = get_inputs(node, kwargs) + dtype = attrs.get('dtype') + if dtype != None: + data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + else: + data_type = kwargs['in_type'] # create tensor with shape of input - tensor_value = make_tensor(name+"_zero", kwargs['in_type'], [1], [0]) + tensor_value = make_tensor(name+"_zero", data_type, [1], [0]) nodes = [ make_node("Shape", [input_nodes[0]], [name+"_shape"]), make_node("ConstantOfShape", [name+"_shape"], [name], name=name, value=tensor_value) @@ -2707,10 +2712,14 @@ def convert_ones_like(node, **kwargs): """Map MXNet's ones_like operator attributes to onnx's ConstantOfShape operator. """ from onnx.helper import make_node, make_tensor - name, input_nodes, _ = get_inputs(node, kwargs) - + name, input_nodes, attrs = get_inputs(node, kwargs) + dtype = attrs.get('dtype') + if dtype != None: + data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + else: + data_type = kwargs['in_type'] # create tensor with shape of input - tensor_value = make_tensor(name+"_one", kwargs['in_type'], [1], [1]) + tensor_value = make_tensor(name+"_one", data_type, [1], [1]) nodes = [ make_node("Shape", [input_nodes[0]], [name+"_shape"]), make_node("ConstantOfShape", [name+"_shape"], [name], name=name, value=tensor_value) @@ -3266,7 +3275,7 @@ def convert_broadcast_mod(node, **kwargs): make_node('Where', [name+'_mask', input_nodes[1], name+'_zero'], [name+'_adjustment']), make_node('Add', [name+'_mod', name+'_adjustment'], [name+'_adjusted']), make_node('Equal', [input_nodes[1], name+'_zero'], [name+'_mask_div_0']), - make_node('Where', [name+'_mask_div_0', name+'_zero', name+'_adjusted'], [name]) + make_node('Where', [name+'_mask_div_0', name+'_zero', name+'_adjusted'], [name], name=name) ] return nodes From b2a476eb2a2b4d223a31213131f5cb95af57afb0 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Fri, 29 Jan 2021 04:07:57 +0000 Subject: [PATCH 07/12] Refactor object classification tests to make more extendable and add object detection tests with currently supported models. --- tests/python-pytest/onnx/test_onnxruntime.py | 172 +++++++++++-------- 1 file changed, 104 insertions(+), 68 deletions(-) diff --git a/tests/python-pytest/onnx/test_onnxruntime.py b/tests/python-pytest/onnx/test_onnxruntime.py index 2cb2b772861b..0efb94df7ca6 100644 --- a/tests/python-pytest/onnx/test_onnxruntime.py +++ b/tests/python-pytest/onnx/test_onnxruntime.py @@ -17,6 +17,7 @@ import mxnet as mx import numpy as np +import gluoncv import onnxruntime from mxnet.test_utils import assert_almost_equal @@ -27,16 +28,47 @@ import pytest import shutil -# images that are tested and their accepted classes -test_images = [ - ['dog.jpg', [242,243]], - ['apron.jpg', [411,578,638,639,689,775]], - ['dolphin.jpg', [2,3,4,146,147,148,395]], - ['hammerheadshark.jpg', [3,4]], - ['lotus.jpg', [716,723,738,985]] -] -test_models = [ + +class GluonModel(): + def __init__(self, model_name, input_shape, input_dtype, tmpdir): + self.model_name = model_name + self.input_shape = input_shape + self.input_dtype = input_dtype + self.modelpath = os.path.join(tmpdir, model_name) + self.ctx = mx.cpu(0) + self.get_model() + self.export() + + def get_model(self): + self.model = mx.gluon.model_zoo.vision.get_model(self.model_name, pretrained=True, ctx=self.ctx, root=self.modelpath) + self.model.hybridize() + + def export(self): + data = mx.nd.zeros(self.input_shape, dtype=self.input_dtype, ctx=self.ctx) + self.model.forward(data) + self.model.export(self.modelpath, 0) + + def export_onnx(self): + onnx_file = self.modelpath + ".onnx" + mx.contrib.onnx.export_model(self.modelpath + "-symbol.json", self.modelpath + "-0000.params", + [self.input_shape], self.input_dtype, onnx_file) + return onnx_file + + def predict(self, data): + return self.model(data) + + +def download_test_images(image_urls, tmpdir): + from urllib.parse import urlparse + paths = [] + for url in image_urls: + filename = os.path.join(tmpdir, os.path.basename(urlparse(url).path)) + mx.test_utils.download(url, fname=filename) + paths.append(filename) + return paths + +@pytest.mark.parametrize('model', [ 'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25', 'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25', @@ -44,82 +76,86 @@ 'resnet101_v1', 'resnet101_v2', 'resnet152_v1', 'resnet152_v2', 'squeezenet1.0', 'squeezenet1.1', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn' -] - -@with_seed() -@pytest.mark.parametrize('model', test_models) -def test_cv_model_inference_onnxruntime(tmp_path, model): - def get_gluon_cv_model(model_name, tmp): - tmpfile = os.path.join(tmp, model_name) - ctx = mx.cpu(0) - net_fp32 = mx.gluon.model_zoo.vision.get_model(model_name, pretrained=True, ctx=ctx, root=tmp) - net_fp32.hybridize() - data = mx.nd.zeros((1,3,224,224), dtype='float32', ctx=ctx) - net_fp32.forward(data) - net_fp32.export(tmpfile, 0) - sym_file = tmpfile + '-symbol.json' - params_file = tmpfile + '-0000.params' - return sym_file, params_file - - def export_model_to_onnx(sym_file, params_file): - input_shape = (1,3,224,224) - onnx_file = os.path.join(os.path.dirname(sym_file), "model.onnx") - converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, [input_shape], - np.float32, onnx_file) - return onnx_file - +]) +def test_obj_class_model_inference_onnxruntime(tmp_path, model): def normalize_image(imgfile): - image = mx.image.imread(imgfile).asnumpy() - image_data = np.array(image).transpose(2, 0, 1) - img_data = image_data.astype('float32') - mean_vec = np.array([0.485, 0.456, 0.406]) - stddev_vec = np.array([0.229, 0.224, 0.225]) - norm_img_data = np.zeros(img_data.shape).astype('float32') + img_data = mx.image.imread(imgfile).transpose([2, 0, 1]).astype('float32') + mean_vec = mx.nd.array([0.485, 0.456, 0.406]) + stddev_vec = mx.nd.array([0.229, 0.224, 0.225]) + norm_img_data = mx.nd.zeros(img_data.shape).astype('float32') for i in range(img_data.shape[0]): norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i] return norm_img_data.reshape(1, 3, 224, 224).astype('float32') - def get_prediction(model, image): - pass + try: + tmp_path = str(tmp_path) + M = GluonModel(model, (1,3,224,224), 'float32', tmp_path) + onnx_file = M.export_onnx() + + # create onnxruntime session using the generated onnx file + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + session = onnxruntime.InferenceSession(onnx_file, ses_opt) + input_name = session.get_inputs()[0].name - def softmax(x): - x = x.reshape(-1) - e_x = np.exp(x - np.max(x)) - return e_x / e_x.sum(axis=0) + test_image_urls = [ + 'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/dog.jpg', + 'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/apron.jpg', + 'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/dolphin.jpg', + 'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/hammerheadshark.jpg', + 'https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/images/lotus.jpg' + ] - def load_imgnet_labels(tmpdir): - tmpfile = os.path.join(tmpdir, 'image_net_labels.json') - mx.test_utils.download('https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/image_net_labels.json', - fname=tmpfile) - return np.array(json.load(open(tmpfile, 'r'))) + for img in download_test_images(test_image_urls, tmp_path): + img_data = normalize_image(img) + mx_result = M.predict(img_data) + onnx_result = session.run([], {input_name: img_data.asnumpy()})[0] + assert_almost_equal(mx_result, onnx_result) - def download_test_images(tmpdir): - global test_images - for f,_ in test_images: - mx.test_utils.download('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/onnx/images/'+f+'?raw=true', - fname=os.path.join(tmpdir, f)) - return test_images + finally: + shutil.rmtree(tmp_path) - tmp_path = str(tmp_path) - try: - #labels = load_imgnet_labels(tmp_path) - test_images = download_test_images(tmp_path) - sym_file, params_file = get_gluon_cv_model(model, tmp_path) - onnx_file = export_model_to_onnx(sym_file, params_file) +class GluonCVModel(GluonModel): + def __init__(self, *args, **kwargs): + super(GluonCVModel, self).__init__(*args, **kwargs) + def get_model(self): + self.model = gluoncv.model_zoo.get_model(self.model_name, pretrained=True, ctx=self.ctx) + self.model.hybridize() + +@pytest.mark.parametrize('model', [ + 'center_net_resnet18_v1b_voc', + 'center_net_resnet50_v1b_voc', + 'center_net_resnet101_v1b_voc', + 'center_net_resnet18_v1b_coco', + 'center_net_resnet50_v1b_coco', + 'center_net_resnet101_v1b_coco' +]) +def test_obj_detection_model_inference_onnxruntime(tmp_path, model): + def normalize_image(imgfile): + x, _ = gluoncv.data.transforms.presets.center_net.load_test(imgfile, short=512) + return x + + try: + tmp_path = str(tmp_path) + M = GluonCVModel(model, (1,3,512,683), 'float32', tmp_path) + onnx_file = M.export_onnx() # create onnxruntime session using the generated onnx file ses_opt = onnxruntime.SessionOptions() ses_opt.log_severity_level = 3 session = onnxruntime.InferenceSession(onnx_file, ses_opt) input_name = session.get_inputs()[0].name - for img, accepted_ids in test_images: - img_data = normalize_image(os.path.join(tmp_path,img)) - raw_result = session.run([], {input_name: img_data}) - res = softmax(np.array(raw_result)).tolist() - class_idx = np.argmax(res) - assert(class_idx in accepted_ids) + test_image_urls = ['https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/dog.jpg'] + + for img in download_test_images(test_image_urls, tmp_path): + img_data = normalize_image(os.path.join(tmp_path, img)) + mx_class_ids, mx_scores, mx_boxes = M.predict(img_data) + onnx_scores, onnx_class_ids, onnx_boxes = session.run([], {input_name: img_data.asnumpy()}) + assert_almost_equal(mx_class_ids, onnx_class_ids) + assert_almost_equal(mx_scores, onnx_scores) + assert_almost_equal(mx_boxes, onnx_boxes) finally: shutil.rmtree(tmp_path) From 86b039f5e7f5321fe6d5202076a128ad39b4b51d Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Fri, 29 Jan 2021 05:53:53 +0000 Subject: [PATCH 08/12] Fix lint --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index f1f4f675117b..471df26686fc 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1975,7 +1975,7 @@ def convert_broadcast_equal(node, **kwargs): and return the created node. """ from onnx.helper import make_node - name, input_nodes, attrs = get_inputs(node, kwargs) + name, input_nodes, _ = get_inputs(node, kwargs) input_type = kwargs['in_type'] nodes = [ @@ -2693,7 +2693,7 @@ def convert_zeros_like(node, **kwargs): from onnx.helper import make_node, make_tensor name, input_nodes, attrs = get_inputs(node, kwargs) dtype = attrs.get('dtype') - if dtype != None: + if dtype is not None: data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] else: data_type = kwargs['in_type'] @@ -2714,7 +2714,7 @@ def convert_ones_like(node, **kwargs): from onnx.helper import make_node, make_tensor name, input_nodes, attrs = get_inputs(node, kwargs) dtype = attrs.get('dtype') - if dtype != None: + if dtype is not None: data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] else: data_type = kwargs['in_type'] From 808819cc409b307eebd77e492da2e861e698ef98 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Fri, 29 Jan 2021 20:40:55 +0000 Subject: [PATCH 09/12] Update onnxruntime in CI to 1.6.0. --- ci/docker/install/ubuntu_onnx.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/docker/install/ubuntu_onnx.sh b/ci/docker/install/ubuntu_onnx.sh index a6bca56518ff..3407caac215f 100755 --- a/ci/docker/install/ubuntu_onnx.sh +++ b/ci/docker/install/ubuntu_onnx.sh @@ -31,4 +31,4 @@ apt-get update || true apt-get install -y libprotobuf-dev protobuf-compiler echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX, tabulate and onnxruntime..." -pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.4.0 gluonnlp +pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.6.0 gluonnlp From ebc8e6de6a25bbf920ffe7d62afaa631edc6723e Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Fri, 29 Jan 2021 22:08:10 +0000 Subject: [PATCH 10/12] Install gluoncv in CI environment for onnxruntime tests. --- ci/docker/install/ubuntu_onnx.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/docker/install/ubuntu_onnx.sh b/ci/docker/install/ubuntu_onnx.sh index 3407caac215f..6a5104a8910a 100755 --- a/ci/docker/install/ubuntu_onnx.sh +++ b/ci/docker/install/ubuntu_onnx.sh @@ -31,4 +31,4 @@ apt-get update || true apt-get install -y libprotobuf-dev protobuf-compiler echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX, tabulate and onnxruntime..." -pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.6.0 gluonnlp +pip3 install pytest pytest-cov protobuf==3.5.2 onnx==1.7.0 Pillow==5.0.0 tabulate==0.7.5 onnxruntime==1.6.0 gluonnlp gluoncv From fcfe006f0371436acaa33922e8628984c321bbee Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Sat, 30 Jan 2021 01:38:26 +0000 Subject: [PATCH 11/12] Add test case for negative axis in SliceChannel. --- tests/python-pytest/onnx/test_operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 878ab32fbea8..34ad8ffa8f71 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -621,7 +621,7 @@ def test_onnx_export_slice_like(tmp_path, dtype, axes): @pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64']) -@pytest.mark.parametrize('axis', [None, 0, 2]) +@pytest.mark.parametrize('axis', [None, 0, 2, -1]) @pytest.mark.parametrize('num_outputs', [2, 5]) def test_onnx_export_slice_channel(tmp_path, dtype, axis, num_outputs): x = mx.nd.zeros((10,20,30,40), dtype=dtype) From 22c9826d49af72d162b453c16364d2549d8ac806 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Sat, 30 Jan 2021 01:46:01 +0000 Subject: [PATCH 12/12] Add export functions and unit tests for broadcast_minimum and lesser_scalar operators. --- .../contrib/onnx/mx2onnx/_op_translations.py | 37 +++++++++++++++++++ tests/python-pytest/onnx/test_operators.py | 23 ++++++++++++ 2 files changed, 60 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 471df26686fc..990b98f868ec 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1548,6 +1548,13 @@ def convert_broadcast_mul(node, **kwargs): """ return create_basic_op_node('Mul', node, kwargs) +@mx_op.register("broadcast_minimum") +def convert_broadcast_min(node, **kwargs): + """Map MXNet's broadcast_minimum operator attributes to onnx's Min operator + and return the created node. + """ + return create_basic_op_node('Min', node, kwargs) + @mx_op.register("elemwise_div") def convert_elemwise_div(node, **kwargs): """Map MXNet's elemwise_div operator attributes to onnx's Div operator @@ -3116,6 +3123,36 @@ def convert_greater_scalar(node, **kwargs): return nodes +@mx_op.register("_lesser_scalar") +def convert_lesser_scalar(node, **kwargs): + """Map MXNet's lesser_scalar operator attributes to onnx's Less + operator and return the created node. + """ + from onnx.helper import make_node, make_tensor + name, input_nodes, attrs = get_inputs(node, kwargs) + + scalar = float(attrs.get('scalar')) + input_type = kwargs['in_type'] + dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] + + if str(dtype).startswith('int'): + scalar = int(scalar) + else: + if dtype == 'float16': + # when using float16, we must convert it to np.uint16 view first + # pylint: disable=too-many-function-args + scalar = np.float16(scalar).view(np.uint16) + + tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar]) + nodes = [ + make_node("Shape", [input_nodes[0]], [name+"_shape"]), + make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value), + make_node("Less", [input_nodes[0], name+"_rhs"], [name+"_lt"]), + make_node("Cast", [name+"_lt"], [name], to=input_type, name=name) + ] + return nodes + + @mx_op.register("where") def convert_where(node, **kwargs): """Map MXNet's where operator attributes to onnx's Where diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 34ad8ffa8f71..f627b3306fa3 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -318,6 +318,17 @@ def test_onnx_export_broadcast_equal(tmp_path, dtype): op_export_test('broadcast_equal', M, [x, y], tmp_path) +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) +def test_onnx_export_broadcast_minimum(tmp_path, dtype): + M = def_model('broadcast_minimum') + if 'int' in dtype: + x = mx.nd.random.randint(0, 1000, (4, 5, 6), dtype=dtype) + y = mx.nd.random.randint(0, 1000, (4, 5, 6), dtype=dtype) + else: + x = mx.nd.random.uniform(0, 1000, (4, 5, 6), dtype=dtype) + y = mx.nd.random.uniform(0, 1000, (4, 5, 6), dtype=dtype) + op_export_test('broadcast_minimum', M, [x, y], tmp_path) + @pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64']) @pytest.mark.parametrize('axis', [0, 1, 2, -1]) def test_onnx_export_stack(tmp_path, dtype, axis): @@ -467,6 +478,18 @@ def test_onnx_export_greater_scalar(tmp_path, dtype, scalar): op_export_test('_internal._greater_scalar', M, [x], tmp_path) +@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"]) +@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.]) +def test_onnx_export_lesser_scalar(tmp_path, dtype, scalar): + if 'int' in dtype: + scalar = int(scalar) + x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4)) + else: + x = mx.random.uniform(0, 9999, (5,10), dtype=dtype) + M = def_model('_internal._lesser_scalar', scalar=scalar) + op_export_test('_internal._lesser_scalar', M, [x], tmp_path) + + @pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"]) @pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)]) def test_onnx_export_where(tmp_path, dtype, shape):