Skip to content

Commit

Permalink
ONNX ops: norm exported and lpnormalization imported (apache#13806)
Browse files Browse the repository at this point in the history
* ReduceL1, l2 export, lpnormalization import added

* fix

* fix

* fix

* fix
  • Loading branch information
Roshrini authored and haohuw committed Jun 23, 2019
1 parent d0d22b4 commit 3f8238f
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 27 deletions.
34 changes: 34 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,40 @@ def convert_logsoftmax(node, **kwargs):
)
return [node]

@mx_op.register("norm")
def convert_norm(node, **kwargs):
"""Map MXNet's norm operator attributes to onnx's ReduceL1 and ReduceL2 operators
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

mx_axis = attrs.get("axis", None)
axes = convert_string_to_list(str(mx_axis)) if mx_axis else None

keepdims = get_boolean_attribute_value(attrs, "keepdims")
ord = int(attrs.get("ord", 2))

onnx_op_name = "ReduceL1" if ord == 1 else "ReduceL2"

if axes:
reduce_node = onnx.helper.make_node(
onnx_op_name,
input_nodes,
[name],
axes=axes,
keepdims=keepdims,
name=name
)
return [reduce_node]
else:
reduce_node = onnx.helper.make_node(
onnx_op_name,
input_nodes,
[name],
keepdims=keepdims,
name=name
)
return [reduce_node]

@mx_op.register("_sample_multinomial")
def convert_multinomial(node, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ._op_translations import reduce_sum_square, reduce_l1, reduce_l2, max_roi_pooling
from ._op_translations import log_softmax, softsign, lesser, greater, equal
from ._op_translations import logical_and, logical_or, logical_xor, logical_not
from ._op_translations import mean, depthtospace, spacetodepth
from ._op_translations import mean, depthtospace, spacetodepth, lpnormalization

# convert_map defines maps of ONNX operator names to converter functor(callable)
# defined in the op_translations module.
Expand Down Expand Up @@ -146,5 +146,6 @@
'LpPool' : lp_pooling,
'DepthToSpace' : depthtospace,
'SpaceToDepth' : spacetodepth,
'Hardmax' : hardmax
'Hardmax' : hardmax,
'LpNormalization' : lpnormalization
}
9 changes: 9 additions & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,3 +755,12 @@ def hardmax(attrs, inputs, proto_obj):
one_hot = symbol.one_hot(amax, depth=new_shape[-1])
hardmax_op = symbol.reshape(one_hot, input_shape)
return hardmax_op, attrs, inputs

def lpnormalization(attrs, inputs, proto_obj):
"""ONNX does not have eps attribute, so cannot map it to L2normalization in MXNet
without that, it works as norm operator discussion in PR:
https://github.com/onnx/onnx/pull/1330"""
new_attrs = translation_utils._fix_attribute_names(attrs, {'p': 'ord'})
axis = int(attrs.get("axis", -1))
new_attrs.update(axis=axis)
return 'norm', new_attrs, inputs
4 changes: 2 additions & 2 deletions tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@
'test_selu_default',
'test_elu',
'test_max_',
'test_softplus'
'test_softplus',
'test_reduce_'
],
'import': ['test_gather',
'test_softsign',
'test_reduce_',
'test_mean',
'test_averagepool_1d',
'test_averagepool_2d_pads_count_include_pad',
Expand Down
66 changes: 43 additions & 23 deletions tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,35 +106,35 @@ def forward_pass(sym, arg, aux, data_names, input_data):
return mod.get_outputs()[0].asnumpy()


class TestNode(unittest.TestCase):
""" Tests for models.
Tests are dynamically added.
Therefore edit test_models to add more tests.
"""
def get_input_tensors(input_data):
input_tensor = []
input_names = []
input_sym = []
for idx, ip in enumerate(input_data):
name = "input" + str(idx + 1)
input_sym.append(mx.sym.Variable(name))
input_names.append(name)
input_tensor.append(helper.make_tensor_value_info(name,
TensorProto.FLOAT, shape=np.shape(ip)))
return input_names, input_tensor, input_sym

def test_import_export(self):
def get_input_tensors(input_data):
input_tensor = []
input_names = []
input_sym = []
for idx, ip in enumerate(input_data):
name = "input" + str(idx + 1)
input_sym.append(mx.sym.Variable(name))
input_names.append(name)
input_tensor.append(helper.make_tensor_value_info(name,
TensorProto.FLOAT, shape=np.shape(ip)))
return input_names, input_tensor, input_sym

def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, attr):
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=output_shape)]
def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, attr):
outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=output_shape)]

nodes = [helper.make_node(output_name, input_names, ["output"], **attr)]
nodes = [helper.make_node(output_name, input_names, ["output"], **attr)]

graph = helper.make_graph(nodes, testname, inputs, outputs)
graph = helper.make_graph(nodes, testname, inputs, outputs)

model = helper.make_model(graph)
return model
model = helper.make_model(graph)
return model

class TestNode(unittest.TestCase):
""" Tests for models.
Tests are dynamically added.
Therefore edit test_models to add more tests.
"""
def test_import_export(self):
for test in test_cases:
test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, fix_attrs, check_value, check_shape = test
with self.subTest(test_name):
Expand All @@ -161,6 +161,18 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
if check_shape:
npt.assert_equal(output[0].shape, outputshape)

def test_imports(self):
for test in import_test_cases:
test_name, onnx_name, inputs, np_op, attrs = test
with self.subTest(test_name):
names, input_tensors, inputsym = get_input_tensors(inputs)
np_out = [np_op(*inputs, **attrs)]
output_shape = np.shape(np_out)
onnx_model = get_onnx_graph(test_name, names, input_tensors, onnx_name, output_shape, attrs)
bkd_rep = backend.prepare(onnx_model, operation='import')
mxnet_out = bkd_rep.run(inputs)
npt.assert_almost_equal(np_out, mxnet_out)


# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False,
# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
Expand Down Expand Up @@ -211,5 +223,13 @@ def get_onnx_graph(testname, input_names, inputs, output_name, output_shape, att
{'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, True)
]

# test_case = ("test_case_name", "ONNX_op_name", [input_list], np_op, attribute map)
import_test_cases = [
("test_lpnormalization_default", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':-1}),
("test_lpnormalization_ord1", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':1, 'axis':-1}),
("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1}),
("test_lpnormalization_ord_axis", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':1, 'axis':2})
]

if __name__ == '__main__':
unittest.main()

0 comments on commit 3f8238f

Please sign in to comment.