Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… move-compare-op-to-phi
  • Loading branch information
From00 committed Mar 2, 2022
2 parents 01e3b8e + 0925804 commit a345a80
Show file tree
Hide file tree
Showing 272 changed files with 10,021 additions and 5,120 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ tools/__pycache__
# This file is automatically generated.
# TODO(zhiqiang) Move this file to build directory.
paddle/infrt/dialect/pd_ops.td
paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td
paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td
tools/infrt/kernels.json
paddle/infrt/dialect/pd_ops_info.h
.lit_test_times.txt
paddle/infrt/tests/dialect/Output
Expand Down
20 changes: 19 additions & 1 deletion paddle/fluid/distributed/collective/ProcessGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,25 @@ class ProcessGroup {
std::vector<Tensor>& /* tensors */,
const BroadcastOptions& = BroadcastOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support allreduce", GetBackendName()));
"ProcessGroup%s does not support broadcast", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<Tensor>& tensors /* tensors */, int dst_rank) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<Tensor>& tensors /* tensors */, int src_rank) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support receive", GetBackendName()));
}

protected:
Expand Down
156 changes: 156 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h"

DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator);
Expand Down Expand Up @@ -139,6 +142,14 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout));
}
}

if (!barrierTensors_.empty()) {
// If we use the work to do barrier, we should block cpu
for (auto& place : places_) {
platform::CUDADeviceGuard gpuGuard(place);
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
}
}
return true;
}

Expand Down Expand Up @@ -193,6 +204,10 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
nccl_ids.resize(1);
auto& nccl_id = nccl_ids.front();

for (auto& place : places) {
used_place_ids_.insert(place.GetDeviceId());
}

if (rank_ == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGetUniqueId(&nccl_id));
}
Expand Down Expand Up @@ -274,6 +289,54 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
return task;
}

template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
std::vector<Tensor>& tensors, Fn fn, int dst_rank, CommType op_type) {
const auto places = GetPlaceList(tensors);
const auto key = GetKeyFromPlaces(places);

{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
CreateNCCLManagerCache(key, places);
}
}

auto& nccl_comms = places_to_ncclcomm_[key];

SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);

auto task = CreateTask(places, rank_, op_type, tensors);

// construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard;

if (FLAGS_use_stream_safe_cuda_allocator) {
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
auto dense_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
memory::RecordStream(dense_tensor->Holder(),
places_to_ctx_[key][i]->stream());
}
}

{
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream();
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
}
}

for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]);
task->control_events_[i].Record(*places_to_ctx_[key][i]);
}
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
std::vector<Tensor>& tensors, const AllreduceOptions& opts) {
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -317,5 +380,98 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
CommType::BROADCAST);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) {
std::vector<phi::GPUPlace> places;

if (!opts.place_ids.empty()) {
for (auto place_id : opts.place_ids) {
places.emplace_back(place_id);
}
} else if (!used_place_ids_.empty()) {
for (auto place_id : used_place_ids_) {
places.emplace_back(place_id);
}
} else {
auto numGPUs = GetSize();
int place_id = static_cast<int>(rank_ % numGPUs);
places.emplace_back(place_id);
}

std::vector<Tensor> barrierTensors;
barrierTensors.reserve(places.size());

platform::CUDADeviceGuard gpuGuard;
for (auto& place : places) {
gpuGuard.SetDeviceIndex(place.GetDeviceId());
auto dt = full({1}, 0, phi::DataType::FLOAT32, phi::Backend::GPU);
barrierTensors.push_back(dt);
}
auto task = ProcessGroupNCCL::AllReduce(barrierTensors);
auto nccl_task = dynamic_cast<ProcessGroupNCCL::NCCLTask*>(task.get());
nccl_task->barrierTensors_ = std::move(barrierTensors);
return task;
}

