From be4e56c41dbc10171fb272dca1df4b854cb31fe5 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 13 Oct 2021 16:56:05 +0200 Subject: [PATCH] unified softplus kernel --- paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc | 9 +-------- paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h | 9 ++++++--- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 96fbafe2d1a85..29106dc30498e 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -173,14 +173,7 @@ struct GeluMKLDNNGradFunctor : public BaseActivationFunctor { template struct SoftplusMKLDNNFunctor : public BaseActivationFunctor { void operator()(const framework::ExecutionContext &ctx) const { - const float beta = ctx.Attr("beta"); - // if beta is equal to 1.0f then we can simply use oneDNN's soft_relu but if - // it has other value, we have to fuse binary + eltwise + binary - if (beta == 1.0f) { - eltwise_forward(ctx, mkldnn::algorithm::eltwise_soft_relu); - } else { - custom_softplus_eltwise_forward(ctx); - } + custom_softplus_eltwise_forward(ctx); } }; diff --git a/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h b/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h index b74451d353855..fdb2c534e0363 100644 --- a/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h +++ b/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h @@ -37,7 +37,11 @@ class SoftplusMKLDNNHandler dnnl::post_ops post_ops; post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_soft_relu, 0.0f, 0.0f); - post_ops.append_binary(dnnl::algorithm::binary_div, beta_md); + if (beta != 1.0f) { + post_ops.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, + 1.0f / beta, 0.0f); + } + dnnl::primitive_attr attrs; attrs.set_post_ops(post_ops); @@ -78,8 +82,7 @@ void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) { const std::unordered_map args = { {DNNL_ARG_SRC_0, *src_memory_p}, {DNNL_ARG_SRC_1, *beta_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}, - {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, *beta_memory_p}}; + {DNNL_ARG_DST, *dst_memory_p}}; binary_p->execute(astream, args); astream.wait();