diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 32c9494c2e9a..11e75d9a6000 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -145,6 +145,19 @@ def get_inputs(node, kwargs): return name, input_nodes, attrs +def create_basic_op_node(op_name, node, kwargs): + """Helper function to create a basic operator + node that doesn't contain op specific attrs""" + name, input_nodes, _ = get_inputs(node, kwargs) + + node = onnx.helper.make_node( + op_name, + input_nodes, + [name], + name=name + ) + return [node] + @mx_op.register("null") def convert_weights_and_inputs(node, **kwargs): """Helper function to convert weights and inputs. @@ -280,105 +293,49 @@ def convert_tanh(node, **kwargs): """Map MXNet's tanh operator attributes to onnx's Tanh operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Tanh', - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Tanh', node, kwargs) @mx_op.register("cos") def convert_cos(node, **kwargs): """Map MXNet's cos operator attributes to onnx's Cos operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Cos', - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Cos', node, kwargs) @mx_op.register("sin") def convert_sin(node, **kwargs): """Map MXNet's sin operator attributes to onnx's Sin operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Sin', - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Sin', node, kwargs) @mx_op.register("tan") def convert_tan(node, **kwargs): """Map MXNet's tan operator attributes to onnx's tan operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Tan', - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Tan', node, kwargs) @mx_op.register("arccos") def convert_acos(node, **kwargs): """Map MXNet's acos operator attributes to onnx's acos operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Acos', - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Acos', node, kwargs) @mx_op.register("arcsin") def convert_asin(node, **kwargs): """Map MXNet's asin operator attributes to onnx's asin operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Asin', - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Asin', node, kwargs) @mx_op.register("arctan") def convert_atan(node, **kwargs): """Map MXNet's atan operator attributes to onnx's atan operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Atan', - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Atan', node, kwargs) #Basic neural network functions @mx_op.register("sigmoid") @@ -386,31 +343,14 @@ def convert_sigmoid(node, **kwargs): """Map MXNet's sigmoid operator attributes to onnx's Sigmoid operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Sigmoid', - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Sigmoid', node, kwargs) @mx_op.register("relu") def convert_relu(node, **kwargs): """Map MXNet's relu operator attributes to onnx's Relu operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Relu', - input_nodes, - [name], - name=name - ) - - return [node] + return create_basic_op_node('Relu', node, kwargs) @mx_op.register("Activation") def convert_activation(node, **kwargs): @@ -674,31 +614,14 @@ def convert_exp(node, **kwargs): """Map MXNet's exp operator attributes to onnx's Exp operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - "Exp", - input_nodes, - [name], - name=name, - ) - return [node] - + return create_basic_op_node('Exp', node, kwargs) @mx_op.register("_copy") def convert_identity(node, **kwargs): """Map MXNet's _copy operator attributes to onnx's Identity operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - "Identity", - input_nodes, - [name], - name=name, - ) - return [node] + return create_basic_op_node('Identity', node, kwargs) @mx_op.register("LeakyReLU") @@ -893,15 +816,7 @@ def convert_flatten(node, **kwargs): """Map MXNet's Flatten operator attributes to onnx's Flatten operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - flatten_node = onnx.helper.make_node( - "Flatten", - input_nodes, - [name], - name=name - ) - return [flatten_node] + return create_basic_op_node('Flatten', node, kwargs) @mx_op.register("clip") def convert_clip(node, **kwargs): @@ -1077,16 +992,7 @@ def convert_maximum(node, **kwargs): """Map MXNet's _maximum operator attributes to onnx's Max operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Max', - inputs=input_nodes, - outputs=[name], - name=name, - ) - - return [node] + return create_basic_op_node('Max', node, kwargs) @mx_op.register("_minimum") @@ -1094,17 +1000,7 @@ def convert_minimum(node, **kwargs): """Map MXNet's _minimum operator attributes to onnx's Min operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - 'Min', - inputs=input_nodes, - outputs=[name], - name=name, - ) - - return [node] - + return create_basic_op_node('Min', node, kwargs) @mx_op.register("min") def convert_min(node, **kwargs): @@ -1252,16 +1148,7 @@ def convert_elementwise_add(node, **kwargs): """Map MXNet's elemwise_add operator attributes to onnx's Add operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - add_node = onnx.helper.make_node( - "Add", - input_nodes, - [name], - name=name, - ) - - return [add_node] + return create_basic_op_node('Add', node, kwargs) @mx_op.register("broadcast_add") @@ -1269,16 +1156,7 @@ def covert_broadcast_add(node, **kwargs): """Map MXNet's broadcast_add operator attributes to onnx's Add operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - add_node = onnx.helper.make_node( - "Add", - input_nodes, - [name], - name=name, - ) - - return [add_node] + return create_basic_op_node('Add', node, kwargs) @mx_op.register("elemwise_sub") @@ -1286,149 +1164,63 @@ def convert_elementwise_sub(node, **kwargs): """Map MXNet's elemwise_sub operator attributes to onnx's Sub operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - sub_node = onnx.helper.make_node( - "Sub", - input_nodes, - [name], - name=name, - ) - - return [sub_node] + return create_basic_op_node('Sub', node, kwargs) @mx_op.register("broadcast_sub") def covert_broadcast_sub(node, **kwargs): """Map MXNet's broadcast_sub operator attributes to onnx's Sub operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - sub_node = onnx.helper.make_node( - "Sub", - input_nodes, - [name], - name=name, - ) - - return [sub_node] - + return create_basic_op_node('Sub', node, kwargs) @mx_op.register("elemwise_mul") def convert_elemwise_mul(node, **kwargs): """Map MXNet's elemwise_mul operator attributes to onnx's Mul operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - mul_node = onnx.helper.make_node( - "Mul", - input_nodes, - [name], - name=name, - ) - - return [mul_node] + return create_basic_op_node('Mul', node, kwargs) @mx_op.register("broadcast_mul") def convert_broadcast_mul(node, **kwargs): """Map MXNet's broadcast_mul operator attributes to onnx's Mul operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - mul_node = onnx.helper.make_node( - "Mul", - input_nodes, - [name], - name=name - ) - - return [mul_node] - + return create_basic_op_node('Mul', node, kwargs) @mx_op.register("elemwise_div") def convert_elemwise_div(node, **kwargs): """Map MXNet's elemwise_div operator attributes to onnx's Div operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - div_node = onnx.helper.make_node( - "Div", - input_nodes, - [name], - name=name - ) - - return [div_node] - + return create_basic_op_node('Div', node, kwargs) @mx_op.register("broadcast_div") def convert_broadcast_div(node, **kwargs): """Map MXNet's broadcast_div operator attributes to onnx's Div operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - div_node = onnx.helper.make_node( - "Div", - input_nodes, - [name], - name=name - ) - - return [div_node] - + return create_basic_op_node('Div', node, kwargs) @mx_op.register("negative") def convert_negative(node, **kwargs): """Map MXNet's negative operator attributes to onnx's Neg operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - neg_node = onnx.helper.make_node( - "Neg", - input_nodes, - [name], - name=name, - ) - - return [neg_node] - + return create_basic_op_node('Neg', node, kwargs) @mx_op.register("abs") def convert_abs(node, **kwargs): """Map MXNet's abs operator attributes to onnx's Abs operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - abs_node = onnx.helper.make_node( - "Abs", - input_nodes, - [name], - name=name - ) - - return [abs_node] - + return create_basic_op_node('Abs', node, kwargs) @mx_op.register("add_n") def convert_addn(node, **kwargs): """Map MXNet's add_n operator attributes to onnx's Sum operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - sum_node = onnx.helper.make_node( - "Sum", - input_nodes, - [name], - name=name - ) - return [sum_node] + return create_basic_op_node('Sum', node, kwargs) # Rounding @mx_op.register("ceil") @@ -1436,30 +1228,14 @@ def convert_ceil(node, **kwargs): """Map MXNet's ceil operator attributes to onnx's Ceil operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - "Ceil", - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Ceil', node, kwargs) @mx_op.register("floor") def convert_floor(node, **kwargs): """Map MXNet's floor operator attributes to onnx's Floor operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - "Floor", - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Floor', node, kwargs) # Changing shape and type. @mx_op.register("Reshape") @@ -1641,76 +1417,35 @@ def convert_log(node, **kwargs): """Map MXNet's log operator attributes to onnx's Log operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - "Log", - input_nodes, - [name], - name=name, - ) - return [node] - + return create_basic_op_node('Log', node, kwargs) @mx_op.register("reciprocal") def convert_reciprocal(node, **kwargs): """Map MXNet's reciprocal operator attributes to onnx's Reciprocal operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - "Reciprocal", - input_nodes, - [name], - name=name, - ) - return [node] + return create_basic_op_node('Reciprocal', node, kwargs) @mx_op.register("_power") def convert_power(node, **kwargs): """Map MXNet's _power operator attributes to onnx's Pow operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - "Pow", - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Pow', node, kwargs) @mx_op.register("broadcast_power") def convert_broadcast_power(node, **kwargs): """Map MXNet's _power operator attributes to onnx's Pow operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - "Pow", - input_nodes, - [name], - name=name - ) - return [node] + return create_basic_op_node('Pow', node, kwargs) @mx_op.register("sqrt") def convert_sqrt(node, **kwargs): """Map MXNet's sqrt operator attributes to onnx's Sqrt operator and return the created node. """ - name, input_nodes, _ = get_inputs(node, kwargs) - - node = onnx.helper.make_node( - "Sqrt", - input_nodes, - [name], - name=name, - ) - return [node] + return create_basic_op_node('Sqrt', node, kwargs) @mx_op.register("depth_to_space") def convert_depthtospace(node, **kwargs):