diff --git a/paddle/fluid/operators/roi_align_op.cu b/paddle/fluid/operators/roi_align_op.cu index a08339d776ff1..3b25676fb0c36 100644 --- a/paddle/fluid/operators/roi_align_op.cu +++ b/paddle/fluid/operators/roi_align_op.cu @@ -26,6 +26,7 @@ using LoDTensor = framework::LoDTensor; static constexpr int kNumCUDAThreads = 512; static constexpr int kNumMaxinumNumBlocks = 4096; +static constexpr int kROISize = 4; static inline int NumBlocks(const int N) { return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h index 29c9268d5241c..1ab5ddc83fb67 100644 --- a/paddle/fluid/operators/roi_align_op.h +++ b/paddle/fluid/operators/roi_align_op.h @@ -12,6 +12,7 @@ limitations under the License. */ #pragma once #include #include +#include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" @@ -22,72 +23,150 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -static constexpr int kROISize = 4; +namespace { +constexpr size_t get_offset(size_t x, size_t y, size_t width) { + return y * width + x; +} template -void PreCalcForBilinearInterpolate( - const platform::DeviceContext& ctx, const int height, const int width, - const int pooled_height, const int pooled_width, const int iy_upper, - const int ix_upper, T roi_ymin, T roi_xmin, T bin_size_h, T bin_size_w, - int roi_bin_grid_h, int roi_bin_grid_w, Tensor* pre_pos, Tensor* pre_w) { - int pre_calc_index = 0; - int* pre_pos_data = pre_pos->mutable_data(ctx.GetPlace()); - T* pre_w_data = pre_w->mutable_data(ctx.GetPlace()); - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - for (int iy = 0; iy < iy_upper; iy++) { - // calculate y of sample points - T y = roi_ymin + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); - // calculate x of samle points - for (int ix = 0; ix < ix_upper; ix++) { - T x = roi_xmin + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); +struct offsets_and_ratios { + offsets_and_ratios() = default; + offsets_and_ratios(std::size_t xy, std::size_t xY, std::size_t Xy, + std::size_t XY, T xy_ratio, T xY_ratio, T Xy_ratio, + T XY_ratio) + : xy(xy), + xY(xY), + Xy(Xy), + XY(XY), + xy_ratio(xy_ratio), + xY_ratio(xY_ratio), + Xy_ratio(Xy_ratio), + XY_ratio(XY_ratio){}; + + std::size_t xy = 0; + std::size_t xY = 0; + std::size_t Xy = 0; + std::size_t XY = 0; + T xy_ratio = 0.0f; + T xY_ratio = 0.0f; + T Xy_ratio = 0.0f; + T XY_ratio = 0.0f; +}; + +template +std::vector> get_indexes_and_ratios( + std::size_t width, std::size_t height, const T roi_width, + const T roi_height, const T roi_xmin, const T roi_ymin, + std::size_t pooled_width, std::size_t roi_bin_grid_w, + std::size_t pooled_height, std::size_t roi_bin_grid_h) { + const auto ind_num = + pooled_width * roi_bin_grid_w * pooled_height * roi_bin_grid_h; + + std::vector> interpolation_cords; + interpolation_cords.reserve(ind_num); + + const auto bin_w = roi_width / pooled_width; + const auto bin_h = roi_height / pooled_height; + + for (std::size_t py = 0; py < pooled_height; py++) { + for (std::size_t px = 0; px < pooled_width; px++) { + for (std::size_t iy = 0; iy < roi_bin_grid_h; iy++) { + // calculate x of sample points + auto y = + roi_ymin + + bin_h * (py + + static_cast(iy + .5f) / static_cast(roi_bin_grid_h)); + for (std::size_t ix = 0; ix < roi_bin_grid_w; ix++) { + // calculate x of sample points + auto x = roi_xmin + + bin_w * (px + + static_cast(ix + .5f) / + static_cast(roi_bin_grid_w)); + // deal with elements out of map if (y < -1.0 || y > height || x < -1.0 || x > width) { - for (int i = 0; i < kROISize; ++i) { - pre_pos_data[i + pre_calc_index * kROISize] = 0; - pre_w_data[i + pre_calc_index * kROISize] = 0; - } - pre_calc_index += 1; + interpolation_cords.emplace_back(); continue; } y = y <= 0 ? 0 : y; x = x <= 0 ? 0 : x; - int y_low = static_cast(y); - int x_low = static_cast(x); - int y_high; - int x_high; - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = static_cast(y_low); + std::size_t x_low_index = static_cast(x); + std::size_t x_high_index; + if (x_low_index >= width - 1) { + x_high_index = x_low_index = width - 1; + x = static_cast(x_low_index); } else { - y_high = y_low + 1; + x_high_index = x_low_index + 1; } - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = static_cast(x_low); + T x_ratio = x_high_index - x; + + std::size_t y_low_index = static_cast(y); + std::size_t y_high_index; + if (y_low_index >= height - 1) { + y_high_index = y_low_index = height - 1; + y = static_cast(y_low_index); } else { - x_high = x_low + 1; + y_high_index = y_low_index + 1; } - T ly = y - y_low, lx = x - x_low; - T hy = 1. - ly, hx = 1. - lx; - pre_pos_data[pre_calc_index * kROISize] = y_low * width + x_low; - pre_pos_data[pre_calc_index * kROISize + 1] = y_low * width + x_high; - pre_pos_data[pre_calc_index * kROISize + 2] = y_high * width + x_low; - pre_pos_data[pre_calc_index * kROISize + 3] = y_high * width + x_high; - pre_w_data[pre_calc_index * kROISize] = hy * hx; - pre_w_data[pre_calc_index * kROISize + 1] = hy * lx; - pre_w_data[pre_calc_index * kROISize + 2] = ly * hx; - pre_w_data[pre_calc_index * kROISize + 3] = ly * lx; - pre_calc_index += 1; + T y_ratio = y_high_index - y; + + auto xy = get_offset(x_low_index, y_low_index, width); + auto xY = get_offset(x_low_index, y_high_index, width); + auto Xy = get_offset(x_high_index, y_low_index, width); + auto XY = get_offset(x_high_index, y_high_index, width); + + auto xy_ratio = x_ratio * y_ratio; + auto xY_ratio = x_ratio * (1 - y_ratio); + auto Xy_ratio = (1 - x_ratio) * y_ratio; + auto XY_ratio = (1 - x_ratio) * (1 - y_ratio); + + interpolation_cords.emplace_back(xy, xY, Xy, XY, xy_ratio, xY_ratio, + Xy_ratio, XY_ratio); } } } } + return interpolation_cords; +} + +template +void interpolate(std::vector& interpolated_values, + const std::vector>& interpolation_cords, + const T* data) { + for (auto& ic : interpolation_cords) { + auto xlyl_offset = ic.xy; + auto xhyl_offset = ic.Xy; + auto xlyh_offset = ic.xY; + auto xhyh_offset = ic.XY; + + auto xlyl_ratio = ic.xy_ratio; + auto xhyl_ratio = ic.Xy_ratio; + auto xlyh_ratio = ic.xY_ratio; + auto xhyh_ratio = ic.XY_ratio; + + interpolated_values.emplace_back( + xlyl_ratio * data[xlyl_offset] + xhyl_ratio * data[xhyl_offset] + + xlyh_ratio * data[xlyh_offset] + xhyh_ratio * data[xhyh_offset]); + } +} + +template +void avg_pool(const std::vector& interpolated_values, T* output_data, + int roi_bin_grid_w, int roi_bin_grid_h, int pooled_width, + int pooled_height) { + const auto data_amount = pooled_width * pooled_height; + const auto grid_points = roi_bin_grid_w * roi_bin_grid_h; + const T count = 1.0 / grid_points; + auto val_begin = interpolated_values.cbegin(); + for (auto i = 0; i < data_amount; ++i) { + T sum = 0.0; + auto val_end = val_begin + grid_points; + sum = std::accumulate(val_begin, val_end, sum); + val_begin = val_end; + output_data[i] = sum * count; + } +} } template @@ -147,8 +226,6 @@ class CPUROIAlignOpKernel : public framework::OpKernel { auto sampling_ratio = ctx.Attr("sampling_ratio"); auto aligned = ctx.Attr("aligned"); - auto& dev_ctx = ctx.template device_context(); - auto in_dims = in->dims(); int batch_size = in_dims[0]; int channels = in_dims[1]; @@ -209,7 +286,7 @@ class CPUROIAlignOpKernel : public framework::OpKernel { "of rois from RoIsLoD is %d", rois_num, rois_num_with_lod)); for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + for (std::size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { roi_batch_id_data[i] = n; } } @@ -231,8 +308,6 @@ class CPUROIAlignOpKernel : public framework::OpKernel { roi_height = std::max(roi_height, static_cast(1.)); } - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); const T* batch_data = input_data + roi_batch_id * in_stride[0]; int roi_bin_grid_h = (sampling_ratio > 0) @@ -241,41 +316,20 @@ class CPUROIAlignOpKernel : public framework::OpKernel { int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); - Tensor pre_pos; - Tensor pre_w; - int pre_size = count * out_stride[1]; - pre_pos.Resize({pre_size, kROISize}); - pre_w.Resize({pre_size, kROISize}); - - PreCalcForBilinearInterpolate( - dev_ctx, height, width, pooled_height, pooled_width, roi_bin_grid_h, - roi_bin_grid_w, roi_ymin, roi_xmin, bin_size_h, bin_size_w, - roi_bin_grid_h, roi_bin_grid_w, &pre_pos, &pre_w); - const int* pre_pos_data = pre_pos.data(); - const T* pre_w_data = pre_w.data(); - for (int c = 0; c < channels; c++) { - int pre_calc_index = 0; - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - const int pool_index = ph * pooled_width + pw; - T output_val = 0; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - for (int i = 0; i < kROISize; i++) { - int pos = pre_pos_data[pre_calc_index * kROISize + i]; - T w = pre_w_data[pre_calc_index * kROISize + i]; - output_val += w * batch_data[pos]; - } - pre_calc_index += 1; - } - } - output_val /= count; - output_data[pool_index] = output_val; - } - } + + auto interpolation_cords = get_indexes_and_ratios( + width, height, roi_width, roi_height, roi_xmin, roi_ymin, + pooled_width, roi_bin_grid_w, pooled_height, roi_bin_grid_h); + + std::vector interpolated_values; + interpolated_values.reserve(interpolation_cords.size()); + for (auto channel = 0; channel < channels; ++channel) { + interpolate(interpolated_values, interpolation_cords, batch_data); + avg_pool(interpolated_values, output_data, roi_bin_grid_w, + roi_bin_grid_h, pooled_width, pooled_height); batch_data += in_stride[1]; output_data += out_stride[1]; + interpolated_values.clear(); } rois_data += roi_stride[0]; } @@ -328,7 +382,7 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel { auto rois_lod = rois->lod().back(); rois_batch_size = rois_lod.size() - 1; for (int n = 0; n < rois_batch_size; ++n) { - for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + for (std::size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { roi_batch_id_data[i] = n; } }