diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index 3778e386a98a..32fcf49410b6 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -943,11 +943,11 @@ def convert_softmax(node, **kwargs): nodes = [ make_node("Div", [data, name+"_tmp"], [name+'_data']), make_node("Exp", [name+'_data'], [name+"_exp_out"]), - make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1) + make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1), ] if len(input_nodes) == 1: nodes += [ - make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name) + make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name), ] return nodes elif use_length == "True": diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index 0b563ededdde..9764062cc7a0 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -636,12 +636,10 @@ def convert_softmax(node, **kwargs): create_tensor([temperature], name+"_tmp", kwargs["initializer"], dtype=dtype) nodes = [ make_node("Div", [data, name+"_tmp"], [name+'_data']), - make_node("Exp", [name+'_data'], [name+"_exp_out"]), - make_node("ReduceSum", [name+"_exp_out", name+"_axes"], [name+"_rsum_out"], keepdims=1) ] if len(input_nodes) == 1: nodes += [ - make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name) + make_node("Softmax", [name+'_data'], [name], axis=axis) ] return nodes elif use_length == "True": @@ -656,7 +654,7 @@ def convert_softmax(node, **kwargs): make_node("Cast", [name+"_0"], [name+"_0_itype"], to=dtype_t), make_node("Cast", [name+"_1"], [name+"_1_itype"], to=dtype_t), # softmax output - make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name+"_div1_out"]), + make_node("Softmax", [name+'_data'], [name+"_softmax_out"], axis=axis), # update axis make_node("Shape", [data], [name+"_shape0_out"]), make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]), @@ -688,7 +686,7 @@ def convert_softmax(node, **kwargs): # mask output make_node("Less", [name+"_reshape0_out", name+"_reshape1_out"], [name+"_less_out"]), make_node("Cast", [name+"_less_out"], [name+"_mask"], to=dtype_t), - make_node("Mul", [name+"_div1_out", name+"_mask"], [name+"_mul3_out"]), + make_node("Mul", [name+"_softmax_out", name+"_mask"], [name+"_mul3_out"]), make_node("ReduceSum", [name+"_mul3_out", name+"_axes"], [name+"_rsum1_out"], keepdims=1), make_node("Equal", [name+"_rsum1_out", name+"_0_itype"], [name+"_equal1_out"]), make_node("Where", [name+"_equal1_out", name+"_1_itype", name+"_rsum1_out"], [name+"_where_out"]),