From 9685b2fdf6a223ab01683f5e2c3b671be43ea8e5 Mon Sep 17 00:00:00 2001 From: vandanavk Date: Fri, 31 Aug 2018 09:10:12 -0700 Subject: [PATCH] Add trigonometric operators --- .../contrib/onnx/mx2onnx/_op_translations.py | 120 ++++++++++++++++++ .../onnx/export/onnx_backend_test.py | 6 + 2 files changed, 126 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index af7fedb33cb9..0960776251c4 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -308,6 +308,126 @@ def convert_tanh(node, **kwargs): ) return [node] +@mx_op.register("cos") +def convert_cos(node, **kwargs): + """Map MXNet's cos operator attributes to onnx's Cos operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_idx].name + + node = helper.make_node( + 'Cos', + [input_node], + [name], + name=name + ) + return [node] + +@mx_op.register("sin") +def convert_sin(node, **kwargs): + """Map MXNet's sin operator attributes to onnx's Sin operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_idx].name + + node = helper.make_node( + 'Sin', + [input_node], + [name], + name=name + ) + return [node] + +@mx_op.register("tan") +def convert_tan(node, **kwargs): + """Map MXNet's tan operator attributes to onnx's tan operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_idx].name + + node = helper.make_node( + 'Tan', + [input_node], + [name], + name=name + ) + return [node] + +@mx_op.register("arccos") +def convert_acos(node, **kwargs): + """Map MXNet's acos operator attributes to onnx's acos operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_idx].name + + node = helper.make_node( + 'Acos', + [input_node], + [name], + name=name + ) + return [node] + +@mx_op.register("arcsin") +def convert_asin(node, **kwargs): + """Map MXNet's asin operator attributes to onnx's asin operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_idx].name + + node = helper.make_node( + 'Asin', + [input_node], + [name], + name=name + ) + return [node] + +@mx_op.register("arctan") +def convert_atan(node, **kwargs): + """Map MXNet's atan operator attributes to onnx's atan operator + and return the created node. + """ + helper, _, _ = import_onnx_modules() + name = node["name"] + inputs = node["inputs"] + input_node_idx = kwargs["index_lookup"][inputs[0][0]] + proc_nodes = kwargs["proc_nodes"] + input_node = proc_nodes[input_node_idx].name + + node = helper.make_node( + 'Atan', + [input_node], + [name], + name=name + ) + return [node] + #Basic neural network functions @mx_op.register("sigmoid") def convert_sigmoid(node, **kwargs): diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py b/tests/python-pytest/onnx/export/onnx_backend_test.py index 1fbfde5977eb..19bf6993e7cd 100644 --- a/tests/python-pytest/onnx/export/onnx_backend_test.py +++ b/tests/python-pytest/onnx/export/onnx_backend_test.py @@ -45,6 +45,12 @@ 'test_abs', 'test_sum', 'test_tanh', + 'test_cos', + 'test_sin', + 'test_tan', + 'test_acos', + 'test_asin', + 'test_atan' 'test_ceil', 'test_floor', 'test_concat',