@@ -21,6 +21,7 @@ limitations under the License. */
2121#include " paddle/phi/common/amp_type_traits.h"
2222#include " paddle/phi/common/memory_utils.h"
2323#include " paddle/phi/core/dense_tensor.h"
24+ #include " paddle/phi/kernels/cast_kernel.h"
2425#include " paddle/phi/kernels/funcs/common_shape.h"
2526#include " paddle/phi/kernels/funcs/elementwise_utils.h"
2627#include " paddle/phi/kernels/funcs/for_range.h"
@@ -64,18 +65,28 @@ void CommonGradBroadcastCPU(const DenseTensor &x,
6465 const CPUContext &dev_ctx,
6566 DX_OP dx_op,
6667 DY_OP dy_op) {
68+ using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
69+
6770 std::vector<int64_t > index_array (max_dim, 0 );
6871 const T *x_data = x.data <T>();
6972 const T *y_data = y.data <T>();
7073 const Tout *out_data = out.data <Tout>();
7174 const Tout *dout_data = dout.data <Tout>();
72- T *dx_data = dx == nullptr ? nullptr : dev_ctx.Alloc <T>(dx);
73- T *dy_data = dy == nullptr ? nullptr : dev_ctx.Alloc <T>(dy);
74- if (dx_data != nullptr ) {
75- memset (dx_data, 0 , dx->numel () * sizeof (T));
75+
76+ DenseTensor dx_mp, dy_mp;
77+ MPType *dx_mp_data = nullptr ;
78+ MPType *dy_mp_data = nullptr ;
79+ if (dx != nullptr ) {
80+ dx_mp.Resize (dx->dims ());
81+ dev_ctx.Alloc <MPType>(&dx_mp);
82+ dx_mp_data = dx_mp.data <MPType>();
83+ memset (dx_mp_data, 0 , dx->numel () * sizeof (MPType));
7684 }
77- if (dy_data != nullptr ) {
78- memset (dy_data, 0 , dy->numel () * sizeof (T));
85+ if (dy != nullptr ) {
86+ dy_mp.Resize (dy->dims ());
87+ dev_ctx.Alloc <MPType>(&dy_mp);
88+ dy_mp_data = dy_mp.data <MPType>();
89+ memset (dy_mp_data, 0 , dy->numel () * sizeof (MPType));
7990 }
8091 const int64_t out_size = std::accumulate (out_dims_array,
8192 out_dims_array + max_dim,
@@ -87,22 +98,32 @@ void CommonGradBroadcastCPU(const DenseTensor &x,
8798 GetElementwiseIndex<int64_t >(x_dims_array, max_dim, index_array.data ());
8899 y_index =
89100 GetElementwiseIndex<int64_t >(y_dims_array, max_dim, index_array.data ());
90- if (dx_data != nullptr ) {
91- dx_data [x_index] += dx_op (x_data[x_index],
92- y_data[y_index],
93- out_data[out_index],
94- dout_data[out_index]);
101+ if (dx_mp_data != nullptr ) {
102+ dx_mp_data [x_index] += static_cast <MPType>( dx_op (x_data[x_index],
103+ y_data[y_index],
104+ out_data[out_index],
105+ dout_data[out_index]) );
95106 }
96- if (dy_data != nullptr ) {
97- dy_data [y_index] += dy_op (x_data[x_index],
98- y_data[y_index],
99- out_data[out_index],
100- dout_data[out_index]);
107+ if (dy_mp_data != nullptr ) {
108+ dy_mp_data [y_index] += static_cast <MPType>( dy_op (x_data[x_index],
109+ y_data[y_index],
110+ out_data[out_index],
111+ dout_data[out_index]) );
101112 }
102113
103114 UpdateElementwiseIndexArray<int64_t >(
104115 out_dims_array, max_dim, index_array.data ());
105116 }
117+ if (dx != nullptr ) {
118+ dev_ctx.Alloc <T>(dx);
119+ phi::CastKernel<MPType, CPUContext>(
120+ dev_ctx, dx_mp, phi::CppTypeToDataType<T>::Type (), dx);
121+ }
122+ if (dy != nullptr ) {
123+ dev_ctx.Alloc <T>(dy);
124+ phi::CastKernel<MPType, CPUContext>(
125+ dev_ctx, dy_mp, phi::CppTypeToDataType<T>::Type (), dy);
126+ }
106127}
107128
108129template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
0 commit comments