diff --git a/src/mock/ray/gcs_client/accessor.h b/src/mock/ray/gcs_client/accessor.h index 12c2c2698fee..463f8939cce8 100644 --- a/src/mock/ray/gcs_client/accessor.h +++ b/src/mock/ray/gcs_client/accessor.h @@ -80,30 +80,17 @@ class MockNodeInfoAccessor : public NodeInfoAccessor { int64_t timeout_ms, const std::vector &node_ids), (override)); - MOCK_METHOD(void, - AsyncSubscribeToNodeChange, - (std::function subscribe, - StatusCallback done), - (override)); MOCK_METHOD( void, AsyncSubscribeToNodeAddressAndLivenessChange, (std::function subscribe, StatusCallback done), (override)); - MOCK_METHOD(const rpc::GcsNodeInfo *, - Get, - (const NodeID &node_id, bool filter_dead_nodes), - (const, override)); - MOCK_METHOD(const rpc::GcsNodeAddressAndLiveness *, + MOCK_METHOD(std::optional, GetNodeAddressAndLiveness, (const NodeID &node_id, bool filter_dead_nodes), (const, override)); - MOCK_METHOD((const absl::flat_hash_map &), - GetAll, - (), - (const, override)); - MOCK_METHOD((const absl::flat_hash_map &), + MOCK_METHOD((absl::flat_hash_map), GetAllNodeAddressAndLiveness, (), (const, override)); @@ -114,6 +101,7 @@ class MockNodeInfoAccessor : public NodeInfoAccessor { std::vector &nodes_alive), (override)); MOCK_METHOD(bool, IsNodeDead, (const NodeID &node_id), (const, override)); + MOCK_METHOD(bool, IsNodeAlive, (const NodeID &node_id), (const, override)); MOCK_METHOD(void, AsyncResubscribe, (), (override)); }; diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 8d755487ab30..b3c46fb51a8e 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1945,7 +1945,7 @@ class CoreWorker { // the shutdown procedure without exposing additional public APIs. friend class CoreWorkerShutdownExecutor; - /// Used to block in certain spots if the GCS node cache is needed. + /// Used to block in certain spots if the GCS node address and liveness cache is needed. std::mutex gcs_client_node_cache_populated_mutex_; std::condition_variable gcs_client_node_cache_populated_cv_; bool gcs_client_node_cache_populated_ = false; diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index 24b668678b94..aae14dd55795 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -588,9 +588,9 @@ std::shared_ptr CoreWorkerProcessImpl::CreateCoreWorker( if (object_locations.has_value()) { locations.reserve(object_locations->size()); for (const auto &node_id : *object_locations) { - auto *node_info = core_worker->gcs_client_->Nodes().GetNodeAddressAndLiveness( + auto node_info = core_worker->gcs_client_->Nodes().GetNodeAddressAndLiveness( node_id, /*filter_dead_nodes=*/false); - if (node_info == nullptr) { + if (!node_info) { // Unsure if the node is dead, so we need to confirm with the GCS. This should // be rare, the only foreseeable reasons are: // 1. We filled our cache after the GCS cleared the node info due to diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index b55328668315..f7a9a9425e10 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -1162,7 +1162,7 @@ bool TaskManager::RetryTaskIfPossible(const TaskID &task_id, const auto node_info = gcs_client_->Nodes().GetNodeAddressAndLiveness(task_entry.GetNodeId(), /*filter_dead_nodes=*/false); - is_preempted = node_info != nullptr && node_info->has_death_info() && + is_preempted = node_info && node_info->has_death_info() && node_info->death_info().reason() == rpc::NodeDeathInfo::AUTOSCALER_DRAIN_PREEMPTED; } diff --git a/src/ray/core_worker/tests/task_manager_test.cc b/src/ray/core_worker/tests/task_manager_test.cc index 14f80e4211c0..8978dc7d0a7e 100644 --- a/src/ray/core_worker/tests/task_manager_test.cc +++ b/src/ray/core_worker/tests/task_manager_test.cc @@ -2889,7 +2889,7 @@ TEST_F(TaskManagerTest, TestTaskRetriedOnNodePreemption) { rpc::NodeDeathInfo::AUTOSCALER_DRAIN_PREEMPTED); EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, GetNodeAddressAndLiveness(node_id, false)) - .WillOnce(::testing::Return(&node_info)); + .WillOnce(::testing::Return(node_info)); // Task should be retried because the node was preempted, even with 0 retries left rpc::RayErrorInfo node_died_error; diff --git a/src/ray/core_worker_rpc_client/core_worker_client_pool.cc b/src/ray/core_worker_rpc_client/core_worker_client_pool.cc index 3de28fad8519..7680174c1696 100644 --- a/src/ray/core_worker_rpc_client/core_worker_client_pool.cc +++ b/src/ray/core_worker_rpc_client/core_worker_client_pool.cc @@ -92,9 +92,9 @@ std::function CoreWorkerClientPool::GetDefaultUnavailableTimeoutCallback }; if (gcs_client->Nodes().IsSubscribedToNodeChange()) { - auto *node_info = gcs_client->Nodes().GetNodeAddressAndLiveness( + auto node_info = gcs_client->Nodes().GetNodeAddressAndLiveness( node_id, /*filter_dead_nodes=*/false); - if (node_info == nullptr) { + if (!node_info) { // Node could be dead or info may have not made it to the subscriber cache yet. // Check with the GCS to confirm if the node is dead. gcs_check_node_alive(); diff --git a/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc b/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc index 727168213ef3..013e6c1fb1f0 100644 --- a/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc +++ b/src/ray/core_worker_rpc_client/tests/core_worker_client_pool_test.cc @@ -100,9 +100,7 @@ class MockGcsClientNodeAccessor : public gcs::NodeInfoAccessor { bool IsSubscribedToNodeChange() const override { return is_subscribed_to_node_change_; } - MOCK_METHOD(const rpc::GcsNodeInfo *, Get, (const NodeID &, bool), (const, override)); - - MOCK_METHOD(const rpc::GcsNodeAddressAndLiveness *, + MOCK_METHOD(std::optional, GetNodeAddressAndLiveness, (const NodeID &, bool), (const, override)); @@ -212,16 +210,16 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, NodeDeath) { if (is_subscribed_to_node_change_) { EXPECT_CALL(mock_node_accessor, GetNodeAddressAndLiveness(worker_1_node_id, /*filter_dead_nodes=*/false)) - .WillOnce(Return(nullptr)) - .WillOnce(Return(&node_info_alive)) - .WillOnce(Return(&node_info_dead)); + .WillOnce(Return(std::nullopt)) + .WillOnce(Return(node_info_alive)) + .WillOnce(Return(node_info_dead)); EXPECT_CALL( mock_node_accessor, AsyncGetAllNodeAddressAndLiveness(_, _, std::vector{worker_1_node_id})) .WillOnce(invoke_with_node_info_vector({node_info_alive})); EXPECT_CALL(mock_node_accessor, GetNodeAddressAndLiveness(worker_2_node_id, /*filter_dead_nodes=*/false)) - .WillOnce(Return(nullptr)); + .WillOnce(Return(std::nullopt)); EXPECT_CALL( mock_node_accessor, AsyncGetAllNodeAddressAndLiveness(_, _, std::vector{worker_2_node_id})) @@ -279,7 +277,7 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, WorkerDeath) { EXPECT_CALL(gcs_client_.MockNodeAccessor(), GetNodeAddressAndLiveness(_, /*filter_dead_nodes=*/false)) .Times(2) - .WillRepeatedly(Return(&node_info_alive)); + .WillRepeatedly(Return(node_info_alive)); } else { EXPECT_CALL(gcs_client_.MockNodeAccessor(), AsyncGetAllNodeAddressAndLiveness(_, _, _)) diff --git a/src/ray/gcs_rpc_client/accessor.cc b/src/ray/gcs_rpc_client/accessor.cc index 68ec43e04464..68e85dc0e496 100644 --- a/src/ray/gcs_rpc_client/accessor.cc +++ b/src/ray/gcs_rpc_client/accessor.cc @@ -295,48 +295,6 @@ void NodeInfoAccessor::AsyncGetAll(const MultiItemCallback &ca timeout_ms); } -void NodeInfoAccessor::AsyncSubscribeToNodeChange( - std::function subscribe, - StatusCallback done) { - /** - 1. Subscribe to node info - 2. Once the subscription is made, ask for all node info. - 3. Once all node info is received, call done callback. - 4. HandleNotification can handle conflicts between the subscription updates and - GetAllNodeInfo because nodes can only go from alive to dead, never back to alive. - Note that this only works because state is the only mutable field, otherwise we'd - have to queue processing subscription updates until the initial population from - AsyncGetAll is done. - */ - RAY_CHECK(node_change_callback_address_and_liveness_ == nullptr) - << "Subscriber is already subscribed to GCS_NODE_ADDRESS_AND_LIVENESS_CHANNEL, " - "subscribing to GCS_NODE_INFO_CHANNEL in addition is a waste of resources and " - "likely a bug."; - RAY_CHECK(node_change_callback_ == nullptr); - node_change_callback_ = std::move(subscribe); - RAY_CHECK(node_change_callback_ != nullptr); - - fetch_node_data_operation_ = [this](const StatusCallback &done_callback) { - AsyncGetAll( - [this, done_callback](const Status &status, - std::vector &&node_info_list) { - for (auto &node_info : node_info_list) { - HandleNotification(std::move(node_info)); - } - if (done_callback) { - done_callback(status); - } - }, - /*timeout_ms=*/-1); - }; - - client_impl_->GetGcsSubscriber().SubscribeAllNodeInfo( - /*subscribe=*/[this]( - rpc::GcsNodeInfo &&data) { HandleNotification(std::move(data)); }, - /*done=*/[this, done = std::move(done)]( - const Status &) { fetch_node_data_operation_(done); }); -} - void NodeInfoAccessor::AsyncSubscribeToNodeAddressAndLivenessChange( std::function subscribe, StatusCallback done) { @@ -350,11 +308,6 @@ void NodeInfoAccessor::AsyncSubscribeToNodeAddressAndLivenessChange( have to queue processing subscription updates until the initial population from AsyncGetAll is done. */ - RAY_CHECK(node_change_callback_ == nullptr) - << "Subscriber is already subscribed to GCS_NODE_INFO_CHANNEL, " - "subscribing to GCS_NODE_ADDRESS_AND_LIVENESS_CHANNEL in addition is a waste of " - "resources and " - "likely a bug."; RAY_CHECK(node_change_callback_address_and_liveness_ == nullptr); node_change_callback_address_and_liveness_ = std::move(subscribe); RAY_CHECK(node_change_callback_address_and_liveness_ != nullptr); @@ -383,38 +336,23 @@ void NodeInfoAccessor::AsyncSubscribeToNodeAddressAndLivenessChange( &) { fetch_node_address_and_liveness_data_operation_(done); }); } -const rpc::GcsNodeInfo *NodeInfoAccessor::Get(const NodeID &node_id, - bool filter_dead_nodes) const { - RAY_CHECK(!node_id.IsNil()); - auto entry = node_cache_.find(node_id); - if (entry != node_cache_.end()) { - if (filter_dead_nodes && entry->second.state() == rpc::GcsNodeInfo::DEAD) { - return nullptr; - } - return &entry->second; - } - return nullptr; -} - -const rpc::GcsNodeAddressAndLiveness *NodeInfoAccessor::GetNodeAddressAndLiveness( +std::optional NodeInfoAccessor::GetNodeAddressAndLiveness( const NodeID &node_id, bool filter_dead_nodes) const { RAY_CHECK(!node_id.IsNil()); + absl::MutexLock lock(&node_cache_address_and_liveness_mutex_); auto entry = node_cache_address_and_liveness_.find(node_id); if (entry != node_cache_address_and_liveness_.end()) { if (filter_dead_nodes && entry->second.state() == rpc::GcsNodeInfo::DEAD) { - return nullptr; + return std::nullopt; } - return &entry->second; + return entry->second; } - return nullptr; + return std::nullopt; } -const absl::flat_hash_map &NodeInfoAccessor::GetAll() const { - return node_cache_; -} - -const absl::flat_hash_map - &NodeInfoAccessor::GetAllNodeAddressAndLiveness() const { +absl::flat_hash_map +NodeInfoAccessor::GetAllNodeAddressAndLiveness() const { + absl::MutexLock lock(&node_cache_address_and_liveness_mutex_); return node_cache_address_and_liveness_; } @@ -450,127 +388,75 @@ Status NodeInfoAccessor::CheckAlive(const std::vector &node_ids, } bool NodeInfoAccessor::IsNodeDead(const NodeID &node_id) const { - if (node_change_callback_ != nullptr) { - auto node_iter = node_cache_.find(node_id); - return node_iter != node_cache_.end() && - node_iter->second.state() == rpc::GcsNodeInfo::DEAD; - } else { - auto node_iter = node_cache_address_and_liveness_.find(node_id); - return node_iter != node_cache_address_and_liveness_.end() && - node_iter->second.state() == rpc::GcsNodeInfo::DEAD; - } + absl::MutexLock lock(&node_cache_address_and_liveness_mutex_); + auto node_iter = node_cache_address_and_liveness_.find(node_id); + return node_iter != node_cache_address_and_liveness_.end() && + node_iter->second.state() == rpc::GcsNodeInfo::DEAD; } -void NodeInfoAccessor::HandleNotification(rpc::GcsNodeInfo &&node_info) { - NodeID node_id = NodeID::FromBinary(node_info.node_id()); - bool is_alive = (node_info.state() == rpc::GcsNodeInfo::ALIVE); - auto entry = node_cache_.find(node_id); - bool is_notif_new; - if (entry == node_cache_.end()) { - // If the entry is not in the cache, then the notification is new. - is_notif_new = true; - } else { - // If the entry is in the cache, then the notification is new if the node - // was alive and is now dead or resources have been updated. - bool was_alive = (entry->second.state() == rpc::GcsNodeInfo::ALIVE); - is_notif_new = was_alive && !is_alive; - - // Once a node with a given ID has been removed, it should never be added - // again. If the entry was in the cache and the node was deleted, we should check - // that this new notification is not an insertion. - // However, when a new node(node-B) registers with GCS, it subscribes to all node - // information. It will subscribe to redis and then get all node information from GCS - // through RPC. If node-A fails after GCS replies to node-B, GCS will send another - // message(node-A is dead) to node-B through redis publish. Because RPC and redis - // subscribe are two different sessions, node-B may process node-A dead message first - // and then node-A alive message. So we use `RAY_LOG` instead of `RAY_CHECK ` as a - // workaround. - if (!was_alive && is_alive) { - RAY_LOG(INFO) << "Notification for addition of a node that was already removed:" - << node_id; - return; - } - } - - // Add the notification to our cache. - RAY_LOG(INFO).WithField(node_id) - << "Received notification for node, IsAlive = " << is_alive; - - auto &node = node_cache_[node_id]; - if (is_alive) { - node = std::move(node_info); - } else { - node.set_node_id(node_info.node_id()); - node.set_state(rpc::GcsNodeInfo::DEAD); - node.mutable_death_info()->CopyFrom(node_info.death_info()); - node.set_end_time_ms(node_info.end_time_ms()); - } - - // If the notification is new, call registered callback. - if (is_notif_new && node_change_callback_ != nullptr) { - node_change_callback_(node_id, node_cache_[node_id]); - } +bool NodeInfoAccessor::IsNodeAlive(const NodeID &node_id) const { + absl::MutexLock lock(&node_cache_address_and_liveness_mutex_); + auto node_iter = node_cache_address_and_liveness_.find(node_id); + return node_iter != node_cache_address_and_liveness_.end() && + node_iter->second.state() == rpc::GcsNodeInfo::ALIVE; } void NodeInfoAccessor::HandleNotification(rpc::GcsNodeAddressAndLiveness &&node_info) { NodeID node_id = NodeID::FromBinary(node_info.node_id()); bool is_alive = (node_info.state() == rpc::GcsNodeInfo::ALIVE); - auto entry = node_cache_address_and_liveness_.find(node_id); - bool is_notif_new; - if (entry == node_cache_address_and_liveness_.end()) { - // If the entry is not in the cache, then the notification is new. - is_notif_new = true; - } else { - // If the entry is in the cache, then the notification is new if the node - // was alive and is now dead. - bool was_alive = (entry->second.state() == rpc::GcsNodeInfo::ALIVE); - is_notif_new = was_alive && !is_alive; - - // Handle the same logic as in HandleNotification for preventing re-adding removed - // nodes - if (!was_alive && is_alive) { - RAY_LOG(INFO) << "Address and liveness notification for addition of a node that " - "was already removed:" - << node_id; - return; + std::optional node_info_copy_for_callback; + { + absl::MutexLock lock(&node_cache_address_and_liveness_mutex_); + + auto entry = node_cache_address_and_liveness_.find(node_id); + bool is_notif_new; + if (entry == node_cache_address_and_liveness_.end()) { + // If the entry is not in the cache, then the notification is new. + is_notif_new = true; + } else { + // If the entry is in the cache, then the notification is new if the node + // was alive and is now dead. + bool was_alive = (entry->second.state() == rpc::GcsNodeInfo::ALIVE); + is_notif_new = was_alive && !is_alive; + + // Handle the same logic as in HandleNotification for preventing re-adding removed + // nodes + if (!was_alive && is_alive) { + RAY_LOG(INFO) << "Address and liveness notification for addition of a node that " + "was already removed:" + << node_id; + return; + } } - } - // Add the notification to our address and liveness cache. - RAY_LOG(INFO).WithField(node_id) - << "Received address and liveness notification for node, IsAlive = " << is_alive; + // Add the notification to our address and liveness cache. + RAY_LOG(INFO).WithField(node_id) + << "Received address and liveness notification for node, IsAlive = " << is_alive; + + auto &node = node_cache_address_and_liveness_[node_id]; + if (is_alive) { + node = std::move(node_info); + } else { + node.set_node_id(node_info.node_id()); + node.set_state(rpc::GcsNodeInfo::DEAD); + if (node_info.has_death_info()) { + *node.mutable_death_info() = std::move(*node_info.mutable_death_info()); + } + } - auto &node = node_cache_address_and_liveness_[node_id]; - if (is_alive) { - node = std::move(node_info); - } else { - node.set_node_id(node_info.node_id()); - node.set_state(rpc::GcsNodeInfo::DEAD); - if (node_info.has_death_info()) { - node.mutable_death_info()->CopyFrom(node_info.death_info()); + if (is_notif_new && node_change_callback_address_and_liveness_ != nullptr) { + node_info_copy_for_callback = node; } } // If the notification is new, call registered callback. - if (is_notif_new && node_change_callback_address_and_liveness_ != nullptr) { - node_change_callback_address_and_liveness_(node_id, - node_cache_address_and_liveness_[node_id]); + if (node_info_copy_for_callback) { + node_change_callback_address_and_liveness_(node_id, *node_info_copy_for_callback); } } void NodeInfoAccessor::AsyncResubscribe() { RAY_LOG(DEBUG) << "Reestablishing subscription for node info."; - if (node_change_callback_ != nullptr) { - client_impl_->GetGcsSubscriber().SubscribeAllNodeInfo( - /*subscribe=*/[this](rpc::GcsNodeInfo - &&data) { HandleNotification(std::move(data)); }, - /*done=*/ - [this](const Status &) { - fetch_node_data_operation_([](const Status &) { - RAY_LOG(INFO) << "Finished fetching all node information for resubscription."; - }); - }); - } if (node_change_callback_address_and_liveness_ != nullptr) { client_impl_->GetGcsSubscriber().SubscribeAllNodeAddressAndLiveness( /*subscribe=*/[this](rpc::GcsNodeAddressAndLiveness diff --git a/src/ray/gcs_rpc_client/accessor.h b/src/ray/gcs_rpc_client/accessor.h index e1b5e09aa1a4..0a5a984b11ca 100644 --- a/src/ray/gcs_rpc_client/accessor.h +++ b/src/ray/gcs_rpc_client/accessor.h @@ -14,10 +14,12 @@ #pragma once #include +#include #include #include #include +#include "absl/synchronization/mutex.h" #include "ray/common/gcs_callback_types.h" #include "ray/common/id.h" #include "ray/common/placement_group.h" @@ -170,42 +172,27 @@ class NodeInfoAccessor { int64_t timeout_ms, const std::vector &node_ids = {}); - /// Subscribe to node addition and removal events from GCS and cache those information. - /// - /// \param subscribe Callback that will be called if a node is - /// added or a node is removed. The callback needs to be idempotent because it will also - /// be called for existing nodes. - /// \param done Callback that will be called when subscription is complete. - virtual void AsyncSubscribeToNodeChange( - std::function subscribe, - StatusCallback done); - /// Get node information from local cache. - /// Non-thread safe. - /// Note, the local cache is only available if `AsyncSubscribeToNodeChange` - /// is called before. + /// Thread-safe. + /// Note, the local cache is only available if + /// `AsyncSubscribeToNodeAddressAndLivenessChange` is called before. /// /// \param node_id The ID of node to look up in local cache. /// \param filter_dead_nodes Whether or not if this method will filter dead nodes. /// \return The item returned by GCS. If the item to read doesn't exist or the node is - virtual /// dead, this optional object is empty. - const rpc::GcsNodeInfo * - Get(const NodeID &node_id, bool filter_dead_nodes = true) const; - - virtual /// dead, this optional object is empty. - const rpc::GcsNodeAddressAndLiveness * - GetNodeAddressAndLiveness(const NodeID &node_id, - bool filter_dead_nodes = true) const; + /// dead, this optional object is empty. + virtual std::optional GetNodeAddressAndLiveness( + const NodeID &node_id, bool filter_dead_nodes = true) const; /// Get information of all nodes from local cache. - /// Non-thread safe. - /// Note, the local cache is only available if `AsyncSubscribeToNodeChange` - /// is called before. + /// Thread-safe. + /// Note, the local cache is only available if + /// `AsyncSubscribeToNodeAddressAndLivenessChange` is called before. /// /// \return All nodes in cache. - virtual const absl::flat_hash_map &GetAll() const; - virtual const absl::flat_hash_map - &GetAllNodeAddressAndLiveness() const; + + virtual absl::flat_hash_map + GetAllNodeAddressAndLiveness() const; /// Get information of all nodes from an RPC to GCS synchronously with optional filters. /// @@ -216,9 +203,8 @@ class NodeInfoAccessor { std::optional node_selector = std::nullopt); - /// Subscribe to only critical node information changes. This method works similarly to - /// AsyncSubscribeToNodeChange but will only transmit address and liveness information - /// for each node and will exclude other information. + /// Subscribe to critical node information changes. This method transmits only address + /// and liveness information for each node, excluding other node metadata. /// /// \param subscribe Callback that will be called if a node is /// added or a node is removed. The callback needs to be idempotent because it will also @@ -257,11 +243,17 @@ class NodeInfoAccessor { /// 2. The node is alive and we have that information in the cache. /// 3. The GCS has evicted the node from its dead node cache based on /// maximum_gcs_dead_node_cached_count - /// Non-thread safe. - /// Note, the local cache is only available if `AsyncSubscribeToNodeChange` is called - /// before. + /// Hence we only return true if we're confident that the node is dead. + /// Thread-safe. + /// Note, the local cache is only available if + /// `AsyncSubscribeToNodeAddressAndLivenessChange` is called before. virtual bool IsNodeDead(const NodeID &node_id) const; + /// NOTE: This is NOT equivalent to !IsNodeDead(node_id) due to the gray area mentioned + /// in the comment above. Thus we only return true if we're confident that the node is + /// alive. + virtual bool IsNodeAlive(const NodeID &node_id) const; + /// Reestablish subscription. /// This should be called when GCS server restarts from a failure. /// PubSub server restart will cause GCS server restart. In this case, we need to @@ -269,39 +261,32 @@ class NodeInfoAccessor { /// server. virtual void AsyncResubscribe(); - /// Add a node to accessor cache. - virtual void HandleNotification(rpc::GcsNodeInfo &&node_info); - /// Add rpc::GcsNodeAddressAndLiveness information to accessor cache. virtual void HandleNotification(rpc::GcsNodeAddressAndLiveness &&node_info); virtual bool IsSubscribedToNodeChange() const { - return node_change_callback_ != nullptr || - node_change_callback_address_and_liveness_ != nullptr; + return node_change_callback_address_and_liveness_ != nullptr; } private: /// Save the fetch data operations in these functions, so we can call them again when /// GCS server restarts from a failure. - FetchDataOperation fetch_node_data_operation_; FetchDataOperation fetch_node_address_and_liveness_data_operation_; GcsClient *client_impl_; - /// The callback to call when a new node is added or a node is removed. - std::function node_change_callback_ = nullptr; - - /// A cache for information about all nodes. - absl::flat_hash_map node_cache_; - /// The callback to call when a new node is added or a node is removed when leveraging /// the GcsNodeAddressAndLiveness version of the node api std::function node_change_callback_address_and_liveness_ = nullptr; + /// Mutex to protect node_cache_address_and_liveness_ for thread-safe access + mutable absl::Mutex node_cache_address_and_liveness_mutex_; + /// A cache for information about all nodes when using the address and liveness api absl::flat_hash_map - node_cache_address_and_liveness_; + node_cache_address_and_liveness_ + ABSL_GUARDED_BY(node_cache_address_and_liveness_mutex_); // TODO(dayshah): Need to refactor gcs client / accessor to avoid this. // https://github.com/ray-project/ray/issues/54805 diff --git a/src/ray/gcs_rpc_client/tests/accessor_test.cc b/src/ray/gcs_rpc_client/tests/accessor_test.cc index 1ac66bf18151..5dc81cb6f72f 100644 --- a/src/ray/gcs_rpc_client/tests/accessor_test.cc +++ b/src/ray/gcs_rpc_client/tests/accessor_test.cc @@ -28,28 +28,34 @@ TEST(NodeInfoAccessorTest, TestHandleNotification) { NodeInfoAccessor accessor; int num_notifications = 0; - accessor.node_change_callback_ = [&](NodeID, const rpc::GcsNodeInfo &) { - num_notifications++; - }; + accessor.node_change_callback_address_and_liveness_ = + [&](NodeID, const rpc::GcsNodeAddressAndLiveness &) { num_notifications++; }; NodeID node_id = NodeID::FromRandom(); - rpc::GcsNodeInfo node_info; + rpc::GcsNodeAddressAndLiveness node_info; node_info.set_node_id(node_id.Binary()); node_info.set_state(rpc::GcsNodeInfo::ALIVE); - accessor.HandleNotification(rpc::GcsNodeInfo(node_info)); - const auto *gotten_node_info = accessor.Get(node_id, /*filter_dead_nodes=*/false); + accessor.HandleNotification(rpc::GcsNodeAddressAndLiveness(node_info)); + auto gotten_node_info = + accessor.GetNodeAddressAndLiveness(node_id, /*filter_dead_nodes=*/false); + ASSERT_TRUE(gotten_node_info.has_value()); ASSERT_EQ(gotten_node_info->node_id(), node_id.Binary()); ASSERT_EQ(gotten_node_info->state(), rpc::GcsNodeInfo::ALIVE); node_info.set_state(rpc::GcsNodeInfo::DEAD); - accessor.HandleNotification(rpc::GcsNodeInfo(node_info)); - gotten_node_info = accessor.Get(node_id, /*filter_dead_nodes=*/false); + accessor.HandleNotification(rpc::GcsNodeAddressAndLiveness(node_info)); + gotten_node_info = + accessor.GetNodeAddressAndLiveness(node_id, /*filter_dead_nodes=*/false); + ASSERT_TRUE(gotten_node_info.has_value()); ASSERT_EQ(gotten_node_info->state(), rpc::GcsNodeInfo::DEAD); - ASSERT_EQ(accessor.Get(node_id, /*filter_dead_nodes=*/true), nullptr); + ASSERT_FALSE(accessor.GetNodeAddressAndLiveness(node_id, /*filter_dead_nodes=*/true) + .has_value()); node_info.set_state(rpc::GcsNodeInfo::ALIVE); - accessor.HandleNotification(rpc::GcsNodeInfo(node_info)); - gotten_node_info = accessor.Get(node_id, /*filter_dead_nodes=*/false); + accessor.HandleNotification(rpc::GcsNodeAddressAndLiveness(node_info)); + gotten_node_info = + accessor.GetNodeAddressAndLiveness(node_id, /*filter_dead_nodes=*/false); + ASSERT_TRUE(gotten_node_info.has_value()); ASSERT_EQ(gotten_node_info->state(), rpc::GcsNodeInfo::DEAD); ASSERT_EQ(num_notifications, 2); @@ -57,7 +63,7 @@ TEST(NodeInfoAccessorTest, TestHandleNotification) { TEST(NodeInfoAccessorTest, TestHandleNotificationDeathInfo) { NodeInfoAccessor accessor; - rpc::GcsNodeInfo node_info; + rpc::GcsNodeAddressAndLiveness node_info; node_info.set_state(rpc::GcsNodeInfo_GcsNodeState::GcsNodeInfo_GcsNodeState_DEAD); NodeID node_id = NodeID::FromRandom(); node_info.set_node_id(node_id.Binary()); @@ -66,12 +72,10 @@ TEST(NodeInfoAccessorTest, TestHandleNotificationDeathInfo) { death_info->set_reason(rpc::NodeDeathInfo::EXPECTED_TERMINATION); death_info->set_reason_message("Test termination reason"); - node_info.set_end_time_ms(12345678); - accessor.HandleNotification(std::move(node_info)); - auto cached_node = accessor.Get(node_id, false); - ASSERT_NE(cached_node, nullptr); + auto cached_node = accessor.GetNodeAddressAndLiveness(node_id, false); + ASSERT_TRUE(cached_node.has_value()); ASSERT_EQ(cached_node->node_id(), node_id.Binary()); ASSERT_EQ(cached_node->state(), rpc::GcsNodeInfo_GcsNodeState::GcsNodeInfo_GcsNodeState_DEAD); @@ -79,7 +83,6 @@ TEST(NodeInfoAccessorTest, TestHandleNotificationDeathInfo) { ASSERT_TRUE(cached_node->has_death_info()); ASSERT_EQ(cached_node->death_info().reason(), rpc::NodeDeathInfo::EXPECTED_TERMINATION); ASSERT_EQ(cached_node->death_info().reason_message(), "Test termination reason"); - ASSERT_EQ(cached_node->end_time_ms(), 12345678); } int main(int argc, char **argv) { diff --git a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc index 7d1efd6f1279..17713e08e6a6 100644 --- a/src/ray/gcs_rpc_client/tests/gcs_client_test.cc +++ b/src/ray/gcs_rpc_client/tests/gcs_client_test.cc @@ -368,10 +368,10 @@ class GcsClientTest : public ::testing::TestWithParam { return actors; } - bool SubscribeToNodeChange( - std::function subscribe) { + bool SubscribeToNodeAddressAndLivenessChange( + std::function subscribe) { std::promise promise; - gcs_client_->Nodes().AsyncSubscribeToNodeChange( + gcs_client_->Nodes().AsyncSubscribeToNodeAddressAndLivenessChange( subscribe, [&promise](Status status) { promise.set_value(status.ok()); }); return WaitReady(promise.get_future(), timeout_ms_); } @@ -608,15 +608,16 @@ TEST_P(GcsClientTest, TestNodeInfo) { // Subscribe to node addition and removal events from GCS. std::atomic register_count(0); std::atomic unregister_count(0); - auto on_subscribe = [®ister_count, &unregister_count](const NodeID &node_id, - const rpc::GcsNodeInfo &data) { + auto on_subscribe = [®ister_count, &unregister_count]( + const NodeID &node_id, + const rpc::GcsNodeAddressAndLiveness &data) { if (data.state() == rpc::GcsNodeInfo::ALIVE) { ++register_count; } else if (data.state() == rpc::GcsNodeInfo::DEAD) { ++unregister_count; } }; - ASSERT_TRUE(SubscribeToNodeChange(on_subscribe)); + ASSERT_TRUE(SubscribeToNodeAddressAndLivenessChange(on_subscribe)); // Register local node to GCS. RegisterSelf(*gcs_node1_info); @@ -631,9 +632,9 @@ TEST_P(GcsClientTest, TestNodeInfo) { // Get information of all nodes from GCS. std::vector node_list = GetNodeInfoList(); EXPECT_EQ(node_list.size(), 2); - ASSERT_TRUE(gcs_client_->Nodes().Get(node1_id)); - ASSERT_TRUE(gcs_client_->Nodes().Get(node2_id)); - EXPECT_EQ(gcs_client_->Nodes().GetAll().size(), 2); + ASSERT_TRUE(gcs_client_->Nodes().GetNodeAddressAndLiveness(node1_id).has_value()); + ASSERT_TRUE(gcs_client_->Nodes().GetNodeAddressAndLiveness(node2_id).has_value()); + EXPECT_EQ(gcs_client_->Nodes().GetAllNodeAddressAndLiveness().size(), 2); } TEST_P(GcsClientTest, TestUnregisterNode) { @@ -801,11 +802,12 @@ TEST_P(GcsClientTest, TestNodeTableResubscribe) { // Test that subscription of the node table can still work when GCS server restarts. // Subscribe to node addition and removal events from GCS and cache those information. std::atomic node_change_count(0); - auto node_subscribe = [&node_change_count](const NodeID &id, - const rpc::GcsNodeInfo &result) { + auto node_subscribe = [&node_change_count]( + const NodeID &id, + const rpc::GcsNodeAddressAndLiveness &result) { ++node_change_count; }; - ASSERT_TRUE(SubscribeToNodeChange(node_subscribe)); + ASSERT_TRUE(SubscribeToNodeAddressAndLivenessChange(node_subscribe)); auto node_info = GenNodeInfo(1); ASSERT_TRUE(RegisterNode(*node_info)); diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 646198cce3f8..55547e4f5994 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -720,9 +720,9 @@ std::shared_ptr ObjectManager::GetRpcClient( if (it != remote_object_manager_clients_.end()) { return it->second; } - auto *node_info = + auto node_info = gcs_client_.Nodes().GetNodeAddressAndLiveness(node_id, /*filter_dead_nodes=*/true); - if (node_info == nullptr) { + if (!node_info) { return nullptr; } auto object_manager_client = diff --git a/src/ray/object_manager/tests/object_manager_test.cc b/src/ray/object_manager/tests/object_manager_test.cc index c983adfa0bbd..00a7e706fcd9 100644 --- a/src/ray/object_manager/tests/object_manager_test.cc +++ b/src/ray/object_manager/tests/object_manager_test.cc @@ -128,10 +128,10 @@ TEST_F(ObjectManagerTest, TestFreeObjectsLocalOnlyFalse) { node_info_map_[remote_node_id_] = remote_node_info; EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, GetAllNodeAddressAndLiveness()) - .WillOnce(::testing::ReturnRef(node_info_map_)); + .WillOnce(::testing::Return(node_info_map_)); EXPECT_CALL(*mock_gcs_client_->mock_node_accessor, GetNodeAddressAndLiveness(remote_node_id_, _)) - .WillOnce(::testing::Return(&remote_node_info)); + .WillOnce(::testing::Return(remote_node_info)); fake_plasma_client_->objects_in_plasma_[object_id] = std::make_pair(std::vector(1), std::vector(1)); diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 6199afd93c4f..df436cabe9c6 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -812,8 +812,7 @@ int main(int argc, char *argv[]) { node_manager_config.resource_config.GetResourceMap(), /*is_node_available_fn*/ [&](ray::scheduling::NodeID id) { - return gcs_client->Nodes().GetNodeAddressAndLiveness( - ray::NodeID::FromBinary(id.Binary())) != nullptr; + return gcs_client->Nodes().IsNodeAlive(ray::NodeID::FromBinary(id.Binary())); }, /*get_used_object_store_memory*/ [&]() { @@ -838,8 +837,7 @@ int main(int argc, char *argv[]) { auto get_node_info_func = [&](const ray::NodeID &id) -> std::optional { - auto ptr = gcs_client->Nodes().GetNodeAddressAndLiveness(id); - return ptr ? std::optional(*ptr) : std::nullopt; + return gcs_client->Nodes().GetNodeAddressAndLiveness(id); }; auto announce_infeasible_lease = [](const ray::RayLease &lease) { /// Publish the infeasible lease error to GCS so that drivers can subscribe to it @@ -909,8 +907,7 @@ int main(int argc, char *argv[]) { *local_lease_manager); auto raylet_client_factory = [&](const ray::NodeID &id) { - const ray::rpc::GcsNodeAddressAndLiveness *node_info = - gcs_client->Nodes().GetNodeAddressAndLiveness(id); + auto node_info = gcs_client->Nodes().GetNodeAddressAndLiveness(id); RAY_CHECK(node_info) << "No GCS info for node " << id; auto addr = ray::rpc::RayletClientPool::GenerateRayletAddress( id, node_info->node_manager_address(), node_info->node_manager_port()); diff --git a/src/ray/raylet/scheduling/tests/cluster_lease_manager_test.cc b/src/ray/raylet/scheduling/tests/cluster_lease_manager_test.cc index f61d78d43a88..5c1dcbff4a28 100644 --- a/src/ray/raylet/scheduling/tests/cluster_lease_manager_test.cc +++ b/src/ray/raylet/scheduling/tests/cluster_lease_manager_test.cc @@ -265,8 +265,7 @@ std::shared_ptr CreateSingleNodeScheduler( scheduling::NodeID(id), local_node_resources, /*is_node_available_fn*/ [&gcs_client](scheduling::NodeID node_id) { - return gcs_client.Nodes().GetNodeAddressAndLiveness( - NodeID::FromBinary(node_id.Binary())) != nullptr; + return gcs_client.Nodes().IsNodeAlive(NodeID::FromBinary(node_id.Binary())); }); return scheduler; @@ -432,9 +431,8 @@ class ClusterLeaseManagerTest : public ::testing::Test { void SetUp() { static rpc::GcsNodeAddressAndLiveness node_info; - ON_CALL(*gcs_client_->mock_node_accessor, - GetNodeAddressAndLiveness(::testing::_, ::testing::_)) - .WillByDefault(::testing::Return(&node_info)); + ON_CALL(*gcs_client_->mock_node_accessor, IsNodeAlive(::testing::_)) + .WillByDefault(::testing::Return(true)); } RayObject *MakeDummyArg() { diff --git a/src/ray/raylet/scheduling/tests/cluster_resource_scheduler_test.cc b/src/ray/raylet/scheduling/tests/cluster_resource_scheduler_test.cc index 2c8bceb8aced..4728cb03a4c8 100644 --- a/src/ray/raylet/scheduling/tests/cluster_resource_scheduler_test.cc +++ b/src/ray/raylet/scheduling/tests/cluster_resource_scheduler_test.cc @@ -106,14 +106,12 @@ class ClusterResourceSchedulerTest : public ::testing::Test { // policy. gcs_client_ = std::make_unique(); is_node_available_fn_ = [this](scheduling::NodeID node_id) { - return gcs_client_->Nodes().GetNodeAddressAndLiveness( - NodeID::FromBinary(node_id.Binary())) != nullptr; + return gcs_client_->Nodes().IsNodeAlive(NodeID::FromBinary(node_id.Binary())); }; node_name = NodeID::FromRandom().Binary(); node_info.set_node_id(node_name); - ON_CALL(*gcs_client_->mock_node_accessor, - GetNodeAddressAndLiveness(::testing::_, ::testing::_)) - .WillByDefault(::testing::Return(&node_info)); + ON_CALL(*gcs_client_->mock_node_accessor, IsNodeAlive(::testing::_)) + .WillByDefault(::testing::Return(true)); } void Shutdown() {} @@ -1092,10 +1090,9 @@ TEST_F(ClusterResourceSchedulerTest, DeadNodeTest) { std::string(), &violations, &is_infeasible)); - EXPECT_CALL(*gcs_client_->mock_node_accessor, - GetNodeAddressAndLiveness(node_id, ::testing::_)) - .WillOnce(::testing::Return(nullptr)) - .WillOnce(::testing::Return(nullptr)); + EXPECT_CALL(*gcs_client_->mock_node_accessor, IsNodeAlive(node_id)) + .WillOnce(::testing::Return(false)) + .WillOnce(::testing::Return(false)); ASSERT_TRUE(resource_scheduler .GetBestSchedulableNode(resource, LabelSelector(), diff --git a/src/ray/raylet/tests/local_lease_manager_test.cc b/src/ray/raylet/tests/local_lease_manager_test.cc index aca35b042fc5..d67bc6a599a4 100644 --- a/src/ray/raylet/tests/local_lease_manager_test.cc +++ b/src/ray/raylet/tests/local_lease_manager_test.cc @@ -260,7 +260,7 @@ std::shared_ptr CreateSingleNodeScheduler( scheduling::NodeID(id), local_node_resources, /*is_node_available_fn*/ [&gcs_client](scheduling::NodeID node_id) { - return gcs_client.Nodes().Get(NodeID::FromBinary(node_id.Binary())) != nullptr; + return gcs_client.Nodes().IsNodeAlive(NodeID::FromBinary(node_id.Binary())); }); return scheduler; @@ -349,9 +349,10 @@ class LocalLeaseManagerTest : public ::testing::Test { /*get_time=*/[this]() { return current_time_ms_; })) {} void SetUp() override { - static rpc::GcsNodeInfo node_info; - ON_CALL(*gcs_client_->mock_node_accessor, Get(::testing::_, ::testing::_)) - .WillByDefault(::testing::Return(&node_info)); + static rpc::GcsNodeAddressAndLiveness node_info; + ON_CALL(*gcs_client_->mock_node_accessor, + GetNodeAddressAndLiveness(::testing::_, ::testing::_)) + .WillByDefault(::testing::Return(node_info)); } RayObject *MakeDummyArg() { diff --git a/src/ray/raylet/tests/node_manager_test.cc b/src/ray/raylet/tests/node_manager_test.cc index b53b0fe4f987..a52b27c0f573 100644 --- a/src/ray/raylet/tests/node_manager_test.cc +++ b/src/ray/raylet/tests/node_manager_test.cc @@ -354,8 +354,8 @@ class NodeManagerTest : public ::testing::Test { node_manager_config.resource_config.GetResourceMap(), /*is_node_available_fn*/ [&](ray::scheduling::NodeID node_id) { - return mock_gcs_client_->Nodes().Get(NodeID::FromBinary(node_id.Binary())) != - nullptr; + return mock_gcs_client_->Nodes().IsNodeAlive( + NodeID::FromBinary(node_id.Binary())); }, /*get_used_object_store_memory*/ [&]() { @@ -379,8 +379,7 @@ class NodeManagerTest : public ::testing::Test { node_manager_config.labels); auto get_node_info_func = [&](const NodeID &node_id) { - auto ptr = mock_gcs_client_->Nodes().GetNodeAddressAndLiveness(node_id); - return ptr ? std::optional(*ptr) : std::nullopt; + return mock_gcs_client_->Nodes().GetNodeAddressAndLiveness(node_id); }; auto max_task_args_memory = static_cast( diff --git a/src/ray/raylet/tests/placement_group_resource_manager_test.cc b/src/ray/raylet/tests/placement_group_resource_manager_test.cc index 494e81941ee4..28e00934fa7e 100644 --- a/src/ray/raylet/tests/placement_group_resource_manager_test.cc +++ b/src/ray/raylet/tests/placement_group_resource_manager_test.cc @@ -75,14 +75,15 @@ class NewPlacementGroupResourceManagerTest : public ::testing::Test { std::shared_ptr cluster_resource_scheduler_; std::unique_ptr gcs_client_; std::function is_node_available_fn_; - rpc::GcsNodeInfo node_info_; + rpc::GcsNodeAddressAndLiveness node_info_; void SetUp() { gcs_client_ = std::make_unique(); is_node_available_fn_ = [this](scheduling::NodeID node_id) { - return gcs_client_->Nodes().Get(NodeID::FromBinary(node_id.Binary())) != nullptr; + return gcs_client_->Nodes().IsNodeAlive(NodeID::FromBinary(node_id.Binary())); }; - EXPECT_CALL(*gcs_client_->mock_node_accessor, Get(::testing::_, ::testing::_)) - .WillRepeatedly(::testing::Return(&node_info_)); + EXPECT_CALL(*gcs_client_->mock_node_accessor, + GetNodeAddressAndLiveness(::testing::_, ::testing::_)) + .WillRepeatedly(::testing::Return(node_info_)); } void InitLocalAvailableResource( absl::flat_hash_map &unit_resource) { diff --git a/src/ray/raylet_rpc_client/raylet_client_pool.cc b/src/ray/raylet_rpc_client/raylet_client_pool.cc index ca67bb04926b..a8724794d947 100644 --- a/src/ray/raylet_rpc_client/raylet_client_pool.cc +++ b/src/ray/raylet_rpc_client/raylet_client_pool.cc @@ -56,9 +56,9 @@ std::function RayletClientPool::GetDefaultUnavailableTimeoutCallback( }; if (gcs_client->Nodes().IsSubscribedToNodeChange()) { - auto *node_info = gcs_client->Nodes().GetNodeAddressAndLiveness( + auto node_info = gcs_client->Nodes().GetNodeAddressAndLiveness( node_id, /*filter_dead_nodes=*/false); - if (node_info == nullptr) { + if (!node_info) { // Node could be dead or info may have not made it to the subscriber cache yet. // Check with the GCS to confirm if the node is dead. gcs_check_node_alive(); diff --git a/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc b/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc index 6dc002604a54..d8f6e8748e50 100644 --- a/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc +++ b/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc @@ -59,7 +59,7 @@ class MockGcsClientNodeAccessor : public gcs::NodeInfoAccessor { bool IsSubscribedToNodeChange() const override { return is_subscribed_to_node_change_; } - MOCK_METHOD(const rpc::GcsNodeAddressAndLiveness *, + MOCK_METHOD(std::optional, GetNodeAddressAndLiveness, (const NodeID &, bool), (const, override)); @@ -162,9 +162,9 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, NodeDeath) { EXPECT_CALL( mock_node_accessor, GetNodeAddressAndLiveness(raylet_client_1_node_id, /*filter_dead_nodes=*/false)) - .WillOnce(Return(nullptr)) - .WillOnce(Return(&node_info_alive)) - .WillOnce(Return(&node_info_dead)); + .WillOnce(Return(std::nullopt)) + .WillOnce(Return(node_info_alive)) + .WillOnce(Return(node_info_dead)); EXPECT_CALL(mock_node_accessor, AsyncGetAllNodeAddressAndLiveness( _, _, std::vector{raylet_client_1_node_id})) @@ -172,7 +172,7 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, NodeDeath) { EXPECT_CALL( mock_node_accessor, GetNodeAddressAndLiveness(raylet_client_2_node_id, /*filter_dead_nodes=*/false)) - .WillOnce(Return(nullptr)); + .WillOnce(Return(std::nullopt)); EXPECT_CALL(mock_node_accessor, AsyncGetAllNodeAddressAndLiveness( _, _, std::vector{raylet_client_2_node_id}))