File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -3589,9 +3589,15 @@ struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
35893589
35903590template <typename T>
35913591struct CudaReciprocalGradFunctor : public BaseActivationFunctor <T> {
3592+ using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
3593+
35923594 // dx = -dout * out^2
3593- __device__ __forceinline__ T operator ()(const T dout, const T out) const {
3594- return -dout * out * out;
3595+ __device__ __forceinline__ T operator ()(const T arg_dout,
3596+ const T arg_out) const {
3597+ MPType dout = static_cast <MPType>(arg_dout);
3598+ MPType out = static_cast <MPType>(arg_out);
3599+ return static_cast <T>(-dout *
3600+ static_cast <MPType>(static_cast <T>(out * out)));
35953601 }
35963602
35973603 static constexpr ActBwdOpFwdDeps FwdDeps () {
You can’t perform that action at this time.
0 commit comments