Skip to content

Commit

Permalink
Merge pull request #59 from LLNL/pipeline-progress
Browse files Browse the repository at this point in the history
Progress engine pipeline
  • Loading branch information
ndryden authored Nov 12, 2019
2 parents bbb39a3 + f3723ca commit 5f5d615
Show file tree
Hide file tree
Showing 19 changed files with 715 additions and 209 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ cmake_minimum_required(VERSION 3.9)

set(ALUMINUM_VERSION_MAJOR 0)
set(ALUMINUM_VERSION_MINOR 3)
set(ALUMINUM_VERSION_PATCH 1)
set(ALUMINUM_VERSION_PATCH 2)
set(ALUMINUM_VERSION "${ALUMINUM_VERSION_MAJOR}.${ALUMINUM_VERSION_MINOR}.${ALUMINUM_VERSION_PATCH}")

project(ALUMINUM VERSION ${ALUMINUM_VERSION} LANGUAGES CXX)
Expand Down
27 changes: 15 additions & 12 deletions src/mpi_cuda/allgather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,20 @@ class AllgatherAlState : public AlState {
release_pinned_memory(host_mem_);
}

bool step() override {
if (!ag_started_) {
// Check if mem xfer complete
PEAction step() override {
if (!mem_xfer_done_) {
if (d2h_event_.query()) {
MPI_Iallgather(MPI_IN_PLACE, count_, mpi::TypeMap<T>(),
host_mem_, count_, mpi::TypeMap<T>(), comm_, &req_);
ag_started_ = true;
}
else {
return false;
mem_xfer_done_ = true;
return PEAction::advance;
} else {
return PEAction::cont;
}
}
if (!ag_started_) {
MPI_Iallgather(MPI_IN_PLACE, count_, mpi::TypeMap<T>(),
host_mem_, count_, mpi::TypeMap<T>(), comm_, &req_);
ag_started_ = true;
}

if (!ag_done_) {
// Wait for the all2all to complete
Expand All @@ -98,16 +100,16 @@ class AllgatherAlState : public AlState {
gpuwait_.signal();
}
else {
return false;
return PEAction::cont;
}
}

// Wait for host-to-device memcopy; cleanup
if (h2d_event_.query()) {
return true;
return PEAction::complete;
}

return false;
return PEAction::cont;
}

bool needs_completion() const override { return false; }
Expand All @@ -126,6 +128,7 @@ class AllgatherAlState : public AlState {
MPI_Comm comm_;
MPI_Request req_ = MPI_REQUEST_NULL;

bool mem_xfer_done_ = false;
bool ag_started_ = false;
bool ag_done_ = false;

Expand Down
24 changes: 14 additions & 10 deletions src/mpi_cuda/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,34 @@ class HostTransferState : public AlState {
release_pinned_memory(host_mem);
}

bool step() override {
if (!ar_started) {
// Wait for memory to get to the host.
PEAction step() override {
if (!mem_xfer_done) {
if (d2h_event.query()) {
host_ar->setup();
ar_started = true;
mem_xfer_done = true;
return PEAction::advance;
} else {
return false;
return PEAction::cont;
}
}
if (!ar_started) {
host_ar->setup();
ar_started = true;
}
if (!ar_done) {
// Wait for the allreduce to complete.
if (host_ar->step()) {
if (host_ar->step() == PEAction::complete) {
ar_done = true;
// Mark the sync as done to wake up the device.
gpu_wait.signal();
} else {
return false;
return PEAction::cont;
}
}
// Wait for the memcpy back to device to complete so we can clean up.
if (h2d_event.query()) {
return true;
return PEAction::complete;
}
return false;
return PEAction::cont;
}
bool needs_completion() const override { return false; }
void* get_compute_stream() const override { return compute_stream; }
Expand All @@ -124,6 +127,7 @@ class HostTransferState : public AlState {
mpi::MPIAlState<T>* host_ar;
cuda::FastEvent d2h_event, h2d_event;
cuda::GPUWait gpu_wait;
bool mem_xfer_done = false;
bool ar_started = false;
bool ar_done = false;
cudaStream_t compute_stream;
Expand Down
27 changes: 15 additions & 12 deletions src/mpi_cuda/alltoall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,20 @@ class AlltoallAlState : public AlState {
release_pinned_memory(host_mem_);
}

bool step() override {
if (!a2a_started_) {
// Check if mem xfer complete
PEAction step() override {
if (!mem_xfer_done_) {
if (d2h_event_.query()) {
MPI_Ialltoall(MPI_IN_PLACE, count_, mpi::TypeMap<T>(),
host_mem_, count_, mpi::TypeMap<T>(), comm_, &req_);
a2a_started_ = true;
}
else {
return false;
mem_xfer_done_ = true;
return PEAction::advance;
} else {
return PEAction::cont;
}
}
if (!a2a_started_) {
MPI_Ialltoall(MPI_IN_PLACE, count_, mpi::TypeMap<T>(),
host_mem_, count_, mpi::TypeMap<T>(), comm_, &req_);
a2a_started_ = true;
}

if (!a2a_done_) {
// Wait for the all2all to complete
Expand All @@ -87,16 +89,16 @@ class AlltoallAlState : public AlState {
gpuwait_.signal();
}
else {
return false;
return PEAction::cont;
}
}

// Wait for host-to-device memcopy; cleanup
if (h2d_event_.query()) {
return true;
return PEAction::complete;
}

return false;
return PEAction::cont;
}

bool needs_completion() const override { return false; }
Expand All @@ -115,6 +117,7 @@ class AlltoallAlState : public AlState {
MPI_Comm comm_;
MPI_Request req_ = MPI_REQUEST_NULL;

bool mem_xfer_done_ = false;
bool a2a_started_ = false;
bool a2a_done_ = false;

Expand Down
31 changes: 18 additions & 13 deletions src/mpi_cuda/bcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,22 @@ class BcastAlState : public AlState {
release_pinned_memory(host_mem_);
}

bool step() override {
if (!bcast_started_) {
PEAction step() override {
if (!mem_xfer_done_) {
if (d2h_event_.query()) {
MPI_Ibcast(host_mem_, count_, mpi::TypeMap<T>(),
root_, comm_, &req_);
if (rank_ == root_) {
gpuwait_.signal();
}
bcast_started_ = true;
mem_xfer_done_ = true;
return PEAction::advance;
} else {
return false;
return PEAction::cont;
}
}
if (!bcast_started_) {
MPI_Ibcast(host_mem_, count_, mpi::TypeMap<T>(),
root_, comm_, &req_);
if (rank_ == root_) {
gpuwait_.signal();
}
bcast_started_ = true;
}

if (!bcast_done_) {
Expand All @@ -93,19 +97,19 @@ class BcastAlState : public AlState {
if (rank_ != root_) {
gpuwait_.signal();
} else {
return true;
return PEAction::complete;
}
}
else {
return false;
return PEAction::cont;
}
}

// Wait for host-to-device memcopy; cleanup
if (h2d_event_.query()) {
return true;
return PEAction::complete;
}
return false;
return PEAction::cont;
}

bool needs_completion() const override { return false; }
Expand All @@ -126,6 +130,7 @@ class BcastAlState : public AlState {
MPI_Comm comm_;
MPI_Request req_ = MPI_REQUEST_NULL;

bool mem_xfer_done_ = false;
bool bcast_started_ = false;
bool bcast_done_ = false;

Expand Down
45 changes: 24 additions & 21 deletions src/mpi_cuda/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,27 @@ class GatherAlState : public AlState {
release_pinned_memory(host_mem_);
}

bool step() override {
if (!gather_started_) {
// Check if mem xfer complete
PEAction step() override {
if (!mem_xfer_done_) {
if (d2h_event_.query()) {
if (root_ == rank_) {
MPI_Igather(MPI_IN_PLACE, count_, mpi::TypeMap<T>(),
host_mem_, count_, mpi::TypeMap<T>(),
root_, comm_, &req_);
} else {
MPI_Igather(host_mem_, count_, mpi::TypeMap<T>(),
host_mem_, count_, mpi::TypeMap<T>(),
root_, comm_, &req_);
gpuwait_.signal();
}
gather_started_ = true;
mem_xfer_done_ = true;
return PEAction::advance;
} else {
return PEAction::cont;
}
else {
return false;
}
if (!gather_started_) {
if (root_ == rank_) {
MPI_Igather(MPI_IN_PLACE, count_, mpi::TypeMap<T>(),
host_mem_, count_, mpi::TypeMap<T>(),
root_, comm_, &req_);
} else {
MPI_Igather(host_mem_, count_, mpi::TypeMap<T>(),
host_mem_, count_, mpi::TypeMap<T>(),
root_, comm_, &req_);
gpuwait_.signal();
}
gather_started_ = true;
}

if (!gather_done_) {
Expand All @@ -104,24 +106,24 @@ class GatherAlState : public AlState {
if (rank_ == root_)
gpuwait_.signal();
else
return true;
return PEAction::complete;
}
else {
return false;
return PEAction::cont;
}
}
else if (rank_ != root_) {
// Paranoia, in case step() is ever called again after returning
// 'true' for the first time.
return true;
return PEAction::complete;
}

// Wait for host-to-device memcopy; cleanup
if (h2d_event_.query()) {
return true;
return PEAction::complete;
}

return false;
return PEAction::cont;
}

bool needs_completion() const override { return false; }
Expand All @@ -142,6 +144,7 @@ class GatherAlState : public AlState {
MPI_Comm comm_;
MPI_Request req_ = MPI_REQUEST_NULL;

bool mem_xfer_done_ = false;
bool gather_started_ = false;
bool gather_done_ = false;

Expand Down
Loading

0 comments on commit 5f5d615

Please sign in to comment.