Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projec
* Added initial value support for warp- and block-level inclusive scan.
* Added support for building tests with device-side random data generation, making them finish faster. This requires rocRAND, and is enabled with the `WITH_ROCRAND=ON` build flag.
* Added additional unit tests for `test_block_load.hpp`
* Added additional unit tests for `test_block_rank.hpp`
* Added additional unit tests for `test_block_scan.hpp`
* Added additional unit tests for `test_block_sort.hpp`
* Added additional unit tests for `test_block_store.hpp`
* Added missing `rank_keys_desc` with `digit_extractor` parameter for `block_radix_rank_match.hpp`

### Changed

Expand Down
13 changes: 13 additions & 0 deletions rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,19 @@ class block_radix_rank_match
rank_keys_impl<true>(keys, ranks, storage.get(), begin_bit, pass_bits);
}

template<typename Key, unsigned ItemsPerThread, typename DigitExtractor>
ROCPRIM_DEVICE void rank_keys_desc(const Key (&keys)[ItemsPerThread],
unsigned int (&ranks)[ItemsPerThread],
storage_type& storage,
DigitExtractor digit_extractor)
{
rank_keys_impl(keys, ranks, storage.get(),
[&digit_extractor](const Key & key){
const unsigned int digit = digit_extractor(key);
return radix_digits - 1 - digit;
});
}

template<typename Key, unsigned ItemsPerThread, typename DigitExtractor>
ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread],
unsigned int (&ranks)[ItemsPerThread],
Expand Down
235 changes: 233 additions & 2 deletions test/rocprim/test_block_radix_rank.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,14 @@ static constexpr size_t n_sizes = 12;
static constexpr unsigned int items_per_thread[n_sizes] = {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3};
static constexpr unsigned int rank_desc[n_sizes]
= {false, false, false, false, false, false, true, true, true, true, true, true};
// = {false, false, false, false, false, false, false, false, false, false, false, false};
static constexpr unsigned int use_storage[n_sizes]
= {false, true, false, true, false, true, false, true, false, true, false, true};
static constexpr unsigned int end_bits[n_sizes]
= {0x1, 0x3, 0x7, 0xf, 0x1, 0x3, 0x7, 0xf, 0x1, 0x3, 0x7, 0xf};
static constexpr unsigned int pass_start_bit[n_sizes] = {0, 0, 0, 6, 2, 1, 0, 0, 0, 1, 4, 7};
static constexpr unsigned int max_radix_bits[n_sizes] = {4, 3, 5, 3, 1, 5, 4, 2, 4, 3, 1, 2};
static constexpr unsigned int max_radix_bits_extractor[n_sizes] = {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4};
static constexpr unsigned int pass_radix_bits[n_sizes] = {0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 1};

