Skip to content

Commit

Permalink
ONNX export: broadcast_to, tile ops (apache#13981)
Browse files Browse the repository at this point in the history
* Expand,tile op export

* fix

* adding test cases

* adding comments
  • Loading branch information
Roshrini authored and haohuw committed Jun 23, 2019
1 parent ee63ebd commit e3a99ca
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 1 deletion.
76 changes: 76 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,3 +1969,79 @@ def convert_roipooling(node, **kwargs):
name=name
)
return [node]


@mx_op.register("tile")
def convert_tile(node, **kwargs):
"""Map MXNet's Tile operator attributes to onnx's Tile
operator and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

reps_list = convert_string_to_list(attrs["reps"])

initializer = kwargs["initializer"]
reps_shape_np = np.array(reps_list, dtype='int64')
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[reps_shape_np.dtype]
dims = np.shape(reps_shape_np)

output_shape_name = "reps_attr_tensor" + str(kwargs["idx"])
tensor_node = onnx.helper.make_tensor_value_info(output_shape_name, data_type, dims)

initializer.append(
onnx.helper.make_tensor(
name=output_shape_name,
data_type=data_type,
dims=dims,
vals=reps_list,
raw=False,
)
)

input_nodes.append(output_shape_name)
tile_node = onnx.helper.make_node(
"Tile",
input_nodes,
[name],
name=name
)

return [tensor_node, tile_node]


@mx_op.register("broadcast_to")
def convert_broadcast_to(node, **kwargs):
"""Map MXNet's broadcast_to operator attributes to onnx's Expand
operator and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

shape_list = convert_string_to_list(attrs["shape"])

initializer = kwargs["initializer"]
output_shape_np = np.array(shape_list, dtype='int64')
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[output_shape_np.dtype]
dims = np.shape(output_shape_np)

output_shape_name = "expand_attr_tensor" + str(kwargs["idx"])
tensor_node = onnx.helper.make_tensor_value_info(output_shape_name, data_type, dims)

initializer.append(
onnx.helper.make_tensor(
name=output_shape_name,
data_type=data_type,
dims=dims,
vals=shape_list,
raw=False,
)
)

input_nodes.append(output_shape_name)
expand_node = onnx.helper.make_node(
"Expand",
input_nodes,
[name],
name=name
)

return [tensor_node, expand_node]
20 changes: 19 additions & 1 deletion tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from collections import namedtuple
import numpy as np
import numpy.testing as npt
from onnx import numpy_helper, helper, load_model
from onnx import checker, numpy_helper, helper, load_model
from onnx import TensorProto
from mxnet.test_utils import download
from mxnet.contrib import onnx as onnx_mxnet
Expand Down Expand Up @@ -206,6 +206,18 @@ def test_imports(self):
mxnet_out = bkd_rep.run(inputs)
npt.assert_almost_equal(np_out, mxnet_out, decimal=4)

def test_exports(self):
input_shape = (2,1,3,1)
for test in export_test_cases:
test_name, onnx_name, mx_op, attrs = test
input_sym = mx.sym.var('data')
outsym = mx_op(input_sym, **attrs)
converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], np.float32,
onnx_file_path=outsym.name + ".onnx")
model = load_model(converted_model)
checker.check_model(model)


# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False,
# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
# 'remove': [attr_name],
Expand Down Expand Up @@ -274,5 +286,11 @@ def test_imports(self):
("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1})
]

# test_case = ("test_case_name", "ONNX_op_name", mxnet_op, attribute map)
export_test_cases = [
("test_expand", "Expand", mx.sym.broadcast_to, {'shape': (2,1,3,1)}),
("test_tile", "Tile", mx.sym.tile, {'reps': (2,3)})
]

if __name__ == '__main__':
unittest.main()

0 comments on commit e3a99ca

Please sign in to comment.