Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-34] Onnx Module to import onnx models into mxnet #9963

Merged
merged 36 commits into from
Mar 14, 2018
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d80f95a
Onnx Module to import onnx models into mxnet
Feb 26, 2018
23c409a
Changing the translation and utils file.
Mar 5, 2018
2d0c4e4
- Fixed Pylint issues
Mar 6, 2018
a350505
Add UTs for reduce ops.
Mar 6, 2018
678cd7a
pylint - newline, whitespace.
Mar 6, 2018
0c66f32
Added operators:
rajanksin Mar 6, 2018
5f91c6c
Added operators:
rajanksin Mar 6, 2018
9a615b2
- Added Pad operator support.
rajanksin Mar 6, 2018
0be8eae
RandomUniform,Normal,Sub,Mul,Div,Tanh,Relu,Reciprocal,Sqrt operators
Mar 7, 2018
5936755
lint fix
rajanksin Mar 7, 2018
96d967f
Add protobuf-compile to CI bash script. Add MatMul and Pow operator.
Mar 7, 2018
df61e4f
Max,Min,Sum,Reduce operators.
Mar 7, 2018
8e74bd8
BatchNorm,SpatialBN, Split
Mar 7, 2018
40d9d13
Slice,Transpose and Squeeze Operators.
Mar 8, 2018
a1f3782
Onnx tests in CI integration tests.
Mar 8, 2018
65f627b
Addressing Marco's comments
Mar 9, 2018
495169e
Floor, LeakyRelu, Elu, PRelu, Softmax, Exp, Log operator.
Mar 9, 2018
48b2a7c
Added operators:
rajanksin Mar 9, 2018
51189a8
lint fix
rajanksin Mar 9, 2018
69bf6f8
Rebase fix
rajanksin Mar 9, 2018
ee0393a
Added Maxpool operator
rajanksin Mar 9, 2018
0ebd5b0
Adding FullyConnected operator
rajanksin Mar 9, 2018
40cbe11
Adding operator- GlobalPooling - max and avg
rajanksin Mar 9, 2018
71517b5
Adding operator - Gemm
rajanksin Mar 9, 2018
3dd7a2e
Change test Path, LRN and Dropout operator.
Mar 9, 2018
11be77b
Add asserts for the super_res example.
Mar 9, 2018
d98de71
Fixing conv test failures.
rajanksin Mar 9, 2018
4ef7c36
Update Jenkins job.
Mar 10, 2018
e96b41b
Nits: Removing commented out code
rajanksin Mar 10, 2018
f979d6c
Rebase after Docker PR
Mar 10, 2018
c6c8038
Merge branch 'onnx1' of
Mar 10, 2018
1d02490
Fetch test files by version number. Verify the high resolution example.
Mar 12, 2018
b4b6f9a
Fix method arguments for Python3.5+
Mar 13, 2018
612e2a6
Remove logging configuration from test files.
Mar 13, 2018
a9b2f62
Verify result image in example by hash
Mar 13, 2018
b24baba
Remove fetching test files by ETag.
Mar 14, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,18 @@ try {
}

