Skip to content

Commit

Permalink
[fleet_executor] Support multi carriers (#38650)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored Jan 4, 2022
1 parent 2d2609e commit 2273471
Show file tree
Hide file tree
Showing 16 changed files with 126 additions and 199 deletions.
25 changes: 6 additions & 19 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ USE_INTERCEPTOR(Amplifier);

void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids) {
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;

// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
Expand All @@ -45,14 +43,12 @@ void Carrier::Init(
void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;
interceptor_id_to_node_ = interceptor_id_to_node;
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
Expand Down Expand Up @@ -156,9 +152,7 @@ bool Carrier::Send(const InterceptorMessage& msg) {
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
int64_t carrier_id = *GlobalMap<int64_t, int64_t>::Get(dst_id);
return GlobalMap<int64_t, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(msg);
return EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
Expand Down Expand Up @@ -192,9 +186,6 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
loop, platform::errors::Fatal("thread task loop must not null"));
interceptor->RegisterTaskLoop(loop);

// TODO(liyurui): Using struct InterceptorID replace int64_t
GlobalMap<int64_t, int64_t>::Create(interceptor_id, carrier_id_);

auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
Expand All @@ -220,19 +211,15 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}

void Carrier::CreateInterceptors() {
if (interceptor_ids_.empty()) return;
if (interceptor_id_to_node_.empty()) return;

auto gc = GetGC(place_);

// create each Interceptor
// no auto init since there is no config
for (int64_t interceptor_id : interceptor_ids_) {
const auto& task_node_iter = interceptor_id_to_node_.find(interceptor_id);
PADDLE_ENFORCE_NE(
task_node_iter, interceptor_id_to_node_.end(),
platform::errors::NotFound("Can not find task node for interceptor %ld",
interceptor_id));
TaskNode* task_node = task_node_iter->second;
for (const auto& item : interceptor_id_to_node_) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;

PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(),
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ class InterceptorMessageServiceImpl;
class RuntimeGraph;
class MessageBus;

// TODO(liyurui): Add CarrierId instead of std::string

class Carrier final {
public:
explicit Carrier(int64_t carrier_id) : carrier_id_(carrier_id) {}
explicit Carrier(const std::string& carrier_id) : carrier_id_(carrier_id) {}
~Carrier();
void Init(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids);
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank);
void Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
Expand Down Expand Up @@ -109,7 +109,7 @@ class Carrier final {
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
int64_t carrier_id_;
std::string carrier_id_;
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
int thread_num_;
Expand Down
48 changes: 19 additions & 29 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {

FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Get(item.first)->Release();
for (const auto& carrier_id : carrier_ids_) {
GlobalMap<std::string, Carrier>::Get(carrier_id)->Release();
}
}

void FleetExecutor::Init(
const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place, const std::vector<TaskNode*>& task_nodes,
const std::string& carrier_id, const framework::ProgramDesc& program_desc,
framework::Scope* scope, const platform::Place& place,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank) {
PADDLE_ENFORCE_GT(task_nodes.size(), 0,
platform::errors::InvalidArgument(
Expand All @@ -58,19 +59,13 @@ void FleetExecutor::Init(
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids;
std::unordered_set<int64_t> interceptor_ids;
for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars);
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
interceptor_ids.insert(interceptor_id);
}
carrier_id_to_interceptor_ids.emplace(0, interceptor_ids);
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
runtime_graph_->SetCarrierIdToInterceptorIds(carrier_id_to_interceptor_ids);
for (auto& unique_op : ops) {
unique_op.release();
}
Expand All @@ -87,27 +82,23 @@ void FleetExecutor::Init(
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Create(item.first, item.first);
}
InitCarrier();
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier_ids_.insert(carrier_id);
GlobalVal<std::string>::Set(carrier_id);
// TODO(liyurui): Maybe message bus should be created only once
InitCarrier(carrier);
InitMessageBus();

// Wait for all message bus connected.
msg_bus_->Barrier();
}

void FleetExecutor::InitCarrier() {
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Get(item.first);
PADDLE_ENFORCE_NOT_NULL(carrier, platform::errors::InvalidArgument(
"Carrier has not been created."));
carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(),
runtime_graph_->interceptor_id_to_rank(), item.second,
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
void FleetExecutor::InitCarrier(Carrier* carrier) {
carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(),
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}

void FleetExecutor::InitMessageBus() {
Expand Down Expand Up @@ -145,10 +136,9 @@ void FleetExecutor::InitMessageBus() {
}
}

void FleetExecutor::Run() {
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Get(item.first)->Start();
}
void FleetExecutor::Run(const std::string& carrier_id) {
GlobalMap<std::string, Carrier>::Get(carrier_id)->Start();
GlobalVal<std::string>::Set(carrier_id);
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,17 @@ class FleetExecutor final {
FleetExecutor() = delete;
explicit FleetExecutor(const std::string& exe_desc_str);
~FleetExecutor();
void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope,
void Init(const std::string& carrier_id,
const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
void Run(const std::string& carrier_id);

private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
void InitMessageBus();
void InitCarrier();
void InitCarrier(Carrier* carrier);
void CopyParameters(int microbatch_id, const framework::ProgramDesc& program);
FleetExecutorDesc exe_desc_;
std::shared_ptr<RuntimeGraph> runtime_graph_;
Expand All @@ -57,6 +58,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
std::unordered_set<std::string> carrier_ids_;
};

} // namespace distributed
Expand Down
47 changes: 47 additions & 0 deletions paddle/fluid/distributed/fleet_executor/global_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@
namespace paddle {
namespace distributed {

// TODO(liyurui): Change this file to global.h
template <typename T>
class GlobalVal final {
public:
static T Get() { return *GetPtr(); }
static T Set(T val) {
auto* ptr = GetPtr();
*ptr = val;
return val;
}

private:
static T* GetPtr() {
static T value;
return &value;
}
};

template <typename KeyT, typename ValueT>
class GlobalMap final {
public:
Expand All @@ -26,6 +44,7 @@ class GlobalMap final {
item, platform::errors::NotFound("This value is not in global map."));
return item;
}

template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
Expand All @@ -37,6 +56,34 @@ class GlobalMap final {
return item;
}

private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::unordered_map<KeyT, std::unique_ptr<ValueT>> id_to_ptr;
return &id_to_ptr[id];
}
};

template <typename KeyT, typename ValueT>
class ThreadSafeGlobalMap final {
public:
static ValueT* Get(KeyT id) {
ValueT* item = GetPPtr(id)->get();
PADDLE_ENFORCE_NOT_NULL(
item, platform::errors::NotFound(
"This value is not in thread safe global map."));
return item;
}
template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
platform::errors::AlreadyExists(
"This value has already in thread safe global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...);
ptr->reset(item);
return item;
}

private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::mutex mutex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
// TODO(liyurui): Remove this hard code.
int64_t carrier_id;
if (request->ctrl_message()) {
carrier_id = 0;
} else {
carrier_id = *GlobalMap<int64_t, int64_t>::Get(request->dst_id());
}
bool flag = GlobalMap<int64_t, Carrier>::Get(carrier_id)
const auto& carrier_id = GlobalVal<std::string>::Get();
bool flag = GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(*request);
response->set_rst(flag);
}
Expand Down
11 changes: 0 additions & 11 deletions paddle/fluid/distributed/fleet_executor/runtime_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const {
return interceptor_id_to_rank_;
}
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids() const {
return carrier_id_to_interceptor_ids_;
}
void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
interceptor_id_to_rank_ = interceptor_id_to_rank;
Expand All @@ -47,19 +43,12 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
interceptor_id_to_node_ = interceptor_id_to_node;
}
void SetCarrierIdToInterceptorIds(
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids) {
carrier_id_to_interceptor_ids_ = carrier_id_to_interceptor_ids;
}
std::string DebugString() const;

private:
DISABLE_COPY_AND_ASSIGN(RuntimeGraph);
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids_;
};

} // namespace distributed
Expand Down
3 changes: 0 additions & 3 deletions paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_
set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context)

set_source_files_properties(interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor ${BRPC_DEPS})

if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace();

Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1});
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}});

auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ class StartInterceptor : public Interceptor {
};

TEST(ComputeInterceptor, Compute) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}, {0, 1, 2});
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}});

auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
Expand Down
Loading

0 comments on commit 2273471

Please sign in to comment.