Skip to content

Commit 327420b

Browse files
Cleanup CUB util_type.cuh (#1863)
* Cleanup CUB util_type.cuh * Replace all uses of __conditional_t in CUB if _If, which does not need to instantiate the type not selected.
1 parent 91b78d8 commit 327420b

40 files changed

+289
-321
lines changed

cub/benchmarks/bench/radix_sort/keys.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
#include <cub/device/device_radix_sort.cuh>
2929

30+
#include <cuda/std/type_traits>
31+
3032
#include <nvbench_helper.cuh>
3133

3234
// %//RANGE//% TUNE_RADIX_BITS bits 8:9:1
@@ -46,7 +48,7 @@ struct policy_hub_t
4648
{
4749
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;
4850

49-
using DominantT = cub::detail::conditional_t<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
51+
using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
5052

5153
struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t>
5254
{

cub/benchmarks/bench/radix_sort/pairs.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
#include <cub/device/device_radix_sort.cuh>
2929

30+
#include <cuda/std/type_traits>
31+
3032
#include <nvbench_helper.cuh>
3133

3234
// %//RANGE//% TUNE_RADIX_BITS bits 8:9:1
@@ -44,7 +46,7 @@ struct policy_hub_t
4446
{
4547
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;
4648

47-
using DominantT = cub::detail::conditional_t<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
49+
using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
4850

4951
struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t>
5052
{

cub/cub/agent/agent_histogram.cuh

+5-3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
#include <cub/iterator/cache_modified_input_iterator.cuh>
5050
#include <cub/util_type.cuh>
5151

52+
#include <cuda/std/type_traits>
53+
5254
#include <iterator>
5355

5456
CUB_NAMESPACE_BEGIN
@@ -225,9 +227,9 @@ struct AgentHistogram
225227
// Wrap the native input pointer with CacheModifiedInputIterator
226228
// or directly use the supplied input iterator type
227229
using WrappedSampleIteratorT =
228-
cub::detail::conditional_t<std::is_pointer<SampleIteratorT>::value,
229-
CacheModifiedInputIterator<LOAD_MODIFIER, SampleT, OffsetT>,
230-
SampleIteratorT>;
230+
::cuda::std::_If<std::is_pointer<SampleIteratorT>::value,
231+
CacheModifiedInputIterator<LOAD_MODIFIER, SampleT, OffsetT>,
232+
SampleIteratorT>;
231233

232234
/// Pixel input iterator type (for applying cache modifier)
233235
using WrappedPixelIteratorT = CacheModifiedInputIterator<LOAD_MODIFIER, PixelT, OffsetT>;

cub/cub/agent/agent_radix_sort_onesweep.cuh

+4-2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
#include <cub/util_ptx.cuh>
5050
#include <cub/util_type.cuh>
5151

52+
#include <cuda/std/type_traits>
53+
5254
CUB_NAMESPACE_BEGIN
5355

5456
/** \brief cub::RadixSortStoreAlgorithm enumerates different algorithms to write
@@ -146,10 +148,10 @@ struct AgentRadixSortOnesweep
146148
|| RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR,
147149
"for onesweep agent, the ranking algorithm must warp-strided key arrangement");
148150

149-
using BlockRadixRankT = cub::detail::conditional_t<
151+
using BlockRadixRankT = ::cuda::std::_If<
150152
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR,
151153
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, false, SCAN_ALGORITHM, WARP_MATCH_ATOMIC_OR, RANK_NUM_PARTS>,
152-
cub::detail::conditional_t<
154+
::cuda::std::_If<
153155
RANK_ALGORITHM == RADIX_RANK_MATCH,
154156
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, false, SCAN_ALGORITHM>,
155157
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, false, SCAN_ALGORITHM, WARP_MATCH_ANY, RANK_NUM_PARTS>>>;

cub/cub/agent/agent_reduce.cuh

+5-3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
#include <cub/iterator/cache_modified_input_iterator.cuh>
5252
#include <cub/util_type.cuh>
5353

54+
#include <cuda/std/type_traits>
55+
5456
#include <iterator>
5557

5658
_CCCL_SUPPRESS_DEPRECATED_PUSH
@@ -145,9 +147,9 @@ struct AgentReduce
145147
// Wrap the native input pointer with CacheModifiedInputIterator
146148
// or directly use the supplied input iterator type
147149
using WrappedInputIteratorT =
148-
cub::detail::conditional_t<std::is_pointer<InputIteratorT>::value,
149-
CacheModifiedInputIterator<AgentReducePolicy::LOAD_MODIFIER, InputT, OffsetT>,
150-
InputIteratorT>;
150+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
151+
CacheModifiedInputIterator<AgentReducePolicy::LOAD_MODIFIER, InputT, OffsetT>,
152+
InputIteratorT>;
151153

152154
/// Constants
153155
static constexpr int BLOCK_THREADS = AgentReducePolicy::BLOCK_THREADS;

cub/cub/agent/agent_reduce_by_key.cuh

+11-9
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
#include <cub/iterator/cache_modified_input_iterator.cuh>
5252
#include <cub/iterator/constant_input_iterator.cuh>
5353

54+
#include <cuda/std/type_traits>
55+
5456
#include <iterator>
5557

5658
CUB_NAMESPACE_BEGIN
@@ -225,27 +227,27 @@ struct AgentReduceByKey
225227
// CacheModifiedValuesInputIterator or directly use the supplied input
226228
// iterator type
227229
using WrappedKeysInputIteratorT =
228-
cub::detail::conditional_t<std::is_pointer<KeysInputIteratorT>::value,
229-
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,
230-
KeysInputIteratorT>;
230+
::cuda::std::_If<std::is_pointer<KeysInputIteratorT>::value,
231+
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,
232+
KeysInputIteratorT>;
231233

232234
// Cache-modified Input iterator wrapper type (for applying cache modifier)
233235
// for values Wrap the native input pointer with
234236
// CacheModifiedValuesInputIterator or directly use the supplied input
235237
// iterator type
236238
using WrappedValuesInputIteratorT =
237-
cub::detail::conditional_t<std::is_pointer<ValuesInputIteratorT>::value,
238-
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
239-
ValuesInputIteratorT>;
239+
::cuda::std::_If<std::is_pointer<ValuesInputIteratorT>::value,
240+
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
241+
ValuesInputIteratorT>;
240242

241243
// Cache-modified Input iterator wrapper type (for applying cache modifier)
242244
// for fixup values Wrap the native input pointer with
243245
// CacheModifiedValuesInputIterator or directly use the supplied input
244246
// iterator type
245247
using WrappedFixupInputIteratorT =
246-
cub::detail::conditional_t<std::is_pointer<AggregatesOutputIteratorT>::value,
247-
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
248-
AggregatesOutputIteratorT>;
248+
::cuda::std::_If<std::is_pointer<AggregatesOutputIteratorT>::value,
249+
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
250+
AggregatesOutputIteratorT>;
249251

250252
// Reduce-value-by-segment scan operator
251253
using ReduceBySegmentOpT = ReduceBySegmentOp<ReductionOpT>;

cub/cub/agent/agent_rle.cuh

+6-4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
#include <cub/iterator/cache_modified_input_iterator.cuh>
5555
#include <cub/iterator/constant_input_iterator.cuh>
5656

57+
#include <cuda/std/type_traits>
58+
5759
#include <iterator>
5860

5961
CUB_NAMESPACE_BEGIN
@@ -231,9 +233,9 @@ struct AgentRle
231233
// Wrap the native input pointer with CacheModifiedVLengthnputIterator
232234
// Directly use the supplied input iterator type
233235
using WrappedInputIteratorT =
234-
cub::detail::conditional_t<std::is_pointer<InputIteratorT>::value,
235-
CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,
236-
InputIteratorT>;
236+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
237+
CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,
238+
InputIteratorT>;
237239

238240
// Parameterized BlockLoad type for data
239241
using BlockLoadT =
@@ -257,7 +259,7 @@ struct AgentRle
257259
using WarpExchangePairs = WarpExchange<LengthOffsetPair, ITEMS_PER_THREAD>;
258260

259261
using WarpExchangePairsStorage =
260-
cub::detail::conditional_t<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>;
262+
::cuda::std::_If<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>;
261263

262264
using WarpExchangeOffsets = WarpExchange<OffsetT, ITEMS_PER_THREAD>;
263265
using WarpExchangeLengths = WarpExchange<LengthT, ITEMS_PER_THREAD>;

cub/cub/agent/agent_scan.cuh

+5-3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
#include <cub/grid/grid_queue.cuh>
5151
#include <cub/iterator/cache_modified_input_iterator.cuh>
5252

53+
#include <cuda/std/type_traits>
54+
5355
#include <iterator>
5456

5557
CUB_NAMESPACE_BEGIN
@@ -157,9 +159,9 @@ struct AgentScan
157159
// Wrap the native input pointer with CacheModifiedInputIterator
158160
// or directly use the supplied input iterator type
159161
using WrappedInputIteratorT =
160-
cub::detail::conditional_t<std::is_pointer<InputIteratorT>::value,
161-
CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
162-
InputIteratorT>;
162+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
163+
CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
164+
InputIteratorT>;
163165

164166
// Constants
165167
enum

cub/cub/agent/agent_scan_by_key.cuh

+8-6
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
#include <cub/iterator/cache_modified_input_iterator.cuh>
5151
#include <cub/util_type.cuh>
5252

53+
#include <cuda/std/type_traits>
54+
5355
#include <iterator>
5456

5557
CUB_NAMESPACE_BEGIN
@@ -152,14 +154,14 @@ struct AgentScanByKey
152154
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
153155

154156
using WrappedKeysInputIteratorT =
155-
cub::detail::conditional_t<std::is_pointer<KeysInputIteratorT>::value,
156-
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>,
157-
KeysInputIteratorT>;
157+
::cuda::std::_If<std::is_pointer<KeysInputIteratorT>::value,
158+
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>,
159+
KeysInputIteratorT>;
158160

159161
using WrappedValuesInputIteratorT =
160-
cub::detail::conditional_t<std::is_pointer<ValuesInputIteratorT>::value,
161-
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
162-
ValuesInputIteratorT>;
162+
::cuda::std::_If<std::is_pointer<ValuesInputIteratorT>::value,
163+
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
164+
ValuesInputIteratorT>;
163165

164166
using BlockLoadKeysT = BlockLoad<KeyT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentScanByKeyPolicyT::LOAD_ALGORITHM>;
165167

cub/cub/agent/agent_segment_fixup.cuh

+9-7
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
#include <cub/iterator/cache_modified_input_iterator.cuh>
5353
#include <cub/iterator/constant_input_iterator.cuh>
5454

55+
#include <cuda/std/type_traits>
56+
5557
#include <iterator>
5658

5759
CUB_NAMESPACE_BEGIN
@@ -171,18 +173,18 @@ struct AgentSegmentFixup
171173
// Cache-modified Input iterator wrapper type (for applying cache modifier) for keys
172174
// Wrap the native input pointer with CacheModifiedValuesInputIterator
173175
// or directly use the supplied input iterator type
174-
using WrappedPairsInputIteratorT = cub::detail::conditional_t<
175-
std::is_pointer<PairsInputIteratorT>::value,
176-
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>,
177-
PairsInputIteratorT>;
176+
using WrappedPairsInputIteratorT =
177+
::cuda::std::_If<std::is_pointer<PairsInputIteratorT>::value,
178+
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>,
179+
PairsInputIteratorT>;
178180

179181
// Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values
180182
// Wrap the native input pointer with CacheModifiedValuesInputIterator
181183
// or directly use the supplied input iterator type
182184
using WrappedFixupInputIteratorT =
183-
cub::detail::conditional_t<std::is_pointer<AggregatesOutputIteratorT>::value,
184-
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>,
185-
AggregatesOutputIteratorT>;
185+
::cuda::std::_If<std::is_pointer<AggregatesOutputIteratorT>::value,
186+
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>,
187+
AggregatesOutputIteratorT>;
186188

187189
// Reduce-value-by-segment scan operator
188190
using ReduceBySegmentOpT = ReduceByKeyOp<cub::Sum>;

cub/cub/agent/agent_select_if.cuh

+6-6
Original file line numberDiff line numberDiff line change
@@ -219,17 +219,17 @@ struct AgentSelectIf
219219
// Wrap the native input pointer with CacheModifiedValuesInputIterator
220220
// or directly use the supplied input iterator type
221221
using WrappedInputIteratorT =
222-
cub::detail::conditional_t<::cuda::std::is_pointer<InputIteratorT>::value,
223-
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
224-
InputIteratorT>;
222+
::cuda::std::_If<::cuda::std::is_pointer<InputIteratorT>::value,
223+
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
224+
InputIteratorT>;
225225

226226
// Cache-modified Input iterator wrapper type (for applying cache modifier) for values
227227
// Wrap the native input pointer with CacheModifiedValuesInputIterator
228228
// or directly use the supplied input iterator type
229229
using WrappedFlagsInputIteratorT =
230-
cub::detail::conditional_t<::cuda::std::is_pointer<FlagsInputIteratorT>::value,
231-
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>,
232-
FlagsInputIteratorT>;
230+
::cuda::std::_If<::cuda::std::is_pointer<FlagsInputIteratorT>::value,
231+
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>,
232+
FlagsInputIteratorT>;
233233

234234
// Parameterized BlockLoad type for input data
235235
using BlockLoadT = BlockLoad<InputT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentSelectIfPolicyT::LOAD_ALGORITHM>;

cub/cub/agent/agent_spmv_orig.cuh

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
#include <cub/thread/thread_search.cuh>
5353
#include <cub/util_type.cuh>
5454

55+
#include <cuda/std/type_traits>
56+
5557
#include <iterator>
5658

5759
CUB_NAMESPACE_BEGIN
@@ -264,7 +266,7 @@ struct AgentSpmv
264266
{
265267
// Value type to pair with index type OffsetT
266268
// (NullType if loading values directly during merge)
267-
using MergeValueT = cub::detail::conditional_t<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>;
269+
using MergeValueT = ::cuda::std::_If<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>;
268270

269271
OffsetT row_end_offset;
270272
MergeValueT nonzero;

cub/cub/agent/agent_three_way_partition.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ struct AgentThreeWayPartition
197197
static constexpr int TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD;
198198

199199
using WrappedInputIteratorT =
200-
cub::detail::conditional_t<std::is_pointer<InputIteratorT>::value,
201-
cub::CacheModifiedInputIterator<PolicyT::LOAD_MODIFIER, InputT, OffsetT>,
202-
InputIteratorT>;
200+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
201+
cub::CacheModifiedInputIterator<PolicyT::LOAD_MODIFIER, InputT, OffsetT>,
202+
InputIteratorT>;
203203

204204
// Parameterized BlockLoad type for input data
205205
using BlockLoadT = cub::BlockLoad<InputT, BLOCK_THREADS, ITEMS_PER_THREAD, PolicyT::LOAD_ALGORITHM>;

0 commit comments

Comments
 (0)