void CheckTensorsInDifferentDevices(const std::vector<Tensor>& tensors,
const size_t num_devices) {
PADDLE_ENFORCE_EQ(
tensors.size() == 0, false,
platform::errors::InvalidArgument("Tensor list must be nonempty."));
PADDLE_ENFORCE_LE(
tensors.size(), num_devices,
platform::errors::InvalidArgument(
"Tensor list mustn't be larger than the number of available GPUs."));

std::set<Place> used_devices;

for (const auto& t : tensors) {
PADDLE_ENFORCE_EQ(t.is_cuda() && t.is_dense_tensor(), true,
platform::errors::InvalidArgument(
"Tensors must be CUDA and dense tensor."));

const auto inserted = used_devices.insert(t.inner_place()).second;
PADDLE_ENFORCE_EQ(inserted, true,
platform::errors::InvalidArgument(
"Tensors must be on distinct GPU devices."));
}
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
std::vector<Tensor>& tensors, int dst_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

auto task = PointToPoint(
tensors,
[&](Tensor& input, ncclComm_t comm, const gpuStream_t& stream,
int dst_rank) {
auto input_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(input.impl());
return platform::dynload::ncclSend(
input_tensor->data(), input_tensor->numel(),
platform::ToNCCLDataType(input.type()), dst_rank, comm, stream);
},
dst_rank, CommType::SEND);
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
std::vector<Tensor>& tensors, int src_rank) {
CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

auto task = PointToPoint(
tensors,
[&](Tensor& output, ncclComm_t comm, const gpuStream_t& stream,
int src_rank) {
auto output_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(output.impl());
return platform::dynload::ncclRecv(
output_tensor->data(), output_tensor->numel(),
platform::ToNCCLDataType(output.type()), src_rank, comm, stream);
},
src_rank, CommType::RECV);
return task;
}

} // namespace distributed
} // namespace paddle
17 changes: 17 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class ProcessGroupNCCL : public ProcessGroup {
virtual ~NCCLTask();

std::vector<EventManager> control_events_;
std::vector<Tensor> barrierTensors_;

protected:
std::vector<Place> places_;
Expand All @@ -88,6 +89,15 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<Tensor>& tensors,
const BroadcastOptions& = BroadcastOptions()) override;

std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override;

std::shared_ptr<ProcessGroup::Task> Send(std::vector<Tensor>& tensors,
int dst_rank) override;

std::shared_ptr<ProcessGroup::Task> Recv(std::vector<Tensor>& tensors,
int src_rank) override;

protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, int rank, CommType opType,
Expand All @@ -106,6 +116,8 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<std::unique_ptr<CUDADeviceContext>>>
places_to_ctx_;

std::set<int> used_place_ids_;

private:
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root, // NOLINT
int server_fd);
Expand All @@ -118,6 +130,11 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<Tensor>& outputs, // NOLINT
Fn fn, CommType op_type);

template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<Tensor>& tensors, // NOLINT
Fn fn, int dst_rank, CommType op_type);

void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
};
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/collective/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,9 @@ struct BroadcastOptions {
int source_root = 0;
};

struct BarrierOptions {
std::vector<int> place_ids;
};

} // namespace distributed
} // namespace paddle
24 changes: 19 additions & 5 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
Expand Down Expand Up @@ -46,7 +48,8 @@ void Carrier::Init(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
const framework::ProgramDesc& program, framework::Scope* scope,
int64_t num_micro_batches, const platform::Place& place) {
int64_t num_micro_batches, const platform::Place& place,
const std::vector<std::string>& inference_root_scope_vars) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_id_to_node_ = interceptor_id_to_node;
Expand All @@ -60,7 +63,7 @@ void Carrier::Init(
microbatch_scopes_.resize(num_micro_batches);
for (int i = 0; i < num_micro_batches; ++i) {
microbatch_scopes_[i] = &minibatch_scope_->NewScope();
CopyParameters(i, program);
CopyParameters(i, program, inference_root_scope_vars);
}

// TODO(fleet_exe dev): thread pool
Expand All @@ -80,12 +83,23 @@ void Carrier::Release() {

Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }

void Carrier::CopyParameters(int microbatch_id,
const framework::ProgramDesc& program) {
void Carrier::CopyParameters(
int microbatch_id, const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars) {
auto& global_block = program.Block(0);

std::map<std::string, int> inference_root_scope_var_map;
for (auto var_name : inference_root_scope_vars) {
inference_root_scope_var_map.insert({var_name, 1});
}
for (auto& var : global_block.AllVars()) {
if (var->Persistable() && microbatch_id == 0) {
std::string var_name = var->Name();
bool force_root = inference_root_scope_var_map.find(var_name) !=
inference_root_scope_var_map.end();
if (force_root) {
VLOG(4) << var_name << " will be forced to be created in the root scope.";
}
if ((var->Persistable() || force_root) && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ class Carrier final {
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
const framework::ProgramDesc& program, framework::Scope* scope,
int64_t num_micro_batches, const platform::Place& place);
int64_t num_micro_batches, const platform::Place& place,
const std::vector<std::string>& inference_root_scope_vars = {});

void CopyParameters(int microbatch_id, const framework::ProgramDesc& program);
void CopyParameters(
int microbatch_id, const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars);

void Release();
void Wait();
Expand Down
Loading

1 comment on commit a345a80

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.