Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the problem in fleet executor stop #38114

Merged
merged 2 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ void Carrier::Release() {
// otherwise Derived object will be destructed before thread complete.

// Sending STOP msg to the source interceptor
MessageBus& msg_bus = MessageBus::Instance();
PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true,
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Message bus has not been initialized."));
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
Expand All @@ -61,7 +62,7 @@ void Carrier::Release() {
stop_msg.set_src_id(-1);
stop_msg.set_dst_id(id);
stop_msg.set_message_type(STOP);
msg_bus.Send(stop_msg);
Send(stop_msg);
}

// TODO(wangxi): Maybe need a better to use thread.
Expand Down Expand Up @@ -113,11 +114,17 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
return iter->second.get();
}

void Carrier::Wait() {
std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
}

void Carrier::Start() {
MessageBus& msg_bus = MessageBus::Instance();
PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true,
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Message bus has not been initialized."));
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));

for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Start is sending start to source interceptor " << id
Expand All @@ -127,18 +134,21 @@ void Carrier::Start() {
start_msg.set_src_id(-1);
start_msg.set_dst_id(id);
start_msg.set_message_type(DATA_IS_READY);
msg_bus.Send(start_msg);
Send(start_msg);
}

std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
Wait();
dev_ctx_->Wait();
}

std::condition_variable& Carrier::GetCondVar() { return cond_var_; }

bool Carrier::IsInit() const { return is_init_; }

// TODO(liyurui): Move SendIntra into carrier
bool Carrier::Send(const InterceptorMessage& msg) const {
return msg_bus_->Send(msg);
LiYuRio marked this conversation as resolved.
Show resolved Hide resolved
}

Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor> interceptor) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
Expand All @@ -147,6 +157,7 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id %lld has already been created! "
"The interceptor id should be unique.",
interceptor_id));
interceptor->RegisterCarrier(this);
auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
Expand Down
22 changes: 12 additions & 10 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,19 @@ namespace distributed {
class TaskNode;
class InterceptorMessageServiceImpl;
class RuntimeGraph;
class MessageBus;

// A singleton MessageBus
class Carrier final {
public:
static Carrier& Instance() {
static Carrier carrier;
return carrier;
}

Carrier() = default;
~Carrier();
void Init(std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);

~Carrier();
void Release();
void Wait();

// Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
Expand All @@ -68,22 +65,25 @@ class Carrier final {
std::unique_ptr<Interceptor>);

void SetCreatingFlag(bool flag);
void SetMsgBus(const std::shared_ptr<MessageBus>& msg_bus) {
msg_bus_ = msg_bus;
}

std::condition_variable& GetCondVar();

void Start();

bool IsInit() const;

bool Send(const InterceptorMessage& msg) const;

// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
// simultaneously, which will lead to a random hang for some sync ops.
std::mutex run;

DISABLE_COPY_AND_ASSIGN(Carrier);

private:
Carrier() = default;
DISABLE_COPY_AND_ASSIGN(Carrier);

// create each Interceptor
void CreateInterceptors();
Expand All @@ -110,6 +110,8 @@ class Carrier final {
paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<MessageBus> msg_bus_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
};

} // namespace distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}

void ComputeInterceptor::RunOps() {
Carrier& carrier_instance = Carrier::Instance();
std::unique_lock<std::mutex> lock(carrier_instance.run);
std::unique_lock<std::mutex> lock(carrier_->run);
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ + 1 << " time.";
for (auto op : node_->ops()) {
Expand Down
36 changes: 22 additions & 14 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
"Error occurs while parsing string to proto"));
}

FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); }
FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
GetCarrier().Release();
}

Carrier& FleetExecutor::GetCarrier() {
static Carrier carrier;
return carrier;
}

