Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ This is a complete list of affected functions and how their default accumulator
* Added the `rocprim::merge_inplace` function for merging in-place.
* Added initial value support for warp- and block-level inclusive scan.
* Added support for building tests with device-side random data generation, making them finish faster. This requires rocRAND, and is enabled with the `WITH_ROCRAND=ON` build flag.
* Added additional unit tests for `test_block_load.hpp`
* Added additional unit tests for `test_block_rank.hpp`
* Added additional unit tests for `test_block_scan.hpp`
* Added additional unit tests for `test_block_sort.hpp`
* Added additional unit tests for `test_block_store.hpp`
* Added missing `rank_keys_desc` with `digit_extractor` parameter for `block_radix_rank_match.hpp`

### Changed

Expand Down
13 changes: 13 additions & 0 deletions rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,19 @@ class block_radix_rank_match
rank_keys_impl<true>(keys, ranks, storage.get(), begin_bit, pass_bits);
}

template<typename Key, unsigned ItemsPerThread, typename DigitExtractor>
ROCPRIM_DEVICE void rank_keys_desc(const Key (&keys)[ItemsPerThread],
unsigned int (&ranks)[ItemsPerThread],
storage_type& storage,
DigitExtractor digit_extractor)
{
rank_keys_impl(keys, ranks, storage.get(),
[&digit_extractor](const Key & key){
const unsigned int digit = digit_extractor(key);
return radix_digits - 1 - digit;
});
}

template<typename Key, unsigned ItemsPerThread, typename DigitExtractor>
ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread],
unsigned int (&ranks)[ItemsPerThread],
Expand Down
259 changes: 259 additions & 0 deletions test/rocprim/test_block_load_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,179 @@ typed_test_def(suite_name, name_suffix, LoadStoreClassValid)

}

typed_test_def(suite_name, name_suffix, LoadStoreClassWithStorage)
{
int device_id = test_common_utils::obtain_device_from_ctest();
SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
HIP_CHECK(hipSetDevice(device_id));

using Type = typename TestFixture::params::type;
static constexpr size_t block_size = TestFixture::params::block_size;
static constexpr rocprim::block_load_method load_method = TestFixture::params::load_method;
static constexpr rocprim::block_store_method store_method = TestFixture::params::store_method;
static constexpr size_t items_per_thread = TestFixture::params::items_per_thread;
static constexpr auto items_per_block = block_size * items_per_thread;
const size_t size = items_per_block * 113;
const auto grid_size = size / items_per_block;
// Given block size not supported
if(block_size > test_utils::get_max_block_size() || (block_size & (block_size - 1)) != 0)
{
return;
}

if(load_method == rocprim::block_load_method::block_load_warp_transpose
|| store_method == rocprim::block_store_method::block_store_warp_transpose)
{
unsigned int host_warp_size;
HIP_CHECK(::rocprim::host_warp_size(device_id, host_warp_size));
if(block_size % host_warp_size != 0)
{
GTEST_SKIP() << "Cannot run test of block size " << block_size
<< " on a device with warp size " << host_warp_size;
}
}

const size_t valid = items_per_block;

for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++)
{
unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);

// Generate data
std::vector<Type> input
= test_utils::get_random_data_wrapped<Type>(size, -100, 100, seed_value);
std::vector<Type> output(input.size(), (Type)0);

// Calculate expected results on host
std::vector<Type> expected(input.size(), (Type)0);
for (size_t i = 0; i < 113; i++)
{
size_t block_offset = i * items_per_block;
for (size_t j = 0; j < items_per_block; j++)
{
if (j < valid)
{
expected[j + block_offset] = input[j + block_offset];
}
}
}

// Preparing device
common::device_ptr<Type> device_input(input);
// Have to initialize output for unvalid data to make sure they are not changed
common::device_ptr<Type> device_output(output);

// Running kernel
hipLaunchKernelGGL(HIP_KERNEL_NAME(load_store_storage_kernel<Type,
load_method,
store_method,
block_size,
items_per_thread>),
dim3(grid_size),
dim3(block_size),
0,
0,
device_input.get(),
device_output.get());
HIP_CHECK(hipGetLastError());

// Reading results from device
output = device_output.load();

// Validating results
ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected));
}

}

