Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(),
config.max_lineage_size),
actor_registry_(),
node_manager_server_(config.node_manager_port, io_service, *this),
node_manager_server_("NodeManager", config.node_manager_port),
node_manager_service_(io_service, *this),
client_call_manager_(io_service) {
RAY_CHECK(heartbeat_period_.count() > 0);
// Initialize the resource map with own cluster resource configuration.
Expand All @@ -118,6 +119,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,

RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str()));
// Run the node manger rpc server.
node_manager_server_.RegisterService(node_manager_service_);
node_manager_server_.Run();
}

Expand Down
5 changes: 4 additions & 1 deletion src/ray/raylet/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
std::unordered_map<ActorID, ActorCheckpointID> checkpoint_id_to_restore_;

/// The RPC server.
rpc::NodeManagerServer node_manager_server_;
rpc::GrpcServer node_manager_server_;

/// The RPC service.
rpc::NodeManagerGrpcService node_manager_service_;

/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s.
rpc::ClientCallManager client_call_manager_;
Expand Down
17 changes: 12 additions & 5 deletions src/ray/rpc/grpc_server.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ray/rpc/grpc_server.h"
#include <grpcpp/impl/service_type.h>

namespace ray {
namespace rpc {
Expand All @@ -9,17 +10,18 @@ void GrpcServer::Run() {
grpc::ServerBuilder builder;
// TODO(hchen): Add options for authentication.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
// Allow subclasses to register concrete services.
RegisterServices(builder);
// Register all the services to this server.
for (auto &entry : services_) {
builder.RegisterService(&entry.get());
}
// Get hold of the completion queue used for the asynchronous communication
// with the gRPC runtime.
cq_ = builder.AddCompletionQueue();
// Build and start server.
server_ = builder.BuildAndStart();
RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << ".";

// Allow subclasses to initialize the server call factories.
InitServerCallFactories(&server_call_factories_and_concurrencies_);
// Create calls for all the server call factories.
for (auto &entry : server_call_factories_and_concurrencies_) {
for (int i = 0; i < entry.second; i++) {
// Create and request calls from the factory.
Expand All @@ -31,6 +33,11 @@ void GrpcServer::Run() {
polling_thread.detach();
}

void GrpcServer::RegisterService(GrpcService &service) {
services_.emplace_back(service.GetGrpcService());
service.InitServerCallFactories(cq_, &server_call_factories_and_concurrencies_);
}

void GrpcServer::PollEventsFromCompletionQueue() {
void *tag;
bool ok;
Expand All @@ -48,7 +55,7 @@ void GrpcServer::PollEventsFromCompletionQueue() {
// incoming request.
server_call->GetFactory().CreateCall();
server_call->SetState(ServerCallState::PROCESSING);
main_service_.post([server_call] { server_call->HandleRequest(); });
server_call->HandleRequest();
break;
case ServerCallState::SENDING_REPLY:
// The reply has been sent, this call can be deleted now.
Expand Down
77 changes: 52 additions & 25 deletions src/ray/rpc/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
namespace ray {
namespace rpc {

/// Base class that represents an abstract gRPC server.
class GrpcService;

/// Class that represents an gRPC server.
///
/// A `GrpcServer` listens on a specific port. It owns
/// 1) a `ServerCompletionQueue` that is used for polling events from gRPC,
Expand All @@ -28,11 +30,7 @@ class GrpcServer {
/// \param[in] name Name of this server, used for logging and debugging purpose.
/// \param[in] port The port to bind this server to. If it's 0, a random available port
/// will be chosen.
/// \param[in] main_service The main event loop, to which service handler functions
/// will be posted.
GrpcServer(const std::string &name, const uint32_t port,
boost::asio::io_service &main_service)
: name_(name), port_(port), main_service_(main_service) {}
GrpcServer(const std::string &name, const uint32_t port) : name_(name), port_(port) {}

/// Destruct this gRPC server.
~GrpcServer() {
Expand All @@ -46,36 +44,25 @@ class GrpcServer {
/// Get the port of this gRPC server.
int GetPort() const { return port_; }

protected:
/// Subclasses should implement this method and register one or multiple gRPC services
/// to the given `ServerBuilder`.
/// Register a grpc service. Multiple services can be registered to the same server.
/// Note that the `service` registered must remain valid for the lifetime of the
/// `GrpcServer`, as it holds the underlying `grpc::Service`.
///
/// \param[in] builder The `ServerBuilder` instance to register services to.
virtual void RegisterServices(grpc::ServerBuilder &builder) = 0;

/// Subclasses should implement this method to initialize the `ServerCallFactory`
/// instances, as well as specify maximum number of concurrent requests that gRPC
/// server can "accept" (not "handle"). Each factory will be used to create
/// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and
/// handle an incoming request.
///
/// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects,
/// and the maximum number of concurrent requests that gRPC server can accept.
virtual void InitServerCallFactories(
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) = 0;
/// \param[in] service A `GrpcService` to register to this server.
void RegisterService(GrpcService &service);

protected:
/// This function runs in a background thread. It keeps polling events from the
/// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances
/// via the `ServerCall` objects.
void PollEventsFromCompletionQueue();

/// The main event loop, to which the service handler functions will be posted.
boost::asio::io_service &main_service_;
/// Name of this server, used for logging and debugging purpose.
const std::string name_;
/// Port of this server.
int port_;
/// The `grpc::Service` objects which should be registered to `ServerBuilder`.
std::vector<std::reference_wrapper<grpc::Service>> services_;
/// The `ServerCallFactory` objects, and the maximum number of concurrent requests that
/// gRPC server can accept.
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
Expand All @@ -86,6 +73,46 @@ class GrpcServer {
std::unique_ptr<grpc::Server> server_;
};

/// Base class that represents an abstract gRPC service.
///
/// Subclass should implement `InitServerCallFactories` to decide
/// which kinds of requests this service should accept.
class GrpcService {
public:
/// Constructor.
///
/// \param[in] main_service The main event loop, to which service handler functions
/// will be posted.
GrpcService(boost::asio::io_service &main_service) : main_service_(main_service) {}

/// Destruct this gRPC service.
~GrpcService() {}

protected:
/// Return the underlying grpc::Service object for this class.
/// This is passed to `GrpcServer` to be registered to grpc `ServerBuilder`.
virtual grpc::Service &GetGrpcService() = 0;

/// Subclasses should implement this method to initialize the `ServerCallFactory`
/// instances, as well as specify maximum number of concurrent requests that gRPC
/// server can "accept" (not "handle"). Each factory will be used to create
/// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and
/// handle an incoming request.
///
/// \param[in] cq The grpc completion queue.
/// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects,
/// and the maximum number of concurrent requests that gRPC server can accept.
virtual void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) = 0;

/// The main event loop, to which the service handler functions will be posted.
boost::asio::io_service &main_service_;

friend class GrpcServer;
};

} // namespace rpc
} // namespace ray

Expand Down
25 changes: 12 additions & 13 deletions src/ray/rpc/node_manager_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,31 @@ class NodeManagerServiceHandler {
RequestDoneCallback done_callback) = 0;
};

/// The `GrpcServer` for `NodeManagerService`.
class NodeManagerServer : public GrpcServer {
/// The `GrpcService` for `NodeManagerService`.
class NodeManagerGrpcService : public GrpcService {
public:
/// Constructor.
///
/// \param[in] port See super class.
/// \param[in] main_service See super class.
/// \param[in] io_service See super class.
/// \param[in] handler The service handler that actually handle the requests.
NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service,
NodeManagerServiceHandler &service_handler)
: GrpcServer("NodeManager", port, main_service),
service_handler_(service_handler){};
NodeManagerGrpcService(boost::asio::io_service &io_service,
NodeManagerServiceHandler &service_handler)
: GrpcService(io_service), service_handler_(service_handler){};

void RegisterServices(grpc::ServerBuilder &builder) override {
/// Register `NodeManagerService`.
builder.RegisterService(&service_);
}
protected:
grpc::Service &GetGrpcService() override { return service_; }

void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) override {
// Initialize the factory for `ForwardTask` requests.
std::unique_ptr<ServerCallFactory> forward_task_call_factory(
new ServerCallFactoryImpl<NodeManagerService, NodeManagerServiceHandler,
ForwardTaskRequest, ForwardTaskReply>(
service_, &NodeManagerService::AsyncService::RequestForwardTask,
service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_));
service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq,
main_service_));

// Set `ForwardTask`'s accept concurrency to 100.
server_call_factories_and_concurrencies->emplace_back(
Expand All @@ -61,6 +59,7 @@ class NodeManagerServer : public GrpcServer {
private:
/// The grpc async service object.
NodeManagerService::AsyncService service_;

/// The service handler that actually handle the requests.
NodeManagerServiceHandler &service_handler_;
};
Expand Down
26 changes: 21 additions & 5 deletions src/ray/rpc/server_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,27 @@ class ServerCallImpl : public ServerCall {
/// \param[in] factory The factory which created this call.
/// \param[in] service_handler The service handler that handles the request.
/// \param[in] handle_request_function Pointer to the service handler function.
/// \param[in] io_service The event loop.
ServerCallImpl(
const ServerCallFactory &factory, ServiceHandler &service_handler,
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function)
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function,
boost::asio::io_service &io_service)
: state_(ServerCallState::PENDING),
factory_(factory),
service_handler_(service_handler),
handle_request_function_(handle_request_function),
response_writer_(&context_) {}
response_writer_(&context_),
io_service_(io_service) {}

ServerCallState GetState() const override { return state_; }

void SetState(const ServerCallState &new_state) override { state_ = new_state; }

void HandleRequest() override {
io_service_.post([this] { HandleRequestImpl(); });
}

void HandleRequestImpl() {
state_ = ServerCallState::PROCESSING;
(service_handler_.*handle_request_function_)(request_, &reply_,
[this](Status status) {
Expand Down Expand Up @@ -146,6 +153,9 @@ class ServerCallImpl : public ServerCall {
/// The reponse writer.
grpc::ServerAsyncResponseWriter<Reply> response_writer_;

/// The event loop.
boost::asio::io_service &io_service_;

/// The request message.
Request request_;

Expand Down Expand Up @@ -185,23 +195,26 @@ class ServerCallFactoryImpl : public ServerCallFactory {
/// \param[in] service_handler The service handler that handles the request.
/// \param[in] handle_request_function Pointer to the service handler function.
/// \param[in] cq The `CompletionQueue`.
/// \param[in] io_service The event loop.
ServerCallFactoryImpl(
AsyncService &service,
RequestCallFunction<GrpcService, Request, Reply> request_call_function,
ServiceHandler &service_handler,
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function,
const std::unique_ptr<grpc::ServerCompletionQueue> &cq)
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
boost::asio::io_service &io_service)
: service_(service),
request_call_function_(request_call_function),
service_handler_(service_handler),
handle_request_function_(handle_request_function),
cq_(cq) {}
cq_(cq),
io_service_(io_service) {}

ServerCall *CreateCall() const override {
// Create a new `ServerCall`. This object will eventually be deleted by
// `GrpcServer::PollEventsFromCompletionQueue`.
auto call = new ServerCallImpl<ServiceHandler, Request, Reply>(
*this, service_handler_, handle_request_function_);
*this, service_handler_, handle_request_function_, io_service_);
/// Request gRPC runtime to starting accepting this kind of request, using the call as
/// the tag.
(service_.*request_call_function_)(&call->context_, &call->request_,
Expand All @@ -225,6 +238,9 @@ class ServerCallFactoryImpl : public ServerCallFactory {

/// The `CompletionQueue`.
const std::unique_ptr<grpc::ServerCompletionQueue> &cq_;

/// The event loop.
boost::asio::io_service &io_service_;
};

} // namespace rpc
Expand Down