From ec9209c00c0410d612cba1b77dd3a7b44ed52e9a Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Sun, 19 May 2019 16:38:33 +0800 Subject: [PATCH 1/2] Squash commits to one --- python/ray/tests/conftest.py | 39 +----- python/ray/tests/test_actor.py | 9 +- python/ray/tests/test_basic.py | 21 ++- python/ray/tests/test_object_manager.py | 3 +- python/ray/tests/test_signal.py | 5 +- src/ray/gcs/client_test.cc | 76 ++++++++-- src/ray/gcs/format/gcs.fbs | 7 + src/ray/gcs/redis_module/ray_redis_module.cc | 138 +++++++++++++------ src/ray/gcs/tables.cc | 102 ++++++++++++-- src/ray/gcs/tables.h | 25 +++- src/ray/raylet/node_manager.cc | 5 +- 11 files changed, 311 insertions(+), 119 deletions(-) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 2e670fb0a84d..4f0846524003 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -79,24 +79,16 @@ def ray_start_10_cpus(request): @contextmanager -def _ray_start_cluster(**kwargs): +def _ray_start_cluster(num_nodes=0, do_init=False, **kwargs): init_kwargs = get_default_fixture_ray_kwargs() - num_nodes = 0 - do_init = False - # num_nodes & do_init are not arguments for ray.init, so delete them. - if "num_nodes" in kwargs: - num_nodes = kwargs["num_nodes"] - del kwargs["num_nodes"] - if "do_init" in kwargs: - do_init = kwargs["do_init"] - del kwargs["do_init"] - elif num_nodes > 0: + if num_nodes > 0: do_init = True init_kwargs.update(kwargs) cluster = Cluster() remote_nodes = [] - for _ in range(num_nodes): + for i in range(num_nodes): remote_nodes.append(cluster.add_node(**init_kwargs)) + # Make sure the driver is conencting to the head node. if do_init: ray.init(redis_address=cluster.redis_address) yield cluster @@ -116,14 +108,14 @@ def ray_start_cluster(request): @pytest.fixture def ray_start_cluster_head(request): param = getattr(request, "param", {}) - with _ray_start_cluster(do_init=True, num_nodes=1, **param) as res: + with _ray_start_cluster(num_nodes=1, do_init=True, **param) as res: yield res @pytest.fixture def ray_start_cluster_2_nodes(request): param = getattr(request, "param", {}) - with _ray_start_cluster(do_init=True, num_nodes=2, **param) as res: + with _ray_start_cluster(num_nodes=2, do_init=True, **param) as res: yield res @@ -161,22 +153,3 @@ def call_ray_start(request): ray.shutdown() # Kill the Ray cluster. subprocess.Popen(["ray", "stop"]).wait() - - -@pytest.fixture() -def two_node_cluster(): - internal_config = json.dumps({ - "initial_reconstruction_timeout_milliseconds": 200, - "num_heartbeats_timeout": 10, - }) - cluster = ray.tests.cluster_utils.Cluster( - head_node_args={"_internal_config": internal_config}) - for _ in range(2): - remote_node = cluster.add_node( - num_cpus=1, _internal_config=internal_config) - ray.init(redis_address=cluster.redis_address) - yield cluster, remote_node - - # The code after the yield will run as teardown code. - ray.shutdown() - cluster.shutdown() diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index d7da081fd18c..d4ee1ee7d15c 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -1995,9 +1995,10 @@ def method(self): def test_custom_label_placement(ray_start_cluster): cluster = ray_start_cluster - cluster.add_node(num_cpus=2, resources={"CustomResource1": 2}) + node = cluster.add_node(num_cpus=2, resources={"CustomResource1": 2}) cluster.add_node(num_cpus=2, resources={"CustomResource2": 2}) ray.init(redis_address=cluster.redis_address) + resource1_plasma_socket_name = node.plasma_store_socket_name @ray.remote(resources={"CustomResource1": 1}) class ResourceActor1(object): @@ -2009,17 +2010,15 @@ class ResourceActor2(object): def get_location(self): return ray.worker.global_worker.plasma_client.store_socket_name - local_plasma = ray.worker.global_worker.plasma_client.store_socket_name - # Create some actors. actors1 = [ResourceActor1.remote() for _ in range(2)] actors2 = [ResourceActor2.remote() for _ in range(2)] locations1 = ray.get([a.get_location.remote() for a in actors1]) locations2 = ray.get([a.get_location.remote() for a in actors2]) for location in locations1: - assert location == local_plasma + assert location == resource1_plasma_socket_name for location in locations2: - assert location != local_plasma + assert location != resource1_plasma_socket_name def test_creating_more_actors_than_resources(shutdown_only): diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 3f8c7cb2b3a1..74ea996b1562 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1872,10 +1872,9 @@ def f(): def test_zero_cpus_actor(ray_start_cluster): cluster = ray_start_cluster cluster.add_node(num_cpus=0) - cluster.add_node(num_cpus=2) + node = cluster.add_node(num_cpus=2) ray.init(redis_address=cluster.redis_address) - - local_plasma = ray.worker.global_worker.plasma_client.store_socket_name + available_plasma_socket_name = node.plasma_store_socket_name @ray.remote class Foo(object): @@ -1884,7 +1883,7 @@ def method(self): # Make sure tasks and actors run on the remote raylet. a = Foo.remote() - assert ray.get(a.method.remote()) != local_plasma + assert ray.get(a.method.remote()) == available_plasma_socket_name def test_fractional_resources(shutdown_only): @@ -2067,8 +2066,9 @@ def run_nested2(): def test_custom_resources(ray_start_cluster): cluster = ray_start_cluster cluster.add_node(num_cpus=3, resources={"CustomResource": 0}) - cluster.add_node(num_cpus=3, resources={"CustomResource": 1}) + node = cluster.add_node(num_cpus=3, resources={"CustomResource": 1}) ray.init(redis_address=cluster.redis_address) + available_plasma_socket_name = node.plasma_store_socket_name @ray.remote def f(): @@ -2088,12 +2088,10 @@ def h(): # The f tasks should be scheduled on both raylets. assert len(set(ray.get([f.remote() for _ in range(50)]))) == 2 - local_plasma = ray.worker.global_worker.plasma_client.store_socket_name - # The g tasks should be scheduled only on the second raylet. raylet_ids = set(ray.get([g.remote() for _ in range(50)])) assert len(raylet_ids) == 1 - assert list(raylet_ids)[0] != local_plasma + assert list(raylet_ids)[0] == available_plasma_socket_name # Make sure that resource bookkeeping works when a task that uses a # custom resources gets blocked. @@ -2107,12 +2105,13 @@ def test_two_custom_resources(ray_start_cluster): "CustomResource1": 1, "CustomResource2": 2 }) - cluster.add_node( + node = cluster.add_node( num_cpus=3, resources={ "CustomResource1": 3, "CustomResource2": 4 }) ray.init(redis_address=cluster.redis_address) + available_plasma_socket_name = node.plasma_store_socket_name @ray.remote(resources={"CustomResource1": 1}) def f(): @@ -2143,12 +2142,10 @@ def k(): assert len(set(ray.get([f.remote() for _ in range(50)]))) == 2 assert len(set(ray.get([g.remote() for _ in range(50)]))) == 2 - local_plasma = ray.worker.global_worker.plasma_client.store_socket_name - # The h tasks should be scheduled only on the second raylet. raylet_ids = set(ray.get([h.remote() for _ in range(50)])) assert len(raylet_ids) == 1 - assert list(raylet_ids)[0] != local_plasma + assert list(raylet_ids)[0] == available_plasma_socket_name # Make sure that tasks with unsatisfied custom resource requirements do # not get scheduled. diff --git a/python/ray/tests/test_object_manager.py b/python/ray/tests/test_object_manager.py index e02e3d9a7d6e..6daf075f2a03 100644 --- a/python/ray/tests/test_object_manager.py +++ b/python/ray/tests/test_object_manager.py @@ -222,11 +222,12 @@ def test_object_transfer_retry(ray_start_cluster): object_store_memory = 10**8 cluster.add_node( object_store_memory=object_store_memory, _internal_config=config) + # Make sure the driver is connecting to the head node. + ray.init(redis_address=cluster.redis_address) cluster.add_node( num_gpus=1, object_store_memory=object_store_memory, _internal_config=config) - ray.init(redis_address=cluster.redis_address) @ray.remote(num_gpus=1) def f(size): diff --git a/python/ray/tests/test_signal.py b/python/ray/tests/test_signal.py index fe2e74379245..aba798a8d103 100644 --- a/python/ray/tests/test_signal.py +++ b/python/ray/tests/test_signal.py @@ -274,7 +274,7 @@ def send_signals(self, value, count): assert len(result_list) == count -def test_signal_on_node_failure(two_node_cluster): +def test_signal_on_node_failure(ray_start_cluster_2_nodes): """Test actor checkpointing on a remote node.""" class ActorSignal(object): @@ -285,7 +285,8 @@ def local_plasma(self): return ray.worker.global_worker.plasma_client.store_socket_name # Place the actor on the remote node. - cluster, remote_node = two_node_cluster + cluster = ray_start_cluster_2_nodes + remote_node = list(cluster.worker_nodes)[0] actor_cls = ray.remote(max_reconstructions=0)(ActorSignal) actor = actor_cls.remote() # Try until we put an actor on a different node. diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index f7e25a4873ab..1850bfb3a042 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -70,7 +70,7 @@ class TestGcsWithAsio : public TestGcs { void Start() override { io_service_.run(); } void Stop() override { io_service_.stop(); } - private: + protected: boost::asio::io_service io_service_; // Give the event loop some work so that it's forced to run until Stop() is // called. @@ -292,11 +292,55 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Do a lookup at the object ID. RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback)); + // Test the Replace function of set. + const std::string suffix = "-replaced"; + for (auto &manager : managers) { + auto old_data = std::make_shared(); + auto new_data = std::make_shared(); + old_data->manager = manager; + new_data->manager = manager + suffix; + // Check that we added the correct object entries. + auto add_callback = [object_id, new_data, suffix]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &d) { + ASSERT_EQ(id, object_id); + ASSERT_EQ(new_data->manager, d.manager); + test->IncrementNumCallbacks(); + }; + auto no_trigger_failure = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + ASSERT_TRUE(false); + }; + auto trigger_failure = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + ASSERT_TRUE(true); + }; + auto no_trigger_add = [](gcs::AsyncGcsClient *client, const UniqueID &id, + const ObjectTableDataT &d) { ASSERT_TRUE(false); }; + RAY_CHECK_OK(client->object_table().Replace(driver_id, object_id, old_data, new_data, + add_callback, no_trigger_failure)); + // The old entry should be replaced with the new entry. + RAY_CHECK_OK(client->object_table().Replace(driver_id, object_id, old_data, new_data, + no_trigger_add, trigger_failure)); + } + + // Check that lookup returns the added object entries. + auto lookup_callback2 = [object_id, managers, suffix]( + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { + ASSERT_EQ(id, object_id); + ASSERT_EQ(data.size(), managers.size()); + for (auto &d : data) { + ASSERT_EQ(d.manager.find(suffix), managers[0].length()); + } + test->IncrementNumCallbacks(); + }; + + // Do a lookup at the object ID. + RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback2)); + for (auto &manager : managers) { auto data = std::make_shared(); - data->manager = manager; + data->manager = manager + suffix; // Check that we added the correct object entries. - auto remove_entry_callback = [object_id, data]( + auto remove_entry_callback = [object_id, data, suffix]( gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); @@ -307,7 +351,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli } // Check that the entries are removed. - auto lookup_callback2 = [object_id, managers]( + auto lookup_callback3 = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const std::vector &data) { ASSERT_EQ(id, object_id); @@ -317,11 +361,11 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli }; // Do a lookup at the object ID. - RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback2)); + RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback3)); // Run the event loop. The loop will only stop if the Lookup callback is // called (or an assertion failure). test->Start(); - ASSERT_EQ(test->NumCallbacks(), managers.size() * 2 + 2); + ASSERT_EQ(test->NumCallbacks(), managers.size() * 3 + 3); } TEST_F(TestGcsWithAsio, TestSet) { @@ -1227,14 +1271,14 @@ void TestClientTableDisconnect(const DriverID &driver_id, // event will stop the event loop. client->client_table().RegisterClientAddedCallback( [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientTableNotification(client, id, data, /*is_insertion=*/true); + ClientTableNotification(client, id, data, /*is_connected=*/true); // Disconnect from the client table. We should receive a notification // for the removal of our own entry. RAY_CHECK_OK(client->client_table().Disconnect()); }); client->client_table().RegisterClientRemovedCallback( [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientTableNotification(client, id, data, /*is_insertion=*/false); + ClientTableNotification(client, id, data, /*is_connected=*/false); test->Stop(); }); // Connect to the client table. We should receive notification for the @@ -1282,7 +1326,13 @@ TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) { } void TestClientTableMarkDisconnected(const DriverID &driver_id, - std::shared_ptr client) { + std::shared_ptr client, + std::shared_ptr dead_client) { + // Since marking a non-existing node will not trigger the ClientRemovedCallback, + // a new client is used to connect and mark to dead. + ClientTableDataT dead_client_info = dead_client->client_table().GetLocalClient(); + RAY_CHECK_OK(dead_client->client_table().Connect(dead_client_info)); + ClientTableDataT local_client_info = client->client_table().GetLocalClient(); local_client_info.node_manager_address = "127.0.0.1"; local_client_info.node_manager_port = 0; @@ -1290,7 +1340,7 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, // Connect to the client table to start receiving notifications. RAY_CHECK_OK(client->client_table().Connect(local_client_info)); // Mark a different client as dead. - ClientID dead_client_id = ClientID::from_random(); + ClientID dead_client_id = ClientID::from_binary(dead_client_info.client_id); RAY_CHECK_OK(client->client_table().MarkDisconnected(dead_client_id)); // Make sure we only get a notification for the removal of the client we // marked as dead. @@ -1304,7 +1354,11 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) { test = this; - TestClientTableMarkDisconnected(driver_id_, client_); + auto dead_client = + std::make_shared("127.0.0.1", 6379, command_type_, + /*is_test_client=*/true); + RAY_CHECK_OK(dead_client->Attach(io_service_)); + TestClientTableMarkDisconnected(driver_id_, client_, dead_client); } #undef TEST_MACRO diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 7cf250247461..50944c0a6d09 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -281,6 +281,13 @@ table ClientTableData { resources_total_capacity: [double]; } +table SetReplaceEntryData { + // The old packed data to remove. + old_data: string; + // The new packed data to add. + new_data: string; +} + table HeartbeatTableData { // Node manager client id client_id: string; diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 0405367e15f0..c9e180435c1d 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -190,9 +190,9 @@ flatbuffers::Offset RedisStringToFlatbuf( /// \param mode the update mode, such as append or remove. /// \param data The appended/removed data. /// \return OK if there is no error during a publish. -int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, - RedisModuleString *id, GcsTableNotificationMode notification_mode, - RedisModuleString *data) { +Status PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, + RedisModuleString *id, GcsTableNotificationMode notification_mode, + RedisModuleString *data) { // Serialize the notification to send. flatbuffers::FlatBufferBuilder fbb; auto data_flatbuf = RedisStringToFlatbuf(fbb, data); @@ -206,12 +206,11 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st RedisModuleCallReply *reply = RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str, fbb.GetBufferPointer(), fbb.GetSize()); if (reply == NULL) { - return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); + return Status::RedisError("Empty reply during PUBLISH to all subscriber."); } std::string notification_key; - REPLY_AND_RETURN_IF_NOT_OK( - GetBroadcastKey(ctx, pubsub_channel_str, id, ¬ification_key)); + RAY_RETURN_NOT_OK(GetBroadcastKey(ctx, pubsub_channel_str, id, ¬ification_key)); // Publish the data to any clients who requested notifications on this key. auto it = notification_map.find(notification_key); if (it != notification_map.end()) { @@ -224,11 +223,11 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st RedisModuleCallReply *reply = RedisModule_Call( ctx, "PUBLISH", "sb", channel, fbb.GetBufferPointer(), fbb.GetSize()); if (reply == NULL) { - return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); + return Status::RedisError("Empty reply during PUBLISH"); } } } - return RedisModule_ReplyWithSimpleString(ctx, "OK"); + return Status::OK(); } // RAY.TABLE_ADD: @@ -253,23 +252,23 @@ int TableAdd_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, return REDISMODULE_OK; } -int TableAdd_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { +Status TableAdd_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { if (argc != 5) { - return RedisModule_WrongArity(ctx); + return Status::RedisError("Wrong arity in TableAdd_DoPublish."); } RedisModuleString *pubsub_channel_str = argv[2]; RedisModuleString *id = argv[3]; RedisModuleString *data = argv[4]; TablePubsub pubsub_channel; - REPLY_AND_RETURN_IF_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str)); + RAY_RETURN_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str)); if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the channel. return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsTableNotificationMode::APPEND_OR_ADD, data); } else { - return RedisModule_ReplyWithSimpleString(ctx, "OK"); + return Status::OK(); } } @@ -290,7 +289,8 @@ int TableAdd_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) /// \return The current value at the key, or OK if there is no value. int TableAdd_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { TableAdd_DoWrite(ctx, argv, argc, /*mutated_key_str=*/nullptr); - return TableAdd_DoPublish(ctx, argv, argc); + REPLY_AND_RETURN_IF_NOT_OK(TableAdd_DoPublish(ctx, argv, argc)); + return RedisModule_ReplyWithSimpleString(ctx, "OK"); } #if RAY_USE_NEW_GCS @@ -356,20 +356,21 @@ int TableAppend_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, } } -int TableAppend_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int /*argc*/) { +Status TableAppend_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, + int /*argc*/) { RedisModuleString *pubsub_channel_str = argv[2]; RedisModuleString *id = argv[3]; RedisModuleString *data = argv[4]; // Publish a message on the requested pubsub channel if necessary. TablePubsub pubsub_channel; - REPLY_AND_RETURN_IF_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str)); + RAY_RETURN_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str)); if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the // channel. return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsTableNotificationMode::APPEND_OR_ADD, data); } else { - return RedisModule_ReplyWithSimpleString(ctx, "OK"); + return Status::OK(); } } @@ -397,7 +398,8 @@ int TableAppend_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int REDISMODULE_OK) { return REDISMODULE_ERR; } - return TableAppend_DoPublish(ctx, argv, argc); + REPLY_AND_RETURN_IF_NOT_OK(TableAppend_DoPublish(ctx, argv, argc)); + return RedisModule_ReplyWithSimpleString(ctx, "OK"); } #if RAY_USE_NEW_GCS @@ -409,13 +411,13 @@ int ChainTableAppend_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, } #endif -int Set_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, bool is_add) { +Status Set_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, bool is_add) { RedisModuleString *pubsub_channel_str = argv[2]; RedisModuleString *id = argv[3]; RedisModuleString *data = argv[4]; // Publish a message on the requested pubsub channel if necessary. TablePubsub pubsub_channel; - REPLY_AND_RETURN_IF_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str)); + RAY_RETURN_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str)); if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the // channel. @@ -424,14 +426,14 @@ int Set_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, bool is_add) { : GcsTableNotificationMode::REMOVE, data); } else { - return RedisModule_ReplyWithSimpleString(ctx, "OK"); + return Status::OK(); } } -int Set_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, bool is_add, - bool *changed) { +Status Set_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, bool is_add, + bool *changed) { if (argc != 5) { - return RedisModule_WrongArity(ctx); + return Status::RedisError("Wrong arity in Set_DoWrite."); } RedisModuleString *prefix_str = argv[1]; @@ -449,19 +451,19 @@ int Set_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, bool is if (!is_add && *changed) { // try to delete the empty set. RedisModuleKey *key; - REPLY_AND_RETURN_IF_NOT_OK( - OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_WRITE)); + RAY_RETURN_NOT_OK(OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_WRITE)); auto size = RedisModule_ValueLength(key); if (size == 0) { - REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, - "ERR Failed to delete empty set."); + if (RedisModule_DeleteKey(key) != REDISMODULE_OK) { + return Status::RedisError("ERR Failed to delete empty set."); + } } } - return REDISMODULE_OK; + return Status::OK(); } else { - // the SADD/SREM command failed - RedisModule_ReplyWithCallReply(ctx, reply); - return REDISMODULE_ERR; + size_t reply_size = 0; + const char *reply_pointer = RedisModule_CallReplyStringPtr(reply, &reply_size); + return Status::RedisError(std::string(reply_pointer, reply_size)); } } @@ -481,11 +483,9 @@ int Set_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, bool is /// \return OK if the add succeeds, or an error message string if the add fails. int SetAdd_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { bool changed; - if (Set_DoWrite(ctx, argv, argc, /*is_add=*/true, &changed) != REDISMODULE_OK) { - return REDISMODULE_ERR; - } + REPLY_AND_RETURN_IF_NOT_OK(Set_DoWrite(ctx, argv, argc, /*is_add=*/true, &changed)); if (changed) { - return Set_DoPublish(ctx, argv, /*is_add=*/true); + REPLY_AND_RETURN_IF_NOT_OK(Set_DoPublish(ctx, argv, /*is_add=*/true)); } return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -507,17 +507,71 @@ int SetAdd_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) /// fails. int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { bool changed; - if (Set_DoWrite(ctx, argv, argc, /*is_add=*/false, &changed) != REDISMODULE_OK) { - return REDISMODULE_ERR; - } + REPLY_AND_RETURN_IF_NOT_OK(Set_DoWrite(ctx, argv, argc, /*is_add=*/false, &changed)); if (changed) { - return Set_DoPublish(ctx, argv, /*is_add=*/false); + REPLY_AND_RETURN_IF_NOT_OK(Set_DoPublish(ctx, argv, /*is_add=*/false)); } else { RAY_LOG(ERROR) << "The entry to remove doesn't exist."; } return RedisModule_ReplyWithSimpleString(ctx, "OK"); } +/// Replace an entry from the set stored at a key with another entry. +/// It publishes two notifications about the update to all subscribers, if a pubsub +/// channel is provided and the replace operation succeeds. If the target removing +/// entry does not exist, this command returns a failure message, otherwise, OK. +/// +/// This is called from a client with the command: +// +/// RAY.SET_REPLACE +/// +/// \param table_prefix The prefix string for keys in this table. +/// \param pubsub_channel The pubsub channel name that notifications for this +/// key should be published to. When publishing to a specific client, the +/// channel name should be :. +/// \param id The ID of the key to remove from. +/// \param data The data contains the data to remove and data to add. +/// \return OK if the remove succeeds, or failure message if the operation fails. +int SetReplace_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + if (argc != 5) { + return RedisModule_WrongArity(ctx); + } + // Copy the arguements . + RedisModuleString *data = argv[4]; + std::vector new_argv(argv, argv + argc); + // Split the data into 2 parts. + size_t data_size = 0; + const uint8_t *data_string = + reinterpret_cast(RedisModule_StringPtrLen(data, &data_size)); + auto root = flatbuffers::GetRoot(data_string); + // Remove the old data. + auto old_data = + RedisModule_CreateString(ctx, root->old_data()->data(), root->old_data()->size()); + new_argv[4] = old_data; + bool changed = false; + bool is_add = false; + REPLY_AND_RETURN_IF_NOT_OK(Set_DoWrite(ctx, new_argv.data(), argc, is_add, &changed)); + + if (changed) { + REPLY_AND_RETURN_IF_NOT_OK(Set_DoPublish(ctx, new_argv.data(), is_add)); + } else { + static const char *replace_reply = "ERR trying to replace a none-existing entry"; + return RedisModule_ReplyWithStringBuffer(ctx, replace_reply, strlen(replace_reply)); + } + + // Add the new data. + auto new_data = + RedisModule_CreateString(ctx, root->new_data()->data(), root->new_data()->size()); + new_argv[4] = new_data; + is_add = true; + REPLY_AND_RETURN_IF_NOT_OK(Set_DoWrite(ctx, new_argv.data(), argc, is_add, &changed)); + + if (changed) { + REPLY_AND_RETURN_IF_NOT_OK(Set_DoPublish(ctx, new_argv.data(), is_add)); + } + return RedisModule_ReplyWithSimpleString(ctx, "OK"); +} + /// A helper function to create and finish a GcsTableEntry, based on the /// current value or values at the given key. /// @@ -874,6 +928,7 @@ AUTO_MEMORY(TableAdd_RedisCommand); AUTO_MEMORY(TableAppend_RedisCommand); AUTO_MEMORY(SetAdd_RedisCommand); AUTO_MEMORY(SetRemove_RedisCommand); +AUTO_MEMORY(SetReplace_RedisCommand); AUTO_MEMORY(TableLookup_RedisCommand); AUTO_MEMORY(TableRequestNotifications_RedisCommand); AUTO_MEMORY(TableDelete_RedisCommand); @@ -917,6 +972,11 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } + if (RedisModule_CreateCommand(ctx, "ray.set_replace", SetReplace_RedisCommand, + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + if (RedisModule_CreateCommand(ctx, "ray.table_lookup", TableLookup_RedisCommand, "readonly", 0, 0, 0) == REDISMODULE_ERR) { return REDISMODULE_ERR; diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index dbd39349caf7..d94fafe7e3d7 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -316,6 +316,42 @@ Status Set::Remove(const DriverID &driver_id, const ID &id, std::move(callback)); } +template +Status Set::Replace(const DriverID &driver_id, const ID &id, + std::shared_ptr &old_data, + std::shared_ptr &new_data, const WriteCallback &done, + const FailureCallback &failure) { + num_removes_++; + auto callback = [this, id, new_data, done, failure](const std::string &data) { + if (data.length() == 0 && done != nullptr) { + (done)(client_, id, *new_data); + } else if (data.length() != 0 && failure != nullptr) { + (failure)(client_, id); + } + return true; + }; + + flatbuffers::FlatBufferBuilder fbb_old; + fbb_old.ForceDefaults(true); + fbb_old.Finish(Data::Pack(fbb_old, old_data.get())); + std::string old_data_string(reinterpret_cast(fbb_old.GetBufferPointer()), + fbb_old.GetSize()); + flatbuffers::FlatBufferBuilder fbb_new; + fbb_new.ForceDefaults(true); + fbb_new.Finish(Data::Pack(fbb_new, new_data.get())); + std::string new_data_string(reinterpret_cast(fbb_new.GetBufferPointer()), + fbb_new.GetSize()); + flatbuffers::FlatBufferBuilder fbb; + auto message = CreateSetReplaceEntryData(fbb, fbb.CreateString(old_data_string), + fbb.CreateString(new_data_string)); + fbb.Finish(message); + + return GetRedisContext(id)->RunAsync("RAY.SET_REPLACE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); + return Status::OK(); +} + template std::string Set::DebugString() const { std::stringstream result; @@ -527,6 +563,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { RAY_CHECK(local_client.client_id == local_client_.client_id); local_client_ = local_client; + local_client_.entry_type = EntryType::INSERTION; // Construct the data to add to the client table. auto data = std::make_shared(local_client_); @@ -541,6 +578,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { // Callback for a notification from the client table. auto notification_callback = [this]( AsyncGcsClient *client, const UniqueID &log_key, + const GcsTableNotificationMode notification_mode, const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); std::unordered_map connected_nodes; @@ -557,12 +595,11 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { } disconnected_nodes.emplace(notification.client_id, notification); } - } - for (const auto &pair : connected_nodes) { - HandleNotification(client, pair.second); - } - for (const auto &pair : disconnected_nodes) { - HandleNotification(client, pair.second); + // We only handle the message of APPEND_OR_ADD and ignore the deletion message. + //if (notification_mode == GcsTableNotificationMode::APPEND_OR_ADD) { + // for (const auto ¬ification : notifications) { + // HandleNotification(client, notification); + // } } }; // Callback to request notifications from the client table once we've @@ -574,7 +611,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { RAY_CHECK_OK(Subscribe(DriverID::nil(), client_id_, notification_callback, subscription_callback)); }; - return Append(DriverID::nil(), client_log_key_, data, add_callback); + return Add(DriverID::nil(), client_log_key_, data, add_callback); } Status ClientTable::Disconnect(const DisconnectCallback &callback) { @@ -582,23 +619,63 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { data->entry_type = EntryType::DELETION; auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + //auto remove_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, + // const ClientTableDataT &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(DriverID::nil(), client_log_key_, id)); if (callback != nullptr) { callback(); } }; - RAY_RETURN_NOT_OK(Append(DriverID::nil(), client_log_key_, data, add_callback)); + auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + RAY_LOG(FATAL) << "This should not happen in ClientTable::Disconnect."; + }; + // Change is_connected of this entry from true to false. + auto new_data = std::make_shared(*data); + data->entry_type = EntryType::DELETION; + RAY_CHECK_OK(Replace(DriverID::nil(), client_log_key_, data, new_data, nullptr, + failure_callback)); // We successfully added the deletion entry. Mark ourselves as disconnected. disconnected_ = true; return Status::OK(); } ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { - auto data = std::make_shared(); + /*auto data = std::make_shared(); data->client_id = dead_client_id.binary(); data->entry_type = EntryType::DELETION; - return Append(DriverID::nil(), client_log_key_, data, nullptr); + return Append(DriverID::nil(), client_log_key_, data, nullptr);*/ + if (client_id_ == dead_client_id) { + return Disconnect(nullptr); + } else { + auto lookup_callback = [this, dead_client_id]( + AsyncGcsClient *client, const ClientID &id, + const std::vector &clients) { + std::shared_ptr data; + for (const auto &client_data : clients) { + if (client_data.client_id == dead_client_id.binary()) { + data = std::make_shared(client_data); + break; + } + } + if (data != nullptr) { + // Change is_connected of this entry from true to false. + auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + // To avoid race condition, say the monitor and the raylet both try to change + // the status, we just log an error message here. + RAY_LOG(ERROR) << "ClientTable::MarkDisconnected failed, " + << "which may be caused by race condition."; + }; + auto new_data = std::make_shared(*data); + data->entry_type = EntryType::DELETION; + RAY_CHECK_OK(Replace(DriverID::nil(), client_log_key_, data, new_data, nullptr, + failure_callback)); + } else { + RAY_LOG(WARNING) << "Trying to mark a raylet node that is not registered."; + } + }; + return Lookup(lookup_callback); + } } void ClientTable::GetClient(const ClientID &client_id, @@ -618,12 +695,12 @@ const std::unordered_map &ClientTable::GetAllClients Status ClientTable::Lookup(const Callback &lookup) { RAY_CHECK(lookup != nullptr); - return Log::Lookup(DriverID::nil(), client_log_key_, lookup); + return Set::Lookup(DriverID::nil(), client_log_key_, lookup); } std::string ClientTable::DebugString() const { std::stringstream result; - result << Log::DebugString(); + result << Set::DebugString(); result << ", cache size: " << client_cache_.size() << ", num removed: " << removed_clients_.size(); return result.str(); @@ -676,6 +753,7 @@ template class Table; template class Table; template class Log; template class Log; +template class Set; template class Log; template class Log; template class Table; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 056bf7b97ec7..f0cf5a9ee070 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -371,10 +371,16 @@ class SetInterface { public: using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; + /// The callback to call when a Replace call returns missing entry reply. + using FailureCallback = std::function; virtual Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done) = 0; virtual Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done) = 0; + virtual Status Replace(const DriverID &driver_id, const ID &id, + std::shared_ptr &old_data, + std::shared_ptr &new_data, const WriteCallback &done, + const FailureCallback &failure) = 0; virtual ~SetInterface(){}; }; @@ -397,6 +403,7 @@ class Set : private Log, using WriteCallback = typename Log::WriteCallback; using NotificationCallback = typename Log::NotificationCallback; using SubscriptionCallback = typename Log::SubscriptionCallback; + using FailureCallback = typename SetInterface::FailureCallback; Set(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) {} @@ -428,6 +435,20 @@ class Set : private Log, Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); + /// Replace an entry with the new data. + /// + /// \param job_id The ID of the job (= driver). + /// \param id The ID of the data that is removed from the GCS. + /// \param old_data The old data to remove from the set. The new data will not be + /// added if this data is not in the set. + /// \param new_data The new data to add to the set. + /// \param done Callback that is called once the data has been written to the + /// GCS. + /// \return Status + Status Replace(const DriverID &driver_id, const ID &id, + std::shared_ptr &old_data, std::shared_ptr &new_data, + const WriteCallback &done, const FailureCallback &failure); + Status Subscribe(const DriverID &driver_id, const ClientID &client_id, const NotificationCallback &subscribe, const SubscriptionCallback &done) { @@ -677,14 +698,14 @@ using ConfigTable = Table; /// it should append an entry to the log indicating that it is dead. A client /// that is marked as dead should never again be marked as alive; if it needs /// to reconnect, it must connect with a different ClientID. -class ClientTable : public Log { +class ClientTable : private Set { public: using ClientTableCallback = std::function; using DisconnectCallback = std::function; ClientTable(const std::vector> &contexts, AsyncGcsClient *client, const ClientID &client_id) - : Log(contexts, client), + : Set(contexts, client), // We set the client log's key equal to nil so that all instances of // ClientTable have the same key. client_log_key_(), diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index efd190ba5b27..929367d2c3e2 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -384,8 +384,9 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { auto status = ConnectRemoteNodeManager(client_id, client_data.node_manager_address, client_data.node_manager_port); if (!status.ok()) { - // This is not a fatal error for raylet, but it should not happen. - // We need to broadcase this message. + // This is not a fatal error for raylet, it may happen in a race condition. + // For example, the target raylet node just disconnected after this node finished + // reading the information from GCS. std::string type = "raylet_connection_error"; std::ostringstream error_message; error_message << "Failed to connect to ray node " << client_id From 65f96a0d99e6db555678da968022e93b7fae9526 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Sun, 19 May 2019 19:58:10 +0800 Subject: [PATCH 2/2] Resolve conflict and fix gtest cass --- src/ray/gcs/tables.cc | 35 +++++++++------------------------- src/ray/gcs/tables.h | 2 +- src/ray/raylet/node_manager.cc | 2 +- 3 files changed, 11 insertions(+), 28 deletions(-) diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index d94fafe7e3d7..c790b3c6c7af 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -581,25 +581,11 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { const GcsTableNotificationMode notification_mode, const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; - for (auto ¬ification : notifications) { - // This is temporary fix for Issue 4140 to avoid connect to dead nodes. - // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type != EntryType::DELETION) { - connected_nodes.emplace(notification.client_id, notification); - } else { - auto iter = connected_nodes.find(notification.client_id); - if (iter != connected_nodes.end()) { - connected_nodes.erase(iter); - } - disconnected_nodes.emplace(notification.client_id, notification); - } // We only handle the message of APPEND_OR_ADD and ignore the deletion message. - //if (notification_mode == GcsTableNotificationMode::APPEND_OR_ADD) { - // for (const auto ¬ification : notifications) { - // HandleNotification(client, notification); - // } + if (notification_mode == GcsTableNotificationMode::APPEND_OR_ADD) { + for (const auto ¬ification : notifications) { + HandleNotification(client, notification); + } } }; // Callback to request notifications from the client table once we've @@ -616,11 +602,8 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { Status ClientTable::Disconnect(const DisconnectCallback &callback) { auto data = std::make_shared(local_client_); - data->entry_type = EntryType::DELETION; - auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - //auto remove_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, - // const ClientTableDataT &data) { + auto remove_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, + const ClientTableDataT &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(DriverID::nil(), client_log_key_, id)); if (callback != nullptr) { @@ -632,8 +615,8 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { }; // Change is_connected of this entry from true to false. auto new_data = std::make_shared(*data); - data->entry_type = EntryType::DELETION; - RAY_CHECK_OK(Replace(DriverID::nil(), client_log_key_, data, new_data, nullptr, + new_data->entry_type = EntryType::DELETION; + RAY_CHECK_OK(Replace(DriverID::nil(), client_log_key_, data, new_data, remove_callback, failure_callback)); // We successfully added the deletion entry. Mark ourselves as disconnected. disconnected_ = true; @@ -667,7 +650,7 @@ ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { << "which may be caused by race condition."; }; auto new_data = std::make_shared(*data); - data->entry_type = EntryType::DELETION; + new_data->entry_type = EntryType::DELETION; RAY_CHECK_OK(Replace(DriverID::nil(), client_log_key_, data, new_data, nullptr, failure_callback)); } else { diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index f0cf5a9ee070..12bd2b54bc7c 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -698,7 +698,7 @@ using ConfigTable = Table; /// it should append an entry to the log indicating that it is dead. A client /// that is marked as dead should never again be marked as alive; if it needs /// to reconnect, it must connect with a different ClientID. -class ClientTable : private Set { +class ClientTable : public Set { public: using ClientTableCallback = std::function; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 929367d2c3e2..a1ba69e26204 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1330,7 +1330,7 @@ void NodeManager::ProcessSetResourceRequest( } auto data_shared_ptr = std::make_shared(data); auto client_table = gcs_client_->client_table(); - RAY_CHECK_OK(gcs_client_->client_table().Append( + RAY_CHECK_OK(gcs_client_->client_table().Add( DriverID::nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); }