void FleetExecutor::Init(
const framework::ProgramDesc& program_desc, framework::Scope* scope,
Expand Down Expand Up @@ -61,15 +69,17 @@ void FleetExecutor::Init(
CopyParameters(i, program_desc);
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
InitCarrier();
InitMessageBus();
}

void FleetExecutor::InitCarrier() {
Carrier& carrier_instance = Carrier::Instance();
if (!carrier_instance.IsInit()) {
carrier_instance.Init(runtime_graph_, root_scope_, minibatch_scope_,
microbatch_scopes_, place_);
Carrier& carrier = GetCarrier();
if (!carrier.IsInit()) {
carrier.SetMsgBus(msg_bus_);
carrier.Init(runtime_graph_, root_scope_, minibatch_scope_,
microbatch_scopes_, place_);
}
}

Expand Down Expand Up @@ -103,24 +113,22 @@ void FleetExecutor::InitMessageBus() {
VLOG(3) << "The number of ranks are "
<< (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << ".";
VLOG(5) << ss.str();
MessageBus& message_bus_instance = MessageBus::Instance();
if (!message_bus_instance.IsInit()) {
message_bus_instance.Init(runtime_graph_->intercepter_id_to_rank(),
rank_to_addr, addr);
if (!msg_bus_->IsInit()) {
msg_bus_->Init(runtime_graph_->intercepter_id_to_rank(), rank_to_addr,
addr);
}
}

void FleetExecutor::Run() {
// Run
Carrier& carrier_instance = Carrier::Instance();
MessageBus& message_bus_instance = MessageBus::Instance();
Carrier& carrier = GetCarrier();
PADDLE_ENFORCE_EQ(
carrier_instance.IsInit(), true,
carrier.IsInit(), true,
platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ(
message_bus_instance.IsInit(), true,
msg_bus_->IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
carrier_instance.Start();
carrier.Start();
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class Scope;

namespace distributed {
class RuntimeGraph;
class Carrier;
class MessageBus;
class TaskNode;
class Carrier;

class FleetExecutor final {
public:
Expand All @@ -42,6 +42,8 @@ class FleetExecutor final {
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
// TODO(liyurui): Change to use registry table for multi-carrier.
static Carrier& GetCarrier();

private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
Expand All @@ -54,6 +56,9 @@ class FleetExecutor final {
framework::Scope* minibatch_scope_;
platform::Place place_;
std::vector<framework::Scope*> microbatch_scopes_;
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
};

} // namespace distributed
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/distributed/fleet_executor/interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"

namespace paddle {
Expand Down Expand Up @@ -46,8 +45,9 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
}

void Interceptor::StopCarrier() {
Carrier& carrier_instance = Carrier::Instance();
std::condition_variable& cond_var = carrier_instance.GetCondVar();
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
"Carrier is not registered."));
std::condition_variable& cond_var = carrier_->GetCondVar();
// probably double notify, but ok for ut
cond_var.notify_all();
}
Expand All @@ -73,9 +73,11 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
}

bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
"Carrier is not registered."));
msg.set_src_id(interceptor_id_);
msg.set_dst_id(dst_id);
return MessageBus::Instance().Send(msg);
return carrier_->Send(msg);
}

void Interceptor::PoolTheMailbox() {
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/fleet_executor/interceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class GarbageCollector;
namespace distributed {

class TaskNode;
class Carrier;

class Interceptor {
public:
Expand Down Expand Up @@ -77,6 +78,7 @@ class Interceptor {
void SetGC(const std::shared_ptr<framework::GarbageCollector>& gc) {
gc_ = gc;
}
void RegisterCarrier(Carrier* carrier) { carrier_ = carrier; }

TaskNode* GetTaskNode() const { return node_; }

Expand All @@ -100,6 +102,8 @@ class Interceptor {
std::vector<framework::Scope*> microbatch_scopes_{};
std::shared_ptr<framework::GarbageCollector> gc_{nullptr};

Carrier* carrier_;

private:
// pool the local mailbox, parse the Message
void PoolTheMailbox();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
FleetExecutor::GetCarrier().EnqueueInterceptorMessage(*request);
response->set_rst(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

突然想到,这其实可以set_rst为EnqueueInterceptorMessage的返回值,下个pr可以改一下。

// call interceptor manager's method to handle the message
Carrier::Instance().EnqueueInterceptorMessage(*request);
}

} // namespace distributed
Expand Down
7 changes: 2 additions & 5 deletions paddle/fluid/distributed/fleet_executor/message_bus.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ void MessageBus::Init(
bool MessageBus::IsInit() const { return is_init_; }

MessageBus::~MessageBus() {
// NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first
Carrier& carrier = Carrier::Instance();
carrier.Release();
VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
Expand Down Expand Up @@ -245,7 +241,8 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {

bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) {
// send the message intra rank (dst is the same rank with src)
return Carrier::Instance().EnqueueInterceptorMessage(interceptor_message);
return FleetExecutor::GetCarrier().EnqueueInterceptorMessage(
interceptor_message);
}

} // namespace distributed
Expand Down
17 changes: 5 additions & 12 deletions paddle/fluid/distributed/fleet_executor/message_bus.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@ class Carrier;
// A singleton MessageBus
class MessageBus final {
public:
static MessageBus& Instance() {
static MessageBus msg_bus;
return msg_bus;
}
MessageBus() = default;
~MessageBus();

void Init(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
Expand All @@ -53,12 +51,8 @@ class MessageBus final {
// called by Interceptor, send InterceptorMessage to dst
bool Send(const InterceptorMessage& interceptor_message);

~MessageBus();

DISABLE_COPY_AND_ASSIGN(MessageBus);

private:
MessageBus() = default;
DISABLE_COPY_AND_ASSIGN(MessageBus);

// function keep listen the port and handle the message
void ListenPort();
Expand All @@ -72,12 +66,11 @@ class MessageBus final {
bool SendInterRank(const InterceptorMessage& interceptor_message);
#endif

bool is_init_{false};

// send the message intra rank (dst is the same rank with src)
bool SendIntraRank(const InterceptorMessage& interceptor_message);

bool is_init_{false};
std::once_flag once_flag_;

// handed by above layer, save the info mapping interceptor id to rank id
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;

Expand Down
Loading