Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Instance normalization, Shape (#12920)
Browse files Browse the repository at this point in the history
* ONNX import/export: Make backend_rep common

* ONNX export: Instance Normalization

* ONNX export: Shape operator
  • Loading branch information
vandanavk authored and sandeep-krishnamurthy committed Dec 1, 2018
1 parent baeada4 commit 77510d7
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 116 deletions.
26 changes: 26 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,23 @@ def convert_identity(node, **kwargs):
"""
return create_basic_op_node('Identity', node, kwargs)

@mx_op.register("InstanceNorm")
def convert_instancenorm(node, **kwargs):
"""Map MXNet's InstanceNorm operator attributes to onnx's InstanceNormalization operator
based on the input node's attributes and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

eps = float(attrs.get("eps", 0.001))

node = onnx.helper.make_node(
'InstanceNormalization',
inputs=input_nodes,
outputs=[name],
name=name,
epsilon=eps)

return [node]

@mx_op.register("LeakyReLU")
def convert_leakyrelu(node, **kwargs):
Expand Down Expand Up @@ -1546,6 +1563,15 @@ def convert_sum(node, **kwargs):
)
return [node]


@mx_op.register("shape_array")
def convert_shape(node, **kwargs):
"""Map MXNet's shape_array operator attributes to onnx's Shape operator
and return the created node.
"""
return create_basic_op_node('Shape', node, kwargs)


@mx_op.register("hard_sigmoid")
def convert_hardsigmoid(node, **kwargs):
"""Map MXNet's hard_sigmoid operator attributes to onnx's HardSigmoid operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
# under the License.

# coding: utf-8
"""backend rep for onnx test infrastructure"""
"""MXNet backend rep for onnx test infrastructure"""
try:
from onnx.backend.base import BackendRep
except ImportError:
raise ImportError("Onnx and protobuf need to be installed")
raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+ " install - https://github.com/onnx/onnx#installation")
import mxnet as mx

# Using these functions for onnx test infrastructure.
# Implemented by following onnx docs guide:
# https://github.com/onnx/onnx/blob/master/docs/Implementing%20an%20ONNX%20backend.md
# https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md
# MXNetBackendRep object will be returned by MXNetBackend's prepare method which is used to
# execute a model repeatedly.
# Inputs will be passed to the run method of MXNetBackendRep class, it will perform computation and
Expand Down Expand Up @@ -54,9 +55,6 @@ def run(self, inputs, **kwargs):
params : numpy array
result obtained after running the inference on mxnet
"""
data_forward = []
for val in inputs:
data_forward.append(mx.nd.array(val))
# create module, passing cpu context
if self.device == 'CPU':
ctx = mx.cpu()
Expand All @@ -68,17 +66,19 @@ def run(self, inputs, **kwargs):
data_names = [graph_input for graph_input in self.symbol.list_inputs()
if graph_input not in self.arg_params and graph_input not in self.aux_params]

data_shapes = []
data_forward = []
for idx, input_name in enumerate(data_names):
data_shapes.append((input_name, inputs[idx].shape))
val = inputs[idx]
data_forward.append(mx.nd.array(val))

mod = mx.mod.Module(symbol=self.symbol, data_names=data_names, context=ctx,
label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes,
label_shapes=None)
mod.set_params(arg_params=self.arg_params, aux_params=self.aux_params)
if self.arg_params:
for idx, input_name in enumerate(self.arg_params):
val = self.arg_params[input_name]
data_names.append(input_name)
data_forward.append(mx.nd.array(val))

# run inference
mod.forward(mx.io.DataBatch(data_forward))
result = mod.get_outputs()[0].asnumpy()
args = dict(zip(data_names, data_forward))
exe = self.symbol.bind(ctx, args=args, aux_states=self.aux_params)
exe.forward(is_train=False)
result = exe.outputs[0].asnumpy()
return [result]
4 changes: 4 additions & 0 deletions tests/python-pytest/onnx/export/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

# coding: utf-8
"""backend wrapper for onnx test infrastructure"""
import os
import sys
import numpy as np
from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
from mxnet.contrib.onnx.mx2onnx.export_onnx import MXNetGraph
Expand All @@ -25,6 +27,8 @@
from onnx.backend.base import Backend
except ImportError:
raise ImportError("Onnx and protobuf need to be installed")
CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(CURR_PATH, '../'))
from backend_rep import MXNetBackendRep

# Using these functions for onnx test infrastructure.
Expand Down
4 changes: 3 additions & 1 deletion tests/python-pytest/onnx/export/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@
'test_clip'
'test_cast',
'test_depthtospace',
'test_hardsigmoid'
'test_hardsigmoid',
'test_instancenorm',
'test_shape'
]

BASIC_MODEL_TESTS = [
Expand Down
6 changes: 5 additions & 1 deletion tests/python-pytest/onnx/import/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

# coding: utf-8
"""MXNet backend wrapper for onnx test infrastructure"""
import os
import sys
from mxnet.contrib.onnx.onnx2mx.import_onnx import GraphProto
try:
from onnx import helper, TensorProto
from onnx.backend.base import Backend
except ImportError:
raise ImportError("Onnx and protobuf need to be installed. Instructions to"
+ " install - https://github.com/onnx/onnx#installation")
from mxnet_backend_rep import MXNetBackendRep
CURR_PATH = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(CURR_PATH, '../'))
from backend_rep import MXNetBackendRep

# MXNetBackend class will take an ONNX model with inputs, perform a computation,
# and then return the output.
Expand Down
98 changes: 0 additions & 98 deletions tests/python-pytest/onnx/import/mxnet_backend_rep.py

This file was deleted.

0 comments on commit 77510d7

Please sign in to comment.