Skip to content

Commit 55face6

Browse files
authored
[Accuracy diff No.106] Fix accuracy diff for paddle.nn.functional.adaptive_avg_pool2d API (#74077)
* fix KernelPool2DGrad * fix pool2d * improve * add test
1 parent 9454bd4 commit 55face6

File tree

3 files changed

+56
-9
lines changed

3 files changed

+56
-9
lines changed

paddle/phi/kernels/funcs/fast_divmod.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ struct FastDivMod {
6767
return result;
6868
}
6969

70+
__device__ __forceinline__ uint32_t DivCeil(uint32_t n) const {
71+
DivModT res = Divmod(n);
72+
return res.val[1] > 0 ? res.val[0] + 1 : res.val[0];
73+
}
74+
7075
int32_t shift_val;
7176
uint32_t divisor;
7277
uint32_t multiplier;
@@ -108,6 +113,10 @@ struct FastDivMod<int64_t> {
108113
uint64_t q = Div(n);
109114
return {q, n - q * divisor};
110115
}
116+
__device__ __forceinline__ uint64_t DivCeil(uint32_t n) const {
117+
DivModT res = Divmod(n);
118+
return res.val[1] > 0 ? res.val[0] + 1 : res.val[0];
119+
}
111120

112121
int shift_val;
113122
uint64_t divisor;

paddle/phi/kernels/funcs/pooling.cu

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,21 @@ __device__ void OffsetPreparationFor4Dimension(IndexT index,
124124
}
125125
}
126126

127+
template <typename IndexT>
128+
__device__ void PreparationPoolSize(IndexT index,
129+
IndexT input_size,
130+
IndexT output_size,
131+
FastDivMod<IndexT> divmods,
132+
IndexT* tmp_size
133+
134+
) {
135+
IndexT left = (index == 0) ? 0 : divmods.Div(index * input_size);
136+
IndexT right = (index == output_size - 1)
137+
? input_size
138+
: divmods.DivCeil((index + 1) * input_size);
139+
*tmp_size = right - left;
140+
}
141+
127142
template <typename PoolProcess, typename T, typename IndexT>
128143
__global__ void KernelPool2D(const IndexT nthreads,
129144
const T* input_data,
@@ -304,22 +319,23 @@ __global__ void KernelPool2DGrad(
304319
output_grad += output_offset;
305320

306321
if (adaptive) {
322+
auto tmp_phstart = divmods.height.Divmod(h_offset * output_height);
323+
auto tmp_pwstart = divmods.width.Divmod(w_offset * output_width);
307324
auto tmp_phend = divmods.height.Divmod((h_offset + 1) * output_height);
308325
auto tmp_pwend = divmods.width.Divmod((w_offset + 1) * output_width);
309-
phstart = divmods.height.Div(h_offset * output_height);
310-
pwstart = divmods.width.Div(w_offset * output_width);
326+
phstart = tmp_phstart.val[0];
327+
pwstart = tmp_pwstart.val[0];
311328
phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
312329
pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];
313330

331+
IndexT tmp_height, tmp_width;
314332
for (IndexT ph = phstart; ph < phend; ++ph) {
333+
PreparationPoolSize(
334+
ph, input_height, output_height, divmods.ksize_h, &tmp_height);
335+
315336
for (IndexT pw = pwstart; pw < pwend; ++pw) {
316-
auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width);
317-
auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height);
318-
auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1
319-
: ksize_w_divmod.val[0];
320-
auto tmp_height = ksize_h_divmod.val[1] > 0
321-
? ksize_h_divmod.val[0] + 1
322-
: ksize_h_divmod.val[0];
337+
PreparationPoolSize(
338+
pw, input_width, output_width, divmods.ksize_w, &tmp_width);
323339
IndexT pool_size = tmp_height * tmp_width;
324340
IndexT tmp_idx = ph * output_width + pw;
325341
IndexT output_sub_idx =

test/legacy_test/test_adaptive_avg_pool2d.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,28 @@ def test_dynamic_graph(self):
217217
out_6.numpy(), self.res_3_np, rtol=1e-5, atol=1e-8
218218
)
219219

220+
def test_grad(self):
221+
for use_cuda in (
222+
[False, True] if core.is_compiled_with_cuda() else [False]
223+
):
224+
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
225+
paddle.disable_static(place=place)
226+
x = paddle.to_tensor(self.x_np)
227+
x.stop_gradient = False
228+
for output_size in [[3, 3], [2, 5], [8, 8]]:
229+
out = paddle.nn.functional.adaptive_avg_pool2d(
230+
x=x, output_size=output_size
231+
)
232+
x_grad = paddle.grad(
233+
[out],
234+
[x],
235+
grad_outputs=paddle.ones_like(out),
236+
allow_unused=True,
237+
)
238+
np.testing.assert_allclose(
239+
paddle.sum(x_grad[0]), out.numel(), rtol=1e-6
240+
)
241+
220242

221243
class TestAdaptiveAvgPool2DClassAPI(unittest.TestCase):
222244
def setUp(self):

0 commit comments

Comments
 (0)