diff --git a/projects/rocprim/CHANGELOG.md b/projects/rocprim/CHANGELOG.md index 041a513b75e..1322576ffd9 100644 --- a/projects/rocprim/CHANGELOG.md +++ b/projects/rocprim/CHANGELOG.md @@ -26,6 +26,10 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projec * Added support for building tests with device-side random data generation, making them finish faster. This requires rocRAND, and is enabled with the `WITH_ROCRAND=ON` build flag. * Added tests and documentation to `lookback_scan_state`. It is still in the `detail` namespace. +### Optimizations + +* Improved performance of `rocprim::device_select` and `rocprim::device_partition` when using multiple streams on the MI3XX architecture. + ### Changed * Changed the parameters `long_radix_bits` and `LongRadixBits` from `segmented_radix_sort` to `radix_bits` and `RadixBits` respectively. diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_partition.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_partition.hpp index ce0930eb908..b796050c7b3 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_partition.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_partition.hpp @@ -33,10 +33,12 @@ #include "../../block/block_store.hpp" #include "../../block/block_scan.hpp" #include "../../block/block_discontinuity.hpp" +#include "ordered_block_id.hpp" #include "../config_types.hpp" #include "device_config_helper.hpp" #include "lookback_scan_state.hpp" +#include "rocprim/intrinsics/thread.hpp" #include "rocprim/type_traits.hpp" #include "rocprim/types/tuple.hpp" @@ -914,7 +916,8 @@ template + class OffsetType, + class BlockIdWrapper> struct partition_kernel_impl_ { @@ -934,6 +937,7 @@ struct partition_kernel_impl_ using block_scan_offset_type = ::rocprim::block_scan; using block_discontinuity_key_type = ::rocprim::block_discontinuity; + using ordered_block_id = BlockIdWrapper; // Memory required for 2-phase scatter using exchange_keys_storage_type = Key[items_per_block]; @@ -952,6 +956,7 @@ struct partition_kernel_impl_ typename block_load_flag_type::storage_type load_flags; typename block_discontinuity_key_type::storage_type discontinuity_values; typename block_scan_offset_type::storage_type scan_offsets; + typename ordered_block_id::storage_type block_id; }; template< @@ -976,6 +981,7 @@ struct partition_kernel_impl_ InequalityOp, OffsetLookbackScanState, const unsigned int, + ordered_block_id, storage_type&, UnaryPredicates...) -> std::enable_if_t()> @@ -1005,6 +1011,7 @@ struct partition_kernel_impl_ InequalityOp inequality_op, OffsetLookbackScanState offset_scan_state, const unsigned int number_of_blocks, + ordered_block_id block_id, storage_type& storage, UnaryPredicates... predicates) -> std::enable_if_t()> @@ -1023,9 +1030,11 @@ struct partition_kernel_impl_ size_t prev_selected_count_values[sizeof...(UnaryPredicates)]{}; load_selected_count(prev_selected_count, prev_selected_count_values); - const auto flat_block_thread_id = ::rocprim::detail::block_thread_id<0>(); - const auto flat_block_id = ::rocprim::detail::block_id<0>(); - const auto block_offset = flat_block_id * items_per_block; + const auto flat_block_thread_id = ::rocprim::detail::block_thread_id<0>(); + const auto flat_block_id = block_id.get(flat_block_thread_id, storage.block_id); + ::rocprim::syncthreads(); // sync threads to reuse shared memory + + const auto block_offset = flat_block_id * items_per_block; const unsigned int valid_in_global_last_block = total_size - prev_processed - items_per_block * (number_of_blocks - 1); const bool is_last_launch diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/ordered_block_id.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/ordered_block_id.hpp index 4683f8f9757..d67b503e3b9 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/ordered_block_id.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/ordered_block_id.hpp @@ -22,12 +22,10 @@ #define ROCPRIM_DEVICE_DETAIL_ORDERED_BLOCK_ID_HPP_ #include -#include #include "../../detail/temp_storage.hpp" -#include "../../detail/various.hpp" -#include "../../intrinsics.hpp" -#include "../../types.hpp" +#include "../../intrinsics/atomic.hpp" +#include "../../intrinsics/thread.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -86,6 +84,104 @@ struct ordered_block_id id_type* id; }; +template +struct block_id_wrapper; + +template +struct block_id_wrapper +{ + using id_type = T; + + // shared memory temporary storage type + struct storage_type + {}; + + ROCPRIM_HOST + static inline block_id_wrapper create(id_type* /*id*/) + { + block_id_wrapper ordered_id; + return ordered_id; + } + + ROCPRIM_HOST + static inline size_t get_storage_size() + { + return 0; + } + + ROCPRIM_HOST + static inline detail::temp_storage::layout get_temp_storage_layout() + { + return detail::temp_storage::layout{get_storage_size(), 0}; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void reset() + {} + + ROCPRIM_HOST ROCPRIM_INLINE + hipError_t reset_from_host(const hipStream_t /*stream*/) + { + return hipSuccess; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + id_type get(unsigned int /*tid*/, storage_type& /*storage*/) + { + return ::rocprim::detail::block_id<0>(); + } +}; + +template +struct block_id_wrapper +{ + using id_type = T; + + using storage_type = typename ::rocprim::detail::ordered_block_id::storage_type; + + ROCPRIM_HOST + static inline block_id_wrapper create(id_type* id) + { + block_id_wrapper id_wrapper; + id_wrapper.ordered_id = detail::ordered_block_id::create(id); + return id_wrapper; + } + + ROCPRIM_HOST + static inline size_t get_storage_size() + { + return ::rocprim::detail::ordered_block_id::get_storage_size(); + } + + ROCPRIM_HOST + static inline detail::temp_storage::layout get_temp_storage_layout() + { + return ::rocprim::detail::ordered_block_id::get_temp_storage_layout(); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void reset() + { + ordered_id.reset(); + } + + ROCPRIM_HOST ROCPRIM_INLINE + hipError_t reset_from_host(const hipStream_t stream) + { + return hipMemsetAsync(ordered_id.id, 0, sizeof(id_type), stream); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + id_type get(unsigned int tid, storage_type& storage) + { + auto id = ordered_id.get(tid, storage); + ::rocprim::syncthreads(); + return id; + } + + ::rocprim::detail::ordered_block_id ordered_id; +}; + } // end of detail namespace END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp index e7e8d6d3ff2..2714864400b 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp @@ -41,6 +41,7 @@ #include "device_transform.hpp" #include "rocprim/detail/virtual_shared_memory.hpp" #include "rocprim/device/config_types.hpp" +#include "rocprim/device/detail/ordered_block_id.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -60,22 +61,23 @@ template -ROCPRIM_KERNEL - ROCPRIM_LAUNCH_BOUNDS(device_params().kernel_config.block_size) void - partition_kernel(KeyIterator keys_input, - ValueIterator values_input, - FlagIterator flags, - OutputKeyIterator keys_output, - OutputValueIterator values_output, - size_t* selected_count, - size_t* prev_selected_count, - size_t prev_processed, - const size_t total_size, - InequalityOp inequality_op, - OffsetLookbackScanState offset_scan_state, - const unsigned int number_of_blocks, - detail::vsmem_t vsmem, +ROCPRIM_KERNEL ROCPRIM_LAUNCH_BOUNDS(device_params().kernel_config.block_size) void + partition_kernel(KeyIterator keys_input, + ValueIterator values_input, + FlagIterator flags, + OutputKeyIterator keys_output, + OutputValueIterator values_output, + size_t* selected_count, + size_t* prev_selected_count, + size_t prev_processed, + const size_t total_size, + InequalityOp inequality_op, + OffsetLookbackScanState offset_scan_state, + const unsigned int number_of_blocks, + detail::vsmem_t vsmem, + BlockIdWrapper block_id, UnaryPredicates... predicates) { using offset_type = typename OffsetLookbackScanState::value_type; @@ -92,7 +94,8 @@ ROCPRIM_KERNEL key_type, value_type, flag_type, - offset_type>; + offset_type, + BlockIdWrapper>; using VSmemHelperT = detail::vsmem_helper_impl; ROCPRIM_SHARED_MEMORY typename VSmemHelperT::static_temp_storage_t static_temp_storage; @@ -112,6 +115,7 @@ ROCPRIM_KERNEL inequality_op, offset_scan_state, number_of_blocks, + block_id, storage, predicates...); } @@ -122,7 +126,8 @@ template + class OffsetLookbackScanState, + class BlockIdWrapper> inline size_t get_partition_vsmem_size_per_block() { using offset_type = typename OffsetLookbackScanState::value_type; @@ -132,13 +137,15 @@ inline size_t get_partition_vsmem_size_per_block() Key, Value, FlagType, - offset_type>; + offset_type, + BlockIdWrapper>; using PartitionVSmemHelperT = detail::vsmem_helper_impl; return PartitionVSmemHelperT::vsmem_per_block; } template; + + typename block_id_wrapper_type::id_type* block_id_pool = nullptr; + bool use_sleep; ROCPRIM_RETURN_ON_ERROR(is_sleep_scan_state_used(stream, use_sleep)); @@ -245,7 +256,8 @@ inline hipError_t partition_impl(void* temporary_storage, key_type, value_type, flag_type, - offset_scan_state_with_sleep_type>(); + offset_scan_state_with_sleep_type, + block_id_wrapper_type>(); } else { @@ -255,11 +267,13 @@ inline hipError_t partition_impl(void* temporary_storage, key_type, value_type, flag_type, - offset_scan_state_type>(); + offset_scan_state_type, + block_id_wrapper_type>(); } virtual_shared_memory_size *= number_of_blocks; - + // temporary storage partition + result = detail::temp_storage::partition( temporary_storage, storage_size, @@ -271,15 +285,19 @@ inline hipError_t partition_impl(void* temporary_storage, // They have the same base type, so there is no padding between the types. detail::temp_storage::ptr_aligned_array(&selected_count, selected_count_size), detail::temp_storage::ptr_aligned_array(&prev_selected_count, selected_count_size), + detail::temp_storage::ptr_aligned_array(&block_id_pool, block_id_wrapper_type::get_storage_size()), // vsmem detail::temp_storage::make_partition(&vsmem, virtual_shared_memory_size, cache_line_size))); + if(result != hipSuccess || temporary_storage == nullptr) { return result; } + auto block_id = block_id_wrapper_type::create(block_id_pool); + // Start point for time measurements std::chrono::steady_clock::time_point start; @@ -371,6 +389,15 @@ inline hipError_t partition_impl(void* temporary_storage, current_number_of_blocks, start); + if (UsingOrderedBlockId) + { + result = hipMemsetAsync(block_id_pool, 0, sizeof(unsigned int), stream); + if (result != hipSuccess) + { + return result; + } + } + if(debug_synchronous) start = std::chrono::steady_clock::now(); with_scan_state( @@ -391,6 +418,7 @@ inline hipError_t partition_impl(void* temporary_storage, scan_state, current_number_of_blocks, detail::vsmem_t{vsmem}, + block_id, predicates...); }); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("partition_kernel", size, start); @@ -515,7 +543,8 @@ template + class Predicate, + bool UsingOrderedBlockId = false> inline hipError_t partition_two_way(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -542,6 +571,7 @@ inline hipError_t partition_two_way(void* temporary_storag const output_value_iterator_tuple no_output_values{nullptr, nullptr}; // key only return detail::partition_impl(temporary_storage, storage_size, @@ -655,7 +685,8 @@ template + typename SelectedCountOutputIterator, + bool UsingOrderedBlockId = false> inline hipError_t partition_two_way(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -680,6 +711,7 @@ inline hipError_t partition_two_way(void* temporary_storag const output_value_iterator_tuple no_output_values{nullptr, nullptr}; // key only return detail::partition_impl(temporary_storage, storage_size, @@ -779,8 +811,8 @@ template< class InputIterator, class FlagIterator, class OutputIterator, - class SelectedCountOutputIterator -> + class SelectedCountOutputIterator, + bool UsingOrderedBlockId = false> inline hipError_t partition(void * temporary_storage, size_t& storage_size, @@ -804,20 +836,22 @@ hipError_t partition(void * temporary_storage, using output_value_iterator_tuple = tuple<::rocprim::empty_type*, ::rocprim::empty_type*>; const output_value_iterator_tuple no_output_values{nullptr, nullptr}; // key only - return detail::partition_impl( - temporary_storage, - storage_size, - input, - no_input_values, - flags, - keys_output, - no_output_values, - selected_count_output, - size, - inequality_op_type(), - stream, - debug_synchronous, - unary_predicate_type()); + return detail::partition_impl(temporary_storage, + storage_size, + input, + no_input_values, + flags, + keys_output, + no_output_values, + selected_count_output, + size, + inequality_op_type(), + stream, + debug_synchronous, + unary_predicate_type()); } /// \brief Parallel select primitive for device level using selection predicate. @@ -910,8 +944,8 @@ template< class InputIterator, class OutputIterator, class SelectedCountOutputIterator, - class UnaryPredicate -> + class UnaryPredicate, + bool UsingOrderedBlockId = false> inline hipError_t partition(void * temporary_storage, size_t& storage_size, @@ -938,6 +972,7 @@ hipError_t partition(void * temporary_storage, const output_value_iterator_tuple no_output_values{nullptr, nullptr}; // key only return detail::partition_impl(temporary_storage, storage_size, @@ -1087,7 +1122,8 @@ template < typename UnselectedOutputIterator, typename SelectedCountOutputIterator, typename FirstUnaryPredicate, - typename SecondUnaryPredicate> + typename SecondUnaryPredicate, + bool UsingOrderedBlockId = false> inline hipError_t partition_three_way(void * temporary_storage, size_t& storage_size, @@ -1120,6 +1156,7 @@ hipError_t partition_three_way(void * temporary_storage, output_key_iterator_tuple output{ output_first_part, output_second_part, output_unselected }; return detail::partition_impl(temporary_storage, storage_size, diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_select.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_select.hpp index a7d65d1bf75..51cbca6b1e4 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_select.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_select.hpp @@ -125,7 +125,8 @@ template< class InputIterator, class FlagIterator, class OutputIterator, - class SelectedCountOutputIterator + class SelectedCountOutputIterator, + bool UsingOrderedBlockId = false > inline hipError_t select(void * temporary_storage, @@ -151,20 +152,22 @@ hipError_t select(void * temporary_storage, using output_value_iterator_tuple = tuple<::rocprim::empty_type*, ::rocprim::empty_type*>; const output_value_iterator_tuple no_output_values{nullptr, nullptr}; // key only - return detail::partition_impl( - temporary_storage, - storage_size, - input, - no_values, - flags, - output_tuple, - no_output_values, - selected_count_output, - size, - inequality_op_type(), - stream, - debug_synchronous, - unary_predicate_type()); + return detail::partition_impl(temporary_storage, + storage_size, + input, + no_values, + flags, + output_tuple, + no_output_values, + selected_count_output, + size, + inequality_op_type(), + stream, + debug_synchronous, + unary_predicate_type()); } /// \brief Parallel select primitive for device level using selection operator. @@ -257,8 +260,8 @@ template< class InputIterator, class OutputIterator, class SelectedCountOutputIterator, - class UnaryPredicate -> + class UnaryPredicate, + bool UsingOrderedBlockId = false> inline hipError_t select(void * temporary_storage, size_t& storage_size, @@ -284,20 +287,22 @@ hipError_t select(void * temporary_storage, using output_value_iterator_tuple = tuple<::rocprim::empty_type*, ::rocprim::empty_type*>; const output_value_iterator_tuple no_output_values{nullptr, nullptr}; // key only - return detail::partition_impl( - temporary_storage, - storage_size, - input, - no_values, - flags, - output_tuple, - no_output_values, - selected_count_output, - size, - inequality_op_type(), - stream, - debug_synchronous, - predicate); + return detail::partition_impl(temporary_storage, + storage_size, + input, + no_values, + flags, + output_tuple, + no_output_values, + selected_count_output, + size, + inequality_op_type(), + stream, + debug_synchronous, + predicate); } /// \brief Parallel select primitive for device level using a range of pre-selected flags. @@ -396,7 +401,8 @@ template + class UnaryPredicate, + bool UsingOrderedBlockId = false> inline hipError_t select(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -420,6 +426,7 @@ inline hipError_t select(void* temporary_storage, const output_value_iterator_tuple no_output_values{nullptr, nullptr}; // key only return detail::partition_impl(temporary_storage, storage_size, @@ -516,8 +523,8 @@ template< class InputIterator, class OutputIterator, class UniqueCountOutputIterator, - class EqualityOp = ::rocprim::equal_to::value_type> -> + class EqualityOp = ::rocprim::equal_to::value_type>, + bool UsingOrderedBlockId = false> inline hipError_t unique(void * temporary_storage, size_t& storage_size, @@ -546,20 +553,22 @@ hipError_t unique(void * temporary_storage, using output_value_iterator_tuple = tuple<::rocprim::empty_type*, ::rocprim::empty_type*>; const output_value_iterator_tuple no_output_values{nullptr, nullptr}; // key only - return detail::partition_impl( - temporary_storage, - storage_size, - input, - no_values, - flags, - output_tuple, - no_output_values, - unique_count_output, - size, - inequality_op, - stream, - debug_synchronous, - unary_predicate_type()); + return detail::partition_impl(temporary_storage, + storage_size, + input, + no_values, + flags, + output_tuple, + no_output_values, + unique_count_output, + size, + inequality_op, + stream, + debug_synchronous, + unary_predicate_type()); } /// \brief Device-level parallel unique by key primitive. @@ -615,7 +624,8 @@ template ::value_type>> + = ::rocprim::equal_to::value_type>, + bool UsingOrderedBlockId = false> inline hipError_t unique_by_key(void* temporary_storage, size_t& storage_size, const KeyIterator keys_input, @@ -644,6 +654,7 @@ inline hipError_t unique_by_key(void* temporary_storag const output_value_iterator_tuple output_value_tuple{values_output, nullptr}; return detail::partition_impl(temporary_storage, storage_size,