diff --git a/include/cuco/detail/static_multimap/kernels.cuh b/include/cuco/detail/static_multimap/kernels.cuh index f3820bf64..7aadfb8b9 100644 --- a/include/cuco/detail/static_multimap/kernels.cuh +++ b/include/cuco/detail/static_multimap/kernels.cuh @@ -205,6 +205,75 @@ __global__ void contains( } } +/** + * @brief Indicates whether the pairs in the range `[first, first + n)` are contained in the map if + * `pred` of the corresponding stencil returns true. + * + * If `pred( *(stencil + i) )` is true, stores `true` or `false` to `(output_begin + i)` indicating + * if the pair `*(first + i)` exists in the map. If `pred( *(stencil + i) )` is false, stores false + * to `(output_begin + i)`. + * + * Uses the CUDA Cooperative Groups API to leverage groups of multiple threads to perform the + * contains operation for each element. This provides a significant boost in throughput compared + * to the non Cooperative Group `contains` at moderate to high load factors. + * + * @tparam block_size The size of the thread block + * @tparam tile_size The number of threads in the Cooperative Groups + * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose value_type is + * convertible to Predicate's argument type + * @tparam OutputIt Device accessible output iterator assignable from `bool` + * @tparam viewT Type of device view allowing access of hash map storage + * @tparam PairEqual Binary callable type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type. + * + * @param first Beginning of the sequence of elements + * @param last End of the sequence of elements + * @param output_begin Beginning of the sequence of booleans for the presence of each element + * @param view Device view used to access the hash map's slot storage + * @param equal The binary function to compare input element and slot content for equality + * @param pred Predicate to test on every element in the range `[stencil, stencil + n)` + */ +template +__global__ void pair_contains_if_n(InputIt first, + StencilIt stencil, + OutputIt output_begin, + std::size_t n, + viewT view, + PairEqual pair_equal, + Predicate pred) +{ + auto tile = cg::tiled_partition(cg::this_thread_block()); + auto tid = block_size * blockIdx.x + threadIdx.x; + auto idx = tid / tile_size; + __shared__ bool writeBuffer[block_size]; + + while (idx < n) { + typename std::iterator_traits::value_type pair = *(first + idx); + auto found = pred(*(stencil + idx)) ? view.pair_contains(tile, pair, pair_equal) : false; + + /* + * The ld.relaxed.gpu instruction used in view.find causes L1 to + * flush more frequently, causing increased sector stores from L2 to global memory. + * By writing results to shared memory and then synchronizing before writing back + * to global, we no longer rely on L1, preventing the increase in sector stores from + * L2 to global and improving performance. + */ + if (tile.thread_rank() == 0) { writeBuffer[threadIdx.x / tile_size] = found; } + __syncthreads(); + if (tile.thread_rank() == 0) { *(output_begin + idx) = writeBuffer[threadIdx.x / tile_size]; } + idx += (gridDim.x * block_size) / tile_size; + } +} + /** * @brief Counts the occurrences of keys in `[first, last)` contained in the multimap. * diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index ddec2e4a2..b22bce577 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -121,7 +121,6 @@ void static_multimap::contains( detail::contains <<>>(first, last, output_begin, view, key_equal); - CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); } template ::pair_contains detail::contains <<>>(first, last, output_begin, view, pair_equal); - CUCO_CUDA_TRY(cudaStreamSynchronize(stream)); +} + +template +template +void static_multimap::pair_contains_if( + InputIt first, + InputIt last, + StencilIt stencil, + OutputIt output_begin, + PairEqual pair_equal, + Predicate pred, + cudaStream_t stream) const +{ + auto const num_pairs = std::distance(first, last); + if (num_pairs == 0) { return; } + + auto constexpr block_size = 128; + auto constexpr stride = 1; + auto const grid_size = (cg_size() * num_pairs + stride * block_size - 1) / (stride * block_size); + auto view = get_device_view(); + + detail::pair_contains_if_n<<>>( + first, stencil, output_begin, num_pairs, view, pair_equal, pred); } template std::iterator_traits::value_type::first_type + * and Key type. std::invoke_result::value_type::first_type, Key> + * must be well-formed. + * + * @tparam InputIt Device accessible random access input iterator + * @tparam StencilIt Device accessible random access iterator whose value_type is + * convertible to Predicate's argument type + * @tparam OutputIt Device accessible output iterator assignable from `bool` + * @tparam PairEqual Binary callable type used to compare input pair and slot content for equality + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type. + * + * @param first Beginning of the sequence of pairs + * @param last End of the sequence of pairs + * @param stencil Beginning of the stencil sequence + * @param output_begin Beginning of the output sequence indicating whether each pair is present + * @param pair_equal The binary function to compare input pair and slot content for equality + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param stream CUDA stream used for contains + */ + template + void pair_contains_if(InputIt first, + InputIt last, + StencilIt stencil, + OutputIt output_begin, + PairEqual pair_equal, + Predicate pred, + cudaStream_t stream = 0) const; + /** * @brief Counts the occurrences of keys in `[first, last)` contained in the multimap. * diff --git a/tests/static_multimap/pair_function_test.cu b/tests/static_multimap/pair_function_test.cu index c5442533b..81d90bf68 100644 --- a/tests/static_multimap/pair_function_test.cu +++ b/tests/static_multimap/pair_function_test.cu @@ -72,6 +72,25 @@ __inline__ void test_pair_functions(Map& map, PairIt pair_begin, std::size_t num res_begin + num_pairs / 2, res_begin + num_pairs, false_iter, thrust::equal_to{})); } + SECTION("pair_contains_if checks the input pair only if the corresponding stencil returns true.") + { + thrust::device_vector result(num_pairs); + auto res_begin = result.begin(); + auto pred = [num_pairs] __device__(int32_t s) { + if (s < num_pairs / 2) { return false; } + return true; + }; + auto count_iter = thrust::make_counting_iterator(0); + map.pair_contains_if( + pair_begin, pair_begin + num_pairs, count_iter, res_begin, pair_equal{}, pred); + + auto false_iter = thrust::make_constant_iterator(false); + + // All false since the stencil of the first half is false and the second half is not inserted + REQUIRE( + cuco::test::equal(res_begin, res_begin + num_pairs, false_iter, thrust::equal_to{})); + } + SECTION("Output of pair_count and pair_retrieve should be coherent.") { auto num = map.pair_count(pair_begin, pair_begin + num_pairs, pair_equal{});