From a38f7a1f407c4410e10f2ad0c669250cc344f00b Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 6 May 2019 14:56:03 +0800 Subject: [PATCH 01/13] refactor roi_pooling backward --- src/operator/roi_pooling.cc | 76 +++++-------------------------------- 1 file changed, 10 insertions(+), 66 deletions(-) diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 8862d0db1401..96792e09958c 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -144,7 +144,6 @@ inline void ROIPoolBackwardAcc(const Tensor &in_grad, Dtype *bottom_diff = in_grad.dptr_; Dtype *argmax_data = max_idx.dptr_; - const int batch_size_ = in_grad.size(0); const int channels_ = in_grad.size(1); const int height_ = in_grad.size(2); const int width_ = in_grad.size(3); @@ -153,73 +152,18 @@ inline void ROIPoolBackwardAcc(const Tensor &in_grad, const int num_rois = bbox.size(0); - for (int b = 0; b < batch_size_; ++b) { + for (int r = 0; r < num_rois; ++r) { + int b = static_cast(bottom_rois[r * 5]); for (int c = 0; c < channels_; ++c) { - for (int h = 0; h < height_; ++h) { - for (int w = 0; w < width_; ++w) { - int offset_bottom_diff = (b * channels_ + c) * height_ * width_; - offset_bottom_diff += h * width_ + w; - - Dtype gradient = 0; - // Accumulate gradient over all ROIs that pooled this element - for (int roi_n = 0; roi_n < num_rois; ++roi_n) { - const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5; - int roi_batch_ind = offset_bottom_rois[0]; - assert(roi_batch_ind >= 0); - assert(roi_batch_ind < batch_size_); - if (b != roi_batch_ind) { - continue; - } - - int roi_start_w = std::round(offset_bottom_rois[1] * spatial_scale_); - int roi_start_h = std::round(offset_bottom_rois[2] * spatial_scale_); - int roi_end_w = std::round(offset_bottom_rois[3] * spatial_scale_); - int roi_end_h = std::round(offset_bottom_rois[4] * spatial_scale_); - - bool in_roi = (w >= roi_start_w && w <= roi_end_w && - h >= roi_start_h && h <= roi_end_h); - if (!in_roi) { - continue; - } - - // force malformed ROIs to be 1 * 1 - int roi_height = max(roi_end_h - roi_start_h + 1, 1); - int roi_width = max(roi_end_w - roi_start_w + 1, 1); - const Dtype bin_size_h = static_cast(roi_height) - / static_cast(pooled_height_); - const Dtype bin_size_w = static_cast(roi_width) - / static_cast(pooled_width_); - - // compute pooled regions correspond to original (h, w) point - int phstart = static_cast(floor(static_cast(h - roi_start_h) - / bin_size_h)); - int pwstart = static_cast(floor(static_cast(w - roi_start_w) - / bin_size_w)); - int phend = static_cast(ceil(static_cast(h - roi_start_h + 1) - / bin_size_h)); - int pwend = static_cast(ceil(static_cast(w - roi_start_w + 1) - / bin_size_w)); - - // clip to boundaries of pooled region - phstart = min(max(phstart, 0), pooled_height_); - phend = min(max(phend, 0), pooled_height_); - pwstart = min(max(pwstart, 0), pooled_width_); - pwend = min(max(pwend, 0), pooled_width_); - - // accumulate over gradients in pooled regions - int offset = (roi_n * channels_ + c) * pooled_height_ * pooled_width_; - const Dtype* offset_top_diff = top_diff + offset; - const Dtype* offset_argmax_data = argmax_data + offset; - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - const int pooled_index = ph * pooled_width_ + pw; - if (static_cast(offset_argmax_data[pooled_index]) == h * width_ + w) { - gradient += offset_top_diff[pooled_index]; - } - } - } + for (int h = 0; h < pooled_height_; ++h) { + for (int w = 0; w < pooled_width_; ++w) { + int offset_top = (r * channels_ + c) * pooled_height_ * pooled_width_; + offset_top += h * pooled_width_ + w; + int max_idx = static_cast(argmax_data[offset_top]); + if (max_idx >= 0) { + int offset_bottom_diff = (b * channels_ + c) * height_ * width_ + max_idx; + bottom_diff[offset_bottom_diff] += top_diff[offset_top]; } - bottom_diff[offset_bottom_diff] += gradient; } } } From aa7f8291999b5e9bf444750ca2e4e45fbbbf9c95 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 6 May 2019 15:28:56 +0800 Subject: [PATCH 02/13] update max_idx --- src/operator/roi_pooling.cc | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 96792e09958c..d3ae7b81124d 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -81,12 +81,13 @@ inline void ROIPoolForward(const Tensor &out, const Dtype bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); - const Dtype* batch_data = bottom_data + data_size * roi_batch_ind; + int offset_batch_data = data_size * roi_batch_ind; #pragma omp parallel for for (int c = 0; c < channels_; ++c) { // Increment all data pointers - const Dtype* batch_data_c = batch_data + c * data_size_c; + int offset_batch_data_c = offset_batch_data + c * data_size_c; + const Dtype* batch_data_c = bottom_data + offset_batch_data_c; Dtype* top_data_c = top_data_n + c * out_size_c; Dtype* argmax_data_c = argmax_data_n + c * max_idx_size_c; @@ -122,7 +123,7 @@ inline void ROIPoolForward(const Tensor &out, const int index = h * width_ + w; if (batch_data_c[index] > top_data_c[pool_index]) { top_data_c[pool_index] = batch_data_c[index]; - argmax_data_c[pool_index] = index; + argmax_data_c[pool_index] = offset_batch_data_c + index; } } } @@ -140,20 +141,15 @@ inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &max_idx, const float spatial_scale_) { const Dtype *top_diff = out_grad.dptr_; - const Dtype *bottom_rois = bbox.dptr_; Dtype *bottom_diff = in_grad.dptr_; Dtype *argmax_data = max_idx.dptr_; const int channels_ = in_grad.size(1); - const int height_ = in_grad.size(2); - const int width_ = in_grad.size(3); const int pooled_height_ = out_grad.size(2); const int pooled_width_ = out_grad.size(3); - const int num_rois = bbox.size(0); for (int r = 0; r < num_rois; ++r) { - int b = static_cast(bottom_rois[r * 5]); for (int c = 0; c < channels_; ++c) { for (int h = 0; h < pooled_height_; ++h) { for (int w = 0; w < pooled_width_; ++w) { @@ -161,8 +157,7 @@ inline void ROIPoolBackwardAcc(const Tensor &in_grad, offset_top += h * pooled_width_ + w; int max_idx = static_cast(argmax_data[offset_top]); if (max_idx >= 0) { - int offset_bottom_diff = (b * channels_ + c) * height_ * width_ + max_idx; - bottom_diff[offset_bottom_diff] += top_diff[offset_top]; + bottom_diff[max_idx] += top_diff[offset_top]; } } } From 4653a97f0ffd03d919f1e29f9dc884005e90b64a Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 6 May 2019 08:36:27 +0000 Subject: [PATCH 03/13] refactor gpu part --- src/operator/roi_pooling.cu | 83 ++++++------------------------------- 1 file changed, 13 insertions(+), 70 deletions(-) diff --git a/src/operator/roi_pooling.cu b/src/operator/roi_pooling.cu index 9ea99b309aaf..2c69d73bd5ef 100644 --- a/src/operator/roi_pooling.cu +++ b/src/operator/roi_pooling.cu @@ -34,8 +34,8 @@ namespace cuda { template __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, - const float spatial_scale, const int channels, - const int height, const int width, + const float spatial_scale, const int batch_size, + const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const Dtype* bottom_rois, Dtype* top_data, Dtype* argmax_data) { @@ -51,9 +51,9 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, bottom_rois += n * 5; int roi_batch_ind = bottom_rois[0]; - if (roi_batch_ind < 0) { + if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) { top_data[index] = 0; - argmax_data[index] = 0; + argmax_data[index] = -1; continue; } @@ -90,13 +90,14 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, Dtype maxval = is_empty ? 0 : -FLT_MAX; // If nothing is pooled, argmax = -1 causes nothing to be backprop'd int maxidx = -1; - bottom_data += (roi_batch_ind * channels + c) * height * width; + int offset_bottom_data = (roi_batch_ind * channels + c) * height * width; + bottom_data += offset_bottom_data; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int bottom_index = h * width + w; if (bottom_data[bottom_index] > maxval) { maxval = bottom_data[bottom_index]; - maxidx = bottom_index; + maxidx = offset_bottom_data + bottom_index; } } } @@ -116,6 +117,7 @@ inline void ROIPoolForward(const Tensor &out, Dtype *top_data = out.dptr_; Dtype *argmax_data = max_idx.dptr_; const int count = out.shape_.Size(); + const int batch_size = data.size(0); const int channels = data.size(1); const int height = data.size(2); const int width = data.size(3); @@ -127,7 +129,7 @@ inline void ROIPoolForward(const Tensor &out, CheckLaunchParam(dimGrid, dimBlock, "ROIPooling Forward"); cudaStream_t stream = Stream::GetStream(out.stream_); ROIPoolForwardKernel<<>>( - count, bottom_data, spatial_scale, channels, height, width, + count, bottom_data, spatial_scale, batch_size, channels, height, width, pooled_height, pooled_width, bottom_rois, top_data, argmax_data); MSHADOW_CUDA_POST_KERNEL_CHECK(ROIPoolForwardKernel); } @@ -142,69 +144,10 @@ __global__ void ROIPoolBackwardAccKernel(const int count, const Dtype* top_diff, for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x * gridDim.y) { - // (n, c, h, w) coords in bottom data - int w = index % width; - int h = (index / width) % height; - int c = (index / width / height) % channels; - int n = index / width / height / channels; - - Dtype gradient = 0; - // Accumulate gradient over all ROIs that pooled this element - for (int roi_n = 0; roi_n < num_rois; ++roi_n) { - const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5; - int roi_batch_ind = offset_bottom_rois[0]; - // Skip if ROI's batch index doesn't match n - if (n != roi_batch_ind) { - continue; - } - - int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); - - // Skip if ROI doesn't include (h, w) - const bool in_roi = (w >= roi_start_w && w <= roi_end_w && - h >= roi_start_h && h <= roi_end_h); - if (!in_roi) { - continue; - } - - int offset = (roi_n * channels + c) * pooled_height * pooled_width; - const Dtype* offset_top_diff = top_diff + offset; - const Dtype* offset_argmax_data = argmax_data + offset; - - // Compute feasible set of pooled units that could have pooled - // this bottom unit - - // Force malformed ROIs to be 1x1 - int roi_width = max(roi_end_w - roi_start_w + 1, 1); - int roi_height = max(roi_end_h - roi_start_h + 1, 1); - - Dtype bin_size_h = static_cast(roi_height) - / static_cast(pooled_height); - Dtype bin_size_w = static_cast(roi_width) - / static_cast(pooled_width); - - int phstart = floor(static_cast(h - roi_start_h) / bin_size_h); - int phend = ceil(static_cast(h - roi_start_h + 1) / bin_size_h); - int pwstart = floor(static_cast(w - roi_start_w) / bin_size_w); - int pwend = ceil(static_cast(w - roi_start_w + 1) / bin_size_w); - - phstart = min(max(phstart, 0), pooled_height); - phend = min(max(phend, 0), pooled_height); - pwstart = min(max(pwstart, 0), pooled_width); - pwend = min(max(pwend, 0), pooled_width); - - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - if (static_cast(offset_argmax_data[ph * pooled_width + pw]) == (h * width + w)) { - gradient += offset_top_diff[ph * pooled_width + pw]; - } - } - } + int max_idx = static_cast(argmax_data[index]); + if (max_idx >= 0) { + atomicAdd(&bottom_diff[max_idx], top_diff[index]); } - bottom_diff[index] += gradient; } } @@ -218,7 +161,7 @@ inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Dtype *bottom_rois = bbox.dptr_; Dtype *bottom_diff = in_grad.dptr_; Dtype *argmax_data = max_idx.dptr_; - const int count = in_grad.shape_.Size(); + const int count = out_grad.shape_.Size(); const int num_rois = bbox.size(0); const int channels = in_grad.size(1); const int height = in_grad.size(2); From 094678660c93acf060564a6e1cfe92f658732bac Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 6 May 2019 16:45:23 +0800 Subject: [PATCH 04/13] update --- src/operator/roi_pooling.cc | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index d3ae7b81124d..4198728cccfd 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -144,23 +144,12 @@ inline void ROIPoolBackwardAcc(const Tensor &in_grad, Dtype *bottom_diff = in_grad.dptr_; Dtype *argmax_data = max_idx.dptr_; - const int channels_ = in_grad.size(1); - const int pooled_height_ = out_grad.size(2); - const int pooled_width_ = out_grad.size(3); - const int num_rois = bbox.size(0); + const int count = out_grad.shape_.Size(); - for (int r = 0; r < num_rois; ++r) { - for (int c = 0; c < channels_; ++c) { - for (int h = 0; h < pooled_height_; ++h) { - for (int w = 0; w < pooled_width_; ++w) { - int offset_top = (r * channels_ + c) * pooled_height_ * pooled_width_; - offset_top += h * pooled_width_ + w; - int max_idx = static_cast(argmax_data[offset_top]); - if (max_idx >= 0) { - bottom_diff[max_idx] += top_diff[offset_top]; - } - } - } + for (int index = 0; index < count; ++index) { + int max_idx = static_cast(argmax_data[index]); + if (max_idx >= 0) { + bottom_diff[max_idx] += top_diff[index]; } } From f30db1b2c4e082271364d13441018b15f1cd9533 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 6 May 2019 15:20:54 +0000 Subject: [PATCH 05/13] update max_idx as int --- src/operator/roi_pooling-inl.h | 6 +++--- src/operator/roi_pooling.cc | 14 +++++++------- src/operator/roi_pooling.cu | 20 ++++++++++---------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/operator/roi_pooling-inl.h b/src/operator/roi_pooling-inl.h index ce0efe9b07c9..b20130edf4a6 100644 --- a/src/operator/roi_pooling-inl.h +++ b/src/operator/roi_pooling-inl.h @@ -83,7 +83,7 @@ class ROIPoolingOp : public Operator { Tensor data = in_data[roipool::kData].get(s); Tensor bbox = in_data[roipool::kBox].get(s); Tensor out = out_data[roipool::kOut].get(s); - Tensor max_idx = out_data[roipool::kMaxIdx].get(s); + Tensor max_idx = out_data[roipool::kMaxIdx].get(s); CHECK_EQ(data.CheckContiguous(), true); CHECK_EQ(bbox.CheckContiguous(), true); CHECK_EQ(out.CheckContiguous(), true); @@ -114,7 +114,7 @@ class ROIPoolingOp : public Operator { Tensor grad_out = out_grad[roipool::kOut].get(s); Tensor bbox = in_data[roipool::kBox].get(s); - Tensor max_idx = out_data[roipool::kMaxIdx].get(s); + Tensor max_idx = out_data[roipool::kMaxIdx].get(s); Tensor grad_in = in_grad[roipool::kData].get(s); Tensor grad_roi = in_grad[roipool::kBox].get(s); CHECK_EQ(grad_out.CheckContiguous(), true); @@ -202,7 +202,7 @@ class ROIPoolingProp : public OperatorProperty { out_type->clear(); out_type->push_back(dtype); - out_type->push_back(dtype); + out_type->push_back(mshadow::kInt32); return true; } diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 4198728cccfd..76c4518d5752 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -40,12 +40,12 @@ template inline void ROIPoolForward(const Tensor &out, const Tensor &data, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale_) { const Dtype *bottom_data = data.dptr_; const Dtype *bottom_rois = bbox.dptr_; Dtype *top_data = out.dptr_; - Dtype *argmax_data = max_idx.dptr_; + int *argmax_data = max_idx.dptr_; const int channels_ = data.size(1); const int height_ = data.size(2); const int width_ = data.size(3); @@ -64,7 +64,7 @@ inline void ROIPoolForward(const Tensor &out, // Increment ROI data pointer const Dtype *bottom_rois_n = bottom_rois + n * bbox.size(1); Dtype *top_data_n = top_data + n * out_size; - Dtype *argmax_data_n = argmax_data + n * max_idx_size; + int *argmax_data_n = argmax_data + n * max_idx_size; int roi_batch_ind = bottom_rois_n[0]; int roi_start_w = std::round(bottom_rois_n[1] * spatial_scale_); int roi_start_h = std::round(bottom_rois_n[2] * spatial_scale_); @@ -89,7 +89,7 @@ inline void ROIPoolForward(const Tensor &out, int offset_batch_data_c = offset_batch_data + c * data_size_c; const Dtype* batch_data_c = bottom_data + offset_batch_data_c; Dtype* top_data_c = top_data_n + c * out_size_c; - Dtype* argmax_data_c = argmax_data_n + c * max_idx_size_c; + int* argmax_data_c = argmax_data_n + c * max_idx_size_c; for (int ph = 0; ph < pooled_height_; ++ph) { for (int pw = 0; pw < pooled_width_; ++pw) { @@ -138,16 +138,16 @@ template inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &out_grad, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale_) { const Dtype *top_diff = out_grad.dptr_; Dtype *bottom_diff = in_grad.dptr_; - Dtype *argmax_data = max_idx.dptr_; + int *argmax_data = max_idx.dptr_; const int count = out_grad.shape_.Size(); for (int index = 0; index < count; ++index) { - int max_idx = static_cast(argmax_data[index]); + int max_idx = argmax_data[index]; if (max_idx >= 0) { bottom_diff[max_idx] += top_diff[index]; } diff --git a/src/operator/roi_pooling.cu b/src/operator/roi_pooling.cu index 2c69d73bd5ef..6d85d5a54973 100644 --- a/src/operator/roi_pooling.cu +++ b/src/operator/roi_pooling.cu @@ -38,7 +38,7 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const Dtype* bottom_rois, Dtype* top_data, - Dtype* argmax_data) { + int* argmax_data) { for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x * gridDim.y) { @@ -102,7 +102,7 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, } } top_data[index] = maxval; - argmax_data[index] = (Dtype)maxidx; + argmax_data[index] = maxidx; } } @@ -110,12 +110,12 @@ template inline void ROIPoolForward(const Tensor &out, const Tensor &data, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { const Dtype *bottom_data = data.dptr_; const Dtype *bottom_rois = bbox.dptr_; Dtype *top_data = out.dptr_; - Dtype *argmax_data = max_idx.dptr_; + int *argmax_data = max_idx.dptr_; const int count = out.shape_.Size(); const int batch_size = data.size(0); const int channels = data.size(1); @@ -136,7 +136,7 @@ inline void ROIPoolForward(const Tensor &out, template __global__ void ROIPoolBackwardAccKernel(const int count, const Dtype* top_diff, - const Dtype* argmax_data, const int num_rois, + const int* argmax_data, const int num_rois, const float spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, @@ -144,7 +144,7 @@ __global__ void ROIPoolBackwardAccKernel(const int count, const Dtype* top_diff, for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x * gridDim.y) { - int max_idx = static_cast(argmax_data[index]); + int max_idx = argmax_data[index]; if (max_idx >= 0) { atomicAdd(&bottom_diff[max_idx], top_diff[index]); } @@ -155,12 +155,12 @@ template inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &out_grad, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { const Dtype *top_diff = out_grad.dptr_; const Dtype *bottom_rois = bbox.dptr_; Dtype *bottom_diff = in_grad.dptr_; - Dtype *argmax_data = max_idx.dptr_; + int *argmax_data = max_idx.dptr_; const int count = out_grad.shape_.Size(); const int num_rois = bbox.size(0); const int channels = in_grad.size(1); @@ -185,7 +185,7 @@ template inline void ROIPoolForward(const Tensor &out, const Tensor &data, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { cuda::ROIPoolForward(out, data, bbox, max_idx, spatial_scale); } @@ -194,7 +194,7 @@ template inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &out_grad, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { cuda::ROIPoolBackwardAcc(in_grad, out_grad, bbox, max_idx, spatial_scale); } From 8d354cc8efb8904175c4e3dfe38e8e5091bae1ea Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Mon, 6 May 2019 23:54:43 +0800 Subject: [PATCH 06/13] remove unused arguments --- src/operator/roi_pooling.cu | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/operator/roi_pooling.cu b/src/operator/roi_pooling.cu index 6d85d5a54973..38398151a387 100644 --- a/src/operator/roi_pooling.cu +++ b/src/operator/roi_pooling.cu @@ -136,11 +136,7 @@ inline void ROIPoolForward(const Tensor &out, template __global__ void ROIPoolBackwardAccKernel(const int count, const Dtype* top_diff, - const int* argmax_data, const int num_rois, - const float spatial_scale, const int channels, - const int height, const int width, - const int pooled_height, const int pooled_width, - Dtype* bottom_diff, const Dtype* bottom_rois) { + const int* argmax_data, Dtype* bottom_diff) { for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x * gridDim.y) { @@ -158,24 +154,16 @@ inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &max_idx, const float spatial_scale) { const Dtype *top_diff = out_grad.dptr_; - const Dtype *bottom_rois = bbox.dptr_; Dtype *bottom_diff = in_grad.dptr_; int *argmax_data = max_idx.dptr_; const int count = out_grad.shape_.Size(); - const int num_rois = bbox.size(0); - const int channels = in_grad.size(1); - const int height = in_grad.size(2); - const int width = in_grad.size(3); - const int pooled_height = out_grad.size(2); - const int pooled_width = out_grad.size(3); const int gridSize = (count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; dim3 dimGrid(kMaxGridDim, (gridSize + kMaxGridDim - 1) / kMaxGridDim); dim3 dimBlock(kMaxThreadsPerBlock); CheckLaunchParam(dimGrid, dimBlock, "ROIPooling Backward"); cudaStream_t stream = Stream::GetStream(in_grad.stream_); ROIPoolBackwardAccKernel<<>>( - count, top_diff, argmax_data, num_rois, spatial_scale, channels, height, width, - pooled_height, pooled_width, bottom_diff, bottom_rois); + count, top_diff, argmax_data, bottom_diff); MSHADOW_CUDA_POST_KERNEL_CHECK(ROIPoolBackwardAccKernel); } From a82195f579cb23eb5bc631fb2e1ac9d35aaec961 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Tue, 7 May 2019 08:48:09 +0000 Subject: [PATCH 07/13] trigger CI From da11b2bb361a5dffb53db3bfeb142fdf2d26a4bc Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Thu, 9 May 2019 13:24:14 +0800 Subject: [PATCH 08/13] fix index type as index_t --- src/operator/roi_pooling-inl.h | 9 ++++++-- src/operator/roi_pooling.cc | 40 +++++++++++++++++----------------- src/operator/roi_pooling.cu | 28 ++++++++++++------------ 3 files changed, 41 insertions(+), 36 deletions(-) diff --git a/src/operator/roi_pooling-inl.h b/src/operator/roi_pooling-inl.h index b20130edf4a6..1d2a93683121 100644 --- a/src/operator/roi_pooling-inl.h +++ b/src/operator/roi_pooling-inl.h @@ -114,7 +114,7 @@ class ROIPoolingOp : public Operator { Tensor grad_out = out_grad[roipool::kOut].get(s); Tensor bbox = in_data[roipool::kBox].get(s); - Tensor max_idx = out_data[roipool::kMaxIdx].get(s); + Tensor max_idx = out_data[roipool::kMaxIdx].get(s); Tensor grad_in = in_grad[roipool::kData].get(s); Tensor grad_roi = in_grad[roipool::kBox].get(s); CHECK_EQ(grad_out.CheckContiguous(), true); @@ -195,6 +195,7 @@ class ROIPoolingProp : public OperatorProperty { bool InferType(std::vector *in_type, std::vector *out_type, std::vector *aux_type) const override { + using namespace mshadow; CHECK_EQ(in_type->size(), 2U); int dtype = (*in_type)[0]; CHECK_EQ(dtype, (*in_type)[1]); @@ -202,7 +203,11 @@ class ROIPoolingProp : public OperatorProperty { out_type->clear(); out_type->push_back(dtype); - out_type->push_back(mshadow::kInt32); +# if MXNET_USE_INT64_TENSOR_SIZE == 1 + out_type->push_back(kInt64); +# else + out_type->push_back(kInt32); +# endif return true; } diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 76c4518d5752..2067aa5b4005 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -40,12 +40,12 @@ template inline void ROIPoolForward(const Tensor &out, const Tensor &data, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale_) { const Dtype *bottom_data = data.dptr_; const Dtype *bottom_rois = bbox.dptr_; Dtype *top_data = out.dptr_; - int *argmax_data = max_idx.dptr_; + index_t *argmax_data = max_idx.dptr_; const int channels_ = data.size(1); const int height_ = data.size(2); const int width_ = data.size(3); @@ -53,25 +53,25 @@ inline void ROIPoolForward(const Tensor &out, const int pooled_width_ = out.size(3); const int num_rois = bbox.size(0); - const int data_size = data.size(1) * data.size(2) * data.size(3); - const int data_size_c = data.size(2) * data.size(3); - const int out_size_c = out.size(2) * out.size(3); - const int out_size = channels_ * out_size_c; - const int max_idx_size_c = max_idx.size(2) * max_idx.size(3); - const int max_idx_size = channels_ * max_idx_size_c; + const index_t data_size = data.size(1) * data.size(2) * data.size(3); + const index_t data_size_c = data.size(2) * data.size(3); + const index_t out_size_c = out.size(2) * out.size(3); + const index_t out_size = channels_ * out_size_c; + const index_t max_idx_size_c = max_idx.size(2) * max_idx.size(3); + const index_t max_idx_size = channels_ * max_idx_size_c; // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R for (int n = 0; n < num_rois; ++n) { // Increment ROI data pointer const Dtype *bottom_rois_n = bottom_rois + n * bbox.size(1); Dtype *top_data_n = top_data + n * out_size; - int *argmax_data_n = argmax_data + n * max_idx_size; - int roi_batch_ind = bottom_rois_n[0]; + index_t *argmax_data_n = argmax_data + n * max_idx_size; + int roi_batch_ind = static_cast(bottom_rois_n[0]); int roi_start_w = std::round(bottom_rois_n[1] * spatial_scale_); int roi_start_h = std::round(bottom_rois_n[2] * spatial_scale_); int roi_end_w = std::round(bottom_rois_n[3] * spatial_scale_); int roi_end_h = std::round(bottom_rois_n[4] * spatial_scale_); assert(roi_batch_ind >= 0); - assert(static_cast(roi_batch_ind) < data.size(0) /* batch size */); + assert(roi_batch_ind < data.size(0) /* batch size */); // force malformed ROIs to be 1 * 1 int roi_height = max(roi_end_h - roi_start_h + 1, 1); @@ -81,15 +81,15 @@ inline void ROIPoolForward(const Tensor &out, const Dtype bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); - int offset_batch_data = data_size * roi_batch_ind; + index_t offset_batch_data = data_size * roi_batch_ind; #pragma omp parallel for for (int c = 0; c < channels_; ++c) { // Increment all data pointers - int offset_batch_data_c = offset_batch_data + c * data_size_c; + index_t offset_batch_data_c = offset_batch_data + c * data_size_c; const Dtype* batch_data_c = bottom_data + offset_batch_data_c; Dtype* top_data_c = top_data_n + c * out_size_c; - int* argmax_data_c = argmax_data_n + c * max_idx_size_c; + index_t* argmax_data_c = argmax_data_n + c * max_idx_size_c; for (int ph = 0; ph < pooled_height_; ++ph) { for (int pw = 0; pw < pooled_width_; ++pw) { @@ -112,7 +112,7 @@ inline void ROIPoolForward(const Tensor &out, bool is_empty = (hend <= hstart) || (wend <= wstart); - const int pool_index = ph * pooled_width_ + pw; + const index_t pool_index = ph * pooled_width_ + pw; if (is_empty) { top_data_c[pool_index] = 0; argmax_data_c[pool_index] = -1; @@ -120,7 +120,7 @@ inline void ROIPoolForward(const Tensor &out, for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - const int index = h * width_ + w; + const index_t index = h * width_ + w; if (batch_data_c[index] > top_data_c[pool_index]) { top_data_c[pool_index] = batch_data_c[index]; argmax_data_c[pool_index] = offset_batch_data_c + index; @@ -138,16 +138,16 @@ template inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &out_grad, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale_) { const Dtype *top_diff = out_grad.dptr_; Dtype *bottom_diff = in_grad.dptr_; - int *argmax_data = max_idx.dptr_; + index_t *argmax_data = max_idx.dptr_; - const int count = out_grad.shape_.Size(); + const index_t count = out_grad.shape_.Size(); for (int index = 0; index < count; ++index) { - int max_idx = argmax_data[index]; + index_t max_idx = argmax_data[index]; if (max_idx >= 0) { bottom_diff[max_idx] += top_diff[index]; } diff --git a/src/operator/roi_pooling.cu b/src/operator/roi_pooling.cu index 38398151a387..b73a46f742c3 100644 --- a/src/operator/roi_pooling.cu +++ b/src/operator/roi_pooling.cu @@ -38,8 +38,8 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const Dtype* bottom_rois, Dtype* top_data, - int* argmax_data) { - for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; + index_t* argmax_data) { + for (index_t index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x * gridDim.y) { // (n, c, ph, pw) is an element in the pooled output @@ -49,7 +49,7 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, int n = index / pooled_width / pooled_height / channels; bottom_rois += n * 5; - int roi_batch_ind = bottom_rois[0]; + int roi_batch_ind = static_cast(bottom_rois[0]); if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) { top_data[index] = 0; @@ -89,12 +89,12 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, // Define an empty pooling region to be zero Dtype maxval = is_empty ? 0 : -FLT_MAX; // If nothing is pooled, argmax = -1 causes nothing to be backprop'd - int maxidx = -1; - int offset_bottom_data = (roi_batch_ind * channels + c) * height * width; + index_t maxidx = -1; + index_t offset_bottom_data = (roi_batch_ind * channels + c) * height * width; bottom_data += offset_bottom_data; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - int bottom_index = h * width + w; + index_t bottom_index = h * width + w; if (bottom_data[bottom_index] > maxval) { maxval = bottom_data[bottom_index]; maxidx = offset_bottom_data + bottom_index; @@ -110,13 +110,13 @@ template inline void ROIPoolForward(const Tensor &out, const Tensor &data, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { const Dtype *bottom_data = data.dptr_; const Dtype *bottom_rois = bbox.dptr_; Dtype *top_data = out.dptr_; - int *argmax_data = max_idx.dptr_; - const int count = out.shape_.Size(); + index_t *argmax_data = max_idx.dptr_; + const index_t count = out.shape_.Size(); const int batch_size = data.size(0); const int channels = data.size(1); const int height = data.size(2); @@ -137,10 +137,10 @@ inline void ROIPoolForward(const Tensor &out, template __global__ void ROIPoolBackwardAccKernel(const int count, const Dtype* top_diff, const int* argmax_data, Dtype* bottom_diff) { - for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; + for (index_t index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x * gridDim.y) { - int max_idx = argmax_data[index]; + index_t max_idx = argmax_data[index]; if (max_idx >= 0) { atomicAdd(&bottom_diff[max_idx], top_diff[index]); } @@ -151,12 +151,12 @@ template inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &out_grad, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { const Dtype *top_diff = out_grad.dptr_; Dtype *bottom_diff = in_grad.dptr_; - int *argmax_data = max_idx.dptr_; - const int count = out_grad.shape_.Size(); + index_t *argmax_data = max_idx.dptr_; + const index_t count = out_grad.shape_.Size(); const int gridSize = (count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; dim3 dimGrid(kMaxGridDim, (gridSize + kMaxGridDim - 1) / kMaxGridDim); dim3 dimBlock(kMaxThreadsPerBlock); From ec474eff45b4a48efe99078c51e8e74dafe31a16 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Thu, 9 May 2019 13:48:26 +0800 Subject: [PATCH 09/13] remove assert, fix invalid box --- src/operator/roi_pooling.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 2067aa5b4005..bba3bea5ce6a 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -46,6 +46,7 @@ inline void ROIPoolForward(const Tensor &out, const Dtype *bottom_rois = bbox.dptr_; Dtype *top_data = out.dptr_; index_t *argmax_data = max_idx.dptr_; + const int batch_size = data.size(0); const int channels_ = data.size(1); const int height_ = data.size(2); const int width_ = data.size(3); @@ -65,13 +66,13 @@ inline void ROIPoolForward(const Tensor &out, const Dtype *bottom_rois_n = bottom_rois + n * bbox.size(1); Dtype *top_data_n = top_data + n * out_size; index_t *argmax_data_n = argmax_data + n * max_idx_size; - int roi_batch_ind = static_cast(bottom_rois_n[0]); int roi_start_w = std::round(bottom_rois_n[1] * spatial_scale_); int roi_start_h = std::round(bottom_rois_n[2] * spatial_scale_); int roi_end_w = std::round(bottom_rois_n[3] * spatial_scale_); int roi_end_h = std::round(bottom_rois_n[4] * spatial_scale_); - assert(roi_batch_ind >= 0); - assert(roi_batch_ind < data.size(0) /* batch size */); + + int roi_batch_ind = static_cast(bottom_rois_n[0]); + bool is_ind_invalid = (roi_batch_ind < 0) || (roi_batch_ind >= batch_size); // force malformed ROIs to be 1 * 1 int roi_height = max(roi_end_h - roi_start_h + 1, 1); @@ -113,9 +114,10 @@ inline void ROIPoolForward(const Tensor &out, bool is_empty = (hend <= hstart) || (wend <= wstart); const index_t pool_index = ph * pooled_width_ + pw; - if (is_empty) { + if (is_empty || is_ind_invalid) { top_data_c[pool_index] = 0; argmax_data_c[pool_index] = -1; + continue; } for (int h = hstart; h < hend; ++h) { From 21c94642980701a8ae37a0fe07613c3dc4bebb7e Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Thu, 9 May 2019 13:56:40 +0800 Subject: [PATCH 10/13] update --- src/operator/roi_pooling-inl.h | 2 +- src/operator/roi_pooling.cu | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/roi_pooling-inl.h b/src/operator/roi_pooling-inl.h index 1d2a93683121..a189fe231826 100644 --- a/src/operator/roi_pooling-inl.h +++ b/src/operator/roi_pooling-inl.h @@ -83,7 +83,7 @@ class ROIPoolingOp : public Operator { Tensor data = in_data[roipool::kData].get(s); Tensor bbox = in_data[roipool::kBox].get(s); Tensor out = out_data[roipool::kOut].get(s); - Tensor max_idx = out_data[roipool::kMaxIdx].get(s); + Tensor max_idx = out_data[roipool::kMaxIdx].get(s); CHECK_EQ(data.CheckContiguous(), true); CHECK_EQ(bbox.CheckContiguous(), true); CHECK_EQ(out.CheckContiguous(), true); diff --git a/src/operator/roi_pooling.cu b/src/operator/roi_pooling.cu index b73a46f742c3..50fae147415c 100644 --- a/src/operator/roi_pooling.cu +++ b/src/operator/roi_pooling.cu @@ -173,7 +173,7 @@ template inline void ROIPoolForward(const Tensor &out, const Tensor &data, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { cuda::ROIPoolForward(out, data, bbox, max_idx, spatial_scale); } @@ -182,7 +182,7 @@ template inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &out_grad, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { cuda::ROIPoolBackwardAcc(in_grad, out_grad, bbox, max_idx, spatial_scale); } From 4d603b0b1a3a354af08140cb1362d1817cec6cdd Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Thu, 9 May 2019 14:53:01 +0800 Subject: [PATCH 11/13] fix --- src/operator/roi_pooling.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/roi_pooling.cu b/src/operator/roi_pooling.cu index 50fae147415c..b4013a672711 100644 --- a/src/operator/roi_pooling.cu +++ b/src/operator/roi_pooling.cu @@ -136,7 +136,7 @@ inline void ROIPoolForward(const Tensor &out, template __global__ void ROIPoolBackwardAccKernel(const int count, const Dtype* top_diff, - const int* argmax_data, Dtype* bottom_diff) { + const index_t* argmax_data, Dtype* bottom_diff) { for (index_t index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x * gridDim.y) { From 37795f496706b8193a11841c4a9fd913ee34cf70 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Fri, 10 May 2019 00:41:43 +0800 Subject: [PATCH 12/13] trigger CI From cb9c62b40f7432bd24a823ff64ac0e8ec956bede Mon Sep 17 00:00:00 2001 From: JackieWu Date: Thu, 4 Jul 2019 13:04:45 +0800 Subject: [PATCH 13/13] retrigger CI