Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #444 from senior-zero/fix-main/github/select_if_mi…
Browse files Browse the repository at this point in the history
…xed_types

Fix select if for mixed types
  • Loading branch information
alliepiper authored Mar 18, 2022
2 parents 3e27978 + 94c3e1f commit cdcec9c
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 31 deletions.
32 changes: 14 additions & 18 deletions cub/agent/agent_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,6 @@ struct AgentSelectIf
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

// The output value type
using OutputT =
cub::detail::non_void_value_t<SelectedOutputIteratorT, InputT>;

// The flag value type
using FlagT = cub::detail::value_t<FlagsInputIteratorT>;

Expand Down Expand Up @@ -156,7 +152,7 @@ struct AgentSelectIf
FlagsInputIteratorT>;

// Parameterized BlockLoad type for input data
using BlockLoadT = BlockLoad<OutputT,
using BlockLoadT = BlockLoad<InputT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentSelectIfPolicyT::LOAD_ALGORITHM>;
Expand All @@ -168,7 +164,7 @@ struct AgentSelectIf
AgentSelectIfPolicyT::LOAD_ALGORITHM>;

// Parameterized BlockDiscontinuity type for items
using BlockDiscontinuityT = BlockDiscontinuity<OutputT, BLOCK_THREADS>;
using BlockDiscontinuityT = BlockDiscontinuity<InputT, BLOCK_THREADS>;

// Parameterized BlockScan type
using BlockScanT =
Expand All @@ -179,7 +175,7 @@ struct AgentSelectIf
TilePrefixCallbackOp<OffsetT, cub::Sum, ScanTileStateT>;

// Item exchange type
typedef OutputT ItemExchangeT[TILE_ITEMS];
typedef InputT ItemExchangeT[TILE_ITEMS];

