Skip to content

Commit

Permalink
Fix MPI-CUDA and RMA builds.
Browse files Browse the repository at this point in the history
  • Loading branch information
ndryden committed Mar 2, 2023
1 parent b98fb0b commit 5116016
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions include/aluminum/mpi_cuda/rma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ class RMA {
std::map<int, Connection *> m_connections;

void wait_for_completion(mpi::AlMPIReq& req) {
if (req == NULL_REQUEST) {
if (req == MPIBackend::null_req) {
return;
}
while (!req->load(std::memory_order_acquire)) {}
req = NULL_REQUEST;
req = MPIBackend::null_req;
}

public:
Expand Down Expand Up @@ -148,21 +148,21 @@ class RMA {

void notify(int dst_rank) {
auto conn = get_connection(dst_rank);
mpi::AlMPIReq req = get_free_request();
mpi::AlMPIReq req = mpi::get_free_request();
conn->notify(req);
wait_for_completion(req);
}

void wait(int dst_rank) {
auto conn = get_connection(dst_rank);
mpi::AlMPIReq req = get_free_request();
mpi::AlMPIReq req = mpi::get_free_request();
conn->wait(req);
wait_for_completion(req);
}

void sync(int peer) {
auto conn = get_connection(peer);
mpi::AlMPIReq req = get_free_request();
mpi::AlMPIReq req = mpi::get_free_request();
conn->sync(req);
wait_for_completion(req);
}
Expand All @@ -171,12 +171,12 @@ class RMA {
mpi::AlMPIReq *requests = new mpi::AlMPIReq[num_peers];
for (int i = 0; i < num_peers; ++i) {
auto conn = get_connection(peers[i]);
mpi::AlMPIReq req = get_free_request();
mpi::AlMPIReq req = mpi::get_free_request();
conn->sync(req);
requests[i] = req;
}
for (int i = 0; i < num_peers; ++i) {
wait_for_completion(req);
wait_for_completion(requests[i]);
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/test_utils_mpi_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,5 @@ CommWrapper<Al::MPICUDABackend>::CommWrapper(MPI_Comm mpi_comm) {
template <>
void complete_operations<Al::MPICUDABackend>(
typename Al::MPICUDABackend::comm_type& comm) {
AL_FORCE_CHECK_CUDA_NOSYNC(cudaStreamSynchronize(comm.get_stream()));
AL_FORCE_CHECK_GPU_NOSYNC(cudaStreamSynchronize(comm.get_stream()));
}

0 comments on commit 5116016

Please sign in to comment.