Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX optimize softmax (#20231)
Browse files Browse the repository at this point in the history
* optimize softmax

* same for opset12

* revert opset12

Co-authored-by: Wei Chu <[email protected]>
  • Loading branch information
waytrue17 and Wei Chu authored May 1, 2021
1 parent 2127c3e commit b1b3634
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -943,11 +943,11 @@ def convert_softmax(node, **kwargs):
nodes = [
make_node("Div", [data, name+"_tmp"], [name+'_data']),
make_node("Exp", [name+'_data'], [name+"_exp_out"]),
make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1)
make_node("ReduceSum", [name+"_exp_out"], [name+"_rsum_out"], axes=[axis], keepdims=1),
]
if len(input_nodes) == 1:
nodes += [
make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name)
make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name),
]
return nodes
elif use_length == "True":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -636,12 +636,10 @@ def convert_softmax(node, **kwargs):
create_tensor([temperature], name+"_tmp", kwargs["initializer"], dtype=dtype)
nodes = [
make_node("Div", [data, name+"_tmp"], [name+'_data']),
make_node("Exp", [name+'_data'], [name+"_exp_out"]),
make_node("ReduceSum", [name+"_exp_out", name+"_axes"], [name+"_rsum_out"], keepdims=1)
]
if len(input_nodes) == 1:
nodes += [
make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name], name=name)
make_node("Softmax", [name+'_data'], [name], axis=axis)
]
return nodes
elif use_length == "True":
Expand All @@ -656,7 +654,7 @@ def convert_softmax(node, **kwargs):
make_node("Cast", [name+"_0"], [name+"_0_itype"], to=dtype_t),
make_node("Cast", [name+"_1"], [name+"_1_itype"], to=dtype_t),
# softmax output
make_node("Div", [name+"_exp_out", name+"_rsum_out"], [name+"_div1_out"]),
make_node("Softmax", [name+'_data'], [name+"_softmax_out"], axis=axis),
# update axis
make_node("Shape", [data], [name+"_shape0_out"]),
make_node("Shape", [name+"_shape0_out"], [name+"_in_dim"]),
Expand Down Expand Up @@ -688,7 +686,7 @@ def convert_softmax(node, **kwargs):
# mask output
make_node("Less", [name+"_reshape0_out", name+"_reshape1_out"], [name+"_less_out"]),
make_node("Cast", [name+"_less_out"], [name+"_mask"], to=dtype_t),
make_node("Mul", [name+"_div1_out", name+"_mask"], [name+"_mul3_out"]),
make_node("Mul", [name+"_softmax_out", name+"_mask"], [name+"_mul3_out"]),
make_node("ReduceSum", [name+"_mul3_out", name+"_axes"], [name+"_rsum1_out"], keepdims=1),
make_node("Equal", [name+"_rsum1_out", name+"_0_itype"], [name+"_equal1_out"]),
make_node("Where", [name+"_equal1_out", name+"_1_itype", name+"_rsum1_out"], [name+"_where_out"]),
Expand Down

0 comments on commit b1b3634

Please sign in to comment.