From 4e8b00d09fbabdb6e4988f84a54174393f81ba5c Mon Sep 17 00:00:00 2001 From: spidyDev Date: Wed, 13 Jun 2018 14:38:30 -0700 Subject: [PATCH] minor fix for squeeze operator. Also, added error handling --- python/mxnet/contrib/onnx/_export/op_translations.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/_export/op_translations.py b/python/mxnet/contrib/onnx/_export/op_translations.py index 4ef836116ea2..0e0087c801d2 100644 --- a/python/mxnet/contrib/onnx/_export/op_translations.py +++ b/python/mxnet/contrib/onnx/_export/op_translations.py @@ -1643,14 +1643,18 @@ def convert_expand_dims(node, **kwargs): @mx_op.register("squeeze") def convert_squeeze(node, **kwargs): - """Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator + """Map MXNet's squeeze operator attributes to onnx's squeeze operator and return the created node. """ helper, _, _ = import_onnx_modules() name = node["name"] proc_nodes = kwargs["proc_nodes"] inputs = node["inputs"] - axis = int(node["attrs"]["axis"]) + if "axis" in node["attrs"]: + axis = convert_string_to_list(node["attrs"]["axis"]) + else: + raise AttributeError("Missing axis attribute: ONNX currently requires axis to " + "be specified for squeeze operator") input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -1659,7 +1663,7 @@ def convert_squeeze(node, **kwargs): "Squeeze", [input_node], [name], - axes=[axis], + axes=axis, name=name, ) return [node]