Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Use inputs from get_inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Nov 20, 2018
1 parent 3075806 commit ba0b25e
Showing 1 changed file with 1 addition and 14 deletions.
15 changes: 1 addition & 14 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,6 @@ def convert_deconvolution(node, **kwargs):
"""
name, inputs, attrs = get_inputs(node, kwargs)

num_inputs = len(inputs)

proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name
weights_node = proc_nodes[kwargs["index_lookup"][inputs[1][0]]].name

if num_inputs > 2:
bias_node = proc_nodes[kwargs["index_lookup"][inputs[2][0]]].name

kernel_dims = list(parse_helper(attrs, "kernel"))
stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
Expand All @@ -244,13 +235,9 @@ def convert_deconvolution(node, **kwargs):

pad_dims = pad_dims + pad_dims

input_nodes = [input_node, weights_node]
if num_inputs > 2:
input_nodes.append(bias_node)

deconv_node = onnx.helper.make_node(
"ConvTranspose",
inputs=input_nodes,
inputs=inputs,
outputs=[name],
kernel_shape=kernel_dims,
strides=stride_dims,
Expand Down

0 comments on commit ba0b25e

Please sign in to comment.