Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix retrieve_all for containers with large capacity #580

Merged
merged 8 commits into from
Aug 16, 2024
79 changes: 47 additions & 32 deletions include/cuco/detail/open_addressing/open_addressing_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -639,45 +639,60 @@ class open_addressing_impl {
template <typename OutputIt>
[[nodiscard]] OutputIt retrieve_all(OutputIt output_begin, cuda::stream_ref stream) const
{
std::size_t temp_storage_bytes = 0;
using temp_allocator_type =
typename std::allocator_traits<allocator_type>::template rebind_alloc<char>;

cuco::detail::index_type constexpr stride = std::numeric_limits<int32_t>::max();
PointKernel marked this conversation as resolved.
Show resolved Hide resolved

cuco::detail::index_type h_num_out{0};
auto temp_allocator = temp_allocator_type{this->allocator()};
auto d_num_out = reinterpret_cast<size_type*>(
std::allocator_traits<temp_allocator_type>::allocate(temp_allocator, sizeof(size_type)));
auto const begin = thrust::make_transform_iterator(
thrust::counting_iterator<size_type>{0},
open_addressing_ns::detail::get_slot<has_payload, storage_ref_type>(this->storage_ref()));
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
this->empty_key_sentinel(), this->erased_key_sentinel()};
CUCO_CUDA_TRY(cub::DeviceSelect::If(nullptr,
temp_storage_bytes,
begin,
output_begin,
d_num_out,
this->capacity(),
is_filled,
stream.get()));

// Allocate temporary storage
auto d_temp_storage = temp_allocator.allocate(temp_storage_bytes);

CUCO_CUDA_TRY(cub::DeviceSelect::If(d_temp_storage,
temp_storage_bytes,
begin,
output_begin,
d_num_out,
this->capacity(),
is_filled,
stream.get()));

size_type h_num_out;
CUCO_CUDA_TRY(cudaMemcpyAsync(
&h_num_out, d_num_out, sizeof(size_type), cudaMemcpyDeviceToHost, stream.get()));
stream.wait();

for (cuco::detail::index_type offset = 0;
offset < static_cast<cuco::detail::index_type>(this->capacity());
offset += stride) {
auto const num_items =
std::min(static_cast<cuco::detail::index_type>(this->capacity()) - offset, stride);
auto const begin = thrust::make_transform_iterator(
thrust::counting_iterator{static_cast<size_type>(offset)},
open_addressing_ns::detail::get_slot<has_payload, storage_ref_type>(this->storage_ref()));
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
this->empty_key_sentinel(), this->erased_key_sentinel()};

std::size_t temp_storage_bytes = 0;

CUCO_CUDA_TRY(cub::DeviceSelect::If(nullptr,
temp_storage_bytes,
begin,
output_begin + h_num_out,
d_num_out,
static_cast<int32_t>(num_items),
is_filled,
stream.get()));

// Allocate temporary storage
auto d_temp_storage = temp_allocator.allocate(temp_storage_bytes);

CUCO_CUDA_TRY(cub::DeviceSelect::If(d_temp_storage,
temp_storage_bytes,
begin,
output_begin + h_num_out,
d_num_out,
static_cast<int32_t>(num_items),
is_filled,
stream.get()));

size_type temp_count;
CUCO_CUDA_TRY(cudaMemcpyAsync(
&temp_count, d_num_out, sizeof(size_type), cudaMemcpyDeviceToHost, stream.get()));
stream.wait();
h_num_out += temp_count;
temp_allocator.deallocate(d_temp_storage, temp_storage_bytes);
}

std::allocator_traits<temp_allocator_type>::deallocate(
temp_allocator, reinterpret_cast<char*>(d_num_out), sizeof(size_type));
temp_allocator.deallocate(d_temp_storage, temp_storage_bytes);

return output_begin + h_num_out;
}
Expand Down
19 changes: 16 additions & 3 deletions tests/static_set/large_input_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,27 @@ void test_unique_sequence(Set& set, bool* res_begin, std::size_t num_keys)
set.contains(keys_begin, keys_end, res_begin);
REQUIRE(cuco::test::all_of(res_begin, res_begin + num_keys, thrust::identity{}));
}

SECTION("All inserted key/value pairs can be retrieved.")
{
auto output_keys = thrust::device_vector<Key>(num_keys);

auto const keys_end = set.retrieve_all(output_keys.begin());
REQUIRE(static_cast<std::size_t>(std::distance(output_keys.begin(), keys_end)) == num_keys);

thrust::sort(output_keys.begin(), keys_end);

REQUIRE(cuco::test::equal(output_keys.begin(),
output_keys.end(),
thrust::counting_iterator<Key>(0),
thrust::equal_to<Key>{}));
}
}

TEMPLATE_TEST_CASE_SIG(
"Large input",
"cuco::static_set large input test",
"",
((typename Key, cuco::test::probe_sequence Probe, int CGSize), Key, Probe, CGSize),
(int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, cuco::test::probe_sequence::double_hashing, 2))
{
Expand Down
Loading