Skip to content

Refactor thrust::[stable_]partition[_copy] to use cub::DevicePartition #1435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 145 additions & 78 deletions cub/cub/agent/agent_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include <cub/block/block_store.cuh>
#include <cub/grid/grid_queue.cuh>
#include <cub/iterator/cache_modified_input_iterator.cuh>
#include <cub/util_type.cuh>

#include <cuda/std/type_traits>

Expand Down Expand Up @@ -121,6 +122,19 @@ struct AgentSelectIfPolicy
* Thread block abstractions
******************************************************************************/

namespace detail
{
template <typename SelectedOutputItT, typename RejectedOutputItT>
struct partition_distinct_output_t
{
using selected_iterator_t = SelectedOutputItT;
using rejected_iterator_t = RejectedOutputItT;

selected_iterator_t selected_it;
rejected_iterator_t rejected_it;
};
} // namespace detail

/**
* @brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in
* device-wide selection
Expand All @@ -139,8 +153,8 @@ struct AgentSelectIfPolicy
* Random-access input iterator type for selections (NullType* if a selection functor or
* discontinuity flagging is to be used for selection)
*
* @tparam SelectedOutputIteratorT
* Random-access output iterator type for selection_flags items
* @tparam OutputIteratorWrapperT
* Either a random-access iterator or an instance of the `partition_distinct_output_t` template.
*
* @tparam SelectOpT
* Selection operator type (NullType if selections or discontinuity flagging is to be used for
Expand All @@ -159,7 +173,7 @@ struct AgentSelectIfPolicy
template <typename AgentSelectIfPolicyT,
typename InputIteratorT,
typename FlagsInputIteratorT,
typename SelectedOutputIteratorT,
typename OutputIteratorWrapperT,
typename SelectOpT,
typename EqualityOpT,
typename OffsetT,
Expand Down Expand Up @@ -279,13 +293,13 @@ struct AgentSelectIf
// Per-thread fields
//---------------------------------------------------------------------

_TempStorage &temp_storage; ///< Reference to temp_storage
WrappedInputIteratorT d_in; ///< Input items
SelectedOutputIteratorT d_selected_out; ///< Unique output items
WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable)
_TempStorage& temp_storage; ///< Reference to temp_storage
WrappedInputIteratorT d_in; ///< Input items
OutputIteratorWrapperT d_selected_out; ///< Unique output items
WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable)
InequalityWrapper<EqualityOpT> inequality_op; ///< T inequality operator
SelectOpT select_op; ///< Selection operator
OffsetT num_items; ///< Total number of input items
SelectOpT select_op; ///< Selection operator
OffsetT num_items; ///< Total number of input items

//---------------------------------------------------------------------
// Constructor
Expand Down Expand Up @@ -316,7 +330,7 @@ struct AgentSelectIf
_CCCL_DEVICE _CCCL_FORCEINLINE AgentSelectIf(TempStorage &temp_storage,
InputIteratorT d_in,
FlagsInputIteratorT d_flags_in,
SelectedOutputIteratorT d_selected_out,
OutputIteratorWrapperT d_selected_out,
SelectOpT select_op,
EqualityOpT equality_op,
OffsetT num_items)
Expand Down Expand Up @@ -477,10 +491,10 @@ struct AgentSelectIf
//---------------------------------------------------------------------

/**
* Scatter flagged items to output offsets (specialized for direct scattering)
* Scatter flagged items to output offsets (specialized for direct scattering).
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterDirect(
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterSelectedDirect(
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
Expand Down Expand Up @@ -519,19 +533,17 @@ struct AgentSelectIf
* Marker type indicating whether to keep rejected items in the second partition
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterTwoPhase(InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int /*num_tile_items*/,
int num_tile_selections,
OffsetT num_selections_prefix,
OffsetT /*num_rejected_prefix*/,
Int2Type<false> /*is_keep_rejects*/)
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterSelectedTwoPhase(
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_selections,
OffsetT num_selections_prefix)
{
CTA_SYNC();

// Compact and scatter items
#pragma unroll
// Compact and scatter items
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix;
Expand All @@ -550,7 +562,52 @@ struct AgentSelectIf
}

/**
* @brief Scatter flagged items to output offsets (specialized for two-phase scattering)
* @brief Scatter flagged items. Specialized for selection algorithm that simply discards rejected items
*
* @param num_tile_items
* Number of valid items in this tile
*
* @param num_tile_selections
* Number of selections in this tile
*
* @param num_selections_prefix
* Total number of selections prior to this tile
*
* @param num_rejected_prefix
* Total number of rejections prior to this tile
*
* @param num_selections
* Total number of selections including this tile
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
_CCCL_DEVICE _CCCL_FORCEINLINE void Scatter(
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items,
int num_tile_selections,
OffsetT num_selections_prefix,
OffsetT num_rejected_prefix,
OffsetT num_selections,
Int2Type<false> /*is_keep_rejects*/)
{
// Do a two-phase scatter if two-phase is enabled and the average number of selection_flags items per thread is
// greater than one
if (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS))
{
ScatterSelectedTwoPhase<IS_LAST_TILE, IS_FIRST_TILE>(
items, selection_flags, selection_indices, num_tile_selections, num_selections_prefix);
}
else
{
ScatterSelectedDirect<IS_LAST_TILE, IS_FIRST_TILE>(
items, selection_flags, selection_indices, num_selections);
}
}

