From 1c5165c05d4f2c5414a13fe9dcfbe5becad5a13e Mon Sep 17 00:00:00 2001 From: NguyenNhuDi Date: Fri, 2 May 2025 09:56:56 -0600 Subject: [PATCH 1/3] fixed block mergesort for case with oob_defaul --- .../backend/rocprim/block/block_merge_sort.hpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp b/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp index 1593062f155..3aa1cde5993 100644 --- a/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp +++ b/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp @@ -387,7 +387,7 @@ class BlockMergeSortStrategy KeyT max_key = oob_default; #pragma unroll - for (int item = 1; item < ITEMS_PER_THREAD; ++item) + for (int item = 0; item < ITEMS_PER_THREAD; ++item) { if (ITEMS_PER_THREAD * static_cast(linear_tid) + item < valid_items) { @@ -439,14 +439,16 @@ class BlockMergeSortStrategy int thread_idx_in_thread_group_being_merged = mask & linear_tid; + const int ITEMS_PER_BLOCK = ITEMS_PER_THREAD * blockDim.x; + int diag = - (::rocprim::min)(valid_items, + (::rocprim::min)(ITEMS_PER_BLOCK, ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged); - int keys1_beg = (::rocprim::min)(valid_items, start); - int keys1_end = (::rocprim::min)(valid_items, keys1_beg + size); + int keys1_beg = (::rocprim::min)(ITEMS_PER_BLOCK, start); + int keys1_end = (::rocprim::min)(ITEMS_PER_BLOCK, keys1_beg + size); int keys2_beg = keys1_end; - int keys2_end = (::rocprim::min)(valid_items, keys2_beg + size); + int keys2_end = (::rocprim::min)(ITEMS_PER_BLOCK, keys2_beg + size); int keys1_count = keys1_end - keys1_beg; int keys2_count = keys2_end - keys2_beg; From 372a08507a9375fd4410c8de303b1e614684c0d4 Mon Sep 17 00:00:00 2001 From: NguyenNhuDi Date: Fri, 2 May 2025 10:01:04 -0600 Subject: [PATCH 2/3] updated changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ccaf1d6918..5dccab7311e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ Full documentation for hipCUB is available at [https://rocm.docs.amd.com/project * Updated `thread_load` and `thread_store` to align hipCUB with CUB. * All kernels now have hidden symbol visibility. All symbols now have inline namespaces that include the library version, (for example, hipcub::HIPCUB_300400_NS::symbol instead of hipcub::symbol), letting the user link multiple libraries built with different versions of hipCUB. +### Resolved issues +* Fixed an issue where `Sort(keys, compare_op, valid_items, oob_default)` in `block_merge_sort.hpp` would not fill in elements that are out of range (items after `valid_items`) with `oob_default`. + ### Known issues * `BlockAdjacentDifference::FlagHeads`, `BlockAdjacentDifference::FlagTails` and `BlockAdjacentDifference::FlagHeadsAndTails` have been removed from hipCUB's CUB backend. They were already deprecated as of version 2.12.0 of hipCUB and they were removed from CCCL (CUB) as of CCCL's 2.6.0 release. From 69b931af63ac3249083f79dae0c3973abb73e685 Mon Sep 17 00:00:00 2001 From: NguyenNhuDi Date: Thu, 15 May 2025 09:42:06 -0600 Subject: [PATCH 3/3] fixed warp merge sort error caused by changes --- .../backend/rocprim/block/block_merge_sort.hpp | 15 ++++++++------- .../backend/rocprim/warp/warp_merge_sort.hpp | 6 ++++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp b/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp index 3aa1cde5993..8d3bf048a41 100644 --- a/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp +++ b/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp @@ -172,7 +172,8 @@ template + typename SynchronizationPolicy, + bool WARP_SORT = false> class BlockMergeSortStrategy { static_assert(PowerOfTwo::VALUE, @@ -387,7 +388,7 @@ class BlockMergeSortStrategy KeyT max_key = oob_default; #pragma unroll - for (int item = 0; item < ITEMS_PER_THREAD; ++item) + for (int item = WARP_SORT ? 1 : 0; item < ITEMS_PER_THREAD; ++item) { if (ITEMS_PER_THREAD * static_cast(linear_tid) + item < valid_items) { @@ -439,16 +440,16 @@ class BlockMergeSortStrategy int thread_idx_in_thread_group_being_merged = mask & linear_tid; - const int ITEMS_PER_BLOCK = ITEMS_PER_THREAD * blockDim.x; + const int COMPARE_NUM = WARP_SORT ? valid_items : ITEMS_PER_THREAD * blockDim.x; int diag = - (::rocprim::min)(ITEMS_PER_BLOCK, + (::rocprim::min)(COMPARE_NUM, ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged); - int keys1_beg = (::rocprim::min)(ITEMS_PER_BLOCK, start); - int keys1_end = (::rocprim::min)(ITEMS_PER_BLOCK, keys1_beg + size); + int keys1_beg = (::rocprim::min)(COMPARE_NUM, start); + int keys1_end = (::rocprim::min)(COMPARE_NUM, keys1_beg + size); int keys2_beg = keys1_end; - int keys2_end = (::rocprim::min)(ITEMS_PER_BLOCK, keys2_beg + size); + int keys2_end = (::rocprim::min)(COMPARE_NUM, keys2_beg + size); int keys1_count = keys1_end - keys1_beg; int keys2_count = keys2_end - keys2_beg; diff --git a/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp b/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp index 5d7a627191b..136e674f051 100644 --- a/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp +++ b/hipcub/include/hipcub/backend/rocprim/warp/warp_merge_sort.hpp @@ -129,7 +129,8 @@ class WarpMergeSort ValueT, LOGICAL_WARP_THREADS, ITEMS_PER_THREAD, - WarpMergeSort> + WarpMergeSort, + true> { private: constexpr static bool IS_ARCH_WARP = LOGICAL_WARP_THREADS == HIPCUB_DEVICE_WARP_THREADS; @@ -140,7 +141,8 @@ class WarpMergeSort ValueT, LOGICAL_WARP_THREADS, ITEMS_PER_THREAD, - WarpMergeSort>; + WarpMergeSort, + true>; const unsigned int warp_id; const uint64_t member_mask;