diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 0a9d9aac71e4..f908d312a813 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -848,20 +848,86 @@ def convert_softmax(node, **kwargs): """Map MXNet's softmax operator attributes to onnx's Softmax operator and return the created node. """ + from onnx.helper import make_node + from onnx import TensorProto name, input_nodes, attrs = get_inputs(node, kwargs) axis = int(attrs.get("axis", -1)) + temperature = attrs.get("temperature", None) + if temperature and float(temperature) != 1.0: + raise NotImplementedError("Temperature is not supported for now.") + use_length = attrs.get("use_length", None) + input_type = kwargs["in_type"] + data = input_nodes[0] - softmax_node = onnx.helper.make_node( - "Softmax", - input_nodes, - [name], - axis=axis, - name=name - ) + nodes = [ + make_node("Exp", [data], [name+"_exp_out"]), + make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1) + ] + if len(input_nodes) == 1: + nodes += [ + make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name) + ] + return nodes + elif use_length == "True": + length = input_nodes[1] - return [softmax_node] + nodes += [ + # const nodes + create_tensor([axis], name+"_axis", kwargs["initializer"]), + create_tensor([], name+"_void", kwargs["initializer"]), + create_tensor([0], name+"_0", kwargs["initializer"]), + create_tensor([1], name+"_1", kwargs["initializer"]), + create_const_scalar_node(name+'_-1_s', np.int64(-1), kwargs), + create_const_scalar_node(name+'_0_s', np.int64(0), kwargs), + create_const_scalar_node(name+'_1_s', np.int64(1), kwargs), + # cast data type + make_node("Cast", [length], [name+"_length"], to=int(TensorProto.INT64)), + make_node("Cast", [name+"_0"], [name+"_0_itype"], to=input_type), + make_node("Cast", [name+"_1"], [name+"_1_itype"], to=input_type), + # softmax output + make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name+"_div1_out"]), + # update axis + make_node("Shape", [data], [name+"_shape0_out"]), + make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]), + make_node("Add", [name+"_in_dim", name+"_axis"], [name+"_dim+axis"]), + make_node("Less", [name+"_axis", name+"_0_s"], [name+"_less0_out"]), + make_node("Where", [name+"_less0_out", name+"_dim+axis", name+"_axis"], [name+"_final_axis"]), + # data mask + make_node("Add", [name+"_final_axis", name+"_1_s"], [name+"_final_axis+1"]), + make_node("Slice", [name+"_shape0_out", name+"_final_axis", name+"_final_axis+1"], [name+"_axis_dim"]), + make_node("Reshape", [name+"_axis_dim", name+"_void"], [name+"_axis_dim_s"]), + make_node("Range", [name+"_0_s", name+"_axis_dim_s", name+"_1_s"], [name+"_range0_out"]), + # one hot for axis + make_node("Reshape", [name+"_in_dim", name+"_void"], [name+"_in_dim_s"]), + make_node("Range", [name+"_0_s", name+"_in_dim_s", name+"_1_s"], [name+"_range1_out"]), + make_node("Equal", [name+"_range1_out", name+"_final_axis"], [name+"_equal_out"]), + make_node("Cast", [name+"_equal_out"], [name+"_one_hot"], to=int(TensorProto.INT64)), + # reshape data mask for less + make_node("Sub", [name+"_axis_dim_s", name+"_1_s"], [name+"_sub0_out"]), + make_node("Mul", [name+"_one_hot", name+"_sub0_out"], [name+"_mul0_out"]), + make_node("Add", [name+"_mul0_out", name+"_1_s"], [name+"_add0_out"]), + make_node('Reshape', [name+"_range0_out", name+"_add0_out"], [name+"_reshape0_out"]), + # reshape length for less + make_node("Mul", [name+"_one_hot", name+"_-1_s"], [name+"_mul1_out"]), + make_node("Add", [name+"_mul1_out", name+"_1_s"], [name+"_add1_out"]), + make_node("Sub", [name+"_shape0_out", name+"_1_s"], [name+"_sub1_out"]), + make_node("Mul", [name+"_add1_out", name+"_sub1_out"], [name+"_mul2_out"]), + make_node("Add", [name+"_mul2_out", name+"_1_s"], [name+"_add2_out"]), + make_node('Reshape', [name+"_length", name+"_add2_out"], [name+"_reshape1_out"]), + # mask output + make_node("Less", [name+"_reshape0_out", name+"_reshape1_out"], [name+"_less_out"]), + make_node("Cast", [name+"_less_out"], [name+"_mask"], to=input_type), + make_node("Mul", [name+"_div1_out", name+"_mask"], [name+"_mul3_out"]), + make_node("ReduceSum", [name+"_mul3_out"], [name+"_rsum1_out"], axes=[axis], keepdims=1), + make_node("Equal", [name+"_rsum1_out", name+"_0_itype"], [name+"_equal1_out"]), + make_node("Where", [name+"_equal1_out", name+"_1_itype", name+"_rsum1_out"], [name+"_where_out"]), + make_node("Div", [name+"_mul3_out", name+"_where_out"], [name], name=name) + ] + return nodes + else: + raise NotImplementedError("use_length must be true when both data and length are paased in.") # There's also mx.sym.softmax(), which doesn't do cross-entropy loss, # just softmax for inference - hence the name convert_softmax_output. diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 51fe41895ec0..69bec1dd603a 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -314,7 +314,7 @@ def _selu(attrs, inputs, proto_obj): def softmax(attrs, inputs, proto_obj): """Softmax function.""" if 'axis' not in attrs: - attrs = translation_utils._add_extra_attributes(attrs, {'axis': 1}) + attrs = translation_utils._add_extra_attributes(attrs, {'axis': -1}) return 'softmax', attrs, inputs def log_softmax(attrs, inputs, proto_obj): diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 87554ed795a8..4f259c08c62a 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -51,7 +51,8 @@ def export_to_onnx(model, model_name, inputs): def onnx_rt(onnx_file, inputs): sess = rt.InferenceSession(onnx_file) - input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs))) + dtype_0 = inputs[0].asnumpy().dtype + input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy().astype(dtype_0)) for i in range(len(inputs))) pred = sess.run(None, input_dict)[0] return pred @@ -302,3 +303,19 @@ def test_onnx_export_cast(tmp_path, src_dtype, dst_dtype, shape): M = def_model('Cast', dtype=dst_dtype) x = mx.nd.ones(shape, dtype=src_dtype) op_export_test('Cast', M, [x], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float16', 'float32']) +def test_onnx_export_softmax(tmp_path, dtype): + x = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype) + M1 = def_model('softmax') + op_export_test('softmax_1', M1, [x], tmp_path) + M2 = def_model('softmax', use_length=True, axis=0) + l2 = mx.nd.array([[2,0,2,1],[1,1,2,1], [0,0,0,1]], dtype=int) + op_export_test('softmax_2', M2, [x, l2], tmp_path) + M3 = def_model('softmax', use_length=True, axis=-1) + l3 = mx.nd.array([[2,0,4],[0,0,0]], dtype=int) + op_export_test('softmax_3', M3, [x, l3], tmp_path) + M4 = def_model('softmax', use_length=True, axis=1) + l4 = mx.nd.array([[2,0,3,1],[0,1,0,0]], dtype=int) + op_export_test('softmax_4', M4, [x, l4], tmp_path)