Skip to content

Commit

Permalink
Fix retrieve_all for containers with large capacity (#580)
Browse files Browse the repository at this point in the history
Fix #576

This PR fixes the large input retrieve_all bug with a method similar to
the streaming approach mentioned in
NVIDIA/cccl#1422 (comment).

To be reverted once the CCCL fix is in place.
  • Loading branch information
PointKernel authored Aug 16, 2024
1 parent 6eaed1b commit abc5095
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 35 deletions.
80 changes: 48 additions & 32 deletions include/cuco/detail/open_addressing/open_addressing_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -639,45 +639,61 @@ class open_addressing_impl {
template <typename OutputIt>
[[nodiscard]] OutputIt retrieve_all(OutputIt output_begin, cuda::stream_ref stream) const
{
std::size_t temp_storage_bytes = 0;
using temp_allocator_type =
typename std::allocator_traits<allocator_type>::template rebind_alloc<char>;

cuco::detail::index_type constexpr stride = std::numeric_limits<int32_t>::max();

cuco::detail::index_type h_num_out{0};
auto temp_allocator = temp_allocator_type{this->allocator()};
auto d_num_out = reinterpret_cast<size_type*>(
std::allocator_traits<temp_allocator_type>::allocate(temp_allocator, sizeof(size_type)));
auto const begin = thrust::make_transform_iterator(
thrust::counting_iterator<size_type>{0},
open_addressing_ns::detail::get_slot<has_payload, storage_ref_type>(this->storage_ref()));
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
this->empty_key_sentinel(), this->erased_key_sentinel()};
CUCO_CUDA_TRY(cub::DeviceSelect::If(nullptr,
temp_storage_bytes,
begin,
output_begin,
d_num_out,
this->capacity(),
is_filled,
stream.get()));

// Allocate temporary storage
auto d_temp_storage = temp_allocator.allocate(temp_storage_bytes);

CUCO_CUDA_TRY(cub::DeviceSelect::If(d_temp_storage,
temp_storage_bytes,
begin,
output_begin,
d_num_out,
this->capacity(),
is_filled,
stream.get()));

size_type h_num_out;
CUCO_CUDA_TRY(cudaMemcpyAsync(
&h_num_out, d_num_out, sizeof(size_type), cudaMemcpyDeviceToHost, stream.get()));
stream.wait();

// TODO: PR #580 to be reverted once https://github.com/NVIDIA/cccl/issues/1422 is resolved
for (cuco::detail::index_type offset = 0;
offset < static_cast<cuco::detail::index_type>(this->capacity());
offset += stride) {
auto const num_items =
std::min(static_cast<cuco::detail::index_type>(this->capacity()) - offset, stride);
auto const begin = thrust::make_transform_iterator(
thrust::counting_iterator{static_cast<size_type>(offset)},
open_addressing_ns::detail::get_slot<has_payload, storage_ref_type>(this->storage_ref()));
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
this->empty_key_sentinel(), this->erased_key_sentinel()};

std::size_t temp_storage_bytes = 0;

CUCO_CUDA_TRY(cub::DeviceSelect::If(nullptr,
temp_storage_bytes,
begin,
output_begin + h_num_out,
d_num_out,
static_cast<int32_t>(num_items),
is_filled,
stream.get()));

// Allocate temporary storage
auto d_temp_storage = temp_allocator.allocate(temp_storage_bytes);

CUCO_CUDA_TRY(cub::DeviceSelect::If(d_temp_storage,
temp_storage_bytes,
begin,
output_begin + h_num_out,
d_num_out,
static_cast<int32_t>(num_items),
is_filled,
stream.get()));

size_type temp_count;
CUCO_CUDA_TRY(cudaMemcpyAsync(
&temp_count, d_num_out, sizeof(size_type), cudaMemcpyDeviceToHost, stream.get()));
stream.wait();
h_num_out += temp_count;
temp_allocator.deallocate(d_temp_storage, temp_storage_bytes);
}

std::allocator_traits<temp_allocator_type>::deallocate(
temp_allocator, reinterpret_cast<char*>(d_num_out), sizeof(size_type));
temp_allocator.deallocate(d_temp_storage, temp_storage_bytes);

return output_begin + h_num_out;
}
Expand Down
19 changes: 16 additions & 3 deletions tests/static_set/large_input_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,27 @@ void test_unique_sequence(Set& set, bool* res_begin, std::size_t num_keys)
set.contains(keys_begin, keys_end, res_begin);
REQUIRE(cuco::test::all_of(res_begin, res_begin + num_keys, thrust::identity{}));
}

SECTION("All inserted key/value pairs can be retrieved.")
{
auto output_keys = thrust::device_vector<Key>(num_keys);

auto const keys_end = set.retrieve_all(output_keys.begin());
REQUIRE(static_cast<std::size_t>(std::distance(output_keys.begin(), keys_end)) == num_keys);

thrust::sort(output_keys.begin(), keys_end);

REQUIRE(cuco::test::equal(output_keys.begin(),
output_keys.end(),
thrust::counting_iterator<Key>(0),
thrust::equal_to<Key>{}));
}
}

TEMPLATE_TEST_CASE_SIG(
"Large input",
"cuco::static_set large input test",
"",
((typename Key, cuco::test::probe_sequence Probe, int CGSize), Key, Probe, CGSize),
(int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, cuco::test::probe_sequence::double_hashing, 2))
{
Expand Down

0 comments on commit abc5095

Please sign in to comment.