diff --git a/cub/agent/agent_select_if.cuh b/cub/agent/agent_select_if.cuh index 807ea18c6d..aded8589e2 100644 --- a/cub/agent/agent_select_if.cuh +++ b/cub/agent/agent_select_if.cuh @@ -107,10 +107,6 @@ struct AgentSelectIf // The input value type using InputT = cub::detail::value_t; - // The output value type - using OutputT = - cub::detail::non_void_value_t; - // The flag value type using FlagT = cub::detail::value_t; @@ -156,7 +152,7 @@ struct AgentSelectIf FlagsInputIteratorT>; // Parameterized BlockLoad type for input data - using BlockLoadT = BlockLoad; @@ -168,7 +164,7 @@ struct AgentSelectIf AgentSelectIfPolicyT::LOAD_ALGORITHM>; // Parameterized BlockDiscontinuity type for items - using BlockDiscontinuityT = BlockDiscontinuity; + using BlockDiscontinuityT = BlockDiscontinuity; // Parameterized BlockScan type using BlockScanT = @@ -179,7 +175,7 @@ struct AgentSelectIf TilePrefixCallbackOp; // Item exchange type - typedef OutputT ItemExchangeT[TILE_ITEMS]; + typedef InputT ItemExchangeT[TILE_ITEMS]; // Shared memory type for this thread block union _TempStorage @@ -254,7 +250,7 @@ struct AgentSelectIf __device__ __forceinline__ void InitializeSelections( OffsetT /*tile_offset*/, OffsetT num_tile_items, - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { @@ -277,7 +273,7 @@ struct AgentSelectIf __device__ __forceinline__ void InitializeSelections( OffsetT tile_offset, OffsetT num_tile_items, - OutputT (&/*items*/)[ITEMS_PER_THREAD], + InputT (&/*items*/)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { @@ -311,7 +307,7 @@ struct AgentSelectIf __device__ __forceinline__ void InitializeSelections( OffsetT tile_offset, OffsetT num_tile_items, - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { @@ -324,7 +320,7 @@ struct AgentSelectIf } else { - OutputT tile_predecessor; + InputT tile_predecessor; if (threadIdx.x == 0) tile_predecessor = d_in[tile_offset - 1]; @@ -353,7 +349,7 @@ struct AgentSelectIf */ template __device__ __forceinline__ void ScatterDirect( - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], OffsetT num_selections) @@ -378,7 +374,7 @@ struct AgentSelectIf */ template __device__ __forceinline__ void ScatterTwoPhase( - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int /*num_tile_items*/, ///< Number of valid items in this tile @@ -414,7 +410,7 @@ struct AgentSelectIf */ template __device__ __forceinline__ void ScatterTwoPhase( - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int num_tile_items, ///< Number of valid items in this tile @@ -454,7 +450,7 @@ struct AgentSelectIf num_items - num_rejected_prefix - rejection_idx - 1 : num_selections_prefix + selection_idx; - OutputT item = temp_storage.raw_exchange.Alias()[item_idx]; + InputT item = temp_storage.raw_exchange.Alias()[item_idx]; if (!IS_LAST_TILE || (item_idx < num_tile_items)) { @@ -469,7 +465,7 @@ struct AgentSelectIf */ template __device__ __forceinline__ void Scatter( - OutputT (&items)[ITEMS_PER_THREAD], + InputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int num_tile_items, ///< Number of valid items in this tile @@ -515,7 +511,7 @@ struct AgentSelectIf OffsetT tile_offset, ///< Tile offset ScanTileStateT& tile_state) ///< Global tile state descriptor { - OutputT items[ITEMS_PER_THREAD]; + InputT items[ITEMS_PER_THREAD]; OffsetT selection_flags[ITEMS_PER_THREAD]; OffsetT selection_indices[ITEMS_PER_THREAD]; @@ -575,7 +571,7 @@ struct AgentSelectIf OffsetT tile_offset, ///< Tile offset ScanTileStateT& tile_state) ///< Global tile state descriptor { - OutputT items[ITEMS_PER_THREAD]; + InputT items[ITEMS_PER_THREAD]; OffsetT selection_flags[ITEMS_PER_THREAD]; OffsetT selection_indices[ITEMS_PER_THREAD]; diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh index b188c75fae..ce485a26ed 100644 --- a/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/device/dispatch/dispatch_radix_sort.cuh @@ -1278,9 +1278,11 @@ struct DispatchRadixSort : const PortionOffsetT PORTION_SIZE = ((1 << 28) - 1) / ONESWEEP_TILE_ITEMS * ONESWEEP_TILE_ITEMS; int num_passes = cub::DivideAndRoundUp(end_bit - begin_bit, RADIX_BITS); OffsetT num_portions = static_cast(cub::DivideAndRoundUp(num_items, PORTION_SIZE)); - PortionOffsetT max_num_blocks = cub::DivideAndRoundUp(CUB_MIN(num_items, PORTION_SIZE), - ONESWEEP_TILE_ITEMS); - + PortionOffsetT max_num_blocks = cub::DivideAndRoundUp( + static_cast( + CUB_MIN(num_items, static_cast(PORTION_SIZE))), + ONESWEEP_TILE_ITEMS); + size_t value_size = KEYS_ONLY ? 0 : sizeof(ValueT); size_t allocation_sizes[] = { @@ -1355,7 +1357,10 @@ struct DispatchRadixSort : int num_bits = CUB_MIN(end_bit - current_bit, RADIX_BITS); for (OffsetT portion = 0; portion < num_portions; ++portion) { - PortionOffsetT portion_num_items = CUB_MIN(num_items - portion * PORTION_SIZE, PORTION_SIZE); + PortionOffsetT portion_num_items = + static_cast( + CUB_MIN(num_items - portion * PORTION_SIZE, + static_cast(PORTION_SIZE))); PortionOffsetT num_blocks = cub::DivideAndRoundUp(portion_num_items, ONESWEEP_TILE_ITEMS); if (CubDebug(error = cudaMemsetAsync( diff --git a/cub/device/dispatch/dispatch_select_if.cuh b/cub/device/dispatch/dispatch_select_if.cuh index 94a17419a8..5654ba29a3 100644 --- a/cub/device/dispatch/dispatch_select_if.cuh +++ b/cub/device/dispatch/dispatch_select_if.cuh @@ -129,10 +129,8 @@ struct DispatchSelectIf * Types and constants ******************************************************************************/ - // The output value type - using OutputT = - cub::detail::non_void_value_t>; + // The input value type + using InputT = cub::detail::value_t; // The flag value type using FlagT = cub::detail::value_t; @@ -155,7 +153,7 @@ struct DispatchSelectIf { enum { NOMINAL_4B_ITEMS_PER_THREAD = 10, - ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(OutputT)))), + ITEMS_PER_THREAD = CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(InputT)))), }; typedef AgentSelectIfPolicy< diff --git a/test/test_device_radix_sort.cu b/test/test_device_radix_sort.cu index 698ee6b4ae..e2e2967423 100644 --- a/test/test_device_radix_sort.cu +++ b/test/test_device_radix_sort.cu @@ -35,6 +35,7 @@ #include #include +#include #include #include #include @@ -283,9 +284,11 @@ cudaError_t Dispatch( cudaStream_t stream, bool debug_synchronous) { + AssertTrue(num_items < std::numeric_limits::max()); + return DeviceSegmentedRadixSort::SortPairs( d_temp_storage, temp_storage_bytes, - d_keys, d_values, num_items, + d_keys, d_values, static_cast(num_items), num_segments, d_segment_begin_offsets, d_segment_end_offsets, begin_bit, end_bit, stream, debug_synchronous); } @@ -317,13 +320,15 @@ cudaError_t Dispatch( cudaStream_t stream, bool debug_synchronous) { + AssertTrue(num_items < std::numeric_limits::max()); + KeyT const *const_keys_itr = d_keys.Current(); ValueT const *const_values_itr = d_values.Current(); cudaError_t retval = DeviceSegmentedRadixSort::SortPairs( d_temp_storage, temp_storage_bytes, const_keys_itr, d_keys.Alternate(), const_values_itr, d_values.Alternate(), - num_items, num_segments, d_segment_begin_offsets, d_segment_end_offsets, + static_cast(num_items), num_segments, d_segment_begin_offsets, d_segment_end_offsets, begin_bit, end_bit, stream, debug_synchronous); d_keys.selector ^= 1; @@ -359,9 +364,11 @@ cudaError_t Dispatch( cudaStream_t stream, bool debug_synchronous) { + AssertTrue(num_items < std::numeric_limits::max()); + return DeviceSegmentedRadixSort::SortPairsDescending( d_temp_storage, temp_storage_bytes, - d_keys, d_values, num_items, + d_keys, d_values, static_cast(num_items), num_segments, d_segment_begin_offsets, d_segment_end_offsets, begin_bit, end_bit, stream, debug_synchronous); } @@ -393,13 +400,15 @@ cudaError_t Dispatch( cudaStream_t stream, bool debug_synchronous) { + AssertTrue(num_items < std::numeric_limits::max()); + KeyT const *const_keys_itr = d_keys.Current(); ValueT const *const_values_itr = d_values.Current(); cudaError_t retval = DeviceSegmentedRadixSort::SortPairsDescending( d_temp_storage, temp_storage_bytes, const_keys_itr, d_keys.Alternate(), const_values_itr, d_values.Alternate(), - num_items, num_segments, d_segment_begin_offsets, d_segment_end_offsets, + static_cast(num_items), num_segments, d_segment_begin_offsets, d_segment_end_offsets, begin_bit, end_bit, stream, debug_synchronous); d_keys.selector ^= 1; diff --git a/test/test_device_select_if.cu b/test/test_device_select_if.cu index c3cc1d8e2a..dc4290e233 100644 --- a/test/test_device_select_if.cu +++ b/test/test_device_select_if.cu @@ -41,6 +41,12 @@ #include #include +#include +#include +#include +#include +#include + #include "test_util.h" using namespace cub; @@ -652,6 +658,74 @@ void Test( } } +template +struct pair_to_col_t +{ + __host__ __device__ T0 operator()(const thrust::tuple &in) + { + return thrust::get<0>(in); + } +}; + +template +struct select_t { + __host__ __device__ bool operator()(const thrust::tuple &in) { + return static_cast(thrust::get<0>(in)) > thrust::get<1>(in); + } +}; + +template +void TestMixedOp(int num_items) +{ + const T0 target_value = static_cast(42); + thrust::device_vector col_a(num_items, target_value); + thrust::device_vector col_b(num_items, static_cast(4.2)); + + thrust::device_vector result(num_items); + + auto in = thrust::make_zip_iterator(col_a.begin(), col_b.begin()); + auto out = thrust::make_transform_output_iterator(result.begin(), pair_to_col_t{}); + + void *d_tmp_storage {}; + std::size_t tmp_storage_size{}; + cub::DeviceSelect::If( + d_tmp_storage, tmp_storage_size, + in, out, thrust::make_discard_iterator(), + num_items, select_t{}, + 0, true); + + thrust::device_vector tmp_storage(tmp_storage_size); + d_tmp_storage = thrust::raw_pointer_cast(tmp_storage.data()); + + cub::DeviceSelect::If( + d_tmp_storage, tmp_storage_size, + in, out, thrust::make_discard_iterator(), + num_items, select_t{}, + 0, true); + + AssertEquals(num_items, thrust::count(result.begin(), result.end(), target_value)); +} + +/** + * Test different input sizes + */ +template +void TestMixed(int num_items) +{ + if (num_items < 0) + { + TestMixedOp(0); + TestMixedOp(1); + TestMixedOp(100); + TestMixedOp(10000); + TestMixedOp(1000000); + } + else + { + TestMixedOp(num_items); + } +} + //--------------------------------------------------------------------- // Main //--------------------------------------------------------------------- @@ -708,6 +782,8 @@ int main(int argc, char** argv) Test(num_items); Test(num_items); + TestMixed(num_items); + return 0; } diff --git a/test/test_device_select_unique.cu b/test/test_device_select_unique.cu index 497b43fc36..c7f6679278 100644 --- a/test/test_device_select_unique.cu +++ b/test/test_device_select_unique.cu @@ -40,6 +40,10 @@ #include #include +#include +#include +#include + #include "test_util.h" using namespace cub; @@ -474,6 +478,57 @@ void Test( } } +template +void TestIteratorOp(int num_items) +{ + void *d_temp_storage{}; + std::size_t temp_storage_size{}; + + thrust::device_vector num_selected(1); + + auto in = thrust::make_counting_iterator(static_cast(0)); + auto out = thrust::make_discard_iterator(); + + CubDebugExit(cub::DeviceSelect::Unique(d_temp_storage, + temp_storage_size, + in, + out, + num_selected.begin(), + num_items, + 0, + true)); + + thrust::device_vector temp_storage(temp_storage_size); + d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); + + CubDebugExit(cub::DeviceSelect::Unique(d_temp_storage, + temp_storage_size, + in, + out, + num_selected.begin(), + num_items, + 0, + true)); + + AssertEquals(num_selected[0], num_items); +} + +template +void TestIterator(int num_items) +{ + if (num_items < 0) + { + TestIteratorOp(0); + TestIteratorOp(1); + TestIteratorOp(100); + TestIteratorOp(10000); + TestIteratorOp(1000000); + } + else + { + TestIteratorOp(num_items); + } +} //--------------------------------------------------------------------- @@ -536,6 +591,8 @@ int main(int argc, char** argv) Test(num_items); Test(num_items); + TestIterator(num_items); + return 0; }