diff --git a/include/cuco/detail/probing_scheme/probing_scheme_impl.inl b/include/cuco/detail/probing_scheme/probing_scheme_impl.inl index 047ec7987..da1dc5dbd 100644 --- a/include/cuco/detail/probing_scheme/probing_scheme_impl.inl +++ b/include/cuco/detail/probing_scheme/probing_scheme_impl.inl @@ -95,7 +95,7 @@ __host__ __device__ constexpr linear_probing::linear_probing(Hash template template -__host__ __device__ constexpr auto linear_probing::with_hash_function( +__host__ __device__ constexpr auto linear_probing::rebind_hash_function( NewHash const& hash) const noexcept { return linear_probing{hash}; @@ -148,17 +148,9 @@ __host__ __device__ constexpr double_hashing::double_hashi { } -template -template -__host__ __device__ constexpr auto double_hashing::with_hash_function( - NewHash1 const& hash1, NewHash2 const& hash2) const noexcept -{ - return double_hashing{hash1, hash2}; -} - template template -__host__ __device__ constexpr auto double_hashing::with_hash_function( +__host__ __device__ constexpr auto double_hashing::rebind_hash_function( NewHash const& hash) const { static_assert(cuco::is_tuple_like::value, diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 3756a641b..3dbdebe57 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -311,22 +311,44 @@ template template -__host__ __device__ auto constexpr static_map_ref::with_operators(NewOperators...) - const noexcept +__host__ __device__ constexpr auto +static_map_ref::with_operators( + NewOperators...) const noexcept { return static_map_ref{ cuco::empty_key{this->empty_key_sentinel()}, cuco::empty_value{this->empty_value_sentinel()}, this->key_eq(), - this->impl_.probing_scheme(), + this->probing_scheme(), {}, - this->impl_.storage_ref()}; + this->storage_ref()}; +} + +template +template +__host__ __device__ constexpr auto +static_map_ref:: + rebind_hash_function(NewHash const& hash) const +{ + auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash); + return static_map_ref, + StorageRef, + Operators...>{cuco::empty_key{this->empty_key_sentinel()}, + cuco::empty_value{this->empty_value_sentinel()}, + this->key_eq(), + probing_scheme, + {}, + this->storage_ref()}; } template cuco::empty_value{this->empty_value_sentinel()}, cuco::erased_key{this->erased_key_sentinel()}, this->key_eq(), - this->impl_.probing_scheme(), + this->probing_scheme(), scope, storage_ref_type{this->window_extent(), memory_to_use}}; } diff --git a/include/cuco/detail/static_multimap/static_multimap_ref.inl b/include/cuco/detail/static_multimap/static_multimap_ref.inl index 3bbf90e3a..f2adbfaa2 100644 --- a/include/cuco/detail/static_multimap/static_multimap_ref.inl +++ b/include/cuco/detail/static_multimap/static_multimap_ref.inl @@ -327,6 +327,33 @@ __host__ __device__ auto constexpr static_multimap_ref< impl_.storage_ref()}; } +template +template +__host__ __device__ constexpr auto +static_multimap_ref:: + rebind_hash_function(NewHash const& hash) const +{ + auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash); + return static_multimap_ref, + StorageRef, + Operators...>{cuco::empty_key{this->empty_key_sentinel()}, + cuco::empty_value{this->empty_value_sentinel()}, + this->key_eq(), + probing_scheme, + {}, + this->storage_ref()}; +} + template { return impl_->count(first, last, - ref(op::count).with_key_eq(probe_key_equal).with_hash_function(probe_hash), + ref(op::count).with_key_eq(probe_key_equal).rebind_hash_function(probe_hash), stream); } @@ -333,7 +333,7 @@ static_multiset return impl_->count_outer( first, last, - ref(op::count).with_key_eq(probe_key_equal).with_hash_function(probe_hash), + ref(op::count).with_key_eq(probe_key_equal).rebind_hash_function(probe_hash), stream); } diff --git a/include/cuco/detail/static_multiset/static_multiset_ref.inl b/include/cuco/detail/static_multiset/static_multiset_ref.inl index aa25cdf70..db9f6c3ce 100644 --- a/include/cuco/detail/static_multiset/static_multiset_ref.inl +++ b/include/cuco/detail/static_multiset/static_multiset_ref.inl @@ -272,9 +272,9 @@ static_multiset_ref{ cuco::empty_key{this->empty_key_sentinel()}, this->key_eq(), - this->impl_.probing_scheme(), + this->probing_scheme(), {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } template { cuco::empty_key{this->empty_key_sentinel()}, key_equal, - this->impl_.probing_scheme(), + this->probing_scheme(), {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } template __host__ __device__ constexpr auto static_multiset_ref:: - with_hash_function(NewHash const& hash) const + rebind_hash_function(NewHash const& hash) const { - auto const probing_scheme = this->impl_.probing_scheme().with_hash_function(hash); + auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash); return static_multiset_ref, StorageRef, Operators...>{cuco::empty_key{this->empty_key_sentinel()}, - this->impl_.key_eq(), + this->key_eq(), probing_scheme, {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } namespace detail { diff --git a/include/cuco/detail/static_set/static_set_ref.inl b/include/cuco/detail/static_set/static_set_ref.inl index 7e2882a0a..e872f7522 100644 --- a/include/cuco/detail/static_set/static_set_ref.inl +++ b/include/cuco/detail/static_set/static_set_ref.inl @@ -269,9 +269,9 @@ static_set_ref::w return static_set_ref{ cuco::empty_key{this->empty_key_sentinel()}, this->key_eq(), - this->impl_.probing_scheme(), + this->probing_scheme(), {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } template ::w return static_set_ref{ cuco::empty_key{this->empty_key_sentinel()}, key_equal, - this->impl_.probing_scheme(), + this->probing_scheme(), {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } template template __host__ __device__ constexpr auto -static_set_ref::with_hash_function( +static_set_ref::rebind_hash_function( NewHash const& hash) const { - auto const probing_scheme = this->impl_.probing_scheme().with_hash_function(hash); + auto const probing_scheme = this->probing_scheme().rebind_hash_function(hash); return static_set_ref, StorageRef, Operators...>{cuco::empty_key{this->empty_key_sentinel()}, - this->impl_.key_eq(), + this->key_eq(), probing_scheme, {}, - this->impl_.storage_ref()}; + this->storage_ref()}; } template ::m cuco::empty_key{this->empty_key_sentinel()}, cuco::erased_key{this->erased_key_sentinel()}, this->key_eq(), - this->impl_.probing_scheme(), + this->probing_scheme(), scope, storage_ref_type{this->window_extent(), memory_to_use}}; } diff --git a/include/cuco/probing_scheme.cuh b/include/cuco/probing_scheme.cuh index 4885ad63d..4b77de7f5 100644 --- a/include/cuco/probing_scheme.cuh +++ b/include/cuco/probing_scheme.cuh @@ -62,7 +62,7 @@ class linear_probing : private detail::probing_scheme_base { * @return Copy of the current probing method */ template - [[nodiscard]] __host__ __device__ constexpr auto with_hash_function( + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function( NewHash const& hash) const noexcept; /** @@ -145,22 +145,6 @@ class double_hashing : private detail::probing_scheme_base { */ __host__ __device__ constexpr double_hashing(cuco::pair const& hash); - /** - *@brief Makes a copy of the current probing method with the given hasher - * - * @tparam NewHash1 First new hasher type - * @tparam NewHash2 Second new hasher type - * - * @param hash1 First hasher - * @param hash2 second hasher - * - * @return Copy of the current probing method - */ - template - [[nodiscard]] __host__ __device__ constexpr auto with_hash_function(NewHash1 const& hash1, - NewHash2 const& hash2 = { - 1}) const noexcept; - /** *@brief Makes a copy of the current probing method with the given hasher * @@ -174,7 +158,7 @@ class double_hashing : private detail::probing_scheme_base { */ template ::value>> - [[nodiscard]] __host__ __device__ constexpr auto with_hash_function(NewHash const& hash) const; + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const; /** * @brief Operator to return a probing iterator diff --git a/include/cuco/static_map_ref.cuh b/include/cuco/static_map_ref.cuh index 1da1e501a..52ef89d73 100644 --- a/include/cuco/static_map_ref.cuh +++ b/include/cuco/static_map_ref.cuh @@ -261,6 +261,18 @@ class static_map_ref [[nodiscard]] __host__ __device__ constexpr auto with_operators( NewOperators... ops) const noexcept; + /** + * @brief Makes a copy of the current device reference with the given hasher + * + * @tparam NewHash The new hasher type + * + * @param hash New hasher + * + * @return Copy of the current device ref + */ + template + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const; + /** * @brief Makes a copy of the current device reference using non-owned memory * diff --git a/include/cuco/static_multimap_ref.cuh b/include/cuco/static_multimap_ref.cuh index b23925b86..56721319e 100644 --- a/include/cuco/static_multimap_ref.cuh +++ b/include/cuco/static_multimap_ref.cuh @@ -260,6 +260,18 @@ class static_multimap_ref [[nodiscard]] __host__ __device__ constexpr auto with_operators( NewOperators... ops) const noexcept; + /** + * @brief Makes a copy of the current device reference with the given hasher + * + * @tparam NewHash The new hasher type + * + * @param hash New hasher + * + * @return Copy of the current device ref + */ + template + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const; + /** * @brief Makes a copy of the current device reference using non-owned memory * diff --git a/include/cuco/static_multiset_ref.cuh b/include/cuco/static_multiset_ref.cuh index bf0588f2f..52d4c0fa6 100644 --- a/include/cuco/static_multiset_ref.cuh +++ b/include/cuco/static_multiset_ref.cuh @@ -254,7 +254,7 @@ class static_multiset_ref NewKeyEqual const& key_equal) const noexcept; /** - * @brief Makes a copy of the current device reference with given hasher + * @brief Makes a copy of the current device reference with the given hasher * * @tparam NewHash The new hasher type * @@ -263,7 +263,7 @@ class static_multiset_ref * @return Copy of the current device ref */ template - [[nodiscard]] __host__ __device__ constexpr auto with_hash_function(NewHash const& hash) const; + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const; private: impl_type impl_; diff --git a/include/cuco/static_set_ref.cuh b/include/cuco/static_set_ref.cuh index 1271cb756..197d50d85 100644 --- a/include/cuco/static_set_ref.cuh +++ b/include/cuco/static_set_ref.cuh @@ -252,7 +252,7 @@ class static_set_ref NewKeyEqual const& key_equal) const noexcept; /** - * @brief Makes a copy of the current device reference with given hasher + * @brief Makes a copy of the current device reference with the given hasher * * @tparam NewHash The new hasher type * @@ -261,7 +261,7 @@ class static_set_ref * @return Copy of the current device ref */ template - [[nodiscard]] __host__ __device__ constexpr auto with_hash_function(NewHash const& hash) const; + [[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const; /** * @brief Makes a copy of the current device reference using non-owned memory diff --git a/tests/static_multiset/custom_count_test.cu b/tests/static_multiset/custom_count_test.cu index f92b91aad..69ce41da6 100644 --- a/tests/static_multiset/custom_count_test.cu +++ b/tests/static_multiset/custom_count_test.cu @@ -61,21 +61,27 @@ void test_custom_count(Set& set, size_type num_keys) { using Key = typename Set::key_type; + auto const hash = []() { + if constexpr (cuco::is_double_hashing::value) { + return cuco::pair{custom_hash{}, custom_hash{}}; + } else { + return custom_hash{}; + } + }(); + auto query_begin = thrust::make_transform_iterator( thrust::make_counting_iterator(0), cuda::proclaim_return_type([] __device__(auto i) { return static_cast(i * XXX); })); SECTION("Count of empty set should be zero.") { - auto const count = - set.count(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == 0); } SECTION("Outer count of empty set should be the same as input size.") { - auto const count = - set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == num_keys); } @@ -84,15 +90,13 @@ void test_custom_count(Set& set, size_type num_keys) SECTION("Count of n unique keys should be n.") { - auto const count = - set.count(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == num_keys); } SECTION("Outer count of n unique keys should be n.") { - auto const count = - set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == num_keys); } @@ -102,15 +106,13 @@ void test_custom_count(Set& set, size_type num_keys) SECTION("Count of a key whose multiplicity equals n should be n.") { - auto const count = - set.count(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == num_keys); } SECTION("Outer count of a key whose multiplicity equals n should be n + input_size - 1.") { - auto const count = - set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, custom_hash{}); + auto const count = set.count_outer(query_begin, query_begin + num_keys, custom_key_eq{}, hash); REQUIRE(count == 2 * num_keys - 1); } }