/**
* @brief Scatter flagged items. Specialized for partitioning algorithm that writes rejected items to a second
* partition.
*
* @param num_tile_items
* Number of valid items in this tile
Expand All @@ -568,13 +625,14 @@ struct AgentSelectIf
* Marker type indicating whether to keep rejected items in the second partition
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterTwoPhase(InputT (&items)[ITEMS_PER_THREAD],
_CCCL_DEVICE _CCCL_FORCEINLINE void Scatter(InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items,
int num_tile_selections,
OffsetT num_selections_prefix,
OffsetT num_rejected_prefix,
OffsetT num_selections,
Int2Type<true> /*is_keep_rejects*/)
{
CTA_SYNC();
Expand All @@ -595,76 +653,83 @@ struct AgentSelectIf
temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM];
}

// Ensure all threads finished scattering to shared memory
CTA_SYNC();

// Gather items from shared memory and scatter to global
ScatterPartitionsToGlobal<IS_LAST_TILE, IS_FIRST_TILE>(
num_tile_items, tile_num_rejections, num_selections_prefix, num_rejected_prefix, d_selected_out);
}

/**
* @brief Second phase of scattering partitioned items to global memory. Specialized for partitioning to two
* distinct partitions.
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE, typename SelectedItT, typename RejectedItT>
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterPartitionsToGlobal(
int num_tile_items,
int tile_num_rejections,
OffsetT num_selections_prefix,
OffsetT num_rejected_prefix,
detail::partition_distinct_output_t<SelectedItT, RejectedItT> partitioned_out_it_wrapper)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x;
int rejection_idx = item_idx;
int selection_idx = item_idx - tile_num_rejections;
OffsetT scatter_offset = (item_idx < tile_num_rejections) ?
num_items - num_rejected_prefix - rejection_idx - 1 :
num_selections_prefix + selection_idx;
int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x;
int rejection_idx = item_idx;
int selection_idx = item_idx - tile_num_rejections;
OffsetT scatter_offset =
(item_idx < tile_num_rejections)
? num_rejected_prefix + rejection_idx
: num_selections_prefix + selection_idx;

InputT item = temp_storage.raw_exchange.Alias()[item_idx];

if (!IS_LAST_TILE || (item_idx < num_tile_items))
{
d_selected_out[scatter_offset] = item;
if (item_idx >= tile_num_rejections)
{
partitioned_out_it_wrapper.selected_it[scatter_offset] = item;
}
else
{
partitioned_out_it_wrapper.rejected_it[scatter_offset] = item;
}
}
}
}

/**
* @brief Scatter flagged items
*
* @param num_tile_items
* Number of valid items in this tile
*
* @param num_tile_selections
* Number of selections in this tile
*
* @param num_selections_prefix
* Total number of selections prior to this tile
*
* @param num_rejected_prefix
* Total number of rejections prior to this tile
*
* @param num_selections
* Total number of selections including this tile
* @brief Second phase of scattering partitioned items to global memory. Specialized for partitioning to a single
* iterator, where selected items are written in order from the beginning of the itereator and rejected items are
* writtem from the iterators end backwards.
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
_CCCL_DEVICE _CCCL_FORCEINLINE void Scatter(InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items,
int num_tile_selections,
OffsetT num_selections_prefix,
OffsetT num_rejected_prefix,
OffsetT num_selections)
template <bool IS_LAST_TILE, bool IS_FIRST_TILE, typename PartitionedOutputItT>
_CCCL_DEVICE _CCCL_FORCEINLINE void ScatterPartitionsToGlobal(
int num_tile_items,
int tile_num_rejections,
OffsetT num_selections_prefix,
OffsetT num_rejected_prefix,
PartitionedOutputItT partitioned_out_it)
{
// Do a two-phase scatter if (a) keeping both partitions or (b) two-phase is enabled and the average number of selection_flags items per thread is greater than one
if (KEEP_REJECTS || (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS)))
{
ScatterTwoPhase<IS_LAST_TILE, IS_FIRST_TILE>(
items,
selection_flags,
selection_indices,
num_tile_items,
num_tile_selections,
num_selections_prefix,
num_rejected_prefix,
Int2Type<KEEP_REJECTS>());
}
else
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
ScatterDirect<IS_LAST_TILE, IS_FIRST_TILE>(
items,
selection_flags,
selection_indices,
num_selections);
int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x;
int rejection_idx = item_idx;
int selection_idx = item_idx - tile_num_rejections;
OffsetT scatter_offset =
(item_idx < tile_num_rejections)
? num_items - num_rejected_prefix - rejection_idx - 1
: num_selections_prefix + selection_idx;

InputT item = temp_storage.raw_exchange.Alias()[item_idx];

if (!IS_LAST_TILE || (item_idx < num_tile_items))
{
partitioned_out_it[scatter_offset] = item;
}
}
}

Expand Down Expand Up @@ -736,7 +801,8 @@ struct AgentSelectIf
num_tile_selections,
0,
0,
num_tile_selections);
num_tile_selections,
cub::Int2Type<KEEP_REJECTS>{});

return num_tile_selections;
}
Expand Down Expand Up @@ -791,7 +857,7 @@ struct AgentSelectIf
OffsetT num_tile_selections = prefix_op.GetBlockAggregate();
OffsetT num_selections = prefix_op.GetInclusivePrefix();
OffsetT num_selections_prefix = prefix_op.GetExclusivePrefix();
OffsetT num_rejected_prefix = (tile_idx * TILE_ITEMS) - num_selections_prefix;
OffsetT num_rejected_prefix = tile_offset - num_selections_prefix;

// Discount any out-of-bounds selections
if (IS_LAST_TILE)
Expand All @@ -810,7 +876,8 @@ struct AgentSelectIf
num_tile_selections,
num_selections_prefix,
num_rejected_prefix,
num_selections);
num_selections,
cub::Int2Type<KEEP_REJECTS>{});

return num_selections;
}
Expand Down
Loading