Skip to content

Commit

Permalink
fix potential floating number overflow, enable float16 (apache#12118)
Browse files Browse the repository at this point in the history
* fix potential floating number overflow, enable float16

* fix cuda impl

* fix cuda imple

* fix template substitution for windows

* half_f substantiate operand + fix

* remove ambiguous operand + for mshadow half_T

* fix con't

* use int32_t as indices

* use overload

* try remove ambiguous function overloading

* thrust version limit

* change sizeof cast from floor to ceil  when allocating buffers

* cleaner

* fix alignment of pointers
  • Loading branch information
zhreshold authored and anirudh2290 committed Sep 19, 2018
1 parent 95b0f94 commit 74f45fb
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 66 deletions.
4 changes: 2 additions & 2 deletions src/operator/contrib/bounding_box-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ struct valid_score {

template<typename DType>
int FilterScores(mshadow::Tensor<gpu, 1, DType> out_scores,
mshadow::Tensor<gpu, 1, DType> out_sorted_index,
mshadow::Tensor<gpu, 1, int32_t> out_sorted_index,
mshadow::Tensor<gpu, 1, DType> scores,
mshadow::Tensor<gpu, 1, DType> sorted_index,
mshadow::Tensor<gpu, 1, int32_t> sorted_index,
float valid_thresh) {
valid_score<DType> pred(static_cast<DType>(valid_thresh));
DType * end_scores = thrust::copy_if(thrust::device, scores.dptr_, scores.dptr_ + scores.MSize(),
Expand Down
86 changes: 47 additions & 39 deletions src/operator/contrib/bounding_box-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ inline uint32_t BoxNMSNumVisibleOutputs(const NodeAttrs& attrs) {

template<typename DType>
int FilterScores(mshadow::Tensor<cpu, 1, DType> out_scores,
mshadow::Tensor<cpu, 1, DType> out_sorted_index,
mshadow::Tensor<cpu, 1, int32_t> out_sorted_index,
mshadow::Tensor<cpu, 1, DType> scores,
mshadow::Tensor<cpu, 1, DType> sorted_index,
mshadow::Tensor<cpu, 1, int32_t> sorted_index,
float valid_thresh) {
index_t j = 0;
for (index_t i = 0; i < scores.size(0); i++) {
Expand Down Expand Up @@ -230,7 +230,7 @@ MSHADOW_XINLINE DType BoxArea(const DType *box, int encode) {

/*!
* \brief compute areas specialized for nms to reduce computation
*
*
* \param i the launched thread index (total thread num_batch * topk)
* \param out 1d array for areas (size num_batch * num_elem)
* \param in 1st coordinate of 1st box (buffer + coord_start)
Expand All @@ -243,7 +243,7 @@ MSHADOW_XINLINE DType BoxArea(const DType *box, int encode) {
struct compute_area {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
const DType *indices, const DType *batch_start,
const int32_t *indices, const int32_t *batch_start,
int topk, int num_elem, int stride, int encode) {
int b = i / topk;
int k = i % topk;
Expand Down Expand Up @@ -302,7 +302,7 @@ MSHADOW_XINLINE DType Intersect(const DType *a, const DType *b, int encode) {
*/
struct nms_impl {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *index, const DType *batch_start,
MSHADOW_XINLINE static void Map(int i, int32_t *index, const int32_t *batch_start,
const DType *input, const DType *areas,
int k, int ref, int num,
int stride, int offset_box, int offset_id,
Expand All @@ -326,8 +326,7 @@ struct nms_impl {
intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode);
int ref_area_offset = static_cast<int>(index[ref]);
int pos_area_offset = static_cast<int>(index[pos]);
DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] -
intersect);
DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] - intersect);
if (iou > thresh) {
index[pos] = -1;
}
Expand All @@ -336,7 +335,7 @@ struct nms_impl {

/*!
* \brief Assign output of nms by indexing input
*
*
* \param i the launched thread index (total num_batch)
* \param out output array [cls, conf, b0, b1, b2, b3]
* \param record book keeping the selected index for backward
Expand All @@ -349,7 +348,7 @@ struct nms_impl {
struct nms_assign {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, DType *record, const DType *input,
const DType *index, const DType *batch_start,
const int32_t *index, const int32_t *batch_start,
int k, int num, int stride) {
int count = 0;
for (int j = 0; j < k; ++j) {
Expand Down Expand Up @@ -404,7 +403,7 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2);
int num_elem = in_shape[indim - 2];
int width_elem = in_shape[indim - 1];
MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 3, DType> data = inputs[box_nms_enum::kData]
.get_with_shape<xpu, 3, DType>(Shape3(num_batch, num_elem, width_elem), s);
Tensor<xpu, 3, DType> out = outputs[box_nms_enum::kOut]
Expand All @@ -415,25 +414,33 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
// prepare workspace
Shape<1> sort_index_shape = Shape1(num_batch * num_elem);
Shape<3> buffer_shape = Shape3(num_batch, num_elem, width_elem);
index_t workspace_size = 4 * sort_index_shape.Size();
Shape<1> batch_start_shape = Shape1(num_batch + 1);
workspace_size += batch_start_shape.Size();

// index
index_t int32_size = sort_index_shape.Size() * 3 + batch_start_shape.Size();
index_t dtype_size = sort_index_shape.Size() * 2;
if (req[0] == kWriteInplace) {
workspace_size += buffer_shape.Size();
dtype_size += buffer_shape.Size();
}
// ceil up when sizeof(DType) is larger than sizeof(DType)
index_t int32_offset = (int32_size * sizeof(int32_t) - 1) / sizeof(DType) + 1;
index_t workspace_size = int32_offset + dtype_size;
Tensor<xpu, 1, DType> workspace = ctx.requested[box_nms_enum::kTempSpace]
.get_space_typed<xpu, 1, DType>(Shape1(workspace_size), s);
Tensor<xpu, 1, DType> sorted_index(workspace.dptr_, sort_index_shape, s);
Tensor<xpu, 1, DType> scores(sorted_index.dptr_ + sorted_index.MSize(),
Tensor<xpu, 1, int32_t> sorted_index(
reinterpret_cast<int32_t*>(workspace.dptr_), sort_index_shape, s);
Tensor<xpu, 1, int32_t> all_sorted_index(sorted_index.dptr_ + sorted_index.MSize(),
sort_index_shape, s);
Tensor<xpu, 1, DType> batch_id(scores.dptr_ + scores.MSize(), sort_index_shape,
s);
Tensor<xpu, 1, DType> areas(batch_id.dptr_ + batch_id.MSize(), sort_index_shape, s);
Tensor<xpu, 1, DType> batch_start(areas.dptr_ + areas.MSize(), batch_start_shape, s);
Tensor<xpu, 1, int32_t> batch_id(
all_sorted_index.dptr_ + all_sorted_index.MSize(), sort_index_shape, s);
Tensor<xpu, 1, int32_t> batch_start(batch_id.dptr_ + batch_id.MSize(), batch_start_shape, s);
Tensor<xpu, 1, DType> scores(workspace.dptr_ + int32_offset,
sort_index_shape, s);
Tensor<xpu, 1, DType> areas(scores.dptr_ + scores.MSize(), sort_index_shape, s);
Tensor<xpu, 3, DType> buffer = data;
if (req[0] == kWriteInplace) {
// make copy
buffer = Tensor<xpu, 3, DType>(batch_start.dptr_ + batch_start.MSize(), buffer_shape, s);
buffer = Tensor<xpu, 3, DType>(areas.dptr_ + areas.MSize(), buffer_shape, s);
buffer = F<mshadow_op::identity>(data);
}

Expand All @@ -451,10 +458,10 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
}

// use batch_id and areas as temporary storage
Tensor<xpu, 1, DType> all_scores = batch_id;
Tensor<xpu, 1, DType> all_sorted_index = areas;
Tensor<xpu, 1, DType> all_scores = areas;
// Tensor<xpu, 1, DType> all_sorted_index = areas;
all_scores = reshape(slice<2>(buffer, score_index, score_index + 1), all_scores.shape_);
all_sorted_index = range<DType>(0, num_batch * num_elem);
all_sorted_index = range<int32_t>(0, num_batch * num_elem);

// filter scores but keep original sorted_index value
// move valid score and index to the front, return valid size
Expand All @@ -474,19 +481,19 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,
// only sort the valid scores and batch_id
Shape<1> valid_score_shape = Shape1(num_valid);
Tensor<xpu, 1, DType> valid_scores(scores.dptr_, valid_score_shape, s);
Tensor<xpu, 1, DType> valid_sorted_index(sorted_index.dptr_, valid_score_shape, s);
Tensor<xpu, 1, DType> valid_batch_id(batch_id.dptr_, valid_score_shape, s);
Tensor<xpu, 1, int32_t> valid_sorted_index(sorted_index.dptr_, valid_score_shape, s);
Tensor<xpu, 1, int32_t> valid_batch_id(batch_id.dptr_, valid_score_shape, s);

// sort index by batch_id then score (stable sort)
mxnet::op::SortByKey(valid_scores, valid_sorted_index, false);
valid_batch_id = F<mshadow_op::floor>(valid_sorted_index / ScalarExp<DType>(num_elem));
valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
mxnet::op::SortByKey(valid_batch_id, valid_sorted_index, true);

// calculate batch_start: accumulated sum to denote 1st sorted_index for a given batch_index
valid_batch_id = F<mshadow_op::floor>(valid_sorted_index / ScalarExp<DType>(num_elem));
valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
for (int b = 0; b < num_batch + 1; b++) {
slice<0>(batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
F<mshadow_op::less_than>(valid_batch_id, ScalarExp<DType>(b)), 0);
F<mshadow_op::less_than>(valid_batch_id, ScalarExp<int32_t>(b)), 0);
}

// pre-compute areas of candidates
Expand Down Expand Up @@ -721,11 +728,11 @@ inline bool MatchingShape(const nnvm::NodeAttrs& attrs,
struct bipartite_matching {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *row_marker, DType *col_marker,
const DType *scores, const DType *sorted_index,
const DType *scores, const int32_t *sorted_index,
int num_batch, int num_row, int num_col,
float threshold, bool is_ascend, int topk) {
int stride = num_row * num_col;
const DType *index = sorted_index + i * stride;
const int32_t *index = sorted_index + i * stride;
const DType *score = scores + i * stride;
DType *rmarker = row_marker + i * num_row;
DType *cmarker = col_marker + i * num_col;
Expand Down Expand Up @@ -769,31 +776,32 @@ void BipartiteMatchingForward(const nnvm::NodeAttrs& attrs,
int row = dshape[dshape.ndim() - 2];
int col = dshape[dshape.ndim() - 1];
int batch_size = dshape.Size() / row / col;
MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> scores = inputs[0]
.get_with_shape<xpu, 1, DType>(Shape1(dshape.Size()), s);
Tensor<xpu, 2, DType> row_marker = outputs[0]
.get_with_shape<xpu, 2, DType>(Shape2(batch_size, row), s);
Tensor<xpu, 2, DType> col_marker = outputs[1]
.get_with_shape<xpu, 2, DType>(Shape2(batch_size, col), s);
Shape<1> sort_index_shape = Shape1(dshape.Size());
index_t workspace_size = sort_index_shape.Size() * 3;
index_t workspace_size = sort_index_shape.Size();
workspace_size += ((sort_index_shape.Size() * sizeof(int32_t) - 1) / sizeof(DType)) * 2;
Tensor<xpu, 1, DType> workspace = ctx.requested[0]
.get_space_typed<xpu, 1, DType>(Shape1(workspace_size), s);
Tensor<xpu, 1, DType> sorted_index(workspace.dptr_,
sort_index_shape, s);
Tensor<xpu, 1, DType> batch_id(sorted_index.dptr_ + sorted_index.MSize(),
Tensor<xpu, 1, DType> scores_copy(workspace.dptr_,
sort_index_shape, s);
Tensor<xpu, 1, DType> scores_copy(batch_id.dptr_ + batch_id.MSize(),
Tensor<xpu, 1, int32_t> sorted_index(reinterpret_cast<int32_t*>(
scores_copy.dptr_ + scores_copy.MSize()), sort_index_shape, s);
Tensor<xpu, 1, int32_t> batch_id(sorted_index.dptr_ + sorted_index.MSize(),
sort_index_shape, s);

// sort according to score
scores_copy = F<mshadow_op::identity>(scores);
sorted_index = range<DType>(0, dshape.Size());
sorted_index = range<int32_t>(0, dshape.Size());
mxnet::op::SortByKey(scores_copy, sorted_index, param.is_ascend);
batch_id = F<mshadow_op::floor>(sorted_index / ScalarExp<DType>(row * col));
batch_id = (sorted_index / ScalarExp<int32_t>(row * col));
mxnet::op::SortByKey(batch_id, scores_copy, true);
batch_id = F<mshadow_op::floor>(sorted_index / ScalarExp<DType>(row * col));
batch_id = (sorted_index / ScalarExp<int32_t>(row * col));
mxnet::op::SortByKey(batch_id, sorted_index, true);

// bipartite matching, parallelization is limited to batch_size
Expand Down
135 changes: 122 additions & 13 deletions src/operator/tensor/sort_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/
#ifndef MXNET_OPERATOR_TENSOR_SORT_OP_INL_CUH_
#define MXNET_OPERATOR_TENSOR_SORT_OP_INL_CUH_
#include <type_traits>
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#if defined(_MSC_VER) && __CUDACC_VER_MAJOR__ == 8 && __CUDACC_VER_BUILD__ != 44
Expand All @@ -40,6 +41,29 @@

namespace mxnet {
namespace op {
namespace cuda {
template<typename T>
struct less_half
{
typedef T first_argument_type;
typedef T second_argument_type;
typedef bool result_type;
__host__ __device__ bool operator()(const T &lhs, const T &rhs) const {
return static_cast<mshadow::half::half_t>(lhs) < static_cast<mshadow::half::half_t>(rhs);
}
};

template<typename T>
struct greater_half
{
typedef T first_argument_type;
typedef T second_argument_type;
typedef bool result_type;
__host__ __device__ bool operator()(const T &lhs, const T &rhs) const {
return static_cast<mshadow::half::half_t>(lhs) < static_cast<mshadow::half::half_t>(rhs);
}
};
}

template <typename KDType, typename VDType, typename xpu>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, size_t>::type
Expand All @@ -57,9 +81,12 @@ SortByKeyWorkspaceSize(const size_t num_keys) {
}

template<typename KDType, typename VDType>
inline void SortByKey(mshadow::Tensor<gpu, 1, KDType> keys, mshadow::Tensor<gpu, 1, VDType> values,
bool is_ascend, mshadow::Tensor<gpu, 1, char>* workspace,
const int begin_bit, const int end_bit) {
inline typename std::enable_if<!(std::is_same<KDType,mshadow::half::half_t>::value ||
std::is_same<VDType,mshadow::half::half_t>::value), void>::type
SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
mshadow::Tensor<gpu, 1, char>* workspace,
const int begin_bit, const int end_bit) {
CHECK_EQ(keys.CheckContiguous(), true);
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 7000
Expand Down Expand Up @@ -128,18 +155,100 @@ inline void SortByKey(mshadow::Tensor<gpu, 1, KDType> keys, mshadow::Tensor<gpu,
#endif
}

template<typename DType>
inline void SortByKey(mshadow::Tensor<gpu, 1, mshadow::half::half_t> keys,
mshadow::Tensor<gpu, 1, DType> values, bool is_ascend,
mshadow::Tensor<gpu, 1, char>* workspace, const int begin_bit, const int end_bit) {
LOG(FATAL) << "SortByKey for half_t is not implemented!";
template<typename KDType, typename VDType>
inline typename std::enable_if<((!std::is_same<KDType,mshadow::half::half_t>::value) &&
std::is_same<VDType,mshadow::half::half_t>::value), void>::type
SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
mshadow::Tensor<gpu, 1, char>* workspace,
const int begin_bit, const int end_bit) {
CHECK_EQ(keys.CheckContiguous(), true);
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<KDType> key_iter = thrust::device_pointer_cast(keys.dptr_);
thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(values.dptr_));
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter.get(), key_iter.get() + (keys.size(0)), value_iter.get(), thrust::less<KDType>());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter.get(), key_iter.get() + (keys.size(0)), value_iter.get(), thrust::greater<KDType>());
}
MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
#else
LOG(FATAL) << "SortByKey with fp16 values is only supported for CUDA version >= 9.0";
#endif
}

template<typename KDType, typename VDType>
inline typename std::enable_if<(std::is_same<KDType,mshadow::half::half_t>::value &&
(!std::is_same<VDType,mshadow::half::half_t>::value)), void>::type
SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
mshadow::Tensor<gpu, 1, char>* workspace,
const int begin_bit, const int end_bit) {
CHECK_EQ(keys.CheckContiguous(), true);
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(keys.dptr_));
thrust::device_ptr<VDType> value_iter = thrust::device_pointer_cast(values.dptr_);
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::less_half<half>());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::greater_half<half>());
}
MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
#else
LOG(FATAL) << "SortByKey with fp16 keys is only supported for CUDA version >= 9.0";
#endif
}

// use thrust sorting when keys or values are half_t
template<typename KDType, typename VDType>
inline typename std::enable_if<(std::is_same<KDType,mshadow::half::half_t>::value &&
std::is_same<VDType,mshadow::half::half_t>::value), void>::type
SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
mshadow::Tensor<gpu, 1, VDType> values, bool is_ascend,
mshadow::Tensor<gpu, 1, char>* workspace,
const int begin_bit, const int end_bit) {
CHECK_EQ(keys.CheckContiguous(), true);
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(keys.dptr_));
thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(values.dptr_));
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::less_half<half>());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::greater_half<half>());
}
MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
#else
LOG(FATAL) << "SortByKey with fp16 keys and values is only supported for CUDA version >= 9.0";
#endif
}

template<typename DType>
inline void SortByKey(mshadow::Tensor<gpu, 1, DType> keys,
mshadow::Tensor<gpu, 1, mshadow::half::half_t> values, bool is_ascend,
mshadow::Tensor<gpu, 1, char>* workspace, const int begin_bit, const int end_bit) {
LOG(FATAL) << "SortByKey for half_t is not implemented!";
template<typename KDType, typename VDType>
inline void SortByKey(mshadow::Tensor<gpu, 1, KDType> keys, mshadow::Tensor<gpu, 1, VDType> values,
bool is_ascend, mshadow::Tensor<gpu, 1, char>* workspace,
const int begin_bit, const int end_bit) {
SortByKeyImpl(keys, values, is_ascend, workspace, begin_bit, end_bit);
}

} // namespace op
Expand Down
Loading

0 comments on commit 74f45fb

Please sign in to comment.