Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge branch 'onnx_export' of https://github.com/Roshrini/mxnet into …
Browse files Browse the repository at this point in the history
…onnx_export
  • Loading branch information
Roshrini authored and rajanksin committed Jun 14, 2018
2 parents 4e46acd + 4e8b00d commit 9922a00
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/mxnet/contrib/onnx/_export/op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,14 +1709,18 @@ def convert_expand_dims(node, **kwargs):

@mx_op.register("squeeze")
def convert_squeeze(node, **kwargs):
"""Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator
"""Map MXNet's squeeze operator attributes to onnx's squeeze operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
proc_nodes = kwargs["proc_nodes"]
inputs = node["inputs"]
axis = int(node["attrs"]["axis"])
if "axis" in node["attrs"]:
axis = convert_string_to_list(node["attrs"]["axis"])
else:
raise AttributeError("Missing axis attribute: ONNX currently requires axis to "
"be specified for squeeze operator")

input_node_id = kwargs["index_lookup"][inputs[0][0]]
input_node = proc_nodes[input_node_id].name
Expand All @@ -1725,7 +1729,7 @@ def convert_squeeze(node, **kwargs):
"Squeeze",
[input_node],
[name],
axes=[axis],
axes=axis,
name=name,
)
return [node]
Expand Down

0 comments on commit 9922a00

Please sign in to comment.