diff --git a/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp b/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp index 12ad70041..f6e043709 100644 --- a/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp +++ b/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp @@ -235,11 +235,12 @@ class block_sort_bitonic return !r; }; wsort.sort(kv..., compare_function2); - + #pragma unroll for(unsigned int length = ::rocprim::warp_size(); length < Size; length *= 2) { bool dir = (flat_tid & (length * 2)) != 0; + #pragma unroll for(unsigned int k = length; k > 0; k /= 2) { copy_to_shared(kv..., flat_tid, storage); diff --git a/rocprim/include/rocprim/device/config_types.hpp b/rocprim/include/rocprim/device/config_types.hpp index cad90a0b5..b01c38710 100644 --- a/rocprim/include/rocprim/device/config_types.hpp +++ b/rocprim/include/rocprim/device/config_types.hpp @@ -52,6 +52,39 @@ struct kernel_config namespace detail { +template< + unsigned int MaxBlockSize, + unsigned int SharedMemoryPerThread, + // Most kernels require block sizes not smaller than warp + unsigned int MinBlockSize = ::rocprim::warp_size(), + // Can fit in shared memory? + // Although GPUs have 64KiB, 32KiB is used here as a "soft" limit, + // because some additional memory may be required in kernels + bool = (MaxBlockSize * SharedMemoryPerThread <= (1u << 15)) +> +struct limit_block_size +{ + // No, then try to decrease block size + static constexpr unsigned int value = + limit_block_size< + detail::next_power_of_two(MaxBlockSize) / 2, + SharedMemoryPerThread, + MinBlockSize + >::value; +}; + +template< + unsigned int MaxBlockSize, + unsigned int SharedMemoryPerThread, + unsigned int MinBlockSize +> +struct limit_block_size +{ + static_assert(MaxBlockSize >= MinBlockSize, "Data is too large, it cannot fit in shared memory"); + + static constexpr unsigned int value = MaxBlockSize; +}; + template using void_t = void; diff --git a/rocprim/include/rocprim/device/detail/device_binary_search.hpp b/rocprim/include/rocprim/device/detail/device_binary_search.hpp index 02effbd16..5e4b54683 100644 --- a/rocprim/include/rocprim/device/detail/device_binary_search.hpp +++ b/rocprim/include/rocprim/device/detail/device_binary_search.hpp @@ -27,6 +27,7 @@ namespace detail { template +ROCPRIM_DEVICE inline Size get_binary_search_middle(Size left, Size right) { // Instead of `/ 2` we use `* 33 / 64`, i.e. the middle is slightly moved. @@ -40,7 +41,7 @@ Size get_binary_search_middle(Size left, Size right) } template -ROCPRIM_DEVICE +ROCPRIM_DEVICE inline Size lower_bound_n(RandomAccessIterator first, Size size, const T& value, @@ -64,7 +65,7 @@ Size lower_bound_n(RandomAccessIterator first, } template -ROCPRIM_DEVICE +ROCPRIM_DEVICE inline Size upper_bound_n(RandomAccessIterator first, Size size, const T& value, @@ -90,7 +91,7 @@ Size upper_bound_n(RandomAccessIterator first, struct lower_bound_search_op { template - ROCPRIM_DEVICE + ROCPRIM_DEVICE inline Size operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const { return lower_bound_n(haystack, size, value, compare_op); @@ -100,7 +101,7 @@ struct lower_bound_search_op struct upper_bound_search_op { template - ROCPRIM_DEVICE + ROCPRIM_DEVICE inline Size operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const { return upper_bound_n(haystack, size, value, compare_op); @@ -110,7 +111,7 @@ struct upper_bound_search_op struct binary_search_op { template - ROCPRIM_DEVICE + ROCPRIM_DEVICE inline bool operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const { const Size n = lower_bound_n(haystack, size, value, compare_op); diff --git a/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp b/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp index 23d1f684c..a1c5b1174 100644 --- a/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp +++ b/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp @@ -323,21 +323,28 @@ class lookback_scan_prefix_op T get_prefix() { flag_type flag; - T partial_prefix; unsigned int previous_block_id = block_id_ - ::rocprim::lane_id() - 1; + bool is_prefix_initialized = false; + T prefix; - // reduce last warp_size() number of prefixes to - // get the complete prefix for this block. - reduce_partial_prefixes(previous_block_id, flag, partial_prefix); - T prefix = partial_prefix; - - // while we don't load a complete prefix, reduce partial prefixes - while(::rocprim::detail::warp_all(flag != PREFIX_COMPLETE)) + do { - previous_block_id -= ::rocprim::warp_size(); + // reduce last warp_size() number of prefixes to + // get the complete prefix for this block. + T partial_prefix; reduce_partial_prefixes(previous_block_id, flag, partial_prefix); - prefix = scan_op_(partial_prefix, prefix); - } + if(!is_prefix_initialized) + { + prefix = partial_prefix; + is_prefix_initialized = true; + } + else + { + prefix = scan_op_(partial_prefix, prefix); + } + previous_block_id -= ::rocprim::warp_size(); + // while we don't load a complete prefix, reduce partial prefixes + } while(::rocprim::detail::warp_all(flag != PREFIX_COMPLETE)); return prefix; } diff --git a/rocprim/include/rocprim/device/device_merge_sort_config.hpp b/rocprim/include/rocprim/device/device_merge_sort_config.hpp index 3f932c241..5cccea7d2 100644 --- a/rocprim/include/rocprim/device/device_merge_sort_config.hpp +++ b/rocprim/include/rocprim/device/device_merge_sort_config.hpp @@ -42,47 +42,28 @@ using merge_sort_config = kernel_config; namespace detail { -// TODO investigate why some tests fail with block size > 256 template struct merge_sort_config_803 { - // static constexpr size_t key_value_size = sizeof(Key) + sizeof(Value); - // static constexpr unsigned int item_scale = - // ::rocprim::detail::ceiling_div(key_value_size, 8); - - // using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; - using type = merge_sort_config<256U>; + using type = merge_sort_config::value>; }; template struct merge_sort_config_803 { - // static constexpr unsigned int item_scale = - // ::rocprim::detail::ceiling_div(sizeof(Key), 8); - - // using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; - using type = merge_sort_config<256U>; + using type = merge_sort_config::value>; }; template struct merge_sort_config_900 { - // static constexpr size_t key_value_size = sizeof(Key) + sizeof(Value); - // static constexpr unsigned int item_scale = - // ::rocprim::detail::ceiling_div(key_value_size, 16); - - // using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; - using type = merge_sort_config<256U>; + using type = merge_sort_config::value>; }; template struct merge_sort_config_900 { - // static constexpr unsigned int item_scale = - // ::rocprim::detail::ceiling_div(sizeof(Key), 16); - - // using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; - using type = merge_sort_config<256U>; + using type = merge_sort_config::value>; }; template diff --git a/rocprim/include/rocprim/device/device_radix_sort_config.hpp b/rocprim/include/rocprim/device/device_radix_sort_config.hpp index 6a29f01d1..f2134d8aa 100644 --- a/rocprim/include/rocprim/device/device_radix_sort_config.hpp +++ b/rocprim/include/rocprim/device/device_radix_sort_config.hpp @@ -92,7 +92,13 @@ struct radix_sort_config_803 (sizeof(Key) == 8 && sizeof(Value) <= 8), radix_sort_config<7, 6, scan, kernel_config<256, 13> > >, - radix_sort_config<7, 6, scan, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)> > + radix_sort_config< + 6, 4, scan, + kernel_config< + limit_block_size<256U, sizeof(Value)>::value, + ::rocprim::max(1u, 15u / item_scale) + > + > >; }; @@ -130,7 +136,13 @@ struct radix_sort_config_900 (sizeof(Key) == 8 && sizeof(Value) <= 8), radix_sort_config<7, 6, scan, kernel_config<256, 15> > >, - radix_sort_config<7, 6, scan, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)> > + radix_sort_config< + 6, 4, scan, + kernel_config< + limit_block_size<256U, sizeof(Value)>::value, + ::rocprim::max(1u, 15u / item_scale) + > + > >; };