Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX fix softmax (#19691)
Browse files Browse the repository at this point in the history
* fix softmax

* add test

* fix typo

* fix test shape

* update test data type

* add more tests

* fix temperature

* fix onnx2mx

* remove temperature

* update msg

* update msg

Co-authored-by: Wei Chu <[email protected]>
  • Loading branch information
waytrue17 and Wei Chu committed Dec 20, 2020
1 parent 403d31f commit 5fce08a
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 10 deletions.
82 changes: 74 additions & 8 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,20 +874,86 @@ def convert_softmax(node, **kwargs):
"""Map MXNet's softmax operator attributes to onnx's Softmax operator
and return the created node.
"""
from onnx.helper import make_node
from onnx import TensorProto
name, input_nodes, attrs = get_inputs(node, kwargs)

axis = int(attrs.get("axis", -1))
temperature = attrs.get("temperature", None)
if temperature and float(temperature) != 1.0:
raise NotImplementedError("Temperature is not supported for now.")
use_length = attrs.get("use_length", None)
input_type = kwargs["in_type"]
data = input_nodes[0]

softmax_node = onnx.helper.make_node(
"Softmax",
input_nodes,
[name],
axis=axis,
name=name
)
nodes = [
make_node("Exp", [data], [name+"_exp_out"]),
make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1)
]
if len(input_nodes) == 1:
nodes += [
make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name)
]
return nodes
elif use_length == "True":
length = input_nodes[1]

return [softmax_node]
nodes += [
# const nodes
create_tensor([axis], name+"_axis", kwargs["initializer"]),
create_tensor([], name+"_void", kwargs["initializer"]),
create_tensor([0], name+"_0", kwargs["initializer"]),
create_tensor([1], name+"_1", kwargs["initializer"]),
create_const_scalar_node(name+'_-1_s', np.int64(-1), kwargs),
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs),
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs),
# cast data type
make_node("Cast", [length], [name+"_length"], to=int(TensorProto.INT64)),
make_node("Cast", [name+"_0"], [name+"_0_itype"], to=input_type),
make_node("Cast", [name+"_1"], [name+"_1_itype"], to=input_type),
# softmax output
make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name+"_div1_out"]),
# update axis
make_node("Shape", [data], [name+"_shape0_out"]),
make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]),
make_node("Add", [name+"_in_dim", name+"_axis"], [name+"_dim+axis"]),
make_node("Less", [name+"_axis", name+"_0_s"], [name+"_less0_out"]),
make_node("Where", [name+"_less0_out", name+"_dim+axis", name+"_axis"], [name+"_final_axis"]),
# data mask
make_node("Add", [name+"_final_axis", name+"_1_s"], [name+"_final_axis+1"]),
make_node("Slice", [name+"_shape0_out", name+"_final_axis", name+"_final_axis+1"], [name+"_axis_dim"]),
make_node("Reshape", [name+"_axis_dim", name+"_void"], [name+"_axis_dim_s"]),
make_node("Range", [name+"_0_s", name+"_axis_dim_s", name+"_1_s"], [name+"_range0_out"]),
# one hot for axis
make_node("Reshape", [name+"_in_dim", name+"_void"], [name+"_in_dim_s"]),
make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range1_out"]),
make_node("Equal", [name+"_range1_out", name+"_final_axis"], [name+"_equal_out"]),
make_node("Cast", [name+"_equal_out"], [name+"_one_hot"], to=int(TensorProto.INT64)),
# reshape data mask for less
make_node("Sub", [name+"_axis_dim_s", name+"_1_s"], [name+"_sub0_out"]),
make_node("Mul", [name+"_one_hot", name+"_sub0_out"], [name+"_mul0_out"]),
make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add0_out"]),
make_node('Reshape', [name+"_range0_out", name+"_add0_out"], [name+"_reshape0_out"]),
# reshape length for less
make_node("Mul", [name+"_one_hot", name+"_-1_s"], [name+"_mul1_out"]),
make_node("Add", [name+"_mul1_out", name+"_1_s"], [name+"_add1_out"]),
make_node("Sub", [name+"_shape0_out", name+"_1_s"], [name+"_sub1_out"]),
make_node("Mul", [name+"_add1_out", name+"_sub1_out"], [name+"_mul2_out"]),
make_node("Add", [name+"_mul2_out", name+"_1_s"], [name+"_add2_out"]),
make_node('Reshape', [name+"_length", name+"_add2_out"], [name+"_reshape1_out"]),
# mask output
make_node("Less", [name+"_reshape0_out", name+"_reshape1_out"], [name+"_less_out"]),
make_node("Cast", [name+"_less_out"], [name+"_mask"], to=input_type),
make_node("Mul", [name+"_div1_out", name+"_mask"], [name+"_mul3_out"]),
make_node("ReduceSum", [name+"_mul3_out"], [name+"_rsum1_out"], axes=[axis], keepdims=1),
make_node("Equal", [name+"_rsum1_out", name+"_0_itype"], [name+"_equal1_out"]),
make_node("Where", [name+"_equal1_out", name+"_1_itype", name+"_rsum1_out"], [name+"_where_out"]),
make_node("Div", [name+"_mul3_out", name+"_where_out"], [name], name=name)
]
return nodes

else:
raise NotImplementedError("use_length must be true when both data and length are paased in.")

# There's also mx.sym.softmax(), which doesn't do cross-entropy loss,
# just softmax for inference - hence the name convert_softmax_output.
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def _selu(attrs, inputs, proto_obj):
def softmax(attrs, inputs, proto_obj):
"""Softmax function."""
if 'axis' not in attrs:
attrs = translation_utils._add_extra_attributes(attrs, {'axis': 1})
attrs = translation_utils._add_extra_attributes(attrs, {'axis': -1})
return 'softmax', attrs, inputs

def log_softmax(attrs, inputs, proto_obj):
Expand Down
19 changes: 18 additions & 1 deletion tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def export_to_onnx(model, model_name, inputs):

def onnx_rt(onnx_file, inputs):
sess = rt.InferenceSession(onnx_file)
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs)))
dtype_0 = inputs[0].asnumpy().dtype
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy().astype(dtype_0)) for i in range(len(inputs)))
pred = sess.run(None, input_dict)[0]
return pred

Expand Down Expand Up @@ -309,3 +310,19 @@ def test_onnx_export_cast(tmp_path, src_dtype, dst_dtype, shape):
M = def_model('Cast', dtype=dst_dtype)
x = mx.nd.ones(shape, dtype=src_dtype)
op_export_test('Cast', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32'])
def test_onnx_export_softmax(tmp_path, dtype):
x = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype)
M1 = def_model('softmax')
op_export_test('softmax_1', M1, [x], tmp_path)
M2 = def_model('softmax', use_length=True, axis=0)
l2 = mx.nd.array([[2,0,2,1],[1,1,2,1], [0,0,0,1]], dtype=int)
op_export_test('softmax_2', M2, [x, l2], tmp_path)
M3 = def_model('softmax', use_length=True, axis=-1)
l3 = mx.nd.array([[2,0,4],[0,0,0]], dtype=int)
op_export_test('softmax_3', M3, [x, l3], tmp_path)
M4 = def_model('softmax', use_length=True, axis=1)
l4 = mx.nd.array([[2,0,3,1],[0,1,0,0]], dtype=int)
op_export_test('softmax_4', M4, [x, l4], tmp_path)

0 comments on commit 5fce08a

Please sign in to comment.