Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions projects/rocprim/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projec
* Added support for building tests with device-side random data generation, making them finish faster. This requires rocRAND, and is enabled with the `WITH_ROCRAND=ON` build flag.
* Added tests and documentation to `lookback_scan_state`. It is still in the `detail` namespace.

### Optimizations

* Improved performance of `rocprim::device_select` and `rocprim::device_partition` when using multiple streams on the MI3XX architecture.

### Changed

* Changed the parameters `long_radix_bits` and `LongRadixBits` from `segmented_radix_sort` to `radix_bits` and `RadixBits` respectively.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
#include "../../block/block_store.hpp"
#include "../../block/block_scan.hpp"
#include "../../block/block_discontinuity.hpp"
#include "ordered_block_id.hpp"

#include "../config_types.hpp"
#include "device_config_helper.hpp"
#include "lookback_scan_state.hpp"
#include "rocprim/intrinsics/thread.hpp"
#include "rocprim/type_traits.hpp"
#include "rocprim/types/tuple.hpp"

Expand Down Expand Up @@ -914,7 +916,8 @@ template<select_method SelectMethod,
class Key,
class Value,
class FlagType,
class OffsetType>
class OffsetType,
class BlockIdWrapper>
struct partition_kernel_impl_
{

Expand All @@ -934,6 +937,7 @@ struct partition_kernel_impl_
using block_scan_offset_type
= ::rocprim::block_scan<OffsetType, block_size, params.block_scan_method>;
using block_discontinuity_key_type = ::rocprim::block_discontinuity<Key, block_size>;
using ordered_block_id = BlockIdWrapper;

// Memory required for 2-phase scatter
using exchange_keys_storage_type = Key[items_per_block];
Expand All @@ -952,6 +956,7 @@ struct partition_kernel_impl_
typename block_load_flag_type::storage_type load_flags;
typename block_discontinuity_key_type::storage_type discontinuity_values;
typename block_scan_offset_type::storage_type scan_offsets;
typename ordered_block_id::storage_type block_id;
};

template<
Expand All @@ -976,6 +981,7 @@ struct partition_kernel_impl_
InequalityOp,
OffsetLookbackScanState,
const unsigned int,
ordered_block_id,
storage_type&,
UnaryPredicates...)
-> std::enable_if_t<!is_lookback_kernel_runnable<OffsetLookbackScanState>()>
Expand Down Expand Up @@ -1005,6 +1011,7 @@ struct partition_kernel_impl_
InequalityOp inequality_op,
OffsetLookbackScanState offset_scan_state,
const unsigned int number_of_blocks,
ordered_block_id block_id,
storage_type& storage,
UnaryPredicates... predicates)
-> std::enable_if_t<is_lookback_kernel_runnable<OffsetLookbackScanState>()>
Expand All @@ -1023,9 +1030,11 @@ struct partition_kernel_impl_
size_t prev_selected_count_values[sizeof...(UnaryPredicates)]{};
load_selected_count(prev_selected_count, prev_selected_count_values);

const auto flat_block_thread_id = ::rocprim::detail::block_thread_id<0>();
const auto flat_block_id = ::rocprim::detail::block_id<0>();
const auto block_offset = flat_block_id * items_per_block;
const auto flat_block_thread_id = ::rocprim::detail::block_thread_id<0>();
const auto flat_block_id = block_id.get(flat_block_thread_id, storage.block_id);
::rocprim::syncthreads(); // sync threads to reuse shared memory

const auto block_offset = flat_block_id * items_per_block;
const unsigned int valid_in_global_last_block
= total_size - prev_processed - items_per_block * (number_of_blocks - 1);
const bool is_last_launch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
#define ROCPRIM_DEVICE_DETAIL_ORDERED_BLOCK_ID_HPP_

#include <type_traits>
#include <limits>

#include "../../detail/temp_storage.hpp"
#include "../../detail/various.hpp"
#include "../../intrinsics.hpp"
#include "../../types.hpp"
#include "../../intrinsics/atomic.hpp"
#include "../../intrinsics/thread.hpp"

BEGIN_ROCPRIM_NAMESPACE

Expand Down Expand Up @@ -86,6 +84,104 @@ struct ordered_block_id
id_type* id;
};

template<class T = unsigned int, bool UsingOrderedBlockId = false>
struct block_id_wrapper;

template<class T>
struct block_id_wrapper<T, false>
{
using id_type = T;

// shared memory temporary storage type
struct storage_type
{};

ROCPRIM_HOST
static inline block_id_wrapper create(id_type* /*id*/)
{
block_id_wrapper ordered_id;
return ordered_id;
}

ROCPRIM_HOST
static inline size_t get_storage_size()
{
return 0;
}

ROCPRIM_HOST
static inline detail::temp_storage::layout get_temp_storage_layout()
{
return detail::temp_storage::layout{get_storage_size(), 0};
}

ROCPRIM_DEVICE ROCPRIM_INLINE
void reset()
{}

ROCPRIM_HOST ROCPRIM_INLINE
hipError_t reset_from_host(const hipStream_t /*stream*/)
{
return hipSuccess;
}

ROCPRIM_DEVICE ROCPRIM_INLINE
id_type get(unsigned int /*tid*/, storage_type& /*storage*/)
{
return ::rocprim::detail::block_id<0>();
}
};

template<class T>
struct block_id_wrapper<T, true>
{
using id_type = T;

using storage_type = typename ::rocprim::detail::ordered_block_id<id_type>::storage_type;

ROCPRIM_HOST
static inline block_id_wrapper create(id_type* id)
{
block_id_wrapper id_wrapper;
id_wrapper.ordered_id = detail::ordered_block_id<id_type>::create(id);
return id_wrapper;
}

ROCPRIM_HOST
static inline size_t get_storage_size()
{
return ::rocprim::detail::ordered_block_id<id_type>::get_storage_size();
}

ROCPRIM_HOST
static inline detail::temp_storage::layout get_temp_storage_layout()
{
return ::rocprim::detail::ordered_block_id<id_type>::get_temp_storage_layout();
}

ROCPRIM_DEVICE ROCPRIM_INLINE
void reset()
{
ordered_id.reset();
}

ROCPRIM_HOST ROCPRIM_INLINE
hipError_t reset_from_host(const hipStream_t stream)
{
return hipMemsetAsync(ordered_id.id, 0, sizeof(id_type), stream);
}

ROCPRIM_DEVICE ROCPRIM_INLINE
id_type get(unsigned int tid, storage_type& storage)
{
auto id = ordered_id.get(tid, storage);
::rocprim::syncthreads();
return id;
}

::rocprim::detail::ordered_block_id<id_type> ordered_id;
};

} // end of detail namespace

END_ROCPRIM_NAMESPACE
Expand Down
Loading