From abd6951ab2d72ef9681ae842aab182b5dac9152e Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 3 Jan 2019 14:09:59 -0800 Subject: [PATCH 1/2] Added Deconvolution and Crop to ONNX exporter --- .../contrib/onnx/mx2onnx/_op_translations.py | 64 ++++++++++++++++++- tests/python-pytest/onnx/test_cases.py | 3 +- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 3baf10a10d39..6c5c21e5d493 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -219,6 +219,68 @@ def convert_convolution(node, **kwargs): return [conv_node] +@mx_op.register("Deconvolution") +def convert_deconvolution(node, **kwargs): + """Map MXNet's deconvolution operator attributes to onnx's ConvTranspose operator + and return the created node. + """ + name, inputs, attrs = get_inputs(node, kwargs) + + kernel_dims = list(parse_helper(attrs, "kernel")) + stride_dims = list(parse_helper(attrs, "stride", [1, 1])) + pad_dims = list(parse_helper(attrs, "pad", [0, 0])) + num_group = int(attrs.get("num_group", 1)) + dilations = list(parse_helper(attrs, "dilate", [1, 1])) + adj_dims = list(parse_helper(attrs, "adj", [0, 0])) + + pad_dims = pad_dims + pad_dims + + deconv_node = onnx.helper.make_node( + "ConvTranspose", + inputs=inputs, + outputs=[name], + kernel_shape=kernel_dims, + strides=stride_dims, + dilations=dilations, + output_padding=adj_dims, + pads=pad_dims, + group=num_group, + name=name + ) + + return [deconv_node] + + +@mx_op.register("Crop") +def convert_crop(node, **kwargs): + """Map MXNet's crop operator attributes to onnx's Crop operator + and return the created node. + """ + name, inputs, attrs = get_inputs(node, kwargs) + num_inputs = len(inputs) + + y, x = list(parse_helper(attrs, "offset", [0, 0])) + h, w = list(parse_helper(attrs, "h_w", [0, 0])) + if num_inputs > 1: + h, w = kwargs["out_shape"][-2:] + border = [x, y, x + w, y + h] + + crop_node = onnx.helper.make_node( + "Crop", + inputs=[inputs[0]], + outputs=[name], + border=border, + scale=[1, 1], + name=name + ) + + logging.warning( + "Using an experimental ONNX operator: Crop. " \ + "Its definition can change.") + + return [crop_node] + + @mx_op.register("FullyConnected") def convert_fully_connected(node, **kwargs): """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator @@ -584,7 +646,7 @@ def convert_pooling(node, **kwargs): kernel = eval(attrs["kernel"]) pool_type = attrs["pool_type"] - stride = eval(attrs["stride"]) if attrs.get("stride") else None + stride = eval(attrs["stride"]) if attrs.get("stride") else (1, 1) global_pool = get_boolean_attribute_value(attrs, "global_pool") pooling_convention = attrs.get('pooling_convention', 'valid') diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index 6a189b62492d..0178a2c9ecbc 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -115,7 +115,8 @@ 'test_Softmax', 'test_softmax_functional', 'test_softmax_lastdim', - ] + ], + 'export': ['test_ConvTranspose2d'] } STANDARD_MODEL = { From 889c1cd561fafa53d765d5412427a87d7d0f32c7 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 3 Jan 2019 14:16:59 -0800 Subject: [PATCH 2/2] Added default for pool_type --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 6c5c21e5d493..b5dbb461ab47 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -645,7 +645,7 @@ def convert_pooling(node, **kwargs): name, input_nodes, attrs = get_inputs(node, kwargs) kernel = eval(attrs["kernel"]) - pool_type = attrs["pool_type"] + pool_type = attrs["pool_type"] if attrs.get("pool_type") else "max" stride = eval(attrs["stride"]) if attrs.get("stride") else (1, 1) global_pool = get_boolean_attribute_value(attrs, "global_pool")