Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions include/cuco/detail/probe_sequence_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include <cooperative_groups.h>

#include <utility>

namespace cuco {
namespace detail {

Expand Down Expand Up @@ -208,6 +210,24 @@ class linear_probing_impl
template <typename ProbeKey>
__device__ __forceinline__ iterator
initial_slot(cooperative_groups::thread_block_tile<cg_size> const& g, ProbeKey const& k) noexcept
{
return const_cast<iterator>(std::as_const(*this).initial_slot(g, k));
}

/**
* @brief Returns the initial slot for a given key `k`.
*
* If vector-load is enabled, the return slot is always even to avoid illegal memory access.
*
* @tparam ProbeKey Probe key type
*
* @param g the Cooperative Group for which the initial slot is needed
* @param k The key to get the slot for
* @return Pointer to the initial slot for `k`
*/
template <typename ProbeKey>
__device__ __forceinline__ const_iterator initial_slot(
cooperative_groups::thread_block_tile<cg_size> const& g, ProbeKey const& k) const noexcept
{
auto const hash_value = [&]() {
auto const tmp = hash_(k);
Expand Down Expand Up @@ -237,6 +257,19 @@ class linear_probing_impl
* @return The next slot after `s`
*/
__device__ __forceinline__ iterator next_slot(iterator s) noexcept
{
return const_cast<iterator>(std::as_const(*this).next_slot(s));
}

/**
* @brief Given a slot `s`, returns the next slot.
*
* If `s` is the last slot, wraps back around to the first slot.
*
* @param s The slot to advance
* @return The next slot after `s`
*/
__device__ __forceinline__ const_iterator next_slot(const_iterator s) const noexcept
{
std::size_t index = s - slots_;
std::size_t offset;
Expand Down Expand Up @@ -331,6 +364,25 @@ class double_hashing_impl
template <typename ProbeKey>
__device__ __forceinline__ iterator
initial_slot(cooperative_groups::thread_block_tile<cg_size> const& g, ProbeKey const& k) noexcept
{
return const_cast<iterator>(std::as_const(*this).initial_slot(g, k));
}

/**
* @brief Returns the initial slot for a given key `k`.
*
* If vector-load is enabled, the return slot is always a multiple of (`cg_size` * `vector_width`)
* to avoid illegal memory access.
*
* @tparam ProbeKey Probe key type
*
* @param g the Cooperative Group for which the initial slot is needed
* @param k The key to get the slot for
* @return Pointer to the initial slot for `k`
*/
template <typename ProbeKey>
__device__ __forceinline__ const_iterator initial_slot(
cooperative_groups::thread_block_tile<cg_size> const& g, ProbeKey const& k) const noexcept
{
std::size_t index;
auto const hash_value = hash1_(k);
Expand All @@ -357,16 +409,29 @@ class double_hashing_impl
* @return The next slot after `s`
*/
__device__ __forceinline__ iterator next_slot(iterator s) noexcept
{
return const_cast<iterator>(std::as_const(*this).next_slot(s));
}

/**
* @brief Given a slot `s`, returns the next slot.
*
* If `s` is the last slot, wraps back around to the first slot.
*
* @param s The slot to advance
* @return The next slot after `s`
*/
__device__ __forceinline__ const_iterator next_slot(const_iterator s) const noexcept
{
std::size_t index = s - slots_;
return &slots_[(index + step_size_) % capacity_];
}

private:
Hash1 hash1_; ///< The first unary callable used to hash the key
Hash2 hash2_; ///< The second unary callable used to determine step size
std::size_t step_size_; ///< The step stride when searching for the next slot
}; // class double_hashing
Hash1 hash1_; ///< The first unary callable used to hash the key
Hash2 hash2_; ///< The second unary callable used to determine step size
mutable std::size_t step_size_; ///< The step stride when searching for the next slot
}; // class double_hashing

/**
* @brief Probe sequence used internally by hash map.
Expand Down
2 changes: 1 addition & 1 deletion include/cuco/detail/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ void static_map<Key, Value, Scope, Allocator>::contains(InputIt first,
OutputIt output_begin,
Hash hash,
KeyEqual key_equal,
cudaStream_t stream)
cudaStream_t stream) const
{
auto num_keys = std::distance(first, last);
if (num_keys == 0) { return; }
Expand Down
99 changes: 61 additions & 38 deletions include/cuco/detail/static_multimap/device_view_impl.inl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
* @param current_slot The given slot to load from
*/
__device__ __forceinline__ void load_pair_array(value_type* arr,
const_iterator current_slot) noexcept
const_iterator current_slot) const noexcept
{
if constexpr (sizeof(value_type) == 4) {
auto const tmp = *reinterpret_cast<ushort4 const*>(current_slot);
Expand Down Expand Up @@ -567,32 +567,33 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
}

/**
* @brief Indicates whether the key `k` exists in the map using vector loads.
* @brief Indicates whether the probe `element` exists in the map using vector loads.
*
* If the key `k` was inserted into the map, `contains` returns
* true. Otherwise, it returns false. Uses the CUDA Cooperative Groups API to
* to leverage multiple threads to perform a single `contains` operation. This provides a
* significant boost in throughput compared to the non Cooperative Group based
* `contains` at moderate to high load factors.
* If `element` was inserted into the map, `contains` returns true. Otherwise, it returns false.
* Uses the CUDA Cooperative Groups API to leverage multiple threads to perform a single
* `contains` operation. This provides a significant boost in throughput compared to the non
* Cooperative Group based `contains` at moderate to high load factors.
*
* @tparam is_pair_contains `true` if it's a `pair_contains` implementation
* @tparam uses_vector_load Boolean flag indicating whether vector loads are used
* @tparam ProbeKey Probe key type
* @tparam KeyEqual Binary callable type
* @tparam ProbeT Probe data type
* @tparam Equal Binary callable type
*
* @param g The Cooperative Group used to perform the contains operation
* @param k The key to search for
* @param key_equal The binary callable used to compare two keys
* for equality
* @return A boolean indicating whether the key/value pair
* containing `k` was inserted
* @param element The probe element to search for
* @param equal The binary function to compare input element and slot content for equality
* @return A boolean indicating whether the key/value pair represented by `element` was inserted
*/
template <bool uses_vector_load, typename ProbeKey, typename KeyEqual>
template <bool is_pair_contains, bool uses_vector_load, typename ProbeT, typename Equal>
__device__ __forceinline__ std::enable_if_t<uses_vector_load, bool> contains(
cooperative_groups::thread_block_tile<ProbeSequence::cg_size> const& g,
ProbeKey const& k,
KeyEqual key_equal) noexcept
ProbeT const& element,
Equal equal) const noexcept
{
auto current_slot = initial_slot(g, k);
auto current_slot = [&]() {
if constexpr (is_pair_contains) { return initial_slot(g, element.first); }
if constexpr (not is_pair_contains) { return initial_slot(g, element); }
}();

while (true) {
value_type arr[2];
Expand All @@ -602,8 +603,22 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
detail::bitwise_compare(arr[0].first, this->get_empty_key_sentinel());
auto const second_slot_is_empty =
detail::bitwise_compare(arr[1].first, this->get_empty_key_sentinel());
auto const first_equals = (not first_slot_is_empty and key_equal(arr[0].first, k));
auto const second_equals = (not second_slot_is_empty and key_equal(arr[1].first, k));
auto const first_equals = [&]() {
if constexpr (is_pair_contains) {
return not first_slot_is_empty and equal(arr[0], element);
}
if constexpr (not is_pair_contains) {
return not first_slot_is_empty and equal(arr[0].first, element);
}
}();
auto const second_equals = [&]() {
if constexpr (is_pair_contains) {
return not second_slot_is_empty and equal(arr[1], element);
}
if constexpr (not is_pair_contains) {
return not second_slot_is_empty and equal(arr[1].first, element);
}
}();

// the key we were searching for was found by one of the threads, so we return true
if (g.any(first_equals or second_equals)) { return true; }
Expand All @@ -618,32 +633,33 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
}

/**
* @brief Indicates whether the key `k` exists in the map using scalar loads.
* @brief Indicates whether `element` exists in the map using scalar loads.
*
* If the key `k` was inserted into the map, `contains` returns
* true. Otherwise, it returns false. Uses the CUDA Cooperative Groups API to
* to leverage multiple threads to perform a single `contains` operation. This provides a
* significant boost in throughput compared to the non Cooperative Group
* `contains` at moderate to high load factors.
* If `element` was inserted into the map, `contains` returns true. Otherwise, it returns false.
* Uses the CUDA Cooperative Groups API to leverage multiple threads to perform a single
* `contains` operation. This provides a significant boost in throughput compared to the non
* Cooperative Group `contains` at moderate to high load factors.
*
* @tparam is_pair_contains `true` if it's a `pair_contains` implementation
* @tparam uses_vector_load Boolean flag indicating whether vector loads are used
* @tparam ProbeKey Probe key type
* @tparam KeyEqual Binary callable type
* @tparam ProbeT Probe data type
* @tparam Equal Binary callable type
*
* @param g The Cooperative Group used to perform the contains operation
* @param k The key to search for
* @param key_equal The binary callable used to compare two keys
* for equality
* @return A boolean indicating whether the key/value pair
* containing `k` was inserted
* @param element The probe element to search for
* @param equal The binary function to compare input element and slot content for equality
* @return A boolean indicating whether the key/value pair represented by `element` was inserted
*/
template <bool uses_vector_load, typename ProbeKey, typename KeyEqual>
template <bool is_pair_contains, bool uses_vector_load, typename ProbeT, typename Equal>
__device__ __forceinline__ std::enable_if_t<not uses_vector_load, bool> contains(
cooperative_groups::thread_block_tile<ProbeSequence::cg_size> const& g,
ProbeKey const& k,
KeyEqual key_equal) noexcept
ProbeT const& element,
Equal equal) const noexcept
{
auto current_slot = initial_slot(g, k);
auto current_slot = [&]() {
if constexpr (is_pair_contains) { return initial_slot(g, element.first); }
if constexpr (not is_pair_contains) { return initial_slot(g, element); }
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also all other places.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I recall, we need this to silence the compiler warnings with cuda-11.0/11.2. Reverting back to the if constexpr (not ...) expression.

}();

while (true) {
value_type slot_contents = *reinterpret_cast<value_type const*>(current_slot);
Expand All @@ -654,7 +670,14 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
auto const slot_is_empty =
detail::bitwise_compare(existing_key, this->get_empty_key_sentinel());

auto const equals = (not slot_is_empty and key_equal(existing_key, k));
auto const equals = [&]() {
if constexpr (is_pair_contains) {
return not slot_is_empty and equal(slot_contents, element);
}
if constexpr (not is_pair_contains) {
return not slot_is_empty and equal(existing_key, element);
}
}();

// the key we were searching for was found by one of the threads, so we return true
if (g.any(equals)) { return true; }
Expand Down
53 changes: 29 additions & 24 deletions include/cuco/detail/static_multimap/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

#include <cooperative_groups/memcpy_async.h>

#include <iterator>

namespace cuco {
namespace detail {
namespace cg = cooperative_groups;
Expand Down Expand Up @@ -144,45 +146,50 @@ __global__ void insert_if_n(InputIt first, StencilIt s, std::size_t n, viewT vie
}

/**
* @brief Indicates whether the keys in the range `[first, last)` are contained in the map.
* @brief Indicates whether the elements in the range `[first, last)` are contained in the map.
*
* Stores `true` or `false` to `(output + i)` indicating if the key `*(first + i)` exists in the
* Stores `true` or `false` to `(output + i)` indicating if the element `*(first + i)` exists in the
* map.
*
* Uses the CUDA Cooperative Groups API to leverage groups of multiple threads to perform the
* contains operation for each key. This provides a significant boost in throughput compared
* contains operation for each element. This provides a significant boost in throughput compared
* to the non Cooperative Group `contains` at moderate to high load factors.
*
* @tparam is_pair_contains `true` if it's a `pair_contains` implementation
* @tparam block_size The size of the thread block
* @tparam tile_size The number of threads in the Cooperative Groups
* @tparam InputIt Device accessible input iterator whose `value_type` is
* convertible to the map's `key_type`
* @tparam OutputIt Device accessible output iterator whose `value_type` is convertible from `bool`
* @tparam InputIt Device accessible input iterator
* @tparam OutputIt Device accessible output iterator assignable from `bool`
* @tparam viewT Type of device view allowing access of hash map storage
* @tparam KeyEqual Binary callable type
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param output_begin Beginning of the sequence of booleans for the presence of each key
* @tparam Equal Binary callable type
*
* @param first Beginning of the sequence of elements
* @param last End of the sequence of elements
* @param output_begin Beginning of the sequence of booleans for the presence of each element
* @param view Device view used to access the hash map's slot storage
* @param key_equal The binary function to compare two keys for equality
* @param equal The binary function to compare input element and slot content for equality
*/
template <uint32_t block_size,
template <bool is_pair_contains,
uint32_t block_size,
uint32_t tile_size,
typename InputIt,
typename OutputIt,
typename viewT,
typename KeyEqual>
typename Equal>
__global__ void contains(
InputIt first, InputIt last, OutputIt output_begin, viewT view, KeyEqual key_equal)
InputIt first, InputIt last, OutputIt output_begin, viewT view, Equal equal)
{
auto tile = cg::tiled_partition<tile_size>(cg::this_thread_block());
auto tid = block_size * blockIdx.x + threadIdx.x;
auto key_idx = tid / tile_size;
auto tile = cg::tiled_partition<tile_size>(cg::this_thread_block());
auto tid = block_size * blockIdx.x + threadIdx.x;
auto idx = tid / tile_size;
__shared__ bool writeBuffer[block_size];

while (first + key_idx < last) {
auto key = *(first + key_idx);
auto found = view.contains(tile, key, key_equal);
while (first + idx < last) {
typename std::iterator_traits<InputIt>::value_type element = *(first + idx);
auto found = [&]() {
if constexpr (is_pair_contains) { return view.pair_contains(tile, element, equal); }
if constexpr (not is_pair_contains) { return view.contains(tile, element, equal); }
}();

/*
* The ld.relaxed.gpu instruction used in view.find causes L1 to
Expand All @@ -193,10 +200,8 @@ __global__ void contains(
*/
if (tile.thread_rank() == 0) { writeBuffer[threadIdx.x / tile_size] = found; }
__syncthreads();
if (tile.thread_rank() == 0) {
*(output_begin + key_idx) = writeBuffer[threadIdx.x / tile_size];
}
key_idx += (gridDim.x * block_size) / tile_size;
if (tile.thread_rank() == 0) { *(output_begin + idx) = writeBuffer[threadIdx.x / tile_size]; }
idx += (gridDim.x * block_size) / tile_size;
}
}

Expand Down
Loading