template<typename T,
Expand All @@ -71,6 +77,7 @@ template<typename T,
__global__ __launch_bounds__(BlockSize) void rank_kernel(const T* const items_input,
unsigned int* const ranks_output,
const bool descending,
const bool use_storage,
const unsigned int start_bit,
const unsigned int radix_bits)
{
Expand Down Expand Up @@ -107,11 +114,108 @@ __global__ __launch_bounds__(BlockSize) void rank_kernel(const T* const ite

if(descending)
{
block_rank_type().rank_keys_desc(keys, ranks, storage.rank, start_bit, radix_bits);
if (use_storage)
block_rank_type().rank_keys_desc(keys, ranks, storage.rank, start_bit, radix_bits);
else
block_rank_type().rank_keys_desc(keys, ranks, start_bit, radix_bits);
}
else
{
block_rank_type().rank_keys(keys, ranks, storage.rank, start_bit, radix_bits);
if (use_storage)
block_rank_type().rank_keys(keys, ranks, storage.rank, start_bit, radix_bits);
else
block_rank_type().rank_keys(keys, ranks, start_bit, radix_bits);
}

if ROCPRIM_IF_CONSTEXPR(warp_striped)
{
// See the comment above.
rocprim::syncthreads();
ranks_exchange_type().warp_striped_to_blocked(ranks, ranks, storage.ranks_exchange);
}
rocprim::block_store_direct_blocked(lid, ranks_output + block_offset, ranks);
}

template<typename T,
unsigned int BlockSize,
unsigned int ItemsPerThread,
unsigned int MaxRadixBits,
rocprim::block_radix_rank_algorithm Algorithm>
__global__ __launch_bounds__(BlockSize) void rank_kernel(const T* const items_input,
unsigned int* const ranks_output,
const bool descending,
const bool use_storage,
const unsigned int last_bits)
{
using block_rank_type = rocprim::block_radix_rank<BlockSize, MaxRadixBits, Algorithm>;
using keys_exchange_type = rocprim::block_exchange<T, BlockSize, ItemsPerThread>;
using ranks_exchange_type = rocprim::block_exchange<unsigned int, BlockSize, ItemsPerThread>;

constexpr bool warp_striped = Algorithm == rocprim::block_radix_rank_algorithm::match;

constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
const unsigned int lid = threadIdx.x;
const unsigned int block_offset = blockIdx.x * items_per_block;

ROCPRIM_SHARED_MEMORY union
{
typename keys_exchange_type::storage_type keys_exchange;
typename block_rank_type::storage_type rank;
typename ranks_exchange_type::storage_type ranks_exchange;
} storage;

T keys[ItemsPerThread];
unsigned int ranks[ItemsPerThread];

rocprim::block_load_direct_blocked(lid, items_input + block_offset, keys);
if ROCPRIM_IF_CONSTEXPR(warp_striped)
{
// block_radix_rank_match requires warp striped input and output. Instead of using
// rocprim::block_load_direct_warp_striped though, we load directly and exchange the
// values manually, as we can also test with block sizes that do not divide the hardware
// warp size that way.
keys_exchange_type().blocked_to_warp_striped(keys, keys, storage.keys_exchange);
rocprim::syncthreads();
}

union converter{
T in;
uint64_t out;
};

if(descending)
{
if (use_storage)
block_rank_type().rank_keys_desc(keys, ranks, storage.rank, [=](const T & key){
converter c;
c.in = key;
uint64_t out = c.out & last_bits;
return out;
});
else
block_rank_type().rank_keys_desc(keys, ranks, [=](const T & key){
converter c;
c.in = key;
uint64_t out = c.out & last_bits;
return out;
});
}
else
{
if (use_storage)
block_rank_type().rank_keys(keys, ranks, storage.rank, [=](const T & key){
converter c;
c.in = key;
uint64_t out = c.out & last_bits;
return out;
});
else
block_rank_type().rank_keys(keys, ranks, [=](const T & key){
converter c;
c.in = key;
uint64_t out = c.out & last_bits;
return out;
});
}

if ROCPRIM_IF_CONSTEXPR(warp_striped)
Expand All @@ -130,6 +234,7 @@ template<typename T,
unsigned int MaxRadixBits,
unsigned int RadixBits,
bool Descending,
bool UseStorage,
rocprim::block_radix_rank_algorithm Algorithm>
void test_block_radix_rank()
{
Expand All @@ -141,6 +246,7 @@ void test_block_radix_rank()
constexpr size_t radix_bits = RadixBits;
constexpr size_t end_bit = start_bit + radix_bits;
constexpr bool descending = Descending;
constexpr bool use_storage = UseStorage;
constexpr rocprim::block_radix_rank_algorithm algorithm = Algorithm;

const size_t grid_size = 23;
Expand Down Expand Up @@ -205,6 +311,7 @@ void test_block_radix_rank()
d_keys_input.get(),
d_ranks_output.get(),
descending,
use_storage,
start_bit,
radix_bits);
HIP_CHECK(hipGetLastError());
Expand All @@ -216,6 +323,111 @@ void test_block_radix_rank()
}
}

