Skip to content

Commit 49ece35

Browse files
authored
[Accuracy diff No.52] Fix accuracy diff for reciprocal API (#73128)
* fix reciprocal_grad accuracy * fix compile * refine
1 parent 4936b77 commit 49ece35

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3589,9 +3589,15 @@ struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
35893589

35903590
template <typename T>
35913591
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
3592+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
3593+
35923594
// dx = -dout * out^2
3593-
__device__ __forceinline__ T operator()(const T dout, const T out) const {
3594-
return -dout * out * out;
3595+
__device__ __forceinline__ T operator()(const T arg_dout,
3596+
const T arg_out) const {
3597+
MPType dout = static_cast<MPType>(arg_dout);
3598+
MPType out = static_cast<MPType>(arg_out);
3599+
return static_cast<T>(-dout *
3600+
static_cast<MPType>(static_cast<T>(out * out)));
35953601
}
35963602

35973603
static constexpr ActBwdOpFwdDeps FwdDeps() {

0 commit comments

Comments
 (0)