Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions include/cuco/detail/static_multimap/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,75 @@ __global__ void contains(
}
}

/**
* @brief Indicates whether the pairs in the range `[first, first + n)` are contained in the map if
* `pred` of the corresponding stencil returns true.
*
* If `pred( *(stencil + i) )` is true, stores `true` or `false` to `(output_begin + i)` indicating
* if the pair `*(first + i)` exists in the map. If `pred( *(stencil + i) )` is false, stores false
* to `(output_begin + i)`.
*
* Uses the CUDA Cooperative Groups API to leverage groups of multiple threads to perform the
* contains operation for each element. This provides a significant boost in throughput compared
* to the non Cooperative Group `contains` at moderate to high load factors.
*
* @tparam block_size The size of the thread block
* @tparam tile_size The number of threads in the Cooperative Groups
* @tparam InputIt Device accessible input iterator
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam OutputIt Device accessible output iterator assignable from `bool`
* @tparam viewT Type of device view allowing access of hash map storage
* @tparam PairEqual Binary callable type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>.
*
* @param first Beginning of the sequence of elements
* @param last End of the sequence of elements
* @param output_begin Beginning of the sequence of booleans for the presence of each element
* @param view Device view used to access the hash map's slot storage
* @param equal The binary function to compare input element and slot content for equality
* @param pred Predicate to test on every element in the range `[stencil, stencil + n)`
*/
template <uint32_t block_size,
uint32_t tile_size,
typename InputIt,
typename StencilIt,
typename OutputIt,
typename viewT,
typename PairEqual,
typename Predicate>
__global__ void pair_contains_if_n(InputIt first,
StencilIt stencil,
OutputIt output_begin,
std::size_t n,
viewT view,
PairEqual pair_equal,
Predicate pred)
{
auto tile = cg::tiled_partition<tile_size>(cg::this_thread_block());
auto tid = block_size * blockIdx.x + threadIdx.x;
auto idx = tid / tile_size;
__shared__ bool writeBuffer[block_size];

while (idx < n) {
typename std::iterator_traits<InputIt>::value_type pair = *(first + idx);
auto found = pred(*(stencil + idx)) ? view.pair_contains(tile, pair, pair_equal) : false;

/*
* The ld.relaxed.gpu instruction used in view.find causes L1 to
* flush more frequently, causing increased sector stores from L2 to global memory.
* By writing results to shared memory and then synchronizing before writing back
* to global, we no longer rely on L1, preventing the increase in sector stores from
* L2 to global and improving performance.
*/
if (tile.thread_rank() == 0) { writeBuffer[threadIdx.x / tile_size] = found; }
__syncthreads();
if (tile.thread_rank() == 0) { *(output_begin + idx) = writeBuffer[threadIdx.x / tile_size]; }
idx += (gridDim.x * block_size) / tile_size;
}
}

