Skip to content

Commit 07fef97

Browse files
Use a constant for the amount of static SMEM (#2374)
1 parent fcf7c91 commit 07fef97

File tree

8 files changed

+23
-11
lines changed

8 files changed

+23
-11
lines changed

cub/benchmarks/bench/radix_sort/keys.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
******************************************************************************/
2727

2828
#include <cub/device/device_radix_sort.cuh>
29+
#include <cub/util_arch.cuh>
2930

3031
#include <cuda/std/type_traits>
3132

@@ -123,7 +124,7 @@ constexpr std::size_t max_temp_storage_size()
123124
template <typename KeyT, typename ValueT, typename OffsetT>
124125
constexpr bool fits_in_default_shared_memory()
125126
{
126-
return max_temp_storage_size<KeyT, ValueT, OffsetT>() < 48 * 1024;
127+
return max_temp_storage_size<KeyT, ValueT, OffsetT>() < cub::detail::max_smem_per_block;
127128
}
128129
#else // TUNE_BASE
129130
template <typename, typename, typename>

cub/benchmarks/bench/radix_sort/pairs.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
******************************************************************************/
2727

2828
#include <cub/device/device_radix_sort.cuh>
29+
#include <cub/util_arch.cuh>
2930

3031
#include <cuda/std/type_traits>
3132

@@ -121,7 +122,7 @@ constexpr std::size_t max_temp_storage_size()
121122
template <typename KeyT, typename ValueT, typename OffsetT>
122123
constexpr bool fits_in_default_shared_memory()
123124
{
124-
return max_temp_storage_size<KeyT, ValueT, OffsetT>() < 48 * 1024;
125+
return max_temp_storage_size<KeyT, ValueT, OffsetT>() < cub::detail::max_smem_per_block;
125126
}
126127
#else // TUNE_BASE
127128
template <typename, typename, typename>

cub/cub/util_arch.cuh

+11-2
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,21 @@ static_assert(CUB_MAX_DEVICES > 0, "CUB_MAX_DEVICES must be greater than 0.");
136136
# define CUB_PTX_PREFER_CONFLICT_OVER_PADDING CUB_PREFER_CONFLICT_OVER_PADDING(0)
137137
# endif
138138

139+
namespace detail
140+
{
141+
// The maximum amount of static shared memory available per thread block
142+
// Note that in contrast to dynamic shared memory, static shared memory is still limited to 48 KB
143+
static constexpr ::cuda::std::size_t max_smem_per_block = 48 * 1024;
144+
} // namespace detail
145+
139146
template <int NOMINAL_4B_BLOCK_THREADS, int NOMINAL_4B_ITEMS_PER_THREAD, typename T>
140147
struct RegBoundScaling
141148
{
142149
enum
143150
{
144151
ITEMS_PER_THREAD = CUB_MAX(1, NOMINAL_4B_ITEMS_PER_THREAD * 4 / CUB_MAX(4, sizeof(T))),
145-
BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS, (((1024 * 48) / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32),
152+
BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS,
153+
((cub::detail::max_smem_per_block / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32),
146154
};
147155
};
148156

@@ -153,7 +161,8 @@ struct MemBoundScaling
153161
{
154162
ITEMS_PER_THREAD =
155163
CUB_MAX(1, CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(T), NOMINAL_4B_ITEMS_PER_THREAD * 2)),
156-
BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS, (((1024 * 48) / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32),
164+
BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS,
165+
((cub::detail::max_smem_per_block / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32),
157166
};
158167
};
159168

cub/cub/util_type.cuh

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#endif // no system header
4545

4646
#include <cub/detail/uninitialized_copy.cuh>
47+
#include <cub/util_deprecated.cuh>
4748

4849
#include <cuda/std/cstdint>
4950
#include <cuda/std/limits>

cub/cub/util_vsmem.cuh

+1-4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
# pragma system_header
4343
#endif // no system header
4444

45+
#include <cub/util_arch.cuh>
4546
#include <cub/util_device.cuh>
4647
#include <cub/util_ptx.cuh>
4748
#include <cub/util_type.cuh>
@@ -67,10 +68,6 @@ struct vsmem_t
6768
void* gmem_ptr;
6869
};
6970

70-
// The maximum amount of static shared memory available per thread block
71-
// Note that in contrast to dynamic shared memory, static shared memory is still limited to 48 KB
72-
static constexpr std::size_t max_smem_per_block = 48 * 1024;
73-
7471
/**
7572
* @brief Class template that helps to prevent exceeding the available shared memory per thread block.
7673
*

cub/test/catch2_test_block_load.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <cub/block/block_load.cuh>
2929
#include <cub/iterator/cache_modified_input_iterator.cuh>
3030
#include <cub/util_allocator.cuh>
31+
#include <cub/util_arch.cuh>
3132

3233
#include "catch2_test_helper.h"
3334

@@ -113,7 +114,7 @@ void block_load(InputIteratorT input, OutputIteratorT output, int num_items)
113114
using input_t = cub::detail::value_t<InputIteratorT>;
114115
using block_load_t = cub::BlockLoad<input_t, ThreadsInBlock, ItemsPerThread, LoadAlgorithm>;
115116
using storage_t = typename block_load_t::TempStorage;
116-
constexpr bool sufficient_resources = sizeof(storage_t) <= 1024 * 48;
117+
constexpr bool sufficient_resources = sizeof(storage_t) <= cub::detail::max_smem_per_block;
117118

118119
kernel<InputIteratorT, OutputIteratorT, ItemsPerThread, ThreadsInBlock, LoadAlgorithm>
119120
<<<1, ThreadsInBlock>>>(std::integral_constant<bool, sufficient_resources>{}, input, output, num_items);

cub/test/catch2_test_block_store.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cub/iterator/cache_modified_output_iterator.cuh>
3030
#include <cub/iterator/discard_output_iterator.cuh>
3131
#include <cub/util_allocator.cuh>
32+
#include <cub/util_arch.cuh>
3233

3334
#include "catch2_test_helper.h"
3435

@@ -114,7 +115,7 @@ void block_store(InputIteratorT input, OutputIteratorT output, int num_items)
114115
using input_t = cub::detail::value_t<InputIteratorT>;
115116
using block_store_t = cub::BlockStore<input_t, ThreadsInBlock, ItemsPerThread, StoreAlgorithm>;
116117
using storage_t = typename block_store_t::TempStorage;
117-
constexpr bool sufficient_resources = sizeof(storage_t) <= 1024 * 48;
118+
constexpr bool sufficient_resources = sizeof(storage_t) <= cub::detail::max_smem_per_block;
118119

119120
kernel<InputIteratorT, OutputIteratorT, ItemsPerThread, ThreadsInBlock, StoreAlgorithm>
120121
<<<1, ThreadsInBlock>>>(std::integral_constant<bool, sufficient_resources>{}, input, output, num_items);

cub/test/test_block_radix_rank.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <cub/block/block_store.cuh>
3535
#include <cub/block/radix_rank_sort_operations.cuh>
3636
#include <cub/util_allocator.cuh>
37+
#include <cub/util_vsmem.cuh>
3738

3839
#include <algorithm>
3940
#include <iostream>
@@ -240,7 +241,7 @@ void Test()
240241
cub::detail::block_radix_rank_t<RankAlgorithm, BlockThreads, RadixBits, Descending, ScanAlgorithm>;
241242
using storage_t = typename block_radix_rank::TempStorage;
242243

243-
cub::Int2Type<(sizeof(storage_t) <= 48 * 1024)> fits_smem_capacity;
244+
cub::Int2Type<(sizeof(storage_t) <= cub::detail::max_smem_per_block)> fits_smem_capacity;
244245

245246
TestValid<RankAlgorithm, BlockThreads, ItemsPerThread, RadixBits, ScanAlgorithm, Descending, Key>(fits_smem_capacity);
246247
}

0 commit comments

Comments
 (0)