Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Add Crop, Deconvolution and fix the default stride of Po…
Browse files Browse the repository at this point in the history
…oling to 1
  • Loading branch information
ptrendx committed Aug 29, 2018
1 parent 1f0d6ba commit fd894b3
Showing 1 changed file with 84 additions and 1 deletion.
85 changes: 84 additions & 1 deletion python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,89 @@ def convert_convolution(node, **kwargs):
return [conv_node]


@mx_op.register("Deconvolution")
def convert_deconvolution(node, **kwargs):
"""Map MXNet's deconvolution operator attributes to onnx's ConvTranspose operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]

num_inputs = len(inputs)

proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name
weights_node = proc_nodes[kwargs["index_lookup"][inputs[1][0]]].name

if num_inputs > 2:
bias_node = proc_nodes[kwargs["index_lookup"][inputs[2][0]]].name

attrs = node.get("attrs")

kernel_dims = list(parse_helper(attrs, "kernel"))
stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
num_group = int(attrs.get("num_group", 1))
dilations = list(parse_helper(attrs, "dilate", [1, 1]))
adj_dims = list(parse_helper(attrs, "adj"))

pad_dims = pad_dims + pad_dims

input_nodes = [input_node, weights_node]
if num_inputs > 2:
input_nodes.append(bias_node)

deconv_node = helper.make_node(
"ConvTranspose",
inputs=input_nodes,
outputs=[name],
kernel_shape=kernel_dims,
strides=stride_dims,
dilations=dilations,
output_padding=adj_dims,
pads=pad_dims,
group=num_group,
name=name
)

return [deconv_node]


@mx_op.register("Crop")
def convert_crop(node, **kwargs):
"""Map MXNet's crop operator attributes to onnx's Crop operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]

num_inputs = len(inputs)

proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name

attrs = node.get("attrs")

x, y = list(parse_helper(attrs, "offset"))
h, w = list(parse_helper(attrs, "h_w", [0, 0]))
if num_inputs > 1:
h, w = kwargs["out_shape"][-2:]
border = [x, y, x + w, y + h]

crop_node = helper.make_node(
"Crop",
inputs=[input_node],
outputs=[name],
border=border,
scale=[1, 1],
name=name
)

return [crop_node]


@mx_op.register("FullyConnected")
def convert_fully_connected(node, **kwargs):
"""Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
Expand Down Expand Up @@ -612,7 +695,7 @@ def convert_pooling(node, **kwargs):
attrs = node["attrs"]
kernel = eval(attrs["kernel"])
pool_type = attrs["pool_type"]
stride = eval(attrs["stride"]) if attrs.get("stride") else None
stride = eval(attrs["stride"]) if attrs.get("stride") else (1, 1)
global_pool = True if "global_pool" in attrs and\
attrs.get("global_pool") == "True" else False
node_inputs = node["inputs"]
Expand Down

0 comments on commit fd894b3

Please sign in to comment.