Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projec
* Added support for building tests with device-side random data generation, making them finish faster. This requires rocRAND, and is enabled with the `WITH_ROCRAND=ON` build flag.
* Added additional unit tests for `test_block_scan.hpp`
* Added additional unit tests for `test_block_sort.hpp`
* Added additional unit tests for `test_block_load.hpp`

### Changed

Expand Down
4 changes: 0 additions & 4 deletions test/rocprim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,6 @@ add_rocprim_test("rocprim.block_histogram" test_block_histogram.cpp)
add_rocprim_test("rocprim.block_load_store" test_block_load_store.cpp)
add_rocprim_test("rocprim.block_sort_merge" test_block_sort_merge.cpp)
add_rocprim_test("rocprim.block_sort_merge_stable" test_block_sort_merge_stable.cpp)
add_rocprim_test("rocprim.block_store_direct" test_block_store_direct.cpp)
add_rocprim_test("rocprim.block_store_striped" test_block_store_striped.cpp)
add_rocprim_test("rocprim.block_store_transpose" test_block_store_transpose.cpp)
add_rocprim_test("rocprim.block_store_vectorize" test_block_store_vectorize.cpp)
add_rocprim_test_parallel("rocprim.block_radix_rank" test_block_radix_rank.cpp.in)
add_rocprim_test_parallel("rocprim.block_radix_sort" test_block_radix_sort.cpp.in)
add_rocprim_test("rocprim.block_reduce" test_block_reduce.cpp)
Expand Down
259 changes: 259 additions & 0 deletions test/rocprim/test_block_load_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,179 @@ typed_test_def(suite_name, name_suffix, LoadStoreClassValid)

}

typed_test_def(suite_name, name_suffix, LoadStoreClassWithStorage)
{
int device_id = test_common_utils::obtain_device_from_ctest();
SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
HIP_CHECK(hipSetDevice(device_id));

using Type = typename TestFixture::params::type;
static constexpr size_t block_size = TestFixture::params::block_size;
static constexpr rocprim::block_load_method load_method = TestFixture::params::load_method;
static constexpr rocprim::block_store_method store_method = TestFixture::params::store_method;
static constexpr size_t items_per_thread = TestFixture::params::items_per_thread;
static constexpr auto items_per_block = block_size * items_per_thread;
const size_t size = items_per_block * 113;
const auto grid_size = size / items_per_block;
// Given block size not supported
if(block_size > test_utils::get_max_block_size() || (block_size & (block_size - 1)) != 0)
{
return;
}

if(load_method == rocprim::block_load_method::block_load_warp_transpose
|| store_method == rocprim::block_store_method::block_store_warp_transpose)
{
unsigned int host_warp_size;
HIP_CHECK(::rocprim::host_warp_size(device_id, host_warp_size));
if(block_size % host_warp_size != 0)
{
GTEST_SKIP() << "Cannot run test of block size " << block_size
<< " on a device with warp size " << host_warp_size;
}
}

const size_t valid = items_per_block;

for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++)
{
unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);

// Generate data
std::vector<Type> input
= test_utils::get_random_data_wrapped<Type>(size, -100, 100, seed_value);
std::vector<Type> output(input.size(), (Type)0);

// Calculate expected results on host
std::vector<Type> expected(input.size(), (Type)0);
for (size_t i = 0; i < 113; i++)
{
size_t block_offset = i * items_per_block;
for (size_t j = 0; j < items_per_block; j++)
{
if (j < valid)
{
expected[j + block_offset] = input[j + block_offset];
}
}
}

// Preparing device
common::device_ptr<Type> device_input(input);
// Have to initialize output for unvalid data to make sure they are not changed
common::device_ptr<Type> device_output(output);

// Running kernel
hipLaunchKernelGGL(HIP_KERNEL_NAME(load_store_storage_kernel<Type,
load_method,
store_method,
block_size,
items_per_thread>),
dim3(grid_size),
dim3(block_size),
0,
0,
device_input.get(),
device_output.get());
HIP_CHECK(hipGetLastError());

// Reading results from device
output = device_output.load();

// Validating results
ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected));
}

}

