From c9cd5c86554b3a5ffe953aa2b8b4ffefb897a523 Mon Sep 17 00:00:00 2001 From: Anirudh Acharya Date: Wed, 21 Aug 2019 23:11:39 +0000 Subject: [PATCH] fix cuda sort routine for half precision --- src/operator/tensor/sort_op-inl.cuh | 38 +++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/src/operator/tensor/sort_op-inl.cuh b/src/operator/tensor/sort_op-inl.cuh index 36625568bc05..95f3a0757df1 100644 --- a/src/operator/tensor/sort_op-inl.cuh +++ b/src/operator/tensor/sort_op-inl.cuh @@ -65,19 +65,43 @@ struct greater_half }; } +#ifndef SORT_WITH_THRUST +template +inline void WorkspaceSize4KeysAndValues( + const size_t num_keys, size_t *pKeys_bytes, size_t *pValues_bytes) { + const size_t alignment = std::max(sizeof(KDType), sizeof(VDType)); + *pKeys_bytes = PadBytes(num_keys * sizeof(KDType), alignment); + *pValues_bytes = PadBytes(num_keys * sizeof(VDType), alignment); +} + +template +inline typename std::enable_if::value, size_t>::type +SortPairsWorkspaceSize(const size_t num_keys) { + size_t sortpairs_bytes = 0; + cub::DeviceRadixSort::SortPairs(NULL, sortpairs_bytes, + NULL, NULL, NULL, NULL, num_keys); + return sortpairs_bytes; +} + +template +inline typename std::enable_if::value, size_t>::type +SortPairsWorkspaceSize(const size_t num_keys) { + size_t sortpairs_bytes = 0; + cub::DeviceRadixSort::SortPairs(NULL, sortpairs_bytes, + NULL, NULL, NULL, NULL, num_keys); + return sortpairs_bytes; +} +#endif + template inline typename std::enable_if::value, size_t>::type SortByKeyWorkspaceSize(const size_t num_keys) { #ifdef SORT_WITH_THRUST return 0; #else - size_t sortpairs_bytes = 0; - cub::DeviceRadixSort::SortPairs(NULL, sortpairs_bytes, - NULL, NULL, NULL, NULL, num_keys); - size_t alignment = std::max(sizeof(KDType), sizeof(VDType)); - size_t keys_bytes = PadBytes(num_keys*sizeof(KDType), alignment); - size_t values_bytes = PadBytes(num_keys*sizeof(VDType), alignment); - return (keys_bytes + values_bytes + sortpairs_bytes); + size_t keys_bytes, values_bytes; + WorkspaceSize4KeysAndValues(num_keys, &keys_bytes, &values_bytes); + return keys_bytes + values_bytes + SortPairsWorkspaceSize(num_keys); #endif }