From d503bb462ca1559766f1ee172ec0dbf6767396a7 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Thu, 7 Mar 2019 18:57:08 -0800 Subject: [PATCH] Optimize NMS part 2 (#14352) * Optimize NMS part 2 * Guarding ldg intrinsics --- src/operator/contrib/bounding_box-common.h | 10 +++++ src/operator/contrib/bounding_box-inl.cuh | 44 ++++++++++++++++++++++ src/operator/contrib/bounding_box-inl.h | 27 ++++++------- 3 files changed, 68 insertions(+), 13 deletions(-) diff --git a/src/operator/contrib/bounding_box-common.h b/src/operator/contrib/bounding_box-common.h index 70215ab25d64..4c9b1b86d10c 100644 --- a/src/operator/contrib/bounding_box-common.h +++ b/src/operator/contrib/bounding_box-common.h @@ -112,6 +112,16 @@ struct nms_impl { } }; +namespace mshadow_op { +struct less_than : public mxnet_op::tunable { + // a is x, b is sigma + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + return static_cast(a < b); + } +}; // struct equal_to +} // namespace mshadow_op + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bounding_box-inl.cuh b/src/operator/contrib/bounding_box-inl.cuh index 4b7cf3476448..e7f5567f469a 100644 --- a/src/operator/contrib/bounding_box-inl.cuh +++ b/src/operator/contrib/bounding_box-inl.cuh @@ -280,6 +280,50 @@ void NMSApply(mshadow::Stream *s, } } +__launch_bounds__(512) +__global__ void nms_calculate_batch_start_kernel(int32_t * batch_start, + int32_t * valid_batch_id, + size_t N, + int num_batch) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < N) { +#if __CUDA_ARCH__ >= 350 + const int32_t previous = tid > 0 ? __ldg(valid_batch_id + tid - 1) : -1; + const int32_t my = __ldg(valid_batch_id + tid); +#else + const int32_t previous = tid > 0 ? valid_batch_id[tid - 1] : -1; + const int32_t my = valid_batch_id[tid]; +#endif + if (my > previous) { + for (int32_t current = previous + 1; current <= my; ++current) { + batch_start[current] = tid; + } + } + if (tid == N - 1) { + for (int32_t current = my + 1; current <= num_batch; ++current) { + batch_start[current] = tid + 1; + } + } + } +} + +inline void NMSCalculateBatchStart(mshadow::Stream *s, + mshadow::Tensor* batch_start, + mshadow::Tensor* valid_batch_id, + int num_batch) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + auto stream = mshadow::Stream::GetStream(s); + constexpr int block_size = 512; + const int num_elements = valid_batch_id->size(0); + const int blocks = (num_elements + block_size - 1) / block_size; + nms_calculate_batch_start_kernel<<>>(batch_start->dptr_, + valid_batch_id->dptr_, + num_elements, + num_batch); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h index 35ab19d01a19..8610dcca8e10 100644 --- a/src/operator/contrib/bounding_box-inl.h +++ b/src/operator/contrib/bounding_box-inl.h @@ -162,15 +162,6 @@ int FilterScores(mshadow::Tensor out_scores, return j; } -namespace mshadow_op { -struct less_than : public mxnet_op::tunable { - // a is x, b is sigma - template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - return static_cast(a < b); - } -}; // struct equal_to -} // namespace mshadow_op struct corner_to_center { template @@ -277,6 +268,19 @@ void NMSApply(mshadow::Stream *s, } } +inline void NMSCalculateBatchStart(mshadow::Stream *s, + mshadow::Tensor* batch_start, + mshadow::Tensor* valid_batch_id, + int num_batch) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + for (int b = 0; b < num_batch + 1; b++) { + slice<0>(*batch_start, b, b + 1) = reduce_keepdim( + F(*valid_batch_id, ScalarExp(b)), 0); + } +} + /*! * \brief Assign output of nms by indexing input * @@ -435,10 +439,7 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs, // calculate batch_start: accumulated sum to denote 1st sorted_index for a given batch_index valid_batch_id = (valid_sorted_index / ScalarExp(num_elem)); - for (int b = 0; b < num_batch + 1; b++) { - slice<0>(batch_start, b, b + 1) = reduce_keepdim( - F(valid_batch_id, ScalarExp(b)), 0); - } + mxnet::op::NMSCalculateBatchStart(s, &batch_start, &valid_batch_id, num_batch); // pre-compute areas of candidates areas = 0;