diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 7cf856c767fa..bed316b694c3 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -74,7 +74,7 @@ def import_onnx_modules(): def parse_helper(attrs, attrs_name, alt_value=None): """Helper function to parse operator attributes in required format.""" tuple_re = re.compile('\([0-9L|,| ]+\)') - if attrs is None: + if attrs is None or attrs == {}: return alt_value attrs_str = None if attrs.get(attrs_name) is None else str(attrs.get(attrs_name)) if attrs_str is None: @@ -135,12 +135,21 @@ def get_boolean_attribute_value(attrs, attr_name): """ return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0 +def get_inputs(node, kwargs): + """Helper function to get inputs""" + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + attrs = node.get("attrs", {}) + + return onnx, name, proc_nodes, inputs, attrs + @mx_op.register("null") def convert_weights_and_inputs(node, **kwargs): """Helper function to convert weights and inputs. """ - onnx = import_onnx_modules() - name = node["name"] + onnx, name, _, _, _ = get_inputs(node, kwargs) if kwargs["is_input"] is False: weights = kwargs["weights"] @@ -172,21 +181,16 @@ def convert_convolution(node, **kwargs): """Map MXNet's convolution operator attributes to onnx's Conv operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) num_inputs = len(inputs) - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name weights_node = proc_nodes[kwargs["index_lookup"][inputs[1][0]]].name if num_inputs > 2: bias_node = proc_nodes[kwargs["index_lookup"][inputs[2][0]]].name - attrs = node.get("attrs") - kernel_dims = list(parse_helper(attrs, "kernel")) stride_dims = list(parse_helper(attrs, "stride", [1, 1])) pad_dims = list(parse_helper(attrs, "pad", [0, 0])) @@ -219,18 +223,18 @@ def convert_fully_connected(node, **kwargs): """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - attrs = node["attrs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + initializer = kwargs["initializer"] no_bias = get_boolean_attribute_value(attrs, "no_bias") input_node_id = kwargs["index_lookup"][inputs[0][0]] weight_node_id = kwargs["index_lookup"][inputs[1][0]] - - proc_nodes = kwargs["proc_nodes"] + bias_node_id = kwargs["index_lookup"][inputs[2][0]] + input_node = proc_nodes[input_node_id] + weights_node = proc_nodes[weight_node_id] + bias_node = proc_nodes[bias_node_id] input_node = proc_nodes[input_node_id] input_name = input_node.name @@ -280,12 +284,8 @@ def convert_batchnorm(node, **kwargs): """Map MXNet's BatchNorm operator attributes to onnx's BatchNormalization operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) - attrs = node["attrs"] momentum = float(node.get("attrs", {}).get("momentum", 0.9)) eps = float(attrs.get("eps", 0.001)) @@ -327,11 +327,9 @@ def convert_tanh(node, **kwargs): """Map MXNet's tanh operator attributes to onnx's Tanh operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) + input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_node_idx].name node = onnx.helper.make_node( @@ -347,11 +345,9 @@ def convert_cos(node, **kwargs): """Map MXNet's cos operator attributes to onnx's Cos operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_node_idx].name node = onnx.helper.make_node( @@ -367,11 +363,9 @@ def convert_sin(node, **kwargs): """Map MXNet's sin operator attributes to onnx's Sin operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_node_idx].name node = onnx.helper.make_node( @@ -387,11 +381,9 @@ def convert_tan(node, **kwargs): """Map MXNet's tan operator attributes to onnx's tan operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) + input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_node_idx].name node = onnx.helper.make_node( @@ -407,11 +399,9 @@ def convert_acos(node, **kwargs): """Map MXNet's acos operator attributes to onnx's acos operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_node_idx].name node = onnx.helper.make_node( @@ -427,11 +417,9 @@ def convert_asin(node, **kwargs): """Map MXNet's asin operator attributes to onnx's asin operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_node_idx].name node = onnx.helper.make_node( @@ -447,11 +435,9 @@ def convert_atan(node, **kwargs): """Map MXNet's atan operator attributes to onnx's atan operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) + input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_node_idx].name node = onnx.helper.make_node( @@ -468,11 +454,9 @@ def convert_sigmoid(node, **kwargs): """Map MXNet's sigmoid operator attributes to onnx's Sigmoid operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_node_idx].name node = onnx.helper.make_node( @@ -488,11 +472,9 @@ def convert_relu(node, **kwargs): """Map MXNet's relu operator attributes to onnx's Relu operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_node_idx].name node = onnx.helper.make_node( @@ -509,14 +491,10 @@ def convert_activation(node, **kwargs): """Map MXNet's Activation operator attributes to onnx's Tanh/Relu operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) - proc_nodes = kwargs["proc_nodes"] - attrs = node["attrs"] act_type = attrs["act_type"] - inputs = node["inputs"] input_node_idx = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_idx].output[0] @@ -551,11 +529,8 @@ def convert_pad(node, **kwargs): """Map MXNet's pad operator attributes to onnx's Pad operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - attrs = node["attrs"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_node_idx = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_idx].name @@ -608,17 +583,13 @@ def convert_dot(node, **kwargs): """Map MXNet's dot operator attributes to onnx's MatMul and Transpose operators based on the values set for transpose_a, transpose_b attributes.""" - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - name = node["name"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) - input_a_idx = kwargs["index_lookup"][node_inputs[0][0]] + input_a_idx = kwargs["index_lookup"][inputs[0][0]] input_node_a = proc_nodes[input_a_idx].name - input_b_idx = kwargs["index_lookup"][node_inputs[1][0]] + input_b_idx = kwargs["index_lookup"][inputs[1][0]] input_node_b = proc_nodes[input_b_idx].name - attrs = node.get('attrs', {}) trans_a_node = None trans_b_node = None @@ -660,26 +631,17 @@ def convert_linalg_gemm2(node, **kwargs): transpose_a, transpose_b attributes. Return multiple nodes created. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - name = node["name"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) - input_a_idx = kwargs["index_lookup"][node_inputs[0][0]] + input_a_idx = kwargs["index_lookup"][inputs[0][0]] input_node_a = proc_nodes[input_a_idx].name - input_b_idx = kwargs["index_lookup"][node_inputs[1][0]] + input_b_idx = kwargs["index_lookup"][inputs[1][0]] input_node_b = proc_nodes[input_b_idx].name # Getting the attributes and assigning default values. - if "attrs" in node: - attrs = node["attrs"] - alpha = float(attrs["alpha"]) - trans_a = int(attrs["transpose_a"]) - trans_b = int(attrs["transpose_b"]) - else: - alpha = 1.0 - trans_a = 0 - trans_b = 0 + alpha = float(attrs.get("alpha", 1.0)) + trans_a = get_boolean_attribute_value(attrs, "transpose_a") + trans_b = get_boolean_attribute_value(attrs, "transpose_b") op_name = "transpose" + str(kwargs["idx"]) @@ -759,18 +721,15 @@ def convert_pooling(node, **kwargs): MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators based on the input node's attributes and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - attrs = node["attrs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + kernel = eval(attrs["kernel"]) pool_type = attrs["pool_type"] stride = eval(attrs["stride"]) if attrs.get("stride") else None global_pool = get_boolean_attribute_value(attrs, "global_pool") - node_inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_idx] - name = node["name"] pooling_convention = attrs.get('pooling_convention', 'valid') @@ -812,10 +771,7 @@ def convert_exp(node, **kwargs): """Map MXNet's exp operator attributes to onnx's Exp operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -834,10 +790,7 @@ def convert_identity(node, **kwargs): """Map MXNet's _copy operator attributes to onnx's Identity operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -856,13 +809,10 @@ def convert_leakyrelu(node, **kwargs): """Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu operators based on the input node's attributes and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name - attrs = node["attrs"] act_type = attrs.get("act_type", "leaky") alpha = float(attrs.get("slope", 0.25)) @@ -901,14 +851,12 @@ def convert_softmax(node, **kwargs): """Map MXNet's softmax operator attributes to onnx's Softmax operator and return the created node. """ - onnx = import_onnx_modules() - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_idx] - name = node["name"] - axis = int(node.get("attrs", {}).get("axis", -1)) + axis = int(attrs.get("axis", -1)) softmax_node = onnx.helper.make_node( "Softmax", @@ -928,12 +876,10 @@ def convert_softmax_output(node, **kwargs): """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator and return the created node. """ - onnx = import_onnx_modules() - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) + input1_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] input1 = proc_nodes[input1_idx] - name = node["name"] softmax_node = onnx.helper.make_node( "Softmax", @@ -951,12 +897,11 @@ def convert_concat(node, **kwargs): """Map MXNet's Concat operator attributes to onnx's Concat operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - proc_nodes = kwargs["proc_nodes"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_names = [proc_nodes[kwargs["index_lookup"][i[0]]].name for i in inputs] - axis = int(node.get("attrs", {}).get("dim", 1)) + + axis = int(attrs.get("dim", 1)) concat_node = onnx.helper.make_node( "Concat", input_names, @@ -972,12 +917,12 @@ def convert_transpose(node, **kwargs): """Map MXNet's transpose operator attributes to onnx's Transpose operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_idx].name - axes = node.get("attrs", {}).get("axes", ()) + + axes = attrs.get("axes", ()) if axes: axes = tuple(map(int, re.findall(r'\d+', axes))) @@ -1004,13 +949,11 @@ def convert_lrn(node, **kwargs): """Map MXNet's LRN operator attributes to onnx's LRN operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] + onnx, name, proc_nodes, _, attrs = get_inputs(node, kwargs) + input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] input_node = proc_nodes[input_idx].name - attrs = node["attrs"] alpha = float(attrs["alpha"]) if "alpha" in attrs else 0.0001 beta = float(attrs["beta"]) if "beta" in attrs else 0.75 bias = float(attrs["knorm"]) if "knorm" in attrs else 1.0 @@ -1035,11 +978,10 @@ def convert_l2normalization(node, **kwargs): """Map MXNet's L2Normalization operator attributes to onnx's LpNormalization operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_id = kwargs["index_lookup"][node["inputs"][0][0]] - input_name = kwargs["proc_nodes"][input_id].name - attrs = node["attrs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + + input_id = kwargs["index_lookup"][inputs[0][0]] + input_name = proc_nodes[input_id].name mode = attrs.get("mode", "instance") if mode != "channel": @@ -1060,11 +1002,10 @@ def convert_dropout(node, **kwargs): """Map MXNet's Dropout operator attributes to onnx's Dropout operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_id = kwargs["index_lookup"][node["inputs"][0][0]] - input_name = kwargs["proc_nodes"][input_id].name - attrs = node["attrs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + + input_id = kwargs["index_lookup"][inputs[0][0]] + input_name = proc_nodes[input_id].name probability = float(attrs["p"]) dropout_node = onnx.helper.make_node( @@ -1082,10 +1023,9 @@ def convert_flatten(node, **kwargs): """Map MXNet's Flatten operator attributes to onnx's Flatten operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) + + input_idx = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_idx].name # .output[0] flatten_node = onnx.helper.make_node( @@ -1101,12 +1041,10 @@ def convert_clip(node, **kwargs): """Map MXNet's Clip operator attributes to onnx's Clip operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + + input_idx = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_idx].name - attrs = node["attrs"] a_min = np.float(attrs.get('a_min', -np.inf)) a_max = np.float(attrs.get('a_max', np.inf)) @@ -1123,11 +1061,9 @@ def convert_clip(node, **kwargs): def scalar_op_helper(node, op_name, **kwargs): """Helper function for scalar arithmetic operations""" - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - scalar_value = [float(node.get("attrs", {}).get("scalar", 1))] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + + scalar_value = [float(attrs.get("scalar", 1))] input_name_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_name_id].name @@ -1239,14 +1175,10 @@ def convert_argmax(node, **kwargs): """Map MXNet's argmax operator attributes to onnx's ArgMax operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) - input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_idx].name - name = node["name"] - attrs = node["attrs"] axis = int(attrs.get("axis")) keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1 @@ -1266,14 +1198,10 @@ def convert_argmin(node, **kwargs): """Map MXNet's argmin operator attributes to onnx's ArgMin operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) - input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_idx].name - name = node["name"] - attrs = node["attrs"] axis = int(attrs.get("axis")) keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1 @@ -1293,17 +1221,13 @@ def convert_maximum(node, **kwargs): """Map MXNet's _maximum operator attributes to onnx's Max operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_list = [] - for node_input in node_inputs: + for node_input in inputs: node_id = kwargs["index_lookup"][node_input[0]] input_node_list.append(proc_nodes[node_id].name) - name = node["name"] - node = onnx.helper.make_node( 'Max', inputs=input_node_list, @@ -1319,17 +1243,13 @@ def convert_minimum(node, **kwargs): """Map MXNet's _minimum operator attributes to onnx's Min operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) input_node_list = [] - for node_input in node_inputs: + for node_input in inputs: node_id = kwargs["index_lookup"][node_input[0]] input_node_list.append(proc_nodes[node_id].name) - name = node["name"] - node = onnx.helper.make_node( 'Min', inputs=input_node_list, @@ -1345,15 +1265,12 @@ def convert_min(node, **kwargs): """Map MXNet's min operator attributes to onnx's ReduceMin operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) + keepdims = int(attrs.get("keepdims", 0)) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -1386,15 +1303,12 @@ def convert_max(node, **kwargs): """Map MXNet's max operator attributes to onnx's ReduceMax operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) + keepdims = int(attrs.get("keepdims", 0)) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -1427,15 +1341,12 @@ def convert_mean(node, **kwargs): """Map MXNet's mean operator attributes to onnx's ReduceMean operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) + keepdims = int(attrs.get("keepdims", 0)) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -1468,15 +1379,12 @@ def convert_prod(node, **kwargs): """Map MXNet's prod operator attributes to onnx's ReduceProd operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) + keepdims = int(attrs.get("keepdims", 0)) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -1510,10 +1418,7 @@ def convert_elementwise_add(node, **kwargs): """Map MXNet's elemwise_add operator attributes to onnx's Add operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_b_id = kwargs["index_lookup"][inputs[1][0]] @@ -1536,10 +1441,7 @@ def covert_broadcast_add(node, **kwargs): """Map MXNet's broadcast_add operator attributes to onnx's Add operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_b_id = kwargs["index_lookup"][inputs[1][0]] @@ -1562,10 +1464,7 @@ def convert_elementwise_sub(node, **kwargs): """Map MXNet's elemwise_sub operator attributes to onnx's Sub operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_b_id = kwargs["index_lookup"][inputs[1][0]] @@ -1587,10 +1486,7 @@ def covert_broadcast_sub(node, **kwargs): """Map MXNet's broadcast_sub operator attributes to onnx's Sub operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_b_id = kwargs["index_lookup"][inputs[1][0]] @@ -1613,10 +1509,7 @@ def convert_elemwise_mul(node, **kwargs): """Map MXNet's elemwise_mul operator attributes to onnx's Mul operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_b_id = kwargs["index_lookup"][inputs[1][0]] @@ -1638,10 +1531,7 @@ def convert_broadcast_mul(node, **kwargs): """Map MXNet's broadcast_mul operator attributes to onnx's Mul operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_b_id = kwargs["index_lookup"][inputs[1][0]] @@ -1664,10 +1554,7 @@ def convert_elemwise_div(node, **kwargs): """Map MXNet's elemwise_div operator attributes to onnx's Div operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_b_id = kwargs["index_lookup"][inputs[1][0]] @@ -1690,10 +1577,7 @@ def convert_broadcast_div(node, **kwargs): """Map MXNet's broadcast_div operator attributes to onnx's Div operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_b_id = kwargs["index_lookup"][inputs[1][0]] @@ -1716,10 +1600,7 @@ def convert_negative(node, **kwargs): """Map MXNet's negative operator attributes to onnx's Neg operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] @@ -1740,10 +1621,7 @@ def convert_abs(node, **kwargs): """Map MXNet's abs operator attributes to onnx's Abs operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] @@ -1764,10 +1642,7 @@ def convert_addn(node, **kwargs): """Map MXNet's add_n operator attributes to onnx's Sum operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_list = [] for input_val in inputs: @@ -1787,10 +1662,7 @@ def convert_ceil(node, **kwargs): """Map MXNet's ceil operator attributes to onnx's Ceil operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -1808,10 +1680,7 @@ def convert_floor(node, **kwargs): """Map MXNet's floor operator attributes to onnx's Floor operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -1831,11 +1700,7 @@ def convert_reshape(node, **kwargs): Converts output shape attribute to output shape tensor and return multiple created nodes. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) output_shape_list = convert_string_to_list(attrs["shape"]) @@ -1880,11 +1745,9 @@ def convert_cast(node, **kwargs): """Map MXNet's Cast operator attributes to onnx's Cast operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - dtype = node["attrs"]["dtype"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + + dtype = attrs["dtype"] # dtype can be mapped only with types from TensorProto # float32 is mapped to float and float64 to double in onnx @@ -1912,16 +1775,14 @@ def convert_slice_axis(node, **kwargs): """Map MXNet's slice_axis operator attributes to onnx's Slice operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - axes = int(node["attrs"]["axis"]) - starts = int(node["attrs"]["begin"]) - if node["attrs"]["end"] == 'None': + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + + axes = int(attrs["axis"]) + starts = int(attrs["begin"]) + if attrs["end"] == 'None': raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' attribute") else: - ends = int(node["attrs"]["end"]) + ends = int(attrs["end"]) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -1944,13 +1805,11 @@ def convert_slice_channel(node, **kwargs): operator based on squeeze_axis attribute and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - num_outputs = int(node.get("attrs", {})["num_outputs"]) - axis = int(node.get("attrs", {}).get("axis", 1)) - squeeze_axis = int(node.get("attrs", {}).get("squeeze_axis", 0)) + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + + num_outputs = int(attrs["num_outputs"]) + axis = int(attrs.get("axis", 1)) + squeeze_axis = int(attrs.get("squeeze_axis", 0)) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -1984,11 +1843,9 @@ def convert_expand_dims(node, **kwargs): """Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - axis = int(node["attrs"]["axis"]) + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + + axis = int(attrs["axis"]) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -2007,12 +1864,10 @@ def convert_squeeze(node, **kwargs): """Map MXNet's squeeze operator attributes to onnx's squeeze operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - if "axis" in node["attrs"]: - axis = convert_string_to_list(node["attrs"]["axis"]) + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) + + if "axis" in attrs: + axis = convert_string_to_list(attrs["axis"]) else: raise AttributeError("Missing axis attribute: ONNX currently requires axis to " "be specified for squeeze operator") @@ -2035,10 +1890,7 @@ def convert_log(node, **kwargs): """Map MXNet's log operator attributes to onnx's Log operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -2057,10 +1909,7 @@ def convert_reciprocal(node, **kwargs): """Map MXNet's reciprocal operator attributes to onnx's Reciprocal operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -2078,10 +1927,7 @@ def convert_power(node, **kwargs): """Map MXNet's _power operator attributes to onnx's Pow operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_b_id = kwargs["index_lookup"][inputs[1][0]] @@ -2102,10 +1948,7 @@ def convert_broadcast_power(node, **kwargs): """Map MXNet's _power operator attributes to onnx's Pow operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_b_id = kwargs["index_lookup"][inputs[1][0]] @@ -2126,10 +1969,7 @@ def convert_sqrt(node, **kwargs): """Map MXNet's sqrt operator attributes to onnx's Sqrt operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -2147,11 +1987,7 @@ def convert_depthtospace(node, **kwargs): """Map MXNet's depth_to_space operator attributes to onnx's DepthToSpace operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -2172,11 +2008,7 @@ def convert_spacetodepth(node, **kwargs): """Map MXNet's space_to_depth operator attributes to onnx's SpaceToDepth operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) input_node_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_node_id].name @@ -2197,10 +2029,7 @@ def convert_square(node, **kwargs): """Map MXNet's square operator attributes to onnx's Pow operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, proc_nodes, inputs, _ = get_inputs(node, kwargs) input_node_a_id = kwargs["index_lookup"][inputs[0][0]] input_node_a = proc_nodes[input_node_a_id].name @@ -2233,11 +2062,7 @@ def convert_sum(node, **kwargs): """Map MXNet's sum operator attributes to onnx's ReduceSum operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] + onnx, name, proc_nodes, inputs, attrs = get_inputs(node, kwargs) mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None