diff --git a/src/operator/contrib/mrcnn_mask_target-inl.h b/src/operator/contrib/mrcnn_mask_target-inl.h index af297e745487..2318691be9fe 100644 --- a/src/operator/contrib/mrcnn_mask_target-inl.h +++ b/src/operator/contrib/mrcnn_mask_target-inl.h @@ -46,6 +46,7 @@ struct MRCNNMaskTargetParam : public dmlc::Parameter { int num_rois; int num_classes; int sample_ratio; + bool aligned; mxnet::TShape mask_size; DMLC_DECLARE_PARAMETER(MRCNNMaskTargetParam) { @@ -58,6 +59,9 @@ struct MRCNNMaskTargetParam : public dmlc::Parameter { .describe("Size of the pooled masks height and width: (h, w)."); DMLC_DECLARE_FIELD(sample_ratio).set_default(2) .describe("Sampling ratio of ROI align. Set to -1 to use adaptative size."); + DMLC_DECLARE_FIELD(aligned).set_default(false) + .describe("Center-aligned ROIAlign introduced in Detectron2. " + "To enable, set aligned to True."); } }; diff --git a/src/operator/contrib/mrcnn_mask_target.cu b/src/operator/contrib/mrcnn_mask_target.cu index 779fb0141e64..e85b6a5e2c95 100644 --- a/src/operator/contrib/mrcnn_mask_target.cu +++ b/src/operator/contrib/mrcnn_mask_target.cu @@ -119,7 +119,8 @@ __device__ void RoIAlignForward( const int sampling_ratio, const int num_rois, const int num_gtmasks, - T* out_data) { // (B, N, C, H, W) + const bool continuous_coordinate, + T* out_data) {// (B, N, C, H, W) // Update kernel for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < num_el; @@ -135,14 +136,21 @@ __device__ void RoIAlignForward( const T* offset_rois = rois + batch_idx * (4 * num_rois) + n * 4; // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_rois[0]; - T roi_start_h = offset_rois[1]; - T roi_end_w = offset_rois[2]; - T roi_end_h = offset_rois[3]; + T roi_offset = continuous_coordinate ? static_cast(0.5) : static_cast(0); + T roi_start_w = offset_rois[0] - roi_offset; + T roi_start_h = offset_rois[1] - roi_offset; + T roi_end_w = offset_rois[2] - roi_offset; + T roi_end_h = offset_rois[3] - roi_offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; // Force malformed ROIs to be 1x1 - T roi_width = max(roi_end_w - roi_start_w, (T)1.); - T roi_height = max(roi_end_h - roi_start_h, (T)1.); + if (!continuous_coordinate) { // backward compatiblity + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); @@ -190,6 +198,7 @@ __global__ void MRCNNMaskTargetKernel(const DType *rois, DType* sampled_masks, DType* mask_cls, const int total_out_el, + const bool aligned, int batch_size, int num_classes, int num_rois, @@ -202,7 +211,7 @@ __global__ void MRCNNMaskTargetKernel(const DType *rois, // computing sampled_masks RoIAlignForward(gt_masks, rois, matches, total_out_el, num_classes, gt_height, gt_width, mask_size_h, mask_size_w, - sample_ratio, num_rois, num_gtmasks, sampled_masks); + sample_ratio, num_rois, num_gtmasks, aligned, sampled_masks); // computing mask_cls int num_masks = batch_size * num_rois * num_classes; int mask_vol = mask_size_h * mask_size_w; @@ -251,8 +260,8 @@ void MRCNNMaskTargetRun(const MRCNNMaskTargetParam& param, const std::vecto MRCNNMaskTargetKernel<<>> (rois.dptr_, gt_masks.dptr_, matches.dptr_, cls_targets.dptr_, - out_masks.dptr_, out_mask_cls.dptr_, - num_el, batch_size, param.num_classes, param.num_rois, + out_masks.dptr_, out_mask_cls.dptr_, num_el, param.aligned, + batch_size, param.num_classes, param.num_rois, num_gtmasks, gt_height, gt_width, param.mask_size[0], param.mask_size[1], param.sample_ratio); MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNMaskTargetKernel); diff --git a/src/operator/contrib/roi_align-inl.h b/src/operator/contrib/roi_align-inl.h index b28e437a7e09..8b3b914435cc 100644 --- a/src/operator/contrib/roi_align-inl.h +++ b/src/operator/contrib/roi_align-inl.h @@ -48,6 +48,7 @@ struct ROIAlignParam : public dmlc::Parameter { float spatial_scale; int sample_ratio; bool position_sensitive; + bool aligned; DMLC_DECLARE_PARAMETER(ROIAlignParam) { DMLC_DECLARE_FIELD(pooled_size) .set_expect_ndim(2).enforce_nonzero() @@ -61,6 +62,9 @@ struct ROIAlignParam : public dmlc::Parameter { .describe("Whether to perform position-sensitive RoI pooling. PSRoIPooling is " "first proposaled by R-FCN and it can reduce the input channels by ph*pw times, " "where (ph, pw) is the pooled_size"); + DMLC_DECLARE_FIELD(aligned).set_default(false) + .describe("Center-aligned ROIAlign introduced in Detectron2. " + "To enable, set aligned to True."); } }; diff --git a/src/operator/contrib/roi_align.cc b/src/operator/contrib/roi_align.cc index ee91561a6818..4741be189efc 100644 --- a/src/operator/contrib/roi_align.cc +++ b/src/operator/contrib/roi_align.cc @@ -143,6 +143,7 @@ void ROIAlignForward( const T* bottom_data, const T& spatial_scale, const bool position_sensitive, + const bool continuous_coordinate, const int channels, const int height, const int width, @@ -175,14 +176,22 @@ num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) } // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[0] * spatial_scale; - T roi_start_h = offset_bottom_rois[1] * spatial_scale; - T roi_end_w = offset_bottom_rois[2] * spatial_scale; - T roi_end_h = offset_bottom_rois[3] * spatial_scale; - - // Force malformed ROIs to be 1x1 - T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); - T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T roi_offset = continuous_coordinate ? static_cast(0.5) : static_cast(0); + T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset; + T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset; + T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset; + T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (continuous_coordinate) { + CHECK_GT(roi_width, 0.); + CHECK_GT(roi_height, 0.); + } else { // backward compatiblity + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); @@ -322,6 +331,7 @@ void ROIAlignBackward( const int /*num_rois*/, const T& spatial_scale, const bool position_sensitive, + const bool continuous_coordinate, const int channels, const int height, const int width, @@ -349,14 +359,19 @@ void ROIAlignBackward( } // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[0] * spatial_scale; - T roi_start_h = offset_bottom_rois[1] * spatial_scale; - T roi_end_w = offset_bottom_rois[2] * spatial_scale; - T roi_end_h = offset_bottom_rois[3] * spatial_scale; - - // Force malformed ROIs to be 1x1 - T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); - T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T roi_offset = continuous_coordinate ? static_cast(0.5) : static_cast(0); + T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset; + T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset; + T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset; + T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!continuous_coordinate) { // backward compatiblity + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); @@ -460,7 +475,7 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs, DType *top_data = out_data[roialign::kOut].dptr(); ROIAlignForward(count, bottom_data, param.spatial_scale, param.position_sensitive, - channels, height, width, pooled_height, pooled_width, + param.aligned, channels, height, width, pooled_height, pooled_width, param.sample_ratio, bottom_rois, rois_cols, top_data); }) } @@ -509,7 +524,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs, Fill(s, outputs[0], kWriteTo, static_cast(0)); } ROIAlignBackward(count, top_diff, num_rois, param.spatial_scale, - param.position_sensitive, channels, height, width, + param.position_sensitive, param.aligned, channels, height, width, pooled_height, pooled_width, param.sample_ratio, grad_in, bottom_rois, rois_cols); } diff --git a/src/operator/contrib/roi_align.cu b/src/operator/contrib/roi_align.cu index a0fd6f93686c..8cf5f1f5efe9 100644 --- a/src/operator/contrib/roi_align.cu +++ b/src/operator/contrib/roi_align.cu @@ -109,6 +109,7 @@ __global__ void RoIAlignForwardKernel( const T* bottom_data, const T spatial_scale, const bool position_sensitive, + const bool continuous_coordinate, const int channels, const int height, const int width, @@ -133,18 +134,19 @@ __global__ void RoIAlignForwardKernel( } // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[1] * spatial_scale; - T roi_start_h = offset_bottom_rois[2] * spatial_scale; - T roi_end_w = offset_bottom_rois[3] * spatial_scale; - T roi_end_h = offset_bottom_rois[4] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); - - // Force malformed ROIs to be 1x1 - T roi_width = max(roi_end_w - roi_start_w, (T)1.); - T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T roi_offset = continuous_coordinate ? static_cast(0.5) : static_cast(0); + T roi_start_w = offset_bottom_rois[1] * spatial_scale - roi_offset; + T roi_start_h = offset_bottom_rois[2] * spatial_scale - roi_offset; + T roi_end_w = offset_bottom_rois[3] * spatial_scale - roi_offset; + T roi_end_h = offset_bottom_rois[4] * spatial_scale - roi_offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!continuous_coordinate) { // backward compatiblity + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); @@ -253,6 +255,7 @@ __global__ void RoIAlignBackwardKernel( const int num_rois, const T spatial_scale, const bool position_sensitive, + const bool continuous_coordinate, const int channels, const int height, const int width, @@ -273,18 +276,19 @@ __global__ void RoIAlignBackwardKernel( if (roi_batch_ind < 0) continue; // Do not using rounding; this implementation detail is critical - T roi_start_w = offset_bottom_rois[1] * spatial_scale; - T roi_start_h = offset_bottom_rois[2] * spatial_scale; - T roi_end_w = offset_bottom_rois[3] * spatial_scale; - T roi_end_h = offset_bottom_rois[4] * spatial_scale; - // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); - // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); - // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); - // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); - - // Force malformed ROIs to be 1x1 - T roi_width = max(roi_end_w - roi_start_w, (T)1.); - T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T roi_offset = continuous_coordinate ? static_cast(0.5) : static_cast(0); + T roi_start_w = offset_bottom_rois[1] * spatial_scale - roi_offset; + T roi_start_h = offset_bottom_rois[2] * spatial_scale - roi_offset; + T roi_end_w = offset_bottom_rois[3] * spatial_scale - roi_offset; + T roi_end_h = offset_bottom_rois[4] * spatial_scale - roi_offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!continuous_coordinate) { // backward compatiblity + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); @@ -397,6 +401,7 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs, bottom_data, param.spatial_scale, param.position_sensitive, + param.aligned, channels, height, width, @@ -466,6 +471,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs, num_rois, param.spatial_scale, param.position_sensitive, + param.aligned, channels, height, width,