Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi carriers #38650

Merged
merged 1 commit into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ USE_INTERCEPTOR(Amplifier);

void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids) {
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;

// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
Expand All @@ -45,14 +43,12 @@ void Carrier::Init(
void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;
interceptor_id_to_node_ = interceptor_id_to_node;
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
Expand Down Expand Up @@ -156,9 +152,7 @@ bool Carrier::Send(const InterceptorMessage& msg) {
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
int64_t carrier_id = *GlobalMap<int64_t, int64_t>::Get(dst_id);
return GlobalMap<int64_t, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(msg);
return EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
Expand Down Expand Up @@ -192,9 +186,6 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
loop, platform::errors::Fatal("thread task loop must not null"));
interceptor->RegisterTaskLoop(loop);

// TODO(liyurui): Using struct InterceptorID replace int64_t
GlobalMap<int64_t, int64_t>::Create(interceptor_id, carrier_id_);

auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
Expand All @@ -220,19 +211,15 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}

void Carrier::CreateInterceptors() {
if (interceptor_ids_.empty()) return;
if (interceptor_id_to_node_.empty()) return;

auto gc = GetGC(place_);

// create each Interceptor
// no auto init since there is no config
for (int64_t interceptor_id : interceptor_ids_) {
const auto& task_node_iter = interceptor_id_to_node_.find(interceptor_id);
PADDLE_ENFORCE_NE(
task_node_iter, interceptor_id_to_node_.end(),
platform::errors::NotFound("Can not find task node for interceptor %ld",
interceptor_id));
TaskNode* task_node = task_node_iter->second;
for (const auto& item : interceptor_id_to_node_) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;

PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(),
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ class InterceptorMessageServiceImpl;
class RuntimeGraph;
class MessageBus;

// TODO(liyurui): Add CarrierId instead of std::string

class Carrier final {
public:
explicit Carrier(int64_t carrier_id) : carrier_id_(carrier_id) {}
explicit Carrier(const std::string& carrier_id) : carrier_id_(carrier_id) {}
~Carrier();
void Init(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids);
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank);
void Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
Expand Down Expand Up @@ -109,7 +109,7 @@ class Carrier final {
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
int64_t carrier_id_;
std::string carrier_id_;
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
int thread_num_;
Expand Down
48 changes: 19 additions & 29 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {

FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Get(item.first)->Release();
for (const auto& carrier_id : carrier_ids_) {
GlobalMap<std::string, Carrier>::Get(carrier_id)->Release();
}
}

void FleetExecutor::Init(
const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place, const std::vector<TaskNode*>& task_nodes,
const std::string& carrier_id, const framework::ProgramDesc& program_desc,
framework::Scope* scope, const platform::Place& place,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank) {
PADDLE_ENFORCE_GT(task_nodes.size(), 0,
platform::errors::InvalidArgument(
Expand All @@ -58,19 +59,13 @@ void FleetExecutor::Init(
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids;
std::unordered_set<int64_t> interceptor_ids;
for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars);
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
interceptor_ids.insert(interceptor_id);
}
carrier_id_to_interceptor_ids.emplace(0, interceptor_ids);
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于不同的program,我们共用runtime graph么?是不是这个改成个map比较好,一个carrier id对应一个runtime graph这样?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个确实后面要改的

runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
runtime_graph_->SetCarrierIdToInterceptorIds(carrier_id_to_interceptor_ids);
for (auto& unique_op : ops) {
unique_op.release();
}
Expand All @@ -87,27 +82,23 @@ void FleetExecutor::Init(
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Create(item.first, item.first);
}
InitCarrier();
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier_ids_.insert(carrier_id);
GlobalVal<std::string>::Set(carrier_id);
// TODO(liyurui): Maybe message bus should be created only once
InitCarrier(carrier);
InitMessageBus();

// Wait for all message bus connected.
msg_bus_->Barrier();
}

void FleetExecutor::InitCarrier() {
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Get(item.first);
PADDLE_ENFORCE_NOT_NULL(carrier, platform::errors::InvalidArgument(
"Carrier has not been created."));
carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(),
runtime_graph_->interceptor_id_to_rank(), item.second,
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
void FleetExecutor::InitCarrier(Carrier* carrier) {
carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(),
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}

void FleetExecutor::InitMessageBus() {
Expand Down Expand Up @@ -145,10 +136,9 @@ void FleetExecutor::InitMessageBus() {
}
}

void FleetExecutor::Run() {
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Get(item.first)->Start();
}
void FleetExecutor::Run(const std::string& carrier_id) {
GlobalMap<std::string, Carrier>::Get(carrier_id)->Start();
GlobalVal<std::string>::Set(carrier_id);
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,17 @@ class FleetExecutor final {
FleetExecutor() = delete;
explicit FleetExecutor(const std::string& exe_desc_str);
~FleetExecutor();
void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope,
void Init(const std::string& carrier_id,
const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place,
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
void Run(const std::string& carrier_id);

private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
void InitMessageBus();
void InitCarrier();
void InitCarrier(Carrier* carrier);
void CopyParameters(int microbatch_id, const framework::ProgramDesc& program);
FleetExecutorDesc exe_desc_;
std::shared_ptr<RuntimeGraph> runtime_graph_;
Expand All @@ -57,6 +58,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
std::unordered_set<std::string> carrier_ids_;
};

} // namespace distributed
Expand Down
47 changes: 47 additions & 0 deletions paddle/fluid/distributed/fleet_executor/global_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@
namespace paddle {
namespace distributed {

// TODO(liyurui): Change this file to global.h
template <typename T>
class GlobalVal final {
public:
static T Get() { return *GetPtr(); }
static T Set(T val) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数为啥有返回值

auto* ptr = GetPtr();
*ptr = val;
return val;
}

private:
static T* GetPtr() {
static T value;
return &value;
}
};

template <typename KeyT, typename ValueT>
class GlobalMap final {
public:
Expand All @@ -26,6 +44,7 @@ class GlobalMap final {
item, platform::errors::NotFound("This value is not in global map."));
return item;
}

template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
Expand All @@ -37,6 +56,34 @@ class GlobalMap final {
return item;
}

private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::unordered_map<KeyT, std::unique_ptr<ValueT>> id_to_ptr;
return &id_to_ptr[id];
}
};

template <typename KeyT, typename ValueT>
class ThreadSafeGlobalMap final {
public:
static ValueT* Get(KeyT id) {
ValueT* item = GetPPtr(id)->get();
PADDLE_ENFORCE_NOT_NULL(
item, platform::errors::NotFound(
"This value is not in thread safe global map."));
return item;
}
template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
platform::errors::AlreadyExists(
"This value has already in thread safe global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...);
ptr->reset(item);
return item;
}

private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::mutex mutex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
// TODO(liyurui): Remove this hard code.
int64_t carrier_id;
if (request->ctrl_message()) {
carrier_id = 0;
} else {
carrier_id = *GlobalMap<int64_t, int64_t>::Get(request->dst_id());
}
bool flag = GlobalMap<int64_t, Carrier>::Get(carrier_id)
const auto& carrier_id = GlobalVal<std::string>::Get();
bool flag = GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(*request);
response->set_rst(flag);
}
Expand Down
11 changes: 0 additions & 11 deletions paddle/fluid/distributed/fleet_executor/runtime_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const {
return interceptor_id_to_rank_;
}
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids() const {
return carrier_id_to_interceptor_ids_;
}
void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
interceptor_id_to_rank_ = interceptor_id_to_rank;
Expand All @@ -47,19 +43,12 @@ class RuntimeGraph final {
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
interceptor_id_to_node_ = interceptor_id_to_node;
}
void SetCarrierIdToInterceptorIds(
const std::unordered_map<int64_t, std::unordered_set<int64_t>>&
carrier_id_to_interceptor_ids) {
carrier_id_to_interceptor_ids_ = carrier_id_to_interceptor_ids;
}
std::string DebugString() const;

private:
DISABLE_COPY_AND_ASSIGN(RuntimeGraph);
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids_;
};

} // namespace distributed
Expand Down
3 changes: 0 additions & 3 deletions paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_
set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context)

set_source_files_properties(interceptor_pass_the_parcel_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pass_the_parcel_test SRCS interceptor_pass_the_parcel_test.cc DEPS fleet_executor ${BRPC_DEPS})

if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace();

Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}}, {0, 1});
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}});

auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ class StartInterceptor : public Interceptor {
};

TEST(ComputeInterceptor, Compute) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Create(0, 0);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}, {0, 1, 2});
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}});

auto msg_bus = std::make_shared<MessageBus>();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
Expand Down
Loading