From 8df40a635db3e70df5140a94ffc1f602e90ba4b2 Mon Sep 17 00:00:00 2001 From: PengfeiChen Date: Thu, 8 Jun 2017 16:33:36 -0700 Subject: [PATCH 1/2] fix proposal batch size > 1 --- src/operator/contrib/proposal-inl.h | 6 +- src/operator/contrib/proposal.cc | 354 +++++++++++++++------------- src/operator/contrib/proposal.cu | 251 ++++++++++---------- src/operator/roi_pooling-inl.h | 36 +-- src/operator/roi_pooling.cc | 30 +-- src/operator/roi_pooling.cu | 34 +-- 6 files changed, 370 insertions(+), 341 deletions(-) diff --git a/src/operator/contrib/proposal-inl.h b/src/operator/contrib/proposal-inl.h index ed0ec826588f..40d59aade18a 100644 --- a/src/operator/contrib/proposal-inl.h +++ b/src/operator/contrib/proposal-inl.h @@ -2,7 +2,7 @@ * Copyright (c) 2015 by Contributors * \file proposal-inl.h * \brief Proposal Operator - * \author Piotr Teterwak, Bing Xu, Jian Guo + * \author Piotr Teterwak, Bing Xu, Jian Guo, Pengfei Chen */ #ifndef MXNET_OPERATOR_CONTRIB_PROPOSAL_INL_H_ #define MXNET_OPERATOR_CONTRIB_PROPOSAL_INL_H_ @@ -186,9 +186,9 @@ class ProposalProp : public OperatorProperty { SHAPE_ASSIGN_CHECK(*in_shape, proposal::kImInfo, im_info_shape); out_shape->clear(); // output - out_shape->push_back(Shape2(param_.rpn_post_nms_top_n, 5)); + out_shape->push_back(Shape3(dshape[0], param_.rpn_post_nms_top_n, 5)); // score - out_shape->push_back(Shape2(param_.rpn_post_nms_top_n, 1)); + out_shape->push_back(Shape3(dshape[0], param_.rpn_post_nms_top_n, 1)); return true; } diff --git a/src/operator/contrib/proposal.cc b/src/operator/contrib/proposal.cc index 06a0565bf822..a118385a5f6f 100644 --- a/src/operator/contrib/proposal.cc +++ b/src/operator/contrib/proposal.cc @@ -15,55 +15,58 @@ namespace op { namespace utils { // bbox prediction and clip to the image borders -inline void BBoxTransformInv(const mshadow::Tensor& boxes, +inline void BBoxTransformInv(const mshadow::Tensor& boxes, const mshadow::Tensor& deltas, const float im_height, const float im_width, const int real_height, const int real_width, - mshadow::Tensor *out_pred_boxes) { + mshadow::Tensor *out_pred_boxes) { CHECK_GE(boxes.size(1), 4); CHECK_GE(out_pred_boxes->size(1), 4); + int nbatch = deltas.size(0); int anchors = deltas.size(1)/4; int heights = deltas.size(2); int widths = deltas.size(3); - for (int a = 0; a < anchors; ++a) { - for (int h = 0; h < heights; ++h) { - for (int w = 0; w < widths; ++w) { - index_t index = h * (widths * anchors) + w * (anchors) + a; - float width = boxes[index][2] - boxes[index][0] + 1.0; - float height = boxes[index][3] - boxes[index][1] + 1.0; - float ctr_x = boxes[index][0] + 0.5 * (width - 1.0); - float ctr_y = boxes[index][1] + 0.5 * (height - 1.0); - - float dx = deltas[0][a*4 + 0][h][w]; - float dy = deltas[0][a*4 + 1][h][w]; - float dw = deltas[0][a*4 + 2][h][w]; - float dh = deltas[0][a*4 + 3][h][w]; - - float pred_ctr_x = dx * width + ctr_x; - float pred_ctr_y = dy * height + ctr_y; - float pred_w = exp(dw) * width; - float pred_h = exp(dh) * height; - - float pred_x1 = pred_ctr_x - 0.5 * (pred_w - 1.0); - float pred_y1 = pred_ctr_y - 0.5 * (pred_h - 1.0); - float pred_x2 = pred_ctr_x + 0.5 * (pred_w - 1.0); - float pred_y2 = pred_ctr_y + 0.5 * (pred_h - 1.0); - - pred_x1 = std::max(std::min(pred_x1, im_width - 1.0f), 0.0f); - pred_y1 = std::max(std::min(pred_y1, im_height - 1.0f), 0.0f); - pred_x2 = std::max(std::min(pred_x2, im_width - 1.0f), 0.0f); - pred_y2 = std::max(std::min(pred_y2, im_height - 1.0f), 0.0f); - - (*out_pred_boxes)[index][0] = pred_x1; - (*out_pred_boxes)[index][1] = pred_y1; - (*out_pred_boxes)[index][2] = pred_x2; - (*out_pred_boxes)[index][3] = pred_y2; - - if (h >= real_height || w >= real_width) { - (*out_pred_boxes)[index][4] = -1.0; + for (int n = 0; n < nbatch; ++n) { + for (int a = 0; a < anchors; ++a) { + for (int h = 0; h < heights; ++h) { + for (int w = 0; w < widths; ++w) { + index_t index = h * (widths * anchors) + w * (anchors) + a; + float width = boxes[n][index][2] - boxes[n][index][0] + 1.0; + float height = boxes[n][index][3] - boxes[n][index][1] + 1.0; + float ctr_x = boxes[n][index][0] + 0.5 * (width - 1.0); + float ctr_y = boxes[n][index][1] + 0.5 * (height - 1.0); + + float dx = deltas[n][a*4 + 0][h][w]; + float dy = deltas[n][a*4 + 1][h][w]; + float dw = deltas[n][a*4 + 2][h][w]; + float dh = deltas[n][a*4 + 3][h][w]; + + float pred_ctr_x = dx * width + ctr_x; + float pred_ctr_y = dy * height + ctr_y; + float pred_w = exp(dw) * width; + float pred_h = exp(dh) * height; + + float pred_x1 = pred_ctr_x - 0.5 * (pred_w - 1.0); + float pred_y1 = pred_ctr_y - 0.5 * (pred_h - 1.0); + float pred_x2 = pred_ctr_x + 0.5 * (pred_w - 1.0); + float pred_y2 = pred_ctr_y + 0.5 * (pred_h - 1.0); + + pred_x1 = std::max(std::min(pred_x1, im_width - 1.0f), 0.0f); + pred_y1 = std::max(std::min(pred_y1, im_height - 1.0f), 0.0f); + pred_x2 = std::max(std::min(pred_x2, im_width - 1.0f), 0.0f); + pred_y2 = std::max(std::min(pred_y2, im_height - 1.0f), 0.0f); + + (*out_pred_boxes)[n][index][0] = pred_x1; + (*out_pred_boxes)[n][index][1] = pred_y1; + (*out_pred_boxes)[n][index][2] = pred_x2; + (*out_pred_boxes)[n][index][3] = pred_y2; + + if (h >= real_height || w >= real_width) { + (*out_pred_boxes)[n][index][4] = -1.0; + } } } } @@ -71,50 +74,52 @@ inline void BBoxTransformInv(const mshadow::Tensor& boxes, } // iou prediction and clip to the image border -inline void IoUTransformInv(const mshadow::Tensor& boxes, +inline void IoUTransformInv(const mshadow::Tensor& boxes, const mshadow::Tensor& deltas, const float im_height, const float im_width, const int real_height, const int real_width, - mshadow::Tensor *out_pred_boxes) { - CHECK_GE(boxes.size(1), 4); - CHECK_GE(out_pred_boxes->size(1), 4); + mshadow::Tensor *out_pred_boxes) { + CHECK_GE(boxes.size(2), 4); + CHECK_GE(out_pred_boxes->size(2), 4); + int nbatch = deltas.size(0); int anchors = deltas.size(1)/4; int heights = deltas.size(2); int widths = deltas.size(3); - - for (int a = 0; a < anchors; ++a) { - for (int h = 0; h < heights; ++h) { - for (int w = 0; w < widths; ++w) { - index_t index = h * (widths * anchors) + w * (anchors) + a; - float x1 = boxes[index][0]; - float y1 = boxes[index][1]; - float x2 = boxes[index][2]; - float y2 = boxes[index][3]; - - float dx1 = deltas[0][a * 4 + 0][h][w]; - float dy1 = deltas[0][a * 4 + 1][h][w]; - float dx2 = deltas[0][a * 4 + 2][h][w]; - float dy2 = deltas[0][a * 4 + 3][h][w]; - - float pred_x1 = x1 + dx1; - float pred_y1 = y1 + dy1; - float pred_x2 = x2 + dx2; - float pred_y2 = y2 + dy2; - - pred_x1 = std::max(std::min(pred_x1, im_width - 1.0f), 0.0f); - pred_y1 = std::max(std::min(pred_y1, im_height - 1.0f), 0.0f); - pred_x2 = std::max(std::min(pred_x2, im_width - 1.0f), 0.0f); - pred_y2 = std::max(std::min(pred_y2, im_height - 1.0f), 0.0f); - - (*out_pred_boxes)[index][0] = pred_x1; - (*out_pred_boxes)[index][1] = pred_y1; - (*out_pred_boxes)[index][2] = pred_x2; - (*out_pred_boxes)[index][3] = pred_y2; - - if (h >= real_height || w >= real_width) { - (*out_pred_boxes)[index][4] = -1.0f; + for (int n = 0; n < nbatch; ++n){ + for (int a = 0; a < anchors; ++a) { + for (int h = 0; h < heights; ++h) { + for (int w = 0; w < widths; ++w) { + index_t index = h * (widths * anchors) + w * (anchors) + a; + float x1 = boxes[n][index][0]; + float y1 = boxes[n][index][1]; + float x2 = boxes[n][index][2]; + float y2 = boxes[n][index][3]; + + float dx1 = deltas[n][a * 4 + 0][h][w]; + float dy1 = deltas[n][a * 4 + 1][h][w]; + float dx2 = deltas[n][a * 4 + 2][h][w]; + float dy2 = deltas[n][a * 4 + 3][h][w]; + + float pred_x1 = x1 + dx1; + float pred_y1 = y1 + dy1; + float pred_x2 = x2 + dx2; + float pred_y2 = y2 + dy2; + + pred_x1 = std::max(std::min(pred_x1, im_width - 1.0f), 0.0f); + pred_y1 = std::max(std::min(pred_y1, im_height - 1.0f), 0.0f); + pred_x2 = std::max(std::min(pred_x2, im_width - 1.0f), 0.0f); + pred_y2 = std::max(std::min(pred_y2, im_height - 1.0f), 0.0f); + + (*out_pred_boxes)[n][index][0] = pred_x1; + (*out_pred_boxes)[n][index][1] = pred_y1; + (*out_pred_boxes)[n][index][2] = pred_x2; + (*out_pred_boxes)[n][index][3] = pred_y2; + + if (h >= real_height || w >= real_width) { + (*out_pred_boxes)[n][index][4] = -1.0f; + } } } } @@ -123,17 +128,19 @@ inline void IoUTransformInv(const mshadow::Tensor& boxes, // filter box by set confidence to zero // * height or width < rpn_min_size -inline void FilterBox(mshadow::Tensor *dets, +inline void FilterBox(mshadow::Tensor *dets, const float min_size) { - for (index_t i = 0; i < dets->size(0); i++) { - float iw = (*dets)[i][2] - (*dets)[i][0] + 1.0f; - float ih = (*dets)[i][3] - (*dets)[i][1] + 1.0f; - if (iw < min_size || ih < min_size) { - (*dets)[i][0] -= min_size / 2; - (*dets)[i][1] -= min_size / 2; - (*dets)[i][2] += min_size / 2; - (*dets)[i][3] += min_size / 2; - (*dets)[i][4] = -1.0f; + for (index_t n = 0; n < dets->size(0); n++) { + for (index_t i = 0; i < dets->size(1); i++) { + float iw = (*dets)[n][i][2] - (*dets)[n][i][0] + 1.0f; + float ih = (*dets)[n][i][3] - (*dets)[n][i][1] + 1.0f; + if (iw < min_size || ih < min_size) { + (*dets)[n][i][0] -= min_size / 2; + (*dets)[n][i][1] -= min_size / 2; + (*dets)[n][i][2] += min_size / 2; + (*dets)[n][i][3] += min_size / 2; + (*dets)[n][i][4] = -1.0f; + } } } } @@ -160,12 +167,14 @@ struct ReverseArgsortCompl { }; // copy score and init order -inline void CopyScore(const mshadow::Tensor& dets, - mshadow::Tensor *score, - mshadow::Tensor *order) { - for (index_t i = 0; i < dets.size(0); i++) { - (*score)[i] = dets[i][4]; - (*order)[i] = i; +inline void CopyScore(const mshadow::Tensor& dets, + mshadow::Tensor *score, + mshadow::Tensor *order) { + for (index_t n = 0; n < dets.size(0); n++) { + for (index_t i = 0; i < dets.size(1); i++) { + (*score)[n][i] = dets[n][i][4]; + (*order)[n][i] = i; + } } } @@ -270,25 +279,24 @@ class ProposalOp : public Operator{ CHECK_EQ(out_data.size(), 2); CHECK_GT(req.size(), 1); CHECK_EQ(req[proposal::kOut], kWriteTo); - CHECK_EQ(in_data[proposal::kClsProb].shape_[0], 1) - << "Sorry, multiple images each device is not implemented."; Stream *s = ctx.get_stream(); - Shape<4> scores_shape = Shape4(in_data[proposal::kClsProb].shape_[0], - in_data[proposal::kClsProb].shape_[1] / 2, - in_data[proposal::kClsProb].shape_[2], - in_data[proposal::kClsProb].shape_[3]); - real_t* foreground_score_ptr = in_data[proposal::kClsProb].dptr() - + scores_shape.Size(); - Tensor scores = Tensor(foreground_score_ptr, scores_shape); + // Shape<3> fg_scores_shape = Shape3(in_data[proposal::kClsProb].shape_[1] / 2, + // in_data[proposal::kClsProb].shape_[2], + // in_data[proposal::kClsProb].shape_[3]); + + // real_t* foreground_score_ptr = in_data[proposal::kClsProb].dptr() + // + fg_scores_shape.Size(); + Tensor scores = in_data[proposal::kClsProb].get(s); Tensor bbox_deltas = in_data[proposal::kBBoxPred].get(s); Tensor im_info = in_data[proposal::kImInfo].get(s); - Tensor out = out_data[proposal::kOut].get(s); - Tensor out_score = out_data[proposal::kScore].get(s); + Tensor out = out_data[proposal::kOut].get(s); + Tensor out_score = out_data[proposal::kScore].get(s); - int num_anchors = in_data[proposal::kClsProb].shape_[1] / 2; + int nbatch = scores.size(0); + int num_anchors = scores.size(1) / 2; int height = scores.size(2); int width = scores.size(3); int count = num_anchors * height * width; @@ -296,19 +304,19 @@ class ProposalOp : public Operator{ rpn_pre_nms_top_n = std::min(rpn_pre_nms_top_n, count); int rpn_post_nms_top_n = std::min(param_.rpn_post_nms_top_n, rpn_pre_nms_top_n); - int workspace_size = count * 5 + 2 * count + rpn_pre_nms_top_n * 5 + 3 * rpn_pre_nms_top_n; + int workspace_size = nbatch * (count * 5 + 2 * count + rpn_pre_nms_top_n * 5 + 3 * rpn_pre_nms_top_n); Tensor workspace = ctx.requested[proposal::kTempResource].get_space( Shape1(workspace_size), s); int start = 0; - Tensor workspace_proposals(workspace.dptr_ + start, Shape2(count, 5)); - start += count * 5; - Tensor workspace_pre_nms(workspace.dptr_ + start, Shape2(2, count)); - start += 2 * count; - Tensor workspace_ordered_proposals(workspace.dptr_ + start, - Shape2(rpn_pre_nms_top_n, 5)); - start += rpn_pre_nms_top_n * 5; - Tensor workspace_nms(workspace.dptr_ + start, Shape2(3, rpn_pre_nms_top_n)); - start += 3 * rpn_pre_nms_top_n; + Tensor workspace_proposals(workspace.dptr_ + start, Shape3(nbatch, count, 5)); + start += nbatch * count * 5; + Tensor workspace_pre_nms(workspace.dptr_ + start, Shape3(2, nbatch, count)); + start += nbatch * 2 * count; + Tensor workspace_ordered_proposals(workspace.dptr_ + start, + Shape3(nbatch, rpn_pre_nms_top_n, 5)); + start += nbatch * rpn_pre_nms_top_n * 5; + Tensor workspace_nms(workspace.dptr_ + start, Shape3(3, nbatch, rpn_pre_nms_top_n)); + start += nbatch * 3 * rpn_pre_nms_top_n; CHECK_EQ(workspace_size, start) << workspace_size << " " << start << std::endl; // Generate anchors @@ -323,18 +331,22 @@ class ProposalOp : public Operator{ param_.ratios.info, param_.scales.info, &anchors); - std::memcpy(workspace_proposals.dptr_, &anchors[0], sizeof(float) * anchors.size()); + for(int n = 0; n < nbatch; n++) { + std::memcpy(workspace_proposals.dptr_ + n * 5 * count, &anchors[0], sizeof(float) * anchors.size()); + } // Enumerate all shifted anchors - for (index_t i = 0; i < num_anchors; ++i) { - for (index_t j = 0; j < height; ++j) { - for (index_t k = 0; k < width; ++k) { - index_t index = j * (width * num_anchors) + k * (num_anchors) + i; - workspace_proposals[index][0] = workspace_proposals[i][0] + k * param_.feature_stride; - workspace_proposals[index][1] = workspace_proposals[i][1] + j * param_.feature_stride; - workspace_proposals[index][2] = workspace_proposals[i][2] + k * param_.feature_stride; - workspace_proposals[index][3] = workspace_proposals[i][3] + j * param_.feature_stride; - workspace_proposals[index][4] = scores[0][i][j][k]; + for (index_t n = 0; n < nbatch; ++n) { + for (index_t i = 0; i < num_anchors; ++i) { + for (index_t j = 0; j < height; ++j) { + for (index_t k = 0; k < width; ++k) { + index_t index = j * (width * num_anchors) + k * (num_anchors) + i; + workspace_proposals[n][index][0] = workspace_proposals[n][i][0] + k * param_.feature_stride; + workspace_proposals[n][index][1] = workspace_proposals[n][i][1] + j * param_.feature_stride; + workspace_proposals[n][index][2] = workspace_proposals[n][i][2] + k * param_.feature_stride; + workspace_proposals[n][index][3] = workspace_proposals[n][i][3] + j * param_.feature_stride; + workspace_proposals[n][index][4] = scores[n][i + width * height * num_anchors][j][k]; + } } } } @@ -354,58 +366,66 @@ class ProposalOp : public Operator{ } utils::FilterBox(&workspace_proposals, param_.rpn_min_size * im_info[0][2]); - Tensor score = workspace_pre_nms[0]; - Tensor order = workspace_pre_nms[1]; + Tensor score = workspace_pre_nms[0]; + Tensor order = workspace_pre_nms[1]; utils::CopyScore(workspace_proposals, &score, &order); - utils::ReverseArgsort(score, - &order); - utils::ReorderProposals(workspace_proposals, - order, - rpn_pre_nms_top_n, - &workspace_ordered_proposals); - - index_t out_size = 0; - Tensor area = workspace_nms[0]; - Tensor suppressed = workspace_nms[1]; - Tensor keep = workspace_nms[2]; - suppressed = 0; // surprised! - - utils::NonMaximumSuppression(workspace_ordered_proposals, - param_.threshold, - rpn_post_nms_top_n, - &area, - &suppressed, - &keep, - &out_size); - - // fill in output rois - for (index_t i = 0; i < out.size(0); ++i) { - // batch index 0 - out[i][0] = 0; - if (i < out_size) { - index_t index = keep[i]; - for (index_t j = 0; j < 4; ++j) { - out[i][j + 1] = workspace_ordered_proposals[index][j]; - } - } else { - index_t index = keep[i % out_size]; - for (index_t j = 0; j < 4; ++j) { - out[i][j + 1] = workspace_ordered_proposals[index][j]; + + Tensor area = workspace_nms[0]; + Tensor suppressed = workspace_nms[1]; + Tensor keep = workspace_nms[2]; + + for(int n = 0; n < nbatch; n++) { + Tensor cur_order = order[n]; + Tensor cur_area = area[n]; + Tensor cur_keep = keep[n]; + Tensor cur_suppressed = suppressed[n]; + Tensor cur_workspace_ordered_proposals = workspace_ordered_proposals[n]; + utils::ReverseArgsort(score[n], + &cur_order); + utils::ReorderProposals(workspace_proposals[n], + cur_order, + rpn_pre_nms_top_n, + &cur_workspace_ordered_proposals); + index_t out_size = 0; + suppressed = 0; // surprised! + + utils::NonMaximumSuppression(cur_workspace_ordered_proposals, + param_.threshold, + rpn_post_nms_top_n, + &cur_area, + &cur_suppressed, + &cur_keep, + &out_size); + + // fill in output rois + for (index_t i = 0; i < out.size(1); ++i) { + // batch index 0 + out[n][i][0] = 0; + if (i < out_size) { + index_t index = cur_keep[i]; + for (index_t j = 0; j < 4; ++j) { + out[n][i][j + 1] = cur_workspace_ordered_proposals[index][j]; + } + } else { + index_t index = cur_keep[i % out_size]; + for (index_t j = 0; j < 4; ++j) { + out[n][i][j + 1] = cur_workspace_ordered_proposals[index][j]; + } } } - } - // fill in output score - for (index_t i = 0; i < out_score.size(0); i++) { - if (i < out_size) { - index_t index = keep[i]; - out_score[i][0] = workspace_ordered_proposals[index][4]; - } else { - index_t index = keep[i % out_size]; - out_score[i][0] = workspace_ordered_proposals[index][4]; + // fill in output score + for (index_t i = 0; i < out_score.size(1); i++) { + if (i < out_size) { + index_t index = cur_keep[i]; + out_score[n][i][0] = cur_workspace_ordered_proposals[index][4]; + } else { + index_t index = cur_keep[i % out_size]; + out_score[n][i][0] = cur_workspace_ordered_proposals[index][4]; + } } } } diff --git a/src/operator/contrib/proposal.cu b/src/operator/contrib/proposal.cu index ce1e9e5945d0..9f7acb7fd4a5 100644 --- a/src/operator/contrib/proposal.cu +++ b/src/operator/contrib/proposal.cu @@ -2,7 +2,7 @@ * Copyright (c) 2015 by Contributors * \file proposal.cu * \brief Proposal Operator - * \author Shaoqing Ren, Jian Guo + * \author Shaoqing Ren, Jian Guo, Pengfei Chen */ #include #include @@ -344,12 +344,13 @@ __global__ void PrepareOutput(const int count, const Dtype* dets, const int* keep, const int out_size, + const int batchIdx, Dtype* out, Dtype* score) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) { - out[index * 5] = 0; + out[index * 5] = batchIdx; if (index < out_size) { int keep_i = keep[index]; for (int j = 0; j < 4; ++j) { @@ -391,25 +392,22 @@ class ProposalGPUOp : public Operator{ CHECK_EQ(out_data.size(), 2); CHECK_GT(req.size(), 1); CHECK_EQ(req[proposal::kOut], kWriteTo); - CHECK_EQ(in_data[proposal::kClsProb].shape_[0], 1) - << "Sorry, multiple images each device is not implemented."; Stream *s = ctx.get_stream(); - Shape<4> fg_scores_shape = Shape4(in_data[proposal::kClsProb].shape_[0], - in_data[proposal::kClsProb].shape_[1] / 2, + Shape<3> fg_scores_shape = Shape3(in_data[proposal::kClsProb].shape_[1] / 2, in_data[proposal::kClsProb].shape_[2], in_data[proposal::kClsProb].shape_[3]); - real_t* foreground_score_ptr = in_data[proposal::kClsProb].dptr() - + fg_scores_shape.Size(); - Tensor scores = Tensor(foreground_score_ptr, fg_scores_shape); + + Tensor scores = in_data[proposal::kClsProb].get(s); Tensor bbox_deltas = in_data[proposal::kBBoxPred].get(s); Tensor im_info = in_data[proposal::kImInfo].get(s); - Tensor out = out_data[proposal::kOut].get(s); - Tensor out_score = out_data[proposal::kScore].get(s); + Tensor out = out_data[proposal::kOut].get(s); + Tensor out_score = out_data[proposal::kScore].get(s); - int num_anchors = in_data[proposal::kClsProb].shape_[1] / 2; + int nbatch = scores.size(0); + int num_anchors = scores.size(1) / 2; int height = scores.size(2); int width = scores.size(3); int count = num_anchors * height * width; // count of total anchors @@ -433,117 +431,128 @@ class ProposalGPUOp : public Operator{ // Copy generated anchors to GPU float* workspace_proposals_ptr = NULL; - FRCNN_CUDA_CHECK(cudaMalloc(&workspace_proposals_ptr, sizeof(float) * count * 5)); - Tensor workspace_proposals(workspace_proposals_ptr, Shape2(count, 5)); - FRCNN_CUDA_CHECK(cudaMemcpy(workspace_proposals.dptr_, - &anchors[0], sizeof(float) * anchors.size(), - cudaMemcpyHostToDevice)); - - // Copy proposals to a mesh grid - dim3 dimGrid((count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock); - dim3 dimBlock(kMaxThreadsPerBlock); - CheckLaunchParam(dimGrid, dimBlock, "ProposalGrid"); - ProposalGridKernel<<>>( - count, num_anchors, height, width, param_.feature_stride, - scores.dptr_, workspace_proposals.dptr_); - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); - - // im_info is small, we want to copy them to cpu - std::vector cpu_im_info(3); - FRCNN_CUDA_CHECK(cudaMemcpy(&cpu_im_info[0], im_info.dptr_, - sizeof(float) * cpu_im_info.size(), - cudaMemcpyDeviceToHost)); - - // prevent padded predictions - int real_height = static_cast(cpu_im_info[0] / param_.feature_stride); - int real_width = static_cast(cpu_im_info[1] / param_.feature_stride); - CHECK_GE(height, real_height) << height << " " << real_height << std::endl; - CHECK_GE(width, real_width) << width << " " << real_width << std::endl; - - // Transform anchors and bbox_deltas into bboxes - CheckLaunchParam(dimGrid, dimBlock, "BBoxPred"); - if (param_.iou_loss) { - IoUPredKernel<<>>( - count, num_anchors, height, width, real_height, real_width, - cpu_im_info[0], cpu_im_info[1], - workspace_proposals.dptr_, bbox_deltas.dptr_, workspace_proposals.dptr_); - } else { - BBoxPredKernel<<>>( - count, num_anchors, height, width, real_height, real_width, - cpu_im_info[0], cpu_im_info[1], - workspace_proposals.dptr_, bbox_deltas.dptr_, workspace_proposals.dptr_); + FRCNN_CUDA_CHECK(cudaMalloc(&workspace_proposals_ptr, sizeof(float) * nbatch * count * 5)); + Tensor workspace_proposals(workspace_proposals_ptr, Shape3(nbatch, count, 5)); + + // im_info is small, we want to copy them to cpu + std::vector cpu_im_info(3); + FRCNN_CUDA_CHECK(cudaMemcpy(&cpu_im_info[0], im_info.dptr_, + sizeof(float) * cpu_im_info.size(), + cudaMemcpyDeviceToHost)); + + // prevent padded predictions + int real_height = static_cast(cpu_im_info[0] / param_.feature_stride); + int real_width = static_cast(cpu_im_info[1] / param_.feature_stride); + CHECK_GE(height, real_height) << height << " " << real_height << std::endl; + CHECK_GE(width, real_width) << width << " " << real_width << std::endl; + + // Copy anchors for all images in batch + for (int i = 0; i < nbatch; i++) { + float* cur_batch_workspace_proposals_ptr = workspace_proposals.dptr_ + i * 5 * count; + FRCNN_CUDA_CHECK(cudaMemcpy(cur_batch_workspace_proposals_ptr, + &anchors[0], sizeof(float) * anchors.size(), + cudaMemcpyHostToDevice)); + + // get current batch foreground score + real_t* foreground_score_ptr = reinterpret_cast(scores.dptr_) + i * 2 * count + + fg_scores_shape.Size(); + Tensor fg_scores = Tensor(foreground_score_ptr, fg_scores_shape); + + // Copy proposals to a mesh grid + dim3 dimGrid((count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock); + dim3 dimBlock(kMaxThreadsPerBlock); + CheckLaunchParam(dimGrid, dimBlock, "ProposalGrid"); + ProposalGridKernel<<>>( + count, num_anchors, height, width, param_.feature_stride, + fg_scores.dptr_, cur_batch_workspace_proposals_ptr); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // Transform anchors and bbox_deltas into bboxes + CheckLaunchParam(dimGrid, dimBlock, "BBoxPred"); + if (param_.iou_loss) { + IoUPredKernel<<>>( + count, num_anchors, height, width, real_height, real_width, + cpu_im_info[0], cpu_im_info[1], + cur_batch_workspace_proposals_ptr, bbox_deltas.dptr_ + i * 4 * count, cur_batch_workspace_proposals_ptr); + } else { + BBoxPredKernel<<>>( + count, num_anchors, height, width, real_height, real_width, + cpu_im_info[0], cpu_im_info[1], + cur_batch_workspace_proposals_ptr, bbox_deltas.dptr_ + i * 4 * count, cur_batch_workspace_proposals_ptr); + } + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // filter boxes with less than rpn_min_size + CheckLaunchParam(dimGrid, dimBlock, "FilterBox"); + FilterBoxKernel<<>>( + count, param_.rpn_min_size * cpu_im_info[2], cur_batch_workspace_proposals_ptr); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // Copy score to a continuous memory + float* score_ptr = NULL; + FRCNN_CUDA_CHECK(cudaMalloc(&score_ptr, sizeof(float) * count)); + Tensor score(score_ptr, Shape1(count)); + int* order_ptr = NULL; + FRCNN_CUDA_CHECK(cudaMalloc(&order_ptr, sizeof(int) * count)); + Tensor order(order_ptr, Shape1(count)); + + CheckLaunchParam(dimGrid, dimBlock, "CopyScore"); + CopyScoreKernel<<>>( + count, cur_batch_workspace_proposals_ptr, score.dptr_, order.dptr_); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // argsort score, save order + thrust::stable_sort_by_key(thrust::device, + score.dptr_, + score.dptr_ + score.size(0), + order.dptr_, + thrust::greater()); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // Reorder proposals according to order + float* workspace_ordered_proposals_ptr = NULL; + FRCNN_CUDA_CHECK(cudaMalloc(&workspace_ordered_proposals_ptr, + sizeof(float) * rpn_pre_nms_top_n * 5)); + Tensor workspace_ordered_proposals(workspace_ordered_proposals_ptr, + Shape2(rpn_pre_nms_top_n, 5)); + + dimGrid.x = (rpn_pre_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + CheckLaunchParam(dimGrid, dimBlock, "ReorderProposals"); + ReorderProposalsKernel<<>>( + rpn_pre_nms_top_n, cur_batch_workspace_proposals_ptr, order.dptr_, workspace_ordered_proposals.dptr_); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + FRCNN_CUDA_CHECK(cudaFree(score_ptr)); + FRCNN_CUDA_CHECK(cudaFree(order_ptr)); + + // perform nms + std::vector _keep(workspace_ordered_proposals.size(0)); + int out_size = 0; + _nms(workspace_ordered_proposals, + param_.threshold, + &_keep[0], + &out_size); + + // copy nms result to gpu + int* keep; + FRCNN_CUDA_CHECK(cudaMalloc(&keep, sizeof(int) * _keep.size())); + FRCNN_CUDA_CHECK(cudaMemcpy(keep, &_keep[0], sizeof(int) * _keep.size(), + cudaMemcpyHostToDevice)); + + // copy results after nms + dimGrid.x = (rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + CheckLaunchParam(dimGrid, dimBlock, "PrepareOutput"); + PrepareOutput<<>>( + rpn_post_nms_top_n, workspace_ordered_proposals.dptr_, keep, out_size, i, + out.dptr_ + i * 5 * rpn_post_nms_top_n, + out_score.dptr_ + i * rpn_post_nms_top_n); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // free temporary memory + FRCNN_CUDA_CHECK(cudaFree(keep)); + FRCNN_CUDA_CHECK(cudaFree(workspace_ordered_proposals_ptr)); } - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); - - // filter boxes with less than rpn_min_size - CheckLaunchParam(dimGrid, dimBlock, "FilterBox"); - FilterBoxKernel<<>>( - count, param_.rpn_min_size * cpu_im_info[2], workspace_proposals.dptr_); - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); - - // Copy score to a continuous memory - float* score_ptr = NULL; - FRCNN_CUDA_CHECK(cudaMalloc(&score_ptr, sizeof(float) * count)); - Tensor score(score_ptr, Shape1(count)); - int* order_ptr = NULL; - FRCNN_CUDA_CHECK(cudaMalloc(&order_ptr, sizeof(int) * count)); - Tensor order(order_ptr, Shape1(count)); - - CheckLaunchParam(dimGrid, dimBlock, "CopyScore"); - CopyScoreKernel<<>>( - count, workspace_proposals.dptr_, score.dptr_, order.dptr_); - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); - - // argsort score, save order - thrust::stable_sort_by_key(thrust::device, - score.dptr_, - score.dptr_ + score.size(0), - order.dptr_, - thrust::greater()); - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); - - // Reorder proposals according to order - float* workspace_ordered_proposals_ptr = NULL; - FRCNN_CUDA_CHECK(cudaMalloc(&workspace_ordered_proposals_ptr, - sizeof(float) * rpn_pre_nms_top_n * 5)); - Tensor workspace_ordered_proposals(workspace_ordered_proposals_ptr, - Shape2(rpn_pre_nms_top_n, 5)); - - dimGrid.x = (rpn_pre_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; - CheckLaunchParam(dimGrid, dimBlock, "ReorderProposals"); - ReorderProposalsKernel<<>>( - rpn_pre_nms_top_n, workspace_proposals.dptr_, order.dptr_, workspace_ordered_proposals.dptr_); - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); - FRCNN_CUDA_CHECK(cudaFree(workspace_proposals_ptr)); - FRCNN_CUDA_CHECK(cudaFree(score_ptr)); - FRCNN_CUDA_CHECK(cudaFree(order_ptr)); - - // perform nms - std::vector _keep(workspace_ordered_proposals.size(0)); - int out_size = 0; - _nms(workspace_ordered_proposals, - param_.threshold, - &_keep[0], - &out_size); - - // copy nms result to gpu - int* keep; - FRCNN_CUDA_CHECK(cudaMalloc(&keep, sizeof(int) * _keep.size())); - FRCNN_CUDA_CHECK(cudaMemcpy(keep, &_keep[0], sizeof(int) * _keep.size(), - cudaMemcpyHostToDevice)); - - // copy results after nms - dimGrid.x = (rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; - CheckLaunchParam(dimGrid, dimBlock, "PrepareOutput"); - PrepareOutput<<>>( - rpn_post_nms_top_n, workspace_ordered_proposals.dptr_, keep, out_size, - out.dptr_, out_score.dptr_); - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); - - // free temporary memory - FRCNN_CUDA_CHECK(cudaFree(keep)); - FRCNN_CUDA_CHECK(cudaFree(workspace_ordered_proposals_ptr)); } virtual void Backward(const OpContext &ctx, diff --git a/src/operator/roi_pooling-inl.h b/src/operator/roi_pooling-inl.h index cc1555d8c330..1e3512c56800 100644 --- a/src/operator/roi_pooling-inl.h +++ b/src/operator/roi_pooling-inl.h @@ -60,14 +60,14 @@ class ROIPoolingOp : public Operator { size_t expected = 2; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), expected); - CHECK_EQ(out_data[roipool::kOut].shape_[0], in_data[roipool::kBox].shape_[0]); - CHECK_EQ(out_data[roipool::kMaxIdx].shape_[0], in_data[roipool::kBox].shape_[0]); + CHECK_EQ(out_data[roipool::kOut].shape_[1], in_data[roipool::kBox].shape_[1]); + CHECK_EQ(out_data[roipool::kMaxIdx].shape_[1], in_data[roipool::kBox].shape_[1]); Stream *s = ctx.get_stream(); Tensor data = in_data[roipool::kData].get(s); - Tensor bbox = in_data[roipool::kBox].get(s); - Tensor out = out_data[roipool::kOut].get(s); - Tensor max_idx = out_data[roipool::kMaxIdx].get(s); + Tensor bbox = in_data[roipool::kBox].get(s); + Tensor out = out_data[roipool::kOut].get(s); + Tensor max_idx = out_data[roipool::kMaxIdx].get(s); CHECK_EQ(data.CheckContiguous(), true); CHECK_EQ(bbox.CheckContiguous(), true); CHECK_EQ(out.CheckContiguous(), true); @@ -88,19 +88,19 @@ class ROIPoolingOp : public Operator { size_t expected = 2; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), expected); - CHECK_EQ(out_grad[roipool::kOut].shape_[0], in_data[roipool::kBox].shape_[0]); - CHECK_EQ(out_data[roipool::kMaxIdx].shape_[0], in_data[roipool::kBox].shape_[0]); + CHECK_EQ(out_grad[roipool::kOut].shape_[1], in_data[roipool::kBox].shape_[1]); + CHECK_EQ(out_data[roipool::kMaxIdx].shape_[1], in_data[roipool::kBox].shape_[1]); CHECK_NE(req[roipool::kData], kWriteInplace) << "ROIPooling: Backward doesn't support kWriteInplace."; CHECK_NE(req[roipool::kBox], kWriteInplace) << "ROIPooling: Backward doesn't support kWriteInplace."; Stream *s = ctx.get_stream(); - Tensor grad_out = out_grad[roipool::kOut].get(s); - Tensor bbox = in_data[roipool::kBox].get(s); - Tensor max_idx = out_data[roipool::kMaxIdx].get(s); + Tensor grad_out = out_grad[roipool::kOut].get(s); + Tensor bbox = in_data[roipool::kBox].get(s); + Tensor max_idx = out_data[roipool::kMaxIdx].get(s); Tensor grad_in = in_grad[roipool::kData].get(s); - Tensor grad_roi = in_grad[roipool::kBox].get(s); + Tensor grad_roi = in_grad[roipool::kBox].get(s); CHECK_EQ(grad_out.CheckContiguous(), true); CHECK_EQ(bbox.CheckContiguous(), true); CHECK_EQ(max_idx.CheckContiguous(), true); @@ -161,18 +161,18 @@ class ROIPoolingProp : public OperatorProperty { TShape dshape = in_shape->at(roipool::kData); CHECK_EQ(dshape.ndim(), 4U) << "data should be a 4D tensor"; - // bbox: [num_rois, 5] + // bbox: [batch_size, num_rois, 5] TShape bshape = in_shape->at(roipool::kBox); - CHECK_EQ(bshape.ndim(), 2U) << "bbox should be a 2D tensor of shape [batch, 5]"; - CHECK_EQ(bshape[1], 5U) << "bbox should be a 2D tensor of shape [batch, 5]"; + CHECK_EQ(bshape.ndim(), 3U) << "bbox should be a 3D tensor of shape [batch, num_rois, 5]"; + CHECK_EQ(bshape[2], 5U) << "bbox should be a 3D tensor of shape [batch, num_rois, 5]"; - // out: [num_rois, c, pooled_h, pooled_w] - // max_idx: [num_rois, c, pooled_h, pooled_w] + // out: [batch_size, num_rois, c, pooled_h, pooled_w] + // max_idx: [batch_size, num_rois, c, pooled_h, pooled_w] out_shape->clear(); out_shape->push_back( - Shape4(bshape[0], dshape[1], param_.pooled_size[0], param_.pooled_size[1])); + Shape5(dshape[0], bshape[1], dshape[1], param_.pooled_size[0], param_.pooled_size[1])); out_shape->push_back( - Shape4(bshape[0], dshape[1], param_.pooled_size[0], param_.pooled_size[1])); + Shape5(dshape[0], bshape[1], dshape[1], param_.pooled_size[0], param_.pooled_size[1])); return true; } diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 9c5d7c1ca5d6..452a2abac3e2 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -18,10 +18,10 @@ using std::ceil; namespace mshadow { template -inline void ROIPoolForward(const Tensor &out, +inline void ROIPoolForward(const Tensor &out, const Tensor &data, - const Tensor &bbox, - const Tensor &max_idx, + const Tensor &bbox, + const Tensor &max_idx, const float spatial_scale_, const float pad_ratio_) { const Dtype *bottom_data = data.dptr_; @@ -31,10 +31,10 @@ inline void ROIPoolForward(const Tensor &out, const int channels_ = data.size(1); const int height_ = data.size(2); const int width_ = data.size(3); - const int pooled_height_ = out.size(2); - const int pooled_width_ = out.size(3); + const int pooled_height_ = out.size(3); + const int pooled_width_ = out.size(4); - const int num_rois = bbox.size(0); + const int num_rois = bbox.size(1); const int batch_size = data.size(0); const int data_size = data.size(1) * data.size(2) * data.size(3); // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R @@ -101,11 +101,11 @@ inline void ROIPoolForward(const Tensor &out, } // Increment all data pointers by one channel batch_data += data.size(2) * data.size(3); - top_data += out.size(2) * out.size(3); - argmax_data += max_idx.size(2) * max_idx.size(3); + top_data += out.size(3) * out.size(4); + argmax_data += max_idx.size(3) * max_idx.size(4); } // Increment ROI data pointer - bottom_rois += bbox.size(1); + bottom_rois += bbox.size(2); } return; @@ -113,9 +113,9 @@ inline void ROIPoolForward(const Tensor &out, template inline void ROIPoolBackwardAcc(const Tensor &in_grad, - const Tensor &out_grad, - const Tensor &bbox, - const Tensor &max_idx, + const Tensor &out_grad, + const Tensor &bbox, + const Tensor &max_idx, const float spatial_scale_, const float pad_ratio_) { const Dtype *top_diff = out_grad.dptr_; @@ -127,10 +127,10 @@ inline void ROIPoolBackwardAcc(const Tensor &in_grad, const int channels_ = in_grad.size(1); const int height_ = in_grad.size(2); const int width_ = in_grad.size(3); - const int pooled_height_ = out_grad.size(2); - const int pooled_width_ = out_grad.size(3); + const int pooled_height_ = out_grad.size(4); + const int pooled_width_ = out_grad.size(5); - const int num_rois = bbox.size(0); + const int num_rois = bbox.size(1); for (int b = 0; b < batch_size_; ++b) { for (int c = 0; c < channels_; ++c) { diff --git a/src/operator/roi_pooling.cu b/src/operator/roi_pooling.cu index 677ab83efa61..38f5d67db5d5 100644 --- a/src/operator/roi_pooling.cu +++ b/src/operator/roi_pooling.cu @@ -89,10 +89,10 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data, } template -inline void ROIPoolForward(const Tensor &out, +inline void ROIPoolForward(const Tensor &out, const Tensor &data, - const Tensor &bbox, - const Tensor &max_idx, + const Tensor &bbox, + const Tensor &max_idx, const float spatial_scale, const float pad_ratio) { const Dtype *bottom_data = data.dptr_; @@ -103,8 +103,8 @@ inline void ROIPoolForward(const Tensor &out, const int channels = data.size(1); const int height = data.size(2); const int width = data.size(3); - const int pooled_height = out.size(2); - const int pooled_width = out.size(3); + const int pooled_height = out.size(3); + const int pooled_width = out.size(4); const int gridSize = (count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; dim3 dimGrid(kMaxGridDim, (gridSize + kMaxGridDim - 1) / kMaxGridDim); dim3 dimBlock(kMaxThreadsPerBlock); @@ -195,9 +195,9 @@ __global__ void ROIPoolBackwardAccKernel(const int count, const Dtype* top_diff, template inline void ROIPoolBackwardAcc(const Tensor &in_grad, - const Tensor &out_grad, - const Tensor &bbox, - const Tensor &max_idx, + const Tensor &out_grad, + const Tensor &bbox, + const Tensor &max_idx, const float spatial_scale, const float pad_ratio) { const Dtype *top_diff = out_grad.dptr_; @@ -205,12 +205,12 @@ inline void ROIPoolBackwardAcc(const Tensor &in_grad, Dtype *bottom_diff = in_grad.dptr_; Dtype *argmax_data = max_idx.dptr_; const int count = in_grad.shape_.Size(); - const int num_rois = bbox.size(0); + const int num_rois = bbox.size(1); const int channels = in_grad.size(1); const int height = in_grad.size(2); const int width = in_grad.size(3); - const int pooled_height = out_grad.size(2); - const int pooled_width = out_grad.size(3); + const int pooled_height = out_grad.size(3); + const int pooled_width = out_grad.size(4); const int gridSize = (count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; dim3 dimGrid(kMaxGridDim, (gridSize + kMaxGridDim - 1) / kMaxGridDim); dim3 dimBlock(kMaxThreadsPerBlock); @@ -224,10 +224,10 @@ inline void ROIPoolBackwardAcc(const Tensor &in_grad, } // namespace cuda template -inline void ROIPoolForward(const Tensor &out, +inline void ROIPoolForward(const Tensor &out, const Tensor &data, - const Tensor &bbox, - const Tensor &max_idx, + const Tensor &bbox, + const Tensor &max_idx, const float spatial_scale, const float pad_ratio) { cuda::ROIPoolForward(out, data, bbox, max_idx, spatial_scale, pad_ratio); @@ -235,9 +235,9 @@ inline void ROIPoolForward(const Tensor &out, template inline void ROIPoolBackwardAcc(const Tensor &in_grad, - const Tensor &out_grad, - const Tensor &bbox, - const Tensor &max_idx, + const Tensor &out_grad, + const Tensor &bbox, + const Tensor &max_idx, const float spatial_scale, const float pad_ratio) { cuda::ROIPoolBackwardAcc(in_grad, out_grad, bbox, max_idx, spatial_scale, pad_ratio); From f0600fb8244225ac97067e7c0fe41929075bdcd8 Mon Sep 17 00:00:00 2001 From: PengfeiChen Date: Thu, 8 Jun 2017 17:15:09 -0700 Subject: [PATCH 2/2] fix profiler for real layer name rather than just "category" --- include/mxnet/engine.h | 14 ++++++++++---- src/engine/naive_engine.cc | 21 ++++++++++++++++----- src/engine/profiler.cc | 5 +++-- src/engine/profiler.h | 2 ++ src/engine/threaded_engine.cc | 11 +++++++---- src/engine/threaded_engine.h | 13 ++++++++++--- src/executor/graph_executor.cc | 15 +++++++++++++-- src/executor/graph_executor.h | 2 ++ 8 files changed, 63 insertions(+), 20 deletions(-) diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index ed46c84cfe83..72f0bb45a63c 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -108,13 +108,15 @@ class MXNET_API Engine { * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. * \param opr_name The operator name. + * \param attr_name The attribute name. * \return The new operator allocated. */ virtual OprHandle NewOperator(AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, - const char* opr_name = nullptr) = 0; + const char* opr_name = nullptr, + const char* attr_name = nullptr) = 0; /*! * \brief Delete the given operator. * \param op The operator to delete. @@ -143,13 +145,15 @@ class MXNET_API Engine { * \param prop Property of the function. * \param priority Priority of the action, as hint to the engine. * \param opr_name The operator name. + * \param attr_name The attribute name. */ virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, - const char* opr_name = nullptr) = 0; + const char* opr_name = nullptr, + const char* attr_name = nullptr) = 0; /*! * \brief Schedule the deletion of a variable. * @@ -199,6 +203,7 @@ class MXNET_API Engine { * \param prop Property of the function. * \param priority Priority of the action, as hint to the engine. * \param opr_name The operator name. + * \param attr_name The attribute name. * \tparam SyncFn the synchronous function to be pushed. */ inline void PushSync(SyncFn exec_fn, Context exec_ctx, @@ -206,11 +211,12 @@ class MXNET_API Engine { std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, - const char* opr_name = nullptr) { + const char* opr_name = nullptr, + const char* attr_name = nullptr) { this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) { exec_fn(ctx); on_complete(); - }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name); + }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name, attr_name); } /*! diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index efb7bd44981b..b6eee3e7d615 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -21,6 +21,7 @@ class NaiveEngine final : public Engine { std::vector mutable_vars; FnProperty prop; const char* opr_name; + const char* attr_name; /*! \brief indicate whether to profile this operator */ bool profiling{false}; /*! \brief operator execution statistics */ @@ -53,13 +54,15 @@ class NaiveEngine final : public Engine { std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, - const char* opr_name = nullptr) override { + const char* opr_name = nullptr, + const char* attr_name = nullptr) override { NaiveOpr *opr = new NaiveOpr(); opr->fn = fn; opr->const_vars = const_vars; opr->mutable_vars = mutable_vars; opr->prop = prop; opr->opr_name = opr_name; + opr->attr_name = attr_name; return opr; } @@ -81,6 +84,9 @@ class NaiveEngine final : public Engine { strncpy(opr->opr_stat->opr_name, opr->opr_name, sizeof(opr->opr_stat->opr_name) - 1); + strncpy(opr->opr_stat->attr_name, + opr->attr_name, + sizeof(opr->opr_stat->attr_name) - 1); SetOprStart(opr->opr_stat); } opr->fn(ctx, on_complete); @@ -96,7 +102,8 @@ class NaiveEngine final : public Engine { opr->mutable_vars, opr->prop, priority, - PROFILER_MESSAGE(opr->opr_name)); + PROFILER_MESSAGE(opr->opr_name), + PROFILER_MESSAGE(opr->attr_name)); } void PushAsync(AsyncFn exec_fun, @@ -105,7 +112,8 @@ class NaiveEngine final : public Engine { std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, - const char* opr_name = nullptr) override { + const char* opr_name = nullptr, + const char* attr_name = nullptr) override { CallbackOnComplete callback = CreateCallback( NaiveEngine::OnComplete, nullptr); this->req_completed_ = false; @@ -114,10 +122,10 @@ class NaiveEngine final : public Engine { NaiveOpr *opr = nullptr; bool profiling = (profiler->GetState() == Profiler::kRunning) && (profiler->GetMode() == Profiler::kAllOperator) && - opr_name; + opr_name && attr_name; if (profiling) { opr = NewOperator(exec_fun, const_vars, mutable_vars, - prop, opr_name)->Cast(); + prop, opr_name, attr_name)->Cast(); opr->profiling = profiling; opr->opr_stat = Profiler::Get()->AddOprStat(exec_ctx.dev_type, exec_ctx.dev_id); uint64_t id = std::hash()(std::this_thread::get_id()); @@ -125,6 +133,9 @@ class NaiveEngine final : public Engine { strncpy(opr->opr_stat->opr_name, opr->opr_name, sizeof(opr->opr_stat->opr_name) - 1); + strncpy(opr->opr_stat->attr_name, + opr->attr_name, + sizeof(opr->opr_stat->attr_name) - 1); SetOprStart(opr->opr_stat); } #endif diff --git a/src/engine/profiler.cc b/src/engine/profiler.cc index 44099c397783..9c27f906bc9b 100644 --- a/src/engine/profiler.cc +++ b/src/engine/profiler.cc @@ -81,6 +81,7 @@ OprExecStat *Profiler::AddOprStat(int dev_type, uint32_t dev_id) { opr_stat->dev_type = dev_type; opr_stat->dev_id = dev_id; opr_stat->opr_name[sizeof(opr_stat->opr_name)-1] = '\0'; + opr_stat->attr_name[sizeof(opr_stat->attr_name)-1] = '\0'; int idx; switch (dev_type) { @@ -167,10 +168,10 @@ void Profiler::DumpProfile() { file << ","; } file << std::endl; - this->EmitEvent(&file, opr_stat->opr_name, "category", "B", + this->EmitEvent(&file, opr_stat->attr_name, opr_stat->opr_name, "B", opr_stat->opr_start_rel_micros, pid, tid); file << ",\n"; - this->EmitEvent(&file, opr_stat->opr_name, "category", "E", + this->EmitEvent(&file, opr_stat->attr_name, opr_stat->opr_name, "E", opr_stat->opr_end_rel_micros, pid, tid); } } diff --git a/src/engine/profiler.h b/src/engine/profiler.h index f28d691e250d..8883332fccdf 100644 --- a/src/engine/profiler.h +++ b/src/engine/profiler.h @@ -20,6 +20,8 @@ namespace engine { struct OprExecStat { /*! \brief operation name */ char opr_name[32]; + /*! \brief layer name */ + char attr_name[32]; /*! * \brief operation execution start relative timestamp * time unit is microsecond (10^-6 s) diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 3632a46ba80b..894827573ec8 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -187,9 +187,11 @@ ThreadedOpr* ThreadedEngine::NewOperator( std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop, - const char* opr_name) { + const char* opr_name, + const char* attr_name) { auto ret = ThreadedOpr::New(); ret->opr_name = opr_name; + ret->attr_name = attr_name; ret->fn = std::move(fn); ret->prop = prop; ret->const_vars.resize(const_vars.size()); @@ -285,8 +287,9 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, std::vector const& mutable_vars, FnProperty prop, int priority, - const char* opr_name) { - ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name); + const char* opr_name, + const char* attr_name) { + ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, attr_name); opr->temporary = true; #if MXNET_USE_PROFILER Profiler *profiler = Profiler::Get(); @@ -403,7 +406,7 @@ void ThreadedEngine::OnCompleteStatic( OprBlock *opr_block = static_cast(opr_block_); ThreadedOpr *threaded_opr = opr_block->opr; #if MXNET_USE_PROFILER - if (opr_block->profiling && threaded_opr->opr_name) { + if (opr_block->profiling && threaded_opr->opr_name && threaded_opr->attr_name) { // record operator end timestamp SetOprEnd(opr_block->opr_stat); } diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 4612cc6e02bf..d330900b2daf 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -209,6 +209,8 @@ struct ThreadedOpr final : public Opr, FnProperty prop; /*! \brief The name of the operator */ const char* opr_name{nullptr}; + /*! \brief The name of the attribute */ + const char* attr_name{nullptr}; /*! * \brief Whether this is an temporary operator * that can be deleted right after the operation completed. @@ -243,7 +245,8 @@ class ThreadedEngine : public Engine { std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, - const char* opr_name = nullptr) override; + const char* opr_name = nullptr, + const char* attr_name = nullptr) override; void DeleteOperator(OprHandle op) override; void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override; void PushAsync(AsyncFn exec_fun, Context exec_ctx, @@ -251,7 +254,8 @@ class ThreadedEngine : public Engine { std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, - const char* opr_name = nullptr) override; + const char* opr_name = nullptr, + const char* attr_name = nullptr) override; void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override; void WaitForVar(VarHandle var) override; void WaitForAll() override; @@ -294,7 +298,7 @@ class ThreadedEngine : public Engine { void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) { ThreadedOpr* threaded_opr = opr_block->opr; #if MXNET_USE_PROFILER - if (opr_block->profiling && threaded_opr->opr_name) { + if (opr_block->profiling && threaded_opr->opr_name && threaded_opr->attr_name) { const Context& ctx = opr_block->ctx; opr_block->opr_stat = Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id); uint64_t id = std::hash()(std::this_thread::get_id()); @@ -302,6 +306,9 @@ class ThreadedEngine : public Engine { strncpy(opr_block->opr_stat->opr_name, threaded_opr->opr_name, sizeof(opr_block->opr_stat->opr_name) - 1); + strncpy(opr_block->opr_stat->attr_name, + threaded_opr->attr_name, + sizeof(opr_block->opr_stat->attr_name) - 1); // record operator start timestamp SetOprStart(opr_block->opr_stat); } diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index c5248e2cc741..1f6187d2d11e 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -584,8 +584,10 @@ void GraphExecutor::InitCachedOps() { if (inode.source->is_variable()) continue; #if MXNET_USE_PROFILER op_nodes_[nid].opr_name = inode.source->op()->name.c_str(); + op_nodes_[nid].attr_name = inode.source->attrs.name.c_str(); #else op_nodes_[nid].opr_name = nullptr; + op_nodes_[nid].attr_name = nullptr; #endif if (skip_plus_node.at(nid)) { op_nodes_[nid].skip_exec_node = true; continue; @@ -674,7 +676,8 @@ void GraphExecutor::InitCachedOps() { // setup the vars op_nodes_[nid].cached_opr = Engine::Get()->NewOperator( exec_fun, use_vars, mutate_vars, FnProperty::kNormal, - PROFILER_MESSAGE(op_nodes_[nid].opr_name)); + PROFILER_MESSAGE(op_nodes_[nid].opr_name), + PROFILER_MESSAGE(op_nodes_[nid].attr_name)); op_nodes_[nid].mutate_vars = mutate_vars; op_nodes_[nid].use_vars = use_vars; } @@ -849,8 +852,10 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, } #if MXNET_USE_PROFILER std::string opr_names = "["; + std::string attr_names = "["; #else std::string opr_names = "Bulk Execution"; + std::string attr_names = "Bulk Execution"; #endif const auto& idx = graph_.indexed_graph(); @@ -875,6 +880,7 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, ret.exec_list.push_back(exec.get()); #if MXNET_USE_PROFILER opr_names += inode.source->op()->name + ","; + attr_names += inode.source->attrs.name + ","; #endif } @@ -902,14 +908,19 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, #if MXNET_USE_PROFILER opr_names.pop_back(); opr_names += "]"; + attr_names.pop_back(); + attr_names += "]"; // the lifetime of `opr_names.c_str()` is same with opr_names // you need to copy it out. (potential memory leak risk) char *p_opr_name = new char[opr_names.size() + 1]; memcpy(p_opr_name, opr_names.c_str(), opr_names.size() + 1); + char *p_attr_name = new char[attr_names.size() + 1]; + memcpy(p_attr_name, attr_names.c_str(), attr_names.size() + 1); #endif ret.opr = Engine::Get()->NewOperator( exec_fun, use_vars, mutate_vars, FnProperty::kNormal, - PROFILER_MESSAGE(p_opr_name)); + PROFILER_MESSAGE(p_opr_name), + PROFILER_MESSAGE(p_attr_name)); return ret; } } // namespace exec diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index d9c3a3e6aa47..793df14983b9 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -69,6 +69,8 @@ class GraphExecutor : public Executor { // The name of the operator const char* opr_name; // the context of the node + const char* attr_name; + // the context of the node Context ctx; // The executor std::shared_ptr exec;