typed_test_def(suite_name, name_suffix, LoadStoreClassValidWithStorage)
{
int device_id = test_common_utils::obtain_device_from_ctest();
SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
HIP_CHECK(hipSetDevice(device_id));

using Type = typename TestFixture::params::type;
static constexpr size_t block_size = TestFixture::params::block_size;
static constexpr rocprim::block_load_method load_method = TestFixture::params::load_method;
static constexpr rocprim::block_store_method store_method = TestFixture::params::store_method;
static constexpr size_t items_per_thread = TestFixture::params::items_per_thread;
static constexpr auto items_per_block = block_size * items_per_thread;
const size_t size = items_per_block * 113;
const auto grid_size = size / items_per_block;
// Given block size not supported
if(block_size > test_utils::get_max_block_size() || (block_size & (block_size - 1)) != 0)
{
return;
}

if(load_method == rocprim::block_load_method::block_load_warp_transpose
|| store_method == rocprim::block_store_method::block_store_warp_transpose)
{
unsigned int host_warp_size;
HIP_CHECK(::rocprim::host_warp_size(device_id, host_warp_size));
if(block_size % host_warp_size != 0)
{
GTEST_SKIP() << "Cannot run test of block size " << block_size
<< " on a device with warp size " << host_warp_size;
}
}

const size_t valid = items_per_block - 32;

for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++)
{
unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);

// Generate data
std::vector<Type> input
= test_utils::get_random_data_wrapped<Type>(size, -100, 100, seed_value);
std::vector<Type> output(input.size(), (Type)0);

// Calculate expected results on host
std::vector<Type> expected(input.size(), (Type)0);
for (size_t i = 0; i < 113; i++)
{
size_t block_offset = i * items_per_block;
for (size_t j = 0; j < items_per_block; j++)
{
if (j < valid)
{
expected[j + block_offset] = input[j + block_offset];
}
}
}

// Preparing device
common::device_ptr<Type> device_input(input);
// Have to initialize output for unvalid data to make sure they are not changed
common::device_ptr<Type> device_output(output);

// Running kernel
hipLaunchKernelGGL(HIP_KERNEL_NAME(load_store_storage_valid_kernel<Type,
load_method,
store_method,
block_size,
items_per_thread>),
dim3(grid_size),
dim3(block_size),
0,
0,
device_input.get(),
device_output.get(),
valid);
HIP_CHECK(hipGetLastError());

// Reading results from device
output = device_output.load();

// Validating results
ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected));
}

}

typed_test_def(suite_name, name_suffix, LoadStoreClassDefault)
{
int device_id = test_common_utils::obtain_device_from_ctest();
Expand Down Expand Up @@ -277,3 +450,89 @@ typed_test_def(suite_name, name_suffix, LoadStoreClassDefault)
ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected));
}
}

typed_test_def(suite_name, name_suffix, LoadStoreClassDefaultWithStorage)
{
int device_id = test_common_utils::obtain_device_from_ctest();
SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
HIP_CHECK(hipSetDevice(device_id));

using Type = typename TestFixture::params::type;
static constexpr size_t block_size = TestFixture::params::block_size;
static constexpr rocprim::block_load_method load_method = TestFixture::params::load_method;
static constexpr rocprim::block_store_method store_method = TestFixture::params::store_method;
static constexpr size_t items_per_thread = TestFixture::params::items_per_thread;
static constexpr auto items_per_block = block_size * items_per_thread;
const size_t size = items_per_block * 113;
const auto grid_size = size / items_per_block;
// Given block size not supported
if(block_size > test_utils::get_max_block_size() || (block_size & (block_size - 1)) != 0)
{
return;
}

if(load_method == rocprim::block_load_method::block_load_warp_transpose
|| store_method == rocprim::block_store_method::block_store_warp_transpose)
{
unsigned int host_warp_size;
HIP_CHECK(::rocprim::host_warp_size(device_id, host_warp_size));
if(block_size % host_warp_size != 0)
{
GTEST_SKIP() << "Cannot run test of block size " << block_size
<< " on a device with warp size " << host_warp_size;
}
}

const size_t valid = items_per_thread + 1;
Type _default = (Type)-1;

for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++)
{
unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count];
SCOPED_TRACE(testing::Message() << "with seed = " << seed_value);

// Generate data
std::vector<Type> input
= test_utils::get_random_data_wrapped<Type>(size, -100, 100, seed_value);

// Calculate expected results on host
std::vector<Type> expected(input.size(), _default);
for (size_t i = 0; i < 113; i++)
{
size_t block_offset = i * items_per_block;
for (size_t j = 0; j < items_per_block; j++)
{
if (j < valid)
{
expected[j + block_offset] = input[j + block_offset];
}
}
}

// Preparing device
common::device_ptr<Type> device_input(input);
common::device_ptr<Type> device_output(size);

// Running kernel
hipLaunchKernelGGL(HIP_KERNEL_NAME(load_store_valid_default_storage_kernel<Type,
load_method,
store_method,
block_size,
items_per_thread>),
dim3(grid_size),
dim3(block_size),
0,
0,
device_input.get(),
device_output.get(),
valid,
_default);
HIP_CHECK(hipGetLastError());

// Reading results from device
const auto output = device_output.load();

// Validating results
ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected));
}
}
Loading