From 0c5677ed333b3b503898240782ab206f902e9858 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 4 Nov 2019 23:38:50 -0800 Subject: [PATCH] Faster GPU NMS operator (#16542) * Adding second NMS op * NMS kernel * Removing second sort * Optimization * Adding out-of-place ability to SortByKey * Optimization pt2 * Optimizations pt3 * Do not recompute other boxes area every time * Sort only topk results during second sorting * Cleaning * Fixes from rebase * Fix lint and more fixes from rebase * Fix typo * Early exit in Triangle kernel * Fixes * Fix sort * Fix from rebase * Fix for the mixed naming convention * Fix the index_t with int comparisoon --- src/operator/contrib/bounding_box.cc | 1 + src/operator/contrib/bounding_box.cu | 689 ++++++++++++++++++++++++++- src/operator/tensor/sort_op-inl.cuh | 138 ++++-- src/operator/tensor/sort_op.h | 50 +- 4 files changed, 840 insertions(+), 38 deletions(-) diff --git a/src/operator/contrib/bounding_box.cc b/src/operator/contrib/bounding_box.cc index 3ab11bb2d6f9..8b1d53506c47 100644 --- a/src/operator/contrib/bounding_box.cc +++ b/src/operator/contrib/bounding_box.cc @@ -34,6 +34,7 @@ DMLC_REGISTER_PARAMETER(BoxOverlapParam); DMLC_REGISTER_PARAMETER(BipartiteMatchingParam); DMLC_REGISTER_PARAMETER(BoxDecodeParam); + NNVM_REGISTER_OP(_contrib_box_nms) .add_alias("_contrib_box_non_maximum_suppression") .describe(R"code(Apply non-maximum suppression to input. diff --git a/src/operator/contrib/bounding_box.cu b/src/operator/contrib/bounding_box.cu index b20c570ea417..ffc48d7d5a44 100644 --- a/src/operator/contrib/bounding_box.cu +++ b/src/operator/contrib/bounding_box.cu @@ -24,14 +24,701 @@ * \author Joshua Zhang */ +#include + #include "./bounding_box-inl.cuh" #include "./bounding_box-inl.h" #include "../elemwise_op_common.h" namespace mxnet { namespace op { + +namespace { + +using mshadow::Tensor; +using mshadow::Stream; + +template +struct TempWorkspace { + index_t scores_temp_space; + DType* scores; + index_t scratch_space; + uint8_t* scratch; + index_t buffer_space; + DType* buffer; + index_t nms_scratch_space; + uint32_t* nms_scratch; + index_t indices_temp_spaces; + index_t* indices; +}; + +inline index_t ceil_div(index_t x, index_t y) { + return (x + y - 1) / y; +} + +inline index_t align(index_t x, index_t alignment) { + return ceil_div(x, alignment) * alignment; +} + +template +__global__ void FilterAndPrepareAuxDataKernel(const DType* data, DType* out, DType* scores, + index_t num_elements_per_batch, + const index_t element_width, + const index_t N, + const float threshold, + const int id_index, const int score_index, + const int background_id) { + index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + bool first_in_element = (tid % element_width == 0); + index_t start_of_my_element = tid - (tid % element_width); + + if (tid < N) { + DType my_score = data[start_of_my_element + score_index]; + bool filtered_out = my_score <= threshold; + if (id_index != -1 && background_id != -1) { + DType my_id = data[start_of_my_element + id_index]; + filtered_out = filtered_out || (my_id == background_id); + } + if (!filtered_out) { + out[tid] = data[tid]; + } else { + out[tid] = -1; + my_score = -1; + } + + if (first_in_element) { + index_t offset = tid / element_width; + scores[offset] = my_score; + } + } +} + +template +void FilterAndPrepareAuxData(const Tensor& data, + Tensor* out, + const TempWorkspace& workspace, + const BoxNMSParam& param, + Stream* s) { + const int n_threads = 512; + index_t N = data.shape_.Size(); + const auto blocks = ceil_div(N, n_threads); + FilterAndPrepareAuxDataKernel<<::GetStream(s)>>>( + data.dptr_, out->dptr_, workspace.scores, + data.shape_[1], data.shape_[2], N, + param.valid_thresh, param.id_index, + param.score_index, param.background_id); +} + +template +__global__ void CompactDataKernel(const index_t* indices, const DType* source, + DType* destination, const index_t topk, + const index_t element_width, + const index_t num_elements_per_batch, + const int score_index, + const index_t N) { + const index_t tid_start = blockIdx.x * blockDim.x + threadIdx.x; + for (index_t tid = tid_start; tid < N; tid += blockDim.x * gridDim.x) { + const index_t my_element = tid / element_width; + const index_t my_element_in_batch = my_element % num_elements_per_batch; + if (check_topk && my_element_in_batch >= topk) { + destination[tid] = -1; + } else { + DType ret; + const index_t source_element = indices[my_element]; + DType score = 0; + if (check_score) { + score = source[source_element * element_width + score_index]; + } + if (score >= 0) { + ret = source[source_element * element_width + tid % element_width]; + } else { + ret = -1; + } + destination[tid] = ret; + } + } +} + +template +void CompactData(const Tensor& indices, + const Tensor& source, + Tensor* destination, + const index_t topk, + const int score_index, + Stream* s) { + const int n_threads = 512; + const index_t max_blocks = 320; + index_t N = source.shape_.Size(); + const auto blocks = std::min(ceil_div(N, n_threads), max_blocks); + if (topk > 0) { + CompactDataKernel<<::GetStream(s)>>>( + indices.dptr_, source.dptr_, + destination->dptr_, topk, + source.shape_[2], source.shape_[1], + score_index, N); + } else { + CompactDataKernel<<::GetStream(s)>>>( + indices.dptr_, source.dptr_, + destination->dptr_, topk, + source.shape_[2], source.shape_[1], + score_index, N); + } +} + +template +void WorkspaceForSort(const index_t num_elem, + const index_t topk, + const int alignment, + TempWorkspace* workspace) { + const index_t sort_scores_temp_space = + mxnet::op::SortByKeyWorkspaceSize(num_elem, false, false); + const index_t sort_topk_scores_temp_space = + mxnet::op::SortByKeyWorkspaceSize(topk, false, false); + workspace->scratch_space = align(std::max(sort_scores_temp_space, sort_topk_scores_temp_space), + alignment); +} + +template +__global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* result, + const index_t current_start, + const index_t num_elems, + const index_t num_batches, + const index_t num_blocks_per_row_batch, + const index_t num_blocks_per_row, + const index_t topk, + const index_t element_width, + const index_t num_elements_per_batch, + const int coord_index, + const int class_index, + const int score_index, + const float threshold); + +template +__global__ void ReduceNMSResultTriangleKernel(uint32_t* nms_results, + DType * data, + const index_t score_index, + const index_t element_width, + const index_t num_batches, + const index_t num_elems, + const index_t start_index, + const index_t topk); + +template +__global__ void ReduceNMSResultRestKernel(DType* data, + const uint32_t* nms_results, + const index_t score_index, + const index_t element_width, + const index_t num_batches, + const index_t num_elements_per_batch, + const index_t start_index, + const index_t topk, + const index_t num_blocks_per_batch); + +template +struct NMS { + static constexpr int THRESHOLD = 512; + + void operator()(Tensor* data, + Tensor* scratch, + const index_t topk, + const BoxNMSParam& param, + Stream* s) { + const int n_threads = 512; + const index_t num_batches = data->shape_[0]; + const index_t num_elements_per_batch = data->shape_[1]; + const index_t element_width = data->shape_[2]; + for (index_t current_start = 0; current_start < topk; current_start += THRESHOLD) { + const index_t n_elems = topk - current_start; + const index_t num_blocks_per_row_batch = ceil_div(n_elems, n_threads); + const index_t num_blocks_per_row = num_blocks_per_row_batch * num_batches; + const index_t n_blocks = THRESHOLD / (sizeof(uint32_t) * 8) * num_blocks_per_row; + if (param.in_format == box_common_enum::kCorner) { + CalculateGreedyNMSResultsKernel + <<::GetStream(s)>>>( + data->dptr_, scratch->dptr_, current_start, n_elems, num_batches, + num_blocks_per_row_batch, num_blocks_per_row, topk, element_width, + num_elements_per_batch, param.coord_start, + param.force_suppress ? -1 : param.id_index, + param.score_index, param.overlap_thresh); + } else { + CalculateGreedyNMSResultsKernel + <<::GetStream(s)>>>( + data->dptr_, scratch->dptr_, current_start, n_elems, num_batches, + num_blocks_per_row_batch, num_blocks_per_row, topk, element_width, + num_elements_per_batch, param.coord_start, + param.force_suppress ? -1 : param.id_index, + param.score_index, param.overlap_thresh); + } + ReduceNMSResultTriangleKernel<<::GetStream(s)>>>( + scratch->dptr_, data->dptr_, param.score_index, + element_width, num_batches, num_elements_per_batch, + current_start, topk); + const index_t n_rest_elems = n_elems - THRESHOLD; + const index_t num_rest_blocks_per_batch = ceil_div(n_rest_elems, n_threads); + const index_t num_rest_blocks = num_rest_blocks_per_batch * num_batches; + if (n_rest_elems > 0) { + ReduceNMSResultRestKernel<<::GetStream(s)>>>( + data->dptr_, scratch->dptr_, param.score_index, element_width, + num_batches, num_elements_per_batch, current_start, topk, + num_rest_blocks_per_batch); + } + } + } +}; + +template +__device__ __forceinline__ DType calculate_area(const DType b0, const DType b1, + const DType b2, const DType b3) { + DType width = b2; + DType height = b3; + if (encode == box_common_enum::kCorner) { + width -= b0; + height -= b1; + } + if (width < 0 || height < 0) return 0; + return width * height; +} + +template +__device__ __forceinline__ DType calculate_intersection(const DType a0, const DType a1, + const DType a2, const DType a3, + const DType b0, const DType b1, + const DType b2, const DType b3) { + DType wx, wy; + if (encode == box_common_enum::kCorner) { + const DType left = a0 > b0 ? a0 : b0; + const DType bottom = a1 > b1 ? a1 : b1; + const DType right = a2 < b2 ? a2 : b2; + const DType top = a3 < b3 ? a3 : b3; + wx = right - left; + wy = top - bottom; + } else { + const DType al = 2 * a0 - a2; + const DType ar = 2 * a0 + a2; + const DType bl = 2 * b0 - b2; + const DType br = 2 * b0 + b2; + const DType left = bl > al ? bl : al; + const DType right = br < ar ? br : ar; + wx = right - left; + const DType ab = 2 * a1 - a3; + const DType at = 2 * a1 + a3; + const DType bb = 2 * b1 - b3; + const DType bt = 2 * b1 + b3; + const DType bottom = bb > ab ? bb : ab; + const DType top = bt < at ? bt : at; + wy = top - bottom; + wy = wy / 4; // To compensate for both wx and wy being 2x too large + } + if (wx <= 0 || wy <= 0) { + return 0; + } else { + return (wx * wy); + } +} + +template +__launch_bounds__(512) +__global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* result, + const index_t current_start, + const index_t num_elems, + const index_t num_batches, + const index_t num_blocks_per_row_batch, + const index_t num_blocks_per_row, + const index_t topk, + const index_t element_width, + const index_t num_elements_per_batch, + const int coord_index, + const int class_index, + const int score_index, + const float threshold) { + constexpr int max_elem_width = 20; + constexpr int num_other_boxes = sizeof(uint32_t) * 8; + __shared__ DType other_boxes[max_elem_width * num_other_boxes]; + __shared__ DType other_boxes_areas[num_other_boxes]; + const index_t my_row = blockIdx.x / num_blocks_per_row; + const index_t my_block_offset_in_row = blockIdx.x % num_blocks_per_row; + const index_t my_block_offset_in_batch = my_block_offset_in_row % num_blocks_per_row_batch; + const index_t my_batch = (my_block_offset_in_row) / num_blocks_per_row_batch; + const index_t my_element_in_batch = my_block_offset_in_batch * blockDim.x + + current_start + threadIdx.x; + + // Load other boxes + const index_t offset = (my_batch * num_elements_per_batch + + current_start + my_row * num_other_boxes) * + element_width; + for (int i = threadIdx.x; i < element_width * num_other_boxes; i += blockDim.x) { + other_boxes[i] = data[offset + i]; + } + __syncthreads(); + + if (threadIdx.x < num_other_boxes) { + const int other_boxes_offset = element_width * threadIdx.x; + const DType their_area = calculate_area( + other_boxes[other_boxes_offset + coord_index + 0], + other_boxes[other_boxes_offset + coord_index + 1], + other_boxes[other_boxes_offset + coord_index + 2], + other_boxes[other_boxes_offset + coord_index + 3]); + other_boxes_areas[threadIdx.x] = their_area; + } + __syncthreads(); + + if (my_element_in_batch >= topk) return; + + DType my_box[4]; + DType my_class = -1; + DType my_score = -1; + const index_t my_offset = (my_batch * num_elements_per_batch + my_element_in_batch) * + element_width; + my_score = data[my_offset + score_index]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + my_box[i] = data[my_offset + coord_index + i]; + } + if (class_index != -1) { + my_class = data[my_offset + class_index]; + } + DType my_area = calculate_area(my_box[0], my_box[1], my_box[2], my_box[3]); + + uint32_t ret = 0; + if (my_score != -1) { +#pragma unroll + for (int i = 0; i < num_other_boxes; ++i) { + const int other_boxes_offset = element_width * i; + if ((class_index == -1 || my_class == other_boxes[other_boxes_offset + class_index]) && + other_boxes[other_boxes_offset + score_index] != -1) { + const DType their_area = other_boxes_areas[i]; + + const DType intersect = calculate_intersection( + my_box[0], my_box[1], my_box[2], my_box[3], + other_boxes[other_boxes_offset + coord_index + 0], + other_boxes[other_boxes_offset + coord_index + 1], + other_boxes[other_boxes_offset + coord_index + 2], + other_boxes[other_boxes_offset + coord_index + 3]); + if (intersect > threshold * (my_area + their_area - intersect)) { + ret = ret | (1u << i); + } + } + } + } + result[(my_row * num_batches + my_batch) * topk + my_element_in_batch] = ~ret; +} + +template +__launch_bounds__(NMS::THRESHOLD) +__global__ void ReduceNMSResultTriangleKernel(uint32_t* nms_results, + DType * data, + const index_t score_index, + const index_t element_width, + const index_t num_batches, + const index_t num_elements_per_batch, + const index_t start_index, + const index_t topk) { + constexpr int n_threads = NMS::THRESHOLD; + constexpr int warp_size = 32; + const index_t my_batch = blockIdx.x; + const index_t my_element_in_batch = threadIdx.x + start_index; + const index_t my_element = my_batch * topk + my_element_in_batch; + const int my_warp = threadIdx.x / warp_size; + const int my_lane = threadIdx.x % warp_size; + + __shared__ uint32_t current_valid_boxes[n_threads / warp_size]; + const uint32_t full_mask = 0xFFFFFFFF; + const uint32_t my_lane_mask = 1 << my_lane; + const uint32_t earlier_threads_mask = (1 << (my_lane + 1)) - 1; + uint32_t valid = my_lane_mask; + uint32_t valid_boxes = full_mask; + + uint32_t my_next_mask = my_element_in_batch < topk ? + nms_results[my_element]: + full_mask; +#pragma unroll + for (int i = 0; i < n_threads / warp_size; ++i) { + uint32_t my_mask = my_next_mask; + my_next_mask = (((i + 1) < n_threads / warp_size) && + (my_element_in_batch < topk)) ? + nms_results[(i + 1) * topk * num_batches + my_element]: + full_mask; + if (my_warp == i && !__all_sync(full_mask, my_mask == full_mask)) { + my_mask = my_mask | earlier_threads_mask; + // Loop over warp_size - 1 because the last + // thread does not contribute to the mask anyway +#pragma unroll + for (int j = 0; j < warp_size - 1; ++j) { + const uint32_t mask = __shfl_sync(full_mask, valid ? my_mask : full_mask, j); + valid = valid & mask; + } + valid_boxes = __ballot_sync(full_mask, valid); + } + if (my_lane == 0 && my_warp == i) { + current_valid_boxes[i] = valid_boxes; + } + __syncthreads(); + if ((my_warp > i) && (((~my_mask) & current_valid_boxes[i]) != 0)) { + valid = 0; + } + } + if (my_lane == 0) { + nms_results[my_element] = valid_boxes; + } + if (valid == 0) { + data[(my_batch * num_elements_per_batch + my_element_in_batch) * element_width + + score_index] = -1; + } +} + +template +__launch_bounds__(512) +__global__ void ReduceNMSResultRestKernel(DType* data, + const uint32_t* nms_results, + const index_t score_index, + const index_t element_width, + const index_t num_batches, + const index_t num_elements_per_batch, + const index_t start_index, + const index_t topk, + const index_t num_blocks_per_batch) { + constexpr int num_other_boxes = sizeof(uint32_t) * 8; + constexpr int num_iterations = NMS::THRESHOLD / num_other_boxes; + constexpr int warp_size = 32; + const index_t my_block_offset_in_batch = blockIdx.x % num_blocks_per_batch; + const index_t my_batch = blockIdx.x / num_blocks_per_batch; + const index_t my_element_in_batch = my_block_offset_in_batch * blockDim.x + + start_index + NMS::THRESHOLD + threadIdx.x; + const index_t my_element = my_batch * topk + my_element_in_batch; + + if (my_element_in_batch >= topk) return; + + bool valid = true; + +#pragma unroll + for (int i = 0; i < num_iterations; ++i) { + const uint32_t my_mask = nms_results[i * topk * num_batches + my_element]; + const uint32_t valid_boxes = nms_results[my_batch * topk + i * warp_size + start_index]; + + const bool no_hit = (valid_boxes & (~my_mask)) == 0; + valid = valid && no_hit; + } + + if (!valid) { + data[(my_batch * num_elements_per_batch + my_element_in_batch) * element_width + + score_index] = -1; + } +} + +template +TempWorkspace GetWorkspace(const index_t num_batch, + const index_t num_elem, + const int width_elem, + const index_t topk, + const OpContext& ctx) { + TempWorkspace workspace; + Stream *s = ctx.get_stream(); + const int alignment = 128; + + // Get the workspace size + workspace.scores_temp_space = 2 * align(num_batch * num_elem * sizeof(DType), alignment); + workspace.indices_temp_spaces = 2 * align(num_batch * num_elem * sizeof(index_t), alignment); + WorkspaceForSort(num_elem, topk, alignment, &workspace); + // Place for a buffer + workspace.buffer_space = align(num_batch * num_elem * width_elem * sizeof(DType), alignment); + workspace.nms_scratch_space = align(NMS::THRESHOLD / (sizeof(uint32_t) * 8) * + num_batch * topk * sizeof(uint32_t), alignment); + + const index_t workspace_size = workspace.scores_temp_space + + workspace.scratch_space + + workspace.nms_scratch_space + + workspace.indices_temp_spaces; + + // Obtain the memory for workspace + Tensor scratch_memory = ctx.requested[box_nms_enum::kTempSpace] + .get_space_typed(mshadow::Shape1(workspace_size), s); + + // Populate workspace pointers + workspace.scores = scratch_memory.dptr_; + workspace.scratch = reinterpret_cast(workspace.scores) + + workspace.scores_temp_space; + workspace.buffer = reinterpret_cast(workspace.scratch + + workspace.scratch_space); + workspace.nms_scratch = reinterpret_cast( + reinterpret_cast(workspace.buffer) + + workspace.buffer_space); + workspace.indices = reinterpret_cast( + reinterpret_cast(workspace.nms_scratch) + + workspace.nms_scratch_space); + return workspace; +} + +template +__global__ void ExtractScoresKernel(const DType* data, DType* scores, + const index_t N, const int element_width, + const int score_index) { + const index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < N) { + scores[tid] = data[tid * element_width + score_index]; + } +} + +template +void CompactNMSResults(const Tensor& data, + Tensor* out, + Tensor* indices, + Tensor* scores, + Tensor* sorted_indices, + Tensor* sorted_scores, + Tensor* scratch, + const int score_index, + const index_t topk, + Stream* s) { + using mshadow::Shape1; + constexpr int n_threads = 512; + const index_t num_elements = scores->shape_.Size(); + const index_t num_elements_per_batch = data.shape_[1]; + const index_t num_batches = data.shape_[0]; + const int element_width = data.shape_[2]; + const index_t n_blocks = ceil_div(num_elements, n_threads); + ExtractScoresKernel<<::GetStream(s)>>>( + data.dptr_, scores->dptr_, num_elements, element_width, score_index); + *indices = mshadow::expr::range(0, num_elements); + for (index_t i = 0; i < num_batches; ++i) { + // Sort each batch separately + Tensor scores_batch(scores->dptr_ + i * num_elements_per_batch, + Shape1(topk), + s); + Tensor indices_batch(indices->dptr_ + i * num_elements_per_batch, + Shape1(topk), + s); + Tensor sorted_scores_batch(sorted_scores->dptr_ + i * num_elements_per_batch, + Shape1(topk), + s); + Tensor sorted_indices_batch(sorted_indices->dptr_ + i * num_elements_per_batch, + Shape1(topk), + s); + mxnet::op::SortByKey(scores_batch, indices_batch, false, scratch, + 0, 8 * sizeof(DType), &sorted_scores_batch, + &sorted_indices_batch); + } + CompactData(*sorted_indices, data, out, topk, score_index, s); +} + +} // namespace + +void BoxNMSForwardGPU_notemp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using mshadow::Shape1; + using mshadow::Shape2; + using mshadow::Shape3; + CHECK_NE(req[0], kAddTo) << "BoxNMS does not support kAddTo"; + CHECK_NE(req[0], kWriteInplace) << "BoxNMS does not support in place computation"; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U) << "BoxNMS output: [output, temp]"; + const BoxNMSParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + mxnet::TShape in_shape = inputs[box_nms_enum::kData].shape_; + int indim = in_shape.ndim(); + int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2); + int num_elem = in_shape[indim - 2]; + int width_elem = in_shape[indim - 1]; + + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor data = inputs[box_nms_enum::kData] + .get_with_shape(Shape3(num_batch, num_elem, width_elem), s); + Tensor out = outputs[box_nms_enum::kOut] + .get_with_shape(Shape3(num_batch, num_elem, width_elem), s); + + // Special case for topk == 0 + if (param.topk == 0) { + if (req[0] != kNullOp && + req[0] != kWriteInplace) { + out = mshadow::expr::F(data); + } + return; + } + + index_t topk = param.topk > 0 ? std::min(param.topk, num_elem) : num_elem; + const auto& workspace = GetWorkspace(num_batch, num_elem, + width_elem, topk, ctx); + + FilterAndPrepareAuxData(data, &out, workspace, param, s); + Tensor scores(workspace.scores, Shape1(num_batch * num_elem), s); + Tensor sorted_scores(workspace.scores + scores.MSize(), + Shape1(num_batch * num_elem), s); + Tensor indices(workspace.indices, Shape1(num_batch * num_elem), s); + Tensor sorted_indices(workspace.indices + indices.MSize(), + Shape1(num_batch * num_elem), s); + Tensor scratch(reinterpret_cast(workspace.scratch), + Shape1(workspace.scratch_space), s); + Tensor buffer(workspace.buffer, + Shape3(num_batch, num_elem, width_elem), s); + Tensor nms_scratch(workspace.nms_scratch, + Shape2(NMS::THRESHOLD / (sizeof(uint32_t) * 8), + topk * num_batch), + s); + indices = mshadow::expr::range(0, num_batch * num_elem); + for (index_t i = 0; i < num_batch; ++i) { + // Sort each batch separately + Tensor scores_batch(scores.dptr_ + i * num_elem, + Shape1(num_elem), + s); + Tensor indices_batch(indices.dptr_ + i * num_elem, + Shape1(num_elem), + s); + Tensor sorted_scores_batch(sorted_scores.dptr_ + i * num_elem, + Shape1(num_elem), + s); + Tensor sorted_indices_batch(sorted_indices.dptr_ + i * num_elem, + Shape1(num_elem), + s); + mxnet::op::SortByKey(scores_batch, indices_batch, false, &scratch, 0, + 8 * sizeof(DType), &sorted_scores_batch, + &sorted_indices_batch); + } + CompactData(sorted_indices, out, &buffer, topk, -1, s); + NMS nms; + nms(&buffer, &nms_scratch, topk, param, s); + CompactNMSResults(buffer, &out, &indices, &scores, &sorted_indices, + &sorted_scores, &scratch, param.score_index, topk, s); + + // convert encoding + if (param.in_format != param.out_format) { + if (box_common_enum::kCenter == param.out_format) { + mxnet::op::mxnet_op::Kernel::Launch(s, num_batch * num_elem, + out.dptr_ + param.coord_start, width_elem); + } else { + mxnet::op::mxnet_op::Kernel::Launch(s, num_batch * num_elem, + out.dptr_ + param.coord_start, width_elem); + } + } + }); +} + +void BoxNMSForwardGPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U) << "BoxNMS output: [output, temp]"; + if (req[1] == kNullOp) { + BoxNMSForwardGPU_notemp(attrs, ctx, inputs, req, outputs); + return; + } + BoxNMSForward(attrs, ctx, inputs, req, outputs); +} + + NNVM_REGISTER_OP(_contrib_box_nms) -.set_attr("FCompute", BoxNMSForward); +.set_attr("FCompute", BoxNMSForwardGPU); NNVM_REGISTER_OP(_backward_contrib_box_nms) .set_attr("FCompute", BoxNMSBackward); diff --git a/src/operator/tensor/sort_op-inl.cuh b/src/operator/tensor/sort_op-inl.cuh index b20b466d9c2b..c157de99a4e1 100644 --- a/src/operator/tensor/sort_op-inl.cuh +++ b/src/operator/tensor/sort_op-inl.cuh @@ -95,13 +95,22 @@ SortPairsWorkspaceSize(const size_t num_keys) { template inline typename std::enable_if::value, size_t>::type -SortByKeyWorkspaceSize(const size_t num_keys) { +SortByKeyWorkspaceSize(const size_t num_keys, + const bool keys_in_place, + const bool values_in_place) { #ifdef SORT_WITH_THRUST return 0; #else size_t keys_bytes, values_bytes; WorkspaceSize4KeysAndValues(num_keys, &keys_bytes, &values_bytes); - return keys_bytes + values_bytes + SortPairsWorkspaceSize(num_keys); + size_t ret = SortPairsWorkspaceSize(num_keys); + if (keys_in_place) { + ret += keys_bytes; + } + if (values_in_place) { + ret += values_bytes; + } + return ret; #endif } @@ -111,7 +120,9 @@ inline typename std::enable_if::val SortByKeyImpl(mshadow::Tensor keys, mshadow::Tensor values, bool is_ascend, mshadow::Tensor* workspace, - const int begin_bit, const int end_bit) { + const int begin_bit, const int end_bit, + mshadow::Tensor* sorted_keys, + mshadow::Tensor* sorted_values) { CHECK_EQ(keys.CheckContiguous(), true); CHECK_EQ(values.CheckContiguous(), true); #if CUDA_VERSION >= 7000 @@ -135,12 +146,29 @@ SortByKeyImpl(mshadow::Tensor keys, NULL, NULL, NULL, NULL, keys.size(0), begin_bit, end_bit, stream); } + + size_t required_storage = sortpairs_bytes + + (sorted_keys == nullptr ? keys_bytes : 0) + + (sorted_values == nullptr ? values_bytes : 0); + // Check that we have enough storage - CHECK_GE(workspace->size(0), keys_bytes + values_bytes + sortpairs_bytes); - // - KDType* keys_out_ptr = reinterpret_cast(workspace->dptr_); - VDType* values_out_ptr = reinterpret_cast(workspace->dptr_ + keys_bytes); - void* temp_storage = reinterpret_cast(workspace->dptr_ + keys_bytes + values_bytes); + CHECK_GE(workspace->size(0), required_storage) + << "Workspace given to SortByKey is too small: requested " << required_storage << + " B and got " << workspace->size(0) << " B."; + + size_t start_keys = 0; + size_t start_values = start_keys + + (sorted_keys == nullptr ? keys_bytes : 0); + size_t start_scratch = start_values + + (sorted_values == nullptr ? values_bytes : 0); + KDType* keys_out_ptr = sorted_keys == nullptr ? + reinterpret_cast(workspace->dptr_ + start_keys) : + sorted_keys->dptr_; + VDType* values_out_ptr = sorted_values == nullptr ? + reinterpret_cast(workspace->dptr_ + start_values) : + sorted_values->dptr_; + + void* temp_storage = reinterpret_cast(workspace->dptr_ + start_scratch); // Sort if (is_ascend) { cub::DeviceRadixSort::SortPairs(temp_storage, sortpairs_bytes, @@ -152,17 +180,31 @@ SortByKeyImpl(mshadow::Tensor keys, keys.size(0), begin_bit, end_bit, stream); } // Copy result back to [keys, values] - mshadow::Tensor keys_out(keys_out_ptr, mshadow::Shape1(keys.size(0)), - keys.stream_); - mshadow::Tensor values_out(values_out_ptr, mshadow::Shape1(keys.size(0)), - keys.stream_); - mshadow::Copy(keys, keys_out, keys.stream_); - mshadow::Copy(values, values_out, values.stream_); + if (sorted_keys == nullptr) { + mshadow::Tensor keys_out(keys_out_ptr, mshadow::Shape1(keys.size(0)), + keys.stream_); + mshadow::Copy(keys, keys_out, keys.stream_); + } + if (sorted_values == nullptr) { + mshadow::Tensor values_out(values_out_ptr, mshadow::Shape1(keys.size(0)), + keys.stream_); + mshadow::Copy(values, values_out, values.stream_); + } } else { #endif // SORT_WITH_THRUST // No workspace, sort using thrust - thrust::device_ptr key_iter = thrust::device_pointer_cast(keys.dptr_); - thrust::device_ptr value_iter = thrust::device_pointer_cast(values.dptr_); + auto* k = &keys; + auto* v = &values; + if (sorted_keys != nullptr) { + k = sorted_keys; + mshadow::Copy(*sorted_keys, keys, keys.stream_); + } + if (sorted_values != nullptr) { + v = sorted_values; + mshadow::Copy(*sorted_values, values, values.stream_); + } + const auto key_iter = thrust::device_pointer_cast(k->dptr_); + const auto value_iter = thrust::device_pointer_cast(v->dptr_); if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), @@ -187,14 +229,25 @@ inline typename std::enable_if<((!std::is_same::va SortByKeyImpl(mshadow::Tensor keys, mshadow::Tensor values, bool is_ascend, mshadow::Tensor* workspace, - const int begin_bit, const int end_bit) { + const int begin_bit, const int end_bit, + mshadow::Tensor* sorted_keys, + mshadow::Tensor* sorted_values) { CHECK_EQ(keys.CheckContiguous(), true); CHECK_EQ(values.CheckContiguous(), true); #if CUDA_VERSION >= 9000 cudaStream_t stream = mshadow::Stream::GetStream(keys.stream_); - thrust::device_ptr key_iter = thrust::device_pointer_cast(keys.dptr_); - thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast( - reinterpret_cast<__half*>(values.dptr_)); + auto* k = &keys; + auto* v = &values; + if (sorted_keys != nullptr) { + k = sorted_keys; + mshadow::Copy(*sorted_keys, keys, keys.stream_); + } + if (sorted_values != nullptr) { + v = sorted_values; + mshadow::Copy(*sorted_values, values, values.stream_); + } + const auto key_iter = thrust::device_pointer_cast(k->dptr_); + const auto value_iter = thrust::device_pointer_cast(reinterpret_cast<__half*>(v->dptr_)); if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), @@ -216,14 +269,25 @@ inline typename std::enable_if<(std::is_same::valu SortByKeyImpl(mshadow::Tensor keys, mshadow::Tensor values, bool is_ascend, mshadow::Tensor* workspace, - const int begin_bit, const int end_bit) { + const int begin_bit, const int end_bit, + mshadow::Tensor* sorted_keys, + mshadow::Tensor* sorted_values) { CHECK_EQ(keys.CheckContiguous(), true); CHECK_EQ(values.CheckContiguous(), true); #if CUDA_VERSION >= 9000 cudaStream_t stream = mshadow::Stream::GetStream(keys.stream_); - thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast( - reinterpret_cast<__half*>(keys.dptr_)); - thrust::device_ptr value_iter = thrust::device_pointer_cast(values.dptr_); + auto* k = &keys; + auto* v = &values; + if (sorted_keys != nullptr) { + k = sorted_keys; + mshadow::Copy(*sorted_keys, keys, keys.stream_); + } + if (sorted_values != nullptr) { + v = sorted_values; + mshadow::Copy(*sorted_values, values, values.stream_); + } + const auto key_iter = thrust::device_pointer_cast(reinterpret_cast<__half*>(k->dptr_)); + const auto value_iter = thrust::device_pointer_cast(v->dptr_); if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), @@ -246,15 +310,25 @@ inline typename std::enable_if<(std::is_same::valu SortByKeyImpl(mshadow::Tensor keys, mshadow::Tensor values, bool is_ascend, mshadow::Tensor* workspace, - const int begin_bit, const int end_bit) { + const int begin_bit, const int end_bit, + mshadow::Tensor* sorted_keys, + mshadow::Tensor* sorted_values) { CHECK_EQ(keys.CheckContiguous(), true); CHECK_EQ(values.CheckContiguous(), true); #if CUDA_VERSION >= 9000 cudaStream_t stream = mshadow::Stream::GetStream(keys.stream_); - thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast( - reinterpret_cast<__half*>(keys.dptr_)); - thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast( - reinterpret_cast<__half*>(values.dptr_)); + auto* k = &keys; + auto* v = &values; + if (sorted_keys != nullptr) { + k = sorted_keys; + mshadow::Copy(*sorted_keys, keys, keys.stream_); + } + if (sorted_values != nullptr) { + v = sorted_values; + mshadow::Copy(*sorted_values, values, values.stream_); + } + const auto key_iter = thrust::device_pointer_cast(reinterpret_cast<__half*>(k->dptr_)); + const auto value_iter = thrust::device_pointer_cast(reinterpret_cast<__half*>(v->dptr_)); if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), @@ -273,8 +347,10 @@ SortByKeyImpl(mshadow::Tensor keys, template inline void SortByKey(mshadow::Tensor keys, mshadow::Tensor values, bool is_ascend, mshadow::Tensor* workspace, - const int begin_bit, const int end_bit) { - SortByKeyImpl(keys, values, is_ascend, workspace, begin_bit, end_bit); + const int begin_bit, const int end_bit, + mshadow::Tensor* sorted_keys, + mshadow::Tensor* sorted_values) { + SortByKeyImpl(keys, values, is_ascend, workspace, begin_bit, end_bit, sorted_keys, sorted_values); } } // namespace op diff --git a/src/operator/tensor/sort_op.h b/src/operator/tensor/sort_op.h index 6d4675a0775a..11aea9db09ec 100644 --- a/src/operator/tensor/sort_op.h +++ b/src/operator/tensor/sort_op.h @@ -49,11 +49,17 @@ namespace op { * \param keys the keys to sort * \param values the values that sorts w.r.t the key * \param is_ascend whether to sort key in ascending order + * \param begin_bit The beginning bit of the different values in keys. Default 0. + * \param end_bit The ending bit of the different values in keys. Default to 8 * sizeof(dtype of key). + * \param sorted_keys If specified, keys will be sorted out of place. + * \param sorted_values If specified, values will be sorted out of place. */ template inline void SortByKey(mshadow::Tensor keys, mshadow::Tensor values, bool is_ascend = true, mshadow::Tensor* workspace = NULL, - const int begin_bit = 0, const int end_bit = sizeof(KDType)*8) { + const int begin_bit = 0, const int end_bit = sizeof(KDType)*8, + mshadow::Tensor* sorted_keys = nullptr, + mshadow::Tensor* sorted_values = nullptr) { CHECK_EQ(keys.CheckContiguous(), true); CHECK_EQ(values.CheckContiguous(), true); CHECK_EQ(keys.size(0), values.size(0)) @@ -62,6 +68,12 @@ inline void SortByKey(mshadow::Tensor keys, mshadow::Tensor idx(keys.size(0)); std::vector keys_vec(keys.size(0)); std::vector values_vec(values.size(0)); + if (sorted_keys == nullptr) { + sorted_keys = &keys; + } + if (sorted_values == nullptr) { + sorted_values = &values; + } for (index_t i = 0; i < keys.size(0); i++) { idx[i] = i; keys_vec[i] = keys[i]; @@ -77,18 +89,28 @@ inline void SortByKey(mshadow::Tensor keys, mshadow::Tensor keys_vec[i2]; }); } for (index_t i = 0; i < values.size(0); i++) { - keys[i] = keys_vec[idx[i]]; - values[i] = values_vec[idx[i]]; + (*sorted_keys)[i] = keys_vec[idx[i]]; + (*sorted_values)[i] = values_vec[idx[i]]; } } /*! * \brief CPU/GPU: Return the amount of temporary storage in bytes required for SortByKey * \param num_keys number of keys to sort + * \param keys_in_place Whether the sorting of keys will happen in place. + * Default true. If set to false, subsequent + * call to SortByKey needs to specify the + * sorted_keys parameter. + * \param values_in_place Whether the sorting of values will happen in place. + * Default true. If set to false, subsequent + * call to SortByKey needs to specify the + * sorted_values parameter. */ template inline typename std::enable_if::value, size_t>::type -SortByKeyWorkspaceSize(const size_t num_keys) { +SortByKeyWorkspaceSize(const size_t num_keys, + const bool keys_in_place = true, + const bool values_in_place = true) { return 0; } @@ -97,18 +119,34 @@ SortByKeyWorkspaceSize(const size_t num_keys) { * \param keys the keys to sort * \param values the values that sorts w.r.t the key * \param is_ascend whether to sort key in ascending order + * \param begin_bit The beginning bit of the different values in keys. Default 0. + * \param end_bit The ending bit of the different values in keys. Default to 8 * sizeof(dtype of key). + * \param sorted_keys If specified, keys will be sorted out of place. + * \param sorted_values If specified, values will be sorted out of place. */ template inline void SortByKey(mshadow::Tensor keys, mshadow::Tensor values, bool is_ascend = true, mshadow::Tensor* workspace = NULL, - const int begin_bit = 0, const int end_bit = sizeof(KDType)*8); + const int begin_bit = 0, const int end_bit = sizeof(KDType)*8, + mshadow::Tensor* sorted_keys = nullptr, + mshadow::Tensor* sorted_values = nullptr); /*! * \brief CPU/GPU: Return the amount of temporary storage in bytes required for SortByKey * \param num_keys number of keys to sort + * \param keys_in_place Whether the sorting of keys will happen in place. + * Default true. If set to false, subsequent + * call to SortByKey needs to specify the + * sorted_keys parameter. + * \param values_in_place Whether the sorting of values will happen in place. + * Default true. If set to false, subsequent + * call to SortByKey needs to specify the + * sorted_values parameter. */ template inline typename std::enable_if::value, size_t>::type -SortByKeyWorkspaceSize(const size_t num_keys); +SortByKeyWorkspaceSize(const size_t num_keys, + const bool keys_in_place = true, + const bool values_in_place = true); } // namespace op } // namespace mxnet