Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

onnx export ops #13821

Merged
merged 5 commits into from
Jan 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 68 additions & 6 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,12 +656,19 @@ def convert_exp(node, **kwargs):
return create_basic_op_node('Exp', node, kwargs)

@mx_op.register("_copy")
def convert_identity(node, **kwargs):
def convert_copy(node, **kwargs):
"""Map MXNet's _copy operator attributes to onnx's Identity operator
and return the created node.
"""
return create_basic_op_node('Identity', node, kwargs)

@mx_op.register("identity")
def convert_identity(node, **kwargs):
"""Map MXNet's identity operator attributes to onnx's ConstantFill operator
and return the created node.
"""
return create_basic_op_node('ConstantFill', node, kwargs)

@mx_op.register("InstanceNorm")
def convert_instancenorm(node, **kwargs):
"""Map MXNet's InstanceNorm operator attributes to onnx's InstanceNormalization operator
Expand Down Expand Up @@ -752,6 +759,31 @@ def convert_softmax_output(node, **kwargs):

return [softmax_node]

@mx_op.register("LogisticRegressionOutput")
def convert_logistic_regression_output(node, **kwargs):
"""Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator
and return the created node.
"""
name = node["name"]
input1_idx = kwargs["index_lookup"][node["inputs"][0][0]]
input1 = kwargs["proc_nodes"][input1_idx]
sigmoid_node = onnx.helper.make_node(
"Sigmoid",
[input1.name],
[name],
name=name
)
return [sigmoid_node]

@mx_op.register("BlockGrad")
def convert_blockgrad(node, **kwargs):
""" Skip operator """
return create_basic_op_node('ConstantFill', node, kwargs)

@mx_op.register("MakeLoss")
def convert_makeloss(node, **kwargs):
""" Skip operator """
return create_basic_op_node('ConstantFill', node, kwargs)

@mx_op.register("Concat")
def convert_concat(node, **kwargs):
Expand Down Expand Up @@ -898,7 +930,7 @@ def convert_clip(node, **kwargs):
def scalar_op_helper(node, op_name, **kwargs):
"""Helper function for scalar arithmetic operations"""
name, input_nodes, attrs = get_inputs(node, kwargs)

from onnx import numpy_helper
input_type = kwargs["in_type"]
scalar_value = np.array([attrs.get("scalar", 1)],
dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type])
Expand All @@ -910,13 +942,21 @@ def scalar_op_helper(node, op_name, **kwargs):
for i in initializer:
if i.name == input_nodes[0]:
if op_name == 'Mul':
new_initializer = onnx.numpy_helper.to_array(i) * scalar_value[0]
new_initializer = numpy_helper.to_array(i) * scalar_value[0]
elif op_name == 'Sub':
new_initializer = onnx.numpy_helper.to_array(i) - scalar_value[0]
if name.startswith("_rminusscalar"):
new_initializer = scalar_value[0] - numpy_helper.to_array(i)
else:
new_initializer = numpy_helper.to_array(i) - scalar_value[0]
elif op_name == 'Add':
new_initializer = onnx.numpy_helper.to_array(i) + scalar_value[0]
new_initializer = numpy_helper.to_array(i) + scalar_value[0]
elif op_name == 'Div':
new_initializer = onnx.numpy_helper.to_array(i) / scalar_value[0]
if name.startswith("_rdivscalar"):
new_initializer = scalar_value[0] / numpy_helper.to_array(i)
else:
new_initializer = numpy_helper.to_array(i) / scalar_value[0]
elif op_name == 'Pow':
new_initializer = numpy_helper.to_array(i) ** scalar_value[0]
flag = False
break

Expand Down Expand Up @@ -982,6 +1022,13 @@ def convert_minus_scalar(node, **kwargs):
"""
return scalar_op_helper(node, 'Sub', **kwargs)

@mx_op.register("_rminus_scalar")
def convert_rminus_scalar(node, **kwargs):
"""Map MXNet's _rminus_scalar operator attributes to onnx's Sub operator.
Creates a new node for the input scalar value, adds it to the initializer
and return multiple created nodes.
"""
return scalar_op_helper(node, 'Sub', **kwargs)

# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_plus_scalar")
Expand All @@ -1001,6 +1048,21 @@ def convert_div_scalar(node, **kwargs):
"""
return scalar_op_helper(node, 'Div', **kwargs)

@mx_op.register("_rdiv_scalar")
def convert_rdiv_scalar(node, **kwargs):
"""Map MXNet's _rdiv_scalar operator attributes to onnx's Div operator.
Creates a new node for the input scalar value, adds it to the initializer
and return multiple created nodes.
"""
return scalar_op_helper(node, 'Div', **kwargs)

@mx_op.register("_power_scalar")
def convert_pow_scalar(node, **kwargs):
"""Map MXNet's _pow_scalar operator attributes to onnx's Pow operator.
Creates a new node for the input scalar value, adds it to the initializer
and return multiple created nodes.
"""
return scalar_op_helper(node, 'Pow', **kwargs)

# Sorting and Searching
@mx_op.register("argmax")
Expand Down
1 change: 0 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,6 @@ def spacetodepth(attrs, inputs, proto_obj):

return "space_to_depth", new_attrs, inputs


def hardmax(attrs, inputs, proto_obj):
"""Returns batched one-hot vectors."""
input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0]
Expand Down
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def prepare_tests(backend, oper):
for std_model_test in std_models:
BACKEND_TESTS.include(std_model_test)

BACKEND_TESTS.exclude('.*bcast.*')
# Tests for scalar ops are in test_node.py
BACKEND_TESTS.exclude('.*scalar.*')

return BACKEND_TESTS

Expand Down
32 changes: 29 additions & 3 deletions tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,30 @@ def test_import_export(self):
if check_shape:
npt.assert_equal(output[0].shape, outputshape)

input1 = get_rnd((1, 10, 2, 3))
ipsym = mx.sym.Variable("input1")
for test in test_scalar_ops:
if test == 'Add':
outsym = 2 + ipsym
if test == "Sub":
outsym = ipsym - 2
if test == "rSub":
outsym = ipsym.__rsub__(2)
if test == "Mul":
outsym = 2 * ipsym
if test == "Div":
outsym = ipsym / 2
if test == "Pow":
outsym = ipsym ** 2
forward_op = forward_pass(outsym, None, None, ['input1'], input1)
converted_model = onnx_mxnet.export_model(outsym, {}, [np.shape(input1)], np.float32,
onnx_file_path=outsym.name + ".onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)

npt.assert_almost_equal(result, forward_op)

def test_imports(self):
for test in import_test_cases:
test_name, onnx_name, inputs, np_op, attrs = test
Expand All @@ -173,7 +197,6 @@ def test_imports(self):
mxnet_out = bkd_rep.run(inputs)
npt.assert_almost_equal(np_out, mxnet_out)


# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False,
# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
# 'remove': [attr_name],
Expand All @@ -198,6 +221,8 @@ def test_imports(self):
{'block_size': 2}, False, {}, True, False),
("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)), get_rnd(1000)],
{'ignore_label': 0, 'use_ignore': False}, True, {}, True, False),
("test_logistic_regression", mx.sym.LogisticRegressionOutput, "Sigmoid",
[get_rnd((1000, 1000)), get_rnd((1000, 1000))], {}, True, {}, True, False),
("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), get_rnd((4, 3)), get_rnd(4)],
{'num_hidden': 4, 'name': 'FC'}, True, {}, True, False),
("test_lppool1", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
Expand All @@ -223,12 +248,13 @@ def test_imports(self):
{'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True)
]

test_scalar_ops = ['Add', 'Sub', 'rSub' 'Mul', 'Div', 'Pow']

# test_case = ("test_case_name", "ONNX_op_name", [input_list], np_op, attribute map)
import_test_cases = [
("test_lpnormalization_default", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':-1}),
("test_lpnormalization_ord1", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':1, 'axis':-1}),
("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1}),
("test_lpnormalization_ord_axis", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':1, 'axis':2})
("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1})
]

if __name__ == '__main__':
Expand Down