stage('Integration Test') {
parallel 'Python GPU': {
parallel 'Onnx CPU': {
node('mxnetlinux-cpu') {
ws('workspace/it-onnx-cpu') {
init_git()
unpack_lib('cpu')
timeout(time: max_time, unit: 'MINUTES') {
sh "ci/build.py --build --platform ubuntu_cpu /work/runtime_functions.sh integrationtest_ubuntu_cpu_onnx"
}
}
}
},
'Python GPU': {
node('mxnetlinux-gpu') {
ws('workspace/it-python-gpu') {
init_git()
Expand Down
2 changes: 2 additions & 0 deletions ci/docker/Dockerfile.build.ubuntu_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ COPY install/ubuntu_mklml.sh /work/
RUN /work/ubuntu_mklml.sh
COPY install/ubuntu_caffe.sh /work/
RUN /work/ubuntu_caffe.sh
COPY install/ubuntu_onnx.sh /work/
RUN /work/ubuntu_onnx.sh
COPY install/ubuntu_docs.sh /work/
RUN /work/ubuntu_docs.sh
COPY install/ubuntu_adduser.sh /work/
Expand Down
34 changes: 34 additions & 0 deletions ci/docker/install/ubuntu_onnx.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

######################################################################
# This script installs ONNX for Python along with all required dependencies
# on a Ubuntu Machine.
# Tested on Ubuntu 16.04 distro.
######################################################################

set -e
set -x

echo "Installing libprotobuf-dev and protobuf-compiler ..."
apt-get install -y libprotobuf-dev protobuf-compiler

echo "Installing pytest, pytest-cov, protobuf, Pillow, ONNX and tabulate ..."
pip2 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.0.0 onnx==1.0.1 Pillow==5.0.0 tabulate==0.7.5
pip3 install pytest==3.4.0 pytest-cov==2.5.1 protobuf==3.0.0 onnx==1.0.1 Pillow==5.0.0 tabulate==0.7.5
8 changes: 8 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,14 @@ unittest_centos7_gpu() {
python3.6 -m "nose" --with-timer --verbose tests/python/gpu
}

integrationtest_ubuntu_cpu_onnx() {
set -ex
export PYTHONPATH=./python/
python example/onnx/super_resolution.py
pytest tests/python-pytest/onnx/onnx_backend_test.py
pytest tests/python-pytest/onnx/onnx_test.py
}

integrationtest_ubuntu_gpu_python() {
set -ex
export PYTHONPATH=./python/
Expand Down
84 changes: 84 additions & 0 deletions example/onnx/super_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Testing super_resolution model conversion"""
from __future__ import absolute_import as _abs
from __future__ import print_function
from collections import namedtuple
import logging
import numpy as np
from PIL import Image
import mxnet as mx
from mxnet.test_utils import download
import mxnet.contrib.onnx as onnx_mxnet

# set up logger
logging.basicConfig()
LOGGER = logging.getLogger()
LOGGER.setLevel(logging.INFO)

def import_onnx():
"""Import the onnx model into mxnet"""
model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx'
download(model_url, 'super_resolution.onnx', version_tag='"7348c879d16c42bc77e24e270f663524"')

LOGGER.info("Converting onnx format to mxnet's symbol and params...")
sym, params = onnx_mxnet.import_model('super_resolution.onnx')
LOGGER.info("Successfully Converted onnx format to mxnet's symbol and params...")
return sym, params

def get_test_image():
"""Download and process the test image"""
# Load test image
input_image_dim = 224
img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg'
download(img_url, 'super_res_input.jpg', version_tag='"02c90a7248e51316b11f7f39dd1b226d"')
img = Image.open('super_res_input.jpg').resize((input_image_dim, input_image_dim))
img_ycbcr = img.convert("YCbCr")
img_y, img_cb, img_cr = img_ycbcr.split()
input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]
return input_image, img_cb, img_cr

def perform_inference(sym, params, input_img, img_cb, img_cr):
"""Perform inference on image using mxnet"""
# create module
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0', input_img.shape)])
mod.set_params(arg_params=params, aux_params=None)

# run inference
batch = namedtuple('Batch', ['data'])
mod.forward(batch([mx.nd.array(input_img)]))

# Save the result
img_out_y = Image.fromarray(np.uint8(mod.get_outputs()[0][0][0].
asnumpy().clip(0, 255)), mode='L')

result_img = Image.merge(
"YCbCr", [img_out_y,
img_cb.resize(img_out_y.size, Image.BICUBIC),
img_cr.resize(img_out_y.size, Image.BICUBIC)]).convert("RGB")
output_img_dim = 672
assert result_img.size == (output_img_dim, output_img_dim)
LOGGER.info("Super Resolution example success.")
result_img.save("super_res_output.jpg")
return result_img

if __name__ == '__main__':
MX_SYM, MX_PARAM = import_onnx()
INPUT_IMG, IMG_CB, IMG_CR = get_test_image()
perform_inference(MX_SYM, MX_PARAM, INPUT_IMG, IMG_CB, IMG_CR)
2 changes: 1 addition & 1 deletion python/mxnet/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@
from . import tensorboard

from . import text

from . import onnx
anirudhacharya marked this conversation as resolved.
Show resolved Hide resolved
from . import io
20 changes: 20 additions & 0 deletions python/mxnet/contrib/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Module for importing and exporting ONNX models."""

from ._import.import_model import import_model
21 changes: 21 additions & 0 deletions python/mxnet/contrib/onnx/_import/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""ONNX Import module"""
from . import import_model
from . import import_onnx
105 changes: 105 additions & 0 deletions python/mxnet/contrib/onnx/_import/import_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=invalid-name
"""Operator attributes conversion"""
from .op_translations import identity, random_uniform, random_normal
from .op_translations import add, subtract, multiply, divide, absolute, negative, add_n
from .op_translations import tanh
from .op_translations import ceil, floor
from .op_translations import concat
from .op_translations import leaky_relu, _elu, _prelu, softmax, fully_connected
from .op_translations import global_avgpooling, global_maxpooling, linalg_gemm
from .op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm
from .op_translations import dropout, local_response_norm, conv, deconv
anirudhacharya marked this conversation as resolved.
Show resolved Hide resolved
from .op_translations import reshape, cast, split, _slice, transpose, squeeze
from .op_translations import reciprocal, squareroot, power, exponent, _log
from .op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum
from .op_translations import reduce_prod, avg_pooling, max_pooling
from .op_translations import argmax, argmin, maximum, minimum

# convert_map defines maps of ONNX operator names to converter functor(callable)
anirudhacharya marked this conversation as resolved.
Show resolved Hide resolved
# defined in the op_translations module.
_convert_map = {
# Generator Functions
'Constant' : identity,
'RandomUniform' : random_uniform,
'RandomNormal' : random_normal,
'RandomUniformLike' : random_uniform,
'RandomNormalLike' : random_normal,
# Arithmetic Operators
'Add' : add,
'Sub' : subtract,
'Mul' : multiply,
'Div' : divide,
'Abs' : absolute,
'Neg' : negative,
'Sum' : add_n, #elemwise sum
#Hyperbolic functions
'Tanh' : tanh,
# Rounding
'Ceil' : ceil,
'Floor' : floor,
# Joining and spliting
'Concat' : concat,
# Basic neural network functions
'Sigmoid' : sigmoid,
'Relu' : relu,
'Pad' : pad,
'MatMul' : matrix_multiplication, #linalg_gemm2
'Conv' : conv,
'ConvTranspose' : deconv,
'BatchNormalization': batch_norm,
'SpatialBN' : batch_norm,
'LeakyRelu' : leaky_relu,
'Elu' : _elu,
'PRelu' : _prelu,
'Softmax' : softmax,
'FC' : fully_connected,
'GlobalAveragePool' : global_avgpooling,
'GlobalMaxPool' : global_maxpooling,
'Gemm' : linalg_gemm,
'LRN' : local_response_norm,
'Dropout' : dropout,
# Changing shape and type.
'Reshape' : reshape,
'Cast' : cast,
'Split' : split,
'Slice' : _slice,
'Transpose' : transpose,
'Squeeze' : squeeze,
#Powers
'Reciprocal' : reciprocal,
'Sqrt' : squareroot,
'Pow' : power,
'Exp' : exponent,
'Log' : _log,
# Reduce Functions
'ReduceMax' : reduce_max,
'ReduceMean' : reduce_mean,
'ReduceMin' : reduce_min,
'ReduceSum' : reduce_sum,
'ReduceProd' : reduce_prod,
'AveragePool' : avg_pooling,
'MaxPool' : max_pooling,
# Sorting and Searching
'ArgMax' : argmax,
'ArgMin' : argmin,
'Max' : maximum, #elemwise maximum
'Min' : minimum #elemwise minimum
}
46 changes: 46 additions & 0 deletions python/mxnet/contrib/onnx/_import/import_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""import function"""
# pylint: disable=no-member

from .import_onnx import GraphProto

def import_model(model_file):
"""Imports the supplied ONNX model file into MXNet symbol and parameters.
:parameters model_file
----------
model_file : ONNX model file name

:returns (sym, params)
anirudhacharya marked this conversation as resolved.
Show resolved Hide resolved
-------
sym : mx.symbol
Compatible mxnet symbol
params : dict of str to mx.ndarray
Dict of converted parameters stored in mx.ndarray format
"""
graph = GraphProto()

# loads model file and returns ONNX protobuf object
try:
import onnx
except ImportError:
raise ImportError("Onnx and protobuf need to be installed")
anirudhacharya marked this conversation as resolved.
Show resolved Hide resolved
model_proto = onnx.load(model_file)
sym, params = graph.from_onnx(model_proto.graph)
return sym, params
Loading