Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddle/phi/kernels/funcs/fast_divmod.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ struct FastDivMod {
return result;
}

__device__ __forceinline__ uint32_t DivCeil(uint32_t n) const {
DivModT res = Divmod(n);
return res.val[1] > 0 ? res.val[0] + 1 : res.val[0];
}

int32_t shift_val;
uint32_t divisor;
uint32_t multiplier;
Expand Down Expand Up @@ -108,6 +113,10 @@ struct FastDivMod<int64_t> {
uint64_t q = Div(n);
return {q, n - q * divisor};
}
__device__ __forceinline__ uint64_t DivCeil(uint32_t n) const {
DivModT res = Divmod(n);
return res.val[1] > 0 ? res.val[0] + 1 : res.val[0];
}

int shift_val;
uint64_t divisor;
Expand Down
34 changes: 25 additions & 9 deletions paddle/phi/kernels/funcs/pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ __device__ void OffsetPreparationFor4Dimension(IndexT index,
}
}

template <typename IndexT>
__device__ void PreparationPoolSize(IndexT index,
IndexT input_size,
IndexT output_size,
FastDivMod<IndexT> divmods,
IndexT* tmp_size

) {
IndexT left = (index == 0) ? 0 : divmods.Div(index * input_size);
IndexT right = (index == output_size - 1)
? input_size
: divmods.DivCeil((index + 1) * input_size);
*tmp_size = right - left;
}

template <typename PoolProcess, typename T, typename IndexT>
__global__ void KernelPool2D(const IndexT nthreads,
const T* input_data,
Expand Down Expand Up @@ -300,22 +315,23 @@ __global__ void KernelPool2DGrad(
output_grad += output_offset;

if (adaptive) {
auto tmp_phstart = divmods.height.Divmod(h_offset * output_height);
auto tmp_pwstart = divmods.width.Divmod(w_offset * output_width);
auto tmp_phend = divmods.height.Divmod((h_offset + 1) * output_height);
auto tmp_pwend = divmods.width.Divmod((w_offset + 1) * output_width);
phstart = divmods.height.Div(h_offset * output_height);
pwstart = divmods.width.Div(w_offset * output_width);
phstart = tmp_phstart.val[0];
pwstart = tmp_pwstart.val[0];
phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];

IndexT tmp_height, tmp_width;
for (IndexT ph = phstart; ph < phend; ++ph) {
PreparationPoolSize(
ph, input_height, output_height, divmods.ksize_h, &tmp_height);

for (IndexT pw = pwstart; pw < pwend; ++pw) {
auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width);
auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height);
auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1
: ksize_w_divmod.val[0];
auto tmp_height = ksize_h_divmod.val[1] > 0
? ksize_h_divmod.val[0] + 1
: ksize_h_divmod.val[0];
PreparationPoolSize(
pw, input_width, output_width, divmods.ksize_w, &tmp_width);
IndexT pool_size = tmp_height * tmp_width;
IndexT tmp_idx = ph * output_width + pw;
IndexT output_sub_idx =
Expand Down
22 changes: 22 additions & 0 deletions test/legacy_test/test_adaptive_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,28 @@ def test_dynamic_graph(self):
out_6.numpy(), self.res_3_np, rtol=1e-5, atol=1e-8
)

def test_grad(self):
for use_cuda in (
[False, True] if core.is_compiled_with_cuda() else [False]
):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
x = paddle.to_tensor(self.x_np)
x.stop_gradient = False
for output_size in [[3, 3], [2, 5], [8, 8]]:
out = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=output_size
)
x_grad = paddle.grad(
[out],
[x],
grad_outputs=paddle.ones_like(out),
allow_unused=True,
)
np.testing.assert_allclose(
paddle.sum(x_grad[0]), out.numel(), rtol=1e-6
)


class TestAdaptiveAvgPool2DClassAPI(unittest.TestCase):
def setUp(self):
Expand Down