This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MXNET-34] Onnx Module to import onnx models into mxnet (#9963)
* Onnx Module to import onnx models into mxnet * Change package name to onnx from serde. * Remove onnx install time dependency * Remove Renamer class * Add apache license to files. * Refactor test files to tests/python folder. * Removed export folder. * Refactor Attribute COnverter logic Signed-off-by: Acharya <[email protected]> * Changing the translation and utils file. Signed-off-by: Acharya <[email protected]> * - Fixed Pylint issues - Added Sigmoid operator. - Add onnx, protobuf as CI pipeline dependencies. Signed-off-by: Acharya <[email protected]> * Add UTs for reduce ops. Signed-off-by: Acharya <[email protected]> * pylint - newline, whitespace. Signed-off-by: Acharya <[email protected]> * Added operators: - AvgPool - ArgMax - ArgMin - Abs Minor changes in logic for import_onnx * Added operators: - Ceil - Cast - Constant * - Added Pad operator support. - Minor changes for comments * RandomUniform,Normal,Sub,Mul,Div,Tanh,Relu,Reciprocal,Sqrt operators added. Signed-off-by: Acharya <[email protected]> * lint fix * Add protobuf-compile to CI bash script. Add MatMul and Pow operator. Signed-off-by: Acharya <[email protected]> * Max,Min,Sum,Reduce operators. Signed-off-by: Acharya <[email protected]> * BatchNorm,SpatialBN, Split Signed-off-by: Acharya <[email protected]> * Slice,Transpose and Squeeze Operators. Signed-off-by: Acharya <[email protected]> * Onnx tests in CI integration tests. Signed-off-by: Acharya <[email protected]> * Addressing Marco's comments Signed-off-by: Acharya <[email protected]> * Floor, LeakyRelu, Elu, PRelu, Softmax, Exp, Log operator. * Added operators: - Convolution - Deconvolution Refactored convert_operator * lint fix * Rebase fix * Added Maxpool operator * Adding FullyConnected operator * Adding operator- GlobalPooling - max and avg Minor lint fixes. * Adding operator - Gemm * Change test Path, LRN and Dropout operator. * Add asserts for the super_res example. Signed-off-by: Acharya <[email protected]> * Fixing conv test failures. Removed redundant code * Update Jenkins job. Signed-off-by: Acharya <[email protected]> * Nits: Removing commented out code * Rebase after Docker PR * Fetch test files by version number. Verify the high resolution example. Signed-off-by: Acharya <[email protected]> * Fix method arguments for Python3.5+ Signed-off-by: Acharya <[email protected]> * Remove logging configuration from test files. Signed-off-by: Acharya <[email protected]> * Verify result image in example by hash Signed-off-by: Acharya <[email protected]> * Remove fetching test files by ETag. Will add it as a separate PR as per review comments. Signed-off-by: Acharya <[email protected]>
- Loading branch information
1 parent
f202f10
commit 93e0ceb
Showing
17 changed files
with
1,585 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/env bash | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
###################################################################### | ||
# This script installs ONNX for Python along with all required dependencies | ||
# on a Ubuntu Machine. | ||
# Tested on Ubuntu 16.04 distro. | ||
###################################################################### | ||
|
||
set -e | ||
set -x | ||
|
||
echo "Installing libprotobuf-dev and protobuf-compiler ..." | ||
apt-get install -y libprotobuf-dev protobuf-compiler | ||
|
||
echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX and tabulate ..." | ||
pip2 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.0.0 onnx==1.0.1 Pillow==5.0.0 tabulate==0.7.5 | ||
pip3 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.0.0 onnx==1.0.1 Pillow==5.0.0 tabulate==0.7.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Testing super_resolution model conversion""" | ||
from __future__ import absolute_import as _abs | ||
from __future__ import print_function | ||
from collections import namedtuple | ||
import logging | ||
import numpy as np | ||
from PIL import Image | ||
import mxnet as mx | ||
from mxnet.test_utils import download | ||
import mxnet.contrib.onnx as onnx_mxnet | ||
|
||
# set up logger | ||
logging.basicConfig() | ||
LOGGER = logging.getLogger() | ||
LOGGER.setLevel(logging.INFO) | ||
|
||
def import_onnx(): | ||
"""Import the onnx model into mxnet""" | ||
model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx' | ||
download(model_url, 'super_resolution.onnx') | ||
|
||
LOGGER.info("Converting onnx format to mxnet's symbol and params...") | ||
sym, params = onnx_mxnet.import_model('super_resolution.onnx') | ||
LOGGER.info("Successfully Converted onnx format to mxnet's symbol and params...") | ||
return sym, params | ||
|
||
def get_test_image(): | ||
"""Download and process the test image""" | ||
# Load test image | ||
input_image_dim = 224 | ||
img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg' | ||
download(img_url, 'super_res_input.jpg') | ||
img = Image.open('super_res_input.jpg').resize((input_image_dim, input_image_dim)) | ||
img_ycbcr = img.convert("YCbCr") | ||
img_y, img_cb, img_cr = img_ycbcr.split() | ||
input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :] | ||
return input_image, img_cb, img_cr | ||
|
||
def perform_inference(sym, params, input_img, img_cb, img_cr): | ||
"""Perform inference on image using mxnet""" | ||
# create module | ||
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None) | ||
mod.bind(for_training=False, data_shapes=[('input_0', input_img.shape)]) | ||
mod.set_params(arg_params=params, aux_params=None) | ||
|
||
# run inference | ||
batch = namedtuple('Batch', ['data']) | ||
mod.forward(batch([mx.nd.array(input_img)])) | ||
|
||
# Save the result | ||
img_out_y = Image.fromarray(np.uint8(mod.get_outputs()[0][0][0]. | ||
asnumpy().clip(0, 255)), mode='L') | ||
|
||
result_img = Image.merge( | ||
"YCbCr", [img_out_y, | ||
img_cb.resize(img_out_y.size, Image.BICUBIC), | ||
img_cr.resize(img_out_y.size, Image.BICUBIC)]).convert("RGB") | ||
output_img_dim = 672 | ||
assert result_img.size == (output_img_dim, output_img_dim) | ||
LOGGER.info("Super Resolution example success.") | ||
result_img.save("super_res_output.jpg") | ||
return result_img | ||
|
||
if __name__ == '__main__': | ||
MX_SYM, MX_PARAM = import_onnx() | ||
INPUT_IMG, IMG_CB, IMG_CR = get_test_image() | ||
perform_inference(MX_SYM, MX_PARAM, INPUT_IMG, IMG_CB, IMG_CR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,5 +28,5 @@ | |
from . import tensorboard | ||
|
||
from . import text | ||
|
||
from . import onnx | ||
from . import io |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Module for importing and exporting ONNX models.""" | ||
|
||
from ._import.import_model import import_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
"""ONNX Import module""" | ||
from . import import_model | ||
from . import import_onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
# pylint: disable=invalid-name | ||
"""Operator attributes conversion""" | ||
from .op_translations import identity, random_uniform, random_normal | ||
from .op_translations import add, subtract, multiply, divide, absolute, negative, add_n | ||
from .op_translations import tanh | ||
from .op_translations import ceil, floor | ||
from .op_translations import concat | ||
from .op_translations import leaky_relu, _elu, _prelu, softmax, fully_connected | ||
from .op_translations import global_avgpooling, global_maxpooling, linalg_gemm | ||
from .op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm | ||
from .op_translations import dropout, local_response_norm, conv, deconv | ||
from .op_translations import reshape, cast, split, _slice, transpose, squeeze | ||
from .op_translations import reciprocal, squareroot, power, exponent, _log | ||
from .op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum | ||
from .op_translations import reduce_prod, avg_pooling, max_pooling | ||
from .op_translations import argmax, argmin, maximum, minimum | ||
|
||
# convert_map defines maps of ONNX operator names to converter functor(callable) | ||
# defined in the op_translations module. | ||
_convert_map = { | ||
# Generator Functions | ||
'Constant' : identity, | ||
'RandomUniform' : random_uniform, | ||
'RandomNormal' : random_normal, | ||
'RandomUniformLike' : random_uniform, | ||
'RandomNormalLike' : random_normal, | ||
# Arithmetic Operators | ||
'Add' : add, | ||
'Sub' : subtract, | ||
'Mul' : multiply, | ||
'Div' : divide, | ||
'Abs' : absolute, | ||
'Neg' : negative, | ||
'Sum' : add_n, #elemwise sum | ||
#Hyperbolic functions | ||
'Tanh' : tanh, | ||
# Rounding | ||
'Ceil' : ceil, | ||
'Floor' : floor, | ||
# Joining and spliting | ||
'Concat' : concat, | ||
# Basic neural network functions | ||
'Sigmoid' : sigmoid, | ||
'Relu' : relu, | ||
'Pad' : pad, | ||
'MatMul' : matrix_multiplication, #linalg_gemm2 | ||
'Conv' : conv, | ||
'ConvTranspose' : deconv, | ||
'BatchNormalization': batch_norm, | ||
'SpatialBN' : batch_norm, | ||
'LeakyRelu' : leaky_relu, | ||
'Elu' : _elu, | ||
'PRelu' : _prelu, | ||
'Softmax' : softmax, | ||
'FC' : fully_connected, | ||
'GlobalAveragePool' : global_avgpooling, | ||
'GlobalMaxPool' : global_maxpooling, | ||
'Gemm' : linalg_gemm, | ||
'LRN' : local_response_norm, | ||
'Dropout' : dropout, | ||
# Changing shape and type. | ||
'Reshape' : reshape, | ||
'Cast' : cast, | ||
'Split' : split, | ||
'Slice' : _slice, | ||
'Transpose' : transpose, | ||
'Squeeze' : squeeze, | ||
#Powers | ||
'Reciprocal' : reciprocal, | ||
'Sqrt' : squareroot, | ||
'Pow' : power, | ||
'Exp' : exponent, | ||
'Log' : _log, | ||
# Reduce Functions | ||
'ReduceMax' : reduce_max, | ||
'ReduceMean' : reduce_mean, | ||
'ReduceMin' : reduce_min, | ||
'ReduceSum' : reduce_sum, | ||
'ReduceProd' : reduce_prod, | ||
'AveragePool' : avg_pooling, | ||
'MaxPool' : max_pooling, | ||
# Sorting and Searching | ||
'ArgMax' : argmax, | ||
'ArgMin' : argmin, | ||
'Max' : maximum, #elemwise maximum | ||
'Min' : minimum #elemwise minimum | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
"""import function""" | ||
# pylint: disable=no-member | ||
|
||
from .import_onnx import GraphProto | ||
|
||
def import_model(model_file): | ||
"""Imports the ONNX model file passed as a parameter into MXNet symbol and parameters. | ||
Parameters | ||
---------- | ||
model_file : str | ||
ONNX model file name | ||
Returns | ||
------- | ||
Mxnet symbol and parameter objects. | ||
sym : mxnet.symbol | ||
Mxnet symbol | ||
params : dict of str to mx.ndarray | ||
Dict of converted parameters stored in mxnet.ndarray format | ||
""" | ||
graph = GraphProto() | ||
|
||
# loads model file and returns ONNX protobuf object | ||
try: | ||
import onnx | ||
except ImportError: | ||
raise ImportError("Onnx and protobuf need to be installed") | ||
model_proto = onnx.load(model_file) | ||
sym, params = graph.from_onnx(model_proto.graph) | ||
return sym, params |
Oops, something went wrong.