diff --git a/src/operator/contrib/roi_align-inl.h b/src/operator/contrib/roi_align-inl.h index 263f72a6abc0..9f4d7ce48827 100644 --- a/src/operator/contrib/roi_align-inl.h +++ b/src/operator/contrib/roi_align-inl.h @@ -20,7 +20,7 @@ * Copyright (c) 2018 by Contributors * \file roi_align-inl.h * \brief roi align operator and symbol - * \author Hang Zhang + * \author Hang Zhang, Shesung * modified from Caffe2 */ #ifndef MXNET_OPERATOR_CONTRIB_ROI_ALIGN_INL_H_ @@ -35,7 +35,6 @@ namespace mxnet { namespace op { - // Declare enumeration of input order to make code more intuitive. // These enums are only visible within this header namespace roialign { @@ -48,6 +47,7 @@ struct ROIAlignParam : public dmlc::Parameter { TShape pooled_size; float spatial_scale; int sample_ratio; + bool position_sensitive; DMLC_DECLARE_PARAMETER(ROIAlignParam) { DMLC_DECLARE_FIELD(pooled_size) .set_expect_ndim(2).enforce_nonzero() @@ -57,6 +57,10 @@ struct ROIAlignParam : public dmlc::Parameter { "Equals the reciprocal of total stride in convolutional layers"); DMLC_DECLARE_FIELD(sample_ratio).set_default(-1) .describe("Optional sampling ratio of ROI align, using adaptive size by default."); + DMLC_DECLARE_FIELD(position_sensitive).set_default(false) + .describe("Whether to perform position-sensitive RoI pooling. PSRoIPooling is " + "first proposaled by R-FCN and it can reduce the input channels by ph*pw times, " + "where (ph, pw) is the pooled_size"); } }; diff --git a/src/operator/contrib/roi_align.cc b/src/operator/contrib/roi_align.cc index 76675677fa08..e584ea30325d 100644 --- a/src/operator/contrib/roi_align.cc +++ b/src/operator/contrib/roi_align.cc @@ -20,7 +20,7 @@ * Copyright (c) 2018 by Contributors * \file roi_align.cc * \brief roi align operator - * \author Hang Zhang + * \author Hang Zhang, Shesung * Adapted from Caffe2 */ #include "./roi_align-inl.h" @@ -142,6 +142,7 @@ void ROIAlignForward( const int nthreads, const T* bottom_data, const T& spatial_scale, + const bool position_sensitive, const int channels, const int height, const int width, @@ -156,6 +157,8 @@ void ROIAlignForward( int n_rois = nthreads / channels / pooled_width / pooled_height; // (n, c, ph, pw) is an element in the pooled output // can be parallelized using omp +#pragma omp parallel for \ +num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int n = 0; n < n_rois; n++) { int index_n = n * channels * pooled_width * pooled_height; @@ -208,19 +211,23 @@ void ROIAlignForward( roi_bin_grid_w, &pre_calc); - int c; -#pragma omp parallel for private(c) \ -num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (c = 0; c < channels; c++) { + for (int c = 0; c < channels; c++) { int index_n_c = index_n + c * pooled_width * pooled_height; - const T* offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; int pre_calc_index = 0; for (int ph = 0; ph < pooled_height; ph++) { for (int pw = 0; pw < pooled_width; pw++) { int index = index_n_c + ph * pooled_width + pw; + int c_unpooled = c; + int channels_unpooled = channels; + if (position_sensitive) { + c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw; + channels_unpooled = channels * pooled_height * pooled_width; + } + const T* offset_bottom_data = + bottom_data + (roi_batch_ind * channels_unpooled + c_unpooled) + * height * width; T output_val = 0.; for (int iy = 0; iy < roi_bin_grid_h; iy++) { for (int ix = 0; ix < roi_bin_grid_w; ix++) { @@ -310,6 +317,7 @@ void ROIAlignBackward( const T* top_diff, const int /*num_rois*/, const T& spatial_scale, + const bool position_sensitive, const int channels, const int height, const int width, @@ -347,8 +355,15 @@ void ROIAlignBackward( T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + int c_unpooled = c; + int channels_unpooled = channels; + if (position_sensitive) { + c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw; + channels_unpooled = channels * pooled_height * pooled_width; + } T* offset_bottom_diff = - bottom_diff + (roi_batch_ind * channels + c) * height * width; + bottom_diff + (roi_batch_ind * channels_unpooled + c_unpooled) + * height * width; int top_offset = (n * channels + c) * pooled_height * pooled_width; const T* offset_top_diff = top_diff + top_offset; @@ -426,7 +441,7 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs, const int count = out_data[roialign::kOut].Size(); // const int num_rois = in_data[roialign::kBox].size(0); - const int channels = in_data[roialign::kData].size(1); + const int channels = out_data[roialign::kOut].size(1); // channels of pooled output const int height = in_data[roialign::kData].size(2); const int width = in_data[roialign::kData].size(3); const int pooled_height = out_data[roialign::kOut].size(2); @@ -439,9 +454,9 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs, const DType *bottom_rois = in_data[roialign::kBox].dptr(); DType *top_data = out_data[roialign::kOut].dptr(); - ROIAlignForward(count, bottom_data, param.spatial_scale, channels, - height, width, pooled_height, pooled_width, param.sample_ratio, - bottom_rois, rois_cols, top_data); + ROIAlignForward(count, bottom_data, param.spatial_scale, param.position_sensitive, + channels, height, width, pooled_height, pooled_width, + param.sample_ratio, bottom_rois, rois_cols, top_data); }) } @@ -470,7 +485,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs, const int count = out_grad[0].Size(); const int num_rois = in_data[0].size(0); - const int channels = outputs[0].size(1); + const int channels = out_grad[0].size(1); // channels of pooled output const int height = outputs[0].size(2); const int width = outputs[0].size(3); const int pooled_height = out_grad[0].size(2); @@ -489,8 +504,9 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs, Fill(s, outputs[0], kWriteTo, static_cast(0)); } ROIAlignBackward(count, top_diff, num_rois, param.spatial_scale, - channels, height, width, pooled_height, pooled_width, - param.sample_ratio, grad_in, bottom_rois, rois_cols); + param.position_sensitive, channels, height, width, + pooled_height, pooled_width, param.sample_ratio, grad_in, + bottom_rois, rois_cols); } if (kWriteTo == req[roialign::kBox]) { Fill(s, outputs[1], kWriteTo, static_cast(0)); @@ -545,8 +561,17 @@ He, Kaiming, et al. "Mask R-CNN." ICCV, 2017 CHECK_EQ(bshape[1], 5) << "bbox should be a 2D tensor of shape [batch, 5]"; // out: [num_rois, c, pooled_h, pooled_w] out_shape->clear(); - out_shape->push_back( - Shape4(bshape[0], dshape[1], param.pooled_size[0], param.pooled_size[1])); + if (param.position_sensitive) { + CHECK_EQ(dshape[1] % (param.pooled_size[0]*param.pooled_size[1]), 0) << + "Input channels should be divided by pooled_size[0]*pooled_size[1]" + "when position_sensitive is true."; + out_shape->push_back( + Shape4(bshape[0], dshape[1]/param.pooled_size[0]/param.pooled_size[1], + param.pooled_size[0], param.pooled_size[1])); + } else { + out_shape->push_back( + Shape4(bshape[0], dshape[1], param.pooled_size[0], param.pooled_size[1])); + } return true; }) .set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, diff --git a/src/operator/contrib/roi_align.cu b/src/operator/contrib/roi_align.cu index d3db70b73b1a..38b461d5f58c 100644 --- a/src/operator/contrib/roi_align.cu +++ b/src/operator/contrib/roi_align.cu @@ -20,7 +20,7 @@ * Copyright (c) 2018 by Contributors * \file roi_align.cu * \brief roi align operator - * \author Hang Zhang + * \author Hang Zhang, Shesung * Adapted from Caffe2 */ #include "./roi_align-inl.h" @@ -111,6 +111,7 @@ __global__ void RoIAlignForwardKernel( const int nthreads, const T* bottom_data, const T spatial_scale, + const bool position_sensitive, const int channels, const int height, const int width, @@ -145,8 +146,15 @@ __global__ void RoIAlignForwardKernel( T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + int c_unpooled = c; + int channels_unpooled = channels; + if (position_sensitive) { + c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw; + channels_unpooled = channels * pooled_height * pooled_width; + } const T* offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; + bottom_data + (roi_batch_ind * channels_unpooled + c_unpooled) + * height * width; // We use roi_bin_grid to sample the grid and mimic integral int roi_bin_grid_h = (sampling_ratio > 0) @@ -242,6 +250,7 @@ __global__ void RoIAlignBackwardKernel( const T* top_diff, const int num_rois, const T spatial_scale, + const bool position_sensitive, const int channels, const int height, const int width, @@ -276,8 +285,15 @@ __global__ void RoIAlignBackwardKernel( T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + int c_unpooled = c; + int channels_unpooled = channels; + if (position_sensitive) { + c_unpooled = c * pooled_height * pooled_width + ph * pooled_width + pw; + channels_unpooled = channels * pooled_height * pooled_width; + } T* offset_bottom_diff = - bottom_diff + (roi_batch_ind * channels + c) * height * width; + bottom_diff + (roi_batch_ind * channels_unpooled + c_unpooled) + * height * width; int top_offset = (n * channels + c) * pooled_height * pooled_width; const T* offset_top_diff = top_diff + top_offset; @@ -357,7 +373,7 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs, const int count = out_data[roialign::kOut].Size(); const int num_rois = in_data[roialign::kBox].size(0); - const int channels = in_data[roialign::kData].size(1); + const int channels = out_data[roialign::kOut].size(1); // channels of pooled output const int height = in_data[roialign::kData].size(2); const int width = in_data[roialign::kData].size(3); const int pooled_height = out_data[roialign::kOut].size(2); @@ -377,6 +393,7 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs, count, bottom_data, param.spatial_scale, + param.position_sensitive, channels, height, width, @@ -414,7 +431,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs, const int count = out_grad[0].Size(); const int num_rois = in_data[0].size(0); - const int channels = outputs[0].size(1); + const int channels = out_grad[0].size(1); // channels of pooled output const int height = outputs[0].size(2); const int width = outputs[0].size(3); const int pooled_height = out_grad[0].size(2); @@ -445,6 +462,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs, top_diff, num_rois, param.spatial_scale, + param.position_sensitive, channels, height, width, diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index a895594ce289..d8e80d7d6938 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -17,6 +17,7 @@ # pylint: skip-file from __future__ import print_function +from __future__ import division import numpy as np import mxnet as mx import copy @@ -6899,14 +6900,16 @@ def bilinear_interpolate(bottom, height, width, y, x): ] return val, grad - def roialign_forward_backward(data, rois, pooled_size, spatial_scale, sampling_ratio, dy): + def roialign_forward_backward(data, rois, pooled_size, spatial_scale, sampling_ratio, + position_sensitive, dy): N, C, H, W = data.shape R = rois.shape[0] PH, PW = pooled_size assert len(rois.shape) == 2 assert rois.shape[1] == 5 - out = np.zeros((R, C, PH, PW)) + C_out = C // PH // PW if position_sensitive else C + out = np.zeros((R, C_out, PH, PW)) dx = np.zeros_like(data) drois = np.zeros_like(rois) @@ -6924,24 +6927,25 @@ def roialign_forward_backward(data, rois, pooled_size, spatial_scale, sampling_r roi_bin_grid_h = int(np.ceil(roi_h * 1.0 / PH)) roi_bin_grid_w = int(np.ceil(roi_w * 1.0 / PW)) count = roi_bin_grid_h * roi_bin_grid_w - for c in range(C): + for c in range(C_out): for ph in range(PH): for pw in range(PW): val = 0.0 + c_in = c * PH * PW + ph * PW + pw if position_sensitive else c for iy in range(roi_bin_grid_h): y = sh + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h for ix in range(roi_bin_grid_w): x = sw + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w - v, g = bilinear_interpolate(bdata[c], H, W, y, x) + v, g = bilinear_interpolate(bdata[c_in], H, W, y, x) val += v # compute grad for qy, qx, qw in g: - dx[batch_ind, c, qy, qx] += dy[r, c, ph, pw] * qw * 1.0 / count + dx[batch_ind, c_in, qy, qx] += dy[r, c, ph, pw] * qw * 1.0 / count out[r, c, ph, pw] = val * 1.0 / count return out, [dx, drois] - def test_roi_align_value(sampling_ratio=0): + def test_roi_align_value(sampling_ratio=0, position_sensitive=False): ctx=default_context() dtype = np.float32 @@ -6950,6 +6954,7 @@ def test_roi_align_value(sampling_ratio=0): assert H == W R = 7 pooled_size = (3, 4) + C = C * pooled_size[0] * pooled_size[1] if position_sensitive else C spatial_scale = H * 1.0 / dlen data = mx.nd.array(np.arange(N*C*W*H).reshape((N,C,H,W)), ctx=ctx, dtype = dtype) @@ -6964,11 +6969,14 @@ def test_roi_align_value(sampling_ratio=0): rois.attach_grad() with mx.autograd.record(): output = mx.nd.contrib.ROIAlign(data, rois, pooled_size=pooled_size, - spatial_scale=spatial_scale, sample_ratio=sampling_ratio) - dy = mx.nd.random.uniform(-1, 1, (R, C) + pooled_size, ctx=ctx, dtype = dtype) + spatial_scale=spatial_scale, sample_ratio=sampling_ratio, + position_sensitive=position_sensitive) + C_out = C // pooled_size[0] // pooled_size[1] if position_sensitive else C + dy = mx.nd.random.uniform(-1, 1, (R, C_out) + pooled_size, ctx=ctx, dtype = dtype) output.backward(dy) real_output, [dx, drois] = roialign_forward_backward(data.asnumpy(), rois.asnumpy(), pooled_size, - spatial_scale, sampling_ratio, dy.asnumpy()) + spatial_scale, sampling_ratio, + position_sensitive, dy.asnumpy()) assert np.allclose(output.asnumpy(), real_output) # It seems that the precision between Cfloat and Pyfloat is different. assert np.allclose(data.grad.asnumpy(), dx, atol = 1e-5), np.abs(data.grad.asnumpy() - dx).max() @@ -6994,7 +7002,8 @@ def test_roi_align_autograd(sampling_ratio=0): numeric_eps=1e-4, rtol=1e-1, atol=1e-4, ctx=ctx) test_roi_align_value() - test_roi_align_value(2) + test_roi_align_value(sampling_ratio=2) + test_roi_align_value(position_sensitive=True) test_roi_align_autograd() @with_seed()