Skip to content

Commit

Permalink
[fleet_executor] Add barrier rpc (PaddlePaddle#38799)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored Jan 10, 2022
1 parent 492e6dd commit cd2855b
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 34 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ endif()
cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog)

cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog ${BRPC_DEPS})

Expand All @@ -29,8 +29,8 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

add_subdirectory(test)
endif()
1 change: 0 additions & 1 deletion paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ void FleetExecutor::Run(const std::string& carrier_id) {
// Set current running carrier
if (*GlobalVal<std::string>::Get() != carrier_id) {
GlobalVal<std::string>::Set(new std::string(carrier_id));
// TODO(liyurui): Move barrier to service
GlobalVal<MessageBus>::Get()->Barrier();
}
carrier->Start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ message InterceptorMessage {

message InterceptorResponse { optional bool rst = 1 [ default = false ]; }

service TheInterceptorMessageService {
rpc InterceptorMessageService(InterceptorMessage)
service MessageService {
rpc ReceiveInterceptorMessage(InterceptorMessage)
returns (InterceptorResponse);
rpc IncreaseBarrierCount(InterceptorMessage) returns (InterceptorResponse);
}
32 changes: 14 additions & 18 deletions paddle/fluid/distributed/fleet_executor/message_bus.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,9 @@ void MessageBus::Barrier() {

bool MessageBus::DispatchMsgToCarrier(
const InterceptorMessage& interceptor_message) {
if (interceptor_message.ctrl_message()) {
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
// for barrier
IncreaseBarrierCount();
return true;
} else {
const std::string& carrier_id = *GlobalVal<std::string>::Get();
return GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(interceptor_message);
}
const std::string& carrier_id = *GlobalVal<std::string>::Get();
return GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(interceptor_message);
}

void MessageBus::ListenPort() {
Expand All @@ -185,10 +176,9 @@ void MessageBus::ListenPort() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
// function keep listen the port and handle the message
PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service_,
brpc::SERVER_DOESNT_OWN_SERVICE),
0, platform::errors::Unavailable(
"Message bus: init brpc service error."));
PADDLE_ENFORCE_EQ(
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0,
platform::errors::Unavailable("Message bus: init brpc service error."));

// start the server
const char* ip_for_brpc = addr_.c_str();
Expand Down Expand Up @@ -229,11 +219,16 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
PADDLE_ENFORCE_EQ(
channel.Init(dst_addr_for_brpc, &options), 0,
platform::errors::Unavailable("Message bus: init brpc channel error."));
TheInterceptorMessageService_Stub stub(&channel);
MessageService_Stub stub(&channel);
InterceptorResponse response;
brpc::Controller ctrl;
ctrl.set_log_id(0);
stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL);
if (interceptor_message.ctrl_message()) {
stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL);
} else {
stub.ReceiveInterceptorMessage(&ctrl, &interceptor_message, &response,
NULL);
}
if (!ctrl.Failed()) {
if (response.rst()) {
VLOG(3) << "Message bus: brpc sends success.";
Expand All @@ -248,6 +243,7 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
return false;
}
}

#endif

} // namespace distributed
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/fleet_executor/message_bus.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
!defined(PADDLE_WITH_ASCEND_CL)
#include "brpc/channel.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#endif

#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
Expand Down Expand Up @@ -83,7 +83,7 @@ class MessageBus final {

#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
InterceptorMessageServiceImpl interceptor_message_service_;
MessageServiceImpl message_service_;
// brpc server
brpc::Server server_;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,37 @@
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"

namespace paddle {
namespace distributed {

void InterceptorMessageServiceImpl::InterceptorMessageService(
void MessageServiceImpl::ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
VLOG(3) << "Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
bool flag = GlobalVal<MessageBus>::Get()->DispatchMsgToCarrier(*request);
response->set_rst(flag);
}

void MessageServiceImpl::IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Barrier Service receives a message from rank "
<< request->src_id() << " to rank " << request->dst_id();
GlobalVal<MessageBus>::Get()->IncreaseBarrierCount();
response->set_rst(true);
}

} // namespace distributed
} // namespace paddle
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
namespace paddle {
namespace distributed {

class InterceptorMessageServiceImpl : public TheInterceptorMessageService {
class MessageServiceImpl : public MessageService {
public:
InterceptorMessageServiceImpl() {}
virtual ~InterceptorMessageServiceImpl() {}
virtual void InterceptorMessageService(
MessageServiceImpl() {}
virtual ~MessageServiceImpl() {}
virtual void ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done);
virtual void IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done);
Expand Down

0 comments on commit cd2855b

Please sign in to comment.