diff --git a/Jenkinsfile b/Jenkinsfile index e43b6f0d74d1..73e73f27a710 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -578,7 +578,18 @@ try { } stage('Integration Test') { - parallel 'Python GPU': { + parallel 'Onnx CPU': { + node('mxnetlinux-cpu') { + ws('workspace/it-onnx-cpu') { + init_git() + unpack_lib('cpu') + timeout(time: max_time, unit: 'MINUTES') { + sh "ci/build.py --build --platform ubuntu_cpu /work/runtime_functions.sh integrationtest_ubuntu_cpu_onnx" + } + } + } + }, + 'Python GPU': { node('mxnetlinux-gpu') { ws('workspace/it-python-gpu') { init_git() diff --git a/ci/docker/Dockerfile.build.ubuntu_cpu b/ci/docker/Dockerfile.build.ubuntu_cpu index f86c2f2e724b..d652a0d89c58 100755 --- a/ci/docker/Dockerfile.build.ubuntu_cpu +++ b/ci/docker/Dockerfile.build.ubuntu_cpu @@ -42,6 +42,8 @@ COPY install/ubuntu_mklml.sh /work/ RUN /work/ubuntu_mklml.sh COPY install/ubuntu_caffe.sh /work/ RUN /work/ubuntu_caffe.sh +COPY install/ubuntu_onnx.sh /work/ +RUN /work/ubuntu_onnx.sh COPY install/ubuntu_docs.sh /work/ RUN /work/ubuntu_docs.sh COPY install/ubuntu_adduser.sh /work/ diff --git a/ci/docker/install/ubuntu_onnx.sh b/ci/docker/install/ubuntu_onnx.sh new file mode 100755 index 000000000000..72613cd57882 --- /dev/null +++ b/ci/docker/install/ubuntu_onnx.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +###################################################################### +# This script installs ONNX for Python along with all required dependencies +# on a Ubuntu Machine. +# Tested on Ubuntu 16.04 distro. +###################################################################### + +set -e +set -x + +echo "Installing libprotobuf-dev and protobuf-compiler ..." +apt-get install -y libprotobuf-dev protobuf-compiler + +echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX and tabulate ..." +pip2 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.0.0 onnx==1.0.1 Pillow==5.0.0 tabulate==0.7.5 +pip3 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.0.0 onnx==1.0.1 Pillow==5.0.0 tabulate==0.7.5 diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 14a256dd6ea0..39809f281276 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -412,6 +412,14 @@ unittest_centos7_gpu() { python3.6 -m "nose" --with-timer --verbose tests/python/gpu } +integrationtest_ubuntu_cpu_onnx() { + set -ex + export PYTHONPATH=./python/ + python example/onnx/super_resolution.py + pytest tests/python-pytest/onnx/onnx_backend_test.py + pytest tests/python-pytest/onnx/onnx_test.py +} + integrationtest_ubuntu_gpu_python() { set -ex export PYTHONPATH=./python/ diff --git a/example/onnx/super_resolution.py b/example/onnx/super_resolution.py new file mode 100644 index 000000000000..1392b77715cb --- /dev/null +++ b/example/onnx/super_resolution.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Testing super_resolution model conversion""" +from __future__ import absolute_import as _abs +from __future__ import print_function +from collections import namedtuple +import logging +import numpy as np +from PIL import Image +import mxnet as mx +from mxnet.test_utils import download +import mxnet.contrib.onnx as onnx_mxnet + +# set up logger +logging.basicConfig() +LOGGER = logging.getLogger() +LOGGER.setLevel(logging.INFO) + +def import_onnx(): + """Import the onnx model into mxnet""" + model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx' + download(model_url, 'super_resolution.onnx') + + LOGGER.info("Converting onnx format to mxnet's symbol and params...") + sym, params = onnx_mxnet.import_model('super_resolution.onnx') + LOGGER.info("Successfully Converted onnx format to mxnet's symbol and params...") + return sym, params + +def get_test_image(): + """Download and process the test image""" + # Load test image + input_image_dim = 224 + img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg' + download(img_url, 'super_res_input.jpg') + img = Image.open('super_res_input.jpg').resize((input_image_dim, input_image_dim)) + img_ycbcr = img.convert("YCbCr") + img_y, img_cb, img_cr = img_ycbcr.split() + input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :] + return input_image, img_cb, img_cr + +def perform_inference(sym, params, input_img, img_cb, img_cr): + """Perform inference on image using mxnet""" + # create module + mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None) + mod.bind(for_training=False, data_shapes=[('input_0', input_img.shape)]) + mod.set_params(arg_params=params, aux_params=None) + + # run inference + batch = namedtuple('Batch', ['data']) + mod.forward(batch([mx.nd.array(input_img)])) + + # Save the result + img_out_y = Image.fromarray(np.uint8(mod.get_outputs()[0][0][0]. + asnumpy().clip(0, 255)), mode='L') + + result_img = Image.merge( + "YCbCr", [img_out_y, + img_cb.resize(img_out_y.size, Image.BICUBIC), + img_cr.resize(img_out_y.size, Image.BICUBIC)]).convert("RGB") + output_img_dim = 672 + assert result_img.size == (output_img_dim, output_img_dim) + LOGGER.info("Super Resolution example success.") + result_img.save("super_res_output.jpg") + return result_img + +if __name__ == '__main__': + MX_SYM, MX_PARAM = import_onnx() + INPUT_IMG, IMG_CB, IMG_CR = get_test_image() + perform_inference(MX_SYM, MX_PARAM, INPUT_IMG, IMG_CB, IMG_CR) diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py index 36ee21305bfd..63cd8ce26649 100644 --- a/python/mxnet/contrib/__init__.py +++ b/python/mxnet/contrib/__init__.py @@ -28,5 +28,5 @@ from . import tensorboard from . import text - +from . import onnx from . import io diff --git a/python/mxnet/contrib/onnx/__init__.py b/python/mxnet/contrib/onnx/__init__.py new file mode 100644 index 000000000000..eff91206298f --- /dev/null +++ b/python/mxnet/contrib/onnx/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Module for importing and exporting ONNX models.""" + +from ._import.import_model import import_model diff --git a/python/mxnet/contrib/onnx/_import/__init__.py b/python/mxnet/contrib/onnx/_import/__init__.py new file mode 100644 index 000000000000..002cfa925832 --- /dev/null +++ b/python/mxnet/contrib/onnx/_import/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""ONNX Import module""" +from . import import_model +from . import import_onnx diff --git a/python/mxnet/contrib/onnx/_import/import_helper.py b/python/mxnet/contrib/onnx/_import/import_helper.py new file mode 100644 index 000000000000..80541ec35774 --- /dev/null +++ b/python/mxnet/contrib/onnx/_import/import_helper.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=invalid-name +"""Operator attributes conversion""" +from .op_translations import identity, random_uniform, random_normal +from .op_translations import add, subtract, multiply, divide, absolute, negative, add_n +from .op_translations import tanh +from .op_translations import ceil, floor +from .op_translations import concat +from .op_translations import leaky_relu, _elu, _prelu, softmax, fully_connected +from .op_translations import global_avgpooling, global_maxpooling, linalg_gemm +from .op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm +from .op_translations import dropout, local_response_norm, conv, deconv +from .op_translations import reshape, cast, split, _slice, transpose, squeeze +from .op_translations import reciprocal, squareroot, power, exponent, _log +from .op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum +from .op_translations import reduce_prod, avg_pooling, max_pooling +from .op_translations import argmax, argmin, maximum, minimum + +# convert_map defines maps of ONNX operator names to converter functor(callable) +# defined in the op_translations module. +_convert_map = { + # Generator Functions + 'Constant' : identity, + 'RandomUniform' : random_uniform, + 'RandomNormal' : random_normal, + 'RandomUniformLike' : random_uniform, + 'RandomNormalLike' : random_normal, + # Arithmetic Operators + 'Add' : add, + 'Sub' : subtract, + 'Mul' : multiply, + 'Div' : divide, + 'Abs' : absolute, + 'Neg' : negative, + 'Sum' : add_n, #elemwise sum + #Hyperbolic functions + 'Tanh' : tanh, + # Rounding + 'Ceil' : ceil, + 'Floor' : floor, + # Joining and spliting + 'Concat' : concat, + # Basic neural network functions + 'Sigmoid' : sigmoid, + 'Relu' : relu, + 'Pad' : pad, + 'MatMul' : matrix_multiplication, #linalg_gemm2 + 'Conv' : conv, + 'ConvTranspose' : deconv, + 'BatchNormalization': batch_norm, + 'SpatialBN' : batch_norm, + 'LeakyRelu' : leaky_relu, + 'Elu' : _elu, + 'PRelu' : _prelu, + 'Softmax' : softmax, + 'FC' : fully_connected, + 'GlobalAveragePool' : global_avgpooling, + 'GlobalMaxPool' : global_maxpooling, + 'Gemm' : linalg_gemm, + 'LRN' : local_response_norm, + 'Dropout' : dropout, + # Changing shape and type. + 'Reshape' : reshape, + 'Cast' : cast, + 'Split' : split, + 'Slice' : _slice, + 'Transpose' : transpose, + 'Squeeze' : squeeze, + #Powers + 'Reciprocal' : reciprocal, + 'Sqrt' : squareroot, + 'Pow' : power, + 'Exp' : exponent, + 'Log' : _log, + # Reduce Functions + 'ReduceMax' : reduce_max, + 'ReduceMean' : reduce_mean, + 'ReduceMin' : reduce_min, + 'ReduceSum' : reduce_sum, + 'ReduceProd' : reduce_prod, + 'AveragePool' : avg_pooling, + 'MaxPool' : max_pooling, + # Sorting and Searching + 'ArgMax' : argmax, + 'ArgMin' : argmin, + 'Max' : maximum, #elemwise maximum + 'Min' : minimum #elemwise minimum +} diff --git a/python/mxnet/contrib/onnx/_import/import_model.py b/python/mxnet/contrib/onnx/_import/import_model.py new file mode 100644 index 000000000000..1df429b4690f --- /dev/null +++ b/python/mxnet/contrib/onnx/_import/import_model.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""import function""" +# pylint: disable=no-member + +from .import_onnx import GraphProto + +def import_model(model_file): + """Imports the ONNX model file passed as a parameter into MXNet symbol and parameters. + + Parameters + ---------- + model_file : str + ONNX model file name + + Returns + ------- + Mxnet symbol and parameter objects. + + sym : mxnet.symbol + Mxnet symbol + params : dict of str to mx.ndarray + Dict of converted parameters stored in mxnet.ndarray format + """ + graph = GraphProto() + + # loads model file and returns ONNX protobuf object + try: + import onnx + except ImportError: + raise ImportError("Onnx and protobuf need to be installed") + model_proto = onnx.load(model_file) + sym, params = graph.from_onnx(model_proto.graph) + return sym, params diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py b/python/mxnet/contrib/onnx/_import/import_onnx.py new file mode 100644 index 000000000000..56181c777be4 --- /dev/null +++ b/python/mxnet/contrib/onnx/_import/import_onnx.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=invalid-name,too-many-locals,no-self-use +""" Support import export formats.""" +from __future__ import absolute_import as _abs +from .... import symbol +from .... import ndarray as nd +from ....base import string_types +from .import_helper import _convert_map as convert_map + +class GraphProto(object): # pylint: disable=too-few-public-methods + """A helper class for handling mxnet symbol copying from pb2.GraphProto. + Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto + """ + def __init__(self): + self._nodes = {} + self._params = {} + self._renames = {} + self._num_input = 0 + self._num_param = 0 + + def _convert_operator(self, node_name, op_name, attrs, inputs): + """Convert from onnx operator to mxnet operator. + The converter must specify conversions explicitly for incompatible name, and + apply handlers to operator attributes. + + Parameters + ---------- + :param node_name : str + name of the node to be translated. + :param op_name : str + Operator name, such as Convolution, FullyConnected + :param attrs : dict + Dict of operator attributes + :param inputs: list + list of inputs to the operator + Returns + ------- + :return mxnet_sym + Converted mxnet symbol + """ + if op_name in convert_map: + op_name, new_attrs, inputs = convert_map[op_name](attrs, inputs, self) + else: + raise NotImplementedError("Operator {} not implemented.".format(op_name)) + if isinstance(op_name, string_types): + new_op = getattr(symbol, op_name, None) + if node_name is None: + mxnet_sym = new_op(*inputs, **new_attrs) + else: + mxnet_sym = new_op(name=node_name, *inputs, **new_attrs) + if not mxnet_sym: + raise RuntimeError("Unable to map op_name {} to sym".format(op_name)) + return mxnet_sym + return op_name + + def from_onnx(self, graph): + """Construct symbol from onnx graph. + The inputs from onnx graph is vague, only providing "1", "2"... + For convenience, we rename the `real` input names to "input_0", + "input_1"... And renaming parameters to "param_0", "param_1"... + + Parameters + ---------- + graph : onnx protobuf object + The loaded onnx graph + + Returns + ------- + sym :symbol.Symbol + The returned mxnet symbol + params : dict + A dict of name: nd.array pairs, used as pretrained weights + """ + # parse network inputs, aka parameters + for init_tensor in graph.initializer: + if not init_tensor.name.strip(): + raise ValueError("Tensor's name is required.") + self._params[init_tensor.name] = self._parse_array(init_tensor) + + # converting GraphProto message + for i in graph.input: + if i.name in self._params: + # i is a param instead of input + name_param = 'param_{}'.format(self._num_param) + self._num_param += 1 + self._params[name_param] = self._params.pop(i.name) + self._nodes[name_param] = symbol.Variable(name=name_param, + shape=self._params[name_param].shape) + self._renames[i.name] = name_param + else: + name_input = 'input_{}'.format(self._num_input) + self._num_input += 1 + self._nodes[name_input] = symbol.Variable(name=name_input) + self._renames[i.name] = name_input + + # constructing nodes, nodes are stored as directed acyclic graph + # converting NodeProto message + for node in graph.node: + op_name = node.op_type + node_name = node.name.strip() + node_name = node_name if node_name else None + onnx_attr = self._parse_attr(node.attribute) + inputs = [self._nodes[self._renames.get(i, i)] for i in node.input] + mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs) + + assert len(node.output) == len(mxnet_sym.list_outputs()), ( + "Output dimension mismatch between the onnx operator and the mxnet symbol " + + "{} vs {} for the operator - {}.".format( + len(node.output), len(mxnet_sym.list_outputs()), op_name)) + for k, i in zip(list(node.output), range(len(node.output))): + self._nodes[k] = mxnet_sym[i] + # now return the outputs + out = [self._nodes[i.name] for i in graph.output] + if len(out) > 1: + out = symbol.Group(out) + else: + out = out[0] + return out, self._params + + def _parse_array(self, tensor_proto): + """Grab data in TensorProto and convert to numpy array.""" + try: + from onnx.numpy_helper import to_array + except ImportError as e: + raise ImportError("Unable to import onnx which is required {}".format(e)) + np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims)) + return nd.array(np_array) + + def _parse_attr(self, attr_proto): + """Convert a list of AttributeProto to a dict, with names as keys.""" + attrs = {} + for a in attr_proto: + for f in ['f', 'i', 's']: + if a.HasField(f): + attrs[a.name] = getattr(a, f) + # Needed for supporting python version > 3.5 + if isinstance(attrs[a.name], bytes): + attrs[a.name] = attrs[a.name].decode(encoding='utf-8') + for f in ['floats', 'ints', 'strings']: + if list(getattr(a, f)): + assert a.name not in attrs, "Only one type of attr is allowed" + attrs[a.name] = tuple(getattr(a, f)) + for f in ['t', 'g']: + if a.HasField(f): + attrs[a.name] = getattr(a, f) + for f in ['tensors', 'graphs']: + if list(getattr(a, f)): + raise NotImplementedError("Filed {} is not supported in mxnet.".format(f)) + if a.name not in attrs: + raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) + return attrs diff --git a/python/mxnet/contrib/onnx/_import/op_translations.py b/python/mxnet/contrib/onnx/_import/op_translations.py new file mode 100644 index 000000000000..a67c18199eb8 --- /dev/null +++ b/python/mxnet/contrib/onnx/_import/op_translations.py @@ -0,0 +1,411 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +""" Module for translating ONNX operators into Mxnet operatoes""" +# pylint: disable=unused-argument,protected-access +from . import translation_utils +from .... import symbol + +# Method definitions for the callable objects mapped in the import_helper module + +def identity(attrs, inputs, cls): + """Returns the identity function of the the input.""" + return 'identity', attrs, inputs + +def random_uniform(attrs, inputs, cls): + """Draw random samples from a uniform distribtuion.""" + new_attr = translation_utils._remove_attributes(attrs, ['seed']) + return 'random_uniform', new_attr, inputs + +def random_normal(attrs, inputs, cls): + """Draw random samples from a Gaussian distribution.""" + new_attr = translation_utils._remove_attributes(attrs, ['seed']) + new_attr = translation_utils._fix_attribute_names(new_attr, {'mean' : 'loc'}) + return 'random_uniform', new_attr, inputs + +# Arithmetic Operations +def add(attrs, inputs, cls): + """Adding two tensors""" + new_attr = {} + if 'broadcast' in attrs and attrs['broadcast'] == 1: + op_value = translation_utils._fix_bias_shape('broadcast_add', inputs, cls) + return op_value, new_attr, inputs + return 'elemwise_add', new_attr, inputs + +def subtract(attrs, inputs, cls): + """Subtracting two tensors""" + new_attr = {} + if 'broadcast' in attrs and attrs['broadcast'] == 1: + return 'broadcast_sub', new_attr, inputs + return 'elemwise_sub', new_attr, inputs + + +def multiply(attrs, inputs, cls): + """Multiply two tensors""" + new_attr = {} + if 'broadcast' in attrs and attrs['broadcast'] == 1: + op_value = translation_utils._fix_bias_shape('broadcast_mul', inputs, cls) + return op_value, new_attr, inputs + return 'elemwise_mul', new_attr, inputs + +def divide(attrs, inputs, cls): + """Divide two tensors""" + new_attr = {} + if 'broadcast' in attrs and attrs['broadcast'] == 1: + return 'broadcast_div', new_attr, inputs + return 'elemwise_div', new_attr, inputs + +def absolute(attrs, inputs, cls): + """Returns element-wise absolute value of the input.""" + return 'abs', attrs, inputs + +def negative(attrs, inputs, cls): + """Negation of every element in a tensor""" + return 'negative', attrs, inputs + +def add_n(attrs, inputs, cls): + """Elementwise sum of arrays""" + return 'add_n', attrs, inputs + +# Sorting and Searching +def argmax(attrs, inputs, cls): + """Returns indices of the maximum values along an axis""" + return 'argmax', attrs, inputs + + +def argmin(attrs, inputs, cls): + """Returns indices of the minimum values along an axis.""" + return 'argmin', attrs, inputs + +def maximum(attrs, inputs, cls): + """ + Elementwise maximum of arrays. + MXNet maximum compares only two symbols at a time. + ONNX can send more than two to compare. + Breaking into multiple mxnet ops to compare two symbols at a time + """ + if len(inputs) > 1: + mxnet_op = symbol.maximum(inputs[0], inputs[1]) + for op_input in inputs[2:]: + mxnet_op = symbol.maximum(mxnet_op, op_input) + else: + mxnet_op = inputs[0] + return mxnet_op, attrs, inputs + +def minimum(attrs, inputs, cls): + """Elementwise minimum of arrays.""" + # MXNet minimum compares only two symbols at a time. + # ONNX can send more than two to compare. + # Breaking into multiple mxnet ops to compare two symbols at a time + if len(inputs) > 1: + mxnet_op = symbol.minimum(inputs[0], inputs[1]) + for op_input in inputs[2:]: + mxnet_op = symbol.minimum(mxnet_op, op_input) + else: + mxnet_op = inputs[0] + return mxnet_op, attrs, inputs + +#Hyperbolic functions +def tanh(attrs, inputs, cls): + """Returns the hyperbolic tangent of the input array.""" + return 'tanh', attrs, inputs + +# Rounding +def ceil(attrs, inputs, cls): + """ Calculate ceil value for input """ + return 'ceil', attrs, inputs + +def floor(attrs, inputs, cls): + """ Calculate floor value for input """ + return 'floor', attrs, inputs + +# Joining and spliting +def concat(attrs, inputs, cls): + """ Joins input arrays along a given axis. """ + new_attrs = translation_utils._fix_attribute_names(attrs, {'axis': 'dim'}) + return 'concat', new_attrs, inputs + + +# Basic neural network functions +def sigmoid(attrs, inputs, cls): + """Computes elementwise sigmoid of the input array""" + return 'sigmoid', attrs, inputs + +def relu(attrs, inputs, cls): + """Computes rectified linear function.""" + return 'relu', attrs, inputs + +def pad(attrs, inputs, cls): + """ Add padding to input tensor""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'pads' : 'pad_width', + 'value' : 'constant_value' + }) + new_attrs['pad_width'] = translation_utils._pad_sequence_fix(new_attrs.get('pad_width')) + return 'pad', new_attrs, inputs + +def matrix_multiplication(attrs, inputs, cls): + """Performs general matrix multiplication""" + return 'linalg_gemm2', attrs, inputs + +def batch_norm(attrs, inputs, cls): + """Batch normalization.""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'epsilon' : 'eps'}) + new_attrs = translation_utils._remove_attributes(new_attrs, + ['spatial', 'is_test', 'consumed_inputs']) + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'cudnn_off': 1}) + return 'BatchNorm', new_attrs, inputs + + +def leaky_relu(attrs, inputs, cls): + """Leaky Relu function""" + if 'alpha' in attrs: + new_attrs = translation_utils._fix_attribute_names(attrs, {'alpha' : 'slope'}) + else: + new_attrs = translation_utils._add_extra_attributes(attrs, {'slope': 0.01}) + return 'LeakyReLU', new_attrs, inputs + +def _elu(attrs, inputs, cls): + """Elu function""" + if 'alpha' in attrs: + new_attrs = translation_utils._fix_attribute_names(attrs, {'alpha' : 'slope'}) + else: + new_attrs = translation_utils._add_extra_attributes(attrs, {'slope': 1.0}) + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'act_type': 'elu'}) + return 'LeakyReLU', new_attrs, inputs + +def _prelu(attrs, inputs, cls): + """PRelu function""" + new_attrs = translation_utils._add_extra_attributes(attrs, {'act_type': 'prelu'}) + return 'LeakyReLU', new_attrs, inputs + +def softmax(attrs, inputs, cls): + """Softmax function.""" + if 'axis' not in attrs: + attrs = translation_utils._add_extra_attributes(attrs, {'axis': 1}) + return 'softmax', attrs, inputs + +def conv(attrs, inputs, cls): + """Compute N-D convolution on (N+2)-D input.""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'kernel_shape' : 'kernel', + 'strides' : 'stride', + 'pads': 'pad', + 'dilations': 'dilate', + 'group': 'num_group'}) + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'num_group' : 1}) + new_attrs = translation_utils._fix_bias('Convolution', new_attrs, len(inputs)) + + new_attrs = translation_utils._fix_channels('Convolution', new_attrs, inputs, cls) + + return 'Convolution', new_attrs, inputs + + +def deconv(attrs, inputs, cls): + """Compute N-D convolution on (N+2)-D input.""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'kernel_shape' : 'kernel', + 'strides' : 'stride', + 'pads': 'pad', + 'dilations': 'dilate', + 'group': 'num_group'}) + new_attrs = translation_utils._add_extra_attributes(new_attrs, {'num_group' : 1}) + new_attrs = translation_utils._fix_bias('Deconvolution', new_attrs, len(inputs)) + + new_attrs = translation_utils._fix_channels('Deconvolution', new_attrs, inputs, cls) + + return 'Convolution', new_attrs, inputs + + +def fully_connected(attrs, inputs, cls): + """Applies a linear transformation: Y=XWT+b.""" + new_attrs = translation_utils._remove_attributes(attrs, ['axis']) + + new_attrs = translation_utils._fix_bias('FullyConnected', new_attrs, len(inputs)) + + new_attrs = translation_utils._fix_channels('FullyConnected', new_attrs, inputs, cls) + + return 'FullyConnected', new_attrs, inputs + + +def global_maxpooling(attrs, inputs, cls): + """Performs max pooling on the input.""" + new_attrs = translation_utils._add_extra_attributes(attrs, {'global_pool': True, + 'kernel': (1, 1), + 'pool_type': 'max'}) + return 'pooling', new_attrs, inputs + + +def global_avgpooling(attrs, inputs, cls): + """Performs avg pooling on the input.""" + new_attrs = translation_utils._add_extra_attributes(attrs, {'global_pool': True, + 'kernel': (1, 1), + 'pool_type': 'avg'}) + return 'pooling', new_attrs, inputs + + +def linalg_gemm(attrs, inputs, cls): + """Performs general matrix multiplication and accumulation""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'transA': 'transpose_a', + 'transB': 'transpose_b'}) + new_attrs = translation_utils._remove_attributes(new_attrs, ['broadcast']) + return translation_utils._fix_gemm('FullyConnected', inputs, new_attrs, cls) + +def local_response_norm(op_name, attrs, inputs): + """Local Response Normalization.""" + new_attrs = translation_utils._fix_attribute_names(attrs, + {'bias': 'knorm', + 'size' : 'nsize'}) + return 'LRN', new_attrs, inputs + +def dropout(op_name, attrs, inputs): + """Dropout Regularization.""" + new_attrs = translation_utils._fix_attribute_names(attrs, + {'ratio': 'p'}) + new_attrs = translation_utils._remove_attributes(new_attrs, ['is_test']) + return 'Dropout', new_attrs, inputs + +# Changing shape and type. +def reshape(attrs, inputs, cls): + """Reshape the given array by the shape attribute.""" + return 'reshape', attrs, inputs + +def cast(attrs, inputs, cls): + """ Cast input to a given dtype""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'to' : 'dtype'}) + return 'cast', new_attrs, inputs + +def split(attrs, inputs, cls): + """Splits an array along a particular axis into multiple sub-arrays.""" + new_attrs = translation_utils._fix_attribute_names(attrs, + {'split' : 'num_outputs'}) + return 'split', new_attrs, inputs + +def _slice(attrs, inputs, cls): + """Returns a slice of the input tensor along multiple axes.""" + new_attrs = translation_utils._fix_attribute_names(attrs, + {'axes' : 'axis', + 'ends' : 'end', + 'starts' : 'begin'}) + # onnx slice provides slicing on multiple axis. Adding multiple slice_axis operator + # for multiple axes from mxnet + begin = new_attrs.get('begin') + end = new_attrs.get('end') + axes = new_attrs.get('axis', tuple(range(len(begin)))) + slice_op = symbol.slice_axis(inputs[0], axis=axes[0], begin=begin[0], end=end[0]) + if len(axes) > 1: + for i, axis in enumerate(axes): + slice_op = symbol.slice_axis(slice_op, axis=axis, begin=begin[i], end=end[i]) + return slice_op, new_attrs, inputs + +def transpose(attrs, inputs, cls): + """Transpose the input array.""" + new_attrs = translation_utils._fix_attribute_names(attrs, + {'perm' : 'axes'}) + return 'transpose', new_attrs, inputs + +def squeeze(attrs, inputs, cls): + """Remove single-dimensional entries from the shape of a tensor.""" + # MXNet doesnt have a squeeze operator. + # Using "split" to perform similar operation. + new_attrs = translation_utils._fix_attribute_names(attrs, + {'axes' : 'axis'}) + axes = new_attrs.get('axis') + mxnet_op = symbol.split(inputs[0], axis=axes[0], num_outputs=1, squeeze_axis=1) + for i in axes[1:]: + mxnet_op = symbol.split(mxnet_op, axis=i-1, num_outputs=1, squeeze_axis=1) + return mxnet_op, new_attrs, inputs + +#Powers +def reciprocal(attrs, inputs, cls): + """Returns the reciprocal of the argument, element-wise.""" + return 'reciprocal', attrs, inputs + +def squareroot(attrs, inputs, cls): + """Returns element-wise square-root value of the input.""" + return 'sqrt', attrs, inputs + +def power(attrs, inputs, cls): + """Returns element-wise result of base element raised to powers from exp element.""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'exponent':'exp'}) + if 'broadcast' in attrs and attrs['broadcast'] == 1: + new_attrs = translation_utils._remove_attributes(new_attrs, ['broadcast']) + return 'broadcast_power', new_attrs, inputs + return 'pow', new_attrs, inputs + +def exponent(attrs, inputs, cls): + """Elementwise exponent of input array.""" + return 'exp', attrs, inputs + +def _log(attrs, inputs, cls): + """Elementwise log of input array.""" + return 'log', attrs, inputs + +# Reduce Functions +def reduce_max(attrs, inputs, cls): + """Reduce the array along a given axis by maximum value""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'}) + return 'max', new_attrs, inputs + +def reduce_mean(attrs, inputs, cls): + """Reduce the array along a given axis by mean value""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'}) + return 'mean', new_attrs, inputs + +def reduce_min(attrs, inputs, cls): + """Reduce the array along a given axis by mean value""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'}) + return 'min', new_attrs, inputs + +def reduce_sum(attrs, inputs, cls): + """Reduce the array along a given axis by mean value""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'}) + return 'sum', new_attrs, inputs + +def reduce_prod(attrs, inputs, cls): + """Reduce the array along a given axis by mean value""" + new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'}) + return 'prod', new_attrs, inputs + +def avg_pooling(attrs, inputs, cls): + """ Average pooling""" + new_attrs = translation_utils._fix_attribute_names(attrs, + {'kernel_shape': 'kernel', + 'strides': 'stride', + 'pads': 'pad', + }) + new_attrs = translation_utils._add_extra_attributes(new_attrs, + {'pool_type': 'avg', + 'pooling_convention': 'valid' + }) + new_op = translation_utils._fix_pooling('avg', inputs, new_attrs) + + return new_op, new_attrs, inputs + + +def max_pooling(attrs, inputs, cls): + """ Average pooling""" + new_attrs = translation_utils._fix_attribute_names(attrs, + {'kernel_shape': 'kernel', + 'strides': 'stride', + 'pads': 'pad', + }) + new_attrs = translation_utils._add_extra_attributes(new_attrs, + {'pool_type': 'avg', + 'pooling_convention': 'valid' + }) + new_op = translation_utils._fix_pooling('max', inputs, new_attrs) + + return new_op, new_attrs, inputs diff --git a/python/mxnet/contrib/onnx/_import/translation_utils.py b/python/mxnet/contrib/onnx/_import/translation_utils.py new file mode 100644 index 000000000000..0fdef647b50b --- /dev/null +++ b/python/mxnet/contrib/onnx/_import/translation_utils.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""Utilities used for translating operators from Onnx to Mxnet.""" +# pylint: disable=protected-access +from __future__ import absolute_import as _abs +from .... import symbol + + +def _fix_attribute_names(attrs, change_map): + """ + Change attribute names as per values in change_map dictionary. + Parameters + ---------- + :param attrs : dict Dict of operator attributes + :param change_map : dict Dict of onnx attribute name to mxnet attribute names. + + Returns + ------- + :return new_attr : dict Converted dict of operator attributes. + """ + new_attr = {} + for k in attrs.keys(): + if k in change_map: + new_attr[change_map[k]] = attrs[k] + else: + new_attr[k] = attrs[k] + return new_attr + +def _remove_attributes(attrs, remove_list): + """ + Removes attributes in the remove list from the input attribute dict + :param attrs : Dict of operator attributes + :param remove_list : list of attributes to be removed + + :return new_attr : Dict of operator attributes without the listed attributes. + """ + new_attrs = {} + for attr in attrs.keys(): + if attr not in remove_list: + new_attrs[attr] = attrs[attr] + return new_attrs + +def _add_extra_attributes(attrs, extra_attr_map): + """ + :param attrs: Current Attribute list + :param extraAttrMap: Additional attributes to be added + :return: new_attr + """ + for attr in extra_attr_map: + if attr not in attrs: + attrs[attr] = extra_attr_map[attr] + return attrs + + +def _pad_sequence_fix(attr, kernel_dim=None): + """Changing onnx's pads sequence to match with mxnet's pad_width + mxnet: (x1_begin, x1_end, ... , xn_begin, xn_end) + onnx: (x1_begin, x2_begin, ... , xn_end, xn_end)""" + new_attr = () + if len(attr) % 2 == 0: + for index in range(int(len(attr) / 2)): + new_attr = new_attr + attr[index::int(len(attr) / 2)] + # Making sure pad values are in the attr for all axes. + if kernel_dim is not None: + while len(new_attr) < kernel_dim*2: + new_attr = new_attr + (0, 0) + + return new_attr + + +def _fix_pooling(pool_type, inputs, new_attr): + """onnx pooling operator supports asymmetrical padding + Adding pad operator before pooling in mxnet to work with onnx""" + stride = new_attr.get('stride') + kernel = new_attr.get('kernel') + padding = new_attr.get('pad') + pad_width = (0, 0, 0, 0) + _pad_sequence_fix(padding, len(kernel)) + new_pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width) + new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, + stride=stride, kernel=kernel) + return new_pooling_op + +def _fix_bias(op_name, attrs, num_inputs): + """A workaround for 'use_bias' attribute since onnx don't provide this attribute, + we have to check the number of inputs to decide it.""" + if num_inputs == 3: + attrs['no_bias'] = False + elif num_inputs == 2: + attrs['no_bias'] = True + else: + raise ValueError("Unexpected number of inputs for: {}".format(op_name)) + return attrs + +def _fix_bias_shape(op_name, inputs, cls): + """A workaround to reshape bias term to (1, num_channel).""" + if int(len(cls._params)) > 0: + assert len(list(inputs)) == 2 + + op_sym = symbol.reshape(inputs[1], shape=(1, -1, 1, 1)) + if op_name == 'broadcast_add': + op_sym = symbol.broadcast_add(op_sym, inputs[0]) + elif op_name == 'broadcast_mul': + op_sym = symbol.broadcast_mul(op_sym, inputs[0]) + else: + op_sym = op_name + return op_sym + + +def _fix_channels(op_name, attrs, inputs, cls): + """A workaround for getting 'channels' or 'units' since onnx don't provide + these attributes. We check the shape of weights provided to get the number. + """ + weight_name = inputs[1].name + if not weight_name in cls._params: + raise ValueError("Unable to get channels/units attr from onnx graph.") + else: + wshape = cls._params[weight_name].shape + assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape) + + if op_name == 'FullyConnected': + attrs['num_hidden'] = wshape[0] + else: + if op_name == 'Convolution': + # Weight shape for Conv and FC: (M x C x kH x kW) : M is number of + # feature maps/hidden and C is number of channels + attrs['num_filter'] = wshape[0] + elif op_name == 'Deconvolution': + # Weight shape for DeConv : (C x M x kH x kW) : M is number of + # feature maps/filters and C is number of channels + attrs['num_filter'] = wshape[1] + return attrs + + +def _fix_gemm(op_name, inputs, old_attr, cls): + """Using FullyConnected operator in place of linalg_gemm to perform same operation""" + op_sym = getattr(symbol, op_name, None) + alpha = float(old_attr.get('alpha', 1.0)) + beta = float(old_attr.get('beta', 1.0)) + trans_a = int(old_attr.get('transA', 0)) + trans_b = int(old_attr.get('transB', 0)) + if trans_a: + inputs[0] = symbol.transpose(inputs[0], axes=(1, 0)) + if not trans_b: + inputs[1] = symbol.transpose(inputs[1], axes=(1, 0)) + new_inputs = [alpha*inputs[0], inputs[1], beta*inputs[2]] + new_attr = {'num_hidden' : cls._params[inputs[2].name].shape[0]} + return op_sym, new_attr, new_inputs diff --git a/tests/python-pytest/onnx/backend.py b/tests/python-pytest/onnx/backend.py new file mode 100644 index 000000000000..3b99563bccf3 --- /dev/null +++ b/tests/python-pytest/onnx/backend.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""backend wrapper for onnx test infrastructure""" +import mxnet as mx +from mxnet.contrib.onnx._import.import_onnx import GraphProto +try: + from onnx import helper, TensorProto + from onnx.backend.base import Backend +except ImportError: + raise ImportError("Onnx and protobuf need to be installed") +from backend_rep import MXNetBackendRep + +# Using these functions for onnx test infrastructure. +# Implemented by following onnx docs guide: +# https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md +# MXNetBackend class will take an ONNX model with inputs, perform a computation, +# and then return the output. + +class MXNetBackend(Backend): + """MXNet backend for ONNX""" + + @staticmethod + def make_graph(node, inputs): + """ Created ONNX GraphProto from node""" + initializer = [] + tensor_input_info = [] + tensor_output_info = [] + + # Adding input tensor info. + for index in range(len(node.input)): + tensor_input_info.append( + helper.make_tensor_value_info(str(node.input[index]), TensorProto.FLOAT, [1])) + + # Creating an initializer for Weight params. + # Assumes that weight params is named as 'W'. + if node.input[index] == 'W': + dim = inputs[index].shape + param_tensor = helper.make_tensor( + name=node.input[index], + data_type=TensorProto.FLOAT, + dims=dim, + vals=inputs[index].flatten()) + + initializer.append(param_tensor) + + # Adding output tensor info. + for index in range(len(node.output)): + tensor_output_info.append( + helper.make_tensor_value_info(str(node.output[index]), TensorProto.FLOAT, [1])) + + # creating graph proto object. + graph_proto = helper.make_graph( + [node], + "test", + tensor_input_info, + tensor_output_info, + initializer=initializer) + + return graph_proto + + @classmethod + def run_node(cls, node, inputs, device='CPU'): + """Running individual node inference on mxnet engine and + return the result to onnx test infrastructure. + + Parameters + ---------- + node : onnx node object + loaded onnx node (individual layer) + inputs : numpy array + input to run a node on + device : 'CPU' + device to run a node on + + Returns + ------- + params : numpy array + result obtained after running the operator + """ + graph = GraphProto() + sym, _ = graph.from_onnx(MXNetBackend.make_graph(node, inputs)) + data_names = [i for i in sym.get_internals().list_inputs()] + data_shapes = [] + dim_change_op_types = set(['ReduceMin', 'ReduceMax', 'ReduceMean', + 'ReduceProd', 'ReduceSum', 'Slice', 'Pad', + 'Squeeze', 'Upsample', 'Reshape', 'Conv']) + + # Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs. + for idx, input_name in enumerate(data_names): + batch_size = 1 + if len(inputs) > 1 and len(inputs[idx].shape) < 4 and \ + len(set(x.shape[0] for x in inputs)) != 1: + tuples = ((batch_size,), inputs[idx].shape) + new_shape = sum(tuples, ()) + data_shapes.append((input_name, new_shape)) + else: + data_shapes.append((input_name, inputs[idx].shape)) + + # create module, passing cpu context + if device == 'CPU': + ctx = mx.cpu() + else: + raise NotImplementedError("Only CPU context is supported for now") + + # create a module + mod = mx.mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None) + mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None) + + # initializing parameters for calculating result of each individual node + mod.init_params() + + data_forward = [] + for idx, input_name in enumerate(data_names): + # slice and pad operator tests needs 1 less dimension in forward pass + # otherwise it will throw an error. + # for squeeze operator, need to retain shape of input as provided + val = inputs[idx] + if node.op_type in dim_change_op_types: + data_forward.append(mx.nd.array(val)) + else: + data_forward.append(mx.nd.array([val])) + + mod.forward(mx.io.DataBatch(data_forward)) + result = mod.get_outputs()[0].asnumpy() + if node.op_type in dim_change_op_types: + return [result] + return result + + @classmethod + def prepare(cls, model, device='CPU', **kwargs): + """For running end to end model(used for onnx test backend) + + Parameters + ---------- + model : onnx ModelProto object + loaded onnx graph + device : 'CPU' + specifying device to run test on + kwargs : + other arguments + + Returns + ------- + MXNetBackendRep : object + Returns object of MXNetBackendRep class which will be in turn + used to run inference on the input model and return the result for comparison. + """ + graph = GraphProto() + sym, params = graph.from_onnx(model.graph) + return MXNetBackendRep(sym, params, device) + + @classmethod + def supports_device(cls, device): + """Supports only CPU for testing""" + return device == 'CPU' + +prepare = MXNetBackend.prepare + +run_node = MXNetBackend.run_node + +supports_device = MXNetBackend.supports_device diff --git a/tests/python-pytest/onnx/backend_rep.py b/tests/python-pytest/onnx/backend_rep.py new file mode 100644 index 000000000000..a125086bce21 --- /dev/null +++ b/tests/python-pytest/onnx/backend_rep.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""backend rep for onnx test infrastructure""" +from collections import namedtuple +import numpy as np +try: + from onnx.backend.base import BackendRep +except ImportError: + raise ImportError("Onnx and protobuf need to be installed") +import mxnet as mx + +# Using these functions for onnx test infrastructure. +# Implemented by following onnx docs guide: +# https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md +# MXNetBackendRep object will be returned by MXNetBackend's prepare method which is used to +# execute a model repeatedly. +# Inputs will be passed to the run method of MXNetBackendRep class, it will perform computation and +# retrieve the corresponding results for comparison to the onnx backend. +# https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py. + +class MXNetBackendRep(BackendRep): + """Running model inference on mxnet engine and return the result + to onnx test infrastructure for comparison.""" + def __init__(self, symbol, params, device): + self.symbol = symbol + self.params = params + self.device = device + + def run(self, inputs, **kwargs): + """Run model inference and return the result + + Parameters + ---------- + inputs : numpy array + input to run a layer on + + Returns + ------- + params : numpy array + result obtained after running the inference on mxnet + """ + input_data = np.asarray(inputs[0], dtype='f') + + # create module, passing cpu context + if self.device == 'CPU': + ctx = mx.cpu() + else: + raise NotImplementedError("Only CPU context is supported for now") + + mod = mx.mod.Module(symbol=self.symbol, data_names=['input_0'], context=ctx, + label_names=None) + mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], + label_shapes=None) + mod.set_params(arg_params=self.params, aux_params=None) + + # run inference + batch = namedtuple('Batch', ['data']) + + mod.forward(batch([mx.nd.array(input_data)])) + result = mod.get_outputs()[0].asnumpy() + return [result] diff --git a/tests/python-pytest/onnx/onnx_backend_test.py b/tests/python-pytest/onnx/onnx_backend_test.py new file mode 100644 index 000000000000..28e2aaefcdd4 --- /dev/null +++ b/tests/python-pytest/onnx/onnx_backend_test.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""ONNX test backend wrapper""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest +try: + import onnx.backend.test +except ImportError: + raise ImportError("Onnx and protobuf need to be installed") + +import backend as mxnet_backend + +# This is a pytest magic variable to load extra plugins +pytest_plugins = "onnx.backend.test.report", + +BACKEND_TEST = onnx.backend.test.BackendTest(mxnet_backend, __name__) + +IMPLEMENTED_OPERATORS = [ + #Generator Functions + #'test_constant*', # Identity Function + #'test_random_uniform', + #'test_random_normal', + + #Arithmetic Operators + 'test_add', + 'test_sub', + 'test_mul', + 'test_div', + 'test_neg', + 'test_abs', + 'test_sum', + + #Hyperbolic functions + 'test_tanh', + + #Rounding + 'test_ceil', + 'test_floor', + + ## Joining and spliting + #'test_concat.*', #---Failing test + + #Basic neural network functions + 'test_sigmoid', + 'test_relu', + #'test_constant_pad', + #'test_edge_pad', + #'test_reflect_pad', + 'test_matmul', + 'test_leakyrelu', + 'test_elu', + #'test_softmax*', + 'test_conv', + 'test_basic_conv', + #'test_globalmaxpool', + #'test_globalaveragepool', + #'test_batch_norm', + + #Changing shape and type. + 'test_reshape_', + #'test_AvgPool2D*', + #'test_MaxPool2D*', + #'test_cast', + #'test_split', + 'test_slice_cpu', + 'test_default_axes', #make PR against onnx to fix the test name(grep-able) + 'test_slice_neg', + #'test_slice_start_out_of_bounds', + #'test_slice_end_out_of_bounds', + #'test_transpose*', + 'test_squeeze_', + + #Powers + 'test_reciprocal', + 'test_sqrt', + 'test_pow_example', + 'test_pow_cpu', + 'test_pow_bcast_cpu', + #'test_pow_bcast_axis0', + 'test_log_', + 'test_exp', + + # Sorting and Searching + 'test_argmax', + 'test_argmin', + 'test_max', + 'test_min' + ] + +for op_test in IMPLEMENTED_OPERATORS: + BACKEND_TEST.include(op_test) + +# import all test cases at global scope to make them visible to python.unittest +globals().update(BACKEND_TEST.enable_report().test_cases) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/python-pytest/onnx/onnx_test.py b/tests/python-pytest/onnx/onnx_test.py new file mode 100644 index 000000000000..016490a4c4bf --- /dev/null +++ b/tests/python-pytest/onnx/onnx_test.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Tests for individual operators +This module contains operator tests which currently do not exist on +ONNX backend test framework. Once we have PRs on the ONNX repo and get +those PRs merged, this file will get EOL'ed. +""" +from __future__ import absolute_import +import sys +import os +import unittest +import logging +import hashlib +import numpy as np +import numpy.testing as npt +from onnx import helper +import backend as mxnet_backend +CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(CURR_PATH, '../../python/unittest')) +from common import with_seed + +@with_seed() +def test_reduce_max(): + """Test for ReduceMax operator""" + node_def = helper.make_node("ReduceMax", ["input1"], ["output"], axes=[1, 0], keepdims=1) + input1 = np.random.ranf([3, 10]).astype("float32") + output = mxnet_backend.run_node(node_def, [input1])[0] + numpy_op = np.max(input1, axis=(1, 0), keepdims=True) + npt.assert_almost_equal(output, numpy_op) + +@with_seed() +def test_reduce_mean(): + """Test for ReduceMean operator""" + node_def = helper.make_node("ReduceMean", ["input1"], ["output"], axes=[1, 0], keepdims=1) + input1 = np.random.ranf([3, 10]).astype("float32") + output = mxnet_backend.run_node(node_def, [input1])[0] + numpy_op = np.mean(input1, axis=(1, 0), keepdims=True) + npt.assert_almost_equal(output, numpy_op, decimal=5) + +@with_seed() +def test_reduce_min(): + """Test for ReduceMin operator""" + node_def = helper.make_node("ReduceMin", ["input1"], ["output"], axes=[1, 0], keepdims=1) + input1 = np.random.ranf([3, 10]).astype("float32") + output = mxnet_backend.run_node(node_def, [input1])[0] + numpy_op = np.min(input1, axis=(1, 0), keepdims=True) + npt.assert_almost_equal(output, numpy_op) + +@with_seed() +def test_reduce_sum(): + """Test for ReduceSum operator""" + node_def = helper.make_node("ReduceSum", ["input1"], ["output"], axes=[1, 0], keepdims=1) + input1 = np.random.ranf([3, 10]).astype("float32") + output = mxnet_backend.run_node(node_def, [input1])[0] + numpy_op = np.sum(input1, axis=(1, 0), keepdims=True) + npt.assert_almost_equal(output, numpy_op, decimal=5) + +@with_seed() +def test_reduce_prod(): + """Test for ReduceProd operator""" + node_def = helper.make_node("ReduceProd", ["input1"], ["output"], axes=[1, 0], keepdims=1) + input1 = np.random.ranf([3, 10]).astype("float32") + output = mxnet_backend.run_node(node_def, [input1])[0] + numpy_op = np.prod(input1, axis=(1, 0), keepdims=True) + npt.assert_almost_equal(output, numpy_op, decimal=5) + +@with_seed() +def test_squeeze(): + """Test for Squeeze operator""" + node_def = helper.make_node("Squeeze", ["input1"], ["output"], axes=[1, 3]) + input1 = np.random.ranf([3, 1, 2, 1, 4]).astype("float32") + output = mxnet_backend.run_node(node_def, [input1])[0] + npt.assert_almost_equal(output, np.squeeze(input1, axis=[1, 3])) + +def test_super_resolution_example(): + """Test the super resolution example in the example/onnx folder""" + sys.path.insert(0, os.path.join(CURR_PATH, '../../../example/onnx/')) + import super_resolution + + sym, params = super_resolution.import_onnx() + assert sym is not None + assert params is not None + + inputs = sym.list_inputs() + assert len(inputs) == 9 + for i, input_param in enumerate(['param_7', 'param_5', 'param_3', 'param_1', + 'input_0', 'param_0', 'param_2', 'param_4', 'param_6']): + assert inputs[i] == input_param + + assert len(sym.list_outputs()) == 1 + assert sym.list_outputs()[0] == 'reshape5_output' + + attrs_keys = sym.attr_dict().keys() + assert len(attrs_keys) == 19 + for i, key_item in enumerate(['reshape4', 'param_5', 'param_4', 'param_7', + 'param_6', 'param_1', 'param_0', 'param_3', + 'param_2', 'reshape2', 'reshape3', 'reshape0', + 'reshape1', 'convolution2', 'convolution3', + 'convolution0', 'convolution1', 'reshape5', + 'transpose0']): + assert key_item in attrs_keys + + param_keys = params.keys() + assert len(param_keys) == 8 + for i, param_item in enumerate(['param_5', 'param_4', 'param_7', 'param_6', + 'param_1', 'param_0', 'param_3', 'param_2']): + assert param_item in param_keys + + logging.info("Asserted the result of the onnx model conversion") + + output_img_dim = 672 + input_image, img_cb, img_cr = super_resolution.get_test_image() + result_img = super_resolution.perform_inference(sym, params, input_image, + img_cb, img_cr) + + assert hashlib.md5(result_img.tobytes()).hexdigest() == '0d98393a49b1d9942106a2ed89d1e854' + assert result_img.size == (output_img_dim, output_img_dim) + +if __name__ == '__main__': + unittest.main()