diff --git a/src/operator/tensor/sort_op-inl.cuh b/src/operator/tensor/sort_op-inl.cuh index 95f3a0757df1..b20b466d9c2b 100644 --- a/src/operator/tensor/sort_op-inl.cuh +++ b/src/operator/tensor/sort_op-inl.cuh @@ -87,7 +87,7 @@ template inline typename std::enable_if::value, size_t>::type SortPairsWorkspaceSize(const size_t num_keys) { size_t sortpairs_bytes = 0; - cub::DeviceRadixSort::SortPairs(NULL, sortpairs_bytes, + cub::DeviceRadixSort::SortPairs<__half, VDType>(NULL, sortpairs_bytes, NULL, NULL, NULL, NULL, num_keys); return sortpairs_bytes; } @@ -193,8 +193,8 @@ SortByKeyImpl(mshadow::Tensor keys, #if CUDA_VERSION >= 9000 cudaStream_t stream = mshadow::Stream::GetStream(keys.stream_); thrust::device_ptr key_iter = thrust::device_pointer_cast(keys.dptr_); - thrust::device_ptr value_iter = thrust::device_pointer_cast( - reinterpret_cast(values.dptr_)); + thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast( + reinterpret_cast<__half*>(values.dptr_)); if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), @@ -221,17 +221,17 @@ SortByKeyImpl(mshadow::Tensor keys, CHECK_EQ(values.CheckContiguous(), true); #if CUDA_VERSION >= 9000 cudaStream_t stream = mshadow::Stream::GetStream(keys.stream_); - thrust::device_ptr key_iter = thrust::device_pointer_cast( - reinterpret_cast(keys.dptr_)); + thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast( + reinterpret_cast<__half*>(keys.dptr_)); thrust::device_ptr value_iter = thrust::device_pointer_cast(values.dptr_); if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), - key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half()); + key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<__half>()); } else { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), - key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half()); + key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<__half>()); } MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey); #else @@ -251,18 +251,18 @@ SortByKeyImpl(mshadow::Tensor keys, CHECK_EQ(values.CheckContiguous(), true); #if CUDA_VERSION >= 9000 cudaStream_t stream = mshadow::Stream::GetStream(keys.stream_); - thrust::device_ptr key_iter = thrust::device_pointer_cast( - reinterpret_cast(keys.dptr_)); - thrust::device_ptr value_iter = thrust::device_pointer_cast( - reinterpret_cast(values.dptr_)); + thrust::device_ptr<__half> key_iter = thrust::device_pointer_cast( + reinterpret_cast<__half*>(keys.dptr_)); + thrust::device_ptr<__half> value_iter = thrust::device_pointer_cast( + reinterpret_cast<__half*>(values.dptr_)); if (is_ascend) { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), - key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half()); + key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::less_half<__half>()); } else { thrust::stable_sort_by_key( thrust::cuda::par.on(stream), - key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half()); + key_iter, key_iter + (keys.size(0)), value_iter.get(), cuda::greater_half<__half>()); } MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey); #else