Skip to content

Commit 5ee81bb

Browse files
Allow different input value types for merge
As long as they are convertible to the value type of the first iterator. This weakens the publicly documented guarantees of equal value types to restore the old behavior of the thrust implementation replaced in #1817.
1 parent 223d3ea commit 5ee81bb

File tree

3 files changed

+12
-21
lines changed

3 files changed

+12
-21
lines changed

cub/cub/agent/agent_merge.cuh

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct agent_t
5959
{
6060
using policy = Policy;
6161

62+
// key and value type are taken from the first input sequence (consistent with old Thrust behavior)
6263
using key_type = typename ::cuda::std::iterator_traits<KeysIt1>::value_type;
6364
using item_type = typename ::cuda::std::iterator_traits<ItemsIt1>::value_type;
6465

cub/cub/agent/agent_merge_sort.cuh

+5-4
Original file line numberDiff line numberDiff line change
@@ -382,19 +382,20 @@ gmem_to_reg(T (&output)[ITEMS_PER_THREAD], It1 input1, It2 input2, int count1, i
382382
#pragma unroll
383383
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
384384
{
385-
int idx = BLOCK_THREADS * item + threadIdx.x;
386-
output[item] = (idx < count1) ? input1[idx] : input2[idx - count1];
385+
const int idx = BLOCK_THREADS * item + threadIdx.x;
386+
// It1 and It2 could have different value types. Convert after load.
387+
output[item] = (idx < count1) ? static_cast<T>(input1[idx]) : static_cast<T>(input2[idx - count1]);
387388
}
388389
}
389390
else
390391
{
391392
#pragma unroll
392393
for (int item = 0; item < ITEMS_PER_THREAD; ++item)
393394
{
394-
int idx = BLOCK_THREADS * item + threadIdx.x;
395+
const int idx = BLOCK_THREADS * item + threadIdx.x;
395396
if (idx < count1 + count2)
396397
{
397-
output[item] = (idx < count1) ? input1[idx] : input2[idx - count1];
398+
output[item] = (idx < count1) ? static_cast<T>(input1[idx]) : static_cast<T>(input2[idx - count1]);
398399
}
399400
}
400401
}

cub/cub/device/dispatch/dispatch_merge.cuh

+6-17
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,6 @@ CUB_DETAIL_KERNEL_ATTRIBUTES void device_partition_merge_path_kernel(
6464
Offset* merge_partitions,
6565
CompareOp compare_op)
6666
{
67-
static_assert(
68-
::cuda::std::is_convertible<typename ::cuda::std::__invoke_of<CompareOp, value_t<KeyIt1>, value_t<KeyIt1>>::type,
69-
bool>::value,
70-
"Comparison operator must be convertible to bool");
71-
7267
// items_per_tile must be the same of the merge kernel later, so we have to consider whether a fallback agent will be
7368
// selected for the merge agent that changes the tile size
7469
constexpr int items_per_tile =
@@ -121,9 +116,12 @@ __launch_bounds__(
121116
Offset* merge_partitions,
122117
vsmem_t global_temp_storage)
123118
{
119+
// the merge agent loads keys into a local array of KeyIt1::value_type, on which the comparisons are performed
120+
using key_t = value_t<KeyIt1>;
121+
static_assert(::cuda::std::__invokable<CompareOp, key_t, key_t>::value,
122+
"Comparison operator cannot compare two keys");
124123
static_assert(
125-
::cuda::std::is_convertible<typename ::cuda::std::__invoke_of<CompareOp, value_t<KeyIt1>, value_t<KeyIt1>>::type,
126-
bool>::value,
124+
::cuda::std::is_convertible<typename ::cuda::std::__invoke_of<CompareOp, key_t, key_t>::type, bool>::value,
127125
"Comparison operator must be convertible to bool");
128126

129127
using MergeAgent = typename choose_merge_agent<
@@ -218,15 +216,6 @@ template <typename KeyIt1,
218216
typename PolicyHub = device_merge_policy_hub<value_t<KeyIt1>, value_t<ValueIt1>>>
219217
struct dispatch_t
220218
{
221-
using key_t = cub::detail::value_t<KeyIt1>;
222-
using value_t = cub::detail::value_t<ValueIt1>;
223-
224-
// Cannot check output iterators, since they could be discard iterators, which do not have the right value_type
225-
static_assert(::cuda::std::is_same<cub::detail::value_t<KeyIt2>, key_t>::value, "");
226-
static_assert(::cuda::std::is_same<cub::detail::value_t<ValueIt2>, value_t>::value, "");
227-
static_assert(::cuda::std::__invokable<CompareOp, key_t, key_t>::value,
228-
"Comparison operator cannot compare two keys");
229-
230219
void* d_temp_storage;
231220
std::size_t& temp_storage_bytes;
232221
KeyIt1 d_keys1;
@@ -351,7 +340,7 @@ struct dispatch_t
351340
{
352341
return error;
353342
}
354-
dispatch_t dispatch{std::forward<Args>(args)...};
343+
dispatch_t dispatch{::cuda::std::forward<Args>(args)...};
355344
error = CubDebug(PolicyHub::max_policy::Invoke(ptx_version, dispatch));
356345
if (cudaSuccess != error)
357346
{

0 commit comments

Comments
 (0)