diff --git a/include/cuco/detail/static_multimap/kernels.cuh b/include/cuco/detail/static_multimap/kernels.cuh
index f3820bf64..7aadfb8b9 100644
--- a/include/cuco/detail/static_multimap/kernels.cuh
+++ b/include/cuco/detail/static_multimap/kernels.cuh
@@ -205,6 +205,75 @@ __global__ void contains(
}
}
+/**
+ * @brief Indicates whether the pairs in the range `[first, first + n)` are contained in the map if
+ * `pred` of the corresponding stencil returns true.
+ *
+ * If `pred( *(stencil + i) )` is true, stores `true` or `false` to `(output_begin + i)` indicating
+ * if the pair `*(first + i)` exists in the map. If `pred( *(stencil + i) )` is false, stores false
+ * to `(output_begin + i)`.
+ *
+ * Uses the CUDA Cooperative Groups API to leverage groups of multiple threads to perform the
+ * contains operation for each element. This provides a significant boost in throughput compared
+ * to the non Cooperative Group `contains` at moderate to high load factors.
+ *
+ * @tparam block_size The size of the thread block
+ * @tparam tile_size The number of threads in the Cooperative Groups
+ * @tparam InputIt Device accessible input iterator
+ * @tparam StencilIt Device accessible random access iterator whose value_type is
+ * convertible to Predicate's argument type
+ * @tparam OutputIt Device accessible output iterator assignable from `bool`
+ * @tparam viewT Type of device view allowing access of hash map storage
+ * @tparam PairEqual Binary callable type
+ * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
+ * argument type is convertible from std::iterator_traits::value_type.
+ *
+ * @param first Beginning of the sequence of elements
+ * @param last End of the sequence of elements
+ * @param output_begin Beginning of the sequence of booleans for the presence of each element
+ * @param view Device view used to access the hash map's slot storage
+ * @param equal The binary function to compare input element and slot content for equality
+ * @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
+ */
+template
+__global__ void pair_contains_if_n(InputIt first,
+ StencilIt stencil,
+ OutputIt output_begin,
+ std::size_t n,
+ viewT view,
+ PairEqual pair_equal,
+ Predicate pred)
+{
+ auto tile = cg::tiled_partition(cg::this_thread_block());
+ auto tid = block_size * blockIdx.x + threadIdx.x;
+ auto idx = tid / tile_size;
+ __shared__ bool writeBuffer[block_size];
+
+ while (idx < n) {
+ typename std::iterator_traits::value_type pair = *(first + idx);
+ auto found = pred(*(stencil + idx)) ? view.pair_contains(tile, pair, pair_equal) : false;
+
+ /*
+ * The ld.relaxed.gpu instruction used in view.find causes L1 to
+ * flush more frequently, causing increased sector stores from L2 to global memory.
+ * By writing results to shared memory and then synchronizing before writing back
+ * to global, we no longer rely on L1, preventing the increase in sector stores from
+ * L2 to global and improving performance.
+ */
+ if (tile.thread_rank() == 0) { writeBuffer[threadIdx.x / tile_size] = found; }
+ __syncthreads();
+ if (tile.thread_rank() == 0) { *(output_begin + idx) = writeBuffer[threadIdx.x / tile_size]; }
+ idx += (gridDim.x * block_size) / tile_size;
+ }
+}
+
/**
* @brief Counts the occurrences of keys in `[first, last)` contained in the multimap.
*
diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl
index ddec2e4a2..b22bce577 100644
--- a/include/cuco/detail/static_multimap/static_multimap.inl
+++ b/include/cuco/detail/static_multimap/static_multimap.inl
@@ -121,7 +121,6 @@ void static_multimap::contains(
detail::contains
<<>>(first, last, output_begin, view, key_equal);
- CUCO_CUDA_TRY(cudaStreamSynchronize(stream));
}
template ::pair_contains
detail::contains
<<>>(first, last, output_begin, view, pair_equal);
- CUCO_CUDA_TRY(cudaStreamSynchronize(stream));
+}
+
+template
+template
+void static_multimap::pair_contains_if(
+ InputIt first,
+ InputIt last,
+ StencilIt stencil,
+ OutputIt output_begin,
+ PairEqual pair_equal,
+ Predicate pred,
+ cudaStream_t stream) const
+{
+ auto const num_pairs = std::distance(first, last);
+ if (num_pairs == 0) { return; }
+
+ auto constexpr block_size = 128;
+ auto constexpr stride = 1;
+ auto const grid_size = (cg_size() * num_pairs + stride * block_size - 1) / (stride * block_size);
+ auto view = get_device_view();
+
+ detail::pair_contains_if_n<<>>(
+ first, stencil, output_begin, num_pairs, view, pair_equal, pred);
}
template std::iterator_traits::value_type::first_type
+ * and Key type. std::invoke_result::value_type::first_type, Key>
+ * must be well-formed.
+ *
+ * @tparam InputIt Device accessible random access input iterator
+ * @tparam StencilIt Device accessible random access iterator whose value_type is
+ * convertible to Predicate's argument type
+ * @tparam OutputIt Device accessible output iterator assignable from `bool`
+ * @tparam PairEqual Binary callable type used to compare input pair and slot content for equality
+ * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
+ * argument type is convertible from std::iterator_traits::value_type.
+ *
+ * @param first Beginning of the sequence of pairs
+ * @param last End of the sequence of pairs
+ * @param stencil Beginning of the stencil sequence
+ * @param output_begin Beginning of the output sequence indicating whether each pair is present
+ * @param pair_equal The binary function to compare input pair and slot content for equality
+ * @param pred Predicate to test on every element in the range `[stencil, stencil +
+ * std::distance(first, last))`
+ * @param stream CUDA stream used for contains
+ */
+ template
+ void pair_contains_if(InputIt first,
+ InputIt last,
+ StencilIt stencil,
+ OutputIt output_begin,
+ PairEqual pair_equal,
+ Predicate pred,
+ cudaStream_t stream = 0) const;
+
/**
* @brief Counts the occurrences of keys in `[first, last)` contained in the multimap.
*
diff --git a/tests/static_multimap/pair_function_test.cu b/tests/static_multimap/pair_function_test.cu
index c5442533b..81d90bf68 100644
--- a/tests/static_multimap/pair_function_test.cu
+++ b/tests/static_multimap/pair_function_test.cu
@@ -72,6 +72,25 @@ __inline__ void test_pair_functions(Map& map, PairIt pair_begin, std::size_t num
res_begin + num_pairs / 2, res_begin + num_pairs, false_iter, thrust::equal_to{}));
}
+ SECTION("pair_contains_if checks the input pair only if the corresponding stencil returns true.")
+ {
+ thrust::device_vector result(num_pairs);
+ auto res_begin = result.begin();
+ auto pred = [num_pairs] __device__(int32_t s) {
+ if (s < num_pairs / 2) { return false; }
+ return true;
+ };
+ auto count_iter = thrust::make_counting_iterator(0);
+ map.pair_contains_if(
+ pair_begin, pair_begin + num_pairs, count_iter, res_begin, pair_equal{}, pred);
+
+ auto false_iter = thrust::make_constant_iterator(false);
+
+ // All false since the stencil of the first half is false and the second half is not inserted
+ REQUIRE(
+ cuco::test::equal(res_begin, res_begin + num_pairs, false_iter, thrust::equal_to{}));
+ }
+
SECTION("Output of pair_count and pair_retrieve should be coherent.")
{
auto num = map.pair_count(pair_begin, pair_begin + num_pairs, pair_equal{});