From 426f1e4986c8ea11d3fd66f377494fe360d0f821 Mon Sep 17 00:00:00 2001 From: Jian Guo Date: Mon, 31 Oct 2016 21:24:29 +0800 Subject: [PATCH] update proposal op --- src/operator/proposal-inl.h | 18 ++++++----- src/operator/proposal.cu | 63 +++++++++++++++++++------------------ 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/operator/proposal-inl.h b/src/operator/proposal-inl.h index 5b0c009b8382..23c8e39e23b7 100644 --- a/src/operator/proposal-inl.h +++ b/src/operator/proposal-inl.h @@ -176,26 +176,30 @@ class ProposalOp : public Operator{ // fill in output rois for (index_t i = 0; i < out.size(0); ++i) { - index_t index = keep[i]; //batch index 0 out[i][0] = 0; - for (index_t j = 0; j < 4; ++j) { - if (i < out_size) { + if (i < out_size) { + index_t index = keep[i]; + for (index_t j = 0; j < 4; ++j) { out[i][j + 1] = workspace_ordered_proposals[index][j]; - } else { - out[i][j + 1] = 0; + } + } else { + index_t index = keep[i % out_size]; + for (index_t j = 0; j < 4; ++j) { + out[i][j + 1] = workspace_ordered_proposals[index][j]; } } } // fill in output score for (index_t i = 0; i < out_score.size(0); i++) { - index_t index = keep[i]; if (i < out_size) { + index_t index = keep[i]; out_score[i][0] = workspace_ordered_proposals[index][4]; } else { - out_score[i][0] = 0; + index_t index = keep[i % out_size]; + out_score[i][0] = workspace_ordered_proposals[index][4]; } } } diff --git a/src/operator/proposal.cu b/src/operator/proposal.cu index 25fcffe52155..36f76ebfc768 100644 --- a/src/operator/proposal.cu +++ b/src/operator/proposal.cu @@ -130,7 +130,11 @@ __global__ void FilterBoxKernel(const int count, float iw = dets[index * 5 + 2] - dets[index * 5 + 0] + 1.0f; float ih = dets[index * 5 + 3] - dets[index * 5 + 1] + 1.0f; if (iw < min_size || ih < min_size) { - dets[index * 5 + 4] = 0.0f; + dets[index * 5 + 0] -= min_size / 2; + dets[index * 5 + 1] -= min_size / 2; + dets[index * 5 + 2] += min_size / 2; + dets[index * 5 + 3] += min_size / 2; + dets[index * 5 + 4] = -1.0f; } } } @@ -158,14 +162,10 @@ template __global__ void ReorderProposalsKernel(const int count, const Dtype* prev_dets, const int* order, - const int top_n, Dtype* dets) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) { - if (index > top_n) { - return; - } const int order_i = order[index]; for (int j = 0; j < 5; j ++) { dets[index * 5 + j] = prev_dets[order_i * 5 + j]; @@ -300,10 +300,11 @@ __global__ void PrepareOutput(const int count, } score[index] = dets[keep_i * 5 + 4]; } else { + int keep_i = keep[index % out_size]; for (int j = 0; j < 4; ++j) { - out[index * 5 + j + 1] = 0; + out[index * 5 + j + 1] = dets[keep_i * 5 + j]; } - score[index] = 0; + score[index] = dets[keep_i * 5 + 4]; } } } @@ -348,24 +349,10 @@ class ProposalGPUOp : public Operator{ Tensor out = out_data[proposal::kOut].get(s); Tensor out_score = out_data[proposal::kScore].get(s); - index_t num_anchors = in_data[proposal::kClsProb].shape_[1] / 2; - index_t height = scores.size(2); - index_t width = scores.size(3); - index_t count = num_anchors * height * width; // count of total anchors - index_t rpn_pre_nms_top_n = (param_.rpn_pre_nms_top_n > 0) ? param_.rpn_pre_nms_top_n : count; // set to -1 for max - - float* workspace_proposals_ptr = NULL; - FRCNN_CUDA_CHECK(cudaMalloc(&workspace_proposals_ptr, sizeof(float) * count * 5)); - Tensor workspace_proposals(workspace_proposals_ptr, Shape2(count, 5)); - float* workspace_ordered_proposals_ptr = NULL; - FRCNN_CUDA_CHECK(cudaMalloc(&workspace_ordered_proposals_ptr, sizeof(float) * rpn_pre_nms_top_n * 5)); - Tensor workspace_ordered_proposals(workspace_ordered_proposals_ptr, Shape2(rpn_pre_nms_top_n, 5)); - float* score_ptr = NULL; - FRCNN_CUDA_CHECK(cudaMalloc(&score_ptr, sizeof(float) * count)); - Tensor score(score_ptr, Shape1(count)); - int* order_ptr = NULL; - FRCNN_CUDA_CHECK(cudaMalloc(&order_ptr, sizeof(int) * count)); - Tensor order(order_ptr, Shape1(count)); + int num_anchors = in_data[proposal::kClsProb].shape_[1] / 2; + int height = scores.size(2); + int width = scores.size(3); + int count = num_anchors * height * width; // count of total anchors // Generate first anchors based on base anchor std::vector base_anchor(4); @@ -380,7 +367,11 @@ class ProposalGPUOp : public Operator{ param_.scales.info, anchors); - // Copy generated anchors to GPU + // Copy generated anchors to GPU + float* workspace_proposals_ptr = NULL; + FRCNN_CUDA_CHECK(cudaMalloc(&workspace_proposals_ptr, sizeof(float) * count * 5)); + Tensor workspace_proposals(workspace_proposals_ptr, Shape2(count, 5)); + cudaMemcpy(workspace_proposals.dptr_, &anchors[0], sizeof(float) * anchors.size(), cudaMemcpyHostToDevice); FRCNN_CUDA_CHECK(cudaPeekAtLastError()); @@ -413,6 +404,13 @@ class ProposalGPUOp : public Operator{ FRCNN_CUDA_CHECK(cudaPeekAtLastError()); // Copy score to a continuous memory + float* score_ptr = NULL; + FRCNN_CUDA_CHECK(cudaMalloc(&score_ptr, sizeof(float) * count)); + Tensor score(score_ptr, Shape1(count)); + int* order_ptr = NULL; + FRCNN_CUDA_CHECK(cudaMalloc(&order_ptr, sizeof(int) * count)); + Tensor order(order_ptr, Shape1(count)); + CheckLaunchParam(dimGrid, dimBlock, "CopyScore"); CopyScoreKernel<<>>( count, workspace_proposals.dptr_, score.dptr_, order.dptr_); @@ -427,11 +425,16 @@ class ProposalGPUOp : public Operator{ FRCNN_CUDA_CHECK(cudaPeekAtLastError()); // Reorder proposals according to order - const int top_n = std::min(rpn_pre_nms_top_n, count); - dimGrid.x = (top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + int rpn_pre_nms_top_n = (param_.rpn_pre_nms_top_n > 0) ? param_.rpn_pre_nms_top_n : count; // set to -1 for max + rpn_pre_nms_top_n = std::min(rpn_pre_nms_top_n, count); + float* workspace_ordered_proposals_ptr = NULL; + FRCNN_CUDA_CHECK(cudaMalloc(&workspace_ordered_proposals_ptr, sizeof(float) * rpn_pre_nms_top_n * 5)); + Tensor workspace_ordered_proposals(workspace_ordered_proposals_ptr, Shape2(rpn_pre_nms_top_n, 5)); + + dimGrid.x = (rpn_pre_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; CheckLaunchParam(dimGrid, dimBlock, "ReorderProposals"); ReorderProposalsKernel<<>>( - top_n, workspace_proposals.dptr_, order.dptr_, top_n, workspace_ordered_proposals.dptr_); + rpn_pre_nms_top_n, workspace_proposals.dptr_, order.dptr_, workspace_ordered_proposals.dptr_); FRCNN_CUDA_CHECK(cudaPeekAtLastError()); FRCNN_CUDA_CHECK(cudaFree(workspace_proposals_ptr)); @@ -453,7 +456,7 @@ class ProposalGPUOp : public Operator{ FRCNN_CUDA_CHECK(cudaPeekAtLastError()); // copy results after nms - const int post_top_n = param_.rpn_post_nms_top_n; + int post_top_n = std::min(param_.rpn_post_nms_top_n, rpn_pre_nms_top_n); dimGrid.x = (post_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; CheckLaunchParam(dimGrid, dimBlock, "PrepareOutput"); PrepareOutput<<>>(