diff --git a/include/cuco/detail/probe_sequence_impl.cuh b/include/cuco/detail/probe_sequence_impl.cuh index f3dcc2436..688b2f28f 100644 --- a/include/cuco/detail/probe_sequence_impl.cuh +++ b/include/cuco/detail/probe_sequence_impl.cuh @@ -23,6 +23,8 @@ #include +#include + namespace cuco { namespace detail { @@ -208,6 +210,24 @@ class linear_probing_impl template __device__ __forceinline__ iterator initial_slot(cooperative_groups::thread_block_tile const& g, ProbeKey const& k) noexcept + { + return const_cast(std::as_const(*this).initial_slot(g, k)); + } + + /** + * @brief Returns the initial slot for a given key `k`. + * + * If vector-load is enabled, the return slot is always even to avoid illegal memory access. + * + * @tparam ProbeKey Probe key type + * + * @param g the Cooperative Group for which the initial slot is needed + * @param k The key to get the slot for + * @return Pointer to the initial slot for `k` + */ + template + __device__ __forceinline__ const_iterator initial_slot( + cooperative_groups::thread_block_tile const& g, ProbeKey const& k) const noexcept { auto const hash_value = [&]() { auto const tmp = hash_(k); @@ -237,6 +257,19 @@ class linear_probing_impl * @return The next slot after `s` */ __device__ __forceinline__ iterator next_slot(iterator s) noexcept + { + return const_cast(std::as_const(*this).next_slot(s)); + } + + /** + * @brief Given a slot `s`, returns the next slot. + * + * If `s` is the last slot, wraps back around to the first slot. + * + * @param s The slot to advance + * @return The next slot after `s` + */ + __device__ __forceinline__ const_iterator next_slot(const_iterator s) const noexcept { std::size_t index = s - slots_; std::size_t offset; @@ -331,6 +364,25 @@ class double_hashing_impl template __device__ __forceinline__ iterator initial_slot(cooperative_groups::thread_block_tile const& g, ProbeKey const& k) noexcept + { + return const_cast(std::as_const(*this).initial_slot(g, k)); + } + + /** + * @brief Returns the initial slot for a given key `k`. + * + * If vector-load is enabled, the return slot is always a multiple of (`cg_size` * `vector_width`) + * to avoid illegal memory access. + * + * @tparam ProbeKey Probe key type + * + * @param g the Cooperative Group for which the initial slot is needed + * @param k The key to get the slot for + * @return Pointer to the initial slot for `k` + */ + template + __device__ __forceinline__ const_iterator initial_slot( + cooperative_groups::thread_block_tile const& g, ProbeKey const& k) const noexcept { std::size_t index; auto const hash_value = hash1_(k); @@ -357,16 +409,29 @@ class double_hashing_impl * @return The next slot after `s` */ __device__ __forceinline__ iterator next_slot(iterator s) noexcept + { + return const_cast(std::as_const(*this).next_slot(s)); + } + + /** + * @brief Given a slot `s`, returns the next slot. + * + * If `s` is the last slot, wraps back around to the first slot. + * + * @param s The slot to advance + * @return The next slot after `s` + */ + __device__ __forceinline__ const_iterator next_slot(const_iterator s) const noexcept { std::size_t index = s - slots_; return &slots_[(index + step_size_) % capacity_]; } private: - Hash1 hash1_; ///< The first unary callable used to hash the key - Hash2 hash2_; ///< The second unary callable used to determine step size - std::size_t step_size_; ///< The step stride when searching for the next slot -}; // class double_hashing + Hash1 hash1_; ///< The first unary callable used to hash the key + Hash2 hash2_; ///< The second unary callable used to determine step size + mutable std::size_t step_size_; ///< The step stride when searching for the next slot +}; // class double_hashing /** * @brief Probe sequence used internally by hash map. diff --git a/include/cuco/detail/static_map.inl b/include/cuco/detail/static_map.inl index 9eb31db14..a1bb4f016 100644 --- a/include/cuco/detail/static_map.inl +++ b/include/cuco/detail/static_map.inl @@ -261,7 +261,7 @@ void static_map::contains(InputIt first, OutputIt output_begin, Hash hash, KeyEqual key_equal, - cudaStream_t stream) + cudaStream_t stream) const { auto num_keys = std::distance(first, last); if (num_keys == 0) { return; } diff --git a/include/cuco/detail/static_multimap/device_view_impl.inl b/include/cuco/detail/static_multimap/device_view_impl.inl index d0c9bc7b9..30fb56544 100644 --- a/include/cuco/detail/static_multimap/device_view_impl.inl +++ b/include/cuco/detail/static_multimap/device_view_impl.inl @@ -139,7 +139,7 @@ class static_multimap::device_view_ * @param current_slot The given slot to load from */ __device__ __forceinline__ void load_pair_array(value_type* arr, - const_iterator current_slot) noexcept + const_iterator current_slot) const noexcept { if constexpr (sizeof(value_type) == 4) { auto const tmp = *reinterpret_cast(current_slot); @@ -567,32 +567,33 @@ class static_multimap::device_view_ } /** - * @brief Indicates whether the key `k` exists in the map using vector loads. + * @brief Indicates whether the probe `element` exists in the map using vector loads. * - * If the key `k` was inserted into the map, `contains` returns - * true. Otherwise, it returns false. Uses the CUDA Cooperative Groups API to - * to leverage multiple threads to perform a single `contains` operation. This provides a - * significant boost in throughput compared to the non Cooperative Group based - * `contains` at moderate to high load factors. + * If `element` was inserted into the map, `contains` returns true. Otherwise, it returns false. + * Uses the CUDA Cooperative Groups API to leverage multiple threads to perform a single + * `contains` operation. This provides a significant boost in throughput compared to the non + * Cooperative Group based `contains` at moderate to high load factors. * + * @tparam is_pair_contains `true` if it's a `pair_contains` implementation * @tparam uses_vector_load Boolean flag indicating whether vector loads are used - * @tparam ProbeKey Probe key type - * @tparam KeyEqual Binary callable type + * @tparam ProbeT Probe data type + * @tparam Equal Binary callable type * * @param g The Cooperative Group used to perform the contains operation - * @param k The key to search for - * @param key_equal The binary callable used to compare two keys - * for equality - * @return A boolean indicating whether the key/value pair - * containing `k` was inserted + * @param element The probe element to search for + * @param equal The binary function to compare input element and slot content for equality + * @return A boolean indicating whether the key/value pair represented by `element` was inserted */ - template + template __device__ __forceinline__ std::enable_if_t contains( cooperative_groups::thread_block_tile const& g, - ProbeKey const& k, - KeyEqual key_equal) noexcept + ProbeT const& element, + Equal equal) const noexcept { - auto current_slot = initial_slot(g, k); + auto current_slot = [&]() { + if constexpr (is_pair_contains) { return initial_slot(g, element.first); } + if constexpr (not is_pair_contains) { return initial_slot(g, element); } + }(); while (true) { value_type arr[2]; @@ -602,8 +603,22 @@ class static_multimap::device_view_ detail::bitwise_compare(arr[0].first, this->get_empty_key_sentinel()); auto const second_slot_is_empty = detail::bitwise_compare(arr[1].first, this->get_empty_key_sentinel()); - auto const first_equals = (not first_slot_is_empty and key_equal(arr[0].first, k)); - auto const second_equals = (not second_slot_is_empty and key_equal(arr[1].first, k)); + auto const first_equals = [&]() { + if constexpr (is_pair_contains) { + return not first_slot_is_empty and equal(arr[0], element); + } + if constexpr (not is_pair_contains) { + return not first_slot_is_empty and equal(arr[0].first, element); + } + }(); + auto const second_equals = [&]() { + if constexpr (is_pair_contains) { + return not second_slot_is_empty and equal(arr[1], element); + } + if constexpr (not is_pair_contains) { + return not second_slot_is_empty and equal(arr[1].first, element); + } + }(); // the key we were searching for was found by one of the threads, so we return true if (g.any(first_equals or second_equals)) { return true; } @@ -618,32 +633,33 @@ class static_multimap::device_view_ } /** - * @brief Indicates whether the key `k` exists in the map using scalar loads. + * @brief Indicates whether `element` exists in the map using scalar loads. * - * If the key `k` was inserted into the map, `contains` returns - * true. Otherwise, it returns false. Uses the CUDA Cooperative Groups API to - * to leverage multiple threads to perform a single `contains` operation. This provides a - * significant boost in throughput compared to the non Cooperative Group - * `contains` at moderate to high load factors. + * If `element` was inserted into the map, `contains` returns true. Otherwise, it returns false. + * Uses the CUDA Cooperative Groups API to leverage multiple threads to perform a single + * `contains` operation. This provides a significant boost in throughput compared to the non + * Cooperative Group `contains` at moderate to high load factors. * + * @tparam is_pair_contains `true` if it's a `pair_contains` implementation * @tparam uses_vector_load Boolean flag indicating whether vector loads are used - * @tparam ProbeKey Probe key type - * @tparam KeyEqual Binary callable type + * @tparam ProbeT Probe data type + * @tparam Equal Binary callable type * * @param g The Cooperative Group used to perform the contains operation - * @param k The key to search for - * @param key_equal The binary callable used to compare two keys - * for equality - * @return A boolean indicating whether the key/value pair - * containing `k` was inserted + * @param element The probe element to search for + * @param equal The binary function to compare input element and slot content for equality + * @return A boolean indicating whether the key/value pair represented by `element` was inserted */ - template + template __device__ __forceinline__ std::enable_if_t contains( cooperative_groups::thread_block_tile const& g, - ProbeKey const& k, - KeyEqual key_equal) noexcept + ProbeT const& element, + Equal equal) const noexcept { - auto current_slot = initial_slot(g, k); + auto current_slot = [&]() { + if constexpr (is_pair_contains) { return initial_slot(g, element.first); } + if constexpr (not is_pair_contains) { return initial_slot(g, element); } + }(); while (true) { value_type slot_contents = *reinterpret_cast(current_slot); @@ -654,7 +670,14 @@ class static_multimap::device_view_ auto const slot_is_empty = detail::bitwise_compare(existing_key, this->get_empty_key_sentinel()); - auto const equals = (not slot_is_empty and key_equal(existing_key, k)); + auto const equals = [&]() { + if constexpr (is_pair_contains) { + return not slot_is_empty and equal(slot_contents, element); + } + if constexpr (not is_pair_contains) { + return not slot_is_empty and equal(existing_key, element); + } + }(); // the key we were searching for was found by one of the threads, so we return true if (g.any(equals)) { return true; } diff --git a/include/cuco/detail/static_multimap/kernels.cuh b/include/cuco/detail/static_multimap/kernels.cuh index 511d5c9cc..f3820bf64 100644 --- a/include/cuco/detail/static_multimap/kernels.cuh +++ b/include/cuco/detail/static_multimap/kernels.cuh @@ -25,6 +25,8 @@ #include +#include + namespace cuco { namespace detail { namespace cg = cooperative_groups; @@ -144,45 +146,50 @@ __global__ void insert_if_n(InputIt first, StencilIt s, std::size_t n, viewT vie } /** - * @brief Indicates whether the keys in the range `[first, last)` are contained in the map. + * @brief Indicates whether the elements in the range `[first, last)` are contained in the map. * - * Stores `true` or `false` to `(output + i)` indicating if the key `*(first + i)` exists in the + * Stores `true` or `false` to `(output + i)` indicating if the element `*(first + i)` exists in the * map. * * Uses the CUDA Cooperative Groups API to leverage groups of multiple threads to perform the - * contains operation for each key. This provides a significant boost in throughput compared + * contains operation for each element. This provides a significant boost in throughput compared * to the non Cooperative Group `contains` at moderate to high load factors. * + * @tparam is_pair_contains `true` if it's a `pair_contains` implementation * @tparam block_size The size of the thread block * @tparam tile_size The number of threads in the Cooperative Groups - * @tparam InputIt Device accessible input iterator whose `value_type` is - * convertible to the map's `key_type` - * @tparam OutputIt Device accessible output iterator whose `value_type` is convertible from `bool` + * @tparam InputIt Device accessible input iterator + * @tparam OutputIt Device accessible output iterator assignable from `bool` * @tparam viewT Type of device view allowing access of hash map storage - * @tparam KeyEqual Binary callable type - * @param first Beginning of the sequence of keys - * @param last End of the sequence of keys - * @param output_begin Beginning of the sequence of booleans for the presence of each key + * @tparam Equal Binary callable type + * + * @param first Beginning of the sequence of elements + * @param last End of the sequence of elements + * @param output_begin Beginning of the sequence of booleans for the presence of each element * @param view Device view used to access the hash map's slot storage - * @param key_equal The binary function to compare two keys for equality + * @param equal The binary function to compare input element and slot content for equality */ -template + typename Equal> __global__ void contains( - InputIt first, InputIt last, OutputIt output_begin, viewT view, KeyEqual key_equal) + InputIt first, InputIt last, OutputIt output_begin, viewT view, Equal equal) { - auto tile = cg::tiled_partition(cg::this_thread_block()); - auto tid = block_size * blockIdx.x + threadIdx.x; - auto key_idx = tid / tile_size; + auto tile = cg::tiled_partition(cg::this_thread_block()); + auto tid = block_size * blockIdx.x + threadIdx.x; + auto idx = tid / tile_size; __shared__ bool writeBuffer[block_size]; - while (first + key_idx < last) { - auto key = *(first + key_idx); - auto found = view.contains(tile, key, key_equal); + while (first + idx < last) { + typename std::iterator_traits::value_type element = *(first + idx); + auto found = [&]() { + if constexpr (is_pair_contains) { return view.pair_contains(tile, element, equal); } + if constexpr (not is_pair_contains) { return view.contains(tile, element, equal); } + }(); /* * The ld.relaxed.gpu instruction used in view.find causes L1 to @@ -193,10 +200,8 @@ __global__ void contains( */ if (tile.thread_rank() == 0) { writeBuffer[threadIdx.x / tile_size] = found; } __syncthreads(); - if (tile.thread_rank() == 0) { - *(output_begin + key_idx) = writeBuffer[threadIdx.x / tile_size]; - } - key_idx += (gridDim.x * block_size) / tile_size; + if (tile.thread_rank() == 0) { *(output_begin + idx) = writeBuffer[threadIdx.x / tile_size]; } + idx += (gridDim.x * block_size) / tile_size; } } diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index 89edb434e..ddec2e4a2 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -113,16 +113,41 @@ void static_multimap::contains( auto const num_keys = std::distance(first, last); if (num_keys == 0) { return; } - auto constexpr block_size = 128; - auto constexpr stride = 1; + auto constexpr is_pair_contains = false; + auto constexpr block_size = 128; + auto constexpr stride = 1; auto const grid_size = (cg_size() * num_keys + stride * block_size - 1) / (stride * block_size); auto view = get_device_view(); - detail::contains + detail::contains <<>>(first, last, output_begin, view, key_equal); CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); } +template +template +void static_multimap::pair_contains( + InputIt first, InputIt last, OutputIt output_begin, PairEqual pair_equal, cudaStream_t stream) + const +{ + auto const num_pairs = std::distance(first, last); + if (num_pairs == 0) { return; } + + auto constexpr is_pair_contains = true; + auto constexpr block_size = 128; + auto constexpr stride = 1; + auto const grid_size = (cg_size() * num_pairs + stride * block_size - 1) / (stride * block_size); + auto view = get_device_view(); + + detail::contains + <<>>(first, last, output_begin, view, pair_equal); + CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); +} + template ::device_view::contains( cooperative_groups::thread_block_tile const& g, ProbeKey const& k, - KeyEqual key_equal) noexcept + KeyEqual key_equal) const noexcept +{ + constexpr bool is_pair_contains = false; + return impl_.contains(g, k, key_equal); +} + +template +template +__device__ __forceinline__ bool +static_multimap::device_view::pair_contains( + cooperative_groups::thread_block_tile const& g, + ProbePair const& p, + PairEqual pair_equal) const noexcept { - return impl_.contains(g, k, key_equal); + constexpr bool is_pair_contains = true; + return impl_.contains(g, p, pair_equal); } template std::iterator_traits::value_type::first_type + * and Key type. std::invoke_result::value_type::first_type, Key> + * must be well-formed. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam OutputIt Device accessible output iterator assignable from `bool` + * @tparam PairEqual Binary callable type used to compare input pair and slot content for equality + * + * @param first Beginning of the sequence of pairs + * @param last End of the sequence of pairs + * @param output_begin Beginning of the output sequence indicating whether each pair is present + * @param pair_equal The binary function to compare input pair and slot content for equality + * @param stream CUDA stream used for contains + */ + template + void pair_contains(InputIt first, + InputIt last, + OutputIt output_begin, + PairEqual pair_equal, + cudaStream_t stream = 0) const; + /** * @brief Counts the occurrences of keys in `[first, last)` contained in the multimap. * @@ -859,7 +887,37 @@ class static_multimap { __device__ __forceinline__ bool contains( cooperative_groups::thread_block_tile const& g, ProbeKey const& k, - KeyEqual key_equal = KeyEqual{}) noexcept; + KeyEqual key_equal = KeyEqual{}) const noexcept; + + /** + * @brief Indicates whether the pair `p` exists in the map. + * + * If the pair `p` was inserted into the map, `contains` returns + * true. Otherwise, it returns false. Uses the CUDA Cooperative Groups API to + * to leverage multiple threads to perform a single `contains` operation. This provides a + * significant boost in throughput compared to the non Cooperative Group + * `contains` at moderate to high load factors. + * + * ProbeSequence hashers should be callable with both ProbePair::first_type and Key type. + * `std::invoke_result` must be well-formed. + * + * If `pair_equal(p, slot_content)` returns true, `hash(p.first) == hash(slot_key)` must + * also be true. + * + * @tparam ProbePair Probe pair type + * @tparam PairEqual Binary callable type + * + * @param g The Cooperative Group used to perform the contains operation + * @param p The pair to search for + * @param pair_equal The binary callable used to compare input pair and slot content + * for equality + * @return A boolean indicating whether the input pair was inserted in the map + */ + template + __device__ __forceinline__ bool pair_contains( + cooperative_groups::thread_block_tile const& g, + ProbePair const& p, + PairEqual pair_equal) const noexcept; /** * @brief Counts the occurrence of a given key contained in multimap. diff --git a/tests/static_multimap/custom_type_test.cu b/tests/static_multimap/custom_type_test.cu index a12e4ec9a..40bdbe8ba 100644 --- a/tests/static_multimap/custom_type_test.cu +++ b/tests/static_multimap/custom_type_test.cu @@ -39,7 +39,7 @@ struct key_pair { }; struct hash_key_pair { - __device__ uint32_t operator()(key_pair k) { return k.a; }; + __device__ uint32_t operator()(key_pair k) const { return k.a; }; }; struct key_pair_equals { @@ -197,6 +197,7 @@ __inline__ void test_custom_key_value_type(Map& map, std::size_t num_pairs) thrust::device_vector contained(num_pairs); map.contains(key_begin, key_begin + num_pairs, contained.begin(), key_pair_equals{}, stream); + REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{})); } diff --git a/tests/static_multimap/heterogeneous_lookup_test.cu b/tests/static_multimap/heterogeneous_lookup_test.cu index 6283af7ce..ebcdda5b6 100644 --- a/tests/static_multimap/heterogeneous_lookup_test.cu +++ b/tests/static_multimap/heterogeneous_lookup_test.cu @@ -64,7 +64,7 @@ struct key_triplet { // User-defined device hasher struct custom_hasher { template - __device__ uint32_t operator()(CustomKey const& k) + __device__ uint32_t operator()(CustomKey const& k) const { return thrust::raw_reference_cast(k).a; }; diff --git a/tests/static_multimap/pair_function_test.cu b/tests/static_multimap/pair_function_test.cu index d96f03b4e..c5442533b 100644 --- a/tests/static_multimap/pair_function_test.cu +++ b/tests/static_multimap/pair_function_test.cu @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -56,6 +57,21 @@ __inline__ void test_pair_functions(Map& map, PairIt pair_begin, std::size_t num return cuco::pair_type{i, i}; }); + SECTION("pair_contains returns true for all inserted pairs and false for non-inserted ones.") + { + thrust::device_vector result(num_pairs); + auto res_begin = result.begin(); + map.pair_contains(pair_begin, pair_begin + num_pairs, res_begin, pair_equal{}); + + auto true_iter = thrust::make_constant_iterator(true); + auto false_iter = thrust::make_constant_iterator(false); + + REQUIRE( + cuco::test::equal(res_begin, res_begin + num_pairs / 2, true_iter, thrust::equal_to{})); + REQUIRE(cuco::test::equal( + res_begin + num_pairs / 2, res_begin + num_pairs, false_iter, thrust::equal_to{})); + } + SECTION("Output of pair_count and pair_retrieve should be coherent.") { auto num = map.pair_count(pair_begin, pair_begin + num_pairs, pair_equal{});