Skip to content

Commit 7fc7838

Browse files
Replace all uses of __conditional_t in CUB if _If
_If does not need to instantiate the type not selected.
1 parent 800b5df commit 7fc7838

34 files changed

+149
-164
lines changed

cub/benchmarks/bench/radix_sort/keys.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ struct policy_hub_t
4646
{
4747
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;
4848

49-
using DominantT = ::cuda::std::__conditional_t<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
49+
using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
5050

5151
struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t>
5252
{

cub/benchmarks/bench/radix_sort/pairs.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct policy_hub_t
4444
{
4545
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;
4646

47-
using DominantT = ::cuda::std::__conditional_t<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
47+
using DominantT = ::cuda::std::_If<(sizeof(ValueT) > sizeof(KeyT)), ValueT, KeyT>;
4848

4949
struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t>
5050
{

cub/cub/agent/agent_reduce.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ struct AgentReduce
145145
// Wrap the native input pointer with CacheModifiedInputIterator
146146
// or directly use the supplied input iterator type
147147
using WrappedInputIteratorT =
148-
::cuda::std::__conditional_t<std::is_pointer<InputIteratorT>::value,
149-
CacheModifiedInputIterator<AgentReducePolicy::LOAD_MODIFIER, InputT, OffsetT>,
150-
InputIteratorT>;
148+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
149+
CacheModifiedInputIterator<AgentReducePolicy::LOAD_MODIFIER, InputT, OffsetT>,
150+
InputIteratorT>;
151151

152152
/// Constants
153153
static constexpr int BLOCK_THREADS = AgentReducePolicy::BLOCK_THREADS;

cub/cub/agent/agent_reduce_by_key.cuh

+11-11
Original file line numberDiff line numberDiff line change
@@ -225,27 +225,27 @@ struct AgentReduceByKey
225225
// CacheModifiedValuesInputIterator or directly use the supplied input
226226
// iterator type
227227
using WrappedKeysInputIteratorT =
228-
::cuda::std::__conditional_t<std::is_pointer<KeysInputIteratorT>::value,
229-
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,
230-
KeysInputIteratorT>;
228+
::cuda::std::_If<std::is_pointer<KeysInputIteratorT>::value,
229+
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,
230+
KeysInputIteratorT>;
231231

232232
// Cache-modified Input iterator wrapper type (for applying cache modifier)
233233
// for values Wrap the native input pointer with
234234
// CacheModifiedValuesInputIterator or directly use the supplied input
235235
// iterator type
236-
using WrappedValuesInputIteratorT = ::cuda::std::__conditional_t<
237-
std::is_pointer<ValuesInputIteratorT>::value,
238-
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
239-
ValuesInputIteratorT>;
236+
using WrappedValuesInputIteratorT =
237+
::cuda::std::_If<std::is_pointer<ValuesInputIteratorT>::value,
238+
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
239+
ValuesInputIteratorT>;
240240

241241
// Cache-modified Input iterator wrapper type (for applying cache modifier)
242242
// for fixup values Wrap the native input pointer with
243243
// CacheModifiedValuesInputIterator or directly use the supplied input
244244
// iterator type
245-
using WrappedFixupInputIteratorT = ::cuda::std::__conditional_t<
246-
std::is_pointer<AggregatesOutputIteratorT>::value,
247-
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
248-
AggregatesOutputIteratorT>;
245+
using WrappedFixupInputIteratorT =
246+
::cuda::std::_If<std::is_pointer<AggregatesOutputIteratorT>::value,
247+
CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
248+
AggregatesOutputIteratorT>;
249249

250250
// Reduce-value-by-segment scan operator
251251
using ReduceBySegmentOpT = ReduceBySegmentOp<ReductionOpT>;

cub/cub/agent/agent_rle.cuh

+4-4
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,9 @@ struct AgentRle
231231
// Wrap the native input pointer with CacheModifiedVLengthnputIterator
232232
// Directly use the supplied input iterator type
233233
using WrappedInputIteratorT =
234-
::cuda::std::__conditional_t<std::is_pointer<InputIteratorT>::value,
235-
CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,
236-
InputIteratorT>;
234+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
235+
CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,
236+
InputIteratorT>;
237237

238238
// Parameterized BlockLoad type for data
239239
using BlockLoadT =
@@ -257,7 +257,7 @@ struct AgentRle
257257
using WarpExchangePairs = WarpExchange<LengthOffsetPair, ITEMS_PER_THREAD>;
258258

259259
using WarpExchangePairsStorage =
260-
::cuda::std::__conditional_t<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>;
260+
::cuda::std::_If<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>;
261261

262262
using WarpExchangeOffsets = WarpExchange<OffsetT, ITEMS_PER_THREAD>;
263263
using WarpExchangeLengths = WarpExchange<LengthT, ITEMS_PER_THREAD>;

cub/cub/agent/agent_scan.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ struct AgentScan
157157
// Wrap the native input pointer with CacheModifiedInputIterator
158158
// or directly use the supplied input iterator type
159159
using WrappedInputIteratorT =
160-
::cuda::std::__conditional_t<std::is_pointer<InputIteratorT>::value,
161-
CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
162-
InputIteratorT>;
160+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
161+
CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
162+
InputIteratorT>;
163163

164164
// Constants
165165
enum

cub/cub/agent/agent_scan_by_key.cuh

+6-6
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ struct AgentScanByKey
152152
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
153153

154154
using WrappedKeysInputIteratorT =
155-
::cuda::std::__conditional_t<std::is_pointer<KeysInputIteratorT>::value,
156-
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>,
157-
KeysInputIteratorT>;
155+
::cuda::std::_If<std::is_pointer<KeysInputIteratorT>::value,
156+
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>,
157+
KeysInputIteratorT>;
158158

159159
using WrappedValuesInputIteratorT =
160-
::cuda::std::__conditional_t<std::is_pointer<ValuesInputIteratorT>::value,
161-
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
162-
ValuesInputIteratorT>;
160+
::cuda::std::_If<std::is_pointer<ValuesInputIteratorT>::value,
161+
CacheModifiedInputIterator<AgentScanByKeyPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
162+
ValuesInputIteratorT>;
163163

164164
using BlockLoadKeysT = BlockLoad<KeyT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentScanByKeyPolicyT::LOAD_ALGORITHM>;
165165

cub/cub/agent/agent_segment_fixup.cuh

+7-7
Original file line numberDiff line numberDiff line change
@@ -171,18 +171,18 @@ struct AgentSegmentFixup
171171
// Cache-modified Input iterator wrapper type (for applying cache modifier) for keys
172172
// Wrap the native input pointer with CacheModifiedValuesInputIterator
173173
// or directly use the supplied input iterator type
174-
using WrappedPairsInputIteratorT = ::cuda::std::__conditional_t<
175-
std::is_pointer<PairsInputIteratorT>::value,
176-
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>,
177-
PairsInputIteratorT>;
174+
using WrappedPairsInputIteratorT =
175+
::cuda::std::_If<std::is_pointer<PairsInputIteratorT>::value,
176+
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>,
177+
PairsInputIteratorT>;
178178

179179
// Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values
180180
// Wrap the native input pointer with CacheModifiedValuesInputIterator
181181
// or directly use the supplied input iterator type
182182
using WrappedFixupInputIteratorT =
183-
::cuda::std::__conditional_t<std::is_pointer<AggregatesOutputIteratorT>::value,
184-
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>,
185-
AggregatesOutputIteratorT>;
183+
::cuda::std::_If<std::is_pointer<AggregatesOutputIteratorT>::value,
184+
CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>,
185+
AggregatesOutputIteratorT>;
186186

187187
// Reduce-value-by-segment scan operator
188188
using ReduceBySegmentOpT = ReduceByKeyOp<cub::Sum>;

cub/cub/agent/agent_select_if.cuh

+6-6
Original file line numberDiff line numberDiff line change
@@ -219,17 +219,17 @@ struct AgentSelectIf
219219
// Wrap the native input pointer with CacheModifiedValuesInputIterator
220220
// or directly use the supplied input iterator type
221221
using WrappedInputIteratorT =
222-
::cuda::std::__conditional_t<::cuda::std::is_pointer<InputIteratorT>::value,
223-
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
224-
InputIteratorT>;
222+
::cuda::std::_If<::cuda::std::is_pointer<InputIteratorT>::value,
223+
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>,
224+
InputIteratorT>;
225225

226226
// Cache-modified Input iterator wrapper type (for applying cache modifier) for values
227227
// Wrap the native input pointer with CacheModifiedValuesInputIterator
228228
// or directly use the supplied input iterator type
229229
using WrappedFlagsInputIteratorT =
230-
::cuda::std::__conditional_t<::cuda::std::is_pointer<FlagsInputIteratorT>::value,
231-
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>,
232-
FlagsInputIteratorT>;
230+
::cuda::std::_If<::cuda::std::is_pointer<FlagsInputIteratorT>::value,
231+
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>,
232+
FlagsInputIteratorT>;
233233

234234
// Parameterized BlockLoad type for input data
235235
using BlockLoadT = BlockLoad<InputT, BLOCK_THREADS, ITEMS_PER_THREAD, AgentSelectIfPolicyT::LOAD_ALGORITHM>;

cub/cub/agent/agent_spmv_orig.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ struct AgentSpmv
264264
{
265265
// Value type to pair with index type OffsetT
266266
// (NullType if loading values directly during merge)
267-
using MergeValueT = ::cuda::std::__conditional_t<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>;
267+
using MergeValueT = ::cuda::std::_If<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>;
268268

269269
OffsetT row_end_offset;
270270
MergeValueT nonzero;

cub/cub/agent/agent_three_way_partition.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ struct AgentThreeWayPartition
197197
static constexpr int TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD;
198198

199199
using WrappedInputIteratorT =
200-
::cuda::std::__conditional_t<std::is_pointer<InputIteratorT>::value,
201-
cub::CacheModifiedInputIterator<PolicyT::LOAD_MODIFIER, InputT, OffsetT>,
202-
InputIteratorT>;
200+
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
201+
cub::CacheModifiedInputIterator<PolicyT::LOAD_MODIFIER, InputT, OffsetT>,
202+
InputIteratorT>;
203203

204204
// Parameterized BlockLoad type for input data
205205
using BlockLoadT = cub::BlockLoad<InputT, BLOCK_THREADS, ITEMS_PER_THREAD, PolicyT::LOAD_ALGORITHM>;

cub/cub/agent/single_pass_scan_operators.cuh

+13-19
Original file line numberDiff line numberDiff line change
@@ -476,16 +476,16 @@ using default_no_delay_t = default_no_delay_constructor_t::delay_t;
476476

477477
template <class T>
478478
using default_delay_constructor_t =
479-
::cuda::std::__conditional_t<Traits<T>::PRIMITIVE, fixed_delay_constructor_t<350, 450>, default_no_delay_constructor_t>;
479+
::cuda::std::_If<Traits<T>::PRIMITIVE, fixed_delay_constructor_t<350, 450>, default_no_delay_constructor_t>;
480480

481481
template <class T>
482482
using default_delay_t = typename default_delay_constructor_t<T>::delay_t;
483483

484484
template <class KeyT, class ValueT>
485485
using default_reduce_by_key_delay_constructor_t =
486-
::cuda::std::__conditional_t<(Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16),
487-
reduce_by_key_delay_constructor_t<350, 450>,
488-
default_delay_constructor_t<KeyValuePair<KeyT, ValueT>>>;
486+
::cuda::std::_If<(Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16),
487+
reduce_by_key_delay_constructor_t<350, 450>,
488+
default_delay_constructor_t<KeyValuePair<KeyT, ValueT>>>;
489489
} // namespace detail
490490

491491
/**
@@ -503,16 +503,13 @@ template <typename T>
503503
struct ScanTileState<T, true>
504504
{
505505
// Status word type
506-
using StatusWord = ::cuda::std::__conditional_t<
506+
using StatusWord = ::cuda::std::_If<
507507
sizeof(T) == 8,
508508
unsigned long long,
509-
::cuda::std::__conditional_t<sizeof(T) == 4,
510-
unsigned int,
511-
::cuda::std::__conditional_t<sizeof(T) == 2, unsigned short, unsigned char>>>;
509+
::cuda::std::_If<sizeof(T) == 4, unsigned int, ::cuda::std::_If<sizeof(T) == 2, unsigned short, unsigned char>>>;
512510

513511
// Unit word type
514-
using TxnWord = ::cuda::std::
515-
__conditional_t<sizeof(T) == 8, ulonglong2, ::cuda::std::__conditional_t<sizeof(T) == 4, uint2, unsigned int>>;
512+
using TxnWord = ::cuda::std::_If<sizeof(T) == 8, ulonglong2, ::cuda::std::_If<sizeof(T) == 4, uint2, unsigned int>>;
516513

517514
// Device word type
518515
struct TileDescriptor
@@ -889,18 +886,15 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
889886
};
890887

891888
// Status word type
892-
using StatusWord = ::cuda::std::__conditional_t<
889+
using StatusWord = ::cuda::std::_If<
893890
STATUS_WORD_SIZE == 8,
894891
unsigned long long,
895-
::cuda::std::__conditional_t<STATUS_WORD_SIZE == 4,
896-
unsigned int,
897-
::cuda::std::__conditional_t<STATUS_WORD_SIZE == 2, unsigned short, unsigned char>>>;
892+
::cuda::std::
893+
_If<STATUS_WORD_SIZE == 4, unsigned int, ::cuda::std::_If<STATUS_WORD_SIZE == 2, unsigned short, unsigned char>>>;
898894

899895
// Status word type
900-
using TxnWord =
901-
::cuda::std::__conditional_t<TXN_WORD_SIZE == 16,
902-
ulonglong2,
903-
::cuda::std::__conditional_t<TXN_WORD_SIZE == 8, unsigned long long, unsigned int>>;
896+
using TxnWord = ::cuda::std::
897+
_If<TXN_WORD_SIZE == 16, ulonglong2, ::cuda::std::_If<TXN_WORD_SIZE == 8, unsigned long long, unsigned int>>;
904898

905899
// Device word type (for when sizeof(ValueT) == sizeof(KeyT))
906900
struct TileDescriptorBigStatus
@@ -920,7 +914,7 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
920914

921915
// Device word type
922916
using TileDescriptor =
923-
::cuda::std::__conditional_t<sizeof(ValueT) == sizeof(KeyT), TileDescriptorBigStatus, TileDescriptorLittleStatus>;
917+
::cuda::std::_If<sizeof(ValueT) == sizeof(KeyT), TileDescriptorBigStatus, TileDescriptorLittleStatus>;
924918

925919
// Device storage
926920
TxnWord* d_tile_descriptors;

cub/cub/block/block_histogram.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ private:
199199

200200
/// Internal specialization.
201201
using InternalBlockHistogram =
202-
::cuda::std::__conditional_t<ALGORITHM == BLOCK_HISTO_SORT,
203-
BlockHistogramSort<T, BLOCK_DIM_X, ITEMS_PER_THREAD, BINS, BLOCK_DIM_Y, BLOCK_DIM_Z>,
204-
BlockHistogramAtomic<BINS>>;
202+
::cuda::std::_If<ALGORITHM == BLOCK_HISTO_SORT,
203+
BlockHistogramSort<T, BLOCK_DIM_X, ITEMS_PER_THREAD, BINS, BLOCK_DIM_Y, BLOCK_DIM_Z>,
204+
BlockHistogramAtomic<BINS>>;
205205

206206
/// Shared memory storage layout type for BlockHistogram
207207
typedef typename InternalBlockHistogram::TempStorage _TempStorage;

cub/cub/block/block_radix_rank.cuh

+5-5
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ private:
221221

222222
// Integer type for packing DigitCounters into columns of shared memory banks
223223
using PackedCounter =
224-
::cuda::std::__conditional_t<SMEM_CONFIG == cudaSharedMemBankSizeEightByte, unsigned long long, unsigned int>;
224+
::cuda::std::_If<SMEM_CONFIG == cudaSharedMemBankSizeEightByte, unsigned long long, unsigned int>;
225225

226226
static constexpr DigitCounter max_tile_size = ::cuda::std::numeric_limits<DigitCounter>::max();
227227

@@ -1195,16 +1195,16 @@ namespace detail
11951195
// - Support multi-dimensional thread blocks in the rest of implementations
11961196
// - Repurpose BlockRadixRank as an entry name with the algorithm template parameter
11971197
template <RadixRankAlgorithm RankAlgorithm, int BlockDimX, int RadixBits, bool IsDescending, BlockScanAlgorithm ScanAlgorithm>
1198-
using block_radix_rank_t = ::cuda::std::__conditional_t<
1198+
using block_radix_rank_t = ::cuda::std::_If<
11991199
RankAlgorithm == RADIX_RANK_BASIC,
12001200
BlockRadixRank<BlockDimX, RadixBits, IsDescending, false, ScanAlgorithm>,
1201-
::cuda::std::__conditional_t<
1201+
::cuda::std::_If<
12021202
RankAlgorithm == RADIX_RANK_MEMOIZE,
12031203
BlockRadixRank<BlockDimX, RadixBits, IsDescending, true, ScanAlgorithm>,
1204-
::cuda::std::__conditional_t<
1204+
::cuda::std::_If<
12051205
RankAlgorithm == RADIX_RANK_MATCH,
12061206
BlockRadixRankMatch<BlockDimX, RadixBits, IsDescending, ScanAlgorithm>,
1207-
::cuda::std::__conditional_t<
1207+
::cuda::std::_If<
12081208
RankAlgorithm == RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
12091209
BlockRadixRankMatchEarlyCounts<BlockDimX, RadixBits, IsDescending, ScanAlgorithm, WARP_MATCH_ANY>,
12101210
BlockRadixRankMatchEarlyCounts<BlockDimX, RadixBits, IsDescending, ScanAlgorithm, WARP_MATCH_ATOMIC_OR>>>>>;

cub/cub/block/block_reduce.cuh

+5-5
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,11 @@ private:
253253

254254
/// Internal specialization type
255255
using InternalBlockReduce =
256-
::cuda::std::__conditional_t<ALGORITHM == BLOCK_REDUCE_WARP_REDUCTIONS,
257-
WarpReductions,
258-
::cuda::std::__conditional_t<ALGORITHM == BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY,
259-
RakingCommutativeOnly,
260-
Raking>>; // BlockReduceRaking
256+
::cuda::std::_If<ALGORITHM == BLOCK_REDUCE_WARP_REDUCTIONS,
257+
WarpReductions,
258+
::cuda::std::_If<ALGORITHM == BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY,
259+
RakingCommutativeOnly,
260+
Raking>>; // BlockReduceRaking
261261

262262
/// Shared memory storage layout type for BlockReduce
263263
typedef typename InternalBlockReduce::TempStorage _TempStorage;

cub/cub/block/block_scan.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ private:
252252
Raking;
253253

254254
/// Define the delegate type for the desired algorithm
255-
using InternalBlockScan = ::cuda::std::__conditional_t<SAFE_ALGORITHM == BLOCK_SCAN_WARP_SCANS, WarpScans, Raking>;
255+
using InternalBlockScan = ::cuda::std::_If<SAFE_ALGORITHM == BLOCK_SCAN_WARP_SCANS, WarpScans, Raking>;
256256

257257
/// Shared memory storage layout type for BlockScan
258258
typedef typename InternalBlockScan::TempStorage _TempStorage;

cub/cub/device/dispatch/dispatch_batch_memcpy.cuh

+5-5
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,11 @@ struct DispatchBatchMemcpy : SelectedPolicy
454454
// The number of thread blocks (or tiles) required to process all of the given buffers
455455
BlockOffsetT num_tiles = DivideAndRoundUp(num_buffers, TILE_SIZE);
456456

457-
using BlevBufferSrcsOutT = ::cuda::std::__conditional_t<IsMemcpy, void*, cub::detail::value_t<InputBufferIt>>;
458-
using BlevBufferDstOutT = ::cuda::std::__conditional_t<IsMemcpy, void*, cub::detail::value_t<OutputBufferIt>>;
459-
using BlevBufferSrcsOutItT = BlevBufferSrcsOutT*;
460-
using BlevBufferDstsOutItT = BlevBufferDstOutT*;
461-
using BlevBufferSizesOutItT = BufferSizeT*;
457+
using BlevBufferSrcsOutT = ::cuda::std::_If<IsMemcpy, void*, cub::detail::value_t<InputBufferIt>>;
458+
using BlevBufferDstOutT = ::cuda::std::_If<IsMemcpy, void*, cub::detail::value_t<OutputBufferIt>>;
459+
using BlevBufferSrcsOutItT = BlevBufferSrcsOutT*;
460+
using BlevBufferDstsOutItT = BlevBufferDstOutT*;
461+
using BlevBufferSizesOutItT = BufferSizeT*;
462462
using BlevBufferTileOffsetsOutItT = BlockOffsetT*;
463463

464464
temporary_storage::layout<MEM_NUM_ALLOCATIONS> temporary_storage_layout;

0 commit comments

Comments
 (0)