diff --git a/src/ray/raylet_rpc_client/raylet_client_pool.cc b/src/ray/raylet_rpc_client/raylet_client_pool.cc index 438aa27b4d7c..ca67bb04926b 100644 --- a/src/ray/raylet_rpc_client/raylet_client_pool.cc +++ b/src/ray/raylet_rpc_client/raylet_client_pool.cc @@ -29,9 +29,9 @@ std::function RayletClientPool::GetDefaultUnavailableTimeoutCallback( const NodeID node_id = NodeID::FromBinary(addr.node_id()); auto gcs_check_node_alive = [node_id, addr, raylet_client_pool, gcs_client]() { - gcs_client->Nodes().AsyncGetAll( - [addr, node_id, raylet_client_pool](const Status &status, - std::vector &&nodes) { + gcs_client->Nodes().AsyncGetAllNodeAddressAndLiveness( + [addr, node_id, raylet_client_pool]( + const Status &status, std::vector &&nodes) { if (!status.ok()) { // Will try again when unavailable timeout callback is retried. RAY_LOG(INFO) << "Failed to get node info from GCS"; @@ -56,7 +56,8 @@ std::function RayletClientPool::GetDefaultUnavailableTimeoutCallback( }; if (gcs_client->Nodes().IsSubscribedToNodeChange()) { - auto *node_info = gcs_client->Nodes().Get(node_id, /*filter_dead_nodes=*/false); + auto *node_info = gcs_client->Nodes().GetNodeAddressAndLiveness( + node_id, /*filter_dead_nodes=*/false); if (node_info == nullptr) { // Node could be dead or info may have not made it to the subscriber cache yet. // Check with the GCS to confirm if the node is dead. diff --git a/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc b/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc index 12681e8069de..1104763d2a66 100644 --- a/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc +++ b/src/ray/raylet_rpc_client/tests/raylet_client_pool_test.cc @@ -59,11 +59,14 @@ class MockGcsClientNodeAccessor : public gcs::NodeInfoAccessor { bool IsSubscribedToNodeChange() const override { return is_subscribed_to_node_change_; } - MOCK_METHOD(const GcsNodeInfo *, Get, (const NodeID &, bool), (const, override)); + MOCK_METHOD(const rpc::GcsNodeAddressAndLiveness *, + GetNodeAddressAndLiveness, + (const NodeID &, bool), + (const, override)); MOCK_METHOD(void, - AsyncGetAll, - (const gcs::MultiItemCallback &, + AsyncGetAllNodeAddressAndLiveness, + (const gcs::MultiItemCallback &, int64_t, const std::vector &), (override)); @@ -118,13 +121,16 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, NodeDeath) { // had to discard to keep its cache size in check, should disconnect. auto &mock_node_accessor = gcs_client_.MockNodeAccessor(); - auto invoke_with_node_info_vector = [](std::vector node_info_vector) { - return Invoke([node_info_vector](const gcs::MultiItemCallback &callback, - int64_t, - const std::vector &) { - callback(Status::OK(), node_info_vector); - }); - }; + auto invoke_with_node_info_vector = + [](std::vector node_info_vector) { + return Invoke( + [node_info_vector]( + const gcs::MultiItemCallback &callback, + int64_t, + const std::vector &) { + callback(Status::OK(), node_info_vector); + }); + }; auto raylet_client_1_address = CreateRandomAddress("1"); auto raylet_client_2_address = CreateRandomAddress("2"); @@ -140,33 +146,39 @@ TEST_P(DefaultUnavailableTimeoutCallbackTest, NodeDeath) { ASSERT_TRUE( CheckRayletClientPoolHasClient(*raylet_client_pool_, raylet_client_2_node_id)); - GcsNodeInfo node_info_alive; + GcsNodeAddressAndLiveness node_info_alive; node_info_alive.set_state(GcsNodeInfo::ALIVE); - GcsNodeInfo node_info_dead; + GcsNodeAddressAndLiveness node_info_dead; node_info_dead.set_state(GcsNodeInfo::DEAD); if (is_subscribed_to_node_change_) { - EXPECT_CALL(mock_node_accessor, - Get(raylet_client_1_node_id, /*filter_dead_nodes=*/false)) + EXPECT_CALL( + mock_node_accessor, + GetNodeAddressAndLiveness(raylet_client_1_node_id, /*filter_dead_nodes=*/false)) .WillOnce(Return(nullptr)) .WillOnce(Return(&node_info_alive)) .WillOnce(Return(&node_info_dead)); EXPECT_CALL(mock_node_accessor, - AsyncGetAll(_, _, std::vector{raylet_client_1_node_id})) + AsyncGetAllNodeAddressAndLiveness( + _, _, std::vector{raylet_client_1_node_id})) .WillOnce(invoke_with_node_info_vector({node_info_alive})); - EXPECT_CALL(mock_node_accessor, - Get(raylet_client_2_node_id, /*filter_dead_nodes=*/false)) + EXPECT_CALL( + mock_node_accessor, + GetNodeAddressAndLiveness(raylet_client_2_node_id, /*filter_dead_nodes=*/false)) .WillOnce(Return(nullptr)); EXPECT_CALL(mock_node_accessor, - AsyncGetAll(_, _, std::vector{raylet_client_2_node_id})) + AsyncGetAllNodeAddressAndLiveness( + _, _, std::vector{raylet_client_2_node_id})) .WillOnce(invoke_with_node_info_vector({})); } else { EXPECT_CALL(mock_node_accessor, - AsyncGetAll(_, _, std::vector{raylet_client_1_node_id})) + AsyncGetAllNodeAddressAndLiveness( + _, _, std::vector{raylet_client_1_node_id})) .WillOnce(invoke_with_node_info_vector({node_info_alive})) .WillOnce(invoke_with_node_info_vector({node_info_alive})) .WillOnce(invoke_with_node_info_vector({node_info_dead})); EXPECT_CALL(mock_node_accessor, - AsyncGetAll(_, _, std::vector{raylet_client_2_node_id})) + AsyncGetAllNodeAddressAndLiveness( + _, _, std::vector{raylet_client_2_node_id})) .WillOnce(invoke_with_node_info_vector({})); }