Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Helper function to convert bool string attributes to int
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Oct 8, 2018
1 parent 836ba78 commit f5a84e8
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ def convert_string_to_list(string_val):

return result_list

def convert_bool_to_int(attrs, attr_name):
""" Helper function to convert a string version
of Boolean attributes to integer for ONNX.
Takes attribute dictionary and attr_name as
parameters.
"""
return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0

@mx_op.register("null")
def convert_weights_and_inputs(node, **kwargs):
"""Helper function to convert weights and inputs.
Expand Down Expand Up @@ -587,10 +595,8 @@ def convert_dot(node, **kwargs):
trans_a_node = None
trans_b_node = None

trans_a = 1 if ("transpose_a" in attrs) and \
attrs.get("transpose_a") in ["True", "1"] else 0
trans_b = 1 if ("transpose_b" in attrs) and \
attrs.get("transpose_b") in ["True", "1"] else 0
trans_a = convert_bool_to_int(attrs, "transpose_a")
trans_b = convert_bool_to_int(attrs, "transpose_b")

op_name = "transpose" + str(kwargs["idx"])
create_helper_trans_node(op_name, input_node_a, 'a')
Expand Down Expand Up @@ -732,8 +738,8 @@ def convert_pooling(node, **kwargs):
kernel = eval(attrs["kernel"])
pool_type = attrs["pool_type"]
stride = eval(attrs["stride"]) if attrs.get("stride") else None
global_pool = True if "global_pool" in attrs and\
attrs.get("global_pool") == "True" else False
global_pool = convert_bool_to_int(attrs, "global_pool")

node_inputs = node["inputs"]
input_node_idx = kwargs["index_lookup"][node_inputs[0][0]]
input_node = proc_nodes[input_node_idx]
Expand All @@ -753,7 +759,7 @@ def convert_pooling(node, **kwargs):
pool_types = {"max": "MaxPool", "avg": "AveragePool"}
global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"}

if global_pool:
if global_pool == 1:
node = onnx.helper.make_node(
global_pool_types[pool_type],
[input_node.name], # input
Expand Down

0 comments on commit f5a84e8

Please sign in to comment.