diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 381183b3255c6c..a684a66f6709b9 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -3589,9 +3589,15 @@ struct CudaReciprocalFunctor : public BaseActivationFunctor { template struct CudaReciprocalGradFunctor : public BaseActivationFunctor { + using MPType = typename phi::dtype::MPTypeTrait::Type; + // dx = -dout * out^2 - __device__ __forceinline__ T operator()(const T dout, const T out) const { - return -dout * out * out; + __device__ __forceinline__ T operator()(const T arg_dout, + const T arg_out) const { + MPType dout = static_cast(arg_dout); + MPType out = static_cast(arg_out); + return static_cast(-dout * + static_cast(static_cast(out * out))); } static constexpr ActBwdOpFwdDeps FwdDeps() {