diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ccaf1d6..5dccab73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ Full documentation for hipCUB is available at [https://rocm.docs.amd.com/project * Updated `thread_load` and `thread_store` to align hipCUB with CUB. * All kernels now have hidden symbol visibility. All symbols now have inline namespaces that include the library version, (for example, hipcub::HIPCUB_300400_NS::symbol instead of hipcub::symbol), letting the user link multiple libraries built with different versions of hipCUB. +### Resolved issues +* Fixed an issue where `Sort(keys, compare_op, valid_items, oob_default)` in `block_merge_sort.hpp` would not fill in elements that are out of range (items after `valid_items`) with `oob_default`. + ### Known issues * `BlockAdjacentDifference::FlagHeads`, `BlockAdjacentDifference::FlagTails` and `BlockAdjacentDifference::FlagHeadsAndTails` have been removed from hipCUB's CUB backend. They were already deprecated as of version 2.12.0 of hipCUB and they were removed from CCCL (CUB) as of CCCL's 2.6.0 release. diff --git a/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp b/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp index 1593062f..8d3bf048 100644 --- a/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp +++ b/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp @@ -172,7 +172,8 @@ template + typename SynchronizationPolicy, + bool WARP_SORT = false> class BlockMergeSortStrategy { static_assert(PowerOfTwo::VALUE, @@ -387,7 +388,7 @@ class BlockMergeSortStrategy KeyT max_key = oob_default; #pragma unroll - for (int item = 1; item < ITEMS_PER_THREAD; ++item) + for (int item = WARP_SORT ? 1 : 0; item < ITEMS_PER_THREAD; ++item) { if (ITEMS_PER_THREAD * static_cast(linear_tid) + item < valid_items) { @@ -439,14 +440,16 @@ class BlockMergeSortStrategy int thread_idx_in_thread_group_being_merged = mask & linear_tid; + const int COMPARE_NUM = WARP_SORT ? valid_items : ITEMS_PER_THREAD * blockDim.x; + int diag = - (::rocprim::min)(valid_items, + (::rocprim::min)(COMPARE_NUM, ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged); - int keys1_beg = (::rocprim::min)(valid_items, start); - int keys1_end = (::rocprim::min)(valid_items, keys1_beg + size); + int keys1_beg = (::rocprim::min)(COMPARE_NUM, start); + int keys1_end = (::rocprim::min)(COMPARE_NUM, keys1_beg + size); int keys2_beg = keys1_end; - int keys2_end = (::rocprim::min)(valid_items, keys2_beg + size); + int keys2_end = (::rocprim::min)(COMPARE_NUM, keys2_beg + size); int keys1_count = keys1_end - keys1_beg; int keys2_count = keys2_end - keys2_beg; diff --git a/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp b/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp index 5d7a6271..136e674f 100644 --- a/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp +++ b/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp @@ -129,7 +129,8 @@ class WarpMergeSort ValueT, LOGICAL_WARP_THREADS, ITEMS_PER_THREAD, - WarpMergeSort> + WarpMergeSort, + true> { private: constexpr static bool IS_ARCH_WARP = LOGICAL_WARP_THREADS == HIPCUB_DEVICE_WARP_THREADS; @@ -140,7 +141,8 @@ class WarpMergeSort ValueT, LOGICAL_WARP_THREADS, ITEMS_PER_THREAD, - WarpMergeSort>; + WarpMergeSort, + true>; const unsigned int warp_id; const uint64_t member_mask;