template<typename T,
unsigned int BlockSize,
unsigned int ItemsPerThread,
unsigned int MaxRadixBits,
unsigned int EndBits,
bool Descending,
bool UseStorage,
rocprim::block_radix_rank_algorithm Algorithm>
void test_block_radix_extractor_rank()
{
constexpr size_t block_size = BlockSize;
constexpr size_t items_per_thread = ItemsPerThread;
constexpr size_t items_per_block = block_size * items_per_thread;
constexpr size_t max_radix_bits = MaxRadixBits;
constexpr size_t end_bits = EndBits;
constexpr bool descending = Descending;
constexpr bool use_storage = UseStorage;
constexpr rocprim::block_radix_rank_algorithm algorithm = Algorithm;

const size_t grid_size = 23;
const size_t size = items_per_block * grid_size;

SCOPED_TRACE(testing::Message() << "with block_size = " << block_size);
SCOPED_TRACE(testing::Message() << "with items_per_thread = " << items_per_thread);
SCOPED_TRACE(testing::Message() << "with descending = " << (descending ? "true" : "false"));
SCOPED_TRACE(testing::Message() << "with max_radix_bits = " << MaxRadixBits);
SCOPED_TRACE(testing::Message() << "with grid_size = " << size);
SCOPED_TRACE(testing::Message() << "with size = " << size);

for(size_t seed_index = 0; seed_index < number_of_runs; ++seed_index)
{
seed_type seed_value
= seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);

// Generate data
std::vector<T> keys_input
= test_utils::get_random_data_wrapped<T>(size,
common::generate_limits<T>::min(),
common::generate_limits<T>::max(),
seed_value);


union converter{
T in;
uint64_t out;
};
// Calculated expected results on host
std::vector<unsigned int> expected(size);
for(size_t i = 0; i < grid_size; ++i)
{
size_t block_offset = i * items_per_block;

// Perform an 'argsort', which gives a sorted sequence of indices into `keys_input`.
std::vector<int> indices(items_per_block);
std::iota(indices.begin(), indices.end(), 0);

std::stable_sort(
indices.begin(),
indices.end(),
[&](const int& i, const int& j)
{
converter c;
c.in = keys_input[block_offset + i];
uint64_t left = c.out & end_bits;

c.in = keys_input[block_offset + j];

uint64_t right = c.out & end_bits;

return descending ? right < left : left < right;
});

// Invert the sorted indices sequence to obtain the ranks.
for(size_t j = 0; j < items_per_block; ++j)
{
expected[block_offset + indices[j]] = static_cast<int>(j);
}
}

common::device_ptr<T> d_keys_input(keys_input);
common::device_ptr<unsigned int> d_ranks_output(size);

// Running kernel
hipLaunchKernelGGL(
HIP_KERNEL_NAME(
rank_kernel<T, block_size, items_per_thread, max_radix_bits, algorithm>),
dim3(grid_size),
dim3(block_size),
0,
0,
d_keys_input.get(),
d_ranks_output.get(),
descending,
use_storage,
end_bits);
HIP_CHECK(hipGetLastError());

// Getting results to host
auto ranks_output = d_ranks_output.load();

ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(ranks_output, expected));
}
}

template<unsigned int First,
unsigned int Last,
typename T,
Expand All @@ -237,10 +449,27 @@ struct static_for
max_radix_bits[First],
radix_bits,
rank_desc[First],
use_storage[First],
Algorithm>();
}
static_for<First + 1, Last, T, BlockSize, Algorithm>::run();
}

static void run_extractor()
{
{
SCOPED_TRACE(testing::Message() << "TestID = " << First);
test_block_radix_extractor_rank<T,
BlockSize,
items_per_thread[First],
max_radix_bits_extractor[First],
end_bits[First],
rank_desc[First],
use_storage[First],
Algorithm>();
}
static_for<First + 1, Last, T, BlockSize, Algorithm>::run_extractor();
}
};

template<unsigned int Last,
Expand All @@ -250,6 +479,7 @@ template<unsigned int Last,
struct static_for<Last, Last, T, BlockSize, Algorithm>
{
static void run() {}
static void run_extractor() {}
};

template<rocprim::block_radix_rank_algorithm Algorithm, typename TestFixture>
Expand All @@ -264,6 +494,7 @@ void test_block_radix_rank_algorithm()
}

static_for<0, n_sizes, type, block_size, Algorithm>::run();
static_for<0, n_sizes, type, block_size, Algorithm>::run_extractor();
}

#endif // TEST_BLOCK_RADIX_RANK_KERNELS_HPP_