|
| 1 | +// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 3 | + |
| 4 | +#pragma once |
| 5 | + |
| 6 | +#include <cub/config.cuh> |
| 7 | + |
| 8 | +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) |
| 9 | +# pragma GCC system_header |
| 10 | +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) |
| 11 | +# pragma clang system_header |
| 12 | +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) |
| 13 | +# pragma system_header |
| 14 | +#endif // no system header |
| 15 | + |
| 16 | +#include <cub/agent/agent_merge_sort.cuh> |
| 17 | +#include <cub/block/block_load.cuh> |
| 18 | +#include <cub/block/block_merge_sort.cuh> |
| 19 | +#include <cub/block/block_store.cuh> |
| 20 | +#include <cub/util_namespace.cuh> |
| 21 | +#include <cub/util_type.cuh> |
| 22 | + |
| 23 | +#include <thrust/system/cuda/detail/core/util.h> |
| 24 | + |
| 25 | +#include <cuda/std/__cccl/dialect.h> |
| 26 | + |
| 27 | +CUB_NAMESPACE_BEGIN |
| 28 | +namespace detail |
| 29 | +{ |
| 30 | +namespace merge |
| 31 | +{ |
| 32 | +template <int ThreadsPerBlock, |
| 33 | + int ItemsPerThread, |
| 34 | + BlockLoadAlgorithm LoadAlgorithm, |
| 35 | + CacheLoadModifier LoadCacheModifier, |
| 36 | + BlockStoreAlgorithm StoreAlgorithm> |
| 37 | +struct agent_policy_t |
| 38 | +{ |
| 39 | + // do not change data member names, policy_wrapper_t depends on it |
| 40 | + static constexpr int BLOCK_THREADS = ThreadsPerBlock; |
| 41 | + static constexpr int ITEMS_PER_THREAD = ItemsPerThread; |
| 42 | + static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD; |
| 43 | + static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = LoadAlgorithm; |
| 44 | + static constexpr CacheLoadModifier LOAD_MODIFIER = LoadCacheModifier; |
| 45 | + static constexpr BlockStoreAlgorithm STORE_ALGORITHM = StoreAlgorithm; |
| 46 | +}; |
| 47 | + |
| 48 | +// TODO(bgruber): can we unify this one with AgentMerge in agent_merge_sort.cuh? |
| 49 | +template <typename Policy, |
| 50 | + typename KeysIt1, |
| 51 | + typename ItemsIt1, |
| 52 | + typename KeysIt2, |
| 53 | + typename ItemsIt2, |
| 54 | + typename KeysOutputIt, |
| 55 | + typename ItemsOutputIt, |
| 56 | + typename Offset, |
| 57 | + typename CompareOp> |
| 58 | +struct agent_t |
| 59 | +{ |
| 60 | + using policy = Policy; |
| 61 | + |
| 62 | + using key_type = typename ::cuda::std::iterator_traits<KeysIt1>::value_type; |
| 63 | + using item_type = typename ::cuda::std::iterator_traits<ItemsIt1>::value_type; |
| 64 | + |
| 65 | + using keys_load_it1 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeysIt1>::type; |
| 66 | + using keys_load_it2 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, KeysIt2>::type; |
| 67 | + using items_load_it1 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ItemsIt1>::type; |
| 68 | + using items_load_it2 = typename THRUST_NS_QUALIFIER::cuda_cub::core::LoadIterator<Policy, ItemsIt2>::type; |
| 69 | + |
| 70 | + using block_load_keys1 = typename BlockLoadType<Policy, keys_load_it1>::type; |
| 71 | + using block_load_keys2 = typename BlockLoadType<Policy, keys_load_it2>::type; |
| 72 | + using block_load_items1 = typename BlockLoadType<Policy, items_load_it1>::type; |
| 73 | + using block_load_items2 = typename BlockLoadType<Policy, items_load_it2>::type; |
| 74 | + |
| 75 | + using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type; |
| 76 | + using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type; |
| 77 | + |
| 78 | + union temp_storages |
| 79 | + { |
| 80 | + typename block_load_keys1::TempStorage load_keys1; |
| 81 | + typename block_load_keys2::TempStorage load_keys2; |
| 82 | + typename block_load_items1::TempStorage load_items1; |
| 83 | + typename block_load_items2::TempStorage load_items2; |
| 84 | + typename block_store_keys::TempStorage store_keys; |
| 85 | + typename block_store_items::TempStorage store_items; |
| 86 | + |
| 87 | + key_type keys_shared[Policy::ITEMS_PER_TILE + 1]; |
| 88 | + item_type items_shared[Policy::ITEMS_PER_TILE + 1]; |
| 89 | + }; |
| 90 | + |
| 91 | + struct TempStorage : Uninitialized<temp_storages> |
| 92 | + {}; |
| 93 | + |
| 94 | + static constexpr int items_per_thread = Policy::ITEMS_PER_THREAD; |
| 95 | + static constexpr int threads_per_block = Policy::BLOCK_THREADS; |
| 96 | + static constexpr Offset items_per_tile = Policy::ITEMS_PER_TILE; |
| 97 | + |
| 98 | + // Per thread data |
| 99 | + temp_storages& storage; |
| 100 | + keys_load_it1 keys1_in; |
| 101 | + items_load_it1 items1_in; |
| 102 | + Offset keys1_count; |
| 103 | + keys_load_it2 keys2_in; |
| 104 | + items_load_it2 items2_in; |
| 105 | + Offset keys2_count; |
| 106 | + KeysOutputIt keys_out; |
| 107 | + ItemsOutputIt items_out; |
| 108 | + CompareOp compare_op; |
| 109 | + Offset* merge_partitions; |
| 110 | + |
| 111 | + template <bool IsFullTile> |
| 112 | + _CCCL_DEVICE _CCCL_FORCEINLINE void consume_tile(Offset tile_idx, Offset tile_base, int num_remaining) |
| 113 | + { |
| 114 | + const Offset partition_beg = merge_partitions[tile_idx + 0]; |
| 115 | + const Offset partition_end = merge_partitions[tile_idx + 1]; |
| 116 | + |
| 117 | + const Offset diag0 = items_per_tile * tile_idx; |
| 118 | + const Offset diag1 = (cub::min)(keys1_count + keys2_count, diag0 + items_per_tile); |
| 119 | + |
| 120 | + // compute bounding box for keys1 & keys2 |
| 121 | + const Offset keys1_beg = partition_beg; |
| 122 | + const Offset keys1_end = partition_end; |
| 123 | + const Offset keys2_beg = diag0 - keys1_beg; |
| 124 | + const Offset keys2_end = diag1 - keys1_end; |
| 125 | + |
| 126 | + // number of keys per tile |
| 127 | + const int num_keys1 = static_cast<int>(keys1_end - keys1_beg); |
| 128 | + const int num_keys2 = static_cast<int>(keys2_end - keys2_beg); |
| 129 | + |
| 130 | + key_type keys_loc[items_per_thread]; |
| 131 | + gmem_to_reg<threads_per_block, IsFullTile>( |
| 132 | + keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, num_keys1, num_keys2); |
| 133 | + reg_to_shared<threads_per_block>(&storage.keys_shared[0], keys_loc); |
| 134 | + CTA_SYNC(); |
| 135 | + |
| 136 | + // use binary search in shared memory to find merge path for each of thread. |
| 137 | + // we can use int type here, because the number of items in shared memory is limited |
| 138 | + const int diag0_loc = min<int>(num_keys1 + num_keys2, items_per_thread * threadIdx.x); |
| 139 | + |
| 140 | + const int keys1_beg_loc = |
| 141 | + MergePath(&storage.keys_shared[0], &storage.keys_shared[num_keys1], num_keys1, num_keys2, diag0_loc, compare_op); |
| 142 | + const int keys1_end_loc = num_keys1; |
| 143 | + const int keys2_beg_loc = diag0_loc - keys1_beg_loc; |
| 144 | + const int keys2_end_loc = num_keys2; |
| 145 | + |
| 146 | + const int num_keys1_loc = keys1_end_loc - keys1_beg_loc; |
| 147 | + const int num_keys2_loc = keys2_end_loc - keys2_beg_loc; |
| 148 | + |
| 149 | + // perform serial merge |
| 150 | + int indices[items_per_thread]; |
| 151 | + cub::SerialMerge( |
| 152 | + &storage.keys_shared[0], |
| 153 | + keys1_beg_loc, |
| 154 | + keys2_beg_loc + num_keys1, |
| 155 | + num_keys1_loc, |
| 156 | + num_keys2_loc, |
| 157 | + keys_loc, |
| 158 | + indices, |
| 159 | + compare_op); |
| 160 | + CTA_SYNC(); |
| 161 | + |
| 162 | + // write keys |
| 163 | + if (IsFullTile) |
| 164 | + { |
| 165 | + block_store_keys{storage.store_keys}.Store(keys_out + tile_base, keys_loc); |
| 166 | + } |
| 167 | + else |
| 168 | + { |
| 169 | + block_store_keys{storage.store_keys}.Store(keys_out + tile_base, keys_loc, num_remaining); |
| 170 | + } |
| 171 | + |
| 172 | + // if items are provided, merge them |
| 173 | + static constexpr bool have_items = !std::is_same<item_type, NullType>::value; |
| 174 | +#ifdef _CCCL_CUDACC_BELOW_11_8 |
| 175 | + if (have_items) // nvcc 11.1 cannot handle #pragma unroll inside if constexpr but 11.8 can. |
| 176 | + // nvcc versions between may work |
| 177 | +#else |
| 178 | + _CCCL_IF_CONSTEXPR (have_items) |
| 179 | +#endif |
| 180 | + { |
| 181 | + item_type items_loc[items_per_thread]; |
| 182 | + gmem_to_reg<threads_per_block, IsFullTile>( |
| 183 | + items_loc, items1_in + keys1_beg, items2_in + keys2_beg, num_keys1, num_keys2); |
| 184 | + CTA_SYNC(); // block_store_keys above uses shared memory, so make sure all threads are done before we write to it |
| 185 | + reg_to_shared<threads_per_block>(&storage.items_shared[0], items_loc); |
| 186 | + CTA_SYNC(); |
| 187 | + |
| 188 | + // gather items from shared mem |
| 189 | +#pragma unroll |
| 190 | + for (int i = 0; i < items_per_thread; ++i) |
| 191 | + { |
| 192 | + items_loc[i] = storage.items_shared[indices[i]]; |
| 193 | + } |
| 194 | + CTA_SYNC(); |
| 195 | + |
| 196 | + // write from reg to gmem |
| 197 | + if (IsFullTile) |
| 198 | + { |
| 199 | + block_store_items{storage.store_items}.Store(items_out + tile_base, items_loc); |
| 200 | + } |
| 201 | + else |
| 202 | + { |
| 203 | + block_store_items{storage.store_items}.Store(items_out + tile_base, items_loc, num_remaining); |
| 204 | + } |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + _CCCL_DEVICE _CCCL_FORCEINLINE void operator()() |
| 209 | + { |
| 210 | + // XXX with 8.5 chaging type to Offset (or long long) results in error! |
| 211 | + // TODO(bgruber): is the above still true? |
| 212 | + const int tile_idx = static_cast<int>(blockIdx.x); |
| 213 | + const Offset tile_base = tile_idx * items_per_tile; |
| 214 | + // TODO(bgruber): random mixing of int and Offset |
| 215 | + const int items_in_tile = |
| 216 | + static_cast<int>(cub::min(static_cast<Offset>(items_per_tile), keys1_count + keys2_count - tile_base)); |
| 217 | + if (items_in_tile == items_per_tile) |
| 218 | + { |
| 219 | + consume_tile<true>(tile_idx, tile_base, items_per_tile); // full tile |
| 220 | + } |
| 221 | + else |
| 222 | + { |
| 223 | + consume_tile<false>(tile_idx, tile_base, items_in_tile); // partial tile |
| 224 | + } |
| 225 | + } |
| 226 | +}; |
| 227 | +} // namespace merge |
| 228 | +} // namespace detail |
| 229 | +CUB_NAMESPACE_END |
0 commit comments