diff --git a/src/operator/contrib/deformable_psroi_pooling-inl.h b/src/operator/contrib/deformable_psroi_pooling-inl.h index e466c065abbc..78124d2a26a6 100644 --- a/src/operator/contrib/deformable_psroi_pooling-inl.h +++ b/src/operator/contrib/deformable_psroi_pooling-inl.h @@ -51,11 +51,11 @@ namespace deformablepsroipool { struct DeformablePSROIPoolingParam : public dmlc::Parameter { // mxnet::TShape pooled_size; float spatial_scale; - int output_dim; - int group_size; - int pooled_size; - int part_size; - int sample_per_part; + index_t output_dim; + index_t group_size; + index_t pooled_size; + index_t part_size; + index_t sample_per_part; float trans_std; bool no_trans; DMLC_DECLARE_PARAMETER(DeformablePSROIPoolingParam) { @@ -82,10 +82,10 @@ class DeformablePSROIPoolingOp : public Operator { } virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { using namespace mshadow; size_t in_expected = param_.no_trans? 2 : 3; size_t out_expected = 2; @@ -119,12 +119,12 @@ class DeformablePSROIPoolingOp : public Operator { } virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { using namespace mshadow; size_t in_expected = param_.no_trans ? 2 : 3; size_t out_expected = 2; @@ -216,8 +216,8 @@ class DeformablePSROIPoolingProp : public OperatorProperty { } bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { + mxnet::ShapeVector *out_shape, + mxnet::ShapeVector *aux_shape) const override { using namespace mshadow; if (param_.no_trans) { CHECK_EQ(in_shape->size(), 2) << "Input:[data, rois]"; @@ -248,8 +248,8 @@ class DeformablePSROIPoolingProp : public OperatorProperty { } bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { + std::vector *out_type, + std::vector *aux_type) const override { CHECK_GE(in_type->size(), 2); int dtype = (*in_type)[0]; CHECK_EQ(dtype, (*in_type)[1]); @@ -272,10 +272,9 @@ class DeformablePSROIPoolingProp : public OperatorProperty { } // decalre dependency and inplace optimization options - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { + std::vector DeclareBackwardDependency(const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { if (param_.no_trans) { return{ out_grad[deformablepsroipool::kOut], in_data[deformablepsroipool::kData], in_data[deformablepsroipool::kBox], out_data[deformablepsroipool::kTopCount] }; @@ -292,8 +291,9 @@ class DeformablePSROIPoolingProp : public OperatorProperty { return NULL; } - Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const override; + Operator* CreateOperatorEx(Context ctx, + mxnet::ShapeVector *in_shape, + std::vector *in_type) const override; private: diff --git a/src/operator/contrib/deformable_psroi_pooling.cc b/src/operator/contrib/deformable_psroi_pooling.cc index d9d4cf8f78c5..697376dd573f 100644 --- a/src/operator/contrib/deformable_psroi_pooling.cc +++ b/src/operator/contrib/deformable_psroi_pooling.cc @@ -35,43 +35,309 @@ using std::max; using std::min; using std::floor; using std::ceil; +using std::round; namespace mshadow { + + template + inline DType bilinear_interp_cpu(const DType* data, + const DType x, const DType y, + const index_t width, const index_t height) { + index_t x1 = floor(x); + index_t x2 = ceil(x); + index_t y1 = floor(y); + index_t y2 = ceil(y); + DType dist_x = static_cast(x - x1); + DType dist_y = static_cast(y - y1); + DType value11 = data[y1 * width + x1]; + DType value12 = data[y2 * width + x1]; + DType value21 = data[y1 * width + x2]; + DType value22 = data[y2 * width + x2]; + DType value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; + return value; + } + + template + inline void DeformablePSROIPoolForwardCPU(const index_t count, const DType* bottom_data, + const DType spatial_scale, const index_t channels, + const index_t height, const index_t width, + const index_t pooled_height, const index_t pooled_width, + const DType* bottom_rois, const DType* bottom_trans, + const bool no_trans, const DType trans_std, + const index_t sample_per_part, const index_t output_dim, + const index_t group_size, const index_t part_size, + const index_t num_classes, + const index_t channels_each_class, + DType* top_data, DType* top_count) { + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); +#pragma omp parallel for num_threads(omp_threads) + for (index_t index = 0; index < count; index++) { + // The output is in order (n, ctop, ph, pw) + index_t pw = index % pooled_width; + index_t ph = (index / pooled_width) % pooled_height; + index_t ctop = (index / pooled_width / pooled_height) % output_dim; + index_t n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const DType* offset_bottom_rois = bottom_rois + n * 5; + index_t roi_batch_ind = offset_bottom_rois[0]; + DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + DType roi_width = max(roi_end_w - roi_start_w, static_cast(0.1)); // avoid 0 + DType roi_height = max(roi_end_h - roi_start_h, static_cast(0.1)); + + // Compute w and h at bottom + DType bin_size_h = roi_height / static_cast(pooled_height); + DType bin_size_w = roi_width / static_cast(pooled_width); + + DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); + + index_t part_h = floor(static_cast(ph) / pooled_height * part_size); + index_t part_w = floor(static_cast(pw) / pooled_width * part_size); + index_t class_id = ctop / channels_each_class; + DType trans_x = no_trans ? static_cast(0) : + bottom_trans[(((n * num_classes + class_id) * 2) + * part_size + part_h) + * part_size + part_w] * trans_std; + DType trans_y = no_trans ? static_cast(0) : + bottom_trans[(((n * num_classes + class_id) * 2 + 1) + * part_size + part_h) + * part_size + part_w] * trans_std; + + DType wstart = static_cast(pw) * bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + DType hstart = static_cast(ph) * bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + DType sum = 0; + index_t count = 0; + index_t gw = floor(static_cast(pw) * group_size / pooled_width); + index_t gh = floor(static_cast(ph) * group_size / pooled_height); + gw = min(max(gw, static_cast(0)), group_size - 1); + gh = min(max(gh, static_cast(0)), group_size - 1); + + const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; + for (index_t ih = 0; ih < sample_per_part; ih++) { + for (index_t iw = 0; iw < sample_per_part; iw++) { + DType w = wstart + iw * sub_bin_size_w; + DType h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { + continue; + } + w = min(max(w, static_cast(0)), static_cast(width - 1)); + h = min(max(h, static_cast(0)), static_cast(height - 1)); + index_t c = (ctop * group_size + gh) * group_size + gw; + DType val = bilinear_interp_cpu(offset_bottom_data + c * height * width, + w, h, width, height); + sum += val; + count++; + } + } + top_data[index] = count == 0 ? static_cast(0) : sum / count; + top_count[index] = count; + } + } + template inline void DeformablePSROIPoolForward(const Tensor &out, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) { - // NOT_IMPLEMENTED; + const Tensor &data, + const Tensor &bbox, + const Tensor &trans, + const Tensor &top_count, + const bool no_trans, const float spatial_scale, + const index_t output_dim, const index_t group_size, + const index_t pooled_size, const index_t part_size, + const index_t sample_per_part, const float trans_std) { + const DType *bottom_data = data.dptr_; + const DType *bottom_rois = bbox.dptr_; + const DType *bottom_trans = no_trans ? nullptr : trans.dptr_; + DType *top_data = out.dptr_; + DType *top_count_data = top_count.dptr_; + const index_t count = out.shape_.Size(); + const index_t channels = data.size(1); + const index_t height = data.size(2); + const index_t width = data.size(3); + const index_t pooled_height = pooled_size; + const index_t pooled_width = pooled_size; + const index_t num_classes = no_trans ? 1 : trans.size(1) / 2; + const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; + DeformablePSROIPoolForwardCPU(count, bottom_data, spatial_scale, + channels, height, width, + pooled_height, pooled_width, + bottom_rois, bottom_trans, + no_trans, trans_std, sample_per_part, + output_dim, group_size, part_size, num_classes, + channels_each_class, top_data, top_count_data); + return; } + template + inline void DeformablePSROIPoolBackwardAccCPU(const index_t count, const DType* top_diff, + const DType* top_count, const index_t num_rois, + const DType spatial_scale, const index_t channels, + const index_t height, const index_t width, + const index_t pooled_height, + const index_t pooled_width, + const index_t output_dim, + DType* bottom_data_diff, + DType* bottom_trans_diff, + const DType* bottom_data, + const DType* bottom_rois, + const DType* bottom_trans, + const bool no_trans, + const DType trans_std, + const index_t sample_per_part, + const index_t group_size, + const index_t part_size, + const index_t num_classes, + const index_t channels_each_class) { + for (index_t index = 0; index < count; index++) { + // The output is in order (n, ctop, ph, pw) + index_t pw = index % pooled_width; + index_t ph = (index / pooled_width) % pooled_height; + index_t ctop = (index / pooled_width / pooled_height) % output_dim; + index_t n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const DType* offset_bottom_rois = bottom_rois + n * 5; + index_t roi_batch_ind = offset_bottom_rois[0]; + DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + DType roi_width = max(roi_end_w - roi_start_w, static_cast(0.1)); // avoid 0 + DType roi_height = max(roi_end_h - roi_start_h, static_cast(0.1)); + + // Compute w and h at bottom + DType bin_size_h = roi_height / static_cast(pooled_height); + DType bin_size_w = roi_width / static_cast(pooled_width); + + DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); + + index_t part_h = floor(static_cast(ph) / pooled_height * part_size); + index_t part_w = floor(static_cast(pw) / pooled_width * part_size); + index_t class_id = ctop / channels_each_class; + DType trans_x = no_trans ? static_cast(0) : + bottom_trans[(((n * num_classes + class_id) * 2) + * part_size + part_h) + * part_size + part_w] * trans_std; + DType trans_y = no_trans ? static_cast(0) : + bottom_trans[(((n * num_classes + class_id) * 2 + 1) + * part_size + part_h) + * part_size + part_w] * trans_std; + + DType wstart = static_cast(pw) * bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + DType hstart = static_cast(ph) * bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + if (top_count[index] <= 0) { + continue; + } + DType diff_val = top_diff[index] / top_count[index]; + const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; + DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; + index_t gw = floor(static_cast(pw)* group_size / pooled_width); + index_t gh = floor(static_cast(ph)* group_size / pooled_height); + gw = min(max(gw, static_cast(0)), group_size - 1); + gh = min(max(gh, static_cast(0)), group_size - 1); + + for (index_t ih = 0; ih < sample_per_part; ih++) { + for (index_t iw = 0; iw < sample_per_part; iw++) { + DType w = wstart + iw * sub_bin_size_w; + DType h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { + continue; + } + w = min(max(w, static_cast(0)), static_cast(width - 1)); + h = min(max(h, static_cast(0)), static_cast(height - 1)); + index_t c = (ctop * group_size + gh) * group_size + gw; + // backward on feature + index_t x0 = floor(w); + index_t x1 = ceil(w); + index_t y0 = floor(h); + index_t y1 = ceil(h); + DType dist_x = w - x0, dist_y = h - y0; + DType q00 = (1 - dist_x) * (1 - dist_y); + DType q01 = (1 - dist_x) * dist_y; + DType q10 = dist_x * (1 - dist_y); + DType q11 = dist_x * dist_y; + index_t bottom_index_base = c * height * width; + offset_bottom_data_diff[bottom_index_base + y0 * width + x0] += q00 * diff_val; + offset_bottom_data_diff[bottom_index_base + y1 * width + x0] += q01 * diff_val; + offset_bottom_data_diff[bottom_index_base + y0 * width + x1] += q10 * diff_val; + offset_bottom_data_diff[bottom_index_base + y1 * width + x1] += q11 * diff_val; + + if (no_trans) { + continue; + } + DType U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + DType U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + DType U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + DType U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + DType diff_x = U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y); + diff_x *= trans_std * diff_val * roi_width; + DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x); + diff_y *= trans_std * diff_val * roi_height; + + index_t offset_trans_diff = (((n * num_classes + class_id) * 2) + * part_size + part_h) * part_size + part_w; + bottom_trans_diff[offset_trans_diff] += diff_x; + bottom_trans_diff[offset_trans_diff + part_size * part_size] += diff_y; + } + } + } + } + template inline void DeformablePSROIPoolBackwardAcc(const Tensor &in_grad, - const Tensor &trans_grad, - const Tensor &out_grad, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) { - // NOT_IMPLEMENTED; + const Tensor &trans_grad, + const Tensor &out_grad, + const Tensor &data, + const Tensor &bbox, + const Tensor &trans, + const Tensor &top_count, + const bool no_trans, const float spatial_scale, + const index_t output_dim, const index_t group_size, + const index_t pooled_size, const index_t part_size, + const index_t sample_per_part, const float trans_std) { + const DType *top_diff = out_grad.dptr_; + const DType *bottom_data = data.dptr_; + const DType *bottom_rois = bbox.dptr_; + const DType *bottom_trans = no_trans ? nullptr : trans.dptr_; + DType *bottom_data_diff = in_grad.dptr_; + DType *bottom_trans_diff = no_trans ? nullptr : trans_grad.dptr_; + const DType *top_count_data = top_count.dptr_; + const index_t count = out_grad.shape_.Size(); + const index_t num_rois = bbox.size(0); + const index_t channels = in_grad.size(1); + const index_t height = in_grad.size(2); + const index_t width = in_grad.size(3); + const index_t pooled_height = pooled_size; + const index_t pooled_width = pooled_size; + const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2; + const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; + DeformablePSROIPoolBackwardAccCPU(count, top_diff, top_count_data, num_rois, + spatial_scale, channels, height, width, + pooled_height, pooled_width, output_dim, + bottom_data_diff, bottom_trans_diff, + bottom_data, bottom_rois, bottom_trans, + no_trans, trans_std, sample_per_part, + group_size, part_size, num_classes, + channels_each_class); + return; } } // namespace mshadow @@ -88,9 +354,9 @@ namespace op { return op; } - Operator *DeformablePSROIPoolingProp::CreateOperatorEx( - Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const { + Operator *DeformablePSROIPoolingProp::CreateOperatorEx(Context ctx, + mxnet::ShapeVector *in_shape, + std::vector *in_type) const { mxnet::ShapeVector out_shape, aux_shape; std::vector out_type, aux_type; CHECK(InferType(in_type, &out_type, &aux_type)); diff --git a/src/operator/contrib/deformable_psroi_pooling.cu b/src/operator/contrib/deformable_psroi_pooling.cu index bf7d1c0bc755..6c89746b43ab 100644 --- a/src/operator/contrib/deformable_psroi_pooling.cu +++ b/src/operator/contrib/deformable_psroi_pooling.cu @@ -46,56 +46,52 @@ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ namespace mshadow { namespace cuda { template - __device__ DType bilinear_interp( - const DType* data, - const DType x, - const DType y, - const int width, - const int height) { - int x1 = floor(x); - int x2 = ceil(x); - int y1 = floor(y); - int y2 = ceil(y); + __device__ DType bilinear_interp(const DType* data, + const DType x, const DType y, + const index_t width, const index_t height) { + index_t x1 = floor(x); + index_t x2 = ceil(x); + index_t y1 = floor(y); + index_t y2 = ceil(y); DType dist_x = static_cast(x - x1); DType dist_y = static_cast(y - y1); - DType value11 = data[y1*width + x1]; - DType value12 = data[y2*width + x1]; - DType value21 = data[y1*width + x2]; - DType value22 = data[y2*width + x2]; - DType value = (1 - dist_x)*(1 - dist_y)*value11 + (1 - dist_x)*dist_y*value12 - + dist_x*(1 - dist_y)*value21 + dist_x*dist_y*value22; + DType value11 = data[y1 * width + x1]; + DType value12 = data[y2 * width + x1]; + DType value21 = data[y1 * width + x2]; + DType value22 = data[y2 * width + x2]; + DType value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; return value; } template - __global__ void DeformablePSROIPoolForwardKernel( - const int count, - const DType* bottom_data, - const DType spatial_scale, - const int channels, - const int height, const int width, - const int pooled_height, const int pooled_width, - const DType* bottom_rois, const DType* bottom_trans, - const bool no_trans, - const DType trans_std, - const int sample_per_part, - const int output_dim, - const int group_size, - const int part_size, - const int num_classes, - const int channels_each_class, - DType* top_data, - DType* top_count) { + __global__ void DeformablePSROIPoolForwardKernel(const index_t count, + const DType* bottom_data, + const DType spatial_scale, + const index_t channels, + const index_t height, const index_t width, + const index_t pooled_height, + const index_t pooled_width, + const DType* bottom_rois, + const DType* bottom_trans, + const bool no_trans, const DType trans_std, + const index_t sample_per_part, + const index_t output_dim, + const index_t group_size, + const index_t part_size, + const index_t num_classes, + const index_t channels_each_class, + DType* top_data, DType* top_count) { CUDA_KERNEL_LOOP(index, count) { // The output is in order (n, ctop, ph, pw) - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int ctop = (index / pooled_width / pooled_height) % output_dim; - int n = index / pooled_width / pooled_height / output_dim; + index_t pw = index % pooled_width; + index_t ph = (index / pooled_width) % pooled_height; + index_t ctop = (index / pooled_width / pooled_height) % output_dim; + index_t n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling const DType* offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; + index_t roi_batch_ind = offset_bottom_rois[0]; DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; @@ -112,9 +108,9 @@ namespace cuda { DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); - int part_h = floor(static_cast(ph) / pooled_height*part_size); - int part_w = floor(static_cast(pw) / pooled_width*part_size); - int class_id = ctop / channels_each_class; + index_t part_h = floor(static_cast(ph) / pooled_height * part_size); + index_t part_w = floor(static_cast(pw) / pooled_width * part_size); + index_t class_id = ctop / channels_each_class; DType trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) @@ -124,33 +120,32 @@ namespace cuda { * part_size + part_h) * part_size + part_w] * trans_std; - DType wstart = static_cast(pw)* bin_size_w - + roi_start_w; + DType wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; - DType hstart = static_cast(ph) * bin_size_h - + roi_start_h; + DType hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; DType sum = 0; - int count = 0; - int gw = floor(static_cast(pw) * group_size / pooled_width); - int gh = floor(static_cast(ph)* group_size / pooled_height); - gw = min(max(gw, 0), group_size - 1); - gh = min(max(gh, 0), group_size - 1); + index_t count = 0; + index_t gw = floor(static_cast(pw) * group_size / pooled_width); + index_t gh = floor(static_cast(ph) * group_size / pooled_height); + gw = min(max(gw, static_cast(0)), group_size - 1); + gh = min(max(gh, static_cast(0)), group_size - 1); const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; - for (int ih = 0; ih < sample_per_part; ih++) { - for (int iw = 0; iw < sample_per_part; iw++) { - DType w = wstart + iw*sub_bin_size_w; - DType h = hstart + ih*sub_bin_size_h; + for (index_t ih = 0; ih < sample_per_part; ih++) { + for (index_t iw = 0; iw < sample_per_part; iw++) { + DType w = wstart + iw * sub_bin_size_w; + DType h = hstart + ih * sub_bin_size_h; // bilinear interpolation - if (w<-0.5 || w>width - 0.5 || h<-0.5 || h>height - 0.5) { + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { continue; } w = min(max(w, 0.), width - 1.); h = min(max(h, 0.), height - 1.); - int c = (ctop*group_size + gh)*group_size + gw; - DType val = bilinear_interp(offset_bottom_data + c*height*width, w, h, width, height); + index_t c = (ctop * group_size + gh) * group_size + gw; + DType val = bilinear_interp(offset_bottom_data + c * height * width, + w, h, width, height); sum += val; count++; } @@ -162,75 +157,74 @@ namespace cuda { template inline void DeformablePSROIPoolForward(const Tensor &out, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) { - // LOG(INFO) << "DeformablePSROIPoolForward"; + const Tensor &data, + const Tensor &bbox, + const Tensor &trans, + const Tensor &top_count, + const bool no_trans, const float spatial_scale, + const index_t output_dim, const index_t group_size, + const index_t pooled_size, const index_t part_size, + const index_t sample_per_part, const float trans_std) { const DType *bottom_data = data.dptr_; const DType *bottom_rois = bbox.dptr_; const DType *bottom_trans = no_trans ? NULL : trans.dptr_; DType *top_data = out.dptr_; DType *top_count_data = top_count.dptr_; - const int count = out.shape_.Size(); - const int channels = data.size(1); - const int height = data.size(2); - const int width = data.size(3); - const int pooled_height = pooled_size; - const int pooled_width = pooled_size; - const int num_classes = no_trans ? 1 : trans.size(1) / 2; - const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + const index_t count = out.shape_.Size(); + const index_t channels = data.size(1); + const index_t height = data.size(2); + const index_t width = data.size(3); + const index_t pooled_height = pooled_size; + const index_t pooled_width = pooled_size; + const index_t num_classes = no_trans ? 1 : trans.size(1) / 2; + const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; cudaStream_t stream = Stream::GetStream(out.stream_); - DeformablePSROIPoolForwardKernel << > >( - count, bottom_data, spatial_scale, channels, height, width, pooled_height, pooled_width, - bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, output_dim, - group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); + DeformablePSROIPoolForwardKernel<<< + mxnet::op::mxnet_op::cuda_get_num_blocks(count), kBaseThreadNum, + 0, stream>>>(count, bottom_data, spatial_scale, channels, height, width, + pooled_height, pooled_width, bottom_rois, bottom_trans, + no_trans, trans_std, sample_per_part, output_dim, + group_size, part_size, num_classes, + channels_each_class, top_data, top_count_data); DeformablePSROIPOOLING_CUDA_CHECK(cudaPeekAtLastError()); } template - __global__ void DeformablePSROIPoolBackwardAccKernel( - const int count, - const DType* top_diff, - const DType* top_count, - const int num_rois, - const DType spatial_scale, - const int channels, - const int height, const int width, - const int pooled_height, const int pooled_width, - const int output_dim, - DType* bottom_data_diff, DType* bottom_trans_diff, - const DType* bottom_data, - const DType* bottom_rois, - const DType* bottom_trans, - const bool no_trans, - const DType trans_std, - const int sample_per_part, - const int group_size, - const int part_size, - const int num_classes, - const int channels_each_class) { + __global__ void DeformablePSROIPoolBackwardAccKernel(const index_t count, + const DType* top_diff, + const DType* top_count, + const index_t num_rois, + const DType spatial_scale, + const index_t channels, + const index_t height, + const index_t width, + const index_t pooled_height, + const index_t pooled_width, + const index_t output_dim, + DType* bottom_data_diff, + DType* bottom_trans_diff, + const DType* bottom_data, + const DType* bottom_rois, + const DType* bottom_trans, + const bool no_trans, + const DType trans_std, + const index_t sample_per_part, + const index_t group_size, + const index_t part_size, + const index_t num_classes, + const index_t channels_each_class) { CUDA_KERNEL_LOOP(index, count) { // The output is in order (n, ctop, ph, pw) - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int ctop = (index / pooled_width / pooled_height) % output_dim; - int n = index / pooled_width / pooled_height / output_dim; + index_t pw = index % pooled_width; + index_t ph = (index / pooled_width) % pooled_height; + index_t ctop = (index / pooled_width / pooled_height) % output_dim; + index_t n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling const DType* offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; + index_t roi_batch_ind = offset_bottom_rois[0]; DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; @@ -247,9 +241,9 @@ namespace cuda { DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); - int part_h = floor(static_cast(ph) / pooled_height*part_size); - int part_w = floor(static_cast(pw) / pooled_width*part_size); - int class_id = ctop / channels_each_class; + index_t part_h = floor(static_cast(ph) / pooled_height * part_size); + index_t part_w = floor(static_cast(pw) / pooled_width * part_size); + index_t class_id = ctop / channels_each_class; DType trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) @@ -259,11 +253,9 @@ namespace cuda { * part_size + part_h) * part_size + part_w] * trans_std; - DType wstart = static_cast(pw)* bin_size_w - + roi_start_w; + DType wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; - DType hstart = static_cast(ph) * bin_size_h - + roi_start_h; + DType hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; if (top_count[index] <= 0) { @@ -272,51 +264,49 @@ namespace cuda { DType diff_val = top_diff[index] / top_count[index]; const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; - int gw = floor(static_cast(pw)* group_size / pooled_width); - int gh = floor(static_cast(ph)* group_size / pooled_height); - gw = min(max(gw, 0), group_size - 1); - gh = min(max(gh, 0), group_size - 1); - - for (int ih = 0; ih < sample_per_part; ih++) { - for (int iw = 0; iw < sample_per_part; iw++) { - DType w = wstart + iw*sub_bin_size_w; - DType h = hstart + ih*sub_bin_size_h; + index_t gw = floor(static_cast(pw) * group_size / pooled_width); + index_t gh = floor(static_cast(ph) * group_size / pooled_height); + gw = min(max(gw, static_cast(0)), group_size - 1); + gh = min(max(gh, static_cast(0)), group_size - 1); + + for (index_t ih = 0; ih < sample_per_part; ih++) { + for (index_t iw = 0; iw < sample_per_part; iw++) { + DType w = wstart + iw * sub_bin_size_w; + DType h = hstart + ih * sub_bin_size_h; // bilinear interpolation - if (w<-0.5 || w>width - 0.5 || h<-0.5 || h>height - 0.5) { + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { continue; } w = min(max(w, 0.), width - 1.); h = min(max(h, 0.), height - 1.); - int c = (ctop*group_size + gh)*group_size + gw; + index_t c = (ctop * group_size + gh) * group_size + gw; // backward on feature - int x0 = floor(w); - int x1 = ceil(w); - int y0 = floor(h); - int y1 = ceil(h); + index_t x0 = floor(w); + index_t x1 = ceil(w); + index_t y0 = floor(h); + index_t y1 = ceil(h); DType dist_x = w - x0, dist_y = h - y0; - DType q00 = (1 - dist_x)*(1 - dist_y); - DType q01 = (1 - dist_x)*dist_y; - DType q10 = dist_x*(1 - dist_y); - DType q11 = dist_x*dist_y; - int bottom_index_base = c * height *width; - atomicAdd(offset_bottom_data_diff + bottom_index_base + y0*width + x0, q00*diff_val); - atomicAdd(offset_bottom_data_diff + bottom_index_base + y1*width + x0, q01*diff_val); - atomicAdd(offset_bottom_data_diff + bottom_index_base + y0*width + x1, q10*diff_val); - atomicAdd(offset_bottom_data_diff + bottom_index_base + y1*width + x1, q11*diff_val); + DType q00 = (1 - dist_x) * (1 - dist_y); + DType q01 = (1 - dist_x) * dist_y; + DType q10 = dist_x * (1 - dist_y); + DType q11 = dist_x * dist_y; + index_t bottom_index_base = c * height * width; + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); if (no_trans) { continue; } - DType U00 = offset_bottom_data[bottom_index_base + y0*width + x0]; - DType U01 = offset_bottom_data[bottom_index_base + y1*width + x0]; - DType U10 = offset_bottom_data[bottom_index_base + y0*width + x1]; - DType U11 = offset_bottom_data[bottom_index_base + y1*width + x1]; - DType diff_x = (U11*dist_y + U10*(1 - dist_y) - U01*dist_y - U00*(1 - dist_y)) - *trans_std*diff_val; - diff_x *= roi_width; - DType diff_y = (U11*dist_x + U01*(1 - dist_x) - U10*dist_x - U00*(1 - dist_x)) - *trans_std*diff_val; - diff_y *= roi_height; + DType U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + DType U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + DType U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + DType U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + DType diff_x = U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y); + diff_x *= trans_std * diff_val * roi_width; + DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x); + diff_y *= trans_std * diff_val * roi_height; atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) @@ -332,21 +322,16 @@ namespace cuda { template inline void DeformablePSROIPoolBackwardAcc(const Tensor &in_grad, - const Tensor &trans_grad, - const Tensor &out_grad, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) { - // LOG(INFO) << "DeformablePSROIPoolBackward"; + const Tensor &trans_grad, + const Tensor &out_grad, + const Tensor &data, + const Tensor &bbox, + const Tensor &trans, + const Tensor &top_count, + const bool no_trans, const float spatial_scale, + const index_t output_dim, const index_t group_size, + const index_t pooled_size, const index_t part_size, + const index_t sample_per_part, const float trans_std) { const DType *top_diff = out_grad.dptr_; const DType *bottom_data = data.dptr_; const DType *bottom_rois = bbox.dptr_; @@ -354,23 +339,25 @@ namespace cuda { DType *bottom_data_diff = in_grad.dptr_; DType *bottom_trans_diff = no_trans ? NULL : trans_grad.dptr_; const DType *top_count_data = top_count.dptr_; - const int count = out_grad.shape_.Size(); - const int num_rois = bbox.size(0); - const int channels = in_grad.size(1); - const int height = in_grad.size(2); - const int width = in_grad.size(3); - const int pooled_height = pooled_size; - const int pooled_width = pooled_size; - const int num_classes = no_trans ? 1 : trans_grad.size(1) / 2; - const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + const index_t count = out_grad.shape_.Size(); + const index_t num_rois = bbox.size(0); + const index_t channels = in_grad.size(1); + const index_t height = in_grad.size(2); + const index_t width = in_grad.size(3); + const index_t pooled_height = pooled_size; + const index_t pooled_width = pooled_size; + const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2; + const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; cudaStream_t stream = Stream::GetStream(in_grad.stream_); - DeformablePSROIPoolBackwardAccKernel << > >( - count, top_diff, top_count_data, num_rois, spatial_scale, channels, height, width, - pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, - bottom_data, bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, - group_size, part_size, num_classes, channels_each_class); + DeformablePSROIPoolBackwardAccKernel<<< + mxnet::op::mxnet_op::cuda_get_num_blocks(count), kBaseThreadNum, + 0, stream >>>(count, top_diff, top_count_data, num_rois, spatial_scale, + channels, height, width, pooled_height, pooled_width, + output_dim, bottom_data_diff, bottom_trans_diff, + bottom_data, bottom_rois, bottom_trans, + no_trans, trans_std, sample_per_part, group_size, + part_size, num_classes, channels_each_class); DeformablePSROIPOOLING_CUDA_CHECK(cudaPeekAtLastError()); } @@ -378,41 +365,36 @@ namespace cuda { template inline void DeformablePSROIPoolForward(const Tensor &out, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) { - cuda::DeformablePSROIPoolForward(out, data, bbox, trans, top_count, no_trans, spatial_scale, - output_dim, group_size, pooled_size, part_size, sample_per_part, trans_std); + const Tensor &data, + const Tensor &bbox, + const Tensor &trans, + const Tensor &top_count, + const bool no_trans, const float spatial_scale, + const index_t output_dim, const index_t group_size, + const index_t pooled_size, const index_t part_size, + const index_t sample_per_part, const float trans_std) { + cuda::DeformablePSROIPoolForward(out, data, bbox, trans, top_count, + no_trans, spatial_scale, output_dim, + group_size, pooled_size, part_size, + sample_per_part, trans_std); } template inline void DeformablePSROIPoolBackwardAcc(const Tensor &in_grad, - const Tensor &trans_grad, - const Tensor &out_grad, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) { - cuda::DeformablePSROIPoolBackwardAcc(in_grad, trans_grad, out_grad, data, bbox, trans, - top_count, no_trans, spatial_scale, output_dim, group_size, pooled_size, part_size, - sample_per_part, trans_std); + const Tensor &trans_grad, + const Tensor &out_grad, + const Tensor &data, + const Tensor &bbox, + const Tensor &trans, + const Tensor &top_count, + const bool no_trans, const float spatial_scale, + const index_t output_dim, const index_t group_size, + const index_t pooled_size, const index_t part_size, + const index_t sample_per_part, const float trans_std) { + cuda::DeformablePSROIPoolBackwardAcc(in_grad, trans_grad, out_grad, data, bbox, + trans, top_count, no_trans, spatial_scale, + output_dim, group_size, pooled_size, + part_size, sample_per_part, trans_std); } } // namespace mshadow diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 9c88dc15488c..9c004cdfdab1 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1633,6 +1633,24 @@ def test_deformable_psroipooling_with_type(): 'deformable_psroipool_trans': (2, 4, 3, 3), 'type_dict': {'deformable_psroipool_data': np.float16, 'deformable_psroipool_rois': np.float16, 'deformable_psroipool_trans': np.float16}}, + {'ctx': mx.cpu(0), + 'deformable_psroipool_data': (1, 18, 14, 14), + 'deformable_psroipool_rois': (2, 5), + 'deformable_psroipool_trans': (2, 4, 3, 3), + 'type_dict': {'deformable_psroipool_data': np.float64, 'deformable_psroipool_rois': np.float64, + 'deformable_psroipool_trans': np.float64}}, + {'ctx': mx.cpu(0), + 'deformable_psroipool_data': (1, 18, 14, 14), + 'deformable_psroipool_rois': (2, 5), + 'deformable_psroipool_trans': (2, 4, 3, 3), + 'type_dict': {'deformable_psroipool_data': np.float32, 'deformable_psroipool_rois': np.float32, + 'deformable_psroipool_trans': np.float32}}, + {'ctx': mx.cpu(0), + 'deformable_psroipool_data': (1, 18, 14, 14), + 'deformable_psroipool_rois': (2, 5), + 'deformable_psroipool_trans': (2, 4, 3, 3), + 'type_dict': {'deformable_psroipool_data': np.float16, 'deformable_psroipool_rois': np.float16, + 'deformable_psroipool_trans': np.float16}}, ] check_consistency(sym, ctx_list, scale=0.1, tol=tol,