From 7d91602ba771d973360f8a0c66c976c67f700aa3 Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Mon, 25 Jun 2018 09:43:20 -0700 Subject: [PATCH] [MXNET-533] MXNet-ONNX export (#11213) * Resolve conflicts * Export module Test Framework * refactoring export to work with pretrained models * comments added * 1. Refactored export module. 2. Refactored test framework to support ONNX backened tests. 2. Added Operator support: - Convolution2D - BatchNorm - Add * Added Arithmetic operators: - Add, Sub, Mul, Div, Sum * Added operator support: - sigmoid, relu, pad( constant, edge, reflect), tanh - enabled corresponding ONNX backend tests. * Enabled ONNX tests: test_conv, test_basic_conv Added Operators : Ceil, Floor * Added support for: MaxPool, AvgPool, GlobalMaxPool, GlobalAvgPool, matmul * adding more operators * Added Operator support: ArgMax, ArgMin, maximum, minimum * Enabled more BASIC_MODEL tests * Added power operator tests * Added support for reshape. ONNX only supports 0, -1 special values. Added only for these. Fixed logic error with convert_string_to_list() * some tests enabled * enabling squeezenet * LRN Op support * mul_scalar modified to take scalar input * cleaning some code * Resolving conlicts on rebase * Resolving rebase conflicts * id mapping updated for all operators * save onnx models added, some code cleanup * enabled more tests * conv pad calc fixed * reshape op fix * Added support for elu, leakyRelu, prelu * Cleanup - Removed run_node, not needed anymore. - Used correct get_metadata api * valueinfoproto fix, googlenet test added * Removed redundant code. - run_node - Using correct get_metadata_api * dilation added * Lint fixes * lint fixes * some fixes to make export work with onx1.2.1 * enabled more tests * mxnet_export_test file added * duplicate file deleted * reduce ops added * some small fixes * some lint fixes * Add tests for inception_v1 and inception_v2 * Add CI runs for export module * docstring added * lint fixes, pooling attr fix * fix * fix global_pool * CI run fix * code cleanup * lint fix * some code cleanup * pad in pooling added * slicechannel notimplementederror raised * Added required license comments * Lint fixes * lint fix * lint fix * lint fix * lint fix * Correct license statement * Adding onnx a runtime dependency * Fix import module error for string_types * Making ONNX runtime dependency * fixing some comments * addressing some comments * params rename * lint fixes * fixes * spatial disabled, path fixed * fixing some comments * Added support for remaining act_type(softsign, sigmoid, softrelu) in Activation operator * changing import * adding some comments * Add squeeze op * Refactored logic to handle extra node(output label node) for saved mxnet model Added comments * minor fix for squeeze operator. Also, added error handling * identity operator added * scalar ops added * Renamed onnx support folders to mark it public folders Changed underline files public or private as per usage Resolved conflicts with the latest * Added support L2Normalization op Added some error checking * added comments and warning * added comments and warning * doc API ref added --- LICENSE | 52 +- ci/docker/runtime_functions.sh | 2 + docs/api/python/contrib/onnx.md | 2 + python/mxnet/contrib/onnx/__init__.py | 5 +- python/mxnet/contrib/onnx/mx2onnx/LICENSE | 44 + python/mxnet/contrib/onnx/mx2onnx/__init__.py | 24 + .../contrib/onnx/mx2onnx/_export_helper.py | 65 + .../contrib/onnx/mx2onnx/_op_translations.py | 1863 +++++++++++++++++ .../contrib/onnx/mx2onnx/export_model.py | 95 + .../mxnet/contrib/onnx/mx2onnx/export_onnx.py | 347 +++ .../onnx/{_import => onnx2mx}/__init__.py | 0 .../_import_helper.py} | 39 +- .../_op_translations.py} | 6 +- .../_translation_utils.py} | 1 + .../onnx/{_import => onnx2mx}/import_model.py | 0 .../onnx/{_import => onnx2mx}/import_onnx.py | 2 +- .../{_import => onnx2mx}/import_to_gluon.py | 0 tests/python-pytest/onnx/export/backend.py | 97 + .../python-pytest/onnx/export/backend_rep.py | 84 + .../onnx/export/mxnet_export_test.py | 191 ++ .../onnx/export/onnx_backend_test.py | 132 ++ .../onnx/import/gluon_backend.py | 6 +- .../onnx/import/mxnet_backend.py | 3 +- .../onnx/import/mxnet_backend_rep.py | 1 - tests/python-pytest/onnx/import/test_cases.py | 1 + 25 files changed, 3028 insertions(+), 34 deletions(-) create mode 100644 python/mxnet/contrib/onnx/mx2onnx/LICENSE create mode 100644 python/mxnet/contrib/onnx/mx2onnx/__init__.py create mode 100644 python/mxnet/contrib/onnx/mx2onnx/_export_helper.py create mode 100644 python/mxnet/contrib/onnx/mx2onnx/_op_translations.py create mode 100644 python/mxnet/contrib/onnx/mx2onnx/export_model.py create mode 100644 python/mxnet/contrib/onnx/mx2onnx/export_onnx.py rename python/mxnet/contrib/onnx/{_import => onnx2mx}/__init__.py (100%) rename python/mxnet/contrib/onnx/{_import/import_helper.py => onnx2mx/_import_helper.py} (74%) rename python/mxnet/contrib/onnx/{_import/op_translations.py => onnx2mx/_op_translations.py} (99%) rename python/mxnet/contrib/onnx/{_import/translation_utils.py => onnx2mx/_translation_utils.py} (99%) rename python/mxnet/contrib/onnx/{_import => onnx2mx}/import_model.py (100%) rename python/mxnet/contrib/onnx/{_import => onnx2mx}/import_onnx.py (99%) rename python/mxnet/contrib/onnx/{_import => onnx2mx}/import_to_gluon.py (100%) create mode 100644 tests/python-pytest/onnx/export/backend.py create mode 100644 tests/python-pytest/onnx/export/backend_rep.py create mode 100644 tests/python-pytest/onnx/export/mxnet_export_test.py create mode 100644 tests/python-pytest/onnx/export/onnx_backend_test.py diff --git a/LICENSE b/LICENSE index 158bd37f2787..a8b57e583764 100644 --- a/LICENSE +++ b/LICENSE @@ -298,8 +298,6 @@ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ======================================================================================= Other Licenses ======================================================================================= @@ -512,3 +510,53 @@ For details, see, 3rdparty/dmlc-core/include/dmlc/concurrentqueue.h ======================================================================================= + + 11. ONNX Export module + For details, see, python/mxnet/contrib/onnx/_export/LICENSE + + # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file + # distributed with this work for additional information + # regarding copyright ownership. The ASF licenses this file + # to you under the Apache License, Version 2.0 (the + # "License"); you may not use this file except in compliance + # with the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, + # software distributed under the License is distributed on an + # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + # KIND, either express or implied. See the License for the + # specific language governing permissions and limitations + # under the License. + # + # Based on + # https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/# + # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. + # + # Redistribution and use in source and binary forms, with or without + # modification, are permitted provided that the following conditions + # are met: + # * Redistributions of source code must retain the above copyright + # notice, this list of conditions and the following disclaimer. + # * Redistributions in binary form must reproduce the above copyright + # notice, this list of conditions and the following disclaimer in the + # documentation and/or other materials provided with the distribution. + # * Neither the name of NVIDIA CORPORATION nor the names of its + # contributors may be used to endorse or promote products derived + # from this software without specific prior written permission. + # + # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 6e6abf06c491..07980471c580 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -700,6 +700,8 @@ integrationtest_ubuntu_cpu_onnx() { pytest tests/python-pytest/onnx/import/mxnet_backend_test.py pytest tests/python-pytest/onnx/import/onnx_import_test.py pytest tests/python-pytest/onnx/import/gluon_backend_test.py + pytest tests/python-pytest/onnx/export/onnx_backend_test.py + python tests/python-pytest/onnx/export/mxnet_export_test.py } integrationtest_ubuntu_gpu_python() { diff --git a/docs/api/python/contrib/onnx.md b/docs/api/python/contrib/onnx.md index 6fb546fc2b4e..8cd619809c19 100644 --- a/docs/api/python/contrib/onnx.md +++ b/docs/api/python/contrib/onnx.md @@ -24,6 +24,7 @@ This document describes all the ONNX-MXNet APIs. mxnet.contrib.onnx.import_model mxnet.contrib.onnx.get_model_metadata + mxnet.contrib.onnx.export_model ``` ## ONNX Tutorials @@ -46,6 +47,7 @@ This document describes all the ONNX-MXNet APIs. .. automodule:: mxnet.contrib.onnx :members: import_model :members: get_model_metadata + :members: export_model ``` diff --git a/python/mxnet/contrib/onnx/__init__.py b/python/mxnet/contrib/onnx/__init__.py index 4f9296d3c56e..9f27060d3d6f 100644 --- a/python/mxnet/contrib/onnx/__init__.py +++ b/python/mxnet/contrib/onnx/__init__.py @@ -16,5 +16,6 @@ # under the License. """Module for ONNX model format support for Apache MXNet.""" -from ._import.import_model import import_model, get_model_metadata -from ._import.import_to_gluon import import_to_gluon +from .onnx2mx.import_model import import_model, get_model_metadata +from .onnx2mx.import_to_gluon import import_to_gluon +from .mx2onnx.export_model import export_model diff --git a/python/mxnet/contrib/onnx/mx2onnx/LICENSE b/python/mxnet/contrib/onnx/mx2onnx/LICENSE new file mode 100644 index 000000000000..3abe1ee8a8ee --- /dev/null +++ b/python/mxnet/contrib/onnx/mx2onnx/LICENSE @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Based on +# https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/# +# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/mxnet/contrib/onnx/mx2onnx/__init__.py b/python/mxnet/contrib/onnx/mx2onnx/__init__.py new file mode 100644 index 000000000000..238174e4a079 --- /dev/null +++ b/python/mxnet/contrib/onnx/mx2onnx/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""ONNX Export module""" +from __future__ import absolute_import + +from . import export_model +from . import export_onnx +from . import _op_translations diff --git a/python/mxnet/contrib/onnx/mx2onnx/_export_helper.py b/python/mxnet/contrib/onnx/mx2onnx/_export_helper.py new file mode 100644 index 000000000000..781fb4cfbbc1 --- /dev/null +++ b/python/mxnet/contrib/onnx/mx2onnx/_export_helper.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""export helper functions""" +# coding: utf-8 +import os +import logging +import mxnet as mx + + +def load_module(sym_filepath, params_filepath): + """Loads the MXNet model file and + returns MXNet symbol and params (weights). + + Parameters + ---------- + json_path : str + Path to the json file + params_path : str + Path to the params file + + Returns + ------- + sym : MXNet symbol + Model symbol object + + params : params object + Model weights including both arg and aux params. + """ + if not (os.path.isfile(sym_filepath) and os.path.isfile(params_filepath)): + raise ValueError("Symbol and params files provided are invalid") + else: + try: + # reads symbol.json file from given path and + # retrieves model prefix and number of epochs + model_name = sym_filepath.rsplit('.', 1)[0].rsplit('-', 1)[0] + params_file_list = params_filepath.rsplit('.', 1)[0].rsplit('-', 1) + # Setting num_epochs to 0 if not present in filename + num_epochs = 0 if len(params_file_list) == 1 else int(params_file_list[1]) + except IndexError: + logging.info("Model and params name should be in format: " + "prefix-symbol.json, prefix-epoch.params") + raise + + sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs) + + # Merging arg and aux parameters + params = {} + params.update(arg_params) + params.update(aux_params) + + return sym, params diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py new file mode 100644 index 000000000000..5f5561ab32b6 --- /dev/null +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -0,0 +1,1863 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Based on +# https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/ +# mx2onnx_converter_functions.py +# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# coding: utf-8 +# pylint: disable=too-many-locals,no-else-return,too-many-lines +# pylint: disable=anomalous-backslash-in-string,eval-used +""" +Conversion Functions for common layers. +Add new functions here with a decorator. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import re +import logging +import numpy as np +from .export_onnx import MXNetGraph as mx_op + +def import_onnx_modules(): + """ To make sure ONNX is runtime dependency, it is imported used only when needed""" + try: + from onnx import helper, numpy_helper, mapping + except ImportError: + raise ImportError("Onnx and protobuf need to be installed. " + + "Instructions to install - https://github.com/onnx/onnx") + return helper, numpy_helper, mapping + + +def parse_helper(attrs, attrs_name, alt_value=None): + """Helper function to parse operator attributes in required format.""" + tuple_re = re.compile('\([0-9L|,| ]+\)') + if attrs is None: + return alt_value + attrs_str = None if attrs.get(attrs_name) is None else str(attrs.get(attrs_name)) + if attrs_str is None: + return alt_value + attrs_match = tuple_re.search(attrs_str) + if attrs_match is not None: + if attrs_match.span() == (0, len(attrs_str)): + dims = eval(attrs_str) + return dims + else: + raise AttributeError("Malformed %s dimensions: %s" % (attrs_name, str(attrs_str))) + return alt_value + +def transform_padding(pad_width): + """Helper function to convert padding format for pad operator. + """ + num_pad_values = len(pad_width) + onnx_pad_width = [0]*num_pad_values + + start_index = 0 + # num_pad_values will always be multiple of 2 + end_index = int(num_pad_values/2) + for idx in range(0, num_pad_values): + if idx % 2 == 0: + onnx_pad_width[start_index] = pad_width[idx] + start_index += 1 + else: + onnx_pad_width[end_index] = pad_width[idx] + end_index += 1 + + return onnx_pad_width + + +def convert_string_to_list(string_val): + """Helper function to convert string to list. + Used to convert shape attribute string to list format. + """ + result_list = [] + + list_string = string_val.split(',') + for val in list_string: + val = str(val.strip()) + val = val.replace("(", "") + val = val.replace(")", "") + val = val.replace("L", "") + val = val.replace("[", "") + val = val.replace("]", "") + if val != "" and val != "None": + result_list.append(int(val)) + + return result_list + +@mx_op.register("null") +def convert_weights_and_inputs(node, **kwargs): + """Helper function to convert weights and inputs. + """ + + helper, _, mapping = import_onnx_modules() + name = node["name"] + + if kwargs["is_input"] is False: + weights = kwargs["weights"] + initializer = kwargs["initializer"] + np_arr = weights[name] + data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype] + dims = np.shape(np_arr) + + tensor_node = helper.make_tensor_value_info(name, data_type, dims) + + initializer.append( + helper.make_tensor( + name=name, + data_type=data_type, + dims=dims, + vals=np_arr.flatten().tolist(), + raw=False, + ) + ) + + return [tensor_node] + else: + tval_node = helper.make_tensor_value_info(name, kwargs["in_type"], kwargs["in_shape"]) + return [tval_node] + + +@mx_op.register("Convolution") +def convert_convolution(node, **kwargs): + """Map MXNet's convolution operator attributes to onnx's Conv operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + + num_inputs = len(inputs) + + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name + weights_node = proc_nodes[kwargs["index_lookup"][inputs[1][0]]].name + + if num_inputs > 2: + bias_node = proc_nodes[kwargs["index_lookup"][inputs[2][0]]].name + + attrs = node.get("attrs") + + kernel_dims = list(parse_helper(attrs, "kernel")) + stride_dims = list(parse_helper(attrs, "stride", [1, 1])) + pad_dims = list(parse_helper(attrs, "pad", [0, 0])) + num_group = int(attrs.get("num_group", 1)) + dilations = list(parse_helper(attrs, "dilate", [1, 1])) + + pad_dims = pad_dims + pad_dims + + input_nodes = [input_node, weights_node] + if num_inputs > 2: + input_nodes.append(bias_node) + + conv_node = helper.make_node( + "Conv", + inputs=input_nodes, + outputs=[name], + kernel_shape=kernel_dims, + strides=stride_dims, + dilations=dilations, + pads=pad_dims, + group=num_group, + name=name + ) + + return [conv_node] + + +@mx_op.register("FullyConnected") +def convert_fully_connected(node, **kwargs): + """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + input_node_id = kwargs["index_lookup"][inputs[0][0]] + weight_node_id = kwargs["index_lookup"][inputs[1][0]] + bias_node_id = kwargs["index_lookup"][inputs[2][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_id] + weights_node = proc_nodes[weight_node_id] + bias_node = proc_nodes[bias_node_id] + + input_name = input_node.name + weights_name = weights_node.name + bias_name = bias_node.name + + node = helper.make_node( + "Gemm", + [input_name, weights_name, bias_name], # input (A, B, C) - C can be in place + [name], # output + alpha=1.0, + beta=1.0, + transA=False, + transB=True, + name=name + ) + + return [node] + + +@mx_op.register("BatchNorm") +def convert_batchnorm(node, **kwargs): + """Map MXNet's BatchNorm operator attributes to onnx's BatchNormalization operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + attrs = node["attrs"] + momentum = float(node.get("attrs", {}).get("momentum", 0.9)) + eps = float(attrs.get("eps", 0.001)) + + data_idx = kwargs["index_lookup"][inputs[0][0]] + gamma_idx = kwargs["index_lookup"][inputs[1][0]] + beta_idx = kwargs["index_lookup"][inputs[2][0]] + moving_mean_idx = kwargs["index_lookup"][inputs[3][0]] + moving_var_idx = kwargs["index_lookup"][inputs[4][0]] + + data_node = proc_nodes[data_idx].name + gamma_node = proc_nodes[gamma_idx].name + beta_node = proc_nodes[beta_idx].name + + mov_mean_node = proc_nodes[moving_mean_idx] + mov_mean_node = mov_mean_node.name + mov_var_node = proc_nodes[moving_var_idx].name + + bn_node = helper.make_node( + "BatchNormalization", + [data_node, + gamma_node, # scale + beta_node, # bias + mov_mean_node, + mov_var_node + ], + [name], + name=name, + epsilon=eps, + momentum=momentum, + # MXNet computes mean and variance per feature for batchnorm + # Default for onnx is across all spatial features. So disabling the parameter. + spatial=0 + ) + return [bn_node] + + +@mx_op.register("tanh") +def convert_tanh(node, **kwargs): + """Map MXNet's tanh operator attributes to onnx's Tanh operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_idx].name + + node = helper.make_node( + 'Tanh', + [input_node], + [name], + name=name + ) + return [node] + +#Basic neural network functions +@mx_op.register("sigmoid") +def convert_sigmoid(node, **kwargs): + """Map MXNet's sigmoid operator attributes to onnx's Sigmoid operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_idx].name + + node = helper.make_node( + 'Sigmoid', + [input_node], + [name], + name=name + ) + return [node] + +@mx_op.register("relu") +def convert_relu(node, **kwargs): + """Map MXNet's relu operator attributes to onnx's Relu operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_idx].name + + node = helper.make_node( + 'Relu', + [input_node], + [name], + name=name + ) + + return [node] + +@mx_op.register("Activation") +def convert_activation(node, **kwargs): + """Map MXNet's Activation operator attributes to onnx's Tanh/Relu operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + + proc_nodes = kwargs["proc_nodes"] + attrs = node["attrs"] + act_type = attrs["act_type"] + + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_idx].output[0] + + # Creating a dictionary here, but if this titlecase pattern + # mxnet_name.title() + act_types = { + "tanh": "Tanh", + "relu": "Relu", + "sigmoid": "Sigmoid", + "softrelu": "Softplus", + "softsign": "Softsign" + } + + act_name = act_types.get(act_type) + if act_name: + node = helper.make_node( + act_name, + [input_node], + [name], + name=name + ) + else: + raise AttributeError( + "Activation %s not implemented or recognized in the converter" % act_type + ) + + return [node] + + +@mx_op.register("Pad") +def convert_pad(node, **kwargs): + """Map MXNet's pad operator attributes to onnx's Pad operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + attrs = node["attrs"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_idx].name + + mxnet_pad_width = convert_string_to_list(attrs.get("pad_width")) + onnx_pad_width = transform_padding(mxnet_pad_width) + + pad_mode = attrs.get("mode") + + if pad_mode == "constant": + pad_value = float(attrs.get("constant_value")) \ + if "constant_value" in attrs else 0.0 + node = helper.make_node( + 'Pad', + inputs=[input_node], + outputs=[name], + mode='constant', + value=pad_value, + pads=onnx_pad_width, + name=name + ) + else: + node = helper.make_node( + 'Pad', + inputs=[input_node], + outputs=[name], + mode=pad_mode, + pads=onnx_pad_width, + name=name + ) + + return [node] + + +@mx_op.register("_linalg_gemm2") +def convert_linalg_gemm2(node, **kwargs): + """Map MXNet's _linalg_gemm2 operator attributes to onnx's + MatMul and Transpose operators based on the values set for + transpose_a, transpose_b attributes. + Return multiple nodes created. + """ + helper, _, _ = import_onnx_modules() + proc_nodes = kwargs["proc_nodes"] + node_inputs = node["inputs"] + name = node["name"] + + input_a_idx = kwargs["index_lookup"][node_inputs[0][0]] + input_node_a = proc_nodes[input_a_idx].name + input_b_idx = kwargs["index_lookup"][node_inputs[1][0]] + input_node_b = proc_nodes[input_b_idx].name + + # Getting the attributes and assigning default values. + if "attrs" in node: + attrs = node["attrs"] + alpha = float(attrs["alpha"]) + trans_a = int(attrs["transpose_a"]) + trans_b = int(attrs["transpose_b"]) + else: + alpha = 1.0 + trans_a = 0 + trans_b = 0 + + op_name = "transpose" + str(kwargs["idx"]) + + if alpha == 1.0 and trans_a == 0 and trans_b == 0: + matmul_node = helper.make_node( + 'MatMul', + inputs=[input_node_a, input_node_b], + outputs=[name], + name=name + ) + return [matmul_node] + elif trans_a == 1 and trans_b == 0: + op_name = "transpose" + str(kwargs["idx"]) + node_name = op_name+"_a" + trans_a_node = helper.make_node( + 'Transpose', + inputs=[input_node_a], + outputs=[op_name+"_a"], + name=node_name + ) + + matmul_node = helper.make_node( + 'MatMul', + inputs=[node_name, input_node_b], + outputs=[name], + name=name + ) + return [trans_a_node, matmul_node] + + elif trans_a == 0 and trans_b == 1: + node_name = op_name + "_b" + trans_b_node = helper.make_node( + 'Transpose', + inputs=[input_node_b], + outputs=[op_name+"_b"], + name=node_name + ) + + matmul_node = helper.make_node( + 'MatMul', + inputs=[input_node_a, node_name], + outputs=[name], + name=name + ) + + return [trans_b_node, matmul_node] + else: + node_name_a = op_name+"_a" + trans_a_node = helper.make_node( + 'Transpose', + inputs=[input_node_a], + outputs=[op_name+"_a"], + name=node_name_a + ) + + node_name_b = op_name + "_b" + trans_b_node = helper.make_node( + 'Transpose', + inputs=[input_node_b], + outputs=[op_name+"_b"], + name=node_name_b + ) + + matmul_node = helper.make_node( + 'MatMul', + inputs=[node_name_a, node_name_b], + outputs=[name], + name=name + ) + + return [trans_a_node, trans_b_node, matmul_node] + + +@mx_op.register("Pooling") +def convert_pooling(node, **kwargs): + """Map MXNet's Pooling operator attributes to onnx's + MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators + based on the input node's attributes and return the created node. + """ + helper, _, _ = import_onnx_modules() + proc_nodes = kwargs["proc_nodes"] + attrs = node["attrs"] + kernel = eval(attrs["kernel"]) + pool_type = attrs["pool_type"] + stride = eval(attrs["stride"]) if attrs.get("stride") else None + global_pool = True if "global_pool" in attrs and\ + attrs.get("global_pool") == "True" else False + node_inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] + input_node = proc_nodes[input_node_idx] + name = node["name"] + + pooling_convention = attrs.get('pooling_convention', 'valid') + + if pooling_convention == 'full': + pooling_warning = "Pooling: ONNX currently doesn't support pooling_convention. " \ + "This might lead to shape or accuracy issues. " \ + "https://github.com/onnx/onnx/issues/549" + + logging.warning(pooling_warning) + + pad_dims = list(parse_helper(attrs, "pad", [0, 0])) + pad_dims = pad_dims + pad_dims + pool_types = {"max": "MaxPool", "avg": "AveragePool"} + global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"} + + if global_pool: + node = helper.make_node( + global_pool_types[pool_type], + [input_node.name], # input + [name], + name=name + ) + else: + node = helper.make_node( + pool_types[pool_type], + [input_node.name], # input + [name], + kernel_shape=kernel, + pads=pad_dims, + strides=stride, + name=name + ) + + return [node] + + +@mx_op.register("exp") +def convert_exp(node, **kwargs): + """Map MXNet's exp operator attributes to onnx's Exp operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Exp", + [input_node], + [name], + name=name, + ) + return [node] + + +@mx_op.register("_copy") +def convert_identity(node, **kwargs): + """Map MXNet's _copy operator attributes to onnx's Identity operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Identity", + [input_node], + [name], + name=name, + ) + return [node] + + +@mx_op.register("LeakyReLU") +def convert_leakyrelu(node, **kwargs): + """Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu operators + based on the input node's attributes and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + attrs = node["attrs"] + + act_type = attrs.get("act_type", "LeakyRelu") + alpha = float(attrs.get("slope", 0.25)) + + act_name = {"elu": "Elu", "LeakyRelu": "LeakyRelu", "prelu": "PRelu"} + + if act_type == "prelu": + alpha_node_index = kwargs["index_lookup"][inputs[1][0]] + alpha_node_name = proc_nodes[alpha_node_index].name + + node = helper.make_node( + act_name[act_type], + inputs=[input_node, alpha_node_name], + outputs=[name], + name=name) + else: + node = helper.make_node( + act_name[act_type], + inputs=[input_node], + outputs=[name], + name=name, + alpha=alpha) + + return [node] + + +@mx_op.register("softmax") +def convert_softmax(node, **kwargs): + """Map MXNet's softmax operator attributes to onnx's Softmax operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + inputs = node["inputs"] + input_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_idx] + + name = node["name"] + axis = int(node.get("attrs", {}).get("axis", -1)) + + softmax_node = helper.make_node( + "Softmax", + [input_node.name], + [name], + axis=axis, + name=name + ) + + return [softmax_node] + + +# There's also mx.sym.softmax(), which doesn't do cross-entropy loss, +# just softmax for inference - hence the name convert_softmax_output. +@mx_op.register("SoftmaxOutput") +def convert_softmax_output(node, **kwargs): + """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + inputs = node["inputs"] + input1_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input1 = proc_nodes[input1_idx] + name = node["name"] + + softmax_node = helper.make_node( + "Softmax", + [input1.output[0]], + [name], + axis=1, + name=name + ) + + return [softmax_node] + + +@mx_op.register("Concat") +def convert_concat(node, **kwargs): + """Map MXNet's Concat operator attributes to onnx's Concat operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + proc_nodes = kwargs["proc_nodes"] + input_names = [proc_nodes[kwargs["index_lookup"][i[0]]].name for i in inputs] + axis = int(node.get("attrs", {}).get("dim", 1)) + concat_node = helper.make_node( + "Concat", + input_names, + [name], + axis=axis, + name=name + ) + return [concat_node] + + +@mx_op.register("transpose") +def convert_transpose(node, **kwargs): + """Map MXNet's transpose operator attributes to onnx's Transpose operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + input_idx = kwargs["index_lookup"][node["inputs"][0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_idx].name + axes = node.get("attrs", {}).get("axes", ()) + if axes: + axes = tuple(map(int, re.findall(r'\d+', axes))) + + transpose_node = helper.make_node( + "Transpose", + [input_node], + [name], + perm=axes, + name=name + ) + else: + transpose_node = helper.make_node( + "Transpose", + [input_node], + [name], + name=name + ) + + return [transpose_node] + + +@mx_op.register("LRN") +def convert_lrn(node, **kwargs): + """Map MXNet's LRN operator attributes to onnx's LRN operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + input_idx = kwargs["index_lookup"][node["inputs"][0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_idx].name + + attrs = node["attrs"] + alpha = float(attrs["alpha"]) if "alpha" in attrs else 0.0001 + beta = float(attrs["beta"]) if "beta" in attrs else 0.75 + bias = float(attrs["knorm"]) if "knorm" in attrs else 1.0 + size = int(attrs["nsize"]) + + lrn_node = helper.make_node( + "LRN", + inputs=[input_node], + outputs=[name], + name=name, + alpha=alpha, + beta=beta, + bias=bias, + size=size + ) + + return [lrn_node] + + +@mx_op.register("L2Normalization") +def convert_l2normalization(node, **kwargs): + """Map MXNet's L2Normalization operator attributes to onnx's LpNormalization operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + input_id = kwargs["index_lookup"][node["inputs"][0][0]] + input_name = kwargs["proc_nodes"][input_id].name + attrs = node["attrs"] + mode = attrs.get("mode", "instance") + + if mode != "channel": + raise AttributeError("ONNX currently supports channel mode only") + + l2norm_node = helper.make_node( + "LpNormalization", + [input_name], + [name], + axis=1, # channel only + name=name + ) + return [l2norm_node] + + +@mx_op.register("Dropout") +def convert_dropout(node, **kwargs): + """Map MXNet's Dropout operator attributes to onnx's Dropout operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + input_id = kwargs["index_lookup"][node["inputs"][0][0]] + input_name = kwargs["proc_nodes"][input_id].name + attrs = node["attrs"] + probability = float(attrs["p"]) + + dropout_node = helper.make_node( + "Dropout", + [input_name], + [name], + ratio=probability, + name=name + ) + return [dropout_node] + + +@mx_op.register("Flatten") +def convert_flatten(node, **kwargs): + """Map MXNet's Flatten operator attributes to onnx's Flatten operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + input_idx = kwargs["index_lookup"][node["inputs"][0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_idx].name # .output[0] + + flatten_node = helper.make_node( + "Flatten", + [input_node], + [name], + name=name + ) + return [flatten_node] + + +def scalar_op_helper(node, op_name, **kwargs): + """Helper function for scalar arithmetic operations""" + helper, numpy_helper, mapping = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + scalar_value = [float(node.get("attrs", {}).get("scalar", 1))] + + input_name_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_name_id].name + + initializer = kwargs["initializer"] + flag = True + # If the input value is in initializer, just multiply with scalar input + # and create a new initializer + for i in initializer: + if i.name == input_node: + if op_name == 'Mul': + new_initializer = numpy_helper.to_array(i) * scalar_value[0] + elif op_name == 'Sub': + new_initializer = numpy_helper.to_array(i) - scalar_value[0] + elif op_name == 'Add': + new_initializer = numpy_helper.to_array(i) + scalar_value[0] + elif op_name == 'Div': + new_initializer = numpy_helper.to_array(i) / scalar_value[0] + flag = False + break + + # else create a new tensor of the scalar value, add it in initializer + if flag is True: + np_arr = np.array(scalar_value) + data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype] + dims = np.shape(np_arr) + + scalar_op_name = "scalar_op" + str(kwargs["idx"]) + tensor_node = helper.make_tensor_value_info(scalar_op_name, data_type, dims) + + initializer.append( + helper.make_tensor( + name=scalar_op_name, + data_type=data_type, + dims=dims, + vals=scalar_value, + raw=False, + ) + ) + + mul_node = helper.make_node( + op_name, + [input_node, scalar_op_name], + [name], + name=name + ) + + return [tensor_node, mul_node] + else: + data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[new_initializer.dtype] + dims = np.shape(new_initializer) + + new_a_node = input_node + str(kwargs["idx"]) + tensor_node = helper.make_tensor_value_info(new_a_node, data_type, dims) + + initializer.append( + helper.make_tensor( + name=new_a_node, + data_type=data_type, + dims=dims, + vals=new_initializer, + raw=False, + ) + ) + return [tensor_node] + +# Convert scalar value into node and pass it as input to mul_node +@mx_op.register("_mul_scalar") +def convert_mul_scalar(node, **kwargs): + """Map MXNet's _mul_scalar operator attributes to onnx's Mul operator. + Creates a new node for the input scalar value, adds it to the initializer + and return multiple created nodes. + """ + return scalar_op_helper(node, 'Mul', **kwargs) + + +# Convert scalar value into node and pass it as input to mul_node +@mx_op.register("_minus_scalar") +def convert_minus_scalar(node, **kwargs): + """Map MXNet's _minus_scalar operator attributes to onnx's Minus operator. + Creates a new node for the input scalar value, adds it to the initializer + and return multiple created nodes. + """ + return scalar_op_helper(node, 'Sub', **kwargs) + + +# Convert scalar value into node and pass it as input to mul_node +@mx_op.register("_plus_scalar") +def convert_add_scalar(node, **kwargs): + """Map MXNet's _plus_scalar operator attributes to onnx's Add operator. + Creates a new node for the input scalar value, adds it to the initializer + and return multiple created nodes. + """ + return scalar_op_helper(node, 'Add', **kwargs) + +# Convert scalar value into node and pass it as input to mul_node +@mx_op.register("_div_scalar") +def convert_div_scalar(node, **kwargs): + """Map MXNet's _div_scalar operator attributes to onnx's Div operator. + Creates a new node for the input scalar value, adds it to the initializer + and return multiple created nodes. + """ + return scalar_op_helper(node, 'Div', **kwargs) + + +# Sorting and Searching +@mx_op.register("argmax") +def convert_argmax(node, **kwargs): + """Map MXNet's argmax operator attributes to onnx's ArgMax operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + proc_nodes = kwargs["proc_nodes"] + node_inputs = node["inputs"] + + input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] + input_node = proc_nodes[input_node_idx].name + name = node["name"] + attrs = node["attrs"] + + axis = int(attrs.get("axis")) + keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1 + + node = helper.make_node( + 'ArgMax', + inputs=[input_node], + axis=axis, + keepdims=keepdims, + outputs=[name], + name=name + ) + return [node] + +@mx_op.register("argmin") +def convert_argmin(node, **kwargs): + """Map MXNet's argmin operator attributes to onnx's ArgMin operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + proc_nodes = kwargs["proc_nodes"] + node_inputs = node["inputs"] + + input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] + input_node = proc_nodes[input_node_idx].name + name = node["name"] + attrs = node["attrs"] + + axis = int(attrs.get("axis")) + keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1 + + node = helper.make_node( + 'ArgMin', + inputs=[input_node], + axis=axis, + keepdims=keepdims, + outputs=[name], + name=name + ) + return [node] + +@mx_op.register("_maximum") +def convert_maximum(node, **kwargs): + """Map MXNet's _maximum operator attributes to onnx's Max operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + proc_nodes = kwargs["proc_nodes"] + node_inputs = node["inputs"] + + input_node_list = [] + for node_input in node_inputs: + node_id = kwargs["index_lookup"][node_input[0]] + input_node_list.append(proc_nodes[node_id].name) + + name = node["name"] + + node = helper.make_node( + 'Max', + inputs=input_node_list, + outputs=[name], + name=name, + ) + + return [node] + + +@mx_op.register("_minimum") +def convert_minimum(node, **kwargs): + """Map MXNet's _minimum operator attributes to onnx's Min operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + proc_nodes = kwargs["proc_nodes"] + node_inputs = node["inputs"] + + input_node_list = [] + for node_input in node_inputs: + node_id = kwargs["index_lookup"][node_input[0]] + input_node_list.append(proc_nodes[node_id].name) + + name = node["name"] + + node = helper.make_node( + 'Min', + inputs=input_node_list, + outputs=[name], + name=name, + ) + + return [node] + + +@mx_op.register("min") +def convert_min(node, **kwargs): + """Map MXNet's min operator attributes to onnx's ReduceMin operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + mx_axis = node.get("attrs", {}).get("axis", None) + axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None + + keepdims = int(node.get("attrs", {}).get("keepdims", 0)) + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + if axes is not None: + node = helper.make_node( + 'ReduceMin', + inputs=[input_node], + outputs=[name], + axes=axes, + keepdims=keepdims, + name=name + ) + + return [node] + else: + node = helper.make_node( + 'ReduceMin', + inputs=[input_node], + outputs=[name], + keepdims=keepdims, + name=name + ) + + return [node] + + +@mx_op.register("max") +def convert_max(node, **kwargs): + """Map MXNet's max operator attributes to onnx's ReduceMax operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + mx_axis = node.get("attrs", {}).get("axis", None) + axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None + + keepdims = int(node.get("attrs", {}).get("keepdims", 0)) + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + if axes is not None: + node = helper.make_node( + 'ReduceMax', + inputs=[input_node], + outputs=[name], + axes=axes, + keepdims=keepdims, + name=name + ) + + return [node] + else: + node = helper.make_node( + 'ReduceMax', + inputs=[input_node], + outputs=[name], + keepdims=keepdims, + name=name + ) + + return [node] + + +@mx_op.register("mean") +def convert_mean(node, **kwargs): + """Map MXNet's mean operator attributes to onnx's ReduceMean operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + mx_axis = node.get("attrs", {}).get("axis", None) + axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None + + keepdims = int(node.get("attrs", {}).get("keepdims", 0)) + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + if axes is not None: + node = helper.make_node( + 'ReduceMean', + inputs=[input_node], + outputs=[name], + axes=axes, + keepdims=keepdims, + name=name + ) + + return [node] + else: + node = helper.make_node( + 'ReduceMean', + inputs=[input_node], + outputs=[name], + keepdims=keepdims, + name=name + ) + + return [node] + + +@mx_op.register("prod") +def convert_prod(node, **kwargs): + """Map MXNet's prod operator attributes to onnx's ReduceProd operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + mx_axis = node.get("attrs", {}).get("axis", None) + axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None + + keepdims = int(node.get("attrs", {}).get("keepdims", 0)) + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + if axes is not None: + node = helper.make_node( + 'ReduceProd', + inputs=[input_node], + outputs=[name], + axes=axes, + keepdims=keepdims, + name=name + ) + + return [node] + else: + node = helper.make_node( + 'ReduceProd', + inputs=[input_node], + outputs=[name], + keepdims=keepdims, + name=name + ) + + return [node] + + +# Arithmetic Operations +@mx_op.register("elemwise_add") +def convert_elementwise_add(node, **kwargs): + """Map MXNet's elemwise_add operator attributes to onnx's Add operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + add_node = helper.make_node( + "Add", + [input_node_a, input_node_b], + [name], + name=name, + ) + + return [add_node] + + +@mx_op.register("broadcast_add") +def covert_broadcast_add(node, **kwargs): + """Map MXNet's broadcast_add operator attributes to onnx's Add operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + add_node = helper.make_node( + "Add", + [input_node_a, input_node_b], + [name], + name=name, + ) + + return [add_node] + + +@mx_op.register("elemwise_sub") +def convert_elementwise_sub(node, **kwargs): + """Map MXNet's elemwise_sub operator attributes to onnx's Sub operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + sub_node = helper.make_node( + "Sub", + [input_node_a, input_node_b], + [name], + name=name, + ) + + return [sub_node] + +@mx_op.register("broadcast_sub") +def covert_broadcast_sub(node, **kwargs): + """Map MXNet's broadcast_sub operator attributes to onnx's Sub operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + sub_node = helper.make_node( + "Sub", + [input_node_a, input_node_b], + [name], + name=name, + ) + + return [sub_node] + + +@mx_op.register("elemwise_mul") +def convert_elemwise_mul(node, **kwargs): + """Map MXNet's elemwise_mul operator attributes to onnx's Mul operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + mul_node = helper.make_node( + "Mul", + [input_node_a, input_node_b], + [name], + name=name, + ) + + return [mul_node] + +@mx_op.register("broadcast_mul") +def convert_broadcast_mul(node, **kwargs): + """Map MXNet's broadcast_mul operator attributes to onnx's Mul operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + mul_node = helper.make_node( + "Mul", + [input_node_a, input_node_b], + [name], + name=name + ) + + return [mul_node] + + +@mx_op.register("elemwise_div") +def convert_elemwise_div(node, **kwargs): + """Map MXNet's elemwise_div operator attributes to onnx's Div operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + div_node = helper.make_node( + "Div", + [input_node_a, input_node_b], + [name], + name=name + ) + + return [div_node] + + +@mx_op.register("broadcast_div") +def convert_broadcast_div(node, **kwargs): + """Map MXNet's broadcast_div operator attributes to onnx's Div operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + div_node = helper.make_node( + "Div", + [input_node_a, input_node_b], + [name], + name=name + ) + + return [div_node] + + +@mx_op.register("negative") +def convert_negative(node, **kwargs): + """Map MXNet's negative operator attributes to onnx's Neg operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + + input_node = proc_nodes[input_node_id].name + + neg_node = helper.make_node( + "Neg", + [input_node], + [name], + name=name, + ) + + return [neg_node] + + +@mx_op.register("abs") +def convert_abs(node, **kwargs): + """Map MXNet's abs operator attributes to onnx's Abs operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + + input_node = proc_nodes[input_node_id].name + + abs_node = helper.make_node( + "Abs", + [input_node], + [name], + name=name + ) + + return [abs_node] + + +@mx_op.register("add_n") +def convert_addn(node, **kwargs): + """Map MXNet's add_n operator attributes to onnx's Sum operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_list = [] + for input_val in inputs: + input_list.append(proc_nodes[kwargs["index_lookup"][input_val[0]]].name) + + sum_node = helper.make_node( + "Sum", + input_list, + [name], + name=name + ) + return [sum_node] + + # Rounding +@mx_op.register("ceil") +def convert_ceil(node, **kwargs): + """Map MXNet's ceil operator attributes to onnx's Ceil operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Ceil", + [input_node], + [name], + name=name + ) + return [node] + +@mx_op.register("floor") +def convert_floor(node, **kwargs): + """Map MXNet's floor operator attributes to onnx's Floor operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Floor", + [input_node], + [name], + name=name + ) + return [node] + +# Changing shape and type. +@mx_op.register("Reshape") +def convert_reshape(node, **kwargs): + """Map MXNet's Reshape operator attributes to onnx's Reshape operator. + Converts output shape attribute to output shape tensor + and return multiple created nodes. + """ + helper, _, mapping = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + attrs = node["attrs"] + + output_shape_list = convert_string_to_list(attrs["shape"]) + + initializer = kwargs["initializer"] + output_shape_np = np.array(output_shape_list) + data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[output_shape_np.dtype] + dims = np.shape(output_shape_np) + + output_shape_name = "reshape_attr_tensor" + str(kwargs["idx"]) + tensor_node = helper.make_tensor_value_info(output_shape_name, data_type, dims) + + initializer.append( + helper.make_tensor( + name=output_shape_name, + data_type=data_type, + dims=dims, + vals=output_shape_list, + raw=False, + ) + ) + + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + input_node_name = proc_nodes[input_node_idx].name + + not_supported_shape = [-2, -3, -4] + + for val in output_shape_list: + if val in not_supported_shape: + raise AttributeError("Shape value not supported in ONNX", val) + + reshape_node = helper.make_node( + "Reshape", + [input_node_name, output_shape_name], + [name], + name=name + ) + + return [tensor_node, reshape_node] + +@mx_op.register("Cast") +def convert_cast(node, **kwargs): + """Map MXNet's Cast operator attributes to onnx's Cast operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + dtype = node["attrs"]["dtype"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Cast", + [input_node], + [name], + to=dtype, + name=name, + ) + return [node] + + +@mx_op.register("slice_axis") +def convert_slice_axis(node, **kwargs): + """Map MXNet's slice_axis operator attributes to onnx's Slice operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + axes = int(node["attrs"]["axis"]) + starts = int(node["attrs"]["begin"]) + if node["attrs"]["end"] == 'None': + raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' attribute") + else: + ends = int(node["attrs"]["end"]) + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Slice", + [input_node], + [name], + axes=[axes], + starts=[starts], + ends=[ends], + name=name, + ) + return [node] + + +@mx_op.register("SliceChannel") +def convert_slice_channel(node, **kwargs): + """Map MXNet's SliceChannel operator attributes to onnx's Squeeze or Split + operator based on squeeze_axis attribute + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + num_outputs = int(node.get("attrs", {})["num_outputs"]) + axis = int(node.get("attrs", {}).get("axis", 1)) + squeeze_axis = int(node.get("attrs", {}).get("squeeze_axis", 0)) + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + if squeeze_axis == 1 and num_outputs == 1: + node = helper.make_node( + "Squeeze", + [input_node], + [name], + axes=[axis], + name=name, + ) + return [node] + elif squeeze_axis == 0 and num_outputs > 1: + node = helper.make_node( + "Split", + [input_node], + [name], + axis=axis, + split=[num_outputs], + name=name, + ) + return [node] + else: + raise NotImplementedError("SliceChannel operator with num_outputs>1 and" + "squeeze_axis true is not implemented.") + + +@mx_op.register("expand_dims") +def convert_expand_dims(node, **kwargs): + """Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + axis = int(node["attrs"]["axis"]) + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Unsqueeze", + [input_node], + [name], + axes=[axis], + name=name, + ) + return [node] + +@mx_op.register("squeeze") +def convert_squeeze(node, **kwargs): + """Map MXNet's squeeze operator attributes to onnx's squeeze operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + if "axis" in node["attrs"]: + axis = convert_string_to_list(node["attrs"]["axis"]) + else: + raise AttributeError("Missing axis attribute: ONNX currently requires axis to " + "be specified for squeeze operator") + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Squeeze", + [input_node], + [name], + axes=axis, + name=name, + ) + return [node] + + +@mx_op.register("log") +def convert_log(node, **kwargs): + """Map MXNet's log operator attributes to onnx's Log operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Log", + [input_node], + [name], + name=name, + ) + return [node] + + +@mx_op.register("reciprocal") +def convert_reciprocal(node, **kwargs): + """Map MXNet's reciprocal operator attributes to onnx's Reciprocal operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Reciprocal", + [input_node], + [name], + name=name, + ) + return [node] + + +@mx_op.register("_power") +def convert_power(node, **kwargs): + """Map MXNet's _power operator attributes to onnx's Pow operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + node = helper.make_node( + "Pow", + [input_node_a, input_node_b], + [name], + name=None + ) + return [node] + +@mx_op.register("sqrt") +def convert_sqrt(node, **kwargs): + """Map MXNet's sqrt operator attributes to onnx's Sqrt operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = helper.make_node( + "Sqrt", + [input_node], + [name], + name=name, + ) + return [node] diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py new file mode 100644 index 000000000000..0dbfdc1d7b92 --- /dev/null +++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +#pylint: disable-msg=too-many-arguments + +"""export function""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +import logging +import numpy as np + +from ....base import string_types +from .... import symbol +from .export_onnx import MXNetGraph +from ._export_helper import load_module + + +def export_model(sym, params, input_shape, input_type=np.float32, + onnx_file_path='model.onnx', verbose=False): + """Exports the MXNet model file, passed as a parameter, into ONNX model. + Accepts both symbol,parameter objects as well as json and params filepaths as input. + Operator support and coverage - https://cwiki.apache.org/confluence/display/MXNET/ONNX + + Parameters + ---------- + sym : str or symbol object + Path to the json file or Symbol object + params : str or symbol object + Path to the params file or params dictionary. (Including both arg_params and aux_params) + input_shape : List of tuple + Input shape of the model e.g [(1,3,224,224)] + input_type : data type + Input data type e.g. np.float32 + onnx_file_path : str + Path where to save the generated onnx file + verbose : Boolean + If true will print logs of the model conversion + + Returns + ------- + onnx_file_path : str + Onnx file path + """ + + try: + from onnx import helper, mapping + except ImportError: + raise ImportError("Onnx and protobuf need to be installed. " + + "Instructions to install - https://github.com/onnx/onnx") + + converter = MXNetGraph() + + data_format = np.dtype(input_type) + # if input parameters are strings(file paths), load files and create symbol parameter objects + if isinstance(sym, string_types) and isinstance(params, string_types): + logging.info("Converting json and weight file to sym and params") + sym_obj, params_obj = load_module(sym, params) + onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape, + mapping.NP_TYPE_TO_TENSOR_TYPE[data_format], + verbose=verbose) + elif isinstance(sym, symbol.Symbol) and isinstance(params, dict): + onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape, + mapping.NP_TYPE_TO_TENSOR_TYPE[data_format], + verbose=verbose) + else: + raise ValueError("Input sym and params should either be files or objects") + + # Create the model (ModelProto) + onnx_model = helper.make_model(onnx_graph) + + # Save model on disk + with open(onnx_file_path, "wb") as file_handle: + serialized = onnx_model.SerializeToString() + file_handle.write(serialized) + logging.info("Input shape of the model %s ", input_shape) + logging.info("Exported ONNX file %s saved to disk", onnx_file_path) + + return onnx_file_path diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py new file mode 100644 index 000000000000..11847381ab24 --- /dev/null +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -0,0 +1,347 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Based on +# https://github.com/NVIDIA/mxnet_to_onnx/blob/master/mx2onnx_converter/mx2onnx_converter.py# +# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# coding: utf-8 +# pylint: disable=invalid-name,too-many-locals,no-self-use,too-many-arguments, +# pylint: disable=maybe-no-member,too-many-nested-blocks +"""MXNet to ONNX graph converter functions""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +import logging +import json +import numpy as np + +from .... import context +from .... import ndarray as nd +from .... import io +from .... import module as mod + + +class MXNetGraph(object): + """Class to convert MXNet to ONNX graph""" + registry_ = {} + input_output_maps_ = {} + + def __init__(self): + # topologically sorted nodes + self.nodes = [] + self.input_tensors = [] + self.output_tensors = [] + + @staticmethod + def register(op_name): + """Register operators""" + def wrapper(func): + """Helper function to map functions""" + MXNetGraph.registry_[op_name] = func + return func + + return wrapper + + @staticmethod + def convert_layer(node, **kwargs): + """Convert MXNet layer to ONNX""" + op = str(node["op"]) + if op not in MXNetGraph.registry_: + raise AttributeError("No conversion function registered for op type %s yet." % op) + convert_func = MXNetGraph.registry_[op] + return convert_func(node, **kwargs) + + @staticmethod + def forward_pass(inputs, sym, arg_params, aux_params, output_label): + """Do a forward pass based on the sym and params to get the shape + of the output using dummy data + + Parameters + ---------- + inputs : json string + + sym : :class:`~mxnet.symbol.Symbol` + MXNet symbol object + arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` + Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format + aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` + Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format + + Returns + ------- + shape : Shape + Output shape + """ + # if label is not provided, MXNet adds label "softmax_label" by default + # while running load_checkpoint which is not actually a graph input. So ignoring it here + data_names = [graph_input for graph_input in sym.list_inputs() + if graph_input not in arg_params and graph_input not in aux_params + and graph_input != output_label] + + data_shapes = [] + # Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs. + for idx, input_name in enumerate(data_names): + data_shapes.append((input_name, inputs[idx].shape)) + + # create module, passing cpu context + ctx = context.cpu() + test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None) + test_mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None) + + # initializing parameters for calculating result of each individual node + if arg_params is None and aux_params is None: + test_mod.init_params() + else: + test_mod.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True) + + data_forward = [] + for idx, input_name in enumerate(data_names): + val = inputs[idx] + data_forward.append(nd.array(val)) + + test_mod.forward(io.DataBatch(data_forward)) + result = test_mod.get_outputs()[0].asnumpy() + + return result.shape + + + @staticmethod + def split_params(sym, params): + """Helper function to split params dictionary into args and aux params + + Parameters + ---------- + sym : :class:`~mxnet.symbol.Symbol` + MXNet symbol object + params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` + Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format + + Returns + ------- + arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` + Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format + aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` + Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format + """ + arg_params = {} + aux_params = {} + for args in sym.list_arguments(): + if args in params: + arg_params.update({args: nd.array(params[args])}) + for aux in sym.list_auxiliary_states(): + if aux in params: + aux_params.update({aux: nd.array(params[aux])}) + return arg_params, aux_params + + + @staticmethod + def infer_output_shape(sym, params, in_shape, output_label): + """Infer output shape by doing a forward pass using dummy inputs """ + # create dummy input + inputs = [np.random.randn(*input_shape) for input_shape in in_shape] + arg, aux = MXNetGraph.split_params(sym, params) + return MXNetGraph.forward_pass(inputs, sym, arg, aux, output_label) + + + @staticmethod + def convert_weights_to_numpy(weights_dict): + """Convert weights to numpy""" + return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy()) + for k, v in weights_dict.items()]) + + def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False): + """Convert MXNet graph to ONNX graph + + Parameters + ---------- + sym : :class:`~mxnet.symbol.Symbol` + MXNet symbol object + params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` + Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format + in_shape : List of tuple + Input shape of the model e.g [(1,3,224,224)] + in_type : data type + Input data type e.g. np.float32 + verbose : Boolean + If true will print logs of the model conversion + + Returns + ------- + graph : GraphProto + ONNX graph + """ + try: + from onnx import (checker, helper, NodeProto, ValueInfoProto, TensorProto) + from onnx.helper import make_tensor_value_info + except ImportError: + raise ImportError("Onnx and protobuf need to be installed. " + + "Instructions to install - https://github.com/onnx/onnx") + + # When MXNet model is saved to json file , MXNet adds a node for label. + # The name of this node is, name of the last node + "_label" ( i.e if last node + # name is "Softmax", this node will have a name "Softmax_label". Also, the new node + # will always be second last node in the json graph. + # Deriving the output_label name. + output_label = sym.get_internals()[len(sym.get_internals()) - 1].name + "_label" + + # Determine output shape + output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label) + + weights = MXNetGraph.convert_weights_to_numpy(params) + + mx_graph = json.loads(sym.tojson())["nodes"] + + initializer = [] + all_processed_nodes = [] + onnx_processed_nodes = [] + onnx_processed_inputs = [] + onnx_processed_outputs = [] + index_lookup = [] + + graph_input_idx = 0 + for idx, node in enumerate(mx_graph): + op = node["op"] + name = node["name"] + if verbose: + logging.info("Converting idx: %d, op: %s, name: %s", idx, op, name) + + # A node is an input node if its op_name is "null" and is not + # in params dict + if op == "null" and name not in params: + # Handling graph input + + # Skipping output_label node, as this node is not part of graph + # Refer "output_label" assignment above for more details. + if name == output_label: + continue + converted = MXNetGraph.convert_layer( + node, + is_input=True, + mx_graph=mx_graph, + weights=weights, + in_shape=in_shape[graph_input_idx], + in_type=in_type, + proc_nodes=all_processed_nodes, + initializer=initializer, + index_lookup=index_lookup) + graph_input_idx += 1 + + else: + # Handling graph layers + converted = MXNetGraph.convert_layer( + node, + is_input=False, + mx_graph=mx_graph, + weights=weights, + in_shape=in_shape, + in_type=in_type, + proc_nodes=all_processed_nodes, + initializer=initializer, + index_lookup=index_lookup, + idx=idx + ) + + if isinstance(converted, list): + # Iterate for all converted nodes + for converted_node in converted: + # If converted node is ValueInfoProto, add it in inputs + if isinstance(converted_node, ValueInfoProto): + onnx_processed_inputs.append(converted_node) + # If converted node is NodeProto, add it in processed nodes list + elif isinstance(converted_node, NodeProto): + onnx_processed_nodes.append(converted_node) + if idx == (len(mx_graph) - 1): + # If converted node doesnt have name, use it from output field + if not converted_node.name: + onnx_processed_outputs.append( + make_tensor_value_info( + name=converted_node.output[0], + elem_type=in_type, + shape=output_shape + ) + ) + else: + onnx_processed_outputs.append( + make_tensor_value_info( + name=converted_node.name, + elem_type=in_type, + shape=output_shape + ) + ) + if verbose: + logging.info("Output node is: %s", converted_node.name) + elif isinstance(converted_node, TensorProto): + raise ValueError("Did not expect TensorProto") + else: + raise ValueError("node is of an unrecognized type: %s" % type(node)) + + all_processed_nodes.append(converted_node) + + if idx > 0: + # Handling extra node added to the graph if the MXNet model was + # saved to json file, + # refer "output_label" initialization above for more details. + # if extra node was added then prev_index to the last node is adjusted. + if idx == (len(mx_graph) - 1) and \ + mx_graph[len(mx_graph)-2]["name"] == output_label: + prev_index = index_lookup[idx - 2] + else: + prev_index = index_lookup[idx - 1] + + index_lookup.append(prev_index+len(converted)) + else: + index_lookup.append(len(converted) - 1) + else: + logging.info("Operator converter function should always return a list") + + graph = helper.make_graph( + onnx_processed_nodes, + "mxnet_converted_model", + onnx_processed_inputs, + onnx_processed_outputs + ) + + graph.initializer.extend(initializer) + + checker.check_graph(graph) + return graph diff --git a/python/mxnet/contrib/onnx/_import/__init__.py b/python/mxnet/contrib/onnx/onnx2mx/__init__.py similarity index 100% rename from python/mxnet/contrib/onnx/_import/__init__.py rename to python/mxnet/contrib/onnx/onnx2mx/__init__.py diff --git a/python/mxnet/contrib/onnx/_import/import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py similarity index 74% rename from python/mxnet/contrib/onnx/_import/import_helper.py rename to python/mxnet/contrib/onnx/onnx2mx/_import_helper.py index 3dfff3ed6818..c19f0f2cb246 100644 --- a/python/mxnet/contrib/onnx/_import/import_helper.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py @@ -15,27 +15,27 @@ # specific language governing permissions and limitations # under the License. -# coding: utf-8 +# coding: utf-8_ # pylint: disable=invalid-name """Operator attributes conversion""" -from .op_translations import identity, random_uniform, random_normal -from .op_translations import add, subtract, multiply, divide, absolute, negative, add_n -from .op_translations import tanh -from .op_translations import ceil, floor -from .op_translations import concat -from .op_translations import leaky_relu, _elu, _prelu, softmax, fully_connected -from .op_translations import global_avgpooling, global_maxpooling, linalg_gemm -from .op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm -from .op_translations import dropout, local_response_norm, conv, deconv -from .op_translations import reshape, cast, split, _slice, transpose, squeeze, flatten -from .op_translations import reciprocal, squareroot, power, exponent, _log, unsqueeze -from .op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum -from .op_translations import reduce_prod, avg_pooling, max_pooling -from .op_translations import argmax, argmin, maximum, minimum -from .op_translations import clip, reduce_log_sum, reduce_log_sum_exp -from .op_translations import reduce_sum_square, reduce_l2, max_roi_pooling, instance_norm -from .op_translations import log_softmax, softsign, lesser, greater, equal -from .op_translations import logical_and, logical_or, logical_xor, logical_not +from ._op_translations import identity, random_uniform, random_normal +from ._op_translations import add, subtract, multiply, divide, absolute, negative, add_n +from ._op_translations import tanh +from ._op_translations import ceil, floor +from ._op_translations import concat +from ._op_translations import leaky_relu, _elu, _prelu, softmax, fully_connected +from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm +from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm +from ._op_translations import dropout, local_response_norm, conv, deconv +from ._op_translations import reshape, cast, split, _slice, transpose, squeeze, flatten +from ._op_translations import reciprocal, squareroot, power, exponent, _log, unsqueeze +from ._op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum +from ._op_translations import reduce_prod, avg_pooling, max_pooling +from ._op_translations import argmax, argmin, maximum, minimum +from ._op_translations import clip, reduce_log_sum, reduce_log_sum_exp +from ._op_translations import reduce_sum_square, reduce_l2, max_roi_pooling, instance_norm +from ._op_translations import log_softmax, softsign, lesser, greater, equal +from ._op_translations import logical_and, logical_or, logical_xor, logical_not # convert_map defines maps of ONNX operator names to converter functor(callable) # defined in the op_translations module. @@ -89,6 +89,7 @@ 'Squeeze' : squeeze, 'Unsqueeze' : unsqueeze, 'Flatten' : flatten, + 'Identity' : identity, #Powers 'Reciprocal' : reciprocal, 'Sqrt' : squareroot, diff --git a/python/mxnet/contrib/onnx/_import/op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py similarity index 99% rename from python/mxnet/contrib/onnx/_import/op_translations.py rename to python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 0fad0080bef0..2b98aa08febf 100644 --- a/python/mxnet/contrib/onnx/_import/op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -19,7 +19,7 @@ """ Module for translating ONNX operators into Mxnet operatoes""" # pylint: disable=unused-argument,protected-access import numpy as np -from . import translation_utils +from . import _translation_utils as translation_utils from .... import symbol # Method definitions for the callable objects mapped in the import_helper module @@ -130,7 +130,7 @@ def maximum(attrs, inputs, proto_obj): for op_input in inputs[2:]: mxnet_op = symbol.maximum(mxnet_op, op_input) else: - mxnet_op = inputs[0] + mxnet_op = symbol.maximum(inputs[0], inputs[0]) return mxnet_op, attrs, inputs def minimum(attrs, inputs, proto_obj): @@ -143,7 +143,7 @@ def minimum(attrs, inputs, proto_obj): for op_input in inputs[2:]: mxnet_op = symbol.minimum(mxnet_op, op_input) else: - mxnet_op = inputs[0] + mxnet_op = symbol.minimum(inputs[0], inputs[0]) return mxnet_op, attrs, inputs def lesser(attrs, inputs, proto_obj): diff --git a/python/mxnet/contrib/onnx/_import/translation_utils.py b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py similarity index 99% rename from python/mxnet/contrib/onnx/_import/translation_utils.py rename to python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py index fe25a94baa7d..f63c1e9e8e62 100644 --- a/python/mxnet/contrib/onnx/_import/translation_utils.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py @@ -168,6 +168,7 @@ def _fix_broadcast(op_name, inputs, broadcast_axis, proto_obj): op_sym = op_name return op_sym + def _fix_channels(op_name, attrs, inputs, proto_obj): """A workaround for getting 'channels' or 'units' since onnx don't provide these attributes. We check the shape of weights provided to get the number. diff --git a/python/mxnet/contrib/onnx/_import/import_model.py b/python/mxnet/contrib/onnx/onnx2mx/import_model.py similarity index 100% rename from python/mxnet/contrib/onnx/_import/import_model.py rename to python/mxnet/contrib/onnx/onnx2mx/import_model.py diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py similarity index 99% rename from python/mxnet/contrib/onnx/_import/import_onnx.py rename to python/mxnet/contrib/onnx/onnx2mx/import_onnx.py index d81ec96537f3..4e851712972f 100644 --- a/python/mxnet/contrib/onnx/_import/import_onnx.py +++ b/python/mxnet/contrib/onnx/onnx2mx/import_onnx.py @@ -23,7 +23,7 @@ from .... import cpu, gpu from .... import ndarray as nd from ....base import string_types -from .import_helper import _convert_map as convert_map +from ._import_helper import _convert_map as convert_map class GraphProto(object): # pylint: disable=too-few-public-methods """A helper class for handling mxnet symbol copying from pb2.GraphProto. diff --git a/python/mxnet/contrib/onnx/_import/import_to_gluon.py b/python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py similarity index 100% rename from python/mxnet/contrib/onnx/_import/import_to_gluon.py rename to python/mxnet/contrib/onnx/onnx2mx/import_to_gluon.py diff --git a/tests/python-pytest/onnx/export/backend.py b/tests/python-pytest/onnx/export/backend.py new file mode 100644 index 000000000000..e23cc01494e9 --- /dev/null +++ b/tests/python-pytest/onnx/export/backend.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""backend wrapper for onnx test infrastructure""" +import numpy as np +from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto +from mxnet.contrib.onnx.mx2onnx.export_onnx import MXNetGraph +try: + from onnx import helper, TensorProto, mapping + from onnx.backend.base import Backend +except ImportError: + raise ImportError("Onnx and protobuf need to be installed") +from backend_rep import MXNetBackendRep + +# Using these functions for onnx test infrastructure. +# Implemented by following onnx docs guide: +# https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md +# MXNetBackend class will take an ONNX model with inputs, perform a computation, +# and then return the output. + +class MXNetBackend(Backend): + """MXNet backend for ONNX""" + + @staticmethod + def perform_import_export(graph_proto, input_shape): + """ Import ONNX model to mxnet model and then export to ONNX model + and then import it back to mxnet for verifying the result""" + graph = GraphProto() + + sym, arg_params, aux_params = graph.from_onnx(graph_proto) + + params = {} + params.update(arg_params) + params.update(aux_params) + # exporting to onnx graph proto format + converter = MXNetGraph() + graph_proto = converter.create_onnx_graph_proto(sym, params, in_shape=input_shape, in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')]) + + # importing back to MXNET for verifying result. + sym, arg_params, aux_params = graph.from_onnx(graph_proto) + + return sym, arg_params, aux_params + + + @classmethod + def prepare(cls, model, device='CPU', **kwargs): + """For running end to end model(used for onnx test backend) + + Parameters + ---------- + model : onnx ModelProto object + loaded onnx graph + device : 'CPU' + specifying device to run test on + kwargs : + other arguments + + Returns + ------- + MXNetBackendRep : object + Returns object of MXNetBackendRep class which will be in turn + used to run inference on the input model and return the result for comparison. + """ + + graph = GraphProto() + metadata = graph.get_graph_metadata(model.graph) + input_data = metadata['input_tensor_data'] + input_shape = [data[1] for data in input_data] + sym, arg_params, aux_params = MXNetBackend.perform_import_export(model.graph, input_shape) + return MXNetBackendRep(sym, arg_params, aux_params, device) + + @classmethod + def supports_device(cls, device): + """Supports only CPU for testing""" + return device == 'CPU' + + +prepare = MXNetBackend.prepare + +run_node = MXNetBackend.run_node + +supports_device = MXNetBackend.supports_device diff --git a/tests/python-pytest/onnx/export/backend_rep.py b/tests/python-pytest/onnx/export/backend_rep.py new file mode 100644 index 000000000000..8729eafea1a1 --- /dev/null +++ b/tests/python-pytest/onnx/export/backend_rep.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""backend rep for onnx test infrastructure""" +try: + from onnx.backend.base import BackendRep +except ImportError: + raise ImportError("Onnx and protobuf need to be installed") +import mxnet as mx + +# Using these functions for onnx test infrastructure. +# Implemented by following onnx docs guide: +# https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md +# MXNetBackendRep object will be returned by MXNetBackend's prepare method which is used to +# execute a model repeatedly. +# Inputs will be passed to the run method of MXNetBackendRep class, it will perform computation and +# retrieve the corresponding results for comparison to the onnx backend. +# https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py. + +class MXNetBackendRep(BackendRep): + """Running model inference on mxnet engine and return the result + to onnx test infrastructure for comparison.""" + def __init__(self, symbol, arg_params, aux_params, device): + self.symbol = symbol + self.arg_params = arg_params + self.aux_params = aux_params + self.device = device + + def run(self, inputs, **kwargs): + """Run model inference and return the result + + Parameters + ---------- + inputs : numpy array + input to run a layer on + + Returns + ------- + params : numpy array + result obtained after running the inference on mxnet + """ + data_forward = [] + for val in inputs: + data_forward.append(mx.nd.array(val)) + # create module, passing cpu context + if self.device == 'CPU': + ctx = mx.cpu() + else: + raise NotImplementedError("ONNX tests are run only for CPU context.") + + # To fetch the data names of the input to the model we list the inputs of the symbol graph + # and exclude the argument and auxiliary parameters from the list + data_names = [graph_input for graph_input in self.symbol.list_inputs() + if graph_input not in self.arg_params and graph_input not in self.aux_params] + + data_shapes = [] + for idx, input_name in enumerate(data_names): + data_shapes.append((input_name, inputs[idx].shape)) + + mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, context=ctx, + label_names=None) + mod.bind(for_training=False, data_shapes=data_shapes, + label_shapes=None) + mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params) + + # run inference + mod.forward(mx.io.DataBatch(data_forward)) + result = mod.get_outputs()[0].asnumpy() + return [result] diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py new file mode 100644 index 000000000000..7e1df07cbaa1 --- /dev/null +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -0,0 +1,191 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Tests for individual operators +This module contains operator tests which currently do not exist on +ONNX backend test framework. Once we have PRs on the ONNX repo and get +those PRs merged, this file will get EOL'ed. +""" +# pylint: disable=too-many-locals,wrong-import-position,import-error +from __future__ import absolute_import +import sys +import os +import logging +import tarfile +from collections import namedtuple +import numpy as np +import numpy.testing as npt +from onnx import numpy_helper +from onnx import TensorProto +from mxnet.test_utils import download +from mxnet.contrib import onnx as onnx_mxnet +import mxnet as mx +CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(CURR_PATH, '../../python/unittest')) +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) +URLS = { + 'bvlc_googlenet': + 'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_googlenet.tar.gz', + 'bvlc_reference_caffenet': + 'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_reference_caffenet.tar.gz', + 'bvlc_reference_rcnn_ilsvrc13': + 'https://s3.amazonaws.com/onnx-mxnet/model-zoo/bvlc_reference_rcnn_ilsvrc13.tar.gz', + 'inception_v1': + 'https://s3.amazonaws.com/onnx-mxnet/model-zoo/inception_v1.tar.gz', + 'inception_v2': + 'https://s3.amazonaws.com/onnx-mxnet/model-zoo/inception_v2.tar.gz' +} + +def get_test_files(name): + """Extract tar file and returns model path and input, output data""" + tar_name = download(URLS.get(name), dirname=CURR_PATH.__str__()) + # extract tar file + tar_path = os.path.join(CURR_PATH, tar_name) + tar = tarfile.open(tar_path.__str__(), "r:*") + tar.extractall(path=CURR_PATH.__str__()) + tar.close() + data_dir = os.path.join(CURR_PATH, name) + model_path = os.path.join(data_dir, 'model.onnx') + + inputs = [] + outputs = [] + # get test files + for test_file in os.listdir(data_dir): + case_dir = os.path.join(data_dir, test_file) + # skip the non-dir files + if not os.path.isdir(case_dir): + continue + input_file = os.path.join(case_dir, 'input_0.pb') + input_tensor = TensorProto() + with open(input_file, 'rb') as proto_file: + input_tensor.ParseFromString(proto_file.read()) + inputs.append(numpy_helper.to_array(input_tensor)) + + output_tensor = TensorProto() + output_file = os.path.join(case_dir, 'output_0.pb') + with open(output_file, 'rb') as proto_file: + output_tensor.ParseFromString(proto_file.read()) + outputs.append(numpy_helper.to_array(output_tensor)) + + return model_path, inputs, outputs + + +def forward_pass(sym, arg, aux, data_names, input_data): + """ Perform forward pass on given data""" + # create module + mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) + mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) + mod.set_params(arg_params=arg, aux_params=aux, + allow_missing=True, allow_extra=True) + # run inference + batch = namedtuple('Batch', ['data']) + mod.forward(batch([mx.nd.array(input_data)]), is_train=False) + + return mod.get_outputs()[0].asnumpy() + + +def test_models(model_name, input_shape, output_shape): + """ Tests Googlenet model for both onnx import and export""" + model_path, inputs, outputs = get_test_files(model_name) + logging.info("Translating model from ONNX model zoo to Mxnet") + sym, arg_params, aux_params = onnx_mxnet.import_model(model_path) + params = {} + params.update(arg_params) + params.update(aux_params) + + dir_path = os.path.dirname(model_path) + new_model_name = "exported_" + model_name + ".onnx" + onnx_file = os.path.join(dir_path, new_model_name) + + logging.info("Translating converted model from mxnet to ONNX") + converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file) + + sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model_path) + + metadata = onnx_mxnet.get_model_metadata(converted_model_path) + assert len(metadata) == 2 + assert metadata.get('input_tensor_data') + assert metadata.get('input_tensor_data')[0][1] == input_shape + assert metadata.get('output_tensor_data') + assert metadata.get('output_tensor_data')[0][1] == output_shape + data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')] + + logging.info("Running inference on onnx re-import model in mxnet") + # run test for each test file + for input_data, output_data in zip(inputs, outputs): + result = forward_pass(sym, arg_params, aux_params, data_names, input_data) + + # verify the results + npt.assert_equal(result.shape, output_data.shape) + npt.assert_almost_equal(output_data, result, decimal=3) + logging.info(model_name + " conversion successful") + + +def test_model_accuracy(model_name, input_shape): + """ Imports ONNX model, runs inference, exports and imports back + run inference, compare result with the previous inference result""" + model_path, inputs, outputs = get_test_files(model_name) + logging.info("Translating model from ONNX model zoo to Mxnet") + sym, arg_params, aux_params = onnx_mxnet.import_model(model_path) + + metadata = onnx_mxnet.get_model_metadata(model_path) + data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')] + + expected_result= [] + for input_data, output_data in zip(inputs, outputs): + result = forward_pass(sym, arg_params, aux_params, data_names, input_data) + expected_result.append(result) + + params = {} + params.update(arg_params) + params.update(aux_params) + + dir_path = os.path.dirname(model_path) + new_model_name = "exported_" + model_name + ".onnx" + onnx_file = os.path.join(dir_path, new_model_name) + + logging.info("Translating converted model from mxnet to ONNX") + converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, + onnx_file) + + sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model_path) + + metadata = onnx_mxnet.get_model_metadata(converted_model_path) + data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')] + + actual_result = [] + for input_data, output_data in zip(inputs, outputs): + result = forward_pass(sym, arg_params, aux_params, data_names, input_data) + actual_result.append(result) + + # verify the results + for expected, actual in zip(expected_result, actual_result): + npt.assert_equal(expected.shape, actual.shape) + npt.assert_almost_equal(expected, actual, decimal=3) + + +if __name__ == '__main__': + test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000)) + test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000)) + test_models("bvlc_reference_rcnn_ilsvrc13", (1, 3, 224, 224), (1, 200)) + + # Comparing MXNet inference result, since MXNet results don't match + # ONNX expected results due to AveragePool issue github issue(#10194) + test_model_accuracy("inception_v1", (1, 3, 224, 224)) + test_model_accuracy("inception_v2", (1, 3, 224, 224)) diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py b/tests/python-pytest/onnx/export/onnx_backend_test.py new file mode 100644 index 000000000000..803d290b9c69 --- /dev/null +++ b/tests/python-pytest/onnx/export/onnx_backend_test.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""ONNX test backend wrapper""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest +try: + import onnx.backend.test +except ImportError: + raise ImportError("Onnx and protobuf need to be installed") + +import backend as mxnet_backend + +# This is a pytest magic variable to load extra plugins +pytest_plugins = "onnx.backend.test.report", + +BACKEND_TESTS = onnx.backend.test.BackendTest(mxnet_backend, __name__) + +IMPLEMENTED_OPERATORS_TEST = [ + 'test_random_uniform', + 'test_random_normal', + 'test_add', + 'test_sub', + 'test_mul', + 'test_div', + 'test_neg', + 'test_abs', + 'test_sum', + 'test_tanh', + 'test_ceil', + 'test_floor', + 'test_concat', + 'test_identity', + 'test_sigmoid', + 'test_relu', + 'test_constant_pad', + 'test_edge_pad', + 'test_reflect_pad', + 'test_reduce_min', + 'test_reduce_max', + 'test_reduce_mean', + 'test_reduce_prod', + 'test_squeeze', + 'test_softmax_example', + 'test_softmax_large_number', + 'test_softmax_axis_2', + 'test_transpose', + 'test_globalmaxpool', + 'test_globalaveragepool', + 'test_slice_cpu', + 'test_slice_neg', + 'test_squeeze_', + 'test_reciprocal', + 'test_sqrt', + 'test_pow', + 'test_exp', + 'test_argmax', + 'test_argmin', + 'test_min', + 'test_max' + #pytorch operator tests + 'test_operator_exp', + 'test_operator_maxpool', + 'test_operator_params', + 'test_operator_permute2' + ] + +BASIC_MODEL_TESTS = [ + 'test_AvgPool2D', + 'test_BatchNorm', + 'test_ConstantPad2d', + 'test_Conv2d', + 'test_ELU', + 'test_LeakyReLU', + 'test_MaxPool', + 'test_PReLU', + 'test_ReLU', + 'test_Sigmoid', + 'test_Softmax', + 'test_softmax_functional', + 'test_softmax_lastdim', + 'test_Tanh' + ] + +STANDARD_MODEL = [ + 'test_bvlc_alexnet', + 'test_densenet121', + # 'test_inception_v1', + # 'test_inception_v2', + 'test_resnet50', + # 'test_shufflenet', + 'test_squeezenet', + 'test_vgg16', + 'test_vgg19' + ] + +for op_test in IMPLEMENTED_OPERATORS_TEST: + BACKEND_TESTS.include(op_test) + +for basic_model_test in BASIC_MODEL_TESTS: + BACKEND_TESTS.include(basic_model_test) + +for std_model_test in STANDARD_MODEL: + BACKEND_TESTS.include(std_model_test) + +BACKEND_TESTS.exclude('.*broadcast.*') +BACKEND_TESTS.exclude('.*bcast.*') + + +# import all test cases at global scope to make them visible to python.unittest +globals().update(BACKEND_TESTS.enable_report().test_cases) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/python-pytest/onnx/import/gluon_backend.py b/tests/python-pytest/onnx/import/gluon_backend.py index d2946f7bb541..302fd4dcf08f 100644 --- a/tests/python-pytest/onnx/import/gluon_backend.py +++ b/tests/python-pytest/onnx/import/gluon_backend.py @@ -17,10 +17,8 @@ # coding: utf-8 """Gluon backend wrapper for onnx test infrastructure""" -import mxnet as mx -from mxnet import nd -from mxnet.contrib.onnx._import.import_onnx import GraphProto -import numpy as np +from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto + try: from onnx import helper, TensorProto from onnx.backend.base import Backend diff --git a/tests/python-pytest/onnx/import/mxnet_backend.py b/tests/python-pytest/onnx/import/mxnet_backend.py index bbe8899dee15..10f89ecbbbc7 100644 --- a/tests/python-pytest/onnx/import/mxnet_backend.py +++ b/tests/python-pytest/onnx/import/mxnet_backend.py @@ -17,8 +17,7 @@ # coding: utf-8 """MXNet backend wrapper for onnx test infrastructure""" -import mxnet as mx -from mxnet.contrib.onnx._import.import_onnx import GraphProto +from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto try: from onnx import helper, TensorProto from onnx.backend.base import Backend diff --git a/tests/python-pytest/onnx/import/mxnet_backend_rep.py b/tests/python-pytest/onnx/import/mxnet_backend_rep.py index 5ce29f54150a..067ef1568309 100644 --- a/tests/python-pytest/onnx/import/mxnet_backend_rep.py +++ b/tests/python-pytest/onnx/import/mxnet_backend_rep.py @@ -17,7 +17,6 @@ # coding: utf-8 """MXNet backend rep for onnx test infrastructure""" -import numpy as np try: from onnx.backend.base import BackendRep except ImportError: diff --git a/tests/python-pytest/onnx/import/test_cases.py b/tests/python-pytest/onnx/import/test_cases.py index 1a4d8c4fe37b..f7addbb29b32 100644 --- a/tests/python-pytest/onnx/import/test_cases.py +++ b/tests/python-pytest/onnx/import/test_cases.py @@ -31,6 +31,7 @@ 'test_ceil', 'test_floor', 'test_concat', + 'test_identity', 'test_sigmoid', 'test_relu', 'test_constant_pad',