Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support KL2 multi-card training, *test=kunlun #43889

Merged
merged 1 commit into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif()

set(XPU_XCCL_BASE_URL
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.0")

if(WITH_AARCH64)
set(XPU_XRE_DIR_NAME "xre-kylin_aarch64")
set(XPU_XDNN_DIR_NAME "xdnn-kylin_aarch64")
Expand Down Expand Up @@ -76,7 +79,7 @@ set(XPU_XRE_URL
"${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz"
CACHE STRING "" FORCE)
set(XPU_XCCL_URL
"${XPU_BASE_URL_WITHOUT_DATE}/20220411/${XPU_XCCL_DIR_NAME}.tar.gz"
"${XPU_XCCL_BASE_URL}/${XPU_XCCL_DIR_NAME}.tar.gz"
CACHE STRING "" FORCE)
set(XPU_PACK_DEPENCE_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/pack_paddle_depence.sh"
Expand Down
39 changes: 33 additions & 6 deletions paddle/fluid/imperative/bkcl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ void BKCLParallelContext::Init() {
strategy_.local_rank_,
xpu_id,
ring_id);
compute_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
comm_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
}
}

Expand All @@ -134,6 +138,11 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateComm(
&bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id);

compute_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
comm_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
}

void BKCLParallelContext::AllReduceByStream(const framework::Variable &src,
Expand Down Expand Up @@ -213,9 +222,18 @@ void BKCLParallelContext::WaitCompute(int ring_id) {
"but got ring id = %d, nrings = %d",
ring_id,
strategy_.nrings_));
auto compute_dev_ctx = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
auto compute_stream = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto comm_stream = platform::BKCLCommContext::Instance()
.Get(ring_id, place_)
->dev_context()
->stream();
auto event = compute_events_[ring_id].get();

// compute_stream-->event-->comm_stream
PADDLE_ENFORCE_XPU_SUCCESS(xpu_event_record(event, compute_stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个要求runtime和驱动版本比较新,才能支持xpu_event_record,对吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xpu_event_record 我看20年就有了,3994660c (hanjinchen 2020-07-03

PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_wait_event(comm_stream, event));
}

void BKCLParallelContext::WaitComm(int ring_id) {
Expand All @@ -230,9 +248,18 @@ void BKCLParallelContext::WaitComm(int ring_id) {
"but got ring id = %d, nrings = %d",
ring_id,
strategy_.nrings_));
auto comm_dev_ctx =
platform::BKCLCommContext::Instance().Get(ring_id, place_)->dev_context();
comm_dev_ctx->Wait();
auto comm_stream = platform::BKCLCommContext::Instance()
.Get(ring_id, place_)
->dev_context()
->stream();
auto compute_stream = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto event = compute_events_[ring_id].get();

// comm_stream-->event-->compute_stream
PADDLE_ENFORCE_XPU_SUCCESS(xpu_event_record(event, comm_stream));
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_wait_event(compute_stream, event));
}

