Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update to index_t
Browse files Browse the repository at this point in the history
  • Loading branch information
arcadiaphy committed May 24, 2019
1 parent bdca339 commit 83d6801
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 177 deletions.
10 changes: 5 additions & 5 deletions src/operator/contrib/deformable_psroi_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ namespace deformablepsroipool {
struct DeformablePSROIPoolingParam : public dmlc::Parameter<DeformablePSROIPoolingParam> {
// mxnet::TShape pooled_size;
float spatial_scale;
int output_dim;
int group_size;
int pooled_size;
int part_size;
int sample_per_part;
index_t output_dim;
index_t group_size;
index_t pooled_size;
index_t part_size;
index_t sample_per_part;
float trans_std;
bool no_trans;
DMLC_DECLARE_PARAMETER(DeformablePSROIPoolingParam) {
Expand Down
164 changes: 82 additions & 82 deletions src/operator/contrib/deformable_psroi_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ namespace mshadow {
template <typename DType>
inline DType bilinear_interp_cpu(const DType* data,
const DType x, const DType y,
const int width, const int height) {
int x1 = floor(x);
int x2 = ceil(x);
int y1 = floor(y);
int y2 = ceil(y);
const index_t width, const index_t height) {
index_t x1 = floor(x);
index_t x2 = ceil(x);
index_t y1 = floor(y);
index_t y2 = ceil(y);
DType dist_x = static_cast<DType>(x - x1);
DType dist_y = static_cast<DType>(y - y1);
DType value11 = data[y1 * width + x1];
Expand All @@ -59,28 +59,28 @@ namespace mshadow {
}

template <typename DType>
inline void DeformablePSROIPoolForwardCPU(const int count, const DType* bottom_data,
const DType spatial_scale, const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
inline void DeformablePSROIPoolForwardCPU(const index_t count, const DType* bottom_data,
const DType spatial_scale, const index_t channels,
const index_t height, const index_t width,
const index_t pooled_height, const index_t pooled_width,
const DType* bottom_rois, const DType* bottom_trans,
const bool no_trans, const DType trans_std,
const int sample_per_part, const int output_dim,
const int group_size, const int part_size,
const int num_classes, const int channels_each_class,
const index_t sample_per_part, const index_t output_dim,
const index_t group_size, const index_t part_size,
const index_t num_classes, const index_t channels_each_class,
DType* top_data, DType* top_count) {
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
#pragma omp parallel for num_threads(omp_threads)
for (int index = 0; index < count; index++) {
for (index_t index = 0; index < count; index++) {
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
index_t pw = index % pooled_width;
index_t ph = (index / pooled_width) % pooled_height;
index_t ctop = (index / pooled_width / pooled_height) % output_dim;
index_t n = index / pooled_width / pooled_height / output_dim;

// [start, end) interval for spatial sampling
const DType* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
index_t roi_batch_ind = offset_bottom_rois[0];
DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
Expand All @@ -97,9 +97,9 @@ namespace mshadow {
DType sub_bin_size_h = bin_size_h / static_cast<DType>(sample_per_part);
DType sub_bin_size_w = bin_size_w / static_cast<DType>(sample_per_part);

int part_h = floor(static_cast<DType>(ph) / pooled_height*part_size);
int part_w = floor(static_cast<DType>(pw) / pooled_width*part_size);
int class_id = ctop / channels_each_class;
index_t part_h = floor(static_cast<DType>(ph) / pooled_height * part_size);
index_t part_w = floor(static_cast<DType>(pw) / pooled_width * part_size);
index_t class_id = ctop / channels_each_class;
DType trans_x = no_trans ? static_cast<DType>(0) :
bottom_trans[(((n * num_classes + class_id) * 2)
* part_size + part_h)
Expand All @@ -115,15 +115,15 @@ namespace mshadow {
hstart += trans_y * roi_height;

DType sum = 0;
int count = 0;
int gw = floor(static_cast<DType>(pw) * group_size / pooled_width);
int gh = floor(static_cast<DType>(ph) * group_size / pooled_height);
index_t count = 0;
index_t gw = floor(static_cast<DType>(pw) * group_size / pooled_width);
index_t gh = floor(static_cast<DType>(ph) * group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);

const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
for (int ih = 0; ih < sample_per_part; ih++) {
for (int iw = 0; iw < sample_per_part; iw++) {
for (index_t ih = 0; ih < sample_per_part; ih++) {
for (index_t iw = 0; iw < sample_per_part; iw++) {
DType w = wstart + iw * sub_bin_size_w;
DType h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
Expand All @@ -132,7 +132,7 @@ namespace mshadow {
}
w = min(max(w, static_cast<DType>(0)), static_cast<DType>(width - 1));
h = min(max(h, static_cast<DType>(0)), static_cast<DType>(height - 1));
int c = (ctop * group_size + gh) * group_size + gw;
index_t c = (ctop * group_size + gh) * group_size + gw;
DType val = bilinear_interp_cpu(offset_bottom_data + c * height * width,
w, h, width, height);
sum += val;
Expand All @@ -151,22 +151,22 @@ namespace mshadow {
const Tensor<cpu, 4, DType> &trans,
const Tensor<cpu, 4, DType> &top_count,
const bool no_trans, const float spatial_scale,
const int output_dim, const int group_size,
const int pooled_size, const int part_size,
const int sample_per_part, const float trans_std) {
const index_t output_dim, const index_t group_size,
const index_t pooled_size, const index_t part_size,
const index_t sample_per_part, const float trans_std) {
const DType *bottom_data = data.dptr_;
const DType *bottom_rois = bbox.dptr_;
const DType *bottom_trans = no_trans ? nullptr : trans.dptr_;
DType *top_data = out.dptr_;
DType *top_count_data = top_count.dptr_;
const int count = out.shape_.Size();
const int channels = data.size(1);
const int height = data.size(2);
const int width = data.size(3);
const int pooled_height = pooled_size;
const int pooled_width = pooled_size;
const int num_classes = no_trans ? 1 : trans.size(1) / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
const index_t count = out.shape_.Size();
const index_t channels = data.size(1);
const index_t height = data.size(2);
const index_t width = data.size(3);
const index_t pooled_height = pooled_size;
const index_t pooled_width = pooled_size;
const index_t num_classes = no_trans ? 1 : trans.size(1) / 2;
const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes;
DeformablePSROIPoolForwardCPU<DType>(count, bottom_data, spatial_scale,
channels, height, width,
pooled_height, pooled_width,
Expand All @@ -179,35 +179,35 @@ namespace mshadow {
}

template <typename DType>
inline void DeformablePSROIPoolBackwardAccCPU(const int count, const DType* top_diff,
const DType* top_count, const int num_rois,
const DType spatial_scale, const int channels,
const int height, const int width,
const int pooled_height,
const int pooled_width,
const int output_dim,
inline void DeformablePSROIPoolBackwardAccCPU(const index_t count, const DType* top_diff,
const DType* top_count, const index_t num_rois,
const DType spatial_scale, const index_t channels,
const index_t height, const index_t width,
const index_t pooled_height,
const index_t pooled_width,
const index_t output_dim,
DType* bottom_data_diff,
DType* bottom_trans_diff,
const DType* bottom_data,
const DType* bottom_rois,
const DType* bottom_trans,
const bool no_trans,
const DType trans_std,
const int sample_per_part,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class) {
for (int index = 0; index < count; index++) {
const index_t sample_per_part,
const index_t group_size,
const index_t part_size,
const index_t num_classes,
const index_t channels_each_class) {
for (index_t index = 0; index < count; index++) {
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
index_t pw = index % pooled_width;
index_t ph = (index / pooled_width) % pooled_height;
index_t ctop = (index / pooled_width / pooled_height) % output_dim;
index_t n = index / pooled_width / pooled_height / output_dim;

// [start, end) interval for spatial sampling
const DType* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
index_t roi_batch_ind = offset_bottom_rois[0];
DType roi_start_w = static_cast<DType>(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
DType roi_start_h = static_cast<DType>(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
DType roi_end_w = static_cast<DType>(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
Expand All @@ -224,9 +224,9 @@ namespace mshadow {
DType sub_bin_size_h = bin_size_h / static_cast<DType>(sample_per_part);
DType sub_bin_size_w = bin_size_w / static_cast<DType>(sample_per_part);

int part_h = floor(static_cast<DType>(ph) / pooled_height*part_size);
int part_w = floor(static_cast<DType>(pw) / pooled_width*part_size);
int class_id = ctop / channels_each_class;
index_t part_h = floor(static_cast<DType>(ph) / pooled_height * part_size);
index_t part_w = floor(static_cast<DType>(pw) / pooled_width * part_size);
index_t class_id = ctop / channels_each_class;
DType trans_x = no_trans ? static_cast<DType>(0) :
bottom_trans[(((n * num_classes + class_id) * 2)
* part_size + part_h)
Expand All @@ -247,13 +247,13 @@ namespace mshadow {
DType diff_val = top_diff[index] / top_count[index];
const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
int gw = floor(static_cast<DType>(pw)* group_size / pooled_width);
int gh = floor(static_cast<DType>(ph)* group_size / pooled_height);
index_t gw = floor(static_cast<DType>(pw)* group_size / pooled_width);
index_t gh = floor(static_cast<DType>(ph)* group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);

for (int ih = 0; ih < sample_per_part; ih++) {
for (int iw = 0; iw < sample_per_part; iw++) {
for (index_t ih = 0; ih < sample_per_part; ih++) {
for (index_t iw = 0; iw < sample_per_part; iw++) {
DType w = wstart + iw * sub_bin_size_w;
DType h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
Expand All @@ -262,18 +262,18 @@ namespace mshadow {
}
w = min(max(w, static_cast<DType>(0)), static_cast<DType>(width - 1));
h = min(max(h, static_cast<DType>(0)), static_cast<DType>(height - 1));
int c = (ctop * group_size + gh) * group_size + gw;
index_t c = (ctop * group_size + gh) * group_size + gw;
// backward on feature
int x0 = floor(w);
int x1 = ceil(w);
int y0 = floor(h);
int y1 = ceil(h);
index_t x0 = floor(w);
index_t x1 = ceil(w);
index_t y0 = floor(h);
index_t y1 = ceil(h);
DType dist_x = w - x0, dist_y = h - y0;
DType q00 = (1 - dist_x) * (1 - dist_y);
DType q01 = (1 - dist_x) * dist_y;
DType q10 = dist_x * (1 - dist_y);
DType q11 = dist_x * dist_y;
int bottom_index_base = c * height * width;
index_t bottom_index_base = c * height * width;
offset_bottom_data_diff[bottom_index_base + y0 * width + x0] += q00 * diff_val;
offset_bottom_data_diff[bottom_index_base + y1 * width + x0] += q01 * diff_val;
offset_bottom_data_diff[bottom_index_base + y0 * width + x1] += q10 * diff_val;
Expand All @@ -291,7 +291,7 @@ namespace mshadow {
DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x);
diff_y *= trans_std * diff_val * roi_height;

int offset_trans_diff = (((n * num_classes + class_id) * 2)
index_t offset_trans_diff = (((n * num_classes + class_id) * 2)
* part_size + part_h) * part_size + part_w;
bottom_trans_diff[offset_trans_diff] += diff_x;
bottom_trans_diff[offset_trans_diff + part_size * part_size] += diff_y;
Expand All @@ -309,25 +309,25 @@ namespace mshadow {
const Tensor<cpu, 4, DType> &trans,
const Tensor<cpu, 4, DType> &top_count,
const bool no_trans, const float spatial_scale,
const int output_dim, const int group_size,
const int pooled_size, const int part_size,
const int sample_per_part, const float trans_std) {
const index_t output_dim, const index_t group_size,
const index_t pooled_size, const index_t part_size,
const index_t sample_per_part, const float trans_std) {
const DType *top_diff = out_grad.dptr_;
const DType *bottom_data = data.dptr_;
const DType *bottom_rois = bbox.dptr_;
const DType *bottom_trans = no_trans ? nullptr : trans.dptr_;
DType *bottom_data_diff = in_grad.dptr_;
DType *bottom_trans_diff = no_trans ? nullptr : trans_grad.dptr_;
const DType *top_count_data = top_count.dptr_;
const int count = out_grad.shape_.Size();
const int num_rois = bbox.size(0);
const int channels = in_grad.size(1);
const int height = in_grad.size(2);
const int width = in_grad.size(3);
const int pooled_height = pooled_size;
const int pooled_width = pooled_size;
const int num_classes = no_trans ? 1 : trans_grad.size(1) / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
const index_t count = out_grad.shape_.Size();
const index_t num_rois = bbox.size(0);
const index_t channels = in_grad.size(1);
const index_t height = in_grad.size(2);
const index_t width = in_grad.size(3);
const index_t pooled_height = pooled_size;
const index_t pooled_width = pooled_size;
const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2;
const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes;
DeformablePSROIPoolBackwardAccCPU<DType>(count, top_diff, top_count_data, num_rois,
spatial_scale, channels, height, width,
pooled_height, pooled_width, output_dim,
Expand Down
Loading

0 comments on commit 83d6801

Please sign in to comment.