Skip to content

Commit

Permalink
Fix fleet executor stop
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Dec 16, 2021
1 parent 1bdab4f commit 1c03784
Show file tree
Hide file tree
Showing 16 changed files with 138 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ namespace paddle {
namespace distributed {

AmplifierInterceptor::AmplifierInterceptor(int64_t interceptor_id,
TaskNode* node, Carrier* carrier)
: ComputeInterceptor(interceptor_id, node, carrier) {
TaskNode* node)
: ComputeInterceptor(interceptor_id, node) {
run_per_steps_ = node->run_per_steps();
run_at_offset_ = node->run_at_offset();
reply_up_per_steps_ = node->reply_up_per_steps();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ namespace distributed {

class AmplifierInterceptor : public ComputeInterceptor {
public:
AmplifierInterceptor(int64_t interceptor_id, TaskNode* node,
Carrier* carrier);
AmplifierInterceptor(int64_t interceptor_id, TaskNode* node);

private:
void RunOps() override;
Expand Down
30 changes: 19 additions & 11 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);

void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph,
MessageBus* msg_bus, framework::Scope* root_scope,
framework::Scope* root_scope,
framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
runtime_graph_ = runtime_graph;
msg_bus_ = msg_bus;
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
Expand All @@ -52,7 +51,9 @@ void Carrier::Release() {
// Sending STOP msg to the source interceptor
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Message bus has not been initialized."));
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
Expand All @@ -61,7 +62,7 @@ void Carrier::Release() {
stop_msg.set_src_id(-1);
stop_msg.set_dst_id(id);
stop_msg.set_message_type(STOP);
msg_bus_->Send(stop_msg);
Send(stop_msg);
}

// TODO(wangxi): Maybe need a better to use thread.
Expand Down Expand Up @@ -113,10 +114,17 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
return iter->second.get();
}

void Carrier::Wait() {
std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
}

void Carrier::Start() {
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Message bus has not been initialized."));
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));

for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Start is sending start to source interceptor " << id
Expand All @@ -126,18 +134,17 @@ void Carrier::Start() {
start_msg.set_src_id(-1);
start_msg.set_dst_id(id);
start_msg.set_message_type(DATA_IS_READY);
msg_bus_->Send(start_msg);
Send(start_msg);
}

std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
Wait();
dev_ctx_->Wait();
}

std::condition_variable& Carrier::GetCondVar() { return cond_var_; }

bool Carrier::IsInit() const { return is_init_; }

// TODO(liyurui): Move SendIntra into carrier
bool Carrier::Send(const InterceptorMessage& msg) const {
return msg_bus_->Send(msg);
}
Expand All @@ -150,6 +157,7 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
"The interceptor id %lld has already been created! "
"The interceptor id should be unique.",
interceptor_id));
interceptor->RegisterCarrier(this);
auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
Expand Down Expand Up @@ -234,10 +242,10 @@ void Carrier::CreateInterceptors() {
std::unique_ptr<Interceptor> interceptor;
if (task_node->type().empty()) {
// TODO(wangxi): delete this in future
interceptor.reset(new Interceptor(interceptor_id, task_node, this));
interceptor.reset(new Interceptor(interceptor_id, task_node));
} else {
interceptor = InterceptorFactory::Create(task_node->type(),
interceptor_id, task_node, this);
interceptor_id, task_node);
}
interceptor->SetPlace(place_);
interceptor->SetMiniBatchScope(minibatch_scope_);
Expand Down
10 changes: 7 additions & 3 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ class InterceptorMessageServiceImpl;
class RuntimeGraph;
class MessageBus;

