From 84d8fdd6575b5b44be3bdf5d1e71f2769e1756b7 Mon Sep 17 00:00:00 2001 From: NguyenNhuDi Date: Wed, 16 Apr 2025 11:22:16 -0600 Subject: [PATCH 1/2] added addition unit tests and implement missing function in match --- .../block/detail/block_radix_rank_match.hpp | 13 + test/rocprim/test_block_radix_rank.hpp | 235 +++++++++++++++++- 2 files changed, 246 insertions(+), 2 deletions(-) diff --git a/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp b/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp index f92b3bad1..5690e05bf 100644 --- a/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp +++ b/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp @@ -269,6 +269,19 @@ class block_radix_rank_match rank_keys_impl(keys, ranks, storage.get(), begin_bit, pass_bits); } + template + ROCPRIM_DEVICE void rank_keys_desc(const Key (&keys)[ItemsPerThread], + unsigned int (&ranks)[ItemsPerThread], + storage_type& storage, + DigitExtractor digit_extractor) + { + rank_keys_impl(keys, ranks, storage.get(), + [&digit_extractor](const Key & key){ + const unsigned int digit = digit_extractor(key); + return radix_digits - 1 - digit; + }); + } + template ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread], unsigned int (&ranks)[ItemsPerThread], diff --git a/test/rocprim/test_block_radix_rank.hpp b/test/rocprim/test_block_radix_rank.hpp index e078fdcfa..9e0bc35b5 100644 --- a/test/rocprim/test_block_radix_rank.hpp +++ b/test/rocprim/test_block_radix_rank.hpp @@ -59,8 +59,14 @@ static constexpr size_t n_sizes = 12; static constexpr unsigned int items_per_thread[n_sizes] = {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}; static constexpr unsigned int rank_desc[n_sizes] = {false, false, false, false, false, false, true, true, true, true, true, true}; + // = {false, false, false, false, false, false, false, false, false, false, false, false}; +static constexpr unsigned int use_storage[n_sizes] + = {false, true, false, true, false, true, false, true, false, true, false, true}; +static constexpr unsigned int end_bits[n_sizes] + = {0x1, 0x3, 0x7, 0xf, 0x1, 0x3, 0x7, 0xf, 0x1, 0x3, 0x7, 0xf}; static constexpr unsigned int pass_start_bit[n_sizes] = {0, 0, 0, 6, 2, 1, 0, 0, 0, 1, 4, 7}; static constexpr unsigned int max_radix_bits[n_sizes] = {4, 3, 5, 3, 1, 5, 4, 2, 4, 3, 1, 2}; +static constexpr unsigned int max_radix_bits_extractor[n_sizes] = {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}; static constexpr unsigned int pass_radix_bits[n_sizes] = {0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 1}; template +__global__ __launch_bounds__(BlockSize) void rank_kernel(const T* const items_input, + unsigned int* const ranks_output, + const bool descending, + const bool use_storage, + const unsigned int last_bits) +{ + using block_rank_type = rocprim::block_radix_rank; + using keys_exchange_type = rocprim::block_exchange; + using ranks_exchange_type = rocprim::block_exchange; + + constexpr bool warp_striped = Algorithm == rocprim::block_radix_rank_algorithm::match; + + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + const unsigned int lid = threadIdx.x; + const unsigned int block_offset = blockIdx.x * items_per_block; + + ROCPRIM_SHARED_MEMORY union + { + typename keys_exchange_type::storage_type keys_exchange; + typename block_rank_type::storage_type rank; + typename ranks_exchange_type::storage_type ranks_exchange; + } storage; + + T keys[ItemsPerThread]; + unsigned int ranks[ItemsPerThread]; + + rocprim::block_load_direct_blocked(lid, items_input + block_offset, keys); + if ROCPRIM_IF_CONSTEXPR(warp_striped) + { + // block_radix_rank_match requires warp striped input and output. Instead of using + // rocprim::block_load_direct_warp_striped though, we load directly and exchange the + // values manually, as we can also test with block sizes that do not divide the hardware + // warp size that way. + keys_exchange_type().blocked_to_warp_striped(keys, keys, storage.keys_exchange); + rocprim::syncthreads(); + } + + union converter{ + T in; + uint64_t out; + }; + + if(descending) + { + if (use_storage) + block_rank_type().rank_keys_desc(keys, ranks, storage.rank, [=](const T & key){ + converter c; + c.in = key; + uint64_t out = c.out & last_bits; + return out; + }); + else + block_rank_type().rank_keys_desc(keys, ranks, [=](const T & key){ + converter c; + c.in = key; + uint64_t out = c.out & last_bits; + return out; + }); + } + else + { + if (use_storage) + block_rank_type().rank_keys(keys, ranks, storage.rank, [=](const T & key){ + converter c; + c.in = key; + uint64_t out = c.out & last_bits; + return out; + }); + else + block_rank_type().rank_keys(keys, ranks, [=](const T & key){ + converter c; + c.in = key; + uint64_t out = c.out & last_bits; + return out; + }); } if ROCPRIM_IF_CONSTEXPR(warp_striped) @@ -130,6 +234,7 @@ template void test_block_radix_rank() { @@ -141,6 +246,7 @@ void test_block_radix_rank() constexpr size_t radix_bits = RadixBits; constexpr size_t end_bit = start_bit + radix_bits; constexpr bool descending = Descending; + constexpr bool use_storage = UseStorage; constexpr rocprim::block_radix_rank_algorithm algorithm = Algorithm; const size_t grid_size = 23; @@ -205,6 +311,7 @@ void test_block_radix_rank() d_keys_input.get(), d_ranks_output.get(), descending, + use_storage, start_bit, radix_bits); HIP_CHECK(hipGetLastError()); @@ -216,6 +323,111 @@ void test_block_radix_rank() } } +template +void test_block_radix_extractor_rank() +{ + constexpr size_t block_size = BlockSize; + constexpr size_t items_per_thread = ItemsPerThread; + constexpr size_t items_per_block = block_size * items_per_thread; + constexpr size_t max_radix_bits = MaxRadixBits; + constexpr size_t end_bits = EndBits; + constexpr bool descending = Descending; + constexpr bool use_storage = UseStorage; + constexpr rocprim::block_radix_rank_algorithm algorithm = Algorithm; + + const size_t grid_size = 23; + const size_t size = items_per_block * grid_size; + + SCOPED_TRACE(testing::Message() << "with block_size = " << block_size); + SCOPED_TRACE(testing::Message() << "with items_per_thread = " << items_per_thread); + SCOPED_TRACE(testing::Message() << "with descending = " << (descending ? "true" : "false")); + SCOPED_TRACE(testing::Message() << "with max_radix_bits = " << MaxRadixBits); + SCOPED_TRACE(testing::Message() << "with grid_size = " << size); + SCOPED_TRACE(testing::Message() << "with size = " << size); + + for(size_t seed_index = 0; seed_index < number_of_runs; ++seed_index) + { + seed_type seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + // Generate data + std::vector keys_input + = test_utils::get_random_data_wrapped(size, + common::generate_limits::min(), + common::generate_limits::max(), + seed_value); + + + union converter{ + T in; + uint64_t out; + }; + // Calculated expected results on host + std::vector expected(size); + for(size_t i = 0; i < grid_size; ++i) + { + size_t block_offset = i * items_per_block; + + // Perform an 'argsort', which gives a sorted sequence of indices into `keys_input`. + std::vector indices(items_per_block); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort( + indices.begin(), + indices.end(), + [&](const int& i, const int& j) + { + converter c; + c.in = keys_input[block_offset + i]; + uint64_t left = c.out & end_bits; + + c.in = keys_input[block_offset + j]; + + uint64_t right = c.out & end_bits; + + return descending ? right < left : left < right; + }); + + // Invert the sorted indices sequence to obtain the ranks. + for(size_t j = 0; j < items_per_block; ++j) + { + expected[block_offset + indices[j]] = static_cast(j); + } + } + + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_ranks_output(size); + + // Running kernel + hipLaunchKernelGGL( + HIP_KERNEL_NAME( + rank_kernel), + dim3(grid_size), + dim3(block_size), + 0, + 0, + d_keys_input.get(), + d_ranks_output.get(), + descending, + use_storage, + end_bits); + HIP_CHECK(hipGetLastError()); + + // Getting results to host + auto ranks_output = d_ranks_output.load(); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(ranks_output, expected)); + } +} + template(); } static_for::run(); } + + static void run_extractor() + { + { + SCOPED_TRACE(testing::Message() << "TestID = " << First); + test_block_radix_extractor_rank(); + } + static_for::run_extractor(); + } }; template { static void run() {} + static void run_extractor() {} }; template @@ -264,6 +494,7 @@ void test_block_radix_rank_algorithm() } static_for<0, n_sizes, type, block_size, Algorithm>::run(); + static_for<0, n_sizes, type, block_size, Algorithm>::run_extractor(); } #endif // TEST_BLOCK_RADIX_RANK_KERNELS_HPP_ From ac642310e13577516438b71adbdabc2cfcca0e75 Mon Sep 17 00:00:00 2001 From: NguyenNhuDi Date: Wed, 16 Apr 2025 11:24:46 -0600 Subject: [PATCH 2/2] updated changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e00641f08..d04be22ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,9 +19,11 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projec * Added initial value support for warp- and block-level inclusive scan. * Added support for building tests with device-side random data generation, making them finish faster. This requires rocRAND, and is enabled with the `WITH_ROCRAND=ON` build flag. * Added additional unit tests for `test_block_load.hpp` +* Added additional unit tests for `test_block_rank.hpp` * Added additional unit tests for `test_block_scan.hpp` * Added additional unit tests for `test_block_sort.hpp` * Added additional unit tests for `test_block_store.hpp` +* Added missing `rank_keys_desc` with `digit_extractor` parameter for `block_radix_rank_match.hpp` ### Changed