Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

ONNX export: Add Crop, Deconvolution and fix the default stride of Pooling to 1 #12399

Merged
merged 2 commits into from
Jan 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,68 @@ def convert_convolution(node, **kwargs):
return [conv_node]


@mx_op.register("Deconvolution")
def convert_deconvolution(node, **kwargs):
"""Map MXNet's deconvolution operator attributes to onnx's ConvTranspose operator
and return the created node.
"""
name, inputs, attrs = get_inputs(node, kwargs)

kernel_dims = list(parse_helper(attrs, "kernel"))
stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
num_group = int(attrs.get("num_group", 1))
dilations = list(parse_helper(attrs, "dilate", [1, 1]))
adj_dims = list(parse_helper(attrs, "adj", [0, 0]))

pad_dims = pad_dims + pad_dims

deconv_node = onnx.helper.make_node(
"ConvTranspose",
inputs=inputs,
outputs=[name],
kernel_shape=kernel_dims,
strides=stride_dims,
dilations=dilations,
output_padding=adj_dims,
pads=pad_dims,
group=num_group,
name=name
)

return [deconv_node]


@mx_op.register("Crop")
def convert_crop(node, **kwargs):
Roshrini marked this conversation as resolved.
Show resolved Hide resolved
"""Map MXNet's crop operator attributes to onnx's Crop operator
and return the created node.
"""
name, inputs, attrs = get_inputs(node, kwargs)
num_inputs = len(inputs)

y, x = list(parse_helper(attrs, "offset", [0, 0]))
h, w = list(parse_helper(attrs, "h_w", [0, 0]))
if num_inputs > 1:
h, w = kwargs["out_shape"][-2:]
border = [x, y, x + w, y + h]

crop_node = onnx.helper.make_node(
"Crop",
inputs=[inputs[0]],
outputs=[name],
border=border,
scale=[1, 1],
name=name
)

logging.warning(
"Using an experimental ONNX operator: Crop. " \
"Its definition can change.")

return [crop_node]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a test for crop too? - in mxnet_export_test.py maybe

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vandanavk Could you advise where should I put the test for just export? I made the test but am unsure where should it go after the refactor. My understanding is that test_nodes.py is for things that can be both exported and imported, should I put it in mxnet_export_test.py then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ptrendx export tests usually perform an import-export-import on an ONNX model and then compare inference results.

if you are planning to add an ONNX import for crop, then adding the test would be straightforward - just add a test_case in test_node.py.

if you are adding export alone for crop, then in test_node.py, create a new test case list (same format as the existing maybe), add a new function test_export() in class TestNode(unittest.TestCase), export to onnx format. not sure how to test inference in this case though.


@mx_op.register("FullyConnected")
def convert_fully_connected(node, **kwargs):
"""Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
Expand Down Expand Up @@ -583,8 +645,8 @@ def convert_pooling(node, **kwargs):
name, input_nodes, attrs = get_inputs(node, kwargs)

kernel = eval(attrs["kernel"])
pool_type = attrs["pool_type"]
stride = eval(attrs["stride"]) if attrs.get("stride") else None
pool_type = attrs["pool_type"] if attrs.get("pool_type") else "max"
stride = eval(attrs["stride"]) if attrs.get("stride") else (1, 1)
global_pool = get_boolean_attribute_value(attrs, "global_pool")

pooling_convention = attrs.get('pooling_convention', 'valid')
Expand Down
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@
'test_Softmax',
'test_softmax_functional',
'test_softmax_lastdim',
]
],
'export': ['test_ConvTranspose2d']
}

STANDARD_MODEL = {
Expand Down