Skip to content

Commit 9a0351c

Browse files
authored
[Accuracy diff No.57、69] Fix accuracy diff for sum API (#73012)
1 parent 8fea5d4 commit 9a0351c

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

paddle/phi/kernels/cpu/reduce_sum_kernel.cc

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,41 @@ void SumRawKernel(const Context& dev_ctx,
4444
out);
4545
return;
4646
}
47-
phi::Reduce<CPUContext, T, phi::funcs::SumFunctor>(
48-
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
47+
if constexpr (std::is_same_v<T, phi::dtype::float16> ||
48+
std::is_same_v<T, phi::dtype::bfloat16>) {
49+
DenseTensor x_fp32 = phi::Cast<T, Context>(dev_ctx, x, DataType::FLOAT32);
50+
DataType final_out_dtype = out_dtype;
51+
if (final_out_dtype == DataType::UNDEFINED) {
52+
final_out_dtype = x.dtype();
53+
}
54+
if (final_out_dtype == DataType::FLOAT32) {
55+
phi::Reduce<CPUContext, float, phi::funcs::SumFunctor>(
56+
dev_ctx,
57+
x_fp32,
58+
reduce_all,
59+
dims.GetData(),
60+
keep_dim,
61+
phi::DataType::UNDEFINED,
62+
out);
63+
} else {
64+
DenseTensor intermediate_result;
65+
intermediate_result.set_meta(out->meta());
66+
phi::Reduce<CPUContext, float, phi::funcs::SumFunctor>(
67+
dev_ctx,
68+
x_fp32,
69+
reduce_all,
70+
dims.GetData(),
71+
keep_dim,
72+
phi::DataType::UNDEFINED,
73+
&intermediate_result);
74+
75+
phi::CastKernel<float, Context>(
76+
dev_ctx, intermediate_result, final_out_dtype, out);
77+
}
78+
} else {
79+
phi::Reduce<CPUContext, T, phi::funcs::SumFunctor>(
80+
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
81+
}
4982
}
5083

5184
} // namespace phi

test/legacy_test/test_reduce_op.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,24 @@ def test_check_grad(self):
207207
)
208208

209209

210+
def create_test_fp16_class_cpu(parent):
211+
class TestSumOpFp16CPU(parent):
212+
def init_dtype(self):
213+
self.dtype = np.float16
214+
215+
def test_check_output(self):
216+
self.check_output(check_pir=True, rtol=1e-2, atol=1e-2)
217+
218+
def test_check_grad(self):
219+
self.check_grad(
220+
['X'],
221+
'Out',
222+
check_prim=True,
223+
check_prim_pir=True,
224+
check_pir=True,
225+
)
226+
227+
210228
class TestSumOp3D0size(TestSumOp3Dim):
211229

212230
def test_check_output(self):
@@ -261,6 +279,14 @@ def init_attrs(self):
261279
create_test_fp16_class(TestSumOp_withInt)
262280
create_test_fp16_class(TestSumOp3Dim)
263281

282+
create_test_fp16_class_cpu(TestSumOp)
283+
create_test_fp16_class_cpu(TestSumOp_ZeroDim)
284+
create_test_fp16_class_cpu(TestSumOp5D)
285+
create_test_fp16_class_cpu(TestSumOp6D)
286+
create_test_fp16_class_cpu(TestSumOp8D)
287+
create_test_fp16_class_cpu(TestSumOp_withInt)
288+
create_test_fp16_class_cpu(TestSumOp3Dim)
289+
264290

265291
def create_test_bf16_class(parent):
266292
@unittest.skipIf(

0 commit comments

Comments
 (0)