Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor and cleanup of mhp::sort() #664

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 110 additions & 88 deletions include/dr/mhp/algorithms/sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,25 @@ void local_sort(R &r, Compare &&comp) {
}
}

// TODO: quite a long function, refactor to make the code more clear
template <dr::distributed_range R, typename Compare>
void dist_sort(R &r, Compare &&comp) {
using valT = typename R::value_type;
/* elements of dist_sort */
template <typename valT, typename Compare, typename Seg>
void splitters(Seg &lsegment, Compare &&comp,
std::vector<std::size_t> &vec_split_i,
std::vector<std::size_t> &vec_split_s) {

const std::size_t _comm_rank = default_comm().rank();
const std::size_t _comm_size = default_comm().size(); // dr-style ignore

auto &&lsegment = local_segment(r);
/* sort local segment */

__detail::local_sort(lsegment, comp);

std::vector<valT> vec_split_v(_comm_size - 1);
std::vector<valT> vec_lmedians(_comm_size + 1);
std::vector<valT> vec_gmedians((_comm_size + 1) * _comm_size);

const double _step_m = static_cast<double>(rng::size(lsegment)) /
static_cast<double>(_comm_size);

/* calculate splitting values and indices - find n-1 dividers splitting each
* segment into equal parts */
/* calculate splitting values and indices - find n-1 dividers splitting
* each segment into equal parts */

for (std::size_t _i = 0; _i < rng::size(vec_lmedians); _i++) {
// vec_lmedians[_i] = lsegment[(_i + 1) * _step_m];
vec_lmedians[_i] = lsegment[_i * _step_m];
}
vec_lmedians.back() = lsegment.back();
Expand All @@ -79,20 +74,10 @@ void dist_sort(R &r, Compare &&comp) {

rng::sort(rng::begin(vec_gmedians), rng::end(vec_gmedians), comp);

/* find splitting values - medians of dividers */

std::vector<valT> vec_split_v(_comm_size - 1);

for (std::size_t _i = 0; _i < _comm_size - 1; _i++) {
vec_split_v[_i] = vec_gmedians[(_i + 1) * (_comm_size + 1) - 1];
}

/* calculate splitting indices (start of buffers) and sizes of buffers to send
*/

std::vector<std::size_t> vec_split_i(_comm_size, 0);
std::vector<std::size_t> vec_split_s(_comm_size, 0);

std::size_t segidx = 0, vidx = 1;

while (vidx < _comm_size && segidx < rng::size(lsegment)) {
Expand All @@ -107,77 +92,18 @@ void dist_sort(R &r, Compare &&comp) {
}
assert(rng::size(lsegment) > vec_split_i[vidx - 1]);
vec_split_s[vidx - 1] = rng::size(lsegment) - vec_split_i[vidx - 1];
}

/* send data size to each node */
std::vector<std::size_t> vec_rsizes(_comm_size, 0);
std::vector<std::size_t> vec_rindices(_comm_size, 0); // recv buffers

default_comm().alltoall(vec_split_s, vec_rsizes, 1);

std::exclusive_scan(vec_rsizes.begin(), vec_rsizes.end(),
vec_rindices.begin(), 0);

// const std::size_t _recv_elems =
// std::reduce(vec_rsizes.begin(), vec_rsizes.end());

const std::size_t _recv_elems = vec_rindices.back() + vec_rsizes.back();

/* send and receive data belonging to each node, then redistribute
* data to achieve size of data equal to size of local segment */

std::vector<std::size_t> vec_recv_elems(_comm_size);
MPI_Request req_recvelems;
MPI_Status stat_recvelemes;

default_comm().i_all_gather(_recv_elems, vec_recv_elems, &req_recvelems);

#ifdef SYCL_LANGUAGE_VERSION
auto policy = dpl_policy();
sycl::usm_allocator<valT, sycl::usm::alloc::host> alloc(policy.queue());
std::vector<valT, decltype(alloc)> vec_recvdata(_recv_elems, alloc);
#else
std::vector<valT> vec_recvdata(_recv_elems);
#endif

default_comm().alltoallv(lsegment, vec_split_s, vec_split_i, vec_recvdata,
vec_rsizes, vec_rindices);

/* vec recvdata is partially sorted, implementation of merge on GPU is
* desirable */
__detail::local_sort(vec_recvdata, comp);

MPI_Wait(&req_recvelems, &stat_recvelemes);

const std::size_t _total_elems =
std::reduce(vec_recv_elems.begin(), vec_recv_elems.end());

assert(_total_elems == rng::size(r));

std::vector<int> vec_shift(_comm_size - 1);

const auto desired_elems_num = (_total_elems + _comm_size - 1) / _comm_size;
template <typename valT>
void shift_data(const int shift_left, const int shift_right, auto &vec_recvdata,
auto &vec_left, auto &vec_right) {

vec_shift[0] = desired_elems_num - vec_recv_elems[0];
for (std::size_t _i = 1; _i < _comm_size - 1; _i++) {
vec_shift[_i] = vec_shift[_i - 1] + desired_elems_num - vec_recv_elems[_i];
}

const int shift_left = _comm_rank == 0 ? 0 : -vec_shift[_comm_rank - 1];
const int shift_right =
_comm_rank == _comm_size - 1 ? 0 : vec_shift[_comm_rank];
const std::size_t _comm_rank = default_comm().rank();

MPI_Request req_l, req_r;
MPI_Status stat_l, stat_r;
const communicator::tag t = communicator::tag::halo_index;

#ifdef SYCL_LANGUAGE_VERSION
std::vector<valT, decltype(alloc)> vec_left(std::max(shift_left, 0), alloc);
std::vector<valT, decltype(alloc)> vec_right(std::max(shift_right, 0), alloc);
#else
std::vector<valT> vec_left(std::max(shift_left, 0));
std::vector<valT> vec_right(std::max(shift_right, 0));
#endif

if (static_cast<int>(rng::size(vec_recvdata)) < -shift_left) {
// Too little data in recv buffer to shift left - first get from right, then
// send left
Expand Down Expand Up @@ -223,7 +149,11 @@ void dist_sort(R &r, Compare &&comp) {
if (shift_right != 0)
MPI_Wait(&req_r, &stat_r);
}
}

template <typename valT>
void copy_results(auto &lsegment, const int shift_left, const int shift_right,
auto &vec_recvdata, auto &vec_left, auto &vec_right) {
const std::size_t invalidate_left = std::max(-shift_left, 0);
const std::size_t invalidate_right = std::max(-shift_right, 0);

Expand All @@ -243,7 +173,6 @@ void dist_sort(R &r, Compare &&comp) {
lsegment.data() + size_l + size_d, size_r);
e_d = sycl_queue().copy(vec_recvdata.data() + invalidate_left,
lsegment.data() + size_l, size_d);

if (size_l > 0)
e_l.wait();
if (size_r > 0)
Expand All @@ -263,7 +192,100 @@ void dist_sort(R &r, Compare &&comp) {
std::memcpy(lsegment.data() + size_l, vec_recvdata.data() + invalidate_left,
size_d * sizeof(valT));
}
}

template <dr::distributed_range R, typename Compare>
void dist_sort(R &r, Compare &&comp) {

using valT = typename R::value_type;

const std::size_t _comm_rank = default_comm().rank();
const std::size_t _comm_size = default_comm().size(); // dr-style ignore

#ifdef SYCL_LANGUAGE_VERSION
auto policy = dpl_policy();
sycl::usm_allocator<valT, sycl::usm::alloc::host> alloc(policy.queue());
#endif

auto &&lsegment = local_segment(r);

std::vector<std::size_t> vec_split_i(_comm_size, 0);
std::vector<std::size_t> vec_split_s(_comm_size, 0);
std::vector<std::size_t> vec_rsizes(_comm_size, 0);
std::vector<std::size_t> vec_rindices(_comm_size, 0);
std::vector<std::size_t> vec_recv_elems(_comm_size, 0);
std::size_t _total_elems = 0;

__detail::local_sort(lsegment, comp);

/* find splitting values - limits of areas to send to other processes */
__detail::splitters<valT>(lsegment, comp, vec_split_i, vec_split_s);

default_comm().alltoall(vec_split_s, vec_rsizes, 1);

/* prepare data to send and receive */
std::exclusive_scan(vec_rsizes.begin(), vec_rsizes.end(),
vec_rindices.begin(), 0);

const std::size_t _recv_elems = vec_rindices.back() + vec_rsizes.back();

/* send and receive data belonging to each node, then redistribute
* data to achieve size of data equal to size of local segment */

MPI_Request req_recvelems;

default_comm().i_all_gather(_recv_elems, vec_recv_elems, &req_recvelems);

/* buffer for received data */
#ifdef SYCL_LANGUAGE_VERSION
std::vector<valT, decltype(alloc)> vec_recvdata(_recv_elems, alloc);
#else
std::vector<valT> vec_recvdata(_recv_elems);
#endif

/* send data not belonging and receive data belonging to local processes
*/
default_comm().alltoallv(lsegment, vec_split_s, vec_split_i, vec_recvdata,
vec_rsizes, vec_rindices);

/* TODO: vec recvdata is partially sorted, implementation of merge on GPU is
* desirable */
__detail::local_sort(vec_recvdata, comp);

MPI_Wait(&req_recvelems, MPI_STATUS_IGNORE);

_total_elems = std::reduce(vec_recv_elems.begin(), vec_recv_elems.end());

/* prepare data for shift to neighboring processes */
std::vector<int> vec_shift(_comm_size - 1);

const auto desired_elems_num = (_total_elems + _comm_size - 1) / _comm_size;

vec_shift[0] = desired_elems_num - vec_recv_elems[0];
for (std::size_t _i = 1; _i < _comm_size - 1; _i++) {
vec_shift[_i] = vec_shift[_i - 1] + desired_elems_num - vec_recv_elems[_i];
}

const int shift_left = _comm_rank == 0 ? 0 : -vec_shift[_comm_rank - 1];
const int shift_right =
_comm_rank == _comm_size - 1 ? 0 : vec_shift[_comm_rank];

#ifdef SYCL_LANGUAGE_VERSION
std::vector<valT, decltype(alloc)> vec_left(std::max(shift_left, 0), alloc);
std::vector<valT, decltype(alloc)> vec_right(std::max(shift_right, 0), alloc);
#else
std::vector<valT> vec_left(std::max(shift_left, 0));
std::vector<valT> vec_right(std::max(shift_right, 0));
#endif

/* shift data if necessary, to have exactly the number of elements equal to
* lsegment size */
__detail::shift_data<valT>(shift_left, shift_right, vec_recvdata, vec_left,
vec_right);

/* copy results to distributed vector's local segment */
__detail::copy_results<valT>(lsegment, shift_left, shift_right, vec_recvdata,
vec_left, vec_right);
} // __detail::dist_sort

} // namespace __detail
Expand Down