diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index b427feff4..d7b128d88 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -207,6 +207,47 @@ void static_multimapcontains_async(first, last, output_begin, ref(op::contains), stream); } +template +template +void static_multimap:: + contains_if(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream) const +{ + this->contains_if_async(first, last, stencil, pred, output_begin, stream); + stream.wait(); +} + +template +template +void static_multimap:: + contains_if_async(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream) const noexcept +{ + impl_->contains_if_async(first, last, stencil, pred, output_begin, ref(op::contains), stream); +} + template std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator assignable from `bool` + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of booleans for the presence of each key + * @param stream Stream used for executing the kernels + */ + template + void contains_if(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream = {}) const; + + /** + * @brief Asynchronously indicates whether the keys in the range `[first, last)` are contained in + * the map if `pred` of the corresponding stencil returns true. + * + * @note If `pred( *(stencil + i) )` is true, stores `true` or `false` to `(output_begin + i)` + * indicating if the key `*(first + i)` is present in the map. If `pred( *(stencil + i) )` is + * false, stores `false` to `(output_begin + i)`. + * + * @tparam InputIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose value_type is + * convertible to Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and + * argument type is convertible from std::iterator_traits::value_type + * @tparam OutputIt Device accessible output iterator assignable from `bool` + * + * @param first Beginning of the sequence of keys + * @param last End of the sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + + * std::distance(first, last))` + * @param output_begin Beginning of the sequence of booleans for the presence of each key + * @param stream Stream used for executing the kernels + */ + template + void contains_if_async(InputIt first, + InputIt last, + StencilIt stencil, + Predicate pred, + OutputIt output_begin, + cuda::stream_ref stream = {}) const noexcept; + /** * @brief Gets the maximum number of elements the hash map can hold. * diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f2f882f01..dde1317b0 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -111,7 +111,7 @@ ConfigureTest(STATIC_MULTIMAP_TEST static_multimap/custom_pair_retrieve_test.cu static_multimap/custom_type_test.cu static_multimap/heterogeneous_lookup_test.cu - static_multimap/insert_test.cu + static_multimap/insert_contains_test.cu static_multimap/insert_if_test.cu static_multimap/multiplicity_test.cu static_multimap/non_match_test.cu diff --git a/tests/static_multimap/insert_test.cu b/tests/static_multimap/insert_contains_test.cu similarity index 83% rename from tests/static_multimap/insert_test.cu rename to tests/static_multimap/insert_contains_test.cu index 6b8535a9d..a3fa36648 100644 --- a/tests/static_multimap/insert_test.cu +++ b/tests/static_multimap/insert_contains_test.cu @@ -57,10 +57,28 @@ void test_insert(Map& map, std::size_t num_keys) map.contains(keys_begin, keys_begin + num_keys, d_contained.begin()); REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), thrust::identity{})); } + + SECTION("Conditional contains should return true on even inputs.") + { + auto is_even = + cuda::proclaim_return_type([] __device__(auto const& i) { return i % 2 == 0; }); + auto zip_equal = cuda::proclaim_return_type( + [] __device__(auto const& p) { return thrust::get<0>(p) == thrust::get<1>(p); }); + + map.contains_if(keys_begin, + keys_begin + num_keys, + thrust::counting_iterator(0), + is_even, + d_contained.begin()); + auto gold_iter = + thrust::make_transform_iterator(thrust::counting_iterator(0), is_even); + auto zip = thrust::make_zip_iterator(thrust::make_tuple(d_contained.begin(), gold_iter)); + REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal)); + } } TEMPLATE_TEST_CASE_SIG( - "static_multimap insert test", + "static_multimap insert/contains test", "", ((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize), Key,