Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] ONNX fix softmax #19691

Merged
merged 11 commits into from
Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 74 additions & 8 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,20 +848,86 @@ def convert_softmax(node, **kwargs):
"""Map MXNet's softmax operator attributes to onnx's Softmax operator
and return the created node.
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = int(attrs.get("axis", -1))
temperature = attrs.get("temperature", None)
if temperature and float(temperature) != 1.0:
raise NotImplementedError("Temperature is not supported for now.")
use_length = attrs.get("use_length", None)
input_type = kwargs["in_type"]
data = input_nodes[0]

softmax_node = onnx.helper.make_node(
"Softmax",
input_nodes,
[name],
axis=axis,
name=name
)
nodes = [
make_node("Exp", [data], [name+"_exp_out"]),
make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1)
]
if len(input_nodes) == 1:
nodes += [
make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name)
]
return nodes
elif use_length == "True":
length = input_nodes[1]

return [softmax_node]
nodes += [
# const nodes
create_tensor([axis], name+"_axis", kwargs["initializer"]),
create_tensor([], name+"_void", kwargs["initializer"]),
create_tensor([0], name+"_0", kwargs["initializer"]),
create_tensor([1], name+"_1", kwargs["initializer"]),
create_const_scalar_node(name+'_-1_s', np.int64(-1), kwargs),
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs),
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs),
# cast data type
make_node("Cast", [length], [name+"_length"], to=int(TensorProto.INT64)),
make_node("Cast", [name+"_0"], [name+"_0_itype"], to=input_type),
make_node("Cast", [name+"_1"], [name+"_1_itype"], to=input_type),
# softmax output
make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name+"_div1_out"]),
# update axis
make_node("Shape", [data], [name+"_shape0_out"]),
make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]),
make_node("Add", [name+"_in_dim", name+"_axis"], [name+"_dim+axis"]),
make_node("Less", [name+"_axis", name+"_0_s"], [name+"_less0_out"]),
make_node("Where", [name+"_less0_out", name+"_dim+axis", name+"_axis"], [name+"_final_axis"]),
# data mask
make_node("Add", [name+"_final_axis", name+"_1_s"], [name+"_final_axis+1"]),
make_node("Slice", [name+"_shape0_out", name+"_final_axis", name+"_final_axis+1"], [name+"_axis_dim"]),
make_node("Reshape", [name+"_axis_dim", name+"_void"], [name+"_axis_dim_s"]),
make_node("Range", [name+"_0_s", name+"_axis_dim_s", name+"_1_s"], [name+"_range0_out"]),
# one hot for axis
make_node("Reshape", [name+"_in_dim", name+"_void"], [name+"_in_dim_s"]),
make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range1_out"]),
make_node("Equal", [name+"_range1_out", name+"_final_axis"], [name+"_equal_out"]),
make_node("Cast", [name+"_equal_out"], [name+"_one_hot"], to=int(TensorProto.INT64)),
# reshape data mask for less
make_node("Sub", [name+"_axis_dim_s", name+"_1_s"], [name+"_sub0_out"]),
make_node("Mul", [name+"_one_hot", name+"_sub0_out"], [name+"_mul0_out"]),
make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add0_out"]),
make_node('Reshape', [name+"_range0_out", name+"_add0_out"], [name+"_reshape0_out"]),
# reshape length for less
make_node("Mul", [name+"_one_hot", name+"_-1_s"], [name+"_mul1_out"]),
make_node("Add", [name+"_mul1_out", name+"_1_s"], [name+"_add1_out"]),
make_node("Sub", [name+"_shape0_out", name+"_1_s"], [name+"_sub1_out"]),
make_node("Mul", [name+"_add1_out", name+"_sub1_out"], [name+"_mul2_out"]),
make_node("Add", [name+"_mul2_out", name+"_1_s"], [name+"_add2_out"]),
make_node('Reshape', [name+"_length", name+"_add2_out"], [name+"_reshape1_out"]),
# mask output
make_node("Less", [name+"_reshape0_out", name+"_reshape1_out"], [name+"_less_out"]),
make_node("Cast", [name+"_less_out"], [name+"_mask"], to=input_type),
make_node("Mul", [name+"_div1_out", name+"_mask"], [name+"_mul3_out"]),
make_node("ReduceSum", [name+"_mul3_out"], [name+"_rsum1_out"], axes=[axis], keepdims=1),
make_node("Equal", [name+"_rsum1_out", name+"_0_itype"], [name+"_equal1_out"]),
make_node("Where", [name+"_equal1_out", name+"_1_itype", name+"_rsum1_out"], [name+"_where_out"]),
make_node("Div", [name+"_mul3_out", name+"_where_out"], [name], name=name)
]
return nodes

else:
raise NotImplementedError("use_length must be true when both data and length are paased in.")

# There's also mx.sym.softmax(), which doesn't do cross-entropy loss,
# just softmax for inference - hence the name convert_softmax_output.
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def _selu(attrs, inputs, proto_obj):
def softmax(attrs, inputs, proto_obj):
"""Softmax function."""
if 'axis' not in attrs:
attrs = translation_utils._add_extra_attributes(attrs, {'axis': 1})
attrs = translation_utils._add_extra_attributes(attrs, {'axis': -1})
return 'softmax', attrs, inputs

def log_softmax(attrs, inputs, proto_obj):
Expand Down
19 changes: 18 additions & 1 deletion tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def export_to_onnx(model, model_name, inputs):

def onnx_rt(onnx_file, inputs):
sess = rt.InferenceSession(onnx_file)
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs)))
dtype_0 = inputs[0].asnumpy().dtype
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy().astype(dtype_0)) for i in range(len(inputs)))
pred = sess.run(None, input_dict)[0]
return pred

Expand Down Expand Up @@ -302,3 +303,19 @@ def test_onnx_export_cast(tmp_path, src_dtype, dst_dtype, shape):
M = def_model('Cast', dtype=dst_dtype)
x = mx.nd.ones(shape, dtype=src_dtype)
op_export_test('Cast', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32'])
def test_onnx_export_softmax(tmp_path, dtype):
x = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype)
waytrue17 marked this conversation as resolved.
Show resolved Hide resolved
M1 = def_model('softmax')
op_export_test('softmax_1', M1, [x], tmp_path)
M2 = def_model('softmax', use_length=True, axis=0)
l2 = mx.nd.array([[2,0,2,1],[1,1,2,1], [0,0,0,1]], dtype=int)
op_export_test('softmax_2', M2, [x, l2], tmp_path)
M3 = def_model('softmax', use_length=True, axis=-1)
l3 = mx.nd.array([[2,0,4],[0,0,0]], dtype=int)
op_export_test('softmax_3', M3, [x, l3], tmp_path)
M4 = def_model('softmax', use_length=True, axis=1)
l4 = mx.nd.array([[2,0,3,1],[0,1,0,0]], dtype=int)
op_export_test('softmax_4', M4, [x, l4], tmp_path)