/**
* @brief Counts the occurrences of keys in `[first, last)` contained in the multimap.
*
Expand Down
33 changes: 31 additions & 2 deletions include/cuco/detail/static_multimap/static_multimap.inl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ void static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::contains(

detail::contains<is_pair_contains, block_size, cg_size()>
<<<grid_size, block_size, 0, stream>>>(first, last, output_begin, view, key_equal);
CUCO_CUDA_TRY(cudaStreamSynchronize(stream));
}

template <typename Key,
Expand All @@ -145,7 +144,37 @@ void static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::pair_contains

detail::contains<is_pair_contains, block_size, cg_size()>
<<<grid_size, block_size, 0, stream>>>(first, last, output_begin, view, pair_equal);
CUCO_CUDA_TRY(cudaStreamSynchronize(stream));
}

template <typename Key,
typename Value,
cuda::thread_scope Scope,
typename Allocator,
class ProbeSequence>
template <typename InputIt,
typename StencilIt,
typename OutputIt,
typename PairEqual,
typename Predicate>
void static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::pair_contains_if(
InputIt first,
InputIt last,
StencilIt stencil,
OutputIt output_begin,
PairEqual pair_equal,
Predicate pred,
cudaStream_t stream) const
{
auto const num_pairs = std::distance(first, last);
if (num_pairs == 0) { return; }

auto constexpr block_size = 128;
auto constexpr stride = 1;
auto const grid_size = (cg_size() * num_pairs + stride * block_size - 1) / (stride * block_size);
auto view = get_device_view();

detail::pair_contains_if_n<block_size, cg_size()><<<grid_size, block_size, 0, stream>>>(
first, stencil, output_begin, num_pairs, view, pair_equal, pred);
}

template <typename Key,
Expand Down
44 changes: 44 additions & 0 deletions include/cuco/static_multimap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,50 @@ class static_multimap {
PairEqual pair_equal,
cudaStream_t stream = 0) const;

/**
* @brief Indicates whether the pairs in the range `[first, last)` are contained in the map if
* `pred` of the corresponding stencil returns true.
*
* If `pred( *(stencil + i) )` is true, stores `true` or `false` to `(output + i)` indicating if
* the pair `*(first + i)` exists in the map. If `pred( *(stencil + i) )` is false, stores false
* to `(output + i)`.
*
* ProbeSequence hashers should be callable with both
* <tt>std::iterator_traits<InputIt>::value_type::first_type</tt>
* and Key type. <tt>std::invoke_result<KeyEqual,
* std::iterator_traits<InputIt>::value_type::first_type, Key></tt>
* must be well-formed.
*
* @tparam InputIt Device accessible random access input iterator
* @tparam StencilIt Device accessible random access iterator whose value_type is
* convertible to Predicate's argument type
* @tparam OutputIt Device accessible output iterator assignable from `bool`
* @tparam PairEqual Binary callable type used to compare input pair and slot content for equality
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>.
*
* @param first Beginning of the sequence of pairs
* @param last End of the sequence of pairs
* @param stencil Beginning of the stencil sequence
* @param output_begin Beginning of the output sequence indicating whether each pair is present
* @param pair_equal The binary function to compare input pair and slot content for equality
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param stream CUDA stream used for contains
*/
template <typename InputIt,
typename StencilIt,
typename OutputIt,
typename PairEqual,
typename Predicate>
void pair_contains_if(InputIt first,
InputIt last,
StencilIt stencil,
OutputIt output_begin,
PairEqual pair_equal,
Predicate pred,
cudaStream_t stream = 0) const;

/**
* @brief Counts the occurrences of keys in `[first, last)` contained in the multimap.
*
Expand Down
19 changes: 19 additions & 0 deletions tests/static_multimap/pair_function_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@ __inline__ void test_pair_functions(Map& map, PairIt pair_begin, std::size_t num
res_begin + num_pairs / 2, res_begin + num_pairs, false_iter, thrust::equal_to<bool>{}));
}

SECTION("pair_contains_if checks the input pair only if the corresponding stencil returns true.")
{
thrust::device_vector<bool> result(num_pairs);
auto res_begin = result.begin();
auto pred = [num_pairs] __device__(int32_t s) {
if (s < num_pairs / 2) { return false; }
return true;
};
auto count_iter = thrust::make_counting_iterator<int32_t>(0);
map.pair_contains_if(
pair_begin, pair_begin + num_pairs, count_iter, res_begin, pair_equal<Key, Value>{}, pred);

auto false_iter = thrust::make_constant_iterator(false);

// All false since the stencil of the first half is false and the second half is not inserted
REQUIRE(
cuco::test::equal(res_begin, res_begin + num_pairs, false_iter, thrust::equal_to<bool>{}));
}

SECTION("Output of pair_count and pair_retrieve should be coherent.")
{
auto num = map.pair_count(pair_begin, pair_begin + num_pairs, pair_equal<Key, Value>{});
Expand Down