From b4d1f141c8f577bc8466939d5173edbb03021d29 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 30 Apr 2021 14:11:03 -0700 Subject: [PATCH 1/3] optimize softmax --- .../mx2onnx/_op_translations/_op_translations_opset13.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index 0b563ededdde..9764062cc7a0 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -636,12 +636,10 @@ def convert_softmax(node, **kwargs): create_tensor([temperature], name+"_tmp", kwargs["initializer"], dtype=dtype) nodes = [ make_node("Div", [data, name+"_tmp"], [name+'_data']), - make_node("Exp", [name+'_data'], [name+"_exp_out"]), - make_node("ReduceSum", [name+"_exp_out", name+"_axes"], [name+"_rsum_out"], keepdims=1) ] if len(input_nodes) == 1: nodes += [ - make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name) + make_node("Softmax", [name+'_data'], [name], axis=axis) ] return nodes elif use_length == "True": @@ -656,7 +654,7 @@ def convert_softmax(node, **kwargs): make_node("Cast", [name+"_0"], [name+"_0_itype"], to=dtype_t), make_node("Cast", [name+"_1"], [name+"_1_itype"], to=dtype_t), # softmax output - make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name+"_div1_out"]), + make_node("Softmax", [name+'_data'], [name+"_softmax_out"], axis=axis), # update axis make_node("Shape", [data], [name+"_shape0_out"]), make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]), @@ -688,7 +686,7 @@ def convert_softmax(node, **kwargs): # mask output make_node("Less", [name+"_reshape0_out", name+"_reshape1_out"], [name+"_less_out"]), make_node("Cast", [name+"_less_out"], [name+"_mask"], to=dtype_t), - make_node("Mul", [name+"_div1_out", name+"_mask"], [name+"_mul3_out"]), + make_node("Mul", [name+"_softmax_out", name+"_mask"], [name+"_mul3_out"]), make_node("ReduceSum", [name+"_mul3_out", name+"_axes"], [name+"_rsum1_out"], keepdims=1), make_node("Equal", [name+"_rsum1_out", name+"_0_itype"], [name+"_equal1_out"]), make_node("Where", [name+"_equal1_out", name+"_1_itype", name+"_rsum1_out"], [name+"_where_out"]), From 023b67aa85d4da9a5325ec6a45565291c53c3726 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 30 Apr 2021 14:19:08 -0700 Subject: [PATCH 2/3] same for opset12 --- .../mx2onnx/_op_translations/_op_translations_opset12.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index 3778e386a98a..e7d46b48236e 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -942,12 +942,10 @@ def convert_softmax(node, **kwargs): create_tensor([temperature], name+"_tmp", kwargs["initializer"], dtype=dtype) nodes = [ make_node("Div", [data, name+"_tmp"], [name+'_data']), - make_node("Exp", [name+'_data'], [name+"_exp_out"]), - make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1) ] if len(input_nodes) == 1: nodes += [ - make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name) + make_node("Softmax", [name+'_data'], [name], axis=axis), ] return nodes elif use_length == "True": @@ -965,7 +963,7 @@ def convert_softmax(node, **kwargs): make_node("Cast", [name+"_0"], [name+"_0_itype"], to=dtype_t), make_node("Cast", [name+"_1"], [name+"_1_itype"], to=dtype_t), # softmax output - make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name+"_div1_out"]), + make_node("Softmax", [name+'_data'], [name+"_softmax_out"], axis=axis), # update axis make_node("Shape", [data], [name+"_shape0_out"]), make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]), @@ -997,7 +995,7 @@ def convert_softmax(node, **kwargs): # mask output make_node("Less", [name+"_reshape0_out", name+"_reshape1_out"], [name+"_less_out"]), make_node("Cast", [name+"_less_out"], [name+"_mask"], to=dtype_t), - make_node("Mul", [name+"_div1_out", name+"_mask"], [name+"_mul3_out"]), + make_node("Mul", [name+"_softmax_out", name+"_mask"], [name+"_mul3_out"]), make_node("ReduceSum", [name+"_mul3_out"], [name+"_rsum1_out"], axes=[axis], keepdims=1), make_node("Equal", [name+"_rsum1_out", name+"_0_itype"], [name+"_equal1_out"]), make_node("Where", [name+"_equal1_out", name+"_1_itype", name+"_rsum1_out"], [name+"_where_out"]), From 0bc24bf73346cd09b906bf2ab03a45883d440d52 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 30 Apr 2021 14:41:35 -0700 Subject: [PATCH 3/3] revert opset12 --- .../mx2onnx/_op_translations/_op_translations_opset12.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index e7d46b48236e..32fcf49410b6 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -942,10 +942,12 @@ def convert_softmax(node, **kwargs): create_tensor([temperature], name+"_tmp", kwargs["initializer"], dtype=dtype) nodes = [ make_node("Div", [data, name+"_tmp"], [name+'_data']), + make_node("Exp", [name+'_data'], [name+"_exp_out"]), + make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1), ] if len(input_nodes) == 1: nodes += [ - make_node("Softmax", [name+'_data'], [name], axis=axis), + make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name), ] return nodes elif use_length == "True": @@ -963,7 +965,7 @@ def convert_softmax(node, **kwargs): make_node("Cast", [name+"_0"], [name+"_0_itype"], to=dtype_t), make_node("Cast", [name+"_1"], [name+"_1_itype"], to=dtype_t), # softmax output - make_node("Softmax", [name+'_data'], [name+"_softmax_out"], axis=axis), + make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name+"_div1_out"]), # update axis make_node("Shape", [data], [name+"_shape0_out"]), make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]), @@ -995,7 +997,7 @@ def convert_softmax(node, **kwargs): # mask output make_node("Less", [name+"_reshape0_out", name+"_reshape1_out"], [name+"_less_out"]), make_node("Cast", [name+"_less_out"], [name+"_mask"], to=dtype_t), - make_node("Mul", [name+"_softmax_out", name+"_mask"], [name+"_mul3_out"]), + make_node("Mul", [name+"_div1_out", name+"_mask"], [name+"_mul3_out"]), make_node("ReduceSum", [name+"_mul3_out"], [name+"_rsum1_out"], axes=[axis], keepdims=1), make_node("Equal", [name+"_rsum1_out", name+"_0_itype"], [name+"_equal1_out"]), make_node("Where", [name+"_equal1_out", name+"_1_itype", name+"_rsum1_out"], [name+"_where_out"]),