From 0027cc057e9a292648261e89d88573f400e1626c Mon Sep 17 00:00:00 2001 From: vandanavk Date: Thu, 18 Oct 2018 14:09:44 -0700 Subject: [PATCH] ONNX export: Cleanup input retrieval --- .../contrib/onnx/mx2onnx/_op_translations.py | 874 +++++------------- 1 file changed, 214 insertions(+), 660 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 7cf856c767fa..d30d3a2345f8 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -74,7 +74,7 @@ def import_onnx_modules(): def parse_helper(attrs, attrs_name, alt_value=None): """Helper function to parse operator attributes in required format.""" tuple_re = re.compile('\([0-9L|,| ]+\)') - if attrs is None: + if attrs is None or attrs == {}: return alt_value attrs_str = None if attrs.get(attrs_name) is None else str(attrs.get(attrs_name)) if attrs_str is None: @@ -135,12 +135,27 @@ def get_boolean_attribute_value(attrs, attr_name): """ return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0 +def get_inputs(node, kwargs): + """Helper function to get inputs""" + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + index_lookup = kwargs["index_lookup"] + inputs = node["inputs"] + attrs = node.get("attrs", {}) + + input_nodes = [] + for ip in inputs: + input_node_id = index_lookup[ip[0]] + input_nodes.append(proc_nodes[input_node_id].name) + + return onnx, name, input_nodes, attrs + @mx_op.register("null") def convert_weights_and_inputs(node, **kwargs): """Helper function to convert weights and inputs. """ - onnx = import_onnx_modules() - name = node["name"] + onnx, name, _, _ = get_inputs(node, kwargs) if kwargs["is_input"] is False: weights = kwargs["weights"] @@ -172,20 +187,7 @@ def convert_convolution(node, **kwargs): """Map MXNet's convolution operator attributes to onnx's Conv operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - - num_inputs = len(inputs) - - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name - weights_node = proc_nodes[kwargs["index_lookup"][inputs[1][0]]].name - - if num_inputs > 2: - bias_node = proc_nodes[kwargs["index_lookup"][inputs[2][0]]].name - - attrs = node.get("attrs") + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) kernel_dims = list(parse_helper(attrs, "kernel")) stride_dims = list(parse_helper(attrs, "stride", [1, 1])) @@ -195,10 +197,6 @@ def convert_convolution(node, **kwargs): pad_dims = pad_dims + pad_dims - input_nodes = [input_node, weights_node] - if num_inputs > 2: - input_nodes.append(bias_node) - conv_node = onnx.helper.make_node( "Conv", inputs=input_nodes, @@ -219,32 +217,15 @@ def convert_fully_connected(node, **kwargs): """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - attrs = node["attrs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) + initializer = kwargs["initializer"] no_bias = get_boolean_attribute_value(attrs, "no_bias") - input_node_id = kwargs["index_lookup"][inputs[0][0]] - weight_node_id = kwargs["index_lookup"][inputs[1][0]] - - proc_nodes = kwargs["proc_nodes"] - - input_node = proc_nodes[input_node_id] - input_name = input_node.name - - weights_node = proc_nodes[weight_node_id] - weights_name = weights_node.name - fcnode = [] - if no_bias == 0: - bias_node_id = kwargs["index_lookup"][inputs[2][0]] - bias_node = proc_nodes[bias_node_id] - bias_name = bias_node.name - else: + if no_bias: data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')] bias_name = "bias" + str(kwargs["idx"]) tensor_node = onnx.helper.make_tensor_value_info(bias_name, data_type, (1,)) @@ -257,11 +238,12 @@ def convert_fully_connected(node, **kwargs): raw=False, ) ) + input_nodes.append(bias_name) fcnode.append(tensor_node) node = onnx.helper.make_node( "Gemm", - [input_name, weights_name, bias_name], # input (A, B, C) - C can be in place + input_nodes, # input (A, B, C) - C can be in place [name], # output alpha=1.0, beta=1.0, @@ -280,37 +262,14 @@ def convert_batchnorm(node, **kwargs): """Map MXNet's BatchNorm operator attributes to onnx's BatchNormalization operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - attrs = node["attrs"] - momentum = float(node.get("attrs", {}).get("momentum", 0.9)) + momentum = float(attrs.get("momentum", 0.9)) eps = float(attrs.get("eps", 0.001)) - data_idx = kwargs["index_lookup"][inputs[0][0]] - gamma_idx = kwargs["index_lookup"][inputs[1][0]] - beta_idx = kwargs["index_lookup"][inputs[2][0]] - moving_mean_idx = kwargs["index_lookup"][inputs[3][0]] - moving_var_idx = kwargs["index_lookup"][inputs[4][0]] - - data_node = proc_nodes[data_idx].name - gamma_node = proc_nodes[gamma_idx].name - beta_node = proc_nodes[beta_idx].name - - mov_mean_node = proc_nodes[moving_mean_idx] - mov_mean_node = mov_mean_node.name - mov_var_node = proc_nodes[moving_var_idx].name - bn_node = onnx.helper.make_node( "BatchNormalization", - [data_node, - gamma_node, # scale - beta_node, # bias - mov_mean_node, - mov_var_node - ], + input_nodes, [name], name=name, epsilon=eps, @@ -327,16 +286,11 @@ def convert_tanh(node, **kwargs): """Map MXNet's tanh operator attributes to onnx's Tanh operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Tanh', - [input_node], + input_nodes, [name], name=name ) @@ -347,16 +301,11 @@ def convert_cos(node, **kwargs): """Map MXNet's cos operator attributes to onnx's Cos operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Cos', - [input_node], + input_nodes, [name], name=name ) @@ -367,16 +316,11 @@ def convert_sin(node, **kwargs): """Map MXNet's sin operator attributes to onnx's Sin operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Sin', - [input_node], + input_nodes, [name], name=name ) @@ -387,16 +331,11 @@ def convert_tan(node, **kwargs): """Map MXNet's tan operator attributes to onnx's tan operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Tan', - [input_node], + input_nodes, [name], name=name ) @@ -407,16 +346,11 @@ def convert_acos(node, **kwargs): """Map MXNet's acos operator attributes to onnx's acos operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Acos', - [input_node], + input_nodes, [name], name=name ) @@ -427,16 +361,11 @@ def convert_asin(node, **kwargs): """Map MXNet's asin operator attributes to onnx's asin operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Asin', - [input_node], + input_nodes, [name], name=name ) @@ -447,16 +376,11 @@ def convert_atan(node, **kwargs): """Map MXNet's atan operator attributes to onnx's atan operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Atan', - [input_node], + input_nodes, [name], name=name ) @@ -468,16 +392,11 @@ def convert_sigmoid(node, **kwargs): """Map MXNet's sigmoid operator attributes to onnx's Sigmoid operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Sigmoid', - [input_node], + input_nodes, [name], name=name ) @@ -488,16 +407,11 @@ def convert_relu(node, **kwargs): """Map MXNet's relu operator attributes to onnx's Relu operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_idx].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Relu', - [input_node], + input_nodes, [name], name=name ) @@ -509,17 +423,10 @@ def convert_activation(node, **kwargs): """Map MXNet's Activation operator attributes to onnx's Tanh/Relu operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - proc_nodes = kwargs["proc_nodes"] - attrs = node["attrs"] act_type = attrs["act_type"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_idx].output[0] - # Creating a dictionary here, but if this titlecase pattern # mxnet_name.title() act_types = { @@ -534,7 +441,7 @@ def convert_activation(node, **kwargs): if act_name: node = onnx.helper.make_node( act_name, - [input_node], + input_nodes, [name], name=name ) @@ -551,13 +458,7 @@ def convert_pad(node, **kwargs): """Map MXNet's pad operator attributes to onnx's Pad operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - attrs = node["attrs"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_idx].name + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) mxnet_pad_width = convert_string_to_list(attrs.get("pad_width")) onnx_pad_width = transform_padding(mxnet_pad_width) @@ -569,7 +470,7 @@ def convert_pad(node, **kwargs): if "constant_value" in attrs else 0.0 node = onnx.helper.make_node( 'Pad', - inputs=[input_node], + inputs=input_nodes, outputs=[name], mode='constant', value=pad_value, @@ -579,7 +480,7 @@ def convert_pad(node, **kwargs): else: node = onnx.helper.make_node( 'Pad', - inputs=[input_node], + inputs=input_nodes, outputs=[name], mode=pad_mode, pads=onnx_pad_width, @@ -608,17 +509,8 @@ def convert_dot(node, **kwargs): """Map MXNet's dot operator attributes to onnx's MatMul and Transpose operators based on the values set for transpose_a, transpose_b attributes.""" - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - name = node["name"] - - input_a_idx = kwargs["index_lookup"][node_inputs[0][0]] - input_node_a = proc_nodes[input_a_idx].name - input_b_idx = kwargs["index_lookup"][node_inputs[1][0]] - input_node_b = proc_nodes[input_b_idx].name + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - attrs = node.get('attrs', {}) trans_a_node = None trans_b_node = None @@ -626,14 +518,12 @@ def convert_dot(node, **kwargs): trans_b = get_boolean_attribute_value(attrs, "transpose_b") op_name = "transpose" + str(kwargs["idx"]) - create_helper_trans_node(op_name, input_node_a, 'a') - create_helper_trans_node(op_name, input_node_b, 'b') if trans_a: - trans_a_node = create_helper_trans_node(op_name, input_node_a, 'a') + trans_a_node = create_helper_trans_node(op_name, input_nodes[0], 'a') input_node_a = op_name+"_a" if trans_b: - trans_b_node = create_helper_trans_node(op_name, input_node_b, 'b') + trans_b_node = create_helper_trans_node(op_name, input_nodes[1], 'b') input_node_b = op_name+"_b" matmul_node = onnx.helper.make_node( @@ -660,33 +550,19 @@ def convert_linalg_gemm2(node, **kwargs): transpose_a, transpose_b attributes. Return multiple nodes created. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - name = node["name"] - - input_a_idx = kwargs["index_lookup"][node_inputs[0][0]] - input_node_a = proc_nodes[input_a_idx].name - input_b_idx = kwargs["index_lookup"][node_inputs[1][0]] - input_node_b = proc_nodes[input_b_idx].name + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) # Getting the attributes and assigning default values. - if "attrs" in node: - attrs = node["attrs"] - alpha = float(attrs["alpha"]) - trans_a = int(attrs["transpose_a"]) - trans_b = int(attrs["transpose_b"]) - else: - alpha = 1.0 - trans_a = 0 - trans_b = 0 + alpha = float(attrs.get("alpha", 1.0)) + trans_a = get_boolean_attribute_value(attrs, "transpose_a") + trans_b = get_boolean_attribute_value(attrs, "transpose_b") op_name = "transpose" + str(kwargs["idx"]) if alpha == 1.0 and trans_a == 0 and trans_b == 0: matmul_node = onnx.helper.make_node( 'MatMul', - inputs=[input_node_a, input_node_b], + inputs=input_nodes, outputs=[name], name=name ) @@ -696,14 +572,14 @@ def convert_linalg_gemm2(node, **kwargs): node_name = op_name+"_a" trans_a_node = onnx.helper.make_node( 'Transpose', - inputs=[input_node_a], + inputs=[input_nodes[0]], outputs=[op_name+"_a"], name=node_name ) matmul_node = onnx.helper.make_node( 'MatMul', - inputs=[node_name, input_node_b], + inputs=[node_name, input_nodes[1]], outputs=[name], name=name ) @@ -713,14 +589,14 @@ def convert_linalg_gemm2(node, **kwargs): node_name = op_name + "_b" trans_b_node = onnx.helper.make_node( 'Transpose', - inputs=[input_node_b], + inputs=[input_nodes[1]], outputs=[op_name+"_b"], name=node_name ) matmul_node = onnx.helper.make_node( 'MatMul', - inputs=[input_node_a, node_name], + inputs=[input_nodes[0], node_name], outputs=[name], name=name ) @@ -730,7 +606,7 @@ def convert_linalg_gemm2(node, **kwargs): node_name_a = op_name+"_a" trans_a_node = onnx.helper.make_node( 'Transpose', - inputs=[input_node_a], + inputs=[input_nodes[0]], outputs=[op_name+"_a"], name=node_name_a ) @@ -738,14 +614,14 @@ def convert_linalg_gemm2(node, **kwargs): node_name_b = op_name + "_b" trans_b_node = onnx.helper.make_node( 'Transpose', - inputs=[input_node_b], + inputs=[input_nodes[1]], outputs=[op_name+"_b"], name=node_name_b ) matmul_node = onnx.helper.make_node( 'MatMul', - inputs=[node_name_a, node_name_b], + inputs=input_nodes, outputs=[name], name=name ) @@ -759,19 +635,13 @@ def convert_pooling(node, **kwargs): MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators based on the input node's attributes and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - attrs = node["attrs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) + kernel = eval(attrs["kernel"]) pool_type = attrs["pool_type"] stride = eval(attrs["stride"]) if attrs.get("stride") else None global_pool = get_boolean_attribute_value(attrs, "global_pool") - node_inputs = node["inputs"] - input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] - input_node = proc_nodes[input_node_idx] - name = node["name"] - pooling_convention = attrs.get('pooling_convention', 'valid') if pooling_convention == 'full': @@ -789,14 +659,14 @@ def convert_pooling(node, **kwargs): if global_pool: node = onnx.helper.make_node( global_pool_types[pool_type], - [input_node.name], # input + input_nodes, # input [name], name=name ) else: node = onnx.helper.make_node( pool_types[pool_type], - [input_node.name], # input + input_nodes, # input [name], kernel_shape=kernel, pads=pad_dims, @@ -812,17 +682,11 @@ def convert_exp(node, **kwargs): """Map MXNet's exp operator attributes to onnx's Exp operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( "Exp", - [input_node], + input_nodes, [name], name=name, ) @@ -834,17 +698,11 @@ def convert_identity(node, **kwargs): """Map MXNet's _copy operator attributes to onnx's Identity operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( "Identity", - [input_node], + input_nodes, [name], name=name, ) @@ -856,13 +714,7 @@ def convert_leakyrelu(node, **kwargs): """Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu operators based on the input node's attributes and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - attrs = node["attrs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) act_type = attrs.get("act_type", "leaky") alpha = float(attrs.get("slope", 0.25)) @@ -870,25 +722,16 @@ def convert_leakyrelu(node, **kwargs): act_name = {"elu": "Elu", "leaky": "LeakyRelu", "prelu": "PRelu", "selu": "Selu"} - if act_type == "prelu": - alpha_node_index = kwargs["index_lookup"][inputs[1][0]] - alpha_node_name = proc_nodes[alpha_node_index].name - + if act_type == "prelu" or act_type == "selu": node = onnx.helper.make_node( act_name[act_type], - inputs=[input_node, alpha_node_name], - outputs=[name], - name=name) - elif act_type == "selu": - node = onnx.helper.make_node( - act_name[act_type], - inputs=[input_node], + inputs=input_nodes, outputs=[name], name=name) else: node = onnx.helper.make_node( act_name[act_type], - inputs=[input_node], + inputs=input_nodes, outputs=[name], name=name, alpha=alpha) @@ -901,18 +744,13 @@ def convert_softmax(node, **kwargs): """Map MXNet's softmax operator attributes to onnx's Softmax operator and return the created node. """ - onnx = import_onnx_modules() - inputs = node["inputs"] - input_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_idx] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - name = node["name"] - axis = int(node.get("attrs", {}).get("axis", -1)) + axis = int(attrs.get("axis", -1)) softmax_node = onnx.helper.make_node( "Softmax", - [input_node.name], + input_nodes, [name], axis=axis, name=name @@ -928,12 +766,10 @@ def convert_softmax_output(node, **kwargs): """Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator and return the created node. """ - onnx = import_onnx_modules() - inputs = node["inputs"] - input1_idx = kwargs["index_lookup"][inputs[0][0]] - proc_nodes = kwargs["proc_nodes"] - input1 = proc_nodes[input1_idx] - name = node["name"] + onnx, name, _, _ = get_inputs(node, kwargs) + + input1_idx = kwargs["index_lookup"][node["inputs"][0][0]] + input1 = kwargs["proc_nodes"][input1_idx] softmax_node = onnx.helper.make_node( "Softmax", @@ -951,15 +787,12 @@ def convert_concat(node, **kwargs): """Map MXNet's Concat operator attributes to onnx's Concat operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - inputs = node["inputs"] - proc_nodes = kwargs["proc_nodes"] - input_names = [proc_nodes[kwargs["index_lookup"][i[0]]].name for i in inputs] - axis = int(node.get("attrs", {}).get("dim", 1)) + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) + + axis = int(attrs.get("dim", 1)) concat_node = onnx.helper.make_node( "Concat", - input_names, + input_nodes, [name], axis=axis, name=name @@ -972,18 +805,15 @@ def convert_transpose(node, **kwargs): """Map MXNet's transpose operator attributes to onnx's Transpose operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_idx].name - axes = node.get("attrs", {}).get("axes", ()) + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) + + axes = attrs.get("axes", ()) if axes: axes = tuple(map(int, re.findall(r'\d+', axes))) transpose_node = onnx.helper.make_node( "Transpose", - [input_node], + input_nodes, [name], perm=axes, name=name @@ -991,7 +821,7 @@ def convert_transpose(node, **kwargs): else: transpose_node = onnx.helper.make_node( "Transpose", - [input_node], + input_nodes, [name], name=name ) @@ -1004,13 +834,8 @@ def convert_lrn(node, **kwargs): """Map MXNet's LRN operator attributes to onnx's LRN operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_idx].name + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - attrs = node["attrs"] alpha = float(attrs["alpha"]) if "alpha" in attrs else 0.0001 beta = float(attrs["beta"]) if "beta" in attrs else 0.75 bias = float(attrs["knorm"]) if "knorm" in attrs else 1.0 @@ -1018,7 +843,7 @@ def convert_lrn(node, **kwargs): lrn_node = onnx.helper.make_node( "LRN", - inputs=[input_node], + inputs=input_nodes, outputs=[name], name=name, alpha=alpha, @@ -1035,11 +860,8 @@ def convert_l2normalization(node, **kwargs): """Map MXNet's L2Normalization operator attributes to onnx's LpNormalization operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_id = kwargs["index_lookup"][node["inputs"][0][0]] - input_name = kwargs["proc_nodes"][input_id].name - attrs = node["attrs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) + mode = attrs.get("mode", "instance") if mode != "channel": @@ -1047,7 +869,7 @@ def convert_l2normalization(node, **kwargs): l2norm_node = onnx.helper.make_node( "LpNormalization", - [input_name], + input_nodes, [name], axis=1, # channel only name=name @@ -1060,16 +882,13 @@ def convert_dropout(node, **kwargs): """Map MXNet's Dropout operator attributes to onnx's Dropout operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_id = kwargs["index_lookup"][node["inputs"][0][0]] - input_name = kwargs["proc_nodes"][input_id].name - attrs = node["attrs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) + probability = float(attrs["p"]) dropout_node = onnx.helper.make_node( "Dropout", - [input_name], + input_nodes, [name], ratio=probability, name=name @@ -1082,15 +901,11 @@ def convert_flatten(node, **kwargs): """Map MXNet's Flatten operator attributes to onnx's Flatten operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_idx].name # .output[0] + onnx, name, input_nodes, _ = get_inputs(node, kwargs) flatten_node = onnx.helper.make_node( "Flatten", - [input_node], + input_nodes, [name], name=name ) @@ -1101,18 +916,14 @@ def convert_clip(node, **kwargs): """Map MXNet's Clip operator attributes to onnx's Clip operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - input_idx = kwargs["index_lookup"][node["inputs"][0][0]] - proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_idx].name - attrs = node["attrs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) + a_min = np.float(attrs.get('a_min', -np.inf)) a_max = np.float(attrs.get('a_max', np.inf)) clip_node = onnx.helper.make_node( "Clip", - [input_node], + input_nodes, [name], name=name, min=a_min, @@ -1123,21 +934,16 @@ def convert_clip(node, **kwargs): def scalar_op_helper(node, op_name, **kwargs): """Helper function for scalar arithmetic operations""" - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - scalar_value = [float(node.get("attrs", {}).get("scalar", 1))] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - input_name_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_name_id].name + scalar_value = [float(attrs.get("scalar", 1))] initializer = kwargs["initializer"] flag = True # If the input value is in initializer, just multiply with scalar input # and create a new initializer for i in initializer: - if i.name == input_node: + if i.name == input_nodes[0]: if op_name == 'Mul': new_initializer = onnx.numpy_helper.to_array(i) * scalar_value[0] elif op_name == 'Sub': @@ -1170,7 +976,7 @@ def scalar_op_helper(node, op_name, **kwargs): mul_node = onnx.helper.make_node( op_name, - [input_node, scalar_op_name], + [input_nodes[0], scalar_op_name], [name], name=name ) @@ -1180,7 +986,7 @@ def scalar_op_helper(node, op_name, **kwargs): data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[new_initializer.dtype] dims = np.shape(new_initializer) - new_a_node = input_node + str(kwargs["idx"]) + new_a_node = input_nodes[0] + str(kwargs["idx"]) tensor_node = onnx.helper.make_tensor_value_info(new_a_node, data_type, dims) initializer.append( @@ -1239,21 +1045,14 @@ def convert_argmax(node, **kwargs): """Map MXNet's argmax operator attributes to onnx's ArgMax operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - - input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] - input_node = proc_nodes[input_node_idx].name - name = node["name"] - attrs = node["attrs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) axis = int(attrs.get("axis")) keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1 node = onnx.helper.make_node( 'ArgMax', - inputs=[input_node], + inputs=input_nodes, axis=axis, keepdims=keepdims, outputs=[name], @@ -1266,21 +1065,14 @@ def convert_argmin(node, **kwargs): """Map MXNet's argmin operator attributes to onnx's ArgMin operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - - input_node_idx = kwargs["index_lookup"][node_inputs[0][0]] - input_node = proc_nodes[input_node_idx].name - name = node["name"] - attrs = node["attrs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) axis = int(attrs.get("axis")) keepdims = int(attrs.get("keepdims")) if "keepdims" in attrs else 1 node = onnx.helper.make_node( 'ArgMin', - inputs=[input_node], + inputs=input_nodes, axis=axis, keepdims=keepdims, outputs=[name], @@ -1293,20 +1085,11 @@ def convert_maximum(node, **kwargs): """Map MXNet's _maximum operator attributes to onnx's Max operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - - input_node_list = [] - for node_input in node_inputs: - node_id = kwargs["index_lookup"][node_input[0]] - input_node_list.append(proc_nodes[node_id].name) - - name = node["name"] + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Max', - inputs=input_node_list, + inputs=input_nodes, outputs=[name], name=name, ) @@ -1319,20 +1102,11 @@ def convert_minimum(node, **kwargs): """Map MXNet's _minimum operator attributes to onnx's Min operator and return the created node. """ - onnx = import_onnx_modules() - proc_nodes = kwargs["proc_nodes"] - node_inputs = node["inputs"] - - input_node_list = [] - for node_input in node_inputs: - node_id = kwargs["index_lookup"][node_input[0]] - input_node_list.append(proc_nodes[node_id].name) - - name = node["name"] + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( 'Min', - inputs=input_node_list, + inputs=input_nodes, outputs=[name], name=name, ) @@ -1345,23 +1119,17 @@ def convert_min(node, **kwargs): """Map MXNet's min operator attributes to onnx's ReduceMin operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + keepdims = int(attrs.get("keepdims", 0)) if axes is not None: node = onnx.helper.make_node( 'ReduceMin', - inputs=[input_node], + inputs=input_nodes, outputs=[name], axes=axes, keepdims=keepdims, @@ -1372,7 +1140,7 @@ def convert_min(node, **kwargs): else: node = onnx.helper.make_node( 'ReduceMin', - inputs=[input_node], + inputs=input_nodes, outputs=[name], keepdims=keepdims, name=name @@ -1386,23 +1154,17 @@ def convert_max(node, **kwargs): """Map MXNet's max operator attributes to onnx's ReduceMax operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + keepdims = int(attrs.get("keepdims", 0)) if axes is not None: node = onnx.helper.make_node( 'ReduceMax', - inputs=[input_node], + inputs=input_nodes, outputs=[name], axes=axes, keepdims=keepdims, @@ -1413,7 +1175,7 @@ def convert_max(node, **kwargs): else: node = onnx.helper.make_node( 'ReduceMax', - inputs=[input_node], + inputs=input_nodes, outputs=[name], keepdims=keepdims, name=name @@ -1427,23 +1189,17 @@ def convert_mean(node, **kwargs): """Map MXNet's mean operator attributes to onnx's ReduceMean operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + keepdims = int(attrs.get("keepdims", 0)) if axes is not None: node = onnx.helper.make_node( 'ReduceMean', - inputs=[input_node], + inputs=input_nodes, outputs=[name], axes=axes, keepdims=keepdims, @@ -1454,7 +1210,7 @@ def convert_mean(node, **kwargs): else: node = onnx.helper.make_node( 'ReduceMean', - inputs=[input_node], + inputs=input_nodes, outputs=[name], keepdims=keepdims, name=name @@ -1468,23 +1224,17 @@ def convert_prod(node, **kwargs): """Map MXNet's prod operator attributes to onnx's ReduceProd operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - mx_axis = node.get("attrs", {}).get("axis", None) + mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None - keepdims = int(node.get("attrs", {}).get("keepdims", 0)) - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + keepdims = int(attrs.get("keepdims", 0)) if axes is not None: node = onnx.helper.make_node( 'ReduceProd', - inputs=[input_node], + inputs=input_nodes, outputs=[name], axes=axes, keepdims=keepdims, @@ -1495,7 +1245,7 @@ def convert_prod(node, **kwargs): else: node = onnx.helper.make_node( 'ReduceProd', - inputs=[input_node], + inputs=input_nodes, outputs=[name], keepdims=keepdims, name=name @@ -1510,20 +1260,11 @@ def convert_elementwise_add(node, **kwargs): """Map MXNet's elemwise_add operator attributes to onnx's Add operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) add_node = onnx.helper.make_node( "Add", - [input_node_a, input_node_b], + input_nodes, [name], name=name, ) @@ -1536,20 +1277,11 @@ def covert_broadcast_add(node, **kwargs): """Map MXNet's broadcast_add operator attributes to onnx's Add operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) add_node = onnx.helper.make_node( "Add", - [input_node_a, input_node_b], + input_nodes, [name], name=name, ) @@ -1562,20 +1294,11 @@ def convert_elementwise_sub(node, **kwargs): """Map MXNet's elemwise_sub operator attributes to onnx's Sub operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) sub_node = onnx.helper.make_node( "Sub", - [input_node_a, input_node_b], + input_nodes, [name], name=name, ) @@ -1587,20 +1310,11 @@ def covert_broadcast_sub(node, **kwargs): """Map MXNet's broadcast_sub operator attributes to onnx's Sub operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) sub_node = onnx.helper.make_node( "Sub", - [input_node_a, input_node_b], + input_nodes, [name], name=name, ) @@ -1613,20 +1327,11 @@ def convert_elemwise_mul(node, **kwargs): """Map MXNet's elemwise_mul operator attributes to onnx's Mul operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) mul_node = onnx.helper.make_node( "Mul", - [input_node_a, input_node_b], + input_nodes, [name], name=name, ) @@ -1638,20 +1343,11 @@ def convert_broadcast_mul(node, **kwargs): """Map MXNet's broadcast_mul operator attributes to onnx's Mul operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) mul_node = onnx.helper.make_node( "Mul", - [input_node_a, input_node_b], + input_nodes, [name], name=name ) @@ -1664,20 +1360,11 @@ def convert_elemwise_div(node, **kwargs): """Map MXNet's elemwise_div operator attributes to onnx's Div operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) div_node = onnx.helper.make_node( "Div", - [input_node_a, input_node_b], + input_nodes, [name], name=name ) @@ -1690,20 +1377,11 @@ def convert_broadcast_div(node, **kwargs): """Map MXNet's broadcast_div operator attributes to onnx's Div operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) div_node = onnx.helper.make_node( "Div", - [input_node_a, input_node_b], + input_nodes, [name], name=name ) @@ -1716,18 +1394,11 @@ def convert_negative(node, **kwargs): """Map MXNet's negative operator attributes to onnx's Neg operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) neg_node = onnx.helper.make_node( "Neg", - [input_node], + input_nodes, [name], name=name, ) @@ -1740,18 +1411,11 @@ def convert_abs(node, **kwargs): """Map MXNet's abs operator attributes to onnx's Abs operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) abs_node = onnx.helper.make_node( "Abs", - [input_node], + input_nodes, [name], name=name ) @@ -1764,18 +1428,11 @@ def convert_addn(node, **kwargs): """Map MXNet's add_n operator attributes to onnx's Sum operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_list = [] - for input_val in inputs: - input_list.append(proc_nodes[kwargs["index_lookup"][input_val[0]]].name) + onnx, name, input_nodes, _ = get_inputs(node, kwargs) sum_node = onnx.helper.make_node( "Sum", - input_list, + input_nodes, [name], name=name ) @@ -1787,17 +1444,11 @@ def convert_ceil(node, **kwargs): """Map MXNet's ceil operator attributes to onnx's Ceil operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( "Ceil", - [input_node], + input_nodes, [name], name=name ) @@ -1808,17 +1459,11 @@ def convert_floor(node, **kwargs): """Map MXNet's floor operator attributes to onnx's Floor operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( "Floor", - [input_node], + input_nodes, [name], name=name ) @@ -1831,11 +1476,7 @@ def convert_reshape(node, **kwargs): Converts output shape attribute to output shape tensor and return multiple created nodes. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) output_shape_list = convert_string_to_list(attrs["shape"]) @@ -1857,8 +1498,7 @@ def convert_reshape(node, **kwargs): ) ) - input_node_idx = kwargs["index_lookup"][inputs[0][0]] - input_node_name = proc_nodes[input_node_idx].name + input_nodes.append(output_shape_name) not_supported_shape = [-2, -3, -4] @@ -1868,7 +1508,7 @@ def convert_reshape(node, **kwargs): reshape_node = onnx.helper.make_node( "Reshape", - [input_node_name, output_shape_name], + input_nodes, [name], name=name ) @@ -1880,11 +1520,9 @@ def convert_cast(node, **kwargs): """Map MXNet's Cast operator attributes to onnx's Cast operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - dtype = node["attrs"]["dtype"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) + + dtype = attrs["dtype"] # dtype can be mapped only with types from TensorProto # float32 is mapped to float and float64 to double in onnx @@ -1894,12 +1532,9 @@ def convert_cast(node, **kwargs): elif dtype == 'float64': dtype = 'double' - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - node = onnx.helper.make_node( "Cast", - [input_node], + input_nodes, [name], to=getattr(onnx.TensorProto, dtype.upper()), name=name, @@ -1912,23 +1547,18 @@ def convert_slice_axis(node, **kwargs): """Map MXNet's slice_axis operator attributes to onnx's Slice operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - axes = int(node["attrs"]["axis"]) - starts = int(node["attrs"]["begin"]) - if node["attrs"]["end"] == 'None': + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) + + axes = int(attrs["axis"]) + starts = int(attrs["begin"]) + if attrs["end"] == 'None': raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' attribute") else: - ends = int(node["attrs"]["end"]) - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + ends = int(attrs["end"]) node = onnx.helper.make_node( "Slice", - [input_node], + input_nodes, [name], axes=[axes], starts=[starts], @@ -1944,21 +1574,16 @@ def convert_slice_channel(node, **kwargs): operator based on squeeze_axis attribute and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - num_outputs = int(node.get("attrs", {})["num_outputs"]) - axis = int(node.get("attrs", {}).get("axis", 1)) - squeeze_axis = int(node.get("attrs", {}).get("squeeze_axis", 0)) + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + num_outputs = int(attrs["num_outputs"]) + axis = int(attrs.get("axis", 1)) + squeeze_axis = int(attrs.get("squeeze_axis", 0)) if squeeze_axis == 1 and num_outputs == 1: node = onnx.helper.make_node( "Squeeze", - [input_node], + input_nodes, [name], axes=[axis], name=name, @@ -1967,7 +1592,7 @@ def convert_slice_channel(node, **kwargs): elif squeeze_axis == 0 and num_outputs > 1: node = onnx.helper.make_node( "Split", - [input_node], + input_nodes, [name], axis=axis, split=[num_outputs], @@ -1984,18 +1609,13 @@ def convert_expand_dims(node, **kwargs): """Map MXNet's expand_dims operator attributes to onnx's Unsqueeze operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - axis = int(node["attrs"]["axis"]) + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + axis = int(attrs["axis"]) node = onnx.helper.make_node( "Unsqueeze", - [input_node], + input_nodes, [name], axes=[axis], name=name, @@ -2007,22 +1627,17 @@ def convert_squeeze(node, **kwargs): """Map MXNet's squeeze operator attributes to onnx's squeeze operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - if "axis" in node["attrs"]: - axis = convert_string_to_list(node["attrs"]["axis"]) + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) + + if "axis" in attrs: + axis = convert_string_to_list(attrs["axis"]) else: raise AttributeError("Missing axis attribute: ONNX currently requires axis to " "be specified for squeeze operator") - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - node = onnx.helper.make_node( "Squeeze", - [input_node], + input_nodes, [name], axes=axis, name=name, @@ -2035,17 +1650,11 @@ def convert_log(node, **kwargs): """Map MXNet's log operator attributes to onnx's Log operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( "Log", - [input_node], + input_nodes, [name], name=name, ) @@ -2057,17 +1666,11 @@ def convert_reciprocal(node, **kwargs): """Map MXNet's reciprocal operator attributes to onnx's Reciprocal operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( "Reciprocal", - [input_node], + input_nodes, [name], name=name, ) @@ -2078,20 +1681,11 @@ def convert_power(node, **kwargs): """Map MXNet's _power operator attributes to onnx's Pow operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( "Pow", - [input_node_a, input_node_b], + input_nodes, [name], name=name ) @@ -2102,20 +1696,11 @@ def convert_broadcast_power(node, **kwargs): """Map MXNet's _power operator attributes to onnx's Pow operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_b_id = kwargs["index_lookup"][inputs[1][0]] - - input_node_a = proc_nodes[input_node_a_id].name - input_node_b = proc_nodes[input_node_b_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( "Pow", - [input_node_a, input_node_b], + input_nodes, [name], name=name ) @@ -2126,17 +1711,11 @@ def convert_sqrt(node, **kwargs): """Map MXNet's sqrt operator attributes to onnx's Sqrt operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) node = onnx.helper.make_node( "Sqrt", - [input_node], + input_nodes, [name], name=name, ) @@ -2147,20 +1726,13 @@ def convert_depthtospace(node, **kwargs): """Map MXNet's depth_to_space operator attributes to onnx's DepthToSpace operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) blksize = int(attrs.get("block_size", 0)) node = onnx.helper.make_node( "DepthToSpace", - [input_node], + input_nodes, [name], blocksize=blksize, name=name, @@ -2172,20 +1744,13 @@ def convert_spacetodepth(node, **kwargs): """Map MXNet's space_to_depth operator attributes to onnx's SpaceToDepth operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] - - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) blksize = int(attrs.get("block_size", 0)) node = onnx.helper.make_node( "SpaceToDepth", - [input_node], + input_nodes, [name], blocksize=blksize, name=name, @@ -2197,13 +1762,7 @@ def convert_square(node, **kwargs): """Map MXNet's square operator attributes to onnx's Pow operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - - input_node_a_id = kwargs["index_lookup"][inputs[0][0]] - input_node_a = proc_nodes[input_node_a_id].name + onnx, name, input_nodes, _ = get_inputs(node, kwargs) initializer = kwargs["initializer"] data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')] @@ -2220,9 +1779,11 @@ def convert_square(node, **kwargs): ) ) + input_nodes.append(power2_name) + node = onnx.helper.make_node( "Pow", - [input_node_a, power2_name], + input_nodes, [name], name=name ) @@ -2233,24 +1794,17 @@ def convert_sum(node, **kwargs): """Map MXNet's sum operator attributes to onnx's ReduceSum operator and return the created node. """ - onnx = import_onnx_modules() - name = node["name"] - proc_nodes = kwargs["proc_nodes"] - inputs = node["inputs"] - attrs = node["attrs"] + onnx, name, input_nodes, attrs = get_inputs(node, kwargs) mx_axis = attrs.get("axis", None) axes = convert_string_to_list(str(mx_axis)) if mx_axis is not None else None keepdims = get_boolean_attribute_value(attrs, "keepdims") - input_node_id = kwargs["index_lookup"][inputs[0][0]] - input_node = proc_nodes[input_node_id].name - if axes: node = onnx.helper.make_node( 'ReduceSum', - inputs=[input_node], + inputs=input_nodes, outputs=[name], axes=axes, keepdims=keepdims, @@ -2259,7 +1813,7 @@ def convert_sum(node, **kwargs): else: node = onnx.helper.make_node( 'ReduceSum', - inputs=[input_node], + inputs=input_nodes, outputs=[name], keepdims=keepdims, name=name