diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index 1f42068e5b9a..90b5ed6a56c0 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -97,7 +97,7 @@ std::shared_ptr> ClientConnection::Create( std::shared_ptr> self( new ClientConnection(message_handler, std::move(socket))); // Let our manager process our new connection. - client_handler(self); + client_handler(*self); return self; } diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 55efed3dd2f6..0f65e43277f5 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -62,7 +62,7 @@ template class ClientConnection; template -using ClientHandler = std::function>)>; +using ClientHandler = std::function &)>; template using MessageHandler = std::function>, int64_t, const uint8_t *)>; diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index f29c1f6ffb8b..d4aed268ce8d 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -7,7 +7,7 @@ namespace ray { namespace gcs { AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) { - context_.reset(new RedisContext()); + context_ = std::make_shared(); client_table_.reset(new ClientTable(context_, this, client_id)); object_table_.reset(new ObjectTable(context_, this)); actor_table_.reset(new ActorTable(context_, this)); diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 679177e45c0a..92ff7f854c31 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -93,9 +93,9 @@ void TestTableLookup(const JobID &job_id, std::shared_ptr c // Check that we added the correct task. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, - const std::shared_ptr d) { + const protocol::TaskT &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d->task_specification); + ASSERT_EQ(data->task_specification, d.task_specification); }; // Check that the lookup returns the added task. @@ -139,9 +139,9 @@ void TestLogLookup(const JobID &job_id, std::shared_ptr cli data->manager = manager; // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, - const std::shared_ptr d) { + const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d->manager); + ASSERT_EQ(data->manager, d.manager); }; RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, add_callback)); } @@ -222,7 +222,7 @@ void TestLogAppendAt(const JobID &job_id, std::shared_ptr c // Check that we added the correct task. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const UniqueID &id, - const std::shared_ptr d) { + const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -265,8 +265,8 @@ TEST_F(TestGcsWithAsio, TestLogAppendAt) { // Task table callbacks. void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id, - const std::shared_ptr data) { - ASSERT_EQ(data->scheduling_state, SchedulingState_SCHEDULED); + const TaskTableDataT &data) { + ASSERT_EQ(data.scheduling_state, SchedulingState_SCHEDULED); } void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id, diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 43778f70b9d2..df061700be3a 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -100,15 +100,13 @@ void SubscribeRedisCallback(void *c, void *r, void *privdata) { int64_t RedisCallbackManager::add(const RedisCallback &function) { num_callbacks += 1; - callbacks_.emplace(num_callbacks, std::unique_ptr( - new RedisCallback(function))); + callbacks_.emplace(num_callbacks, function); return num_callbacks; } -RedisCallbackManager::RedisCallback &RedisCallbackManager::get( - int64_t callback_index) { +RedisCallback &RedisCallbackManager::get(int64_t callback_index) { RAY_CHECK(callbacks_.find(callback_index) != callbacks_.end()); - return *callbacks_[callback_index]; + return callbacks_[callback_index]; } void RedisCallbackManager::remove(int64_t callback_index) { @@ -185,7 +183,9 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) { Status RedisContext::RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data, int64_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, - int64_t callback_index, int log_length) { + RedisCallback redisCallback, int log_length) { + int64_t callback_index = + redisCallback != nullptr ? RedisCallbackManager::instance().add(redisCallback) : -1; if (length > 0) { if (log_length >= 0) { std::string redis_command = command + " %d %d %b %b %d"; @@ -222,10 +222,11 @@ Status RedisContext::RunAsync(const std::string &command, const UniqueID &id, Status RedisContext::SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel, - int64_t callback_index) { + const RedisCallback &redisCallback) { RAY_CHECK(pubsub_channel != TablePubsub_NO_PUBLISH) << "Client requested subscribe on a table that does not support pubsub"; + int64_t callback_index = RedisCallbackManager::instance().add(redisCallback); int status = 0; if (client_id.is_nil()) { // Subscribe to all messages. diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 4d9a296eee60..c2371dff5809 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -18,13 +18,13 @@ struct aeEventLoop; namespace ray { namespace gcs { +/// Every callback should take in a vector of the results from the Redis +/// operation and return a bool indicating whether the callback should be +/// deleted once called. +using RedisCallback = std::function; class RedisCallbackManager { public: - /// Every callback should take in a vector of the results from the Redis - /// operation and return a bool indicating whether the callback should be - /// deleted once called. - using RedisCallback = std::function; static RedisCallbackManager &instance() { static RedisCallbackManager instance; @@ -44,7 +44,7 @@ class RedisCallbackManager { ~RedisCallbackManager() { printf("shut down callback manager\n"); } int64_t num_callbacks; - std::unordered_map> callbacks_; + std::unordered_map callbacks_; }; class RedisContext { @@ -70,11 +70,11 @@ class RedisContext { /// -1 for unused. If set, then data must be provided. Status RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data, int64_t length, const TablePrefix prefix, - const TablePubsub pubsub_channel, int64_t callback_index, + const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel, - int64_t callback_index); + const RedisCallback &redisCallback); redisAsyncContext *async_context() { return async_context_; } redisAsyncContext *subscribe_context() { return subscribe_context_; }; diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index be73aedf890e..3c014798b930 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -9,76 +9,66 @@ namespace gcs { template Status Log::Append(const JobID &job_id, const ID &id, - std::shared_ptr data, const WriteCallback &done) { - auto d = std::shared_ptr( - new CallbackData({id, data, nullptr, nullptr, this, client_})); - int64_t callback_index = - RedisCallbackManager::instance().add([d, done](const std::string &data) { - RAY_CHECK(data.empty()); - if (done != nullptr) { - (done)(d->client, d->id, d->data); - } - return true; - }); + std::shared_ptr &dataT, const WriteCallback &done) { + auto callback = [this, id, dataT, done](const std::string &data) { + RAY_CHECK(data.empty()); + if (done != nullptr) { + (done)(client_, id, *dataT); + } + return true; + }; flatbuffers::FlatBufferBuilder fbb; fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, data.get())); + fbb.Finish(Data::Pack(fbb, dataT.get())); return context_->RunAsync("RAY.TABLE_APPEND", id, fbb.GetBufferPointer(), fbb.GetSize(), - prefix_, pubsub_channel_, callback_index); + prefix_, pubsub_channel_, std::move(callback)); } template Status Log::AppendAt(const JobID &job_id, const ID &id, - std::shared_ptr data, const WriteCallback &done, + std::shared_ptr &dataT, const WriteCallback &done, const WriteCallback &failure, int log_length) { - auto d = std::shared_ptr( - new CallbackData({id, data, nullptr, nullptr, this, client_})); - int64_t callback_index = - RedisCallbackManager::instance().add([d, done, failure](const std::string &data) { - if (data.empty()) { - if (done != nullptr) { - (done)(d->client, d->id, d->data); - } - } else { - if (failure != nullptr) { - (failure)(d->client, d->id, d->data); - } - } - return true; - }); + auto callback = [this, id, dataT, done, failure](const std::string &data) { + if (data.empty()) { + if (done != nullptr) { + (done)(client_, id, *dataT); + } + } else { + if (failure != nullptr) { + (failure)(client_, id, *dataT); + } + } + return true; + }; flatbuffers::FlatBufferBuilder fbb; fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, data.get())); + fbb.Finish(Data::Pack(fbb, dataT.get())); return context_->RunAsync("RAY.TABLE_APPEND", id, fbb.GetBufferPointer(), fbb.GetSize(), - prefix_, pubsub_channel_, callback_index, log_length); + prefix_, pubsub_channel_, std::move(callback), log_length); } template Status Log::Lookup(const JobID &job_id, const ID &id, const Callback &lookup) { - auto d = std::shared_ptr( - new CallbackData({id, nullptr, lookup, nullptr, this, client_})); - int64_t callback_index = - RedisCallbackManager::instance().add([d](const std::string &data) { - if (d->callback != nullptr) { - std::vector results; - if (!data.empty()) { - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == d->id); - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); - results.emplace_back(std::move(result)); - } - } - (d->callback)(d->client, d->id, results); + auto callback = [this, id, lookup](const std::string &data) { + if (lookup != nullptr) { + std::vector results; + if (!data.empty()) { + auto root = flatbuffers::GetRoot(data.data()); + RAY_CHECK(from_flatbuf(*root->id()) == id); + for (size_t i = 0; i < root->entries()->size(); i++) { + DataT result; + auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); + data_root->UnPackTo(&result); + results.emplace_back(std::move(result)); } - return true; - }); + } + lookup(client_, id, results); + } + return true; + }; std::vector nil; return context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), prefix_, - pubsub_channel_, callback_index); + pubsub_channel_, std::move(callback)); } template @@ -87,42 +77,38 @@ Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, const SubscriptionCallback &done) { RAY_CHECK(subscribe_callback_index_ == -1) << "Client called Subscribe twice on the same table"; - auto d = std::shared_ptr( - new CallbackData({client_id, nullptr, subscribe, done, this, client_})); - int64_t callback_index = - RedisCallbackManager::instance().add([d](const std::string &data) { - if (data.empty()) { - // No notification data is provided. This is the callback for the - // initial subscription request. - if (d->subscription_callback != nullptr) { - (d->subscription_callback)(d->client); - } - } else { - // Data is provided. This is the callback for a message. - if (d->callback != nullptr) { - // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); - ID id = UniqueID::nil(); - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - std::vector results; - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); - results.emplace_back(std::move(result)); - } - (d->callback)(d->client, id, results); - } + auto callback = [this, subscribe, done](const std::string &data) { + if (data.empty()) { + // No notification data is provided. This is the callback for the + // initial subscription request. + if (done != nullptr) { + done(client_); + } + } else { + // Data is provided. This is the callback for a message. + if (subscribe != nullptr) { + // Parse the notification. + auto root = flatbuffers::GetRoot(data.data()); + ID id = UniqueID::nil(); + if (root->id()->size() > 0) { + id = from_flatbuf(*root->id()); } - // We do not delete the callback after calling it since there may be - // more subscription messages. - return false; - }); - subscribe_callback_index_ = callback_index; - return context_->SubscribeAsync(client_id, pubsub_channel_, callback_index); + std::vector results; + for (size_t i = 0; i < root->entries()->size(); i++) { + DataT result; + auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); + data_root->UnPackTo(&result); + results.emplace_back(std::move(result)); + } + subscribe(client_, id, results); + } + } + // We do not delete the callback after calling it since there may be + // more subscription messages. + return false; + }; + subscribe_callback_index_ = 1; + return context_->SubscribeAsync(client_id, pubsub_channel_, std::move(callback)); } template @@ -131,8 +117,7 @@ Status Log::RequestNotifications(const JobID &job_id, const ID &id, RAY_CHECK(subscribe_callback_index_ >= 0) << "Client requested notifications on a key before Subscribe completed"; return context_->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, client_id.data(), - client_id.size(), prefix_, pubsub_channel_, - /*callback_index=*/-1); + client_id.size(), prefix_, pubsub_channel_, nullptr); } template @@ -141,27 +126,23 @@ Status Log::CancelNotifications(const JobID &job_id, const ID &id, RAY_CHECK(subscribe_callback_index_ >= 0) << "Client canceled notifications on a key before Subscribe completed"; return context_->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, client_id.data(), - client_id.size(), prefix_, pubsub_channel_, - /*callback_index=*/-1); + client_id.size(), prefix_, pubsub_channel_, nullptr); } template Status Table::Add(const JobID &job_id, const ID &id, - std::shared_ptr data, const WriteCallback &done) { - auto d = std::shared_ptr( - new CallbackData({id, data, nullptr, nullptr, this, client_})); - int64_t callback_index = - RedisCallbackManager::instance().add([d, done](const std::string &data) { - if (done != nullptr) { - (done)(d->client, d->id, d->data); - } - return true; - }); + std::shared_ptr &dataT, const WriteCallback &done) { + auto callback = [this, id, dataT, done](const std::string &data) { + if (done != nullptr) { + (done)(client_, id, *dataT); + } + return true; + }; flatbuffers::FlatBufferBuilder fbb; fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, data.get())); + fbb.Finish(Data::Pack(fbb, dataT.get())); return context_->RunAsync("RAY.TABLE_ADD", id, fbb.GetBufferPointer(), fbb.GetSize(), - prefix_, pubsub_channel_, callback_index); + prefix_, pubsub_channel_, std::move(callback)); } template @@ -259,9 +240,8 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } } -void ClientTable::HandleConnected(AsyncGcsClient *client, - const std::shared_ptr data) { - auto connected_client_id = ClientID::from_binary(data->client_id); +void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) { + auto connected_client_id = ClientID::from_binary(data.client_id); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; } @@ -282,7 +262,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, - std::shared_ptr data) { + const ClientTableDataT &data) { RAY_CHECK(log_key == client_log_key_); HandleConnected(client, data); @@ -311,7 +291,7 @@ Status ClientTable::Disconnect() { auto data = std::make_shared(local_client_); data->is_insertion = false; auto add_callback = [this](AsyncGcsClient *client, const ClientID &id, - std::shared_ptr data) { + const ClientTableDataT &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(JobID::nil(), client_log_key_, id)); }; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index d5c4df088aa7..a90519953f9a 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -57,8 +57,8 @@ class Log : virtual public PubsubInterface { using Callback = std::function &data)>; /// The callback to call when a write to a key succeeds. - using WriteCallback = std::function data)>; + using WriteCallback = + std::function; /// The callback to call when a SUBSCRIBE call completes and we are ready to /// request and receive notifications. using SubscriptionCallback = std::function; @@ -89,7 +89,7 @@ class Log : virtual public PubsubInterface { /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Append(const JobID &job_id, const ID &id, std::shared_ptr data, + Status Append(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Append a log entry to a key if and only if the log has the given number @@ -105,7 +105,7 @@ class Log : virtual public PubsubInterface { /// \param log_length The number of entries that the log must have for the /// append to succeed. /// \return Status - Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr data, + Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length); @@ -187,7 +187,7 @@ class TableInterface { public: using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr data, + virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~TableInterface(){}; }; @@ -212,17 +212,6 @@ class Table : private Log, /// request and receive notifications. using SubscriptionCallback = typename Log::SubscriptionCallback; - struct CallbackData { - ID id; - std::shared_ptr data; - Callback callback; - // An optional callback to call for subscription operations, where the - // first message is a notification of subscription success. - SubscriptionCallback subscription_callback; - Log *log; - AsyncGcsClient *client; - }; - Table(const std::shared_ptr &context, AsyncGcsClient *client) : Log(context, client) {} @@ -237,7 +226,7 @@ class Table : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const JobID &job_id, const ID &id, std::shared_ptr data, + Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Lookup an entry asynchronously. @@ -358,19 +347,18 @@ class TaskTable : public Table { Status TestAndUpdate(const JobID &job_id, const TaskID &id, std::shared_ptr data, const TestAndUpdateCallback &callback) { - int64_t callback_index = RedisCallbackManager::instance().add( - [this, callback, id](const std::string &data) { - auto result = std::make_shared(); - auto root = flatbuffers::GetRoot(data.data()); - root->UnPackTo(result.get()); - callback(client_, id, *result, root->updated()); - return true; - }); + auto redisCallback = [this, callback, id](const std::string &data) { + auto result = std::make_shared(); + auto root = flatbuffers::GetRoot(data.data()); + root->UnPackTo(result.get()); + callback(client_, id, *result, root->updated()); + return true; + }; flatbuffers::FlatBufferBuilder fbb; fbb.Finish(TaskTableTestAndUpdate::Pack(fbb, data.get())); RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_TEST_AND_UPDATE", id, fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, callback_index)); + pubsub_channel_, redisCallback)); return Status::OK(); } @@ -499,8 +487,7 @@ class ClientTable : private Log { /// Handle a client table notification. void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications); /// Handle this client's successful connection to the GCS. - void HandleConnected(AsyncGcsClient *client, - const std::shared_ptr client_data); + void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data); /// The key at which the log of client information is stored. This key must /// be kept the same across all instances of the ClientTable, so that all diff --git a/src/ray/gcs/task_table.cc b/src/ray/gcs/task_table.cc index 1e3471cc41f5..6a82c06fa2d4 100644 --- a/src/ray/gcs/task_table.cc +++ b/src/ray/gcs/task_table.cc @@ -46,9 +46,9 @@ Status TaskTableAdd(AsyncGcsClient *gcs_client, Task *task) { TaskSpec *spec = execution_spec.Spec(); auto data = MakeTaskTableData(execution_spec, Task_local_scheduler(task), static_cast(Task_state(task))); - return gcs_client->task_table().Add(ray::JobID::nil(), TaskSpec_task_id(spec), data, - [](gcs::AsyncGcsClient *client, const TaskID &id, - std::shared_ptr data) {}); + return gcs_client->task_table().Add( + ray::JobID::nil(), TaskSpec_task_id(spec), data, + [](gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableDataT &data) {}); } // TODO(pcm): This is a helper method that should go away once we get rid of diff --git a/src/ray/object_manager/connection_pool.cc b/src/ray/object_manager/connection_pool.cc index efe5d68891f4..ec06b6d9aac0 100644 --- a/src/ray/object_manager/connection_pool.cc +++ b/src/ray/object_manager/connection_pool.cc @@ -53,7 +53,7 @@ ray::Status ConnectionPool::GetSender(ConnectionType type, const ClientID &clien } ray::Status ConnectionPool::ReleaseSender(ConnectionType type, - std::shared_ptr conn) { + std::shared_ptr &conn) { std::unique_lock guard(connection_mutex); SenderMapType &conn_map = (type == ConnectionType::MESSAGE) ? available_message_send_connections_ @@ -64,20 +64,21 @@ ray::Status ConnectionPool::ReleaseSender(ConnectionType type, void ConnectionPool::Add(ReceiverMapType &conn_map, const ClientID &client_id, std::shared_ptr conn) { - conn_map[client_id].push_back(conn); + conn_map[client_id].push_back(std::move(conn)); } void ConnectionPool::Add(SenderMapType &conn_map, const ClientID &client_id, std::shared_ptr conn) { - conn_map[client_id].push_back(conn); + conn_map[client_id].push_back(std::move(conn)); } void ConnectionPool::Remove(ReceiverMapType &conn_map, const ClientID &client_id, - std::shared_ptr conn) { - if (conn_map.count(client_id) == 0) { + std::shared_ptr &conn) { + auto it = conn_map.find(client_id); + if (it == conn_map.end()) { return; } - std::vector> &connections = conn_map[client_id]; + auto &connections = it->second; int64_t pos = std::find(connections.begin(), connections.end(), conn) - connections.begin(); if (pos >= (int64_t)connections.size()) { @@ -87,15 +88,16 @@ void ConnectionPool::Remove(ReceiverMapType &conn_map, const ClientID &client_id } uint64_t ConnectionPool::Count(SenderMapType &conn_map, const ClientID &client_id) { - if (conn_map.count(client_id) == 0) { + auto it = conn_map.find(client_id); + if (it == conn_map.end()) { return 0; - }; - return conn_map[client_id].size(); + } + return it->second.size(); } std::shared_ptr ConnectionPool::Borrow(SenderMapType &conn_map, const ClientID &client_id) { - std::shared_ptr conn = conn_map[client_id].back(); + std::shared_ptr conn = std::move(conn_map[client_id].back()); conn_map[client_id].pop_back(); RAY_LOG(DEBUG) << "Borrow " << client_id << " " << conn_map[client_id].size(); return conn; @@ -103,7 +105,7 @@ std::shared_ptr ConnectionPool::Borrow(SenderMapType &conn_map void ConnectionPool::Return(SenderMapType &conn_map, const ClientID &client_id, std::shared_ptr conn) { - conn_map[client_id].push_back(conn); + conn_map[client_id].push_back(std::move(conn)); RAY_LOG(DEBUG) << "Return " << client_id << " " << conn_map[client_id].size(); } diff --git a/src/ray/object_manager/connection_pool.h b/src/ray/object_manager/connection_pool.h index 15774a28798c..132083d55895 100644 --- a/src/ray/object_manager/connection_pool.h +++ b/src/ray/object_manager/connection_pool.h @@ -74,7 +74,7 @@ class ConnectionPool { /// \param type The type of connection. /// \param conn The actual connection. /// \return Status of invoking this method. - ray::Status ReleaseSender(ConnectionType type, std::shared_ptr conn); + ray::Status ReleaseSender(ConnectionType type, std::shared_ptr &conn); // TODO(hme): Implement with error handling. /// Remove a sender connection. This is invoked if the connection is no longer @@ -106,7 +106,7 @@ class ConnectionPool { /// Removes the given receiver for ClientID from the given map. void Remove(ReceiverMapType &conn_map, const ClientID &client_id, - std::shared_ptr conn); + std::shared_ptr &conn); /// Returns the count of sender connections to ClientID. uint64_t Count(SenderMapType &conn_map, const ClientID &client_id); diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index e7a6c8504c4a..d6063a39925a 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -16,8 +16,8 @@ ray::Status ObjectDirectory::ReportObjectAdded(const ObjectID &object_id, data->is_eviction = false; data->object_size = object_info.data_size; ray::Status status = gcs_client_->object_table().Append( - job_id, object_id, data, [](gcs::AsyncGcsClient *client, const UniqueID &id, - const std::shared_ptr data) { + job_id, object_id, data, + [](gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) { // Do nothing. }); return status; diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index cb857fec4f0b..42ef8164aee7 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -110,8 +110,8 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id) { } void ObjectManager::SchedulePull(const ObjectID &object_id, int wait_ms) { - pull_requests_[object_id] = std::shared_ptr( - new asio::deadline_timer(*main_service_, boost::posix_time::milliseconds(wait_ms))); + pull_requests_[object_id] = std::make_shared( + *main_service_, boost::posix_time::milliseconds(wait_ms)); pull_requests_[object_id]->async_wait( [this, object_id](const boost::system::error_code &error_code) { pull_requests_.erase(object_id); @@ -184,7 +184,7 @@ ray::Status ObjectManager::PullEstablishConnection(const ObjectID &object_id, } ray::Status ObjectManager::PullSendRequest(const ObjectID &object_id, - std::shared_ptr conn) { + std::shared_ptr &conn) { flatbuffers::FlatBufferBuilder fbb; auto message = object_manager_protocol::CreatePullRequestMessage( fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary())); @@ -209,7 +209,7 @@ ray::Status ObjectManager::Push(const ObjectID &object_id, const ClientID &clien Status status = object_directory_->GetInformation( client_id, [this, object_id, client_id](const RemoteConnectionInfo &info) { - ObjectInfoT object_info = local_objects_[object_id]; + const ObjectInfoT &object_info = local_objects_[object_id]; uint64_t data_size = static_cast(object_info.data_size + object_info.metadata_size); uint64_t metadata_size = static_cast(object_info.metadata_size); @@ -251,7 +251,7 @@ void ObjectManager::ExecuteSendObject(const ClientID &client_id, ray::Status ObjectManager::SendObjectHeaders(const ObjectID &object_id, uint64_t data_size, uint64_t metadata_size, uint64_t chunk_index, - std::shared_ptr conn) { + std::shared_ptr &conn) { std::pair chunk_status = buffer_pool_.GetChunk(object_id, data_size, metadata_size, chunk_index); ObjectBufferPool::ChunkInfo chunk_info = chunk_status.first; @@ -276,7 +276,7 @@ ray::Status ObjectManager::SendObjectHeaders(const ObjectID &object_id, ray::Status ObjectManager::SendObjectData(const ObjectID &object_id, const ObjectBufferPool::ChunkInfo &chunk_info, - std::shared_ptr conn) { + std::shared_ptr &conn) { boost::system::error_code ec; std::vector buffer; buffer.push_back(asio::buffer(chunk_info.data, chunk_info.buffer_length)); @@ -328,11 +328,11 @@ std::shared_ptr ObjectManager::CreateSenderConnection( return conn; } -void ObjectManager::ProcessNewClient(std::shared_ptr conn) { - conn->ProcessMessages(); +void ObjectManager::ProcessNewClient(TcpClientConnection &conn) { + conn.ProcessMessages(); } -void ObjectManager::ProcessClientMessage(std::shared_ptr conn, +void ObjectManager::ProcessClientMessage(std::shared_ptr &conn, int64_t message_type, const uint8_t *message) { switch (message_type) { case object_manager_protocol::MessageType_PushRequest: { @@ -389,7 +389,7 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr &con conn->ProcessMessages(); } -void ObjectManager::ReceivePushRequest(std::shared_ptr conn, +void ObjectManager::ReceivePushRequest(std::shared_ptr &conn, const uint8_t *message) { // Serialize. auto object_header = @@ -400,14 +400,14 @@ void ObjectManager::ReceivePushRequest(std::shared_ptr conn uint64_t metadata_size = object_header->metadata_size(); receive_service_.post([this, object_id, data_size, metadata_size, chunk_index, conn]() { ExecuteReceiveObject(conn->GetClientID(), object_id, data_size, metadata_size, - chunk_index, conn); + chunk_index, *conn); }); } void ObjectManager::ExecuteReceiveObject(const ClientID &client_id, const ObjectID &object_id, uint64_t data_size, uint64_t metadata_size, uint64_t chunk_index, - std::shared_ptr conn) { + TcpClientConnection &conn) { RAY_LOG(DEBUG) << "ExecuteReceiveObject " << client_id << " " << object_id << " " << chunk_index; @@ -419,7 +419,7 @@ void ObjectManager::ExecuteReceiveObject(const ClientID &client_id, std::vector buffer; buffer.push_back(asio::buffer(chunk_info.data, chunk_info.buffer_length)); boost::system::error_code ec; - conn->ReadBuffer(buffer, ec); + conn.ReadBuffer(buffer, ec); if (ec.value() == 0) { buffer_pool_.SealChunk(object_id, chunk_index); } else { @@ -435,13 +435,13 @@ void ObjectManager::ExecuteReceiveObject(const ClientID &client_id, std::vector buffer; buffer.push_back(asio::buffer(mutable_vec, buffer_length)); boost::system::error_code ec; - conn->ReadBuffer(buffer, ec); + conn.ReadBuffer(buffer, ec); if (ec.value() != 0) { RAY_LOG(ERROR) << ec.message(); } // TODO(hme): If the object isn't local, create a pull request for this chunk. } - conn->ProcessMessages(); + conn.ProcessMessages(); RAY_LOG(DEBUG) << "ReceiveCompleted " << client_id_ << " " << object_id << " " << "/" << config_.max_receives; } diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 117a3073d414..d34d50762f24 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -110,7 +110,7 @@ class ObjectManager { /// /// \param conn The connection. /// \return Status of whether the connection was successfully established. - void ProcessNewClient(std::shared_ptr conn); + void ProcessNewClient(TcpClientConnection &conn); /// Process messages sent from other nodes. We only establish /// transfer connections using this method; all other transfer communication @@ -119,7 +119,7 @@ class ObjectManager { /// \param conn The connection. /// \param message_type The message type. /// \param message A pointer set to the beginning of the message. - void ProcessClientMessage(std::shared_ptr conn, + void ProcessClientMessage(std::shared_ptr &conn, int64_t message_type, const uint8_t *message); /// Cancels all requests (Push/Pull) associated with the given ObjectID. @@ -226,7 +226,7 @@ class ObjectManager { /// Synchronously send a pull request via remote object manager connection. /// Executes on main_service_ thread. ray::Status PullSendRequest(const ObjectID &object_id, - std::shared_ptr conn); + std::shared_ptr &conn); std::shared_ptr CreateSenderConnection( ConnectionPool::ConnectionType type, RemoteConnectionInfo info); @@ -241,23 +241,22 @@ class ObjectManager { /// Executes on send_service_ thread pool. ray::Status SendObjectHeaders(const ObjectID &object_id, uint64_t data_size, uint64_t metadata_size, uint64_t chunk_index, - std::shared_ptr conn); + std::shared_ptr &conn); /// This method initiates the actual object transfer. /// Executes on send_service_ thread pool. ray::Status SendObjectData(const ObjectID &object_id, const ObjectBufferPool::ChunkInfo &chunk_info, - std::shared_ptr conn); + std::shared_ptr &conn); /// Invoked when a remote object manager pushes an object to this object manager. /// This will invoke the object receive on the receive_service_ thread pool. - void ReceivePushRequest(std::shared_ptr conn, + void ReceivePushRequest(std::shared_ptr &conn, const uint8_t *message); /// Execute a receive on the receive_service_ thread pool. void ExecuteReceiveObject(const ClientID &client_id, const ObjectID &object_id, uint64_t data_size, uint64_t metadata_size, - uint64_t chunk_index, - std::shared_ptr conn); + uint64_t chunk_index, TcpClientConnection &conn); /// Handles receiving a pull request message. void ReceivePullRequest(std::shared_ptr &conn, diff --git a/src/ray/object_manager/object_manager_client_connection.cc b/src/ray/object_manager/object_manager_client_connection.cc index b904e5d90ab5..1fcfb7a44304 100644 --- a/src/ray/object_manager/object_manager_client_connection.cc +++ b/src/ray/object_manager/object_manager_client_connection.cc @@ -11,7 +11,7 @@ std::shared_ptr SenderConnection::Create( RAY_CHECK_OK(TcpConnect(socket, ip, port)); std::shared_ptr conn = std::make_shared(std::move(socket)); - return std::make_shared(conn, client_id); + return std::make_shared(std::move(conn), client_id); }; SenderConnection::SenderConnection(std::shared_ptr conn, diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 350b37c4caed..827d19818979 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -65,9 +65,7 @@ class MockServer { void HandleAcceptObjectManager(const boost::system::error_code &error) { ClientHandler client_handler = - [this](std::shared_ptr client) { - object_manager_.ProcessNewClient(client); - }; + [this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); }; MessageHandler message_handler = [this]( std::shared_ptr client, int64_t message_type, const uint8_t *message) { diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 259c3ea82287..faef6850a465 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -56,9 +56,7 @@ class MockServer { void HandleAcceptObjectManager(const boost::system::error_code &error) { ClientHandler client_handler = - [this](std::shared_ptr client) { - object_manager_.ProcessNewClient(client); - }; + [this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); }; MessageHandler message_handler = [this]( std::shared_ptr client, int64_t message_type, const uint8_t *message) { diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 592c26481c24..9cc9cad1b919 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -258,8 +258,9 @@ Status LineageCache::Flush() { // Write back all ready tasks whose arguments have been committed to the GCS. gcs::raylet::TaskTable::WriteCallback task_callback = [this]( - ray::gcs::AsyncGcsClient *client, const TaskID &id, - const std::shared_ptr data) { HandleEntryCommitted(id); }; + ray::gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) { + HandleEntryCommitted(id); + }; for (const auto &ready_task_id : ready_task_ids) { auto task = lineage_.GetEntry(ready_task_id); // TODO(swang): Make this better... diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 9a51a3cf9321..1a19feb5c2ca 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -23,7 +23,7 @@ class MockGcs : public gcs::TableInterface, } Status Add(const JobID &job_id, const TaskID &task_id, - std::shared_ptr task_data, + std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; callbacks_.push_back( @@ -38,7 +38,7 @@ class MockGcs : public gcs::TableInterface, bool send_notification = (subscribed_tasks_.count(task_id) == 1); auto callback = [this, send_notification](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - std::shared_ptr data) { + const protocol::TaskT &data) { if (send_notification) { notification_callback_(client, task_id, data); } @@ -63,7 +63,7 @@ class MockGcs : public gcs::TableInterface, void Flush() { for (const auto &callback : callbacks_) { - callback.first(NULL, callback.second, task_table_[callback.second]); + callback.first(NULL, callback.second, *task_table_[callback.second]); } callbacks_.clear(); } @@ -86,7 +86,7 @@ class LineageCacheTest : public ::testing::Test { LineageCacheTest() : mock_gcs_(), lineage_cache_(ClientID::from_random(), mock_gcs_, mock_gcs_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - std::shared_ptr data) { + const ray::protocol::TaskT &data) { lineage_cache_.HandleEntryCommitted(task_id); }); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 82740f1f5985..9f44d466f5c6 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -156,7 +156,7 @@ void NodeManager::Heartbeat() { ray::Status status = heartbeat_table.Add( UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, [](ray::gcs::AsyncGcsClient *client, const ClientID &id, - std::shared_ptr data) { + const HeartbeatTableDataT &data) { RAY_LOG(DEBUG) << "[HEARTBEAT] heartbeat sent callback"; }); @@ -279,9 +279,9 @@ void NodeManager::HandleActorCreation(const ActorID &actor_id, } } -void NodeManager::ProcessNewClient(std::shared_ptr client) { +void NodeManager::ProcessNewClient(LocalClientConnection &client) { // The new client is a worker, so begin listening for messages. - client->ProcessMessages(); + client.ProcessMessages(); } void NodeManager::DispatchTasks() { @@ -309,9 +309,9 @@ void NodeManager::DispatchTasks() { } } -void NodeManager::ProcessClientMessage(std::shared_ptr client, - int64_t message_type, - const uint8_t *message_data) { +void NodeManager::ProcessClientMessage( + const std::shared_ptr &client, int64_t message_type, + const uint8_t *message_data) { RAY_LOG(DEBUG) << "Message of type " << message_type; switch (message_type) { @@ -319,7 +319,7 @@ void NodeManager::ProcessClientMessage(std::shared_ptr cl auto message = flatbuffers::GetRoot(message_data); if (message->is_worker()) { // Create a new worker from the registration request. - std::shared_ptr worker(new Worker(message->worker_pid(), client)); + auto worker = std::make_shared(message->worker_pid(), client); // Register the new worker. worker_pool_.RegisterWorker(std::move(worker)); } @@ -329,10 +329,10 @@ void NodeManager::ProcessClientMessage(std::shared_ptr cl RAY_CHECK(worker); // If the worker was assigned a task, mark it as finished. if (!worker->GetAssignedTaskId().is_nil()) { - FinishAssignedTask(worker); + FinishAssignedTask(*worker); } // Return the worker to the idle pool. - worker_pool_.PushWorker(worker); + worker_pool_.PushWorker(std::move(worker)); // Call task dispatch to assign work to the new worker. DispatchTasks(); @@ -436,14 +436,13 @@ void NodeManager::ProcessClientMessage(std::shared_ptr cl client->ProcessMessages(); } -void NodeManager::ProcessNewNodeManager( - std::shared_ptr node_manager_client) { - node_manager_client->ProcessMessages(); +void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client) { + node_manager_client.ProcessMessages(); } -void NodeManager::ProcessNodeManagerMessage( - std::shared_ptr node_manager_client, int64_t message_type, - const uint8_t *message_data) { +void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_client, + int64_t message_type, + const uint8_t *message_data) { switch (message_type) { case protocol::MessageType_ForwardTaskRequest: { auto message = flatbuffers::GetRoot(message_data); @@ -458,7 +457,7 @@ void NodeManager::ProcessNodeManagerMessage( default: RAY_LOG(FATAL) << "Received unexpected message type " << message_type; } - node_manager_client->ProcessMessages(); + node_manager_client.ProcessMessages(); } void NodeManager::HandleWaitingTaskReady(const TaskID &task_id) { @@ -639,8 +638,8 @@ void NodeManager::AssignTask(Task &task) { } } -void NodeManager::FinishAssignedTask(std::shared_ptr worker) { - TaskID task_id = worker->GetAssignedTaskId(); +void NodeManager::FinishAssignedTask(Worker &worker) { + TaskID task_id = worker.GetAssignedTaskId(); RAY_LOG(DEBUG) << "Finished task " << task_id; auto tasks = local_queues_.RemoveTasks({task_id}); auto task = *tasks.begin(); @@ -648,7 +647,7 @@ void NodeManager::FinishAssignedTask(std::shared_ptr worker) { if (task.GetTaskSpecification().IsActorCreationTask()) { // If this was an actor creation task, then convert the worker to an actor. auto actor_id = task.GetTaskSpecification().ActorCreationId(); - worker->AssignActorId(actor_id); + worker.AssignActorId(actor_id); // Publish the actor creation event to all other nodes so that methods for // the actor will be forwarded directly to this node. @@ -684,7 +683,7 @@ void NodeManager::FinishAssignedTask(std::shared_ptr worker) { } // Unset the worker's assigned task. - worker->AssignTaskId(TaskID::nil()); + worker.AssignTaskId(TaskID::nil()); } void NodeManager::ResubmitTask(const TaskID &task_id) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 3cf77327d08d..268e61f2301b 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -37,7 +37,7 @@ class NodeManager { std::shared_ptr gcs_client); /// Process a new client connection. - void ProcessNewClient(std::shared_ptr client); + void ProcessNewClient(LocalClientConnection &client); /// Process a message from a client. This method is responsible for /// explicitly listening for more messages from the client if the client is @@ -46,12 +46,12 @@ class NodeManager { /// \param client The client that sent the message. /// \param message_type The message type (e.g., a flatbuffer enum). /// \param message A pointer to the message data. - void ProcessClientMessage(std::shared_ptr client, + void ProcessClientMessage(const std::shared_ptr &client, int64_t message_type, const uint8_t *message); - void ProcessNewNodeManager(std::shared_ptr node_manager_client); + void ProcessNewNodeManager(TcpClientConnection &node_manager_client); - void ProcessNodeManagerMessage(std::shared_ptr node_manager_client, + void ProcessNodeManagerMessage(TcpClientConnection &node_manager_client, int64_t message_type, const uint8_t *message); ray::Status RegisterGcs(); @@ -69,7 +69,7 @@ class NodeManager { /// Assign a task. The task is assumed to not be queued in local_queues_. void AssignTask(Task &task); /// Handle a worker finishing its assigned task. - void FinishAssignedTask(std::shared_ptr worker); + void FinishAssignedTask(Worker &worker); /// Schedule tasks. void ScheduleTasks(); /// Handle a task whose local dependencies were missing and are now available. diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 21ad56541240..bc23eb54d299 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -86,14 +86,12 @@ void Raylet::DoAcceptNodeManager() { void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) { if (!error) { - ClientHandler client_handler = - [this](std::shared_ptr client) { - node_manager_.ProcessNewNodeManager(client); - }; + ClientHandler client_handler = [this]( + TcpClientConnection &client) { node_manager_.ProcessNewNodeManager(client); }; MessageHandler message_handler = [this]( std::shared_ptr client, int64_t message_type, const uint8_t *message) { - node_manager_.ProcessNodeManagerMessage(client, message_type, message); + node_manager_.ProcessNodeManagerMessage(*client, message_type, message); }; // Accept a new local client and dispatch it to the node manager. auto new_connection = TcpClientConnection::Create(client_handler, message_handler, @@ -111,9 +109,7 @@ void Raylet::DoAcceptObjectManager() { void Raylet::HandleAcceptObjectManager(const boost::system::error_code &error) { ClientHandler client_handler = - [this](std::shared_ptr client) { - object_manager_.ProcessNewClient(client); - }; + [this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); }; MessageHandler message_handler = [this]( std::shared_ptr client, int64_t message_type, const uint8_t *message) { @@ -134,9 +130,7 @@ void Raylet::HandleAccept(const boost::system::error_code &error) { if (!error) { // TODO: typedef these handlers. ClientHandler client_handler = - [this](std::shared_ptr client) { - node_manager_.ProcessNewClient(client); - }; + [this](LocalClientConnection &client) { node_manager_.ProcessNewClient(client); }; MessageHandler message_handler = [this]( std::shared_ptr client, int64_t message_type, const uint8_t *message) { diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 4b58b5a637f8..1a15dd6c22c0 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -87,14 +87,15 @@ void WorkerPool::StartWorker(bool force_start) { } void WorkerPool::RegisterWorker(std::shared_ptr worker) { - RAY_LOG(DEBUG) << "Registering worker with pid " << worker->Pid(); - registered_workers_.push_back(worker); - RAY_CHECK(started_worker_pids_.count(worker->Pid()) > 0); - started_worker_pids_.erase(worker->Pid()); + auto pid = worker->Pid(); + RAY_LOG(DEBUG) << "Registering worker with pid " << pid; + registered_workers_.push_back(std::move(worker)); + RAY_CHECK(started_worker_pids_.count(pid) > 0); + started_worker_pids_.erase(pid); } std::shared_ptr WorkerPool::GetRegisteredWorker( - std::shared_ptr connection) const { + const std::shared_ptr &connection) const { for (auto it = registered_workers_.begin(); it != registered_workers_.end(); it++) { if ((*it)->Connection() == connection) { return (*it); @@ -135,7 +136,7 @@ std::shared_ptr WorkerPool::PopWorker(const ActorID &actor_id) { // A helper function to remove a worker from a list. Returns true if the worker // was found and removed. bool removeWorker(std::list> &worker_pool, - std::shared_ptr worker) { + const std::shared_ptr &worker) { for (auto it = worker_pool.begin(); it != worker_pool.end(); it++) { if (*it == worker) { worker_pool.erase(it); diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 8b6ef1e54d24..d1c6def3d473 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -60,7 +60,7 @@ class WorkerPool { /// \return The Worker that owns the given client connection. Returns nullptr /// if the client has not registered a worker yet. std::shared_ptr GetRegisteredWorker( - std::shared_ptr connection) const; + const std::shared_ptr &connection) const; /// Disconnect a registered worker. /// diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 8d6c526b4945..28c5ef730d82 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -30,8 +30,8 @@ class WorkerPoolTest : public ::testing::Test { WorkerPoolTest() : worker_pool_({}), io_service_() {} std::shared_ptr CreateWorker(pid_t pid) { - std::function)> client_handler = [this]( - std::shared_ptr client) { HandleNewClient(client); }; + std::function client_handler = + [this](LocalClientConnection &client) { HandleNewClient(client); }; std::function, int64_t, const uint8_t *)> message_handler = [this](std::shared_ptr client, int64_t message_type, const uint8_t *message) { @@ -49,7 +49,7 @@ class WorkerPoolTest : public ::testing::Test { boost::asio::io_service io_service_; private: - void HandleNewClient(std::shared_ptr){}; + void HandleNewClient(LocalClientConnection &){}; void HandleMessage(std::shared_ptr, int64_t, const uint8_t *){}; }; diff --git a/test/runtest.py b/test/runtest.py index 1f95250cdf76..469d21b7708e 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -2036,7 +2036,7 @@ def testWorkers(self): @ray.remote def f(): - return id(ray.worker.global_worker) + return id(ray.worker.global_worker), os.getpid() # Wait until all of the workers have started. worker_ids = set()