Skip to content

Commit 54947b2

Browse files
authored
[Accuracy diff No.141] Fix accuracy diff for paddle.Tensor.__mul__ API (#74198)
* [Accuracy diff No.141] Fix accuracy diff for paddle.Tensor.__mul__ API * add headfile
1 parent 80ad645 commit 54947b2

File tree

2 files changed

+58
-16
lines changed

2 files changed

+58
-16
lines changed

paddle/phi/kernels/funcs/elementwise_grad_base.h

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include "paddle/phi/common/amp_type_traits.h"
2222
#include "paddle/phi/common/memory_utils.h"
2323
#include "paddle/phi/core/dense_tensor.h"
24+
#include "paddle/phi/kernels/cast_kernel.h"
2425
#include "paddle/phi/kernels/funcs/common_shape.h"
2526
#include "paddle/phi/kernels/funcs/elementwise_utils.h"
2627
#include "paddle/phi/kernels/funcs/for_range.h"
@@ -64,18 +65,28 @@ void CommonGradBroadcastCPU(const DenseTensor &x,
6465
const CPUContext &dev_ctx,
6566
DX_OP dx_op,
6667
DY_OP dy_op) {
68+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
69+
6770
std::vector<int64_t> index_array(max_dim, 0);
6871
const T *x_data = x.data<T>();
6972
const T *y_data = y.data<T>();
7073
const Tout *out_data = out.data<Tout>();
7174
const Tout *dout_data = dout.data<Tout>();
72-
T *dx_data = dx == nullptr ? nullptr : dev_ctx.Alloc<T>(dx);
73-
T *dy_data = dy == nullptr ? nullptr : dev_ctx.Alloc<T>(dy);
74-
if (dx_data != nullptr) {
75-
memset(dx_data, 0, dx->numel() * sizeof(T));
75+
76+
DenseTensor dx_mp, dy_mp;
77+
MPType *dx_mp_data = nullptr;
78+
MPType *dy_mp_data = nullptr;
79+
if (dx != nullptr) {
80+
dx_mp.Resize(dx->dims());
81+
dev_ctx.Alloc<MPType>(&dx_mp);
82+
dx_mp_data = dx_mp.data<MPType>();
83+
memset(dx_mp_data, 0, dx->numel() * sizeof(MPType));
7684
}
77-
if (dy_data != nullptr) {
78-
memset(dy_data, 0, dy->numel() * sizeof(T));
85+
if (dy != nullptr) {
86+
dy_mp.Resize(dy->dims());
87+
dev_ctx.Alloc<MPType>(&dy_mp);
88+
dy_mp_data = dy_mp.data<MPType>();
89+
memset(dy_mp_data, 0, dy->numel() * sizeof(MPType));
7990
}
8091
const int64_t out_size = std::accumulate(out_dims_array,
8192
out_dims_array + max_dim,
@@ -87,22 +98,32 @@ void CommonGradBroadcastCPU(const DenseTensor &x,
8798
GetElementwiseIndex<int64_t>(x_dims_array, max_dim, index_array.data());
8899
y_index =
89100
GetElementwiseIndex<int64_t>(y_dims_array, max_dim, index_array.data());
90-
if (dx_data != nullptr) {
91-
dx_data[x_index] += dx_op(x_data[x_index],
92-
y_data[y_index],
93-
out_data[out_index],
94-
dout_data[out_index]);
101+
if (dx_mp_data != nullptr) {
102+
dx_mp_data[x_index] += static_cast<MPType>(dx_op(x_data[x_index],
103+
y_data[y_index],
104+
out_data[out_index],
105+
dout_data[out_index]));
95106
}
96-
if (dy_data != nullptr) {
97-
dy_data[y_index] += dy_op(x_data[x_index],
98-
y_data[y_index],
99-
out_data[out_index],
100-
dout_data[out_index]);
107+
if (dy_mp_data != nullptr) {
108+
dy_mp_data[y_index] += static_cast<MPType>(dy_op(x_data[x_index],
109+
y_data[y_index],
110+
out_data[out_index],
111+
dout_data[out_index]));
101112
}
102113

103114
UpdateElementwiseIndexArray<int64_t>(
104115
out_dims_array, max_dim, index_array.data());
105116
}
117+
if (dx != nullptr) {
118+
dev_ctx.Alloc<T>(dx);
119+
phi::CastKernel<MPType, CPUContext>(
120+
dev_ctx, dx_mp, phi::CppTypeToDataType<T>::Type(), dx);
121+
}
122+
if (dy != nullptr) {
123+
dev_ctx.Alloc<T>(dy);
124+
phi::CastKernel<MPType, CPUContext>(
125+
dev_ctx, dy_mp, phi::CppTypeToDataType<T>::Type(), dy);
126+
}
106127
}
107128

108129
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>

test/legacy_test/test_multiply.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,5 +290,26 @@ def init_shapes(self):
290290
self.y_shape = [5, 1]
291291

292292

293+
class TestMultiplyApiBF16(unittest.TestCase):
294+
# Now only check the successful run of multiply with bfloat16 and backward.
295+
def setUp(self):
296+
paddle.device.set_device('cpu')
297+
298+
def test_multiply(self):
299+
self.x_shape = [1, 1024, 32, 128]
300+
self.y_shape = [1, 1024, 1, 128]
301+
x = paddle.rand(self.x_shape, dtype='bfloat16')
302+
x.stop_gradient = False
303+
y = paddle.rand(self.y_shape, dtype='bfloat16')
304+
y.stop_gradient = False
305+
res = paddle.multiply(x, y)
306+
loss = res.sum()
307+
loss.backward()
308+
assert x.grad is not None
309+
assert x.grad.dtype == paddle.bfloat16
310+
assert y.grad is not None
311+
assert y.grad.dtype == paddle.bfloat16
312+
313+
293314
if __name__ == '__main__':
294315
unittest.main()

0 commit comments

Comments
 (0)