From f865c073366600016bfeb1c5759cb9eb21460805 Mon Sep 17 00:00:00 2001 From: Nikoli Dryden Date: Fri, 26 Feb 2021 07:17:45 -0800 Subject: [PATCH] Fix NCCL reduce-scatterv. Closes #110. --- include/aluminum/nccl_impl.hpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/include/aluminum/nccl_impl.hpp b/include/aluminum/nccl_impl.hpp index 97106501..71589ead 100644 --- a/include/aluminum/nccl_impl.hpp +++ b/include/aluminum/nccl_impl.hpp @@ -1124,12 +1124,10 @@ class NCCLBackend { // Rank 0 is the root. size_t count = std::accumulate(counts.begin(), counts.end(), 0); std::vector displs = excl_prefix_sum(counts); - // Need a temporary reduce buffer when we can't trash it. - T* tmp_redbuf = recvbuf; + // Need a temporary reduce buffer so we don't trash the entire thing. + T* tmp_redbuf = internal::get_gpu_memory(count, stream); if (sendbuf == internal::IN_PLACE()) { sendbuf = recvbuf; - } else { - tmp_redbuf = internal::get_gpu_memory(count, stream); } do_reduce(sendbuf, tmp_redbuf, count, op, 0, comm, stream); do_scatterv(tmp_redbuf, recvbuf, counts, displs, 0, comm, stream);