diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 2da95050223e..cf322c8bd8cf 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -18,6 +18,90 @@ namespace mxnet { namespace op { + +/*! +* \brief structure for numerical tuple input +* \tparam VType data type of param +*/ +template +struct NumericalParam { + NumericalParam() {} + explicit NumericalParam(VType *begin, VType *end) { + int32_t size = static_cast(end - begin); + info.resize(size); + for (int i = 0; i < size; ++i) { + info[i] = *(begin + i); + } + } + inline size_t ndim() const { + return info.size(); + } + std::vector info; +}; + +template +inline std::istream &operator>>(std::istream &is, NumericalParam ¶m) { + while (true) { + char ch = is.get(); + if (ch == '(') break; + if (!isspace(ch)) { + is.setstate(std::ios::failbit); + return is; + } + } + VType idx; + std::vector tmp; + // deal with empty case + size_t pos = is.tellg(); + char ch = is.get(); + if (ch == ')') { + param.info = tmp; + return is; + } + is.seekg(pos); + // finish deal + while (is >> idx) { + tmp.push_back(idx); + char ch; + do { + ch = is.get(); + } while (isspace(ch)); + if (ch == ',') { + while (true) { + ch = is.peek(); + if (isspace(ch)) { + is.get(); continue; + } + if (ch == ')') { + is.get(); break; + } + break; + } + if (ch == ')') break; + } else if (ch == ')') { + break; + } else { + is.setstate(std::ios::failbit); + return is; + } + } + param.info = tmp; + return is; +} + +template +inline std::ostream &operator<<(std::ostream &os, const NumericalParam ¶m) { + os << '('; + for (index_t i = 0; i < param.info.size(); ++i) { + if (i != 0) os << ','; + os << param.info[i]; + } + // python style tuple + if (param.info.size() == 1) os << ','; + os << ')'; + return os; +} + /*! * \brief assign the expression to out according to request * \param out the data to be assigned diff --git a/src/operator/proposal-inl.h b/src/operator/proposal-inl.h new file mode 100644 index 000000000000..a5e10c6a7e51 --- /dev/null +++ b/src/operator/proposal-inl.h @@ -0,0 +1,296 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file proposal-inl.h + * \brief Proposal Operator + * \author Piotr Teterwak, Jian Guo +*/ +#ifndef MXNET_OPERATOR_PROPOSAL_INL_H_ +#define MXNET_OPERATOR_PROPOSAL_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" +#include "./mshadow_op.h" +#include "./native_op-inl.h" +#include "./rcnn_utils.h" + + +namespace mxnet { +namespace op { + +namespace proposal { +enum ProposalOpType {kTrain, kTest}; +enum ProposalOpInputs {kClsProb, kBBoxPred, kImInfo}; +enum ProposalOpOutputs {kOut, kScore}; +enum ProposalForwardResource {kTempResource}; +} // proposal + + +struct ProposalParam : public dmlc::Parameter { + int rpn_pre_nms_top_n; + int rpn_post_nms_top_n; + float threshold; + int rpn_min_size; + NumericalParam scales; + NumericalParam ratios; + int feature_stride; + bool output_score; + DMLC_DECLARE_PARAMETER(ProposalParam) { + float tmp[] = {0, 0, 0, 0}; + DMLC_DECLARE_FIELD(rpn_pre_nms_top_n).set_default(6000) + .describe("Number of top scoring boxes to keep after applying NMS to RPN proposals"); + DMLC_DECLARE_FIELD(rpn_post_nms_top_n).set_default(300) + .describe("Overlap threshold used for non-maximum" + "suppresion(suppress boxes with IoU >= this threshold"); + DMLC_DECLARE_FIELD(threshold).set_default(0.7) + .describe("NMS value, below which to suppress."); + DMLC_DECLARE_FIELD(rpn_min_size).set_default(16) + .describe("Minimum height or width in proposal"); + tmp[0] = 4.0f; tmp[1] = 8.0f; tmp[2] = 16.0f; tmp[3] = 32.0f; + DMLC_DECLARE_FIELD(scales).set_default(NumericalParam(tmp, tmp + 4)) + .describe("Used to generate anchor windows by enumerating scales"); + tmp[0] = 0.5f; tmp[1] = 1.0f; tmp[2] = 2.0f; + DMLC_DECLARE_FIELD(ratios).set_default(NumericalParam(tmp, tmp + 3)) + .describe("Used to generate anchor windows by enumerating ratios"); + DMLC_DECLARE_FIELD(feature_stride).set_default(16) + .describe("The size of the receptive field each unit in the convolution layer of the rpn," + "for example the product of all stride's prior to this layer."); + DMLC_DECLARE_FIELD(output_score).set_default(false) + .describe("Add score to outputs"); + } +}; + +template +class ProposalOp : public Operator{ + public: + explicit ProposalOp(ProposalParam param) { + this->param_ = param; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 3); + CHECK_EQ(out_data.size(), 2); + CHECK_GT(req.size(), 1); + CHECK_EQ(req[proposal::kOut], kWriteTo); + + Stream *s = ctx.get_stream(); + + Shape<4> scores_shape = Shape4(in_data[proposal::kClsProb].shape_[0], + in_data[proposal::kClsProb].shape_[1] / 2, + in_data[proposal::kClsProb].shape_[2], + in_data[proposal::kClsProb].shape_[3]); + real_t* foreground_score_ptr = reinterpret_cast(in_data[proposal::kClsProb].dptr_) + scores_shape.Size(); + Tensor scores = Tensor(foreground_score_ptr, scores_shape); + Tensor bbox_deltas = in_data[proposal::kBBoxPred].get(s); + Tensor im_info = in_data[proposal::kImInfo].get(s); + + Tensor out = out_data[proposal::kOut].get(s); + Tensor out_score = out_data[proposal::kScore].get(s); + + index_t num_anchors = in_data[proposal::kClsProb].shape_[1] / 2; + index_t height = scores.size(2); + index_t width = scores.size(3); + index_t count = num_anchors * height * width; + + Tensor workspace_proposals = ctx.requested[proposal::kTempResource].get_space( + Shape2(count, 5), s); + Tensor workspace_ordered_proposals = ctx.requested[proposal::kTempResource].get_space( + Shape2(param_.rpn_pre_nms_top_n, 5), s); + Tensor workspace_pre_nms = ctx.requested[proposal::kTempResource].get_space( + Shape2(2, count), s); + Tensor workspace_nms = ctx.requested[proposal::kTempResource].get_space( + Shape2(3, param_.rpn_pre_nms_top_n), s); + + // Generate anchors + std::vector base_anchor(4); + base_anchor[0] = 0.0; + base_anchor[1] = 0.0; + base_anchor[2] = param_.feature_stride - 1.0; + base_anchor[3] = param_.feature_stride - 1.0; + CHECK_EQ(num_anchors, param_.ratios.info.size() * param_.scales.info.size()); + std::vector anchors; + utils::GenerateAnchors(base_anchor, + param_.ratios.info, + param_.scales.info, + anchors); + std::memcpy(workspace_proposals.dptr_, &anchors[0], sizeof(float) * anchors.size()); + + //Enumerate all shifted anchors + for (index_t i = 0; i < num_anchors; ++i){ + for (index_t j = 0; j < height; ++j){ + for (index_t k = 0; k < width; ++k){ + index_t index = j * (width * num_anchors) + k * (num_anchors) + i; + workspace_proposals[index][0] = workspace_proposals[i][0] + k * param_.feature_stride; + workspace_proposals[index][1] = workspace_proposals[i][1] + j * param_.feature_stride; + workspace_proposals[index][2] = workspace_proposals[i][2] + k * param_.feature_stride; + workspace_proposals[index][3] = workspace_proposals[i][3] + j * param_.feature_stride; + workspace_proposals[index][4] = scores[0][i][j][k]; + } + } + } + + utils::BBoxTransformInv(workspace_proposals, bbox_deltas, &(workspace_proposals)); + utils::ClipBoxes(Shape2(im_info[0][0],im_info[0][1]), &(workspace_proposals)); + + Tensor score = workspace_pre_nms[0]; + Tensor order = workspace_pre_nms[1]; + + utils::CopyScore(workspace_proposals, + score, + order); + utils::ReverseArgsort(score, + order); + utils::ReorderProposals(workspace_proposals, + order, + param_.rpn_pre_nms_top_n, + workspace_ordered_proposals); + + real_t scale = im_info[0][2]; + index_t out_size = 0; + Tensor area = workspace_nms[0]; + Tensor suppressed = workspace_nms[1]; + Tensor keep = workspace_nms[2]; + + utils::NonMaximumSuppression(workspace_ordered_proposals, + param_.threshold, + param_.rpn_min_size * scale, + param_.rpn_post_nms_top_n, + area, + suppressed, + keep, + &out_size); + + // fill in output rois + for (index_t i = 0; i < out.size(0); ++i) { + index_t index = keep[i]; + //batch index 0 + out[i][0] = 0; + for (index_t j = 0; j < 4; ++j) { + if (i < out_size) { + out[i][j + 1] = workspace_ordered_proposals[index][j]; + } else { + out[i][j + 1] = 0; + } + } + } + + // fill in output score + for (index_t i = 0; i < out_score.size(0); i++) { + index_t index = keep[i]; + if (i < out_size) { + out_score[i][0] = workspace_ordered_proposals[index][4]; + } + else { + out_score[i][0] = 0; + } + } + } + + private: + ProposalParam param_; +}; // class ProposalOp + +template +Operator *CreateOp(ProposalParam param); + + +#if DMLC_USE_CXX11 +class ProposalProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 3) << "Input:[cls_prob, bbox_pred, im_info]"; + const TShape &dshape = in_shape->at(proposal::kClsProb); + if (dshape.ndim() == 0) return false; + Shape<4> bbox_pred_shape; + bbox_pred_shape = Shape4(dshape[0], dshape[1] * 2, dshape[2], dshape[3]); + SHAPE_ASSIGN_CHECK(*in_shape, proposal::kBBoxPred, + bbox_pred_shape); + Shape<2> im_info_shape; + im_info_shape = Shape2(1, 3); + SHAPE_ASSIGN_CHECK(*in_shape, proposal::kImInfo, im_info_shape); + out_shape->clear(); + // output + out_shape->push_back(Shape2(param_.rpn_post_nms_top_n, 5)); + // score + out_shape->push_back(Shape2(param_.rpn_post_nms_top_n, 1)); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new ProposalProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "Proposal"; + } + + std::vector ForwardResource( + const std::vector &in_shape) const override { + return {ResourceRequest::kTempSpace}; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return {}; + } + + int NumVisibleOutputs() const override { + if (param_.output_score) { + return 2; + } + else{ + return 1; + } + } + + int NumOutputs() const override { + return 2; + } + + std::vector ListArguments() const override { + return {"cls_prob", "bbox_pred", "im_info"}; + } + + std::vector ListOutputs() const override { + return {"output", "score"}; + } + + Operator* CreateOperator(Context ctx) const override; + + private: + ProposalParam param_; +}; // class ProposalProp + +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_PROPOSAL_INL_H_ diff --git a/src/operator/proposal.cc b/src/operator/proposal.cc new file mode 100644 index 000000000000..7c357c6ab610 --- /dev/null +++ b/src/operator/proposal.cc @@ -0,0 +1,32 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file proposal.cc + * \brief + * \author Piotr Teterwak +*/ + +#include "./proposal-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(ProposalParam param) { + return new ProposalOp(param); +} + +Operator* ProposalProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(ProposalParam); + +MXNET_REGISTER_OP_PROPERTY(Proposal, ProposalProp) +.describe("Generate region proposals via RPN") +.add_argument("cls_score", "Symbol", "Score of how likely proposal is object.") +.add_argument("bbox_pred", "Symbol", "BBox Predicted deltas from anchors for proposals") +.add_argument("im_info", "Symbol", "Image size and scale.") +.add_arguments(ProposalParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/proposal.cu b/src/operator/proposal.cu new file mode 100644 index 000000000000..3da65b15ef91 --- /dev/null +++ b/src/operator/proposal.cu @@ -0,0 +1,476 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file proposal.cu + * \brief Proposal Operator + * \author Shaoqin Ren, Jian Guo +*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" +#include "./mshadow_op.h" +#include "./native_op-inl.h" +#include "./rcnn_utils.h" +#include "./proposal-inl.h" + +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) + +#define FRCNN_CUDA_CHECK(condition) \ + /* Code block avoids redefinition of cudaError_t error */ \ + do { \ + cudaError_t error = condition; \ + CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ +} while (0) + +namespace mshadow{ +namespace cuda{ + +// scores are (b, anchor, h, w) +// workspace_proposals are (h * w * anchor, 5) +// w defines "x" and h defines "y" +// count should be total anchors numbers, h * w * anchors +template +__global__ void ProposalGridKernel(const int count, + const int num_anchors, + const int height, + const int width, + const int feature_stride, + const Dtype* scores, + Dtype* workspace_proposals) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; + index < count; + index += blockDim.x * gridDim.x) { + int a = index % num_anchors; + int w = (index / num_anchors) % width; + int h = index / num_anchors / width; + + workspace_proposals[index * 5 + 0] = workspace_proposals[a * 5 + 0] + w * feature_stride; + workspace_proposals[index * 5 + 1] = workspace_proposals[a * 5 + 1] + h * feature_stride; + workspace_proposals[index * 5 + 2] = workspace_proposals[a * 5 + 2] + w * feature_stride; + workspace_proposals[index * 5 + 3] = workspace_proposals[a * 5 + 3] + h * feature_stride; + workspace_proposals[index * 5 + 4] = scores[(a * height + h) * width + w]; + } +} + +// boxes are (h * w * anchor, 5) +// deltas are (b, 4 * anchor, h, w) +// out_pred_boxes are (h * w * anchor, 5) +// count should be total anchors numbers, h * w * anchors +// in-place write: boxes and out_pred_boxes are the same location +template +__global__ void BBoxPredKernel(const int count, + const int num_anchors, + const int feat_height, + const int feat_width, + const float im_height, + const float im_width, + const Dtype* boxes, + const Dtype* deltas, + Dtype* out_pred_boxes) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; + index < count; + index += blockDim.x * gridDim.x) { + int a = index % num_anchors; + int w = (index / num_anchors) % feat_width; + int h = index / num_anchors / feat_width; + + float width = boxes[index * 5 + 2] - boxes[index * 5 + 0] + 1.0f; + float height = boxes[index * 5 + 3] - boxes[index * 5 + 1] + 1.0f; + float ctr_x = boxes[index * 5 + 0] + 0.5f * (width - 1.0f); + float ctr_y = boxes[index * 5 + 1] + 0.5f * (height - 1.0f); + + float dx = deltas[((a * 4) * feat_height + h) * feat_width + w]; + float dy = deltas[((a * 4 + 1) * feat_height + h) * feat_width + w]; + float dw = deltas[((a * 4 + 2) * feat_height + h) * feat_width + w]; + float dh = deltas[((a * 4 + 3) * feat_height + h) * feat_width + w]; + + float pred_ctr_x = dx * width + ctr_x; + float pred_ctr_y = dy * height + ctr_y; + float pred_w = exp(dw) * width; + float pred_h = exp(dh) * height; + + float pred_x1 = pred_ctr_x - 0.5f * (pred_w - 1.0f); + float pred_y1 = pred_ctr_y - 0.5f * (pred_h - 1.0f); + float pred_x2 = pred_ctr_x + 0.5f * (pred_w - 1.0f); + float pred_y2 = pred_ctr_y + 0.5f * (pred_h - 1.0f); + + pred_x1 = max(min(pred_x1, im_width - 1.0f), 0.0f); + pred_y1 = max(min(pred_y1, im_height - 1.0f), 0.0f); + pred_x2 = max(min(pred_x2, im_width - 1.0f), 0.0f); + pred_y2 = max(min(pred_y2, im_height - 1.0f), 0.0f); + + out_pred_boxes[index * 5 + 0] = pred_x1; + out_pred_boxes[index * 5 + 1] = pred_y1; + out_pred_boxes[index * 5 + 2] = pred_x2; + out_pred_boxes[index * 5 + 3] = pred_y2; + } +} + +// filter box with stride less than rpn_min_size +// filter: set score to zero +// dets (n, 5) +template +__global__ void FilterBoxKernel(const int count, + const float min_size, + Dtype* dets) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; + index < count; + index += blockDim.x * gridDim.x) { + float iw = dets[index * 5 + 2] - dets[index * 5 + 0] + 1.0f; + float ih = dets[index * 5 + 3] - dets[index * 5 + 1] + 1.0f; + if (iw < min_size || ih < min_size) { + dets[index * 5 + 4] = 0.0f; + } + } +} + +// copy score and init order +// dets (n, 5); score (n, ); order (n, ) +// count should be n (total anchors or proposals) +template +__global__ void CopyScoreKernel(const int count, + const Dtype* dets, + Dtype* score, + Dtype* order) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; + index < count; + index += blockDim.x * gridDim.x) { + score[index] = dets[index * 5 + 4]; + order[index] = index; + } +} + +// reorder proposals according to order and keep the top_n proposals +// prev_dets (n, 5); order (n, ); dets (n, 5) +// count should be output anchor numbers (top_n) +template +__global__ void ReorderProposalsKernel(const int count, + const Dtype* prev_dets, + const Dtype* order, + const int top_n, + Dtype* dets) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; + index < count; + index += blockDim.x * gridDim.x) { + if (index > top_n) { + return; + } + const int order_i = order[index]; + for (int j = 0; j < 5; j ++) { + dets[index * 5 + j] = prev_dets[order_i * 5 + j]; + } + } +} + +__device__ inline float devIoU(float const * const a, float const * const b) { + float left = max(a[0], b[0]), right = min(a[2], b[2]); + float top = max(a[1], b[1]), bottom = min(a[3], b[3]); + float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); + float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); + return interS / (Sa + Sb - interS); +} + +__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, + const float *dev_boxes, unsigned long long *dev_mask) { + const int threadsPerBlock = sizeof(unsigned long long) * 8; + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ float block_boxes[threadsPerBlock * 5]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const float *cur_box = dev_boxes + cur_box_idx * 5; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +void _nms(const mshadow::Tensor& boxes, + const float nms_overlap_thresh, + int *keep, + int *num_out) { + const int threadsPerBlock = sizeof(unsigned long long) * 8; + const int boxes_num = boxes.size(0); + const int boxes_dim = boxes.size(1); + + float* boxes_dev = boxes.dptr_; + unsigned long long* mask_dev = NULL; + + const int col_blocks = DIVUP(boxes_num, threadsPerBlock); + FRCNN_CUDA_CHECK(cudaMalloc(&mask_dev, + boxes_num * col_blocks * sizeof(unsigned long long))); + + dim3 blocks(DIVUP(boxes_num, threadsPerBlock), + DIVUP(boxes_num, threadsPerBlock)); + dim3 threads(threadsPerBlock); + nms_kernel<<>>(boxes_num, + nms_overlap_thresh, + boxes_dev, + mask_dev); + std::vector mask_host(boxes_num * col_blocks); + FRCNN_CUDA_CHECK(cudaMemcpy(&mask_host[0], + mask_dev, + sizeof(unsigned long long) * boxes_num * col_blocks, + cudaMemcpyDeviceToHost)); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + int num_to_keep = 0; + for (int i = 0; i < boxes_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep[num_to_keep++] = i; + unsigned long long *p = &mask_host[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + *num_out = num_to_keep; + + FRCNN_CUDA_CHECK(cudaFree(mask_dev)); +} + +// copy proposals to output +// dets (top_n, 5); keep (top_n, ); out (top_n, ) +// count should be top_n (total anchors or proposals) +template +__global__ void PrepareOutput(const int count, + const Dtype* dets, + const int* keep, + const int out_size, + Dtype* out, + Dtype* score) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; + index < count; + index += blockDim.x * gridDim.x) { + out[index * 5] = 0; + if (index < out_size) { + int keep_i = keep[index]; + for (int j = 0; j < 4; ++j) { + out[index * 5 + j + 1] = dets[keep_i * 5 + j]; + } + score[index] = dets[keep_i * 5 + 4]; + } else { + for (int j = 0; j < 4; ++j) { + out[index * 5 + j + 1] = 0; + } + score[index] = 0; + } + } +} + +} +} + +namespace mxnet { +namespace op { + +template +class ProposalGPUOp : public Operator{ + public: + explicit ProposalGPUOp(ProposalParam param) { + this->param_ = param; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow::cuda; + CHECK_EQ(in_data.size(), 3); + CHECK_EQ(out_data.size(), 2); + CHECK_GT(req.size(), 1); + CHECK_EQ(req[proposal::kOut], kWriteTo); + + Stream *s = ctx.get_stream(); + + Shape<4> fg_scores_shape = Shape4(in_data[proposal::kClsProb].shape_[0], + in_data[proposal::kClsProb].shape_[1] / 2, + in_data[proposal::kClsProb].shape_[2], + in_data[proposal::kClsProb].shape_[3]); + real_t* foreground_score_ptr = reinterpret_cast(in_data[proposal::kClsProb].dptr_) + fg_scores_shape.Size(); + Tensor scores = Tensor(foreground_score_ptr, fg_scores_shape); + Tensor bbox_deltas = in_data[proposal::kBBoxPred].get(s); + Tensor im_info = in_data[proposal::kImInfo].get(s); + + Tensor out = out_data[proposal::kOut].get(s); + Tensor out_score = out_data[proposal::kScore].get(s); + + index_t num_anchors = in_data[proposal::kClsProb].shape_[1] / 2; + index_t height = scores.size(2); + index_t width = scores.size(3); + index_t count = num_anchors * height * width; // count of total anchors + + float* workspace_proposals_ptr = NULL; + FRCNN_CUDA_CHECK(cudaMalloc(&workspace_proposals_ptr, sizeof(float) * count * 5)); + Tensor workspace_proposals(workspace_proposals_ptr, Shape2(count, 5)); + float* workspace_ordered_proposals_ptr = NULL; + FRCNN_CUDA_CHECK(cudaMalloc(&workspace_ordered_proposals_ptr, sizeof(float) * param_.rpn_pre_nms_top_n * 5)); + Tensor workspace_ordered_proposals(workspace_ordered_proposals_ptr, Shape2(param_.rpn_pre_nms_top_n, 5)); + float* workspace_pre_nms_ptr = NULL; + FRCNN_CUDA_CHECK(cudaMalloc(&workspace_pre_nms_ptr, sizeof(float) * count * 2)); + Tensor workspace_pre_nms(workspace_pre_nms_ptr, Shape2(2, count)); + + // Generate first anchors based on base anchor + std::vector base_anchor(4); + base_anchor[0] = 0.0; + base_anchor[1] = 0.0; + base_anchor[2] = param_.feature_stride - 1.0; + base_anchor[3] = param_.feature_stride - 1.0; + CHECK_EQ(num_anchors, param_.ratios.info.size() * param_.scales.info.size()); + std::vector anchors; + utils::GenerateAnchors(base_anchor, + param_.ratios.info, + param_.scales.info, + anchors); + + // Copy generated anchors to GPU + cudaMemcpy(workspace_proposals.dptr_, &anchors[0], sizeof(float) * anchors.size(), + cudaMemcpyHostToDevice); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // Copy proposals to a mesh grid + dim3 dimGrid((count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock); + dim3 dimBlock(kMaxThreadsPerBlock); + CheckLaunchParam(dimGrid, dimBlock, "ProposalGrid"); + ProposalGridKernel<<>>( + count, num_anchors, height, width, param_.feature_stride, + scores.dptr_, workspace_proposals.dptr_); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // im_info is small, we want to copy them to cpu + std::vector cpu_im_info(3); + cudaMemcpy(&cpu_im_info[0], im_info.dptr_, sizeof(float) * cpu_im_info.size(), cudaMemcpyDeviceToHost); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // Transform anchors and bbox_deltas into bboxes + CheckLaunchParam(dimGrid, dimBlock, "BBoxPred"); + BBoxPredKernel<<>>( + count, num_anchors, height, width, cpu_im_info[0], cpu_im_info[1], + workspace_proposals.dptr_, bbox_deltas.dptr_, workspace_proposals.dptr_); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + Tensor score = workspace_pre_nms[0]; + Tensor order = workspace_pre_nms[1]; + + // filter boxes with less than rpn_min_size + CheckLaunchParam(dimGrid, dimBlock, "FilterBox"); + FilterBoxKernel<<>>( + count, param_.rpn_min_size * cpu_im_info[2], workspace_proposals.dptr_); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // Copy score to a continuous memory + CheckLaunchParam(dimGrid, dimBlock, "CopyScore"); + CopyScoreKernel<<>>( + count, workspace_proposals.dptr_, score.dptr_, order.dptr_); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // argsort score, save order + thrust::sort_by_key(thrust::device, + score.dptr_, + score.dptr_ + score.size(0), + order.dptr_, + thrust::greater()); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // Reorder proposals according to order + const int top_n = param_.rpn_pre_nms_top_n; + dimGrid.x = (top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + CheckLaunchParam(dimGrid, dimBlock, "ReorderProposals"); + ReorderProposalsKernel<<>>( + top_n, workspace_proposals.dptr_, order.dptr_, top_n, workspace_ordered_proposals.dptr_); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + FRCNN_CUDA_CHECK(cudaFree(workspace_proposals_ptr)); + FRCNN_CUDA_CHECK(cudaFree(workspace_pre_nms_ptr)); + + // perform nms + std::vector _keep(workspace_ordered_proposals.size(0)); + int out_size = 0; + _nms(workspace_ordered_proposals, + param_.threshold, + &_keep[0], + &out_size); + + // copy nms result to gpu + int* keep; + FRCNN_CUDA_CHECK(cudaMalloc(&keep, sizeof(int) * _keep.size())); + cudaMemcpy(keep, &_keep[0], sizeof(int) * _keep.size(), cudaMemcpyHostToDevice); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // copy results after nms + const int post_top_n = param_.rpn_post_nms_top_n; + dimGrid.x = (post_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + CheckLaunchParam(dimGrid, dimBlock, "PrepareOutput"); + PrepareOutput<<>>( + post_top_n, workspace_ordered_proposals.dptr_, keep, out_size, + out.dptr_, out_score.dptr_); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // free temporary memory + FRCNN_CUDA_CHECK(cudaFree(keep)); + FRCNN_CUDA_CHECK(cudaFree(workspace_ordered_proposals_ptr)); + } + + private: + ProposalParam param_; +}; // class ProposalGPUOp + +template<> +Operator* CreateOp(ProposalParam param) { + return new ProposalGPUOp(param); +} +} // namespace op +} // namespace mxnet diff --git a/src/operator/rcnn_utils.h b/src/operator/rcnn_utils.h new file mode 100644 index 000000000000..d677596fc913 --- /dev/null +++ b/src/operator/rcnn_utils.h @@ -0,0 +1,249 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file rcnn_utils.h + * \brief Proposal Operator + * \author Bing Xu, Jian Guo +*/ +#ifndef MXNET_OPERATOR_RCNN_UTILS_H_ +#define MXNET_OPERATOR_RCNN_UTILS_H_ +#include +#include +#include +#include + +//======================== +// Anchor Generation Utils +//======================== +namespace mxnet { +namespace op { +namespace utils { + +inline void _MakeAnchor(float w, + float h, + float x_ctr, + float y_ctr, + std::vector& out_anchors) { + out_anchors.push_back(x_ctr - 0.5f * (w - 1.0f)); + out_anchors.push_back(y_ctr - 0.5f * (h - 1.0f)); + out_anchors.push_back(x_ctr + 0.5f * (w - 1.0f)); + out_anchors.push_back(y_ctr + 0.5f * (h - 1.0f)); + out_anchors.push_back(0.0f); +} + +inline void _Transform(float scale, + float ratio, + const std::vector& base_anchor, + std::vector& out_anchors) { + float w = base_anchor[2] - base_anchor[1] + 1.0f; + float h = base_anchor[3] - base_anchor[1] + 1.0f; + float x_ctr = base_anchor[0] + 0.5 * (w - 1.0f); + float y_ctr = base_anchor[1] + 0.5 * (h - 1.0f); + float size = w * h; + float size_ratios = std::floor(size / ratio); + float new_w = std::floor(std::sqrt(size_ratios) + 0.5f) * scale; + float new_h = std::floor((new_w / scale * ratio) + 0.5f) * scale; + + _MakeAnchor(new_w, new_h, x_ctr, + y_ctr, out_anchors); +} + +// out_anchors must have shape (n, 5), where n is ratios.size() * scales.size() +inline void GenerateAnchors(const std::vector& base_anchor, + const std::vector& ratios, + const std::vector& scales, + std::vector& out_anchors) { + for (size_t j = 0; j < ratios.size(); ++j) { + for (size_t k = 0; k < scales.size(); ++k) { + _Transform(scales[k], ratios[j], base_anchor, out_anchors); + } + } +} + +} // namespace utils +} // namespace op +} // namespace mxnet + +//============================ +// Bounding Box Transform Utils +//============================ +namespace mxnet { +namespace op { +namespace utils { + +inline void BBoxTransformInv(const mshadow::Tensor& boxes, + const mshadow::Tensor& deltas, + mshadow::Tensor *out_pred_boxes) { + CHECK_GE(boxes.size(1), 4); + CHECK_GE(out_pred_boxes->size(1), 4); + size_t anchors = deltas.size(1)/4; + size_t heights = deltas.size(2); + size_t widths = deltas.size(3); + + for (size_t a = 0; a < anchors; ++a) { + for (size_t h = 0; h < heights; ++h) { + for (size_t w = 0; w < widths; ++w) { + index_t index = h * (widths * anchors) + w * (anchors) + a; + float width = boxes[index][2] - boxes[index][0] + 1.0; + float height = boxes[index][3] - boxes[index][1] + 1.0; + float ctr_x = boxes[index][0] + 0.5 * width; + float ctr_y = boxes[index][1] + 0.5 * height; + + float dx = deltas[0][a*4 + 0][h][w]; + float dy = deltas[0][a*4 + 1][h][w]; + float dw = deltas[0][a*4 + 2][h][w]; + float dh = deltas[0][a*4 + 3][h][w]; + + float pred_ctr_x = dx * width + ctr_x; + float pred_ctr_y = dy * height + ctr_y; + float pred_w = exp(dw) * width; + float pred_h = exp(dh) * height; + + (*out_pred_boxes)[index][0] = pred_ctr_x - 0.5 * pred_w; + (*out_pred_boxes)[index][1] = pred_ctr_y - 0.5 * pred_h; + (*out_pred_boxes)[index][2] = pred_ctr_x + 0.5 * pred_w; + (*out_pred_boxes)[index][3] = pred_ctr_y + 0.5 * pred_h; + } + } + } +} + +inline void ClipBoxes(const mshadow::Shape<2>& im_shape, mshadow::Tensor *in_out_boxes) { + CHECK_GE(in_out_boxes->size(1), 4); + size_t num_boxes = in_out_boxes->size(0); + + for (size_t i=0; i < num_boxes; ++i) { + (*in_out_boxes)[i][0] = std::max(std::min((*in_out_boxes)[i][0], + static_cast (im_shape[1] - 1)), static_cast(0)); + (*in_out_boxes)[i][1] = std::max(std::min((*in_out_boxes)[i][1], + static_cast (im_shape[0] - 1)), static_cast(0)); + (*in_out_boxes)[i][2] = std::max(std::min((*in_out_boxes)[i][2], + static_cast (im_shape[1] - 1)), static_cast(0)); + (*in_out_boxes)[i][3] = std::max(std::min((*in_out_boxes)[i][3], + static_cast (im_shape[0] - 1)), static_cast(0)); + } +} + +} // namespace utils +} // namespace op +} // namespace mxnet + +//===================== +// NMS Utils +//===================== +namespace mxnet { +namespace op { +namespace utils { + +struct ReverseArgsortCompl { + const float *val_; + explicit ReverseArgsortCompl(float *val) + : val_(val) {} + bool operator() (float i, float j) { + return (val_[static_cast(i)] > + val_[static_cast(j)]); + } +}; + +// filter box by set confidence to zero +inline void FilterBox(mshadow::Tensor& dets, + const float min_size) { + for (index_t i = 0; i < dets.size(0); i++) { + float iw = dets[i][2] - dets[i][0] + 1.0f; + float ih = dets[i][3] - dets[i][1] + 1.0f; + if (iw < min_size || ih < min_size) { + dets[i][4] = 0.0f; + } + } +} + +// copy score and init order +inline void CopyScore(const mshadow::Tensor& dets, + mshadow::Tensor& score, + mshadow::Tensor& order) { + for (index_t i = 0; i < dets.size(0); i++) { + score[i] = dets[i][4]; + order[i] = i; + } +} + +// sort order array according to score +inline void ReverseArgsort(const mshadow::Tensor& score, + mshadow::Tensor& order) { + ReverseArgsortCompl cmpl(score.dptr_); + std::sort(order.dptr_, order.dptr_ + score.size(0), cmpl); +} + +// reorder proposals according to order and keep the pre_nms_top_n proposals +// dets.size(0) == pre_nms_top_n +inline void ReorderProposals(const mshadow::Tensor& prev_dets, + const mshadow::Tensor& order, + const index_t pre_nms_top_n, + mshadow::Tensor& dets) { + CHECK_EQ(dets.size(0), pre_nms_top_n); + for (index_t i = 0; i < dets.size(0); i++) { + const index_t index = order[i]; + for (index_t j = 0; j < dets.size(1); j++) { + dets[i][j] = prev_dets[index][j]; + } + } +} + +// greedily keep the max detections (already sorted) +inline void NonMaximumSuppression(const mshadow::Tensor& dets, + const float thresh, + const float min_size, + const index_t post_nms_top_n, + mshadow::Tensor& area, + mshadow::Tensor& suppressed, + mshadow::Tensor& keep, + index_t *out_size) { + CHECK_EQ(dets.shape_[1], 5) << "dets: [x1, y1, x2, y2, score]"; + CHECK_GT(dets.shape_[0], 0); + CHECK_EQ(dets.CheckContiguous(), true); + CHECK_EQ(area.CheckContiguous(), true); + CHECK_EQ(suppressed.CheckContiguous(), true); + CHECK_EQ(keep.CheckContiguous(), true); + // calculate area + for (index_t i = 0; i < dets.size(0); ++i) { + area[i] = (dets[i][2] - dets[i][0] + 1) * + (dets[i][3] - dets[i][1] + 1); + } + + // calculate nms + *out_size = 0; + for (index_t i = 0; i < dets.size(0) && (*out_size) < post_nms_top_n; ++i) { + float ix1 = dets[i][0]; + float iy1 = dets[i][1]; + float ix2 = dets[i][2]; + float iy2 = dets[i][3]; + float iarea = area[i]; + + if (suppressed[i] > 0.0f ) { + continue; + } + + keep[(*out_size)++] = i; + for (index_t j = i + 1; j < dets.size(0); j ++) { + if (suppressed[j] > 0.0f) { + continue; + } + float xx1 = std::max(ix1, dets[j][0]); + float yy1 = std::max(iy1, dets[j][1]); + float xx2 = std::min(ix2, dets[j][2]); + float yy2 = std::min(iy2, dets[j][3]); + float w = std::max(0.0f, xx2 - xx1 + 1.0f); + float h = std::max(0.0f, yy2 - yy1 + 1.0f); + float inter = w * h; + float ovr = inter / (iarea + area[j] - inter); + if (ovr > thresh) { + suppressed[j] = 1.0f; + } + } + } +} + +} // namespace utils +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_RCNN_UTILS_H_