Skip to content

Commit

Permalink
fix proposal memory invalid argument
Browse files Browse the repository at this point in the history
  • Loading branch information
ijkguo committed Jul 24, 2016
1 parent fab3594 commit f39da2d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
7 changes: 4 additions & 3 deletions src/operator/proposal-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,16 @@ class ProposalOp : public Operator{
index_t height = scores.size(2);
index_t width = scores.size(3);
index_t count = num_anchors * height * width;
index_t rpn_pre_nms_top_n = (param_.rpn_pre_nms_top_n > 0) ? param_.rpn_pre_nms_top_n : count;

Tensor<cpu, 2> workspace_proposals = ctx.requested[proposal::kTempResource].get_space<cpu>(
Shape2(count, 5), s);
Tensor<cpu, 2> workspace_ordered_proposals = ctx.requested[proposal::kTempResource].get_space<cpu>(
Shape2(param_.rpn_pre_nms_top_n, 5), s);
Shape2(rpn_pre_nms_top_n, 5), s);
Tensor<cpu, 2> workspace_pre_nms = ctx.requested[proposal::kTempResource].get_space<cpu>(
Shape2(2, count), s);
Tensor<cpu, 2> workspace_nms = ctx.requested[proposal::kTempResource].get_space<cpu>(
Shape2(3, param_.rpn_pre_nms_top_n), s);
Shape2(3, rpn_pre_nms_top_n), s);

// Generate anchors
std::vector<float> base_anchor(4);
Expand Down Expand Up @@ -156,7 +157,7 @@ class ProposalOp : public Operator{
order);
utils::ReorderProposals(workspace_proposals,
order,
param_.rpn_pre_nms_top_n,
rpn_pre_nms_top_n,
workspace_ordered_proposals);

real_t scale = im_info[0][2];
Expand Down
7 changes: 4 additions & 3 deletions src/operator/proposal.cu
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,14 @@ class ProposalGPUOp : public Operator{
index_t height = scores.size(2);
index_t width = scores.size(3);
index_t count = num_anchors * height * width; // count of total anchors
index_t rpn_pre_nms_top_n = (param_.rpn_pre_nms_top_n > 0) ? param_.rpn_pre_nms_top_n : count; // set to -1 for max

float* workspace_proposals_ptr = NULL;
FRCNN_CUDA_CHECK(cudaMalloc(&workspace_proposals_ptr, sizeof(float) * count * 5));
Tensor<xpu, 2> workspace_proposals(workspace_proposals_ptr, Shape2(count, 5));
float* workspace_ordered_proposals_ptr = NULL;
FRCNN_CUDA_CHECK(cudaMalloc(&workspace_ordered_proposals_ptr, sizeof(float) * param_.rpn_pre_nms_top_n * 5));
Tensor<xpu, 2> workspace_ordered_proposals(workspace_ordered_proposals_ptr, Shape2(param_.rpn_pre_nms_top_n, 5));
FRCNN_CUDA_CHECK(cudaMalloc(&workspace_ordered_proposals_ptr, sizeof(float) * rpn_pre_nms_top_n * 5));
Tensor<xpu, 2> workspace_ordered_proposals(workspace_ordered_proposals_ptr, Shape2(rpn_pre_nms_top_n, 5));
float* workspace_pre_nms_ptr = NULL;
FRCNN_CUDA_CHECK(cudaMalloc(&workspace_pre_nms_ptr, sizeof(float) * count * 2));
Tensor<xpu, 2> workspace_pre_nms(workspace_pre_nms_ptr, Shape2(2, count));
Expand Down Expand Up @@ -426,7 +427,7 @@ class ProposalGPUOp : public Operator{
FRCNN_CUDA_CHECK(cudaPeekAtLastError());

// Reorder proposals according to order
const int top_n = param_.rpn_pre_nms_top_n;
const int top_n = rpn_pre_nms_top_n;
dimGrid.x = (top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
CheckLaunchParam(dimGrid, dimBlock, "ReorderProposals");
ReorderProposalsKernel<<<dimGrid, dimBlock>>>(
Expand Down

0 comments on commit f39da2d

Please sign in to comment.