Skip to content

Commit

Permalink
Add multimap conditional contains
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Aug 6, 2024
1 parent bdb7fba commit b82a666
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 2 deletions.
41 changes: 41 additions & 0 deletions include/cuco/detail/static_multimap/static_multimap.inl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,47 @@ void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator,
impl_->contains_async(first, last, output_begin, ref(op::contains), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
contains_if(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream) const
{
this->contains_if_async(first, last, stencil, pred, output_begin, stream);
stream.wait();
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
contains_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream) const noexcept
{
impl_->contains_if_async(first, last, stencil, pred, output_begin, ref(op::contains), stream);
}

template <class Key,
class T,
class Extent,
Expand Down
64 changes: 64 additions & 0 deletions include/cuco/static_multimap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,70 @@ class static_multimap {
OutputIt output_begin,
cuda::stream_ref stream = {}) const noexcept;

/**
* @brief Indicates whether the keys in the range `[first, last)` are contained in the map if
* `pred` of the corresponding stencil returns true.
*
* @note If `pred( *(stencil + i) )` is true, stores `true` or `false` to `(output_begin + i)`
* indicating if the key `*(first + i)` is present in the map. If `pred( *(stencil + i) )` is
* false, stores `false` to `(output_begin + i)`.
* @note This function synchronizes the given stream. For asynchronous execution use
* `contains_if_async`.
*
* @tparam InputIt Device accessible input iterator
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
* @tparam OutputIt Device accessible output iterator assignable from `bool`
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param output_begin Beginning of the sequence of booleans for the presence of each key
* @param stream Stream used for executing the kernels
*/
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void contains_if(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief Asynchronously indicates whether the keys in the range `[first, last)` are contained in
* the map if `pred` of the corresponding stencil returns true.
*
* @note If `pred( *(stencil + i) )` is true, stores `true` or `false` to `(output_begin + i)`
* indicating if the key `*(first + i)` is present in the map. If `pred( *(stencil + i) )` is
* false, stores `false` to `(output_begin + i)`.
*
* @tparam InputIt Device accessible input iterator
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
* @tparam OutputIt Device accessible output iterator assignable from `bool`
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param output_begin Beginning of the sequence of booleans for the presence of each key
* @param stream Stream used for executing the kernels
*/
template <typename InputIt, typename StencilIt, typename Predicate, typename OutputIt>
void contains_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
OutputIt output_begin,
cuda::stream_ref stream = {}) const noexcept;

/**
* @brief Gets the maximum number of elements the hash map can hold.
*
Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ ConfigureTest(STATIC_MULTIMAP_TEST
static_multimap/custom_pair_retrieve_test.cu
static_multimap/custom_type_test.cu
static_multimap/heterogeneous_lookup_test.cu
static_multimap/insert_test.cu
static_multimap/insert_contains_test.cu
static_multimap/insert_if_test.cu
static_multimap/multiplicity_test.cu
static_multimap/non_match_test.cu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,28 @@ void test_insert(Map& map, std::size_t num_keys)
map.contains(keys_begin, keys_begin + num_keys, d_contained.begin());
REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), thrust::identity{}));
}

SECTION("Conditional contains should return true on even inputs.")
{
auto is_even =
cuda::proclaim_return_type<bool>([] __device__(auto const& i) { return i % 2 == 0; });
auto zip_equal = cuda::proclaim_return_type<bool>(
[] __device__(auto const& p) { return thrust::get<0>(p) == thrust::get<1>(p); });

map.contains_if(keys_begin,
keys_begin + num_keys,
thrust::counting_iterator<std::size_t>(0),
is_even,
d_contained.begin());
auto gold_iter =
thrust::make_transform_iterator(thrust::counting_iterator<std::size_t>(0), is_even);
auto zip = thrust::make_zip_iterator(thrust::make_tuple(d_contained.begin(), gold_iter));
REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal));
}
}

TEMPLATE_TEST_CASE_SIG(
"static_multimap insert test",
"static_multimap insert/contains test",
"",
((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize),
Key,
Expand Down

0 comments on commit b82a666

Please sign in to comment.