Skip to content

Commit

Permalink
ONNX export: Add Crop, Deconvolution and fix the default stride of Po…
Browse files Browse the repository at this point in the history
…oling to 1 (apache#12399)

* Added Deconvolution and Crop to ONNX exporter

* Added default for pool_type
  • Loading branch information
ptrendx authored and stephenrawls committed Feb 16, 2019
1 parent f8c38c2 commit ae12ee5
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
66 changes: 64 additions & 2 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,68 @@ def convert_convolution(node, **kwargs):
return [conv_node]


@mx_op.register("Deconvolution")
def convert_deconvolution(node, **kwargs):
"""Map MXNet's deconvolution operator attributes to onnx's ConvTranspose operator
and return the created node.
"""
name, inputs, attrs = get_inputs(node, kwargs)

kernel_dims = list(parse_helper(attrs, "kernel"))
stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
num_group = int(attrs.get("num_group", 1))
dilations = list(parse_helper(attrs, "dilate", [1, 1]))
adj_dims = list(parse_helper(attrs, "adj", [0, 0]))

pad_dims = pad_dims + pad_dims

deconv_node = onnx.helper.make_node(
"ConvTranspose",
inputs=inputs,
outputs=[name],
kernel_shape=kernel_dims,
strides=stride_dims,
dilations=dilations,
output_padding=adj_dims,
pads=pad_dims,
group=num_group,
name=name
)

return [deconv_node]


@mx_op.register("Crop")
def convert_crop(node, **kwargs):
"""Map MXNet's crop operator attributes to onnx's Crop operator
and return the created node.
"""
name, inputs, attrs = get_inputs(node, kwargs)
num_inputs = len(inputs)

y, x = list(parse_helper(attrs, "offset", [0, 0]))
h, w = list(parse_helper(attrs, "h_w", [0, 0]))
if num_inputs > 1:
h, w = kwargs["out_shape"][-2:]
border = [x, y, x + w, y + h]

crop_node = onnx.helper.make_node(
"Crop",
inputs=[inputs[0]],
outputs=[name],
border=border,
scale=[1, 1],
name=name
)

logging.warning(
"Using an experimental ONNX operator: Crop. " \
"Its definition can change.")

return [crop_node]


@mx_op.register("FullyConnected")
def convert_fully_connected(node, **kwargs):
"""Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
Expand Down Expand Up @@ -583,8 +645,8 @@ def convert_pooling(node, **kwargs):
name, input_nodes, attrs = get_inputs(node, kwargs)

kernel = eval(attrs["kernel"])
pool_type = attrs["pool_type"]
stride = eval(attrs["stride"]) if attrs.get("stride") else None
pool_type = attrs["pool_type"] if attrs.get("pool_type") else "max"
stride = eval(attrs["stride"]) if attrs.get("stride") else (1, 1)
global_pool = get_boolean_attribute_value(attrs, "global_pool")
p_value = attrs.get('p_value', 'None')

Expand Down
3 changes: 2 additions & 1 deletion tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@
'test_Softmax',
'test_softmax_functional',
'test_softmax_lastdim',
]
],
'export': ['test_ConvTranspose2d']
}

STANDARD_MODEL = {
Expand Down

0 comments on commit ae12ee5

Please sign in to comment.