Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Accelerate ROIPooling layer #14894

Merged
merged 13 commits into from
Jul 12, 2019
Merged
11 changes: 8 additions & 3 deletions src/operator/roi_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ROIPoolingOp : public Operator {
Tensor<xpu, 4, DType> data = in_data[roipool::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 2, DType> bbox = in_data[roipool::kBox].get<xpu, 2, DType>(s);
Tensor<xpu, 4, DType> out = out_data[roipool::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, DType>(s);
Tensor<xpu, 4, index_t> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, index_t>(s);
CHECK_EQ(data.CheckContiguous(), true);
CHECK_EQ(bbox.CheckContiguous(), true);
CHECK_EQ(out.CheckContiguous(), true);
Expand Down Expand Up @@ -114,7 +114,7 @@ class ROIPoolingOp : public Operator {

Tensor<xpu, 4, DType> grad_out = out_grad[roipool::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 2, DType> bbox = in_data[roipool::kBox].get<xpu, 2, DType>(s);
Tensor<xpu, 4, DType> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, DType>(s);
Tensor<xpu, 4, index_t> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, index_t>(s);
Tensor<xpu, 4, DType> grad_in = in_grad[roipool::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 2, DType> grad_roi = in_grad[roipool::kBox].get<xpu, 2, DType>(s);
CHECK_EQ(grad_out.CheckContiguous(), true);
Expand Down Expand Up @@ -195,14 +195,19 @@ class ROIPoolingProp : public OperatorProperty {
bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
using namespace mshadow;
CHECK_EQ(in_type->size(), 2U);
int dtype = (*in_type)[0];
CHECK_EQ(dtype, (*in_type)[1]);
CHECK_NE(dtype, -1) << "Input must have specified type";

out_type->clear();
out_type->push_back(dtype);
out_type->push_back(dtype);
# if MXNET_USE_INT64_TENSOR_SIZE == 1
out_type->push_back(kInt64);
# else
out_type->push_back(kInt32);
# endif
return true;
}

Expand Down
128 changes: 29 additions & 99 deletions src/operator/roi_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,38 +40,39 @@ template<typename Dtype>
inline void ROIPoolForward(const Tensor<cpu, 4, Dtype> &out,
const Tensor<cpu, 4, Dtype> &data,
const Tensor<cpu, 2, Dtype> &bbox,
const Tensor<cpu, 4, Dtype> &max_idx,
const Tensor<cpu, 4, index_t> &max_idx,
const float spatial_scale_) {
const Dtype *bottom_data = data.dptr_;
const Dtype *bottom_rois = bbox.dptr_;
Dtype *top_data = out.dptr_;
Dtype *argmax_data = max_idx.dptr_;
index_t *argmax_data = max_idx.dptr_;
const int batch_size = data.size(0);
const int channels_ = data.size(1);
const int height_ = data.size(2);
const int width_ = data.size(3);
const int pooled_height_ = out.size(2);
const int pooled_width_ = out.size(3);

const int num_rois = bbox.size(0);
const int data_size = data.size(1) * data.size(2) * data.size(3);
const int data_size_c = data.size(2) * data.size(3);
const int out_size_c = out.size(2) * out.size(3);
const int out_size = channels_ * out_size_c;
const int max_idx_size_c = max_idx.size(2) * max_idx.size(3);
const int max_idx_size = channels_ * max_idx_size_c;
const index_t data_size = data.size(1) * data.size(2) * data.size(3);
const index_t data_size_c = data.size(2) * data.size(3);
const index_t out_size_c = out.size(2) * out.size(3);
const index_t out_size = channels_ * out_size_c;
const index_t max_idx_size_c = max_idx.size(2) * max_idx.size(3);
const index_t max_idx_size = channels_ * max_idx_size_c;
// For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R
for (int n = 0; n < num_rois; ++n) {
// Increment ROI data pointer
const Dtype *bottom_rois_n = bottom_rois + n * bbox.size(1);
Dtype *top_data_n = top_data + n * out_size;
Dtype *argmax_data_n = argmax_data + n * max_idx_size;
int roi_batch_ind = bottom_rois_n[0];
index_t *argmax_data_n = argmax_data + n * max_idx_size;
int roi_start_w = std::round(bottom_rois_n[1] * spatial_scale_);
int roi_start_h = std::round(bottom_rois_n[2] * spatial_scale_);
int roi_end_w = std::round(bottom_rois_n[3] * spatial_scale_);
int roi_end_h = std::round(bottom_rois_n[4] * spatial_scale_);
assert(roi_batch_ind >= 0);
assert(static_cast<index_t>(roi_batch_ind) < data.size(0) /* batch size */);

int roi_batch_ind = static_cast<int>(bottom_rois_n[0]);
bool is_ind_invalid = (roi_batch_ind < 0) || (roi_batch_ind >= batch_size);

// force malformed ROIs to be 1 * 1
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
Expand All @@ -81,14 +82,15 @@ inline void ROIPoolForward(const Tensor<cpu, 4, Dtype> &out,
const Dtype bin_size_w = static_cast<Dtype>(roi_width)
/ static_cast<Dtype>(pooled_width_);

const Dtype* batch_data = bottom_data + data_size * roi_batch_ind;
index_t offset_batch_data = data_size * roi_batch_ind;

#pragma omp parallel for
for (int c = 0; c < channels_; ++c) {
// Increment all data pointers
const Dtype* batch_data_c = batch_data + c * data_size_c;
index_t offset_batch_data_c = offset_batch_data + c * data_size_c;
const Dtype* batch_data_c = bottom_data + offset_batch_data_c;
Dtype* top_data_c = top_data_n + c * out_size_c;
Dtype* argmax_data_c = argmax_data_n + c * max_idx_size_c;
index_t* argmax_data_c = argmax_data_n + c * max_idx_size_c;

for (int ph = 0; ph < pooled_height_; ++ph) {
for (int pw = 0; pw < pooled_width_; ++pw) {
Expand All @@ -111,18 +113,19 @@ inline void ROIPoolForward(const Tensor<cpu, 4, Dtype> &out,

bool is_empty = (hend <= hstart) || (wend <= wstart);

const int pool_index = ph * pooled_width_ + pw;
if (is_empty) {
const index_t pool_index = ph * pooled_width_ + pw;
if (is_empty || is_ind_invalid) {
top_data_c[pool_index] = 0;
argmax_data_c[pool_index] = -1;
continue;
}

for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
const int index = h * width_ + w;
const index_t index = h * width_ + w;
if (batch_data_c[index] > top_data_c[pool_index]) {
top_data_c[pool_index] = batch_data_c[index];
argmax_data_c[pool_index] = index;
argmax_data_c[pool_index] = offset_batch_data_c + index;
}
}
}
Expand All @@ -137,91 +140,18 @@ template<typename Dtype>
inline void ROIPoolBackwardAcc(const Tensor<cpu, 4, Dtype> &in_grad,
const Tensor<cpu, 4, Dtype> &out_grad,
const Tensor<cpu, 2, Dtype> &bbox,
const Tensor<cpu, 4, Dtype> &max_idx,
const Tensor<cpu, 4, index_t> &max_idx,
const float spatial_scale_) {
const Dtype *top_diff = out_grad.dptr_;
const Dtype *bottom_rois = bbox.dptr_;
Dtype *bottom_diff = in_grad.dptr_;
Dtype *argmax_data = max_idx.dptr_;

const int batch_size_ = in_grad.size(0);
const int channels_ = in_grad.size(1);
const int height_ = in_grad.size(2);
const int width_ = in_grad.size(3);
const int pooled_height_ = out_grad.size(2);
const int pooled_width_ = out_grad.size(3);

const int num_rois = bbox.size(0);

for (int b = 0; b < batch_size_; ++b) {
for (int c = 0; c < channels_; ++c) {
for (int h = 0; h < height_; ++h) {
for (int w = 0; w < width_; ++w) {
int offset_bottom_diff = (b * channels_ + c) * height_ * width_;
offset_bottom_diff += h * width_ + w;

Dtype gradient = 0;
// Accumulate gradient over all ROIs that pooled this element
for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5;
int roi_batch_ind = offset_bottom_rois[0];
assert(roi_batch_ind >= 0);
assert(roi_batch_ind < batch_size_);
if (b != roi_batch_ind) {
continue;
}
index_t *argmax_data = max_idx.dptr_;

int roi_start_w = std::round(offset_bottom_rois[1] * spatial_scale_);
int roi_start_h = std::round(offset_bottom_rois[2] * spatial_scale_);
int roi_end_w = std::round(offset_bottom_rois[3] * spatial_scale_);
int roi_end_h = std::round(offset_bottom_rois[4] * spatial_scale_);
const index_t count = out_grad.shape_.Size();

bool in_roi = (w >= roi_start_w && w <= roi_end_w &&
h >= roi_start_h && h <= roi_end_h);
if (!in_roi) {
continue;
}

// force malformed ROIs to be 1 * 1
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
const Dtype bin_size_h = static_cast<Dtype>(roi_height)
/ static_cast<Dtype>(pooled_height_);
const Dtype bin_size_w = static_cast<Dtype>(roi_width)
/ static_cast<Dtype>(pooled_width_);

// compute pooled regions correspond to original (h, w) point
int phstart = static_cast<int>(floor(static_cast<Dtype>(h - roi_start_h)
/ bin_size_h));
int pwstart = static_cast<int>(floor(static_cast<Dtype>(w - roi_start_w)
/ bin_size_w));
int phend = static_cast<int>(ceil(static_cast<Dtype>(h - roi_start_h + 1)
/ bin_size_h));
int pwend = static_cast<int>(ceil(static_cast<Dtype>(w - roi_start_w + 1)
/ bin_size_w));

// clip to boundaries of pooled region
phstart = min(max(phstart, 0), pooled_height_);
phend = min(max(phend, 0), pooled_height_);
pwstart = min(max(pwstart, 0), pooled_width_);
pwend = min(max(pwend, 0), pooled_width_);

// accumulate over gradients in pooled regions
int offset = (roi_n * channels_ + c) * pooled_height_ * pooled_width_;
const Dtype* offset_top_diff = top_diff + offset;
const Dtype* offset_argmax_data = argmax_data + offset;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
const int pooled_index = ph * pooled_width_ + pw;
if (static_cast<int>(offset_argmax_data[pooled_index]) == h * width_ + w) {
gradient += offset_top_diff[pooled_index];
}
}
}
}
bottom_diff[offset_bottom_diff] += gradient;
}
}
for (int index = 0; index < count; ++index) {
index_t max_idx = argmax_data[index];
if (max_idx >= 0) {
bottom_diff[max_idx] += top_diff[index];
sxjscience marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
Loading