Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX fix log_softmax for opset 12 (#20209)
Browse files Browse the repository at this point in the history
* fix log_softmax for opset 12

* Update _op_translations_opset12.py

* Update _op_translations_opset13.py
  • Loading branch information
Zha0q1 authored Apr 26, 2021
1 parent 4988630 commit dca422c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2395,22 +2395,28 @@ def convert_logsoftmax(node, **kwargs):
"""Map MXNet's log_softmax operator attributes to onnx's LogSoftMax operator
and return the created node.
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

# Converting to int
axis = int(attrs.get("axis", -1))
temp = attrs.get("temperature", 'None')
temp = attrs.get('temperature', 'None')
use_length = attrs.get('use_length', 'False')

if temp != 'None':
raise AttributeError("LogSoftMax: ONNX supports only temperature=None")
raise AttributeError('LogSoftMax currently does not support temperature!=None')

node = onnx.helper.make_node(
'LogSoftmax',
input_nodes,
[name],
axis=axis,
name=name
)
return [node]
if use_length in ['1', 'True']:
raise AttributeError('LogSoftMax currently does not support use_length==True')

nodes = [
make_node('Exp', [input_nodes[0]], [name+'_exp']),
make_node('ReduceSum', [name+'_exp'], [name+'_rsum'], axes=[axis], keepdims=1),
make_node('Div', [name+'_exp', name+'_rsum'], [name+'_div']),
make_node('Log', [name+'_div'], [name])
]

return nodes

@mx_op.register("norm")
def convert_norm(node, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1615,3 +1615,32 @@ def convert_norm(node, **kwargs):
make_node('Reshape', [name+'_norm', name+'_1'], [name])
]
return nodes


@mx_op.register("log_softmax", OPSET_VERSION)
def convert_logsoftmax(node, **kwargs):
"""Map MXNet's log_softmax operator attributes to onnx's LogSoftMax operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

# Converting to int
axis = int(attrs.get("axis", -1))
temp = attrs.get('temperature', 'None')
use_length = attrs.get('use_length', 'False')

if temp != 'None':
raise AttributeError('LogSoftMax currently does not support temperature!=None')

if use_length in ['1', 'True']:
raise AttributeError('LogSoftMax currently does not support use_length==True')

node = onnx.helper.make_node(
'LogSoftmax',
input_nodes,
[name],
axis=axis,
name=name
)

return [node]

0 comments on commit dca422c

Please sign in to comment.