diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 8e3c46dceb42..e077824e0226 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1969,3 +1969,79 @@ def convert_roipooling(node, **kwargs): name=name ) return [node] + + +@mx_op.register("tile") +def convert_tile(node, **kwargs): + """Map MXNet's Tile operator attributes to onnx's Tile + operator and return the created node. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + + reps_list = convert_string_to_list(attrs["reps"]) + + initializer = kwargs["initializer"] + reps_shape_np = np.array(reps_list, dtype='int64') + data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[reps_shape_np.dtype] + dims = np.shape(reps_shape_np) + + output_shape_name = "reps_attr_tensor" + str(kwargs["idx"]) + tensor_node = onnx.helper.make_tensor_value_info(output_shape_name, data_type, dims) + + initializer.append( + onnx.helper.make_tensor( + name=output_shape_name, + data_type=data_type, + dims=dims, + vals=reps_list, + raw=False, + ) + ) + + input_nodes.append(output_shape_name) + tile_node = onnx.helper.make_node( + "Tile", + input_nodes, + [name], + name=name + ) + + return [tensor_node, tile_node] + + +@mx_op.register("broadcast_to") +def convert_broadcast_to(node, **kwargs): + """Map MXNet's broadcast_to operator attributes to onnx's Expand + operator and return the created node. + """ + name, input_nodes, attrs = get_inputs(node, kwargs) + + shape_list = convert_string_to_list(attrs["shape"]) + + initializer = kwargs["initializer"] + output_shape_np = np.array(shape_list, dtype='int64') + data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[output_shape_np.dtype] + dims = np.shape(output_shape_np) + + output_shape_name = "expand_attr_tensor" + str(kwargs["idx"]) + tensor_node = onnx.helper.make_tensor_value_info(output_shape_name, data_type, dims) + + initializer.append( + onnx.helper.make_tensor( + name=output_shape_name, + data_type=data_type, + dims=dims, + vals=shape_list, + raw=False, + ) + ) + + input_nodes.append(output_shape_name) + expand_node = onnx.helper.make_node( + "Expand", + input_nodes, + [name], + name=name + ) + + return [tensor_node, expand_node] diff --git a/tests/python-pytest/onnx/test_node.py b/tests/python-pytest/onnx/test_node.py index 25fe9c9f9a51..96045516c69e 100644 --- a/tests/python-pytest/onnx/test_node.py +++ b/tests/python-pytest/onnx/test_node.py @@ -31,7 +31,7 @@ from collections import namedtuple import numpy as np import numpy.testing as npt -from onnx import numpy_helper, helper, load_model +from onnx import checker, numpy_helper, helper, load_model from onnx import TensorProto from mxnet.test_utils import download from mxnet.contrib import onnx as onnx_mxnet @@ -206,6 +206,18 @@ def test_imports(self): mxnet_out = bkd_rep.run(inputs) npt.assert_almost_equal(np_out, mxnet_out, decimal=4) + def test_exports(self): + input_shape = (2,1,3,1) + for test in export_test_cases: + test_name, onnx_name, mx_op, attrs = test + input_sym = mx.sym.var('data') + outsym = mx_op(input_sym, **attrs) + converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], np.float32, + onnx_file_path=outsym.name + ".onnx") + model = load_model(converted_model) + checker.check_model(model) + + # test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False, # fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name}, # 'remove': [attr_name], @@ -274,5 +286,11 @@ def test_imports(self): ("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1}) ] +# test_case = ("test_case_name", "ONNX_op_name", mxnet_op, attribute map) +export_test_cases = [ + ("test_expand", "Expand", mx.sym.broadcast_to, {'shape': (2,1,3,1)}), + ("test_tile", "Tile", mx.sym.tile, {'reps': (2,3)}) +] + if __name__ == '__main__': unittest.main()