Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add fp16 support for topk (#15560)
Browse files Browse the repository at this point in the history
* fp16 for topk

* indentation and mem alloc

* fix cuda code

* fix cuda sort routine for half precision

* remove ambiguous half type
  • Loading branch information
anirudhacharya authored and sxjscience committed Aug 23, 2019
1 parent 1eb1925 commit cba7c4e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 62 deletions.
22 changes: 7 additions & 15 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
* \param resource temporary resource handler
* \param src the Source blob
* \param ret the destination blobs
* \param k the K elements to keep
* \param param the topk parameters
* \tparam xpu the device type.
* \tparam DType type of the output value/mask.
Expand Down Expand Up @@ -563,13 +562,13 @@ void TopK(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKImpl<xpu, DType, IDType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
})
});
} else {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKImpl<xpu, DType, index_t>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
});
}
Expand Down Expand Up @@ -695,13 +694,13 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnBoth) {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKBackwardImpl<xpu, DType, IDType>(ctx, inputs, req, outputs, param);
});
});
} else if (param.ret_typ == topk_enum::kReturnValue) {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKBackwardImpl<xpu, DType, index_t>(ctx, inputs, req, outputs, param);
});
} else {
Expand Down Expand Up @@ -732,7 +731,6 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
int data_type = -1;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
CHECK_EQ(in_size, 1);
Expand All @@ -756,15 +754,9 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs,
CHECK(type_assign(&(*out_attrs)[0], param.dtype))
<< "Failed to set the type of ret_indices.";
} else {
CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
<< (*in_attrs)[0];
CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
<< (*out_attrs)[0];
CHECK(type_assign(&(*in_attrs)[0], data_type)) << "Incompatible dtype of input, in_attrs[0]="
<< (*in_attrs)[0];
CHECK(type_assign(&(*out_attrs)[0], data_type)) << "Incompatible dtype of output, out_attrs[0]="
<< (*out_attrs)[0];
if (data_type == -1) return false;
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return out_attrs->at(0) != -1;
}
return true;
}
Expand Down
68 changes: 46 additions & 22 deletions src/operator/tensor/sort_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,48 @@ struct greater_half
typedef T second_argument_type;
typedef bool result_type;
__host__ __device__ bool operator()(const T &lhs, const T &rhs) const {
return static_cast<mshadow::half::half_t>(lhs) < static_cast<mshadow::half::half_t>(rhs);
return static_cast<mshadow::half::half_t>(lhs) > static_cast<mshadow::half::half_t>(rhs);
}
};
}

#ifndef SORT_WITH_THRUST
template <typename KDType, typename VDType>
inline void WorkspaceSize4KeysAndValues(
const size_t num_keys, size_t *pKeys_bytes, size_t *pValues_bytes) {
const size_t alignment = std::max(sizeof(KDType), sizeof(VDType));
*pKeys_bytes = PadBytes(num_keys * sizeof(KDType), alignment);
*pValues_bytes = PadBytes(num_keys * sizeof(VDType), alignment);
}

template <typename KDType, typename VDType>
inline typename std::enable_if<!std::is_same<KDType, mshadow::half::half_t>::value, size_t>::type
SortPairsWorkspaceSize(const size_t num_keys) {
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<KDType, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
return sortpairs_bytes;
}

template <typename KDType, typename VDType>
inline typename std::enable_if<std::is_same<KDType, mshadow::half::half_t>::value, size_t>::type
SortPairsWorkspaceSize(const size_t num_keys) {
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<__half, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
return sortpairs_bytes;
}
#endif

template <typename KDType, typename VDType, typename xpu>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, size_t>::type
SortByKeyWorkspaceSize(const size_t num_keys) {
#ifdef SORT_WITH_THRUST
return 0;
#else
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<KDType, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
size_t alignment = std::max(sizeof(KDType), sizeof(VDType));
size_t keys_bytes = PadBytes(num_keys*sizeof(KDType), alignment);
size_t values_bytes = PadBytes(num_keys*sizeof(VDType), alignment);
return (keys_bytes + values_bytes + sortpairs_bytes);
size_t keys_bytes, values_bytes;
WorkspaceSize4KeysAndValues<KDType, VDType>(num_keys, &keys_bytes, &values_bytes);
return keys_bytes + values_bytes + SortPairsWorkspaceSize<KDType, VDType>(num_keys);
#endif
}

