Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
remove ambiguous half type
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudhacharya committed Aug 22, 2019
1 parent c9cd5c8 commit 8b0884f
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/operator/tensor/sort_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ template <typename KDType, typename VDType>
inline typename std::enable_if<std::is_same<KDType, mshadow::half::half_t>::value, size_t>::type
SortPairsWorkspaceSize(const size_t num_keys) {
size_t sortpairs_bytes = 0;
cub::DeviceRadixSort::SortPairs<half, VDType>(NULL, sortpairs_bytes,
cub::DeviceRadixSort::SortPairs<__half, VDType>(NULL, sortpairs_bytes,
NULL, NULL, NULL, NULL, num_keys);
return sortpairs_bytes;
}
Expand Down Expand Up @@ -193,8 +193,8 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<KDType> key_iter = thrust::device_pointer_cast(keys.dptr_);
thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(values.dptr_));
thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(values.dptr_));
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
Expand All @@ -221,17 +221,17 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(keys.dptr_));
thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(keys.dptr_));
thrust::device_ptr<VDType> value_iter = thrust::device_pointer_cast(values.dptr_);
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<__half>());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<__half>());
}
MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
#else
Expand All @@ -251,18 +251,18 @@ SortByKeyImpl(mshadow::Tensor<gpu, 1, KDType> keys,
CHECK_EQ(values.CheckContiguous(), true);
#if CUDA_VERSION >= 9000
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(keys.stream_);
thrust::device_ptr<half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(keys.dptr_));
thrust::device_ptr<half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<half*>(values.dptr_));
thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(keys.dptr_));
thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast(
reinterpret_cast<__half*>(values.dptr_));
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<__half>());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<half>());
key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<__half>());
}
MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey);
#else
Expand Down

0 comments on commit 8b0884f

Please sign in to comment.