Skip to content

Commit 0122db7

Browse files
Replace all uses of __conditional_t in CUB if _If
_If does not need to instantiate the type not selected.
1 parent 0afc149 commit 0122db7

36 files changed

+204
-164
lines changed

cub/benchmarks/bench/radix_sort/keys.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
#include <cub/device/device_radix_sort.cuh>
2929

30+
#include <cuda/std/type_traits>
31+
3032
#include <nvbench_helper.cuh>
3133

3234
// %//RANGE//% TUNE_RADIX_BITS bits 8:9:1
@@ -46,7 +48,7 @@ struct policy_hub_t
4648
{
4749
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;
4850

49-
using DominantT = ::cuda::std::__conditional_t<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
51+
using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
5052

5153
struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t>
5254
{

cub/benchmarks/bench/radix_sort/pairs.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
#include <cub/device/device_radix_sort.cuh>
2929

30+
#include <cuda/std/type_traits>
31+
3032
#include <nvbench_helper.cuh>
3133

3234
// %//RANGE//% TUNE_RADIX_BITS bits 8:9:1
@@ -44,7 +46,7 @@ struct policy_hub_t
4446
{
4547
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;
4648

47-
using DominantT = ::cuda::std::__conditional_t<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
49+
using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
4850

4951
struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t>
5052
{

cub/cub/agent/agent_histogram.cuh

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
#include <cub/iterator/cache_modified_input_iterator.cuh>
5050
#include <cub/util_type.cuh>
5151

52+
#include <cuda/std/type_traits>
53+
5254
#include <iterator>
5355

5456
CUB_NAMESPACE_BEGIN

cub/cub/agent/agent_radix_sort_onesweep.cuh

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
#include <cub/util_ptx.cuh>
5050
#include <cub/util_type.cuh>
5151

52+
#include <cuda/std/type_traits>
53+
5254
CUB_NAMESPACE_BEGIN
5355

5456
/** \brief cub::RadixSortStoreAlgorithm enumerates different algorithms to write

cub/cub/agent/agent_reduce.cuh

+5-3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
#include <cub/iterator/cache_modified_input_iterator.cuh>
5252
#include <cub/util_type.cuh>
5353

54+
#include <cuda/std/type_traits>
55+
5456
#include <iterator>
5557

5658
_CCCL_SUPPRESS_DEPRECATED_PUSH
@@ -145,9 +147,9 @@ struct AgentReduce
145147
// Wrap the native input pointer with CacheModifiedInputIterator
146148
// or directly use the supplied input iterator type
147149
using WrappedInputIteratorT =
148-
::cuda::std::__conditional_t<std::is_pointer<InputIteratorT>::value,
149-
CacheModifiedInputIterator<AgentReducePolicy::LOAD_MODIFIER, InputT, OffsetT>,
150-
InputIteratorT>;
150+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
151+
CacheModifiedInputIterator<AgentReducePolicy::LOAD_MODIFIER, InputT, OffsetT>,
152+
InputIteratorT>;
151153

152154
/// Constants
153155
static constexpr int BLOCK_THREADS = AgentReducePolicy::BLOCK_THREADS;

cub/cub/agent/agent_reduce_by_key.cuh

+13-11
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
#include <cub/iterator/cache_modified_input_iterator.cuh>
5252
#include <cub/iterator/constant_input_iterator.cuh>
5353

54+
#include <cuda/std/type_traits>
55+
5456
#include <iterator>
5557

5658
CUB_NAMESPACE_BEGIN
@@ -225,27 +227,27 @@ struct AgentReduceByKey
225227
// CacheModifiedValuesInputIterator or directly use the supplied input
226228
// iterator type
227229
using WrappedKeysInputIteratorT =
228-
::cuda::std::__conditional_t<std::is_pointer<KeysInputIteratorT>::value,
229-
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,
230-
KeysInputIteratorT>;
230+
::cuda::std::_If<std::is_pointer<KeysInputIteratorT>::value,
231+
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,
232+
KeysInputIteratorT>;
231233

232234
// Cache-modified Input iterator wrapper type (for applying cache modifier)
233235
// for values Wrap the native input pointer with
234236
// CacheModifiedValuesInputIterator or directly use the supplied input
235237
// iterator type
236-
using WrappedValuesInputIteratorT = ::cuda::std::__conditional_t<
237-
std::is_pointer<ValuesInputIteratorT>::value,
238-
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
239-
ValuesInputIteratorT>;
238+
using WrappedValuesInputIteratorT =
239+
::cuda::std::_If<std::is_pointer<ValuesInputIteratorT>::value,
240+
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
241+
ValuesInputIteratorT>;
240242

241243
// Cache-modified Input iterator wrapper type (for applying cache modifier)
242244
// for fixup values Wrap the native input pointer with
243245
// CacheModifiedValuesInputIterator or directly use the supplied input
244246
// iterator type
245-
using WrappedFixupInputIteratorT = ::cuda::std::__conditional_t<
246-
std::is_pointer<AggregatesOutputIteratorT>::value,
247-
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
248-
AggregatesOutputIteratorT>;
247+
using WrappedFixupInputIteratorT =
248+
::cuda::std::_If<std::is_pointer<AggregatesOutputIteratorT>::value,
249+
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
250+
AggregatesOutputIteratorT>;
249251

250252
// Reduce-value-by-segment scan operator
251253
using ReduceBySegmentOpT = ReduceBySegmentOp<ReductionOpT>;

cub/cub/agent/agent_rle.cuh

+6-4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
#include <cub/iterator/cache_modified_input_iterator.cuh>
5555
#include <cub/iterator/constant_input_iterator.cuh>
5656

57+
#include <cuda/std/type_traits>
58+
5759
#include <iterator>
5860

5961
CUB_NAMESPACE_BEGIN
@@ -231,9 +233,9 @@ struct AgentRle
231233
// Wrap the native input pointer with CacheModifiedVLengthnputIterator
232234
// Directly use the supplied input iterator type
233235
using WrappedInputIteratorT =
234-
::cuda::std::__conditional_t<std::is_pointer<InputIteratorT>::value,
235-
CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,
236-
InputIteratorT>;
236+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
237+
CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,
238+
InputIteratorT>;
237239

238240
// Parameterized BlockLoad type for data
239241
using BlockLoadT =
@@ -257,7 +259,7 @@ struct AgentRle
257259
using WarpExchangePairs = WarpExchange<LengthOffsetPair, ITEMS_PER_THREAD>;
258260

259261
using WarpExchangePairsStorage =
260-
::cuda::std::__conditional_t<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>;
262+
::cuda::std::_If<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>;
261263

262264
using WarpExchangeOffsets = WarpExchange<OffsetT, ITEMS_PER_THREAD>;
263265
using WarpExchangeLengths = WarpExchange<LengthT, ITEMS_PER_THREAD>;

cub/cub/agent/agent_scan.cuh

+5-3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
#include <cub/grid/grid_queue.cuh>
5151
#include <cub/iterator/cache_modified_input_iterator.cuh>
5252

53+
#include <cuda/std/type_traits>
54+
5355
#include <iterator>
5456

5557
CUB_NAMESPACE_BEGIN
@@ -157,9 +159,9 @@ struct AgentScan
157159
// Wrap the native input pointer with CacheModifiedInputIterator
158160
// or directly use the supplied input iterator type
159161
using WrappedInputIteratorT =
160-
::cuda::std::__conditional_t<std::is_pointer<InputIteratorT>::value,
161-
CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
162-
InputIteratorT>;
162+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
163+
CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
164+
InputIteratorT>;
163165

164166
// Constants
165167
enum

cub/cub/agent/agent_scan_by_key.cuh

+8-6
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
#include <cub/iterator/cache_modified_input_iterator.cuh>
5151
#include <cub/util_type.cuh>
5252

53+
#include <cuda/std/type_traits>
54+
5355
#include <iterator>
5456

5557
CUB_NAMESPACE_BEGIN
@@ -152,14 +154,14 @@ struct AgentScanByKey
152154
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
153155

154156
using WrappedKeysInputIteratorT =
155-
::cuda::std::__conditional_t<std::is_pointer<KeysInputIteratorT>::value,
156-
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>,
157-
KeysInputIteratorT>;
157+
::cuda::std::_If<std::is_pointer<KeysInputIteratorT>::value,
158+
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>,
159+
KeysInputIteratorT>;
158160

159161
using WrappedValuesInputIteratorT =
160-
::cuda::std::__conditional_t<std::is_pointer<ValuesInputIteratorT>::value,
161-
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
162-
ValuesInputIteratorT>;
162+
::cuda::std::_If<std::is_pointer<ValuesInputIteratorT>::value,
163+
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
164+
ValuesInputIteratorT>;
163165

164166
using BlockLoadKeysT = BlockLoad<KeyT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentScanByKeyPolicyT::LOAD_ALGORITHM>;
165167

cub/cub/agent/agent_segment_fixup.cuh

+9-7
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
#include <cub/iterator/cache_modified_input_iterator.cuh>
5353
#include <cub/iterator/constant_input_iterator.cuh>
5454

55+
#include <cuda/std/type_traits>
56+
5557
#include <iterator>
5658

5759
CUB_NAMESPACE_BEGIN
@@ -171,18 +173,18 @@ struct AgentSegmentFixup
171173
// Cache-modified Input iterator wrapper type (for applying cache modifier) for keys
172174
// Wrap the native input pointer with CacheModifiedValuesInputIterator
173175
// or directly use the supplied input iterator type
174-
using WrappedPairsInputIteratorT = ::cuda::std::__conditional_t<
175-
std::is_pointer<PairsInputIteratorT>::value,
176-
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>,
177-
PairsInputIteratorT>;
176+
using WrappedPairsInputIteratorT =
177+
::cuda::std::_If<std::is_pointer<PairsInputIteratorT>::value,
178+
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>,
179+
PairsInputIteratorT>;
178180

179181
// Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values
180182
// Wrap the native input pointer with CacheModifiedValuesInputIterator
181183
// or directly use the supplied input iterator type
182184
using WrappedFixupInputIteratorT =
183-
::cuda::std::__conditional_t<std::is_pointer<AggregatesOutputIteratorT>::value,
184-
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>,
185-
AggregatesOutputIteratorT>;
185+
::cuda::std::_If<std::is_pointer<AggregatesOutputIteratorT>::value,
186+
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>,
187+
AggregatesOutputIteratorT>;
186188

187189
// Reduce-value-by-segment scan operator
188190
using ReduceBySegmentOpT = ReduceByKeyOp<cub::Sum>;

cub/cub/agent/agent_select_if.cuh

+6-6
Original file line numberDiff line numberDiff line change
@@ -219,17 +219,17 @@ struct AgentSelectIf
219219
// Wrap the native input pointer with CacheModifiedValuesInputIterator
220220
// or directly use the supplied input iterator type
221221
using WrappedInputIteratorT =
222-
::cuda::std::__conditional_t<::cuda::std::is_pointer<InputIteratorT>::value,
223-
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
224-
InputIteratorT>;
222+
::cuda::std::_If<::cuda::std::is_pointer<InputIteratorT>::value,
223+
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
224+
InputIteratorT>;
225225

226226
// Cache-modified Input iterator wrapper type (for applying cache modifier) for values
227227
// Wrap the native input pointer with CacheModifiedValuesInputIterator
228228
// or directly use the supplied input iterator type
229229
using WrappedFlagsInputIteratorT =
230-
::cuda::std::__conditional_t<::cuda::std::is_pointer<FlagsInputIteratorT>::value,
231-
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>,
232-
FlagsInputIteratorT>;
230+
::cuda::std::_If<::cuda::std::is_pointer<FlagsInputIteratorT>::value,
231+
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>,
232+
FlagsInputIteratorT>;
233233

234234
// Parameterized BlockLoad type for input data
235235
using BlockLoadT = BlockLoad<InputT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentSelectIfPolicyT::LOAD_ALGORITHM>;

cub/cub/agent/agent_spmv_orig.cuh

+3-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
#include <cub/thread/thread_search.cuh>
5353
#include <cub/util_type.cuh>
5454

55+
#include <cuda/std/type_traits>
56+
5557
#include <iterator>
5658

5759
CUB_NAMESPACE_BEGIN
@@ -264,7 +266,7 @@ struct AgentSpmv
264266
{
265267
// Value type to pair with index type OffsetT
266268
// (NullType if loading values directly during merge)
267-
using MergeValueT = ::cuda::std::__conditional_t<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>;
269+
using MergeValueT = ::cuda::std::_If<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>;
268270

269271
OffsetT row_end_offset;
270272
MergeValueT nonzero;

cub/cub/agent/agent_three_way_partition.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ struct AgentThreeWayPartition
197197
static constexpr int TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD;
198198

199199
using WrappedInputIteratorT =
200-
::cuda::std::__conditional_t<std::is_pointer<InputIteratorT>::value,
201-
cub::CacheModifiedInputIterator<PolicyT::LOAD_MODIFIER, InputT, OffsetT>,
202-
InputIteratorT>;
200+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
201+
cub::CacheModifiedInputIterator<PolicyT::LOAD_MODIFIER, InputT, OffsetT>,
202+
InputIteratorT>;
203203

204204
// Parameterized BlockLoad type for input data
205205
using BlockLoadT = cub::BlockLoad<InputT, BLOCK_THREADS, ITEMS_PER_THREAD, PolicyT::LOAD_ALGORITHM>;

cub/cub/agent/single_pass_scan_operators.cuh

+15-19
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
#include <cub/util_temporary_storage.cuh>
5252
#include <cub/warp/warp_reduce.cuh>
5353

54+
#include <cuda/std/type_traits>
55+
5456
#include <iterator>
5557

5658
#include <nv/target>
@@ -476,16 +478,16 @@ using default_no_delay_t = default_no_delay_constructor_t::delay_t;
476478

477479
template <class T>
478480
using default_delay_constructor_t =
479-
::cuda::std::__conditional_t<Traits<T>::PRIMITIVE, fixed_delay_constructor_t<350, 450>, default_no_delay_constructor_t>;
481+
::cuda::std::_If<Traits<T>::PRIMITIVE, fixed_delay_constructor_t<350, 450>, default_no_delay_constructor_t>;
480482

481483
template <class T>
482484
using default_delay_t = typename default_delay_constructor_t<T>::delay_t;
483485

484486
template <class KeyT, class ValueT>
485487
using default_reduce_by_key_delay_constructor_t =
486-
::cuda::std::__conditional_t<(Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16),
487-
reduce_by_key_delay_constructor_t<350, 450>,
488-
default_delay_constructor_t<KeyValuePair<KeyT, ValueT>>>;
488+
::cuda::std::_If<(Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16),
489+
reduce_by_key_delay_constructor_t<350, 450>,
490+
default_delay_constructor_t<KeyValuePair<KeyT, ValueT>>>;
489491
} // namespace detail
490492

491493
/**
@@ -503,16 +505,13 @@ template <typename T>
503505
struct ScanTileState<T, true>
504506
{
505507
// Status word type
506-
using StatusWord = ::cuda::std::__conditional_t<
508+
using StatusWord = ::cuda::std::_If<
507509
sizeof(T) == 8,
508510
unsigned long long,
509-
::cuda::std::__conditional_t<sizeof(T) == 4,
510-
unsigned int,
511-
::cuda::std::__conditional_t<sizeof(T) == 2, unsigned short, unsigned char>>>;
511+
::cuda::std::_If<sizeof(T) == 4, unsigned int, ::cuda::std::_If<sizeof(T) == 2, unsigned short, unsigned char>>>;
512512

513513
// Unit word type
514-
using TxnWord = ::cuda::std::
515-
__conditional_t<sizeof(T) == 8, ulonglong2, ::cuda::std::__conditional_t<sizeof(T) == 4, uint2, unsigned int>>;
514+
using TxnWord = ::cuda::std::_If<sizeof(T) == 8, ulonglong2, ::cuda::std::_If<sizeof(T) == 4, uint2, unsigned int>>;
516515

517516
// Device word type
518517
struct TileDescriptor
@@ -889,18 +888,15 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
889888
};
890889

891890
// Status word type
892-
using StatusWord = ::cuda::std::__conditional_t<
891+
using StatusWord = ::cuda::std::_If<
893892
STATUS_WORD_SIZE == 8,
894893
unsigned long long,
895-
::cuda::std::__conditional_t<STATUS_WORD_SIZE == 4,
896-
unsigned int,
897-
::cuda::std::__conditional_t<STATUS_WORD_SIZE == 2, unsigned short, unsigned char>>>;
894+
::cuda::std::
895+
_If<STATUS_WORD_SIZE == 4, unsigned int, ::cuda::std::_If<STATUS_WORD_SIZE == 2, unsigned short, unsigned char>>>;
898896

899897
// Status word type
900-
using TxnWord =
901-
::cuda::std::__conditional_t<TXN_WORD_SIZE == 16,
902-
ulonglong2,
903-
::cuda::std::__conditional_t<TXN_WORD_SIZE == 8, unsigned long long, unsigned int>>;
898+
using TxnWord = ::cuda::std::
899+
_If<TXN_WORD_SIZE == 16, ulonglong2, ::cuda::std::_If<TXN_WORD_SIZE == 8, unsigned long long, unsigned int>>;
904900

905901
// Device word type (for when sizeof(ValueT) == sizeof(KeyT))
906902
struct TileDescriptorBigStatus
@@ -920,7 +916,7 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
920916

921917
// Device word type
922918
using TileDescriptor =
923-
::cuda::std::__conditional_t<sizeof(ValueT) == sizeof(KeyT), TileDescriptorBigStatus, TileDescriptorLittleStatus>;
919+
::cuda::std::_If<sizeof(ValueT) == sizeof(KeyT), TileDescriptorBigStatus, TileDescriptorLittleStatus>;
924920

925921
// Device storage
926922
TxnWord* d_tile_descriptors;

0 commit comments

Comments
 (0)