Expand Down Expand Up @@ -142,11 +166,11 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + keys.size(0), value_iter, thrust::less<KDType>());
key_iter, key_iter + keys.size(0), value_iter.get(), thrust::less<KDType>());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + keys.size(0), value_iter, thrust::greater<KDType>());
key_iter, key_iter + keys.size(0), value_iter.get(), thrust::greater<KDType>());
}
#ifndef SORT_WITH_THRUST
}
Expand All @@ -169,8 +193,8 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<KDType> key_iter = thrust::device_pointer_cast(keys.dptr_);
thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(values.dptr_));
thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(values.dptr_));
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
Expand All @@ -197,17 +221,17 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(keys.dptr_));
thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(keys.dptr_));
thrust::device_ptr<VDType> value_iter = thrust::device_pointer_cast(values.dptr_);
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::less_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<__half>());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::greater_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<__half>());
}
MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
#else
Expand All @@ -227,18 +251,18 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(keys.dptr_));
thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(values.dptr_));
thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(keys.dptr_));
thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(values.dptr_));
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::less_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<__half>());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter, cuda::greater_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<__half>());
}
MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
#else
Expand Down
64 changes: 39 additions & 25 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4304,15 +4304,6 @@ def get_large_matrix():

large_matrix_npy = get_large_matrix()

for axis in [1, 3, None]:
K = [1, 3, 5, 7] if axis is None else [1, 3, 5]
for k in K:
for is_ascend in [True, False]:
b = mx.sym.topk(a, axis=axis, is_ascend=is_ascend, ret_typ="value", k=k)
out_npy = gt_topk(dat=a_npy, axis=axis, ret_typ="value", k=k, is_ascend=is_ascend)
check_numeric_gradient(b, location={'a': a_npy}, numeric_eps=1e-2, ctx=ctx)
check_symbolic_forward(b, location={'a': a_npy}, expected=[out_npy])

for axis in [1, 3, None]:
for is_ascend in [True, False]:
b = mx.sym.sort(a, axis=axis, is_ascend=is_ascend)
Expand All @@ -4332,22 +4323,6 @@ def get_large_matrix():
ret_typ="indices", k=5,
is_ascend=is_ascend)])

b = mx.sym.topk(a, axis=3, is_ascend=is_ascend, ret_typ="indices", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 3))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=3, ret_typ="indices", k=3,
is_ascend=False)])

b = mx.sym.topk(a, axis=1, is_ascend=True, ret_typ="mask", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 5))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=1, ret_typ="mask", k=3,
is_ascend=True)])

b = mx.sym.argsort(a, axis=1, is_ascend=False)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 5))],
Expand All @@ -4372,6 +4347,45 @@ def get_large_matrix():
expected=[gt_topk(dat=a_npy, axis=1, ret_typ="indices", k=1,
is_ascend=True)])

for dtype in [np.float16, np.float32, np.float64]:
dshape = (5, 5, 5, 5)
a_npy = np.arange(np.prod(dshape)).astype(dtype)
np.random.shuffle(a_npy)
a_npy = a_npy.reshape(dshape)
a = mx.sym.Variable('a')
for axis in [1, 3, None]:
K = [1, 3, 5, 7] if axis is None else [1, 3, 5]
for k in K:
for is_ascend in [True, False]:
b = mx.sym.topk(a, axis=axis, is_ascend=is_ascend, ret_typ="value", k=k)
out_npy = gt_topk(dat=a_npy, axis=axis, ret_typ="value", k=k, is_ascend=is_ascend)
check_numeric_gradient(b, location={'a': a_npy}, numeric_eps=1e-2, ctx=ctx)
check_symbolic_forward(b, location={'a': a_npy}, expected=[out_npy])

b = mx.sym.topk(a, axis=1, is_ascend=is_ascend, ret_typ="indices", k=5)
check_symbolic_backward(sym=b, location={'a': large_matrix_npy},
out_grads=[np.random.normal(size=(100, 5))],
expected=[np.zeros((100, 300096))])
check_symbolic_forward(b, location={'a': large_matrix_npy},
expected=[gt_topk(dat=large_matrix_npy, axis=1,
ret_typ="indices", k=5, is_ascend=is_ascend)])

b = mx.sym.topk(a, axis=3, is_ascend=is_ascend, ret_typ="indices", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 3))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=3, ret_typ="indices", k=3,
is_ascend=False)])

b = mx.sym.topk(a, axis=1, is_ascend=True, ret_typ="mask", k=3)
check_symbolic_backward(sym=b, location={'a': a_npy},
out_grads=[np.random.normal(size=(5, 5, 5, 5))],
expected=[np.zeros((5, 5, 5, 5))])
check_symbolic_forward(b, location={'a': a_npy},
expected=[gt_topk(dat=a_npy, axis=1, ret_typ="mask", k=3,
is_ascend=True)])


@with_seed()
def test_blockgrad():
Expand Down

0 comments on commit cba7c4e

Please sign in to comment.