Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add backgroud class in box_nms
Browse files Browse the repository at this point in the history
  • Loading branch information
arcadiaphy committed Feb 6, 2019
1 parent 9e14f14 commit 943ea0d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 36 deletions.
26 changes: 10 additions & 16 deletions src/operator/contrib/bounding_box-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,20 @@ namespace mxnet {
namespace op {

template<typename DType>
struct valid_score {
DType thresh;
explicit valid_score(DType _thresh) : thresh(_thresh) {}
struct valid_value {
__host__ __device__ bool operator()(const DType x) {
return x > thresh;
return static_cast<bool>(x);
}
};

template<typename DType>
int FilterScores(mshadow::Tensor<gpu, 1, DType> out_scores,
mshadow::Tensor<gpu, 1, int32_t> out_sorted_index,
mshadow::Tensor<gpu, 1, DType> scores,
mshadow::Tensor<gpu, 1, int32_t> sorted_index,
float valid_thresh) {
valid_score<DType> pred(static_cast<DType>(valid_thresh));
DType * end_scores = thrust::copy_if(thrust::device, scores.dptr_, scores.dptr_ + scores.MSize(),
out_scores.dptr_, pred);
thrust::copy_if(thrust::device, sorted_index.dptr_, sorted_index.dptr_ + sorted_index.MSize(),
scores.dptr_, out_sorted_index.dptr_, pred);
return end_scores - out_scores.dptr_;
template<typename DType, typename FType>
int CopyIf(mshadow::Tensor<gpu, 1, DType> out,
mshadow::Tensor<gpu, 1, DType> value,
mshadow::Tensor<gpu, 1, FType> flag) {
valid_value<FType> pred;
DType *end_out = thrust::copy_if(thrust::device, value.dptr_, value.dptr_ + value.MSize(),
flag.dptr_, out.dptr_, pred);
return end_out - out.dptr_;
}

} // namespace op
Expand Down
72 changes: 54 additions & 18 deletions src/operator/contrib/bounding_box-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct BoxNMSParam : public dmlc::Parameter<BoxNMSParam> {
int coord_start;
int score_index;
int id_index;
int background_id;
bool force_suppress;
int in_format;
int out_format;
Expand All @@ -70,6 +71,8 @@ struct BoxNMSParam : public dmlc::Parameter<BoxNMSParam> {
.describe("Index of the scores/confidence of boxes.");
DMLC_DECLARE_FIELD(id_index).set_default(-1)
.describe("Optional, index of the class categories, -1 to disable.");
DMLC_DECLARE_FIELD(background_id).set_default(-1)
.describe("Optional, id of the background class which will be ignored in nms.");
DMLC_DECLARE_FIELD(force_suppress).set_default(false)
.describe("Optional, if set false and id_index is provided, nms will only apply"
" to boxes belongs to the same category");
Expand Down Expand Up @@ -106,7 +109,7 @@ inline bool BoxNMSShape(const nnvm::NodeAttrs& attrs,
<< ishape << " provided";
int width_elem = ishape[indim - 1];
int expected = 5;
if (param.id_index > 0) {
if (param.id_index >= 0) {
expected += 1;
}
CHECK_GE(width_elem, expected)
Expand Down Expand Up @@ -148,17 +151,14 @@ inline uint32_t BoxNMSNumVisibleOutputs(const NodeAttrs& attrs) {
return static_cast<uint32_t>(1);
}

template<typename DType>
int FilterScores(mshadow::Tensor<cpu, 1, DType> out_scores,
mshadow::Tensor<cpu, 1, int32_t> out_sorted_index,
mshadow::Tensor<cpu, 1, DType> scores,
mshadow::Tensor<cpu, 1, int32_t> sorted_index,
float valid_thresh) {
template<typename DType, typename FType>
int CopyIf(mshadow::Tensor<cpu, 1, DType> out,
mshadow::Tensor<cpu, 1, DType> value,
mshadow::Tensor<cpu, 1, FType> flag) {
index_t j = 0;
for (index_t i = 0; i < scores.size(0); i++) {
if (scores[i] > valid_thresh) {
out_scores[j] = scores[i];
out_sorted_index[j] = sorted_index[i];
for (index_t i = 0; i < flag.size(0); i++) {
if (static_cast<bool>(flag[i])) {
out[j] = value[i];
j++;
}
}
Expand All @@ -167,12 +167,32 @@ int FilterScores(mshadow::Tensor<cpu, 1, DType> out_scores,

namespace mshadow_op {
struct less_than : public mxnet_op::tunable {
// a is x, b is sigma
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(a < b);
}
}; // struct equal_to
};

struct greater_than : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(a > b);
}
};

struct not_equal : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(a != b);
}
};

struct bool_and : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(a && b);
}
};
} // namespace mshadow_op

struct corner_to_center {
Expand Down Expand Up @@ -403,6 +423,7 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2);
int num_elem = in_shape[indim - 2];
int width_elem = in_shape[indim - 1];
bool class_exist = param.id_index >= 0;
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 3, DType> data = inputs[box_nms_enum::kData]
.get_with_shape<xpu, 3, DType>(Shape3(num_batch, num_elem, width_elem), s);
Expand All @@ -418,7 +439,7 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,

// index
index_t int32_size = sort_index_shape.Size() * 3 + batch_start_shape.Size();
index_t dtype_size = sort_index_shape.Size() * 2;
index_t dtype_size = sort_index_shape.Size() * 3;
if (req[0] == kWriteInplace) {
dtype_size += buffer_shape.Size();
}
Expand All @@ -437,6 +458,7 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, DType> scores(workspace.dptr_ + int32_offset,
sort_index_shape, s);
Tensor<xpu, 1, DType> areas(scores.dptr_ + scores.MSize(), sort_index_shape, s);
Tensor<xpu, 1, DType> classes(areas.dptr_ + areas.MSize(), sort_index_shape, s);
Tensor<xpu, 3, DType> buffer = data;
if (req[0] == kWriteInplace) {
// make copy
Expand All @@ -457,16 +479,30 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
return;
}

// use batch_id and areas as temporary storage
// use classes, areas and scores as temporary storage
Tensor<xpu, 1, DType> all_scores = areas;
// Tensor<xpu, 1, DType> all_sorted_index = areas;
all_scores = reshape(slice<2>(buffer, score_index, score_index + 1), all_scores.shape_);
all_sorted_index = range<int32_t>(0, num_batch * num_elem);
Tensor<xpu, 1, DType> all_classes = classes;
if (class_exist) {
all_classes = reshape(slice<2>(buffer, id_index, id_index + 1), classes.shape_);
}

// filter scores but keep original sorted_index value
// move valid score and index to the front, return valid size
int num_valid = mxnet::op::FilterScores(scores, sorted_index, all_scores, all_sorted_index,
param.valid_thresh);
Tensor<xpu, 1, DType> valid_box = scores;
if (class_exist) {
valid_box = F<mshadow_op::bool_and>(
F<mshadow_op::greater_than>(all_scores, ScalarExp<DType>(param.valid_thresh)),
F<mshadow_op::not_equal>(all_classes, ScalarExp<DType>(param.background_id)));
} else {
valid_box = F<mshadow_op::greater_than>(all_scores, ScalarExp<DType>(param.valid_thresh));
}
classes = F<mshadow_op::identity>(valid_box);
valid_box = classes;
int num_valid = mxnet::op::CopyIf(scores, all_scores, valid_box);
mxnet::op::CopyIf(sorted_index, all_sorted_index, valid_box);

// if everything is filtered, output -1
if (num_valid == 0) {
record = -1;
Expand Down
8 changes: 6 additions & 2 deletions src/operator/contrib/bounding_box.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ NNVM_REGISTER_OP(_contrib_box_nms)
.describe(R"code(Apply non-maximum suppression to input.
The output will be sorted in descending order according to `score`. Boxes with
overlaps larger than `overlap_thresh` and smaller scores will be removed and
filled with -1, the corresponding position will be recorded for backward propogation.
overlaps larger than `overlap_thresh`, smaller scores and background boxes
will be removed and filled with -1, the corresponding position will be recorded
for backward propogation.
During back-propagation, the gradient will be copied to the original
position according to the input index. For positions that have been suppressed,
Expand All @@ -60,6 +61,9 @@ additional elements are allowed.
- `id_index`: optional, use -1 to ignore, useful if `force_suppress=False`, which means
we will skip highly overlapped boxes if one is `apple` while the other is `car`.
- `background_id`: optional, default=-1, class id for background boxes, useful
when `id_index >= 0`, which means boxes with background id will be filtered before nms.
- `coord_start`: required, default=2, the starting index of the 4 coordinates.
Two formats are supported:
Expand Down

0 comments on commit 943ea0d

Please sign in to comment.