@@ -124,6 +124,21 @@ __device__ void OffsetPreparationFor4Dimension(IndexT index,
124124 }
125125}
126126
127+ template <typename IndexT>
128+ __device__ void PreparationPoolSize (IndexT index,
129+ IndexT input_size,
130+ IndexT output_size,
131+ FastDivMod<IndexT> divmods,
132+ IndexT* tmp_size
133+
134+ ) {
135+ IndexT left = (index == 0 ) ? 0 : divmods.Div (index * input_size);
136+ IndexT right = (index == output_size - 1 )
137+ ? input_size
138+ : divmods.DivCeil ((index + 1 ) * input_size);
139+ *tmp_size = right - left;
140+ }
141+
127142template <typename PoolProcess, typename T, typename IndexT>
128143__global__ void KernelPool2D (const IndexT nthreads,
129144 const T* input_data,
@@ -304,22 +319,23 @@ __global__ void KernelPool2DGrad(
304319 output_grad += output_offset;
305320
306321 if (adaptive) {
322+ auto tmp_phstart = divmods.height .Divmod (h_offset * output_height);
323+ auto tmp_pwstart = divmods.width .Divmod (w_offset * output_width);
307324 auto tmp_phend = divmods.height .Divmod ((h_offset + 1 ) * output_height);
308325 auto tmp_pwend = divmods.width .Divmod ((w_offset + 1 ) * output_width);
309- phstart = divmods. height . Div (h_offset * output_height) ;
310- pwstart = divmods. width . Div (w_offset * output_width) ;
326+ phstart = tmp_phstart. val [ 0 ] ;
327+ pwstart = tmp_pwstart. val [ 0 ] ;
311328 phend = tmp_phend.val [1 ] > 0 ? tmp_phend.val [0 ] + 1 : tmp_phend.val [0 ];
312329 pwend = tmp_pwend.val [1 ] > 0 ? tmp_pwend.val [0 ] + 1 : tmp_pwend.val [0 ];
313330
331+ IndexT tmp_height, tmp_width;
314332 for (IndexT ph = phstart; ph < phend; ++ph) {
333+ PreparationPoolSize (
334+ ph, input_height, output_height, divmods.ksize_h , &tmp_height);
335+
315336 for (IndexT pw = pwstart; pw < pwend; ++pw) {
316- auto ksize_w_divmod = divmods.ksize_w .Divmod (input_width);
317- auto ksize_h_divmod = divmods.ksize_h .Divmod (input_height);
318- auto tmp_width = ksize_w_divmod.val [1 ] > 0 ? ksize_w_divmod.val [0 ] + 1
319- : ksize_w_divmod.val [0 ];
320- auto tmp_height = ksize_h_divmod.val [1 ] > 0
321- ? ksize_h_divmod.val [0 ] + 1
322- : ksize_h_divmod.val [0 ];
337+ PreparationPoolSize (
338+ pw, input_width, output_width, divmods.ksize_w , &tmp_width);
323339 IndexT pool_size = tmp_height * tmp_width;
324340 IndexT tmp_idx = ph * output_width + pw;
325341 IndexT output_sub_idx =
0 commit comments