void BKCLParallelContext::SynchronizeCompute() {
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/imperative/bkcl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>

#include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/platform/device/xpu/xpu_resource_pool.h"
#include "xpu/bkcl.h"

namespace paddle {
Expand Down Expand Up @@ -52,6 +53,13 @@ class BKCLParallelContext : public ParallelContext {
void WaitComm(int ring_id) override;

void SynchronizeCompute() override;

private:
// used for comm wait compute, compute_stream-->event-->comm_stream[ring_id]
std::vector<std::shared_ptr<platform::XpuEventObject>> compute_events_;

// used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream
std::vector<std::shared_ptr<platform::XpuEventObject>> comm_events_;
};

} // namespace imperative
Expand Down
69 changes: 20 additions & 49 deletions paddle/fluid/imperative/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#ifdef PADDLE_WITH_XPU_BKCL
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
Expand Down Expand Up @@ -431,10 +434,6 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
VLOG(3) << "Start construct the Reducer ...";
nrings_ = parallel_ctx->GetNRings();
nranks_ = parallel_ctx->GetNRanks();
#ifdef PADDLE_WITH_XPU_BKCL
comm_pool_.reset(new ::ThreadPool(1));
comm_op_count_ = 0;
#endif
// initialize groups
InitializeGroups(group_indices);
for (size_t global_var_index = 0; global_var_index < vars_.size();
Expand Down Expand Up @@ -853,8 +852,23 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {

#ifdef PADDLE_WITH_XPU_BKCL
if (platform::is_xpu_place(group_tensor.place())) {
// TODO(liuyuhui) support XPU set constant
VLOG(3) << "XPU doesn't support set_constant";
auto dev_ctx = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
if (HasGrad(var_index)) {
auto var_base = vars_[var_index]->GradVarBase();
auto tensor =
var_base->MutableVar()->GetMutable<framework::LoDTensor>();
group_tensor.ShareDataWith(*tensor).Resize(
{static_cast<int64_t>(length)});
} else {
group_tensor.Resize({static_cast<int64_t>(length)});
int r = xpu::constant(dev_ctx->x_context(),
reinterpret_cast<float *>(group_tensor.data()),
group_tensor.numel(),
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx->stream()));
}
}
#elif defined(PADDLE_WITH_CNCL)
if (platform::is_mlu_place(group_tensor.place())) {
Expand Down Expand Up @@ -948,33 +962,7 @@ void Reducer::MarkGroupReady(size_t group_index) {
// so we expose WaitCompute() interface and call
// it here.
parallel_ctx_->WaitCompute(run_order);
#ifdef PADDLE_WITH_XPU_BKCL
{
std::lock_guard<std::mutex> lock(mutex_);
comm_op_count_ += 1; // lock
}
// TODO(liuyuhui): Add try catch to deal with exception later,
// otherwise the main thread will continue to run when an exception is
// thrown in comm_pool_.
auto next_group = next_group_;
comm_pool_->enqueue([this, run_order, next_group, &group] {
auto dev_id = place_.device;
platform::SetXPUDeviceId(dev_id);
FusedAllReduceSchedule(run_order, group, next_group);
{
std::lock_guard<std::mutex> lock(mutex_);
comm_op_count_ -= 1; // lock
cv_.notify_all();
}
});
#elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL) || \
defined(PADDLE_WITH_CNCL)
FusedAllReduceSchedule(run_order, group, next_group_);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Not compiled with BKCL or NCCL or CNCL or GLOO."));
#endif
}
}

Expand All @@ -997,17 +985,6 @@ void Reducer::FusedAllReduceSchedule(const int run_order,
// group.dense_tensors ---> group.dense_contents_
group.ConcatTensors(dev_context);

// NOTE(liuyuhui): ConcatTensors use communication stream, but BKCL only support
// default stream for communicating, so there exist some problems in
// synchronization. And need to add a WaitComm there.
// TODO(liuyuhui): If BKCL support non-blocking communication, it should be
// fixed as multi gpus card training.
#ifdef PADDLE_WITH_XPU_BKCL
if (platform::is_xpu_place(group.dense_tensors_[0].place())) {
parallel_ctx_->WaitComm(run_order);
}
#endif

group.DivNRanks(dev_context, nranks_);
// Start allreduce
parallel_ctx_->AllReduceByStream(
Expand Down Expand Up @@ -1135,12 +1112,6 @@ bool Reducer::HasGrad(size_t var_index) {
void Reducer::FinalizeBackward() {
groups_need_finalize_ = false;
grad_need_hooks_ = false;
#ifdef PADDLE_WITH_XPU_BKCL
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return comm_op_count_ == 0; });
}
#endif

// Must prevent compute_stream_ starting until all comm streams have finished
for (int i = 0; i < nrings_; ++i) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/platform/collective_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,12 @@ BKCLComm* BKCLCommContext::AssignBKCLComm(
BKCLContext_t comm, int nranks, int rank, int dev_id, int ring_id) {
std::unique_ptr<XPUDeviceContext> dev_ctx(
new XPUDeviceContext(XPUPlace(dev_id)));
// used in BKCL as comm_stream, for every dev_id there is
// a comm_stream at each ring. this stream is passed as input var
// when calling collective comm commands like bkcl_all_reduce
XPUStream comm_stream;
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&comm_stream));
dev_ctx->SetXPUStream(comm_stream);

BKCLCommImpl* c = new BKCLCommImpl;
c->set_ring_id(ring_id);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/platform/device/xpu/xpu2_op_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fluid/operator/下没看到具体的*_op_xpu.cc文件和对应单测文件,和调用broadcast的地方没看到,是不是可以不用加

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xpu调broadcast的地方在dygraph/layers.py:_dygraph_call_func()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要加all_reduce吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时不需要,allreduce现在是直接调用bkcl的接口,没有封装

{"cast",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/all_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ limitations under the License. */
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/xpu/xpu_context.h"
#endif

#ifndef PADDLE_WITH_CUSTOM_KERNEL
// TODO(wilber): DeviceContextPool nees include fluid file.
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/xpu/xpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ struct XPUContext::Impl {

const Place& GetPlace() const { return place_; }

void SetStream(XPUStream stream) { context_->xpu_stream = stream; }

xpu::Context* GetXContext() const {
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
return context_;
Expand Down Expand Up @@ -115,6 +117,8 @@ XPUContext::~XPUContext() = default;

const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); }

void XPUContext::SetXPUStream(XPUStream stream) { impl_->SetStream(stream); }

backends::xpu::XPUVersion XPUContext::xpu_version() const {
return impl_->xpu_version_;
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/xpu/xpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class XPUContext : public DeviceContext {

void SetL3Cache(int l3_size = 14155776);

void SetXPUStream(XPUStream stream);

private:
struct Impl;
std::unique_ptr<Impl> impl_;
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/xpu/xpu_context.h"
#endif
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
Expand Down