Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
minor fix for squeeze operator.
Browse files Browse the repository at this point in the history
Also, added error handling
  • Loading branch information
rajanksin committed Jun 13, 2018
1 parent cc04a35 commit 4e8b00d
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/mxnet/contrib/onnx/_export/op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,14 +1643,18 @@ def convert_expand_dims(node, **kwargs):

@mx_op.register("squeeze")
def convert_squeeze(node, **kwargs):
"""Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator
"""Map MXNet's squeeze operator attributes to onnx's squeeze operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
proc_nodes = kwargs["proc_nodes"]
inputs = node["inputs"]
axis = int(node["attrs"]["axis"])
if "axis" in node["attrs"]:
axis = convert_string_to_list(node["attrs"]["axis"])
else:
raise AttributeError("Missing axis attribute: ONNX currently requires axis to "
"be specified for squeeze operator")

input_node_id = kwargs["index_lookup"][inputs[0][0]]
input_node = proc_nodes[input_node_id].name
Expand All @@ -1659,7 +1663,7 @@ def convert_squeeze(node, **kwargs):
"Squeeze",
[input_node],
[name],
axes=[axis],
axes=axis,
name=name,
)
return [node]
Expand Down

0 comments on commit 4e8b00d

Please sign in to comment.