Skip to content

Commit

Permalink
Fix NCCL reduce-scatterv. Closes #110.
Browse files Browse the repository at this point in the history
  • Loading branch information
ndryden committed Mar 5, 2021
1 parent 7ae563c commit f865c07
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions include/aluminum/nccl_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1124,12 +1124,10 @@ class NCCLBackend {
// Rank 0 is the root.
size_t count = std::accumulate(counts.begin(), counts.end(), 0);
std::vector<size_t> displs = excl_prefix_sum(counts);
// Need a temporary reduce buffer when we can't trash it.
T* tmp_redbuf = recvbuf;
// Need a temporary reduce buffer so we don't trash the entire thing.
T* tmp_redbuf = internal::get_gpu_memory<T>(count, stream);
if (sendbuf == internal::IN_PLACE<T>()) {
sendbuf = recvbuf;
} else {
tmp_redbuf = internal::get_gpu_memory<T>(count, stream);
}
do_reduce(sendbuf, tmp_redbuf, count, op, 0, comm, stream);
do_scatterv(tmp_redbuf, recvbuf, counts, displs, 0, comm, stream);
Expand Down

0 comments on commit f865c07

Please sign in to comment.