Skip to content

Commit

Permalink
modify mish op and add mish api (#38734)
Browse files Browse the repository at this point in the history
* add mish operator and api

* remove redundant code and modify grad_atol of mish unittest

* modify mish code to be consistent with other activation implementation
  • Loading branch information
wangxinxin08 authored Jan 7, 2022
1 parent fb3313e commit 8c92337
Show file tree
Hide file tree
Showing 13 changed files with 304 additions and 538 deletions.
37 changes: 37 additions & 0 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,36 @@ Swish Activation Operator.
}
};

class MishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Mish operator");
AddOutput("Out", "Output of Mish operator");
AddAttr<float>(
"threshold",
"Constant threshold of softplus in Mish operator. Approximate value "
"of softplus will be used if absolute value of input is greater than "
":attr:`threshold`")
.SetDefault(20.f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC(
Mish Activation Operator.
.. math::
softplus(x) = \begin{cases}
x, \text{if } x > \text{threshold} \\
\ln(1 + e^{x}), \text{otherwise}
\end{cases}
out = x * \tanh(softplus(x))
)DOC");
}
};

class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -1901,4 +1931,11 @@ REGISTER_OP_VERSION(softplus)
.NewAttr("threshold", "The threshold value of the new formula",
20.0f));

REGISTER_OP_VERSION(mish)
.AddCheckpoint(
R"ROC(add new attributes [use_mkldnn], and when computing softplus the formula is changed as the new veriosn of softplus)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_mkldnn", "(bool, default false) Only used in mkldnn kernel",
false));

/* ========================================================================== */
50 changes: 50 additions & 0 deletions paddle/fluid/operators/activation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,55 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaMishFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}

// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
// Inputs: args[0], the input x
__device__ __forceinline__ T operator()(const T& arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
return static_cast<T>(x * tanh(sp));
}
};

template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}

// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T& arg_dout,
const T& arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
MPType gsp =
(x > static_cast<MPType>(threshold)) ? one : one / (one + exp(-x));
MPType tsp = tanh(sp);
return static_cast<T>(dout * (tsp + x * (one - tsp * tsp) * gsp));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaThresholdedReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
Expand Down Expand Up @@ -1808,6 +1857,7 @@ REGISTER_OP_CUDA_KERNEL(
__macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor, \
CudaHardSigmoidGradFunctor); \
__macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \
__macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor); \
__macro(thresholded_relu, ThresholdedRelu, CudaThresholdedReluFunctor, \
CudaThresholdedReluGradFunctor); \
__macro(hard_swish, HardSwish, CudaHardSwishFunctor, \
Expand Down
41 changes: 41 additions & 0 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1412,6 +1412,46 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
template <typename T>
struct MishFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
auto sp = (x > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x.exp()).log());
out.device(d) = x * sp.tanh();
}
};

// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)
template <typename T>
struct MishGradFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}

template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
auto sp = (x > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x.exp()).log());
auto gsp = static_cast<T>(1) - (-sp).exp();
auto tsp = sp.tanh();
dx.device(d) = dout * (tsp + x * (static_cast<T>(1) - tsp * tsp) * gsp);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -2841,4 +2881,5 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
__macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, \
ThresholdedReluGradFunctor); \
__macro(mish, Mish, MishFunctor, MishGradFunctor); \
__macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor);
121 changes: 0 additions & 121 deletions paddle/fluid/operators/mish_op.cc

This file was deleted.

Loading

0 comments on commit 8c92337

Please sign in to comment.