diff --git a/projects/hipcub/CHANGELOG.md b/projects/hipcub/CHANGELOG.md index f89a8c12073..a388d28f69d 100644 --- a/projects/hipcub/CHANGELOG.md +++ b/projects/hipcub/CHANGELOG.md @@ -38,6 +38,10 @@ Full documentation for hipCUB is available at [https://rocm.docs.amd.com/project * The `hipcub::detail::accumulator_t` in rocPRIM backend has been changed to utilise `rocprim::accumulator_t`. * The usage of `rocprim::invoke_result_binary_op_t` has been replaced with `rocprim::accumulator_t`. +### Resolved issues +* Fixed an issue where `Sort(keys, compare_op, valid_items, oob_default)` in `block_merge_sort.hpp` would not fill in elements that are out of range (items after `valid_items`) with `oob_default`. +* Fixed an issue where `ScatterToStripedFlagged` in `block_exhange.hpp` was calling the wrong function. + ### Known issues * `BlockAdjacentDifference::FlagHeads`, `BlockAdjacentDifference::FlagTails` and `BlockAdjacentDifference::FlagHeadsAndTails` have been removed from hipCUB's CUB backend. They were already deprecated as of version 2.12.0 of hipCUB and they were removed from CCCL (CUB) as of CCCL's 2.6.0 release. diff --git a/projects/hipcub/hipcub/include/hipcub/backend/rocprim/block/block_exchange.hpp b/projects/hipcub/hipcub/include/hipcub/backend/rocprim/block/block_exchange.hpp index 4f90ac8ea1f..0053ae4a915 100644 --- a/projects/hipcub/hipcub/include/hipcub/backend/rocprim/block/block_exchange.hpp +++ b/projects/hipcub/hipcub/include/hipcub/backend/rocprim/block/block_exchange.hpp @@ -210,7 +210,7 @@ class BlockExchange OffsetT (&ranks)[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks ValidFlag (&is_valid)[ITEMS_PER_THREAD]) ///< [in] Corresponding flag denoting item validity { - ScatterToStriped(items, items, ranks, is_valid); + ScatterToStripedFlagged(items, items, ranks, is_valid); } #endif // DOXYGEN_SHOULD_SKIP_THIS diff --git a/projects/hipcub/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp b/projects/hipcub/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp index 9b9d78acc18..610cceee439 100644 --- a/projects/hipcub/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp +++ b/projects/hipcub/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp @@ -172,7 +172,8 @@ template + typename SynchronizationPolicy, + bool WARP_SORT = false> class BlockMergeSortStrategy { static_assert(PowerOfTwo::VALUE, @@ -387,7 +388,7 @@ class BlockMergeSortStrategy KeyT max_key = oob_default; #pragma unroll - for (int item = 1; item < ITEMS_PER_THREAD; ++item) + for (int item = WARP_SORT ? 1 : 0; item < ITEMS_PER_THREAD; ++item) { if (ITEMS_PER_THREAD * static_cast(linear_tid) + item < valid_items) { @@ -439,14 +440,16 @@ class BlockMergeSortStrategy int thread_idx_in_thread_group_being_merged = mask & linear_tid; + const int COMPARE_NUM = WARP_SORT ? valid_items : ITEMS_PER_THREAD * blockDim.x; + int diag = - (::rocprim::min)(valid_items, + (::rocprim::min)(COMPARE_NUM, ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged); - int keys1_beg = (::rocprim::min)(valid_items, start); - int keys1_end = (::rocprim::min)(valid_items, keys1_beg + size); + int keys1_beg = (::rocprim::min)(COMPARE_NUM, start); + int keys1_end = (::rocprim::min)(COMPARE_NUM, keys1_beg + size); int keys2_beg = keys1_end; - int keys2_end = (::rocprim::min)(valid_items, keys2_beg + size); + int keys2_end = (::rocprim::min)(COMPARE_NUM, keys2_beg + size); int keys1_count = keys1_end - keys1_beg; int keys2_count = keys2_end - keys2_beg; diff --git a/projects/hipcub/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp b/projects/hipcub/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp index 1bcf908d91e..86d3b4e0b98 100644 --- a/projects/hipcub/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp +++ b/projects/hipcub/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp @@ -129,7 +129,8 @@ class WarpMergeSort ValueT, LOGICAL_WARP_THREADS, ITEMS_PER_THREAD, - WarpMergeSort> + WarpMergeSort, + true> { private: constexpr static bool IS_ARCH_WARP = LOGICAL_WARP_THREADS == HIPCUB_DEVICE_WARP_THREADS; @@ -140,7 +141,8 @@ class WarpMergeSort ValueT, LOGICAL_WARP_THREADS, ITEMS_PER_THREAD, - WarpMergeSort>; + WarpMergeSort, + true>; const unsigned int warp_id; const uint64_t member_mask;