Skip to content

Commit

Permalink
Add Psroipooling CPU implementation (apache#12738)
Browse files Browse the repository at this point in the history
* add psroipooling cpu impl

* minor fix

* revert copyright

* fix testcase

* add openmp

* no openmp for backward
  • Loading branch information
zhreshold authored and lanking520 committed Oct 24, 2018
1 parent 5b01eeb commit 27b9e09
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 8 deletions.
176 changes: 172 additions & 4 deletions src/operator/contrib/psroi_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,193 @@ using std::floor;
using std::ceil;

namespace mshadow {

template <typename DType>
inline void PSROIPoolForwardCPU(
const int count,
const DType* bottom_data,
const DType spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const DType* bottom_rois,
const int output_dim,
const int group_size,
DType* top_data) {
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
#pragma omp parallel for num_threads(omp_threads)
for (int index = 0; index < count; index++) {
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;

// [start, end) interval for spatial sampling
const DType* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale;
DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale;
DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale;
DType roi_end_h = static_cast<DType>(round(offset_bottom_rois[4]) + 1.) * spatial_scale;

// Force too small ROIs to be 1x1
DType roi_width = max(roi_end_w - roi_start_w, static_cast<DType>(0.1)); // avoid 0
DType roi_height = max(roi_end_h - roi_start_h, static_cast<DType>(0.1));

// Compute w and h at bottom
DType bin_size_h = roi_height / static_cast<DType>(pooled_height);
DType bin_size_w = roi_width / static_cast<DType>(pooled_width);

int hstart = floor(static_cast<DType>(ph) * bin_size_h
+ roi_start_h);
int wstart = floor(static_cast<DType>(pw)* bin_size_w
+ roi_start_w);
int hend = ceil(static_cast<DType>(ph + 1) * bin_size_h
+ roi_start_h);
int wend = ceil(static_cast<DType>(pw + 1) * bin_size_w
+ roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart, 0), height);
hend = min(max(hend, 0), height);
wstart = min(max(wstart, 0), width);
wend = min(max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);

int gw = floor(static_cast<DType>(pw)* group_size / pooled_width);
int gh = floor(static_cast<DType>(ph)* group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);
int c = (ctop*group_size + gh)*group_size + gw;

const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
DType out_sum = 0;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int bottom_index = h*width + w;
out_sum += offset_bottom_data[bottom_index];
}
}

DType bin_area = (hend - hstart)*(wend - wstart);
top_data[index] = is_empty? (DType)0. : out_sum/bin_area;
}
}

template<typename DType>
inline void PSROIPoolForward(const Tensor<cpu, 4, DType> &out,
const Tensor<cpu, 4, DType> &data,
const Tensor<cpu, 2, DType> &bbox,
const float spatial_scale_,
const float spatial_scale,
const int output_dim_,
const int group_size_) {
// NOT_IMPLEMENTED;
const DType *bottom_data = data.dptr_;
const DType *bottom_rois = bbox.dptr_;
DType *top_data = out.dptr_;
const int count = out.shape_.Size();
const int channels = data.size(1);
const int height = data.size(2);
const int width = data.size(3);
const int pooled_height = out.size(2);
const int pooled_width = out.size(3);
PSROIPoolForwardCPU <DType> (
count, bottom_data, spatial_scale, channels, height, width,
pooled_height, pooled_width, bottom_rois, output_dim_, group_size_, top_data);

return;
}

