From b8591faf4fe3ef5c4935e109f61f4f8295e3096a Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Wed, 23 Jul 2025 18:15:18 +0800 Subject: [PATCH 1/2] [Accuracy diff No.141] Fix accuracy diff for paddle.Tensor.__mul__ API --- .../phi/kernels/funcs/elementwise_grad_base.h | 52 +++++++++++++------ test/legacy_test/test_multiply.py | 21 ++++++++ 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index 26d3195f6a5bfb..21d27a7f85e1d1 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -64,18 +64,28 @@ void CommonGradBroadcastCPU(const DenseTensor &x, const CPUContext &dev_ctx, DX_OP dx_op, DY_OP dy_op) { + using MPType = typename phi::dtype::MPTypeTrait::Type; + std::vector index_array(max_dim, 0); const T *x_data = x.data(); const T *y_data = y.data(); const Tout *out_data = out.data(); const Tout *dout_data = dout.data(); - T *dx_data = dx == nullptr ? nullptr : dev_ctx.Alloc(dx); - T *dy_data = dy == nullptr ? nullptr : dev_ctx.Alloc(dy); - if (dx_data != nullptr) { - memset(dx_data, 0, dx->numel() * sizeof(T)); + + DenseTensor dx_mp, dy_mp; + MPType *dx_mp_data = nullptr; + MPType *dy_mp_data = nullptr; + if (dx != nullptr) { + dx_mp.Resize(dx->dims()); + dev_ctx.Alloc(&dx_mp); + dx_mp_data = dx_mp.data(); + memset(dx_mp_data, 0, dx->numel() * sizeof(MPType)); } - if (dy_data != nullptr) { - memset(dy_data, 0, dy->numel() * sizeof(T)); + if (dy != nullptr) { + dy_mp.Resize(dy->dims()); + dev_ctx.Alloc(&dy_mp); + dy_mp_data = dy_mp.data(); + memset(dy_mp_data, 0, dy->numel() * sizeof(MPType)); } const int64_t out_size = std::accumulate(out_dims_array, out_dims_array + max_dim, @@ -87,22 +97,32 @@ void CommonGradBroadcastCPU(const DenseTensor &x, GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); - if (dx_data != nullptr) { - dx_data[x_index] += dx_op(x_data[x_index], - y_data[y_index], - out_data[out_index], - dout_data[out_index]); + if (dx_mp_data != nullptr) { + dx_mp_data[x_index] += static_cast(dx_op(x_data[x_index], + y_data[y_index], + out_data[out_index], + dout_data[out_index])); } - if (dy_data != nullptr) { - dy_data[y_index] += dy_op(x_data[x_index], - y_data[y_index], - out_data[out_index], - dout_data[out_index]); + if (dy_mp_data != nullptr) { + dy_mp_data[y_index] += static_cast(dy_op(x_data[x_index], + y_data[y_index], + out_data[out_index], + dout_data[out_index])); } UpdateElementwiseIndexArray( out_dims_array, max_dim, index_array.data()); } + if (dx != nullptr) { + dev_ctx.Alloc(dx); + phi::CastKernel( + dev_ctx, dx_mp, phi::CppTypeToDataType::Type(), dx); + } + if (dy != nullptr) { + dev_ctx.Alloc(dy); + phi::CastKernel( + dev_ctx, dy_mp, phi::CppTypeToDataType::Type(), dy); + } } template diff --git a/test/legacy_test/test_multiply.py b/test/legacy_test/test_multiply.py index 59702400db8e55..b82438e8aa4634 100755 --- a/test/legacy_test/test_multiply.py +++ b/test/legacy_test/test_multiply.py @@ -290,5 +290,26 @@ def init_shapes(self): self.y_shape = [5, 1] +class TestMultiplyApiBF16(unittest.TestCase): + # Now only check the successful run of multiply with bfloat16 and backward. + def setUp(self): + paddle.device.set_device('cpu') + + def test_multiply(self): + self.x_shape = [1, 1024, 32, 128] + self.y_shape = [1, 1024, 1, 128] + x = paddle.rand(self.x_shape, dtype='bfloat16') + x.stop_gradient = False + y = paddle.rand(self.y_shape, dtype='bfloat16') + y.stop_gradient = False + res = paddle.multiply(x, y) + loss = res.sum() + loss.backward() + assert x.grad is not None + assert x.grad.dtype == paddle.bfloat16 + assert y.grad is not None + assert y.grad.dtype == paddle.bfloat16 + + if __name__ == '__main__': unittest.main() From 244e7d0323fa105bb97203c45a005eb274400352 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Wed, 23 Jul 2025 18:32:19 +0800 Subject: [PATCH 2/2] add headfile --- paddle/phi/kernels/funcs/elementwise_grad_base.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index 21d27a7f85e1d1..300aeaac2626b8 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/elementwise_utils.h" #include "paddle/phi/kernels/funcs/for_range.h"