Skip to content

Commit

Permalink
[hybrid performance] Optimize tensor parallel plus pipeline parallel …
Browse files Browse the repository at this point in the history
…send recv size (#34110)
  • Loading branch information
wangxicoding authored Jul 13, 2021
1 parent 651aad0 commit 348d043
Show file tree
Hide file tree
Showing 12 changed files with 801 additions and 30 deletions.
85 changes: 85 additions & 0 deletions paddle/fluid/operators/collective/partial_allgather_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/partial_allgather_op.h"

namespace paddle {
namespace operators {

class PartialAllGatherOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "PartialAllGather");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Input", "Out", "PartialAllGather");
int nranks = ctx->Attrs().Get<int>("nranks");
int rank = ctx->Attrs().Get<int>("rank");

PADDLE_ENFORCE_GE(nranks, 2, platform::errors::InvalidArgument(
"The value of nranks should be >=2."));
PADDLE_ENFORCE_EQ(
(rank >= 0 && rank < nranks), true,
platform::errors::InvalidArgument(
"The rank (%d) for partial_allgather op must >=0 and <nranks (%d)",
rank, nranks));

framework::DDim dim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", dim);
}
};

class PartialAllGatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor to be partial allgather");
AddOutput("Out", "(Tensor) the allgather result");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
#if defined(PADDLE_WITH_ASCEND_CL)
AddAttr<std::string>("tag", "(string default tag) tag for all gather.")
.SetDefault("tag");
#endif
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddAttr<int>("nranks",
"Total trainer count of the distributed training job");
AddAttr<int>("rank", "Rand of the distributed training job");
AddComment(R"DOC(
PartialAllGather Operator.
Divide the Input into nranks copies and only use the rank part.
Each rank receives the aggregation of data from all ranks in the order of the ranks.
reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allgather
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(partial_allgather, ops::PartialAllGatherOp,
ops::PartialAllGatherOpMaker);

REGISTER_OP_CPU_KERNEL(partial_allgather,
ops::PartialAllGatherOpCPUKernel<float>,
ops::PartialAllGatherOpCPUKernel<double>,
ops::PartialAllGatherOpCPUKernel<int>,
ops::PartialAllGatherOpCPUKernel<int64_t>,
ops::PartialAllGatherOpCPUKernel<plat::float16>);
91 changes: 91 additions & 0 deletions paddle/fluid/operators/collective/partial_allgather_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/partial_allgather_op.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace paddle {
namespace operators {

template <typename T>
class PartialAllGatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int64_t numel = in->numel();
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());

int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);

PADDLE_ENFORCE_EQ(
nranks, comm->nranks(),
platform::errors::InvalidArgument("nranks: %s should equal to %s",
nranks, comm->nranks()));
PADDLE_ENFORCE_EQ(rank, comm->rank(),
platform::errors::InvalidArgument(
"rank: %s should equal to %s", rank, comm->rank()));
PADDLE_ENFORCE_EQ(
(numel % nranks), 0,
platform::errors::InvalidArgument(
"The input numel (%d) must be divisible by nranks(%d)", numel,
nranks));

framework::DDim dims = in->dims();
out->mutable_data<T>(dims, place);

int64_t send_numel = numel / nranks;
int offset = send_numel * rank;
const T* send_buff = in->data<T>() + offset;
T* recv_buff = out->data<T>();

gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}

PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(partial_allgather,
ops::PartialAllGatherOpCUDAKernel<float>,
ops::PartialAllGatherOpCUDAKernel<double>,
ops::PartialAllGatherOpCUDAKernel<int>,
ops::PartialAllGatherOpCUDAKernel<int64_t>,
ops::PartialAllGatherOpCUDAKernel<plat::float16>);
39 changes: 39 additions & 0 deletions paddle/fluid/operators/collective/partial_allgather_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

template <typename T>
class PartialAllGatherOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support partial_allgather for cpu kernel now."));
}
};

} // namespace operators
} // namespace paddle
131 changes: 131 additions & 0 deletions paddle/fluid/operators/collective/partial_recv_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/partial_recv_op.h"
#include <string>

namespace paddle {
namespace operators {

class PartialRecvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "PartialRecv");
int peer = ctx->Attrs().Get<int>("peer");
int ring_id = ctx->Attrs().Get<int>("ring_id");
int num = ctx->Attrs().Get<int>("num");
int id = ctx->Attrs().Get<int>("id");
auto out_shape = ctx->Attrs().Get<std::vector<int>>("out_shape");

PADDLE_ENFORCE_GE(
peer, 0,
platform::errors::InvalidArgument(
"The peer (%d) for partial_recv op must be non-negative.", peer));
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for partial_recv op must be non-negative.",
ring_id));
PADDLE_ENFORCE_GE(num, 1,
platform::errors::InvalidArgument(
"The num (%d) for partial_send op must >=1", num));
PADDLE_ENFORCE_EQ(
(id >= 0 && id < num), true,
platform::errors::InvalidArgument(
"The id (%d) for partial_send op must >=0 and <num (%d)", id, num));
PADDLE_ENFORCE_GE(out_shape.size(), 1,
platform::errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));

for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(out_shape[i], 1,
platform::errors::InvalidArgument(
"The shape attribute for partial_recv must be set "
"explicitly, but the %dth element is %d which "
"is less than 1.",
i, out_shape[i]));
}
auto out_dims = framework::make_ddim(out_shape);
int numel = framework::product(out_dims);
PADDLE_ENFORCE_EQ(
(numel % num), 0,
platform::errors::InvalidArgument(
"The output numel (%d) must be divisible by num(%d)", numel, num));

ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
int dtype = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type =
framework::proto::VarType::Type(dtype);
return framework::OpKernelType(type, ctx.GetPlace());
}
};

class PartialRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddOutput("Out", "(Tensor) tensor to receive.");
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<int>("peer", "(int default 0) rank id for sender.").SetDefault(0);
AddAttr<int>("dtype", "(int default 5('float32')) data type of tensor.")
.SetDefault(5);
#if defined(PADDLE_WITH_ASCEND_CL)
AddAttr<std::string>("tag", "(string default tag) tag for broadcasting.")
.SetDefault("tag");
AddAttr<int>("srTag", "(string default tag) tag for broadcasting.")
.SetDefault(0);
#endif
AddAttr<std::vector<int>>("out_shape", "shape of the output tensor.")
.SetDefault(std::vector<int>());
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddAttr<int>("num", "(int default 1) The number of Output to be cut.")
.SetDefault(1);
AddAttr<int>("id",
"(int default 0) ID of the part to be recv after Output cut.")
.SetDefault(0);
AddComment(R"DOC(
Recv Operator.
Divide the Output into num copies and only recv the id part.
Reference: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#sendrecv
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(partial_recv, ops::PartialRecvOp,
ops::PartialRecvOpMaker);

REGISTER_OP_CPU_KERNEL(partial_recv, ops::PartialRecvOpCPUKernel<float>,
ops::PartialRecvOpCPUKernel<double>,
ops::PartialRecvOpCPUKernel<int>,
ops::PartialRecvOpCPUKernel<int64_t>,
ops::PartialRecvOpCPUKernel<plat::float16>);
Loading

0 comments on commit 348d043

Please sign in to comment.