// Shared memory type for this thread block
union _TempStorage
Expand Down Expand Up @@ -254,7 +250,7 @@ struct AgentSelectIf
__device__ __forceinline__ void InitializeSelections(
OffsetT /*tile_offset*/,
OffsetT num_tile_items,
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_SELECT_OP> /*select_method*/)
{
Expand All @@ -277,7 +273,7 @@ struct AgentSelectIf
__device__ __forceinline__ void InitializeSelections(
OffsetT tile_offset,
OffsetT num_tile_items,
OutputT (&/*items*/)[ITEMS_PER_THREAD],
InputT (&/*items*/)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_SELECT_FLAGS> /*select_method*/)
{
Expand Down Expand Up @@ -311,7 +307,7 @@ struct AgentSelectIf
__device__ __forceinline__ void InitializeSelections(
OffsetT tile_offset,
OffsetT num_tile_items,
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_DISCONTINUITY> /*select_method*/)
{
Expand All @@ -324,7 +320,7 @@ struct AgentSelectIf
}
else
{
OutputT tile_predecessor;
InputT tile_predecessor;
if (threadIdx.x == 0)
tile_predecessor = d_in[tile_offset - 1];

Expand Down Expand Up @@ -353,7 +349,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterDirect(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
OffsetT num_selections)
Expand All @@ -378,7 +374,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int /*num_tile_items*/, ///< Number of valid items in this tile
Expand Down Expand Up @@ -414,7 +410,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items, ///< Number of valid items in this tile
Expand Down Expand Up @@ -454,7 +450,7 @@ struct AgentSelectIf
num_items - num_rejected_prefix - rejection_idx - 1 :
num_selections_prefix + selection_idx;

OutputT item = temp_storage.raw_exchange.Alias()[item_idx];
InputT item = temp_storage.raw_exchange.Alias()[item_idx];

if (!IS_LAST_TILE || (item_idx < num_tile_items))
{
Expand All @@ -469,7 +465,7 @@ struct AgentSelectIf
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void Scatter(
OutputT (&items)[ITEMS_PER_THREAD],
InputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items, ///< Number of valid items in this tile
Expand Down Expand Up @@ -515,7 +511,7 @@ struct AgentSelectIf
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
OutputT items[ITEMS_PER_THREAD];
InputT items[ITEMS_PER_THREAD];
OffsetT selection_flags[ITEMS_PER_THREAD];
OffsetT selection_indices[ITEMS_PER_THREAD];

Expand Down Expand Up @@ -575,7 +571,7 @@ struct AgentSelectIf
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
OutputT items[ITEMS_PER_THREAD];
InputT items[ITEMS_PER_THREAD];
OffsetT selection_flags[ITEMS_PER_THREAD];
OffsetT selection_indices[ITEMS_PER_THREAD];

Expand Down
13 changes: 9 additions & 4 deletions cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1278,9 +1278,11 @@ struct DispatchRadixSort :
const PortionOffsetT PORTION_SIZE = ((1 << 28) - 1) / ONESWEEP_TILE_ITEMS * ONESWEEP_TILE_ITEMS;
int num_passes = cub::DivideAndRoundUp(end_bit - begin_bit, RADIX_BITS);
OffsetT num_portions = static_cast<OffsetT>(cub::DivideAndRoundUp(num_items, PORTION_SIZE));
PortionOffsetT max_num_blocks = cub::DivideAndRoundUp(CUB_MIN(num_items, PORTION_SIZE),
ONESWEEP_TILE_ITEMS);

PortionOffsetT max_num_blocks = cub::DivideAndRoundUp(
static_cast<int>(
CUB_MIN(num_items, static_cast<OffsetT>(PORTION_SIZE))),
ONESWEEP_TILE_ITEMS);

size_t value_size = KEYS_ONLY ? 0 : sizeof(ValueT);
size_t allocation_sizes[] =
{
Expand Down Expand Up @@ -1355,7 +1357,10 @@ struct DispatchRadixSort :
int num_bits = CUB_MIN(end_bit - current_bit, RADIX_BITS);
for (OffsetT portion = 0; portion < num_portions; ++portion)
{
PortionOffsetT portion_num_items = CUB_MIN(num_items - portion * PORTION_SIZE, PORTION_SIZE);
PortionOffsetT portion_num_items =
static_cast<PortionOffsetT>(
CUB_MIN(num_items - portion * PORTION_SIZE,
static_cast<OffsetT>(PORTION_SIZE)));
PortionOffsetT num_blocks =
cub::DivideAndRoundUp(portion_num_items, ONESWEEP_TILE_ITEMS);
if (CubDebug(error = cudaMemsetAsync(
Expand Down
8 changes: 3 additions & 5 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,8 @@ struct DispatchSelectIf
* Types and constants
******************************************************************************/

// The output value type
using OutputT =
cub::detail::non_void_value_t<SelectedOutputIteratorT,
cub::detail::value_t<InputIteratorT>>;
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

// The flag value type
using FlagT = cub::detail::value_t<FlagsInputIteratorT>;
Expand All @@ -155,7 +153,7 @@ struct DispatchSelectIf
{
enum {
NOMINAL_4B_ITEMS_PER_THREAD = 10,
ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(OutputT)))),
ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(InputT)))),
};

typedef AgentSelectIfPolicy<
Expand Down
17 changes: 13 additions & 4 deletions test/test_device_radix_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include <algorithm>
#include <cstdio>
#include <limits>
#include <memory>
#include <random>
#include <type_traits>
Expand Down Expand Up @@ -283,9 +284,11 @@ cudaError_t Dispatch(
cudaStream_t stream,
bool debug_synchronous)
{
AssertTrue(num_items < std::numeric_limits<int>::max());

return DeviceSegmentedRadixSort::SortPairs(
d_temp_storage, temp_storage_bytes,
d_keys, d_values, num_items,
d_keys, d_values, static_cast<int>(num_items),
num_segments, d_segment_begin_offsets, d_segment_end_offsets,
begin_bit, end_bit, stream, debug_synchronous);
}
Expand Down Expand Up @@ -317,13 +320,15 @@ cudaError_t Dispatch(
cudaStream_t stream,
bool debug_synchronous)
{
AssertTrue(num_items < std::numeric_limits<int>::max());

KeyT const *const_keys_itr = d_keys.Current();
ValueT const *const_values_itr = d_values.Current();

cudaError_t retval = DeviceSegmentedRadixSort::SortPairs(
d_temp_storage, temp_storage_bytes,
const_keys_itr, d_keys.Alternate(), const_values_itr, d_values.Alternate(),
num_items, num_segments, d_segment_begin_offsets, d_segment_end_offsets,
static_cast<int>(num_items), num_segments, d_segment_begin_offsets, d_segment_end_offsets,
begin_bit, end_bit, stream, debug_synchronous);

d_keys.selector ^= 1;
Expand Down Expand Up @@ -359,9 +364,11 @@ cudaError_t Dispatch(
cudaStream_t stream,
bool debug_synchronous)
{
AssertTrue(num_items < std::numeric_limits<int>::max());

return DeviceSegmentedRadixSort::SortPairsDescending(
d_temp_storage, temp_storage_bytes,
d_keys, d_values, num_items,
d_keys, d_values, static_cast<int>(num_items),
num_segments, d_segment_begin_offsets, d_segment_end_offsets,
begin_bit, end_bit, stream, debug_synchronous);
}
Expand Down Expand Up @@ -393,13 +400,15 @@ cudaError_t Dispatch(
cudaStream_t stream,
bool debug_synchronous)
{
AssertTrue(num_items < std::numeric_limits<int>::max());

KeyT const *const_keys_itr = d_keys.Current();
ValueT const *const_values_itr = d_values.Current();

cudaError_t retval = DeviceSegmentedRadixSort::SortPairsDescending(
d_temp_storage, temp_storage_bytes,
const_keys_itr, d_keys.Alternate(), const_values_itr, d_values.Alternate(),
num_items, num_segments, d_segment_begin_offsets, d_segment_end_offsets,
static_cast<int>(num_items), num_segments, d_segment_begin_offsets, d_segment_end_offsets,
begin_bit, end_bit, stream, debug_synchronous);

d_keys.selector ^= 1;
Expand Down
76 changes: 76 additions & 0 deletions test/test_device_select_if.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
#include <cub/device/device_partition.cuh>
#include <cub/iterator/counting_input_iterator.cuh>

#include <thrust/count.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/device_vector.h>

#include "test_util.h"

using namespace cub;
Expand Down Expand Up @@ -652,6 +658,74 @@ void Test(
}
}

template<class T0, class T1>
struct pair_to_col_t
{
__host__ __device__ T0 operator()(const thrust::tuple<T0, T1> &in)
{
return thrust::get<0>(in);
}
};

template<class T0, class T1>
struct select_t {
__host__ __device__ bool operator()(const thrust::tuple<T0, T1> &in) {
return static_cast<T1>(thrust::get<0>(in)) > thrust::get<1>(in);
}
};

template <typename T0, typename T1>
void TestMixedOp(int num_items)
{
const T0 target_value = static_cast<T0>(42);
thrust::device_vector<T0> col_a(num_items, target_value);
thrust::device_vector<T1> col_b(num_items, static_cast<T1>(4.2));

thrust::device_vector<T0> result(num_items);

auto in = thrust::make_zip_iterator(col_a.begin(), col_b.begin());
auto out = thrust::make_transform_output_iterator(result.begin(), pair_to_col_t<T0, T1>{});

void *d_tmp_storage {};
std::size_t tmp_storage_size{};
cub::DeviceSelect::If(
d_tmp_storage, tmp_storage_size,
in, out, thrust::make_discard_iterator(),
num_items, select_t<T0, T1>{},
0, true);

thrust::device_vector<char> tmp_storage(tmp_storage_size);
d_tmp_storage = thrust::raw_pointer_cast(tmp_storage.data());

cub::DeviceSelect::If(
d_tmp_storage, tmp_storage_size,
in, out, thrust::make_discard_iterator(),
num_items, select_t<T0, T1>{},
0, true);

AssertEquals(num_items, thrust::count(result.begin(), result.end(), target_value));
}

/**
* Test different input sizes
*/
template <typename T0, typename T1>
void TestMixed(int num_items)
{
if (num_items < 0)
{
TestMixedOp<T0, T1>(0);
TestMixedOp<T0, T1>(1);
TestMixedOp<T0, T1>(100);
TestMixedOp<T0, T1>(10000);
TestMixedOp<T0, T1>(1000000);
}
else
{
TestMixedOp<T0, T1>(num_items);
}
}

//---------------------------------------------------------------------
// Main
//---------------------------------------------------------------------
Expand Down Expand Up @@ -708,6 +782,8 @@ int main(int argc, char** argv)
Test<TestFoo>(num_items);
Test<TestBar>(num_items);

TestMixed<int, double>(num_items);

return 0;
}

Expand Down
Loading

0 comments on commit cdcec9c

Please sign in to comment.