diff --git a/src/operator/roi_pooling-inl.h b/src/operator/roi_pooling-inl.h index ce0efe9b07c9..a189fe231826 100644 --- a/src/operator/roi_pooling-inl.h +++ b/src/operator/roi_pooling-inl.h @@ -83,7 +83,7 @@ class ROIPoolingOp : public Operator { Tensor data = in_data[roipool::kData].get(s); Tensor bbox = in_data[roipool::kBox].get(s); Tensor out = out_data[roipool::kOut].get(s); - Tensor max_idx = out_data[roipool::kMaxIdx].get(s); + Tensor max_idx = out_data[roipool::kMaxIdx].get(s); CHECK_EQ(data.CheckContiguous(), true); CHECK_EQ(bbox.CheckContiguous(), true); CHECK_EQ(out.CheckContiguous(), true); @@ -114,7 +114,7 @@ class ROIPoolingOp : public Operator { Tensor grad_out = out_grad[roipool::kOut].get(s); Tensor bbox = in_data[roipool::kBox].get(s); - Tensor max_idx = out_data[roipool::kMaxIdx].get(s); + Tensor max_idx = out_data[roipool::kMaxIdx].get(s); Tensor grad_in = in_grad[roipool::kData].get(s); Tensor grad_roi = in_grad[roipool::kBox].get(s); CHECK_EQ(grad_out.CheckContiguous(), true); @@ -195,6 +195,7 @@ class ROIPoolingProp : public OperatorProperty { bool InferType(std::vector *in_type, std::vector *out_type, std::vector *aux_type) const override { + using namespace mshadow; CHECK_EQ(in_type->size(), 2U); int dtype = (*in_type)[0]; CHECK_EQ(dtype, (*in_type)[1]); @@ -202,7 +203,11 @@ class ROIPoolingProp : public OperatorProperty { out_type->clear(); out_type->push_back(dtype); - out_type->push_back(dtype); +# if MXNET_USE_INT64_TENSOR_SIZE == 1 + out_type->push_back(kInt64); +# else + out_type->push_back(kInt32); +# endif return true; } diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 8862d0db1401..bba3bea5ce6a 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -40,12 +40,13 @@ template inline void ROIPoolForward(const Tensor &out, const Tensor &data, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale_) { const Dtype *bottom_data = data.dptr_; const Dtype *bottom_rois = bbox.dptr_; Dtype *top_data = out.dptr_; - Dtype *argmax_data = max_idx.dptr_; + index_t *argmax_data = max_idx.dptr_; + const int batch_size = data.size(0); const int channels_ = data.size(1); const int height_ = data.size(2); const int width_ = data.size(3); @@ -53,25 +54,25 @@ inline void ROIPoolForward(const Tensor &out, const int pooled_width_ = out.size(3); const int num_rois = bbox.size(0); - const int data_size = data.size(1) * data.size(2) * data.size(3); - const int data_size_c = data.size(2) * data.size(3); - const int out_size_c = out.size(2) * out.size(3); - const int out_size = channels_ * out_size_c; - const int max_idx_size_c = max_idx.size(2) * max_idx.size(3); - const int max_idx_size = channels_ * max_idx_size_c; + const index_t data_size = data.size(1) * data.size(2) * data.size(3); + const index_t data_size_c = data.size(2) * data.size(3); + const index_t out_size_c = out.size(2) * out.size(3); + const index_t out_size = channels_ * out_size_c; + const index_t max_idx_size_c = max_idx.size(2) * max_idx.size(3); + const index_t max_idx_size = channels_ * max_idx_size_c; // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R for (int n = 0; n < num_rois; ++n) { // Increment ROI data pointer const Dtype *bottom_rois_n = bottom_rois + n * bbox.size(1); Dtype *top_data_n = top_data + n * out_size; - Dtype *argmax_data_n = argmax_data + n * max_idx_size; - int roi_batch_ind = bottom_rois_n[0]; + index_t *argmax_data_n = argmax_data + n * max_idx_size; int roi_start_w = std::round(bottom_rois_n[1] * spatial_scale_); int roi_start_h = std::round(bottom_rois_n[2] * spatial_scale_); int roi_end_w = std::round(bottom_rois_n[3] * spatial_scale_); int roi_end_h = std::round(bottom_rois_n[4] * spatial_scale_); - assert(roi_batch_ind >= 0); - assert(static_cast(roi_batch_ind) < data.size(0) /* batch size */); + + int roi_batch_ind = static_cast(bottom_rois_n[0]); + bool is_ind_invalid = (roi_batch_ind < 0) || (roi_batch_ind >= batch_size); // force malformed ROIs to be 1 * 1 int roi_height = max(roi_end_h - roi_start_h + 1, 1); @@ -81,14 +82,15 @@ inline void ROIPoolForward(const Tensor &out, const Dtype bin_size_w = static_cast(roi_width) / static_cast(pooled_width_); - const Dtype* batch_data = bottom_data + data_size * roi_batch_ind; + index_t offset_batch_data = data_size * roi_batch_ind; #pragma omp parallel for for (int c = 0; c < channels_; ++c) { // Increment all data pointers - const Dtype* batch_data_c = batch_data + c * data_size_c; + index_t offset_batch_data_c = offset_batch_data + c * data_size_c; + const Dtype* batch_data_c = bottom_data + offset_batch_data_c; Dtype* top_data_c = top_data_n + c * out_size_c; - Dtype* argmax_data_c = argmax_data_n + c * max_idx_size_c; + index_t* argmax_data_c = argmax_data_n + c * max_idx_size_c; for (int ph = 0; ph < pooled_height_; ++ph) { for (int pw = 0; pw < pooled_width_; ++pw) { @@ -111,18 +113,19 @@ inline void ROIPoolForward(const Tensor &out, bool is_empty = (hend <= hstart) || (wend <= wstart); - const int pool_index = ph * pooled_width_ + pw; - if (is_empty) { + const index_t pool_index = ph * pooled_width_ + pw; + if (is_empty || is_ind_invalid) { top_data_c[pool_index] = 0; argmax_data_c[pool_index] = -1; + continue; } for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - const int index = h * width_ + w; + const index_t index = h * width_ + w; if (batch_data_c[index] > top_data_c[pool_index]) { top_data_c[pool_index] = batch_data_c[index]; - argmax_data_c[pool_index] = index; + argmax_data_c[pool_index] = offset_batch_data_c + index; } } } @@ -137,91 +140,18 @@ template inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &out_grad, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale_) { const Dtype *top_diff = out_grad.dptr_; - const Dtype *bottom_rois = bbox.dptr_; Dtype *bottom_diff = in_grad.dptr_; - Dtype *argmax_data = max_idx.dptr_; - - const int batch_size_ = in_grad.size(0); - const int channels_ = in_grad.size(1); - const int height_ = in_grad.size(2); - const int width_ = in_grad.size(3); - const int pooled_height_ = out_grad.size(2); - const int pooled_width_ = out_grad.size(3); - - const int num_rois = bbox.size(0); - - for (int b = 0; b < batch_size_; ++b) { - for (int c = 0; c < channels_; ++c) { - for (int h = 0; h < height_; ++h) { - for (int w = 0; w < width_; ++w) { - int offset_bottom_diff = (b * channels_ + c) * height_ * width_; - offset_bottom_diff += h * width_ + w; - - Dtype gradient = 0; - // Accumulate gradient over all ROIs that pooled this element - for (int roi_n = 0; roi_n < num_rois; ++roi_n) { - const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5; - int roi_batch_ind = offset_bottom_rois[0]; - assert(roi_batch_ind >= 0); - assert(roi_batch_ind < batch_size_); - if (b != roi_batch_ind) { - continue; - } + index_t *argmax_data = max_idx.dptr_; - int roi_start_w = std::round(offset_bottom_rois[1] * spatial_scale_); - int roi_start_h = std::round(offset_bottom_rois[2] * spatial_scale_); - int roi_end_w = std::round(offset_bottom_rois[3] * spatial_scale_); - int roi_end_h = std::round(offset_bottom_rois[4] * spatial_scale_); + const index_t count = out_grad.shape_.Size(); - bool in_roi = (w >= roi_start_w && w <= roi_end_w && - h >= roi_start_h && h <= roi_end_h); - if (!in_roi) { - continue; - } - - // force malformed ROIs to be 1 * 1 - int roi_height = max(roi_end_h - roi_start_h + 1, 1); - int roi_width = max(roi_end_w - roi_start_w + 1, 1); - const Dtype bin_size_h = static_cast(roi_height) - / static_cast(pooled_height_); - const Dtype bin_size_w = static_cast(roi_width) - / static_cast(pooled_width_); - - // compute pooled regions correspond to original (h, w) point - int phstart = static_cast(floor(static_cast(h - roi_start_h) - / bin_size_h)); - int pwstart = static_cast(floor(static_cast(w - roi_start_w) - / bin_size_w)); - int phend = static_cast(ceil(static_cast(h - roi_start_h + 1) - / bin_size_h)); - int pwend = static_cast(ceil(static_cast(w - roi_start_w + 1) - / bin_size_w)); - - // clip to boundaries of pooled region - phstart = min(max(phstart, 0), pooled_height_); - phend = min(max(phend, 0), pooled_height_); - pwstart = min(max(pwstart, 0), pooled_width_); - pwend = min(max(pwend, 0), pooled_width_); - - // accumulate over gradients in pooled regions - int offset = (roi_n * channels_ + c) * pooled_height_ * pooled_width_; - const Dtype* offset_top_diff = top_diff + offset; - const Dtype* offset_argmax_data = argmax_data + offset; - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - const int pooled_index = ph * pooled_width_ + pw; - if (static_cast(offset_argmax_data[pooled_index]) == h * width_ + w) { - gradient += offset_top_diff[pooled_index]; - } - } - } - } - bottom_diff[offset_bottom_diff] += gradient; - } - } + for (int index = 0; index < count; ++index) { + index_t max_idx = argmax_data[index]; + if (max_idx >= 0) { + bottom_diff[max_idx] += top_diff[index]; } } diff --git a/src/operator/roi_pooling.cu b/src/operator/roi_pooling.cu index 9ea99b309aaf..b4013a672711 100644 --- a/src/operator/roi_pooling.cu +++ b/src/operator/roi_pooling.cu @@ -34,12 +34,12 @@ namespace cuda { template __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, - const float spatial_scale, const int channels, - const int height, const int width, + const float spatial_scale, const int batch_size, + const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const Dtype* bottom_rois, Dtype* top_data, - Dtype* argmax_data) { - for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; + index_t* argmax_data) { + for (index_t index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x * gridDim.y) { // (n, c, ph, pw) is an element in the pooled output @@ -49,11 +49,11 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, int n = index / pooled_width / pooled_height / channels; bottom_rois += n * 5; - int roi_batch_ind = bottom_rois[0]; + int roi_batch_ind = static_cast(bottom_rois[0]); - if (roi_batch_ind < 0) { + if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) { top_data[index] = 0; - argmax_data[index] = 0; + argmax_data[index] = -1; continue; } @@ -89,19 +89,20 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, // Define an empty pooling region to be zero Dtype maxval = is_empty ? 0 : -FLT_MAX; // If nothing is pooled, argmax = -1 causes nothing to be backprop'd - int maxidx = -1; - bottom_data += (roi_batch_ind * channels + c) * height * width; + index_t maxidx = -1; + index_t offset_bottom_data = (roi_batch_ind * channels + c) * height * width; + bottom_data += offset_bottom_data; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - int bottom_index = h * width + w; + index_t bottom_index = h * width + w; if (bottom_data[bottom_index] > maxval) { maxval = bottom_data[bottom_index]; - maxidx = bottom_index; + maxidx = offset_bottom_data + bottom_index; } } } top_data[index] = maxval; - argmax_data[index] = (Dtype)maxidx; + argmax_data[index] = maxidx; } } @@ -109,13 +110,14 @@ template inline void ROIPoolForward(const Tensor &out, const Tensor &data, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { const Dtype *bottom_data = data.dptr_; const Dtype *bottom_rois = bbox.dptr_; Dtype *top_data = out.dptr_; - Dtype *argmax_data = max_idx.dptr_; - const int count = out.shape_.Size(); + index_t *argmax_data = max_idx.dptr_; + const index_t count = out.shape_.Size(); + const int batch_size = data.size(0); const int channels = data.size(1); const int height = data.size(2); const int width = data.size(3); @@ -127,84 +129,21 @@ inline void ROIPoolForward(const Tensor &out, CheckLaunchParam(dimGrid, dimBlock, "ROIPooling Forward"); cudaStream_t stream = Stream::GetStream(out.stream_); ROIPoolForwardKernel<<>>( - count, bottom_data, spatial_scale, channels, height, width, + count, bottom_data, spatial_scale, batch_size, channels, height, width, pooled_height, pooled_width, bottom_rois, top_data, argmax_data); MSHADOW_CUDA_POST_KERNEL_CHECK(ROIPoolForwardKernel); } template __global__ void ROIPoolBackwardAccKernel(const int count, const Dtype* top_diff, - const Dtype* argmax_data, const int num_rois, - const float spatial_scale, const int channels, - const int height, const int width, - const int pooled_height, const int pooled_width, - Dtype* bottom_diff, const Dtype* bottom_rois) { - for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; + const index_t* argmax_data, Dtype* bottom_diff) { + for (index_t index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x * gridDim.y) { - // (n, c, h, w) coords in bottom data - int w = index % width; - int h = (index / width) % height; - int c = (index / width / height) % channels; - int n = index / width / height / channels; - - Dtype gradient = 0; - // Accumulate gradient over all ROIs that pooled this element - for (int roi_n = 0; roi_n < num_rois; ++roi_n) { - const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5; - int roi_batch_ind = offset_bottom_rois[0]; - // Skip if ROI's batch index doesn't match n - if (n != roi_batch_ind) { - continue; - } - - int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); - - // Skip if ROI doesn't include (h, w) - const bool in_roi = (w >= roi_start_w && w <= roi_end_w && - h >= roi_start_h && h <= roi_end_h); - if (!in_roi) { - continue; - } - - int offset = (roi_n * channels + c) * pooled_height * pooled_width; - const Dtype* offset_top_diff = top_diff + offset; - const Dtype* offset_argmax_data = argmax_data + offset; - - // Compute feasible set of pooled units that could have pooled - // this bottom unit - - // Force malformed ROIs to be 1x1 - int roi_width = max(roi_end_w - roi_start_w + 1, 1); - int roi_height = max(roi_end_h - roi_start_h + 1, 1); - - Dtype bin_size_h = static_cast(roi_height) - / static_cast(pooled_height); - Dtype bin_size_w = static_cast(roi_width) - / static_cast(pooled_width); - - int phstart = floor(static_cast(h - roi_start_h) / bin_size_h); - int phend = ceil(static_cast(h - roi_start_h + 1) / bin_size_h); - int pwstart = floor(static_cast(w - roi_start_w) / bin_size_w); - int pwend = ceil(static_cast(w - roi_start_w + 1) / bin_size_w); - - phstart = min(max(phstart, 0), pooled_height); - phend = min(max(phend, 0), pooled_height); - pwstart = min(max(pwstart, 0), pooled_width); - pwend = min(max(pwend, 0), pooled_width); - - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - if (static_cast(offset_argmax_data[ph * pooled_width + pw]) == (h * width + w)) { - gradient += offset_top_diff[ph * pooled_width + pw]; - } - } - } + index_t max_idx = argmax_data[index]; + if (max_idx >= 0) { + atomicAdd(&bottom_diff[max_idx], top_diff[index]); } - bottom_diff[index] += gradient; } } @@ -212,27 +151,19 @@ template inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &out_grad, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { const Dtype *top_diff = out_grad.dptr_; - const Dtype *bottom_rois = bbox.dptr_; Dtype *bottom_diff = in_grad.dptr_; - Dtype *argmax_data = max_idx.dptr_; - const int count = in_grad.shape_.Size(); - const int num_rois = bbox.size(0); - const int channels = in_grad.size(1); - const int height = in_grad.size(2); - const int width = in_grad.size(3); - const int pooled_height = out_grad.size(2); - const int pooled_width = out_grad.size(3); + index_t *argmax_data = max_idx.dptr_; + const index_t count = out_grad.shape_.Size(); const int gridSize = (count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; dim3 dimGrid(kMaxGridDim, (gridSize + kMaxGridDim - 1) / kMaxGridDim); dim3 dimBlock(kMaxThreadsPerBlock); CheckLaunchParam(dimGrid, dimBlock, "ROIPooling Backward"); cudaStream_t stream = Stream::GetStream(in_grad.stream_); ROIPoolBackwardAccKernel<<>>( - count, top_diff, argmax_data, num_rois, spatial_scale, channels, height, width, - pooled_height, pooled_width, bottom_diff, bottom_rois); + count, top_diff, argmax_data, bottom_diff); MSHADOW_CUDA_POST_KERNEL_CHECK(ROIPoolBackwardAccKernel); } @@ -242,7 +173,7 @@ template inline void ROIPoolForward(const Tensor &out, const Tensor &data, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { cuda::ROIPoolForward(out, data, bbox, max_idx, spatial_scale); } @@ -251,7 +182,7 @@ template inline void ROIPoolBackwardAcc(const Tensor &in_grad, const Tensor &out_grad, const Tensor &bbox, - const Tensor &max_idx, + const Tensor &max_idx, const float spatial_scale) { cuda::ROIPoolBackwardAcc(in_grad, out_grad, bbox, max_idx, spatial_scale); }