Skip to content

Commit 461166c

Browse files
authored
[Accuracy diff No.20、61、9-10] Fix accuracy diff for expand_grad API (#72992)
* fix * fix
1 parent ac7aa19 commit 461166c

File tree

1 file changed

+44
-12
lines changed

1 file changed

+44
-12
lines changed

paddle/phi/kernels/impl/expand_grad_kernel_impl.h

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "paddle/phi/core/tensor_utils.h"
18+
#include "paddle/phi/kernels/cast_kernel.h"
1819
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/eigen/common.h"
2021
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
@@ -32,19 +33,50 @@ void ExpandBackward(const Context& dev_ctx,
3233
dev_ctx.template Alloc<T>(in_grad);
3334
in_grad->data<T>();
3435

35-
auto x_grad = EigenVector<T>::Flatten(*in_grad);
36-
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
37-
for (size_t i = 0; i < reshape_size; ++i) {
38-
reshape_dims[i] = reshape_dims_vec[i];
39-
}
40-
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
41-
for (size_t i = 0; i < reduce_size; ++i) {
42-
reduce_dims[i] = reduce_dims_vec[i];
36+
if constexpr (std::is_same_v<T, dtype::float16> ||
37+
std::is_same_v<T, dtype::bfloat16>) {
38+
const DenseTensor out_grad_fp32 =
39+
phi::Cast<T, Context>(dev_ctx, out_grad, DataType::FLOAT32);
40+
DenseTensor in_grad_fp32;
41+
in_grad_fp32.Resize(in_grad->dims());
42+
dev_ctx.template Alloc<float>(&in_grad_fp32);
43+
44+
auto x_grad = EigenVector<float>::Flatten(in_grad_fp32);
45+
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
46+
for (size_t i = 0; i < reshape_size; ++i) {
47+
reshape_dims[i] = reshape_dims_vec[i];
48+
}
49+
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
50+
for (size_t i = 0; i < reduce_size; ++i) {
51+
reduce_dims[i] = reduce_dims_vec[i];
52+
}
53+
const auto out_grad0 = EigenVector<float>::Flatten(out_grad_fp32);
54+
auto& place = *dev_ctx.eigen_device();
55+
phi::funcs::EigenBroadcastGrad<std::decay_t<decltype(place)>, float, Dims>::
56+
Eval(place, x_grad, out_grad0, reduce_dims, reshape_dims);
57+
58+
if constexpr (std::is_same_v<T, dtype::float16>) {
59+
phi::CastKernel<float, Context>(
60+
dev_ctx, in_grad_fp32, DataType::FLOAT16, in_grad);
61+
} else {
62+
phi::CastKernel<float, Context>(
63+
dev_ctx, in_grad_fp32, DataType::BFLOAT16, in_grad);
64+
}
65+
} else {
66+
auto x_grad = EigenVector<T>::Flatten(*in_grad);
67+
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
68+
for (size_t i = 0; i < reshape_size; ++i) {
69+
reshape_dims[i] = reshape_dims_vec[i];
70+
}
71+
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
72+
for (size_t i = 0; i < reduce_size; ++i) {
73+
reduce_dims[i] = reduce_dims_vec[i];
74+
}
75+
auto out_grad0 = EigenVector<T>::Flatten(out_grad);
76+
auto& place = *dev_ctx.eigen_device();
77+
phi::funcs::EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::
78+
Eval(place, x_grad, out_grad0, reduce_dims, reshape_dims);
4379
}
44-
auto out_grad0 = EigenVector<T>::Flatten(out_grad);
45-
auto& place = *dev_ctx.eigen_device();
46-
phi::funcs::EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
47-
place, x_grad, out_grad0, reduce_dims, reshape_dims);
4880
}
4981

5082
template <typename T, typename Context>

0 commit comments

Comments
 (0)