Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Optimize NMS part 2 #14352

Merged
merged 2 commits into from
Mar 8, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/operator/contrib/bounding_box-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ struct nms_impl {
}
};

namespace mshadow_op {
struct less_than : public mxnet_op::tunable {
// a is x, b is sigma
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(a < b);
}
}; // struct equal_to
} // namespace mshadow_op

} // namespace op
} // namespace mxnet

Expand Down
39 changes: 39 additions & 0 deletions src/operator/contrib/bounding_box-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,45 @@ void NMSApply(mshadow::Stream<gpu> *s,
}
}

__launch_bounds__(512)
__global__ void nms_calculate_batch_start_kernel(int32_t * batch_start,
int32_t * valid_batch_id,
size_t N,
int num_batch) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N) {
const int32_t previous = tid > 0 ? __ldg(valid_batch_id + tid - 1) : -1;
Copy link
Member

@arcadiaphy arcadiaphy Mar 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using __ldg intrinsic will fail to compile on some early cuda architectures.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will fail on sm 3.0 and earlier (so Fermi and the first Kepler). I can put ifdef there, but do we care about those?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we do ;-). I will introduce the guard, thanks!

const int32_t my = __ldg(valid_batch_id + tid);
if (my > previous) {
for (int32_t current = previous + 1; current <= my; ++current) {
batch_start[current] = tid;
}
}
if (tid == N - 1) {
for (int32_t current = my + 1; current <= num_batch; ++current) {
batch_start[current] = tid + 1;
}
}
}
}

inline void NMSCalculateBatchStart(mshadow::Stream<gpu> *s,
mshadow::Tensor<gpu, 1, int32_t>* batch_start,
mshadow::Tensor<gpu, 1, int32_t>* valid_batch_id,
int num_batch) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
auto stream = mshadow::Stream<gpu>::GetStream(s);
constexpr int block_size = 512;
const int num_elements = valid_batch_id->size(0);
const int blocks = (num_elements + block_size - 1) / block_size;
nms_calculate_batch_start_kernel<<<blocks, block_size, 0, stream>>>(batch_start->dptr_,
valid_batch_id->dptr_,
num_elements,
num_batch);
}

} // namespace op
} // namespace mxnet

Expand Down
27 changes: 14 additions & 13 deletions src/operator/contrib/bounding_box-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,6 @@ int FilterScores(mshadow::Tensor<cpu, 1, DType> out_scores,
return j;
}

namespace mshadow_op {
struct less_than : public mxnet_op::tunable {
// a is x, b is sigma
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(a < b);
}
}; // struct equal_to
} // namespace mshadow_op

struct corner_to_center {
template<typename DType>
Expand Down Expand Up @@ -277,6 +268,19 @@ void NMSApply(mshadow::Stream<cpu> *s,
}
}

inline void NMSCalculateBatchStart(mshadow::Stream<cpu> *s,
mshadow::Tensor<cpu, 1, int32_t>* batch_start,
mshadow::Tensor<cpu, 1, int32_t>* valid_batch_id,
int num_batch) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
for (int b = 0; b < num_batch + 1; b++) {
slice<0>(*batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
F<mshadow_op::less_than>(*valid_batch_id, ScalarExp<int32_t>(b)), 0);
}
}

/*!
* \brief Assign output of nms by indexing input
*
Expand Down Expand Up @@ -435,10 +439,7 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,

// calculate batch_start: accumulated sum to denote 1st sorted_index for a given batch_index
valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
for (int b = 0; b < num_batch + 1; b++) {
slice<0>(batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
F<mshadow_op::less_than>(valid_batch_id, ScalarExp<int32_t>(b)), 0);
}
mxnet::op::NMSCalculateBatchStart(s, &batch_start, &valid_batch_id, num_batch);

// pre-compute areas of candidates
areas = 0;
Expand Down