typed_test_def(suite_name, name_suffix, LoadStoreClassValidWithStorage)
{
int device_id = test_common_utils::obtain_device_from_ctest();
SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
HIP_CHECK(hipSetDevice(device_id));

using Type = typename TestFixture::params::type;
static constexpr size_t block_size = TestFixture::params::block_size;
static constexpr rocprim::block_load_method load_method = TestFixture::params::load_method;
static constexpr rocprim::block_store_method store_method = TestFixture::params::store_method;
static constexpr size_t items_per_thread = TestFixture::params::items_per_thread;
static constexpr auto items_per_block = block_size * items_per_thread;
const size_t size = items_per_block * 113;
const auto grid_size = size / items_per_block;
// Given block size not supported
if(block_size > test_utils::get_max_block_size() || (block_size & (block_size - 1)) != 0)
{
return;
}

if(load_method == rocprim::block_load_method::block_load_warp_transpose
|| store_method == rocprim::block_store_method::block_store_warp_transpose)
{
unsigned int host_warp_size;
HIP_CHECK(::rocprim::host_warp_size(device_id, host_warp_size));
if(block_size % host_warp_size != 0)
{
GTEST_SKIP() << "Cannot run test of block size " << block_size
<< " on a device with warp size " << host_warp_size;
}
}

const size_t valid = items_per_block - 32;

for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++)
{
unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);

// Generate data
std::vector<Type> input
= test_utils::get_random_data_wrapped<Type>(size, -100, 100, seed_value);
std::vector<Type> output(input.size(), (Type)0);

// Calculate expected results on host
std::vector<Type> expected(input.size(), (Type)0);
for (size_t i = 0; i < 113; i++)
{
size_t block_offset = i * items_per_block;
for (size_t j = 0; j < items_per_block; j++)
{
if (j < valid)
{
expected[j + block_offset] = input[j + block_offset];
}
}
}

// Preparing device
common::device_ptr<Type> device_input(input);
// Have to initialize output for unvalid data to make sure they are not changed
common::device_ptr<Type> device_output(output);

// Running kernel
hipLaunchKernelGGL(HIP_KERNEL_NAME(load_store_storage_valid_kernel<Type,
load_method,
store_method,
block_size,
items_per_thread>),
dim3(grid_size),
dim3(block_size),
0,
0,
device_input.get(),
device_output.get(),
valid);
HIP_CHECK(hipGetLastError());

// Reading results from device
output = device_output.load();

// Validating results
ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected));
}

}

typed_test_def(suite_name, name_suffix, LoadStoreClassDefault)
{
int device_id = test_common_utils::obtain_device_from_ctest();
Expand Down Expand Up @@ -277,3 +450,89 @@ typed_test_def(suite_name, name_suffix, LoadStoreClassDefault)
ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected));
}
}

typed_test_def(suite_name, name_suffix, LoadStoreClassDefaultWithStorage)
{
int device_id = test_common_utils::obtain_device_from_ctest();
SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
HIP_CHECK(hipSetDevice(device_id));

using Type = typename TestFixture::params::type;
static constexpr size_t block_size = TestFixture::params::block_size;
static constexpr rocprim::block_load_method load_method = TestFixture::params::load_method;
static constexpr rocprim::block_store_method store_method = TestFixture::params::store_method;
static constexpr size_t items_per_thread = TestFixture::params::items_per_thread;
static constexpr auto items_per_block = block_size * items_per_thread;
const size_t size = items_per_block * 113;
const auto grid_size = size / items_per_block;
// Given block size not supported
if(block_size > test_utils::get_max_block_size() || (block_size & (block_size - 1)) != 0)
{
return;
}

if(load_method == rocprim::block_load_method::block_load_warp_transpose
|| store_method == rocprim::block_store_method::block_store_warp_transpose)
{
unsigned int host_warp_size;
HIP_CHECK(::rocprim::host_warp_size(device_id, host_warp_size));
if(block_size % host_warp_size != 0)
{
GTEST_SKIP() << "Cannot run test of block size " << block_size
<< " on a device with warp size " << host_warp_size;
}
}

const size_t valid = items_per_thread + 1;
Type _default = (Type)-1;

for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++)
{
unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);

// Generate data
std::vector<Type> input
= test_utils::get_random_data_wrapped<Type>(size, -100, 100, seed_value);

// Calculate expected results on host
std::vector<Type> expected(input.size(), _default);
for (size_t i = 0; i < 113; i++)
{
size_t block_offset = i * items_per_block;
for (size_t j = 0; j < items_per_block; j++)
{
if (j < valid)
{
expected[j + block_offset] = input[j + block_offset];
}
}
}

// Preparing device
common::device_ptr<Type> device_input(input);
common::device_ptr<Type> device_output(size);

// Running kernel
hipLaunchKernelGGL(HIP_KERNEL_NAME(load_store_valid_default_storage_kernel<Type,
load_method,
store_method,
block_size,
items_per_thread>),
dim3(grid_size),
dim3(block_size),
0,
0,
device_input.get(),
device_output.get(),
valid,
_default);
HIP_CHECK(hipGetLastError());

// Reading results from device
const auto output = device_output.load();

// Validating results
ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected));
}
}
Loading