diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 204195d7409a..bc7d17e3c927 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -127,6 +127,14 @@ def convert_string_to_list(string_val): return result_list +def convert_bool_to_int(attrs, attr_name): + """ Helper function to convert a string version + of Boolean attributes to integer for ONNX. + Takes attribute dictionary and attr_name as + parameters. + """ + return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0 + @mx_op.register("null") def convert_weights_and_inputs(node, **kwargs): """Helper function to convert weights and inputs. @@ -587,10 +595,8 @@ def convert_dot(node, **kwargs): trans_a_node = None trans_b_node = None - trans_a = 1 if ("transpose_a" in attrs) and \ - attrs.get("transpose_a") in ["True", "1"] else 0 - trans_b = 1 if ("transpose_b" in attrs) and \ - attrs.get("transpose_b") in ["True", "1"] else 0 + trans_a = convert_bool_to_int(attrs, "transpose_a") + trans_b = convert_bool_to_int(attrs, "transpose_b") op_name = "transpose" + str(kwargs["idx"]) create_helper_trans_node(op_name, input_node_a, 'a') @@ -732,8 +738,8 @@ def convert_pooling(node, **kwargs): kernel = eval(attrs["kernel"]) pool_type = attrs["pool_type"] stride = eval(attrs["stride"]) if attrs.get("stride") else None - global_pool = True if "global_pool" in attrs and\ - attrs.get("global_pool") == "True" else False + global_pool = convert_bool_to_int(attrs, "global_pool") + node_inputs = node["inputs"] input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] input_node = proc_nodes[input_node_idx] @@ -753,7 +759,7 @@ def convert_pooling(node, **kwargs): pool_types = {"max": "MaxPool", "avg": "AveragePool"} global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"} - if global_pool: + if global_pool == 1: node = onnx.helper.make_node( global_pool_types[pool_type], [input_node.name], # input