template <typename DType>
inline void PSROIPoolBackwardAccCPU(
const int count,
const DType* top_diff,
const int num_rois,
const DType spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int group_size,
const int output_dim,
DType* bottom_diff,
const DType* bottom_rois) {
for (int index = 0; index < count; index++) {
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;

// [start, end) interval for spatial sampling
const DType* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale;
DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale;
DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale;
DType roi_end_h = static_cast<DType>(round(offset_bottom_rois[4]) + 1.) * spatial_scale;

// Force too small ROIs to be 1x1
DType roi_width = max(roi_end_w - roi_start_w, static_cast<DType>(0.1)); // avoid 0
DType roi_height = max(roi_end_h - roi_start_h, static_cast<DType>(0.1));

// Compute w and h at bottom
DType bin_size_h = roi_height / static_cast<DType>(pooled_height);
DType bin_size_w = roi_width / static_cast<DType>(pooled_width);

int hstart = floor(static_cast<DType>(ph)* bin_size_h
+ roi_start_h);
int wstart = floor(static_cast<DType>(pw)* bin_size_w
+ roi_start_w);
int hend = ceil(static_cast<DType>(ph + 1) * bin_size_h
+ roi_start_h);
int wend = ceil(static_cast<DType>(pw + 1) * bin_size_w
+ roi_start_w);
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart, 0), height);
hend = min(max(hend, 0), height);
wstart = min(max(wstart, 0), width);
wend = min(max(wend, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Compute c at bottom
int gw = floor(static_cast<DType>(pw)* group_size / pooled_width);
int gh = floor(static_cast<DType>(ph)* group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);
int c = (ctop*group_size + gh)*group_size + gw;
DType* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
DType bin_area = (hend - hstart)*(wend - wstart);
DType diff_val = is_empty ? (DType)0. : top_diff[index] / bin_area;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int bottom_index = h*width + w;
*(offset_bottom_diff + bottom_index) = *(offset_bottom_diff + bottom_index) + diff_val;
}
}
}
}


template<typename DType>
inline void PSROIPoolBackwardAcc(const Tensor<cpu, 4, DType> &in_grad,
const Tensor<cpu, 4, DType> &out_grad,
const Tensor<cpu, 2, DType> &bbox,
const float spatial_scale_,
const float spatial_scale,
const int output_dim_,
const int group_size_) {
// NOT_IMPLEMENTED;
// LOG(INFO) << "PSROIPoolBackward";
const DType *top_diff = out_grad.dptr_;
const DType *bottom_rois = bbox.dptr_;
DType *bottom_diff = in_grad.dptr_;
const int count = out_grad.shape_.Size();
const int num_rois = bbox.size(0);
const int channels = in_grad.size(1);
const int height = in_grad.size(2);
const int width = in_grad.size(3);
const int pooled_height = out_grad.size(2);
const int pooled_width = out_grad.size(3);
PSROIPoolBackwardAccCPU<DType> (
count, top_diff, num_rois, spatial_scale, channels, height, width,
pooled_height, pooled_width, group_size_, output_dim_, bottom_diff, bottom_rois);

return;
}
} // namespace mshadow
Expand Down
31 changes: 27 additions & 4 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5037,10 +5037,33 @@ def test_psroipooling():
group_size=num_group, pooled_size=num_group,
output_dim=num_classes, name='test_op')
rtol, atol = 1e-2, 1e-3
# By now we only have gpu implementation
if default_context().device_type == 'gpu':
check_numeric_gradient(op, [im_data, rois_data], rtol=rtol, atol=atol,
grad_nodes=grad_nodes, ctx=mx.gpu(0))
check_numeric_gradient(op, [im_data, rois_data], rtol=rtol, atol=atol,
grad_nodes=grad_nodes)


@with_seed()
def test_psroipooling_with_type():
arg_params = {
'psroipool_rois': np.array([[0, 10, 22, 161, 173], [0, 20, 15, 154, 160]])}

# plain psroipooling
sym = mx.sym.contrib.PSROIPooling(spatial_scale=0.0625, output_dim=2, pooled_size=3, name='psroipool')
ctx_list = [{'ctx': mx.cpu(0),
'psroipool_data': (1, 18, 14, 14),
'psroipool_rois': (2, 5),
'type_dict': {'psroipool_data': np.float64, 'psroipool_rois': np.float64}},
{'ctx': mx.cpu(0),
'psroipool_data': (1, 18, 14, 14),
'psroipool_rois': (2, 5),
'type_dict': {'psroipool_data': np.float32, 'psroipool_rois': np.float32}},
{'ctx': mx.cpu(0),
'psroipool_data': (1, 18, 14, 14),
'psroipool_rois': (2, 5),
'type_dict': {'psroipool_data': np.float16, 'psroipool_rois': np.float16}},
]

check_consistency(sym, ctx_list, grad_req={'psroipool_data': 'write',
'psroipool_rois': 'null'}, arg_params=arg_params)


@with_seed()
Expand Down

0 comments on commit 27b9e09

Please sign in to comment.