// A singleton MessageBus
class Carrier final {
public:
Carrier() = default;
~Carrier();
void Init(std::shared_ptr<RuntimeGraph> runtime_graph, MessageBus* msg_bus,
void Init(std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);

void Release();
void Wait();

// Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
Expand All @@ -65,6 +65,9 @@ class Carrier final {
std::unique_ptr<Interceptor>);

void SetCreatingFlag(bool flag);
void SetMsgBus(const std::shared_ptr<MessageBus>& msg_bus) {
msg_bus_ = msg_bus;
}

std::condition_variable& GetCondVar();

Expand Down Expand Up @@ -107,7 +110,8 @@ class Carrier final {
paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
MessageBus* msg_bus_;
std::shared_ptr<MessageBus> msg_bus_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
};

} // namespace distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
namespace paddle {
namespace distributed {

ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node,
Carrier* carrier)
: Interceptor(interceptor_id, node, carrier) {
ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
PrepareDeps();
RegisterMsgHandle([this](const InterceptorMessage& msg) { Compute(msg); });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace distributed {

class ComputeInterceptor : public Interceptor {
public:
ComputeInterceptor(int64_t interceptor_id, TaskNode* node, Carrier* carrier);
ComputeInterceptor(int64_t interceptor_id, TaskNode* node);

protected:
virtual void RunOps();
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
"Error occurs while parsing string to proto"));
msg_bus_ = std::make_unique<MessageBus>();
}

FleetExecutor::~FleetExecutor() {
Expand Down Expand Up @@ -70,14 +69,16 @@ void FleetExecutor::Init(
CopyParameters(i, program_desc);
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
InitCarrier();
InitMessageBus();
}

void FleetExecutor::InitCarrier() {
Carrier& carrier = GetCarrier();
if (!carrier.IsInit()) {
carrier.Init(runtime_graph_, msg_bus_.get(), root_scope_, minibatch_scope_,
carrier.SetMsgBus(msg_bus_);
carrier.Init(runtime_graph_, root_scope_, minibatch_scope_,
microbatch_scopes_, place_);
}
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ class FleetExecutor final {
framework::Scope* minibatch_scope_;
platform::Place place_;
std::vector<framework::Scope*> microbatch_scopes_;
std::unique_ptr<MessageBus> msg_bus_;
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
};

} // namespace distributed
Expand Down
14 changes: 8 additions & 6 deletions paddle/fluid/distributed/fleet_executor/interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
namespace paddle {
namespace distributed {

Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node,
Carrier* carrier)
: interceptor_id_(interceptor_id), node_(node), carrier_(carrier) {
Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
: interceptor_id_(interceptor_id), node_(node) {
interceptor_thread_ = std::thread([this]() {
VLOG(3) << "Interceptor " << interceptor_id_
<< " starts the thread pooling it's local mailbox.";
Expand All @@ -46,6 +45,8 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
}

void Interceptor::StopCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
"Carrier is not registered."));
std::condition_variable& cond_var = carrier_->GetCondVar();
// probably double notify, but ok for ut
cond_var.notify_all();
Expand All @@ -72,6 +73,8 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
}

bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
PADDLE_ENFORCE_NOT_NULL(carrier_, platform::errors::PreconditionNotMet(
"Carrier is not registered."));
msg.set_src_id(interceptor_id_);
msg.set_dst_id(dst_id);
return carrier_->Send(msg);
Expand Down Expand Up @@ -128,14 +131,13 @@ static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() {

std::unique_ptr<Interceptor> InterceptorFactory::Create(const std::string& type,
int64_t id,
TaskNode* node,
Carrier* carrier) {
TaskNode* node) {
auto& interceptor_map = GetInterceptorMap();
auto iter = interceptor_map.find(type);
PADDLE_ENFORCE_NE(
iter, interceptor_map.end(),
platform::errors::NotFound("interceptor %s is not register", type));
return iter->second(id, node, carrier);
return iter->second(id, node);
}

void InterceptorFactory::Register(
Expand Down
14 changes: 6 additions & 8 deletions paddle/fluid/distributed/fleet_executor/interceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Interceptor {
public:
Interceptor() = delete;

Interceptor(int64_t interceptor_id, TaskNode* node, Carrier* carrier);
Interceptor(int64_t interceptor_id, TaskNode* node);

virtual ~Interceptor();

Expand Down Expand Up @@ -78,6 +78,7 @@ class Interceptor {
void SetGC(const std::shared_ptr<framework::GarbageCollector>& gc) {
gc_ = gc;
}
void RegisterCarrier(Carrier* carrier) { carrier_ = carrier; }

TaskNode* GetTaskNode() const { return node_; }

Expand Down Expand Up @@ -139,22 +140,19 @@ class Interceptor {
class InterceptorFactory {
public:
using CreateInterceptorFunc = std::unique_ptr<Interceptor> (*)(int64_t,
TaskNode*,
Carrier*);
TaskNode*);
using CreateInterceptorMap =
std::unordered_map<std::string, CreateInterceptorFunc>;

static void Register(const std::string& type, CreateInterceptorFunc func);

static std::unique_ptr<Interceptor> Create(const std::string& type,
int64_t id, TaskNode* node,
Carrier*);
int64_t id, TaskNode* node);
};

template <typename InterceptorClass>
std::unique_ptr<Interceptor> CreatorInterceptor(int64_t id, TaskNode* node,
Carrier* carrier) {
return std::make_unique<InterceptorClass>(id, node, carrier);
std::unique_ptr<Interceptor> CreatorInterceptor(int64_t id, TaskNode* node) {
return std::make_unique<InterceptorClass>(id, node);
}

#define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "gtest/gtest.h"

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
Expand Down Expand Up @@ -61,10 +62,12 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace();

Carrier carrier;
// TODO(liyurui): Remove singleton when move SendIntra into Carrier
Carrier& carrier = FleetExecutor::GetCarrier();

MessageBus msg_bus;
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "");
carrier.SetMsgBus(msg_bus);

// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a =
Expand All @@ -76,9 +79,8 @@ TEST(ComputeInterceptor, Compute) {
node_b->AddUpstreamTask(0);

auto* a = carrier.SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a, &carrier));
carrier.SetInterceptor(
1, InterceptorFactory::Create("Compute", 1, node_b, &carrier));
0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));

a->SetPlace(place);
a->SetMicroBatchScope(scopes);
Expand All @@ -91,6 +93,9 @@ TEST(ComputeInterceptor, Compute) {
msg.set_src_id(-1);
msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg);

carrier.Wait();
carrier.Release();
}

} // namespace distributed
Expand Down
Loading

0 comments on commit 1c03784

Please sign in to comment.