Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-878] Add trigonometric operators to onnx #12424

Merged
merged 1 commit into from
Sep 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,126 @@ def convert_tanh(node, **kwargs):
)
return [node]

@mx_op.register("cos")
def convert_cos(node, **kwargs):
"""Map MXNet's cos operator attributes to onnx's Cos operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]
input_node_idx = kwargs["index_lookup"][inputs[0][0]]
proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[input_node_idx].name

node = helper.make_node(
'Cos',
[input_node],
[name],
name=name
)
return [node]

@mx_op.register("sin")
def convert_sin(node, **kwargs):
"""Map MXNet's sin operator attributes to onnx's Sin operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]
input_node_idx = kwargs["index_lookup"][inputs[0][0]]
proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[input_node_idx].name

node = helper.make_node(
'Sin',
[input_node],
[name],
name=name
)
return [node]

@mx_op.register("tan")
def convert_tan(node, **kwargs):
"""Map MXNet's tan operator attributes to onnx's tan operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]
input_node_idx = kwargs["index_lookup"][inputs[0][0]]
proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[input_node_idx].name

node = helper.make_node(
'Tan',
[input_node],
[name],
name=name
)
return [node]

@mx_op.register("arccos")
def convert_acos(node, **kwargs):
"""Map MXNet's acos operator attributes to onnx's acos operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]
input_node_idx = kwargs["index_lookup"][inputs[0][0]]
proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[input_node_idx].name

node = helper.make_node(
'Acos',
[input_node],
[name],
name=name
)
return [node]

@mx_op.register("arcsin")
def convert_asin(node, **kwargs):
"""Map MXNet's asin operator attributes to onnx's asin operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]
input_node_idx = kwargs["index_lookup"][inputs[0][0]]
proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[input_node_idx].name

node = helper.make_node(
'Asin',
[input_node],
[name],
name=name
)
return [node]

@mx_op.register("arctan")
def convert_atan(node, **kwargs):
"""Map MXNet's atan operator attributes to onnx's atan operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]
input_node_idx = kwargs["index_lookup"][inputs[0][0]]
proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[input_node_idx].name

node = helper.make_node(
'Atan',
[input_node],
[name],
name=name
)
return [node]

#Basic neural network functions
@mx_op.register("sigmoid")
def convert_sigmoid(node, **kwargs):
Expand Down
6 changes: 6 additions & 0 deletions tests/python-pytest/onnx/export/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@
'test_abs',
'test_sum',
'test_tanh',
'test_cos',
'test_sin',
'test_tan',
'test_acos',
'test_asin',
'test_atan'
'test_ceil',
'test_floor',
'test_concat',
Expand Down