diff --git a/include/cuco/detail/open_addressing/open_addressing_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_impl.cuh index f9c35e0ff..772e8a667 100644 --- a/include/cuco/detail/open_addressing/open_addressing_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_impl.cuh @@ -639,45 +639,61 @@ class open_addressing_impl { template [[nodiscard]] OutputIt retrieve_all(OutputIt output_begin, cuda::stream_ref stream) const { - std::size_t temp_storage_bytes = 0; using temp_allocator_type = typename std::allocator_traits::template rebind_alloc; + + cuco::detail::index_type constexpr stride = std::numeric_limits::max(); + + cuco::detail::index_type h_num_out{0}; auto temp_allocator = temp_allocator_type{this->allocator()}; auto d_num_out = reinterpret_cast( std::allocator_traits::allocate(temp_allocator, sizeof(size_type))); - auto const begin = thrust::make_transform_iterator( - thrust::counting_iterator{0}, - open_addressing_ns::detail::get_slot(this->storage_ref())); - auto const is_filled = open_addressing_ns::detail::slot_is_filled{ - this->empty_key_sentinel(), this->erased_key_sentinel()}; - CUCO_CUDA_TRY(cub::DeviceSelect::If(nullptr, - temp_storage_bytes, - begin, - output_begin, - d_num_out, - this->capacity(), - is_filled, - stream.get())); - - // Allocate temporary storage - auto d_temp_storage = temp_allocator.allocate(temp_storage_bytes); - - CUCO_CUDA_TRY(cub::DeviceSelect::If(d_temp_storage, - temp_storage_bytes, - begin, - output_begin, - d_num_out, - this->capacity(), - is_filled, - stream.get())); - - size_type h_num_out; - CUCO_CUDA_TRY(cudaMemcpyAsync( - &h_num_out, d_num_out, sizeof(size_type), cudaMemcpyDeviceToHost, stream.get())); - stream.wait(); + + // TODO: PR #580 to be reverted once https://github.com/NVIDIA/cccl/issues/1422 is resolved + for (cuco::detail::index_type offset = 0; + offset < static_cast(this->capacity()); + offset += stride) { + auto const num_items = + std::min(static_cast(this->capacity()) - offset, stride); + auto const begin = thrust::make_transform_iterator( + thrust::counting_iterator{static_cast(offset)}, + open_addressing_ns::detail::get_slot(this->storage_ref())); + auto const is_filled = open_addressing_ns::detail::slot_is_filled{ + this->empty_key_sentinel(), this->erased_key_sentinel()}; + + std::size_t temp_storage_bytes = 0; + + CUCO_CUDA_TRY(cub::DeviceSelect::If(nullptr, + temp_storage_bytes, + begin, + output_begin + h_num_out, + d_num_out, + static_cast(num_items), + is_filled, + stream.get())); + + // Allocate temporary storage + auto d_temp_storage = temp_allocator.allocate(temp_storage_bytes); + + CUCO_CUDA_TRY(cub::DeviceSelect::If(d_temp_storage, + temp_storage_bytes, + begin, + output_begin + h_num_out, + d_num_out, + static_cast(num_items), + is_filled, + stream.get())); + + size_type temp_count; + CUCO_CUDA_TRY(cudaMemcpyAsync( + &temp_count, d_num_out, sizeof(size_type), cudaMemcpyDeviceToHost, stream.get())); + stream.wait(); + h_num_out += temp_count; + temp_allocator.deallocate(d_temp_storage, temp_storage_bytes); + } + std::allocator_traits::deallocate( temp_allocator, reinterpret_cast(d_num_out), sizeof(size_type)); - temp_allocator.deallocate(d_temp_storage, temp_storage_bytes); return output_begin + h_num_out; } diff --git a/tests/static_set/large_input_test.cu b/tests/static_set/large_input_test.cu index 481762e5f..d4cef0201 100644 --- a/tests/static_set/large_input_test.cu +++ b/tests/static_set/large_input_test.cu @@ -53,14 +53,27 @@ void test_unique_sequence(Set& set, bool* res_begin, std::size_t num_keys) set.contains(keys_begin, keys_end, res_begin); REQUIRE(cuco::test::all_of(res_begin, res_begin + num_keys, thrust::identity{})); } + + SECTION("All inserted key/value pairs can be retrieved.") + { + auto output_keys = thrust::device_vector(num_keys); + + auto const keys_end = set.retrieve_all(output_keys.begin()); + REQUIRE(static_cast(std::distance(output_keys.begin(), keys_end)) == num_keys); + + thrust::sort(output_keys.begin(), keys_end); + + REQUIRE(cuco::test::equal(output_keys.begin(), + output_keys.end(), + thrust::counting_iterator(0), + thrust::equal_to{})); + } } TEMPLATE_TEST_CASE_SIG( - "Large input", + "cuco::static_set large input test", "", ((typename Key, cuco::test::probe_sequence Probe, int CGSize), Key, Probe, CGSize), - (int32_t, cuco::test::probe_sequence::double_hashing, 1), - (int32_t, cuco::test::probe_sequence::double_hashing, 2), (int64_t, cuco::test::probe_sequence::double_hashing, 1), (int64_t, cuco::test::probe_sequence::double_hashing, 2)) {