Skip to content

Commit 1be7184

Browse files
bernhardmgrubergevtushenko
authored andcommitted
Port thrust::merge[_by_key] to CUB (NVIDIA#1817)
* Refactor thrust/CUB merge * Port thurst::merge[_by_key] to cub::DeviceMerge Fixes NVIDIA#1763 Co-authored-by: Georgii Evtushenko <[email protected]>
1 parent 0db46cc commit 1be7184

15 files changed

+1697
-971
lines changed

cub/cub/agent/agent_merge.cuh

+229
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#pragma once
5+
6+
#include <cub/config.cuh>
7+
8+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
9+
# pragma GCC system_header
10+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
11+
# pragma clang system_header
12+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
13+
# pragma system_header
14+
#endif // no system header
15+
16+
#include <cub/agent/agent_merge_sort.cuh>
17+
#include <cub/block/block_load.cuh>
18+
#include <cub/block/block_merge_sort.cuh>
19+
#include <cub/block/block_store.cuh>
20+
#include <cub/util_namespace.cuh>
21+
#include <cub/util_type.cuh>
22+
23+
#include <thrust/system/cuda/detail/core/util.h>
24+
25+
#include <cuda/std/__cccl/dialect.h>
26+
27+
CUB_NAMESPACE_BEGIN
28+
namespace detail
29+
{
30+
namespace merge
31+
{
32+
template <int ThreadsPerBlock,
33+
int ItemsPerThread,
34+
BlockLoadAlgorithm LoadAlgorithm,
35+
CacheLoadModifier LoadCacheModifier,
36+
BlockStoreAlgorithm StoreAlgorithm>
37+
struct agent_policy_t
38+
{
39+
// do not change data member names, policy_wrapper_t depends on it
40+
static constexpr int BLOCK_THREADS = ThreadsPerBlock;
41+
static constexpr int ITEMS_PER_THREAD = ItemsPerThread;
42+
static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD;
43+
static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = LoadAlgorithm;
44+
static constexpr CacheLoadModifier LOAD_MODIFIER = LoadCacheModifier;
45+
static constexpr BlockStoreAlgorithm STORE_ALGORITHM = StoreAlgorithm;
46+
};
47+
48+
// TODO(bgruber): can we unify this one with AgentMerge in agent_merge_sort.cuh?
49+
template <typename Policy,
50+
typename KeysIt1,
51+
typename ItemsIt1,
52+
typename KeysIt2,
53+
typename ItemsIt2,
54+
typename KeysOutputIt,
55+
typename ItemsOutputIt,
56+
typename Offset,
57+
typename CompareOp>
58+
struct agent_t
59+
{
60+
using policy = Policy;
61+
62+
using key_type = typename ::cuda::std::iterator_traits<KeysIt1>::value_type;
63+
using item_type = typename ::cuda::std::iterator_traits<ItemsIt1>::value_type;
64+
65+
using keys_load_it1 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeysIt1>::type;
66+
using keys_load_it2 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeysIt2>::type;
67+
using items_load_it1 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ItemsIt1>::type;
68+
using items_load_it2 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ItemsIt2>::type;
69+
70+
using block_load_keys1 = typename BlockLoadType<Policy, keys_load_it1>::type;
71+
using block_load_keys2 = typename BlockLoadType<Policy, keys_load_it2>::type;
72+
using block_load_items1 = typename BlockLoadType<Policy, items_load_it1>::type;
73+
using block_load_items2 = typename BlockLoadType<Policy, items_load_it2>::type;
74+
75+
using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
76+
using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
77+
78+
union temp_storages
79+
{
80+
typename block_load_keys1::TempStorage load_keys1;
81+
typename block_load_keys2::TempStorage load_keys2;
82+
typename block_load_items1::TempStorage load_items1;
83+
typename block_load_items2::TempStorage load_items2;
84+
typename block_store_keys::TempStorage store_keys;
85+
typename block_store_items::TempStorage store_items;
86+
87+
key_type keys_shared[Policy::ITEMS_PER_TILE + 1];
88+
item_type items_shared[Policy::ITEMS_PER_TILE + 1];
89+
};
90+
91+
struct TempStorage : Uninitialized<temp_storages>
92+
{};
93+
94+
static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD;
95+
static constexpr int threads_per_block = Policy::BLOCK_THREADS;
96+
static constexpr Offset items_per_tile = Policy::ITEMS_PER_TILE;
97+
98+
// Per thread data
99+
temp_storages& storage;
100+
keys_load_it1 keys1_in;
101+
items_load_it1 items1_in;
102+
Offset keys1_count;
103+
keys_load_it2 keys2_in;
104+
items_load_it2 items2_in;
105+
Offset keys2_count;
106+
KeysOutputIt keys_out;
107+
ItemsOutputIt items_out;
108+
CompareOp compare_op;
109+
Offset* merge_partitions;
110+
111+
template <bool IsFullTile>
112+
_CCCL_DEVICE _CCCL_FORCEINLINE void consume_tile(Offset tile_idx, Offset tile_base, int num_remaining)
113+
{
114+
const Offset partition_beg = merge_partitions[tile_idx + 0];
115+
const Offset partition_end = merge_partitions[tile_idx + 1];
116+
117+
const Offset diag0 = items_per_tile * tile_idx;
118+
const Offset diag1 = (cub::min)(keys1_count + keys2_count, diag0 + items_per_tile);
119+
120+
// compute bounding box for keys1 & keys2
121+
const Offset keys1_beg = partition_beg;
122+
const Offset keys1_end = partition_end;
123+
const Offset keys2_beg = diag0 - keys1_beg;
124+
const Offset keys2_end = diag1 - keys1_end;
125+
126+
// number of keys per tile
127+
const int num_keys1 = static_cast<int>(keys1_end - keys1_beg);
128+
const int num_keys2 = static_cast<int>(keys2_end - keys2_beg);
129+
130+
key_type keys_loc[items_per_thread];
131+
gmem_to_reg<threads_per_block, IsFullTile>(
132+
keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2);
133+
reg_to_shared<threads_per_block>(&storage.keys_shared[0], keys_loc);
134+
CTA_SYNC();
135+
136+
// use binary search in shared memory to find merge path for each of thread.
137+
// we can use int type here, because the number of items in shared memory is limited
138+
const int diag0_loc = min<int>(num_keys1 + num_keys2, items_per_thread * threadIdx.x);
139+
140+
const int keys1_beg_loc =
141+
MergePath(&storage.keys_shared[0], &storage.keys_shared[num_keys1], num_keys1, num_keys2, diag0_loc, compare_op);
142+
const int keys1_end_loc = num_keys1;
143+
const int keys2_beg_loc = diag0_loc - keys1_beg_loc;
144+
const int keys2_end_loc = num_keys2;
145+
146+
const int num_keys1_loc = keys1_end_loc - keys1_beg_loc;
147+
const int num_keys2_loc = keys2_end_loc - keys2_beg_loc;
148+
149+
// perform serial merge
150+
int indices[items_per_thread];
151+
cub::SerialMerge(
152+
&storage.keys_shared[0],
153+
keys1_beg_loc,
154+
keys2_beg_loc + num_keys1,
155+
num_keys1_loc,
156+
num_keys2_loc,
157+
keys_loc,
158+
indices,
159+
compare_op);
160+
CTA_SYNC();
161+
162+
// write keys
163+
if (IsFullTile)
164+
{
165+
block_store_keys{storage.store_keys}.Store(keys_out + tile_base, keys_loc);
166+
}
167+
else
168+
{
169+
block_store_keys{storage.store_keys}.Store(keys_out + tile_base, keys_loc, num_remaining);
170+
}
171+
172+
// if items are provided, merge them
173+
static constexpr bool have_items = !std::is_same<item_type, NullType>::value;
174+
#ifdef _CCCL_CUDACC_BELOW_11_8
175+
if (have_items) // nvcc 11.1 cannot handle #pragma unroll inside if constexpr but 11.8 can.
176+
// nvcc versions between may work
177+
#else
178+
_CCCL_IF_CONSTEXPR (have_items)
179+
#endif
180+
{
181+
item_type items_loc[items_per_thread];
182+
gmem_to_reg<threads_per_block, IsFullTile>(
183+
items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2);
184+
CTA_SYNC(); // block_store_keys above uses shared memory, so make sure all threads are done before we write to it
185+
reg_to_shared<threads_per_block>(&storage.items_shared[0], items_loc);
186+
CTA_SYNC();
187+
188+
// gather items from shared mem
189+
#pragma unroll
190+
for (int i = 0; i < items_per_thread; ++i)
191+
{
192+
items_loc[i] = storage.items_shared[indices[i]];
193+
}
194+
CTA_SYNC();
195+
196+
// write from reg to gmem
197+
if (IsFullTile)
198+
{
199+
block_store_items{storage.store_items}.Store(items_out + tile_base, items_loc);
200+
}
201+
else
202+
{
203+
block_store_items{storage.store_items}.Store(items_out + tile_base, items_loc, num_remaining);
204+
}
205+
}
206+
}
207+
208+
_CCCL_DEVICE _CCCL_FORCEINLINE void operator()()
209+
{
210+
// XXX with 8.5 chaging type to Offset (or long long) results in error!
211+
// TODO(bgruber): is the above still true?
212+
const int tile_idx = static_cast<int>(blockIdx.x);
213+
const Offset tile_base = tile_idx * items_per_tile;
214+
// TODO(bgruber): random mixing of int and Offset
215+
const int items_in_tile =
216+
static_cast<int>(cub::min(static_cast<Offset>(items_per_tile), keys1_count + keys2_count - tile_base));
217+
if (items_in_tile == items_per_tile)
218+
{
219+
consume_tile<true>(tile_idx, tile_base, items_per_tile); // full tile
220+
}
221+
else
222+
{
223+
consume_tile<false>(tile_idx, tile_base, items_in_tile); // partial tile
224+
}
225+
}
226+
};
227+
} // namespace merge
228+
} // namespace detail
229+
CUB_NAMESPACE_END

0 commit comments

Comments
 (0)