diff --git a/BUILD.bazel b/BUILD.bazel index 90b0f536a10b..36f02e292fa1 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -535,7 +535,7 @@ flatbuffer_py_library( "ErrorTableData.py", "ErrorType.py", "FunctionTableData.py", - "GcsTableEntry.py", + "GcsEntry.py", "HeartbeatBatchTableData.py", "HeartbeatTableData.py", "Language.py", diff --git a/doc/source/conf.py b/doc/source/conf.py index b0ae3416d4ab..98fb3e0d02dd 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -29,7 +29,7 @@ "ray.core.generated.EntryType", "ray.core.generated.ErrorTableData", "ray.core.generated.ErrorType", - "ray.core.generated.GcsTableEntry", + "ray.core.generated.GcsEntry", "ray.core.generated.HeartbeatBatchTableData", "ray.core.generated.HeartbeatTableData", "ray.core.generated.Language", diff --git a/java/BUILD.bazel b/java/BUILD.bazel index f86df8d40f96..f3ae6f063304 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -160,7 +160,7 @@ flatbuffers_generated_files = [ "ErrorTableData.java", "ErrorType.java", "FunctionTableData.java", - "GcsTableEntry.java", + "GcsEntry.java", "HeartbeatBatchTableData.java", "HeartbeatTableData.java", "Language.java", diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 15eec6c81136..cadd197ec73f 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -9,7 +9,7 @@ from ray.core.generated.ClientTableData import ClientTableData from ray.core.generated.DriverTableData import DriverTableData from ray.core.generated.ErrorTableData import ErrorTableData -from ray.core.generated.GcsTableEntry import GcsTableEntry +from ray.core.generated.GcsEntry import GcsEntry from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData from ray.core.generated.HeartbeatTableData import HeartbeatTableData from ray.core.generated.Language import Language @@ -25,7 +25,7 @@ "ClientTableData", "DriverTableData", "ErrorTableData", - "GcsTableEntry", + "GcsEntry", "HeartbeatBatchTableData", "HeartbeatTableData", "Language", diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 09a154d7b548..c9e0424b3eb8 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -101,8 +101,7 @@ def subscribe(self, channel): def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) heartbeat_data = gcs_entries.Entries(0) message = (ray.gcs_utils.HeartbeatBatchTableData. @@ -208,8 +207,7 @@ def xray_driver_removed_handler(self, unused_channel, data): unused_channel: The message channel. data: The message data. """ - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) driver_data = gcs_entries.Entries(0) message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( driver_data, 0) diff --git a/python/ray/state.py b/python/ray/state.py index 6b2c8a4ef8bc..14ba49987ec4 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -41,7 +41,7 @@ def _parse_client_table(redis_client): return [] node_info = {} - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) ordered_client_ids = [] @@ -248,8 +248,7 @@ def _object_table(self, object_id): object_id.binary()) if message is None: return {} - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) assert gcs_entry.EntriesLength() > 0 @@ -307,8 +306,7 @@ def _task_table(self, task_id): "", task_id.binary()) if message is None: return {} - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) assert gcs_entries.EntriesLength() == 1 @@ -431,8 +429,7 @@ def _profile_table(self, batch_id): if message is None: return [] - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) profile_events = [] for i in range(gcs_entries.EntriesLength()): @@ -815,9 +812,8 @@ def available_resources(self): ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - gcs_entries = ( - ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0)) + gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( + data, 0)) heartbeat_data = gcs_entries.Entries(0) message = (ray.gcs_utils.HeartbeatTableData. GetRootAsHeartbeatTableData(heartbeat_data, 0)) @@ -871,8 +867,7 @@ def _error_messages(self, driver_id): if message is None: return [] - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) error_messages = [] for i in range(gcs_entries.EntriesLength()): error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( @@ -934,8 +929,7 @@ def actor_checkpoint_info(self, actor_id): ) if message is None: return None - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) entry = ( ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData( gcs_entry.Entries(0), 0)) diff --git a/python/ray/worker.py b/python/ray/worker.py index 7786c742d9b1..7505120574a6 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1656,7 +1656,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): if msg is None: threads_stopped.wait(timeout=0.01) continue - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( msg["data"], 0) assert gcs_entry.EntriesLength() == 1 error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index d9b5087c4719..3d1c6602740c 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -120,6 +120,7 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, profile_table_.reset(new ProfileTable(shard_contexts_, this)); actor_checkpoint_table_.reset(new ActorCheckpointTable(shard_contexts_, this)); actor_checkpoint_id_table_.reset(new ActorCheckpointIdTable(shard_contexts_, this)); + resource_table_.reset(new DynamicResourceTable({primary_context_}, this)); command_type_ = command_type; // TODO(swang): Call the client table's Connect() method here. To do this, @@ -229,6 +230,8 @@ ActorCheckpointIdTable &AsyncGcsClient::actor_checkpoint_id_table() { return *actor_checkpoint_id_table_; } +DynamicResourceTable &AsyncGcsClient::resource_table() { return *resource_table_; } + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index d47d9a6e8b24..c9f5b4bca624 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -62,6 +62,7 @@ class RAY_EXPORT AsyncGcsClient { ProfileTable &profile_table(); ActorCheckpointTable &actor_checkpoint_table(); ActorCheckpointIdTable &actor_checkpoint_id_table(); + DynamicResourceTable &resource_table(); // We also need something to export generic code to run on workers from the // driver (to set the PYTHONPATH) @@ -94,6 +95,7 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr client_table_; std::unique_ptr actor_checkpoint_table_; std::unique_ptr actor_checkpoint_id_table_; + std::unique_ptr resource_table_; // The following contexts write to the data shard std::vector> shard_contexts_; std::vector> shard_asio_async_clients_; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 1b43bcc23c08..4eb34a95328a 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -657,13 +657,12 @@ void TestSetSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [object_ids, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const GcsTableNotificationMode notification_mode, + gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, const std::vector data) { if (test->NumCallbacks() < 3 * 3) { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); + ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); } else { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::REMOVE); + ASSERT_EQ(change_mode, GcsChangeMode::REMOVE); } ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); // Check that we get notifications in the same order as the writes. @@ -894,10 +893,9 @@ void TestSetSubscribeId(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [object_id2, managers2]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const GcsTableNotificationMode notification_mode, + gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, const std::vector &data) { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); + ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); // Check that we get notifications in the same order as the writes. @@ -1111,10 +1109,9 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const GcsTableNotificationMode notification_mode, + gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, const std::vector &data) { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); + ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because notifications @@ -1307,6 +1304,161 @@ TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) { TestClientTableMarkDisconnected(driver_id_, client_); } +void TestHashTable(const DriverID &driver_id, + std::shared_ptr client) { + const int expected_count = 14; + ClientID client_id = ClientID::FromRandom(); + // Prepare the first resource map: data_map1. + auto cpu_data = std::make_shared(); + cpu_data->resource_name = "CPU"; + cpu_data->resource_capacity = 100; + auto gpu_data = std::make_shared(); + gpu_data->resource_name = "GPU"; + gpu_data->resource_capacity = 2; + DynamicResourceTable::DataMap data_map1; + data_map1.emplace("CPU", cpu_data); + data_map1.emplace("GPU", gpu_data); + // Prepare the second resource map: data_map2 which decreases CPU, + // increases GPU and add a new CUSTOM compared to data_map1. + auto data_cpu = std::make_shared(); + data_cpu->resource_name = "CPU"; + data_cpu->resource_capacity = 50; + auto data_gpu = std::make_shared(); + data_gpu->resource_name = "GPU"; + data_gpu->resource_capacity = 10; + auto data_custom = std::make_shared(); + data_custom->resource_name = "CUSTOM"; + data_custom->resource_capacity = 2; + DynamicResourceTable::DataMap data_map2; + data_map2.emplace("CPU", data_cpu); + data_map2.emplace("GPU", data_gpu); + data_map2.emplace("CUSTOM", data_custom); + data_map2["CPU"]->resource_capacity = 50; + // This is a common comparison function for the test. + auto compare_test = [](const DynamicResourceTable::DataMap &data1, + const DynamicResourceTable::DataMap &data2) { + ASSERT_EQ(data1.size(), data2.size()); + for (const auto &data : data1) { + auto iter = data2.find(data.first); + ASSERT_TRUE(iter != data2.end()); + ASSERT_EQ(iter->second->resource_name, data.second->resource_name); + ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity); + } + }; + auto subscribe_callback = [](AsyncGcsClient *client) { + ASSERT_TRUE(true); + test->IncrementNumCallbacks(); + }; + auto notification_callback = [data_map1, data_map2, compare_test]( + AsyncGcsClient *client, const ClientID &id, const GcsChangeMode change_mode, + const DynamicResourceTable::DataMap &data) { + if (change_mode == GcsChangeMode::REMOVE) { + ASSERT_EQ(data.size(), 2); + ASSERT_TRUE(data.find("GPU") != data.end()); + ASSERT_TRUE(data.find("CUSTOM") != data.end() || data.find("CPU") != data.end()); + // The key "None-Existent" will not appear in the notification. + } else { + if (data.size() == 2) { + compare_test(data_map1, data); + } else if (data.size() == 3) { + compare_test(data_map2, data); + } else { + ASSERT_TRUE(false); + } + } + test->IncrementNumCallbacks(); + // It is not sure which of the notification or lookup callback will come first. + if (test->NumCallbacks() == expected_count) { + test->Stop(); + } + }; + // Step 0: Subscribe the change of the hash table. + RAY_CHECK_OK(client->resource_table().Subscribe( + driver_id, ClientID::Nil(), notification_callback, subscribe_callback)); + RAY_CHECK_OK(client->resource_table().RequestNotifications( + driver_id, client_id, client->client_table().GetLocalClientId())); + + // Step 1: Add elements to the hash table. + auto update_callback1 = [data_map1, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map1, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK( + client->resource_table().Update(driver_id, client_id, data_map1, update_callback1)); + auto lookup_callback1 = [data_map1, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map1, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback1)); + + // Step 2: Decrease one element, increase one and add a new one. + RAY_CHECK_OK(client->resource_table().Update(driver_id, client_id, data_map2, nullptr)); + auto lookup_callback2 = [data_map2, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map2, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback2)); + std::vector delete_keys({"GPU", "CUSTOM", "None-Existent"}); + auto remove_callback = [delete_keys](AsyncGcsClient *client, const ClientID &id, + const std::vector &callback_data) { + for (int i = 0; i < callback_data.size(); ++i) { + // All deleting keys exist in this argument even if the key doesn't exist. + ASSERT_EQ(callback_data[i], delete_keys[i]); + } + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().RemoveEntries(driver_id, client_id, delete_keys, + remove_callback)); + DynamicResourceTable::DataMap data_map3(data_map2); + data_map3.erase("GPU"); + data_map3.erase("CUSTOM"); + auto lookup_callback3 = [data_map3, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map3, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback3)); + + // Step 3: Reset the the resources to data_map1. + RAY_CHECK_OK( + client->resource_table().Update(driver_id, client_id, data_map1, update_callback1)); + auto lookup_callback4 = [data_map1, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map1, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback4)); + + // Step 4: Removing all elements will remove the home Hash table from GCS. + RAY_CHECK_OK(client->resource_table().RemoveEntries( + driver_id, client_id, {"GPU", "CPU", "CUSTOM", "None-Existent"}, nullptr)); + auto lookup_callback5 = [](AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + ASSERT_EQ(callback_data.size(), 0); + test->IncrementNumCallbacks(); + // It is not sure which of notification or lookup callback will come first. + if (test->NumCallbacks() == expected_count) { + test->Stop(); + } + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback5)); + test->Start(); + ASSERT_EQ(test->NumCallbacks(), expected_count); +} + +TEST_F(TestGcsWithAsio, TestHashTable) { + test = this; + TestHashTable(driver_id_, client_); +} + #undef TEST_MACRO } // namespace gcs diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index b81f388d88c5..614c80b27672 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -22,6 +22,7 @@ enum TablePrefix:int { TASK_LEASE, ACTOR_CHECKPOINT, ACTOR_CHECKPOINT_ID, + NODE_RESOURCE, } // The channel that Add operations to the Table should be published on, if any. @@ -37,6 +38,7 @@ enum TablePubsub:int { ERROR_INFO, TASK_LEASE, DRIVER, + NODE_RESOURCE, } // Enum for the entry type in the ClientTable @@ -113,13 +115,13 @@ table ResourcePair { value: double; } -enum GcsTableNotificationMode:int { +enum GcsChangeMode:int { APPEND_OR_ADD = 0, REMOVE, } -table GcsTableEntry { - notification_mode: GcsTableNotificationMode; +table GcsEntry { + change_mode: GcsChangeMode; id: string; entries: [string]; } diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 23e611e400df..e059787472f1 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -179,32 +179,20 @@ flatbuffers::Offset RedisStringToFlatbuf( return fbb.CreateString(redis_string_str, redis_string_size); } -/// Publish a notification for an entry update at a key. This publishes a -/// notification to all subscribers of the table, as well as every client that -/// has requested notifications for this key. +/// Helper method to publish formatted data to target channel. /// /// \param pubsub_channel_str The pubsub channel name that notifications for /// this key should be published to. When publishing to a specific client, the /// channel name should be :. /// \param id The ID of the key that the notification is about. -/// \param mode the update mode, such as append or remove. -/// \param data The appended/removed data. +/// \param data_buffer The data to publish, which is a GcsEntry buffer. /// \return OK if there is no error during a publish. -int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, - RedisModuleString *id, GcsTableNotificationMode notification_mode, - RedisModuleString *data) { - // Serialize the notification to send. - flatbuffers::FlatBufferBuilder fbb; - auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = - CreateGcsTableEntry(fbb, notification_mode, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); - fbb.Finish(message); - +int PublishDataHelper(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, + RedisModuleString *id, RedisModuleString *data_buffer) { // Write the data back to any subscribers that are listening to all table // notifications. - RedisModuleCallReply *reply = RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str, - fbb.GetBufferPointer(), fbb.GetSize()); + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data_buffer); if (reply == NULL) { return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); } @@ -221,8 +209,8 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st // will be garbage collected by redis. auto channel = RedisModule_CreateString(ctx, client_channel.data(), client_channel.size()); - RedisModuleCallReply *reply = RedisModule_Call( - ctx, "PUBLISH", "sb", channel, fbb.GetBufferPointer(), fbb.GetSize()); + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", channel, data_buffer); if (reply == NULL) { return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); } @@ -231,6 +219,31 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st return RedisModule_ReplyWithSimpleString(ctx, "OK"); } +/// Publish a notification for an entry update at a key. This publishes a +/// notification to all subscribers of the table, as well as every client that +/// has requested notifications for this key. +/// +/// \param pubsub_channel_str The pubsub channel name that notifications for +/// this key should be published to. When publishing to a specific client, the +/// channel name should be :. +/// \param id The ID of the key that the notification is about. +/// \param mode the update mode, such as append or remove. +/// \param data The appended/removed data. +/// \return OK if there is no error during a publish. +int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, + RedisModuleString *id, GcsChangeMode change_mode, + RedisModuleString *data) { + // Serialize the notification to send. + flatbuffers::FlatBufferBuilder fbb; + auto data_flatbuf = RedisStringToFlatbuf(fbb, data); + auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id), + fbb.CreateVector(&data_flatbuf, 1)); + fbb.Finish(message); + auto data_buffer = RedisModule_CreateString( + ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer); +} + // RAY.TABLE_ADD: // TableAdd_RedisCommand: the actual command handler. // (helper) TableAdd_DoWrite: performs the write to redis state. @@ -266,8 +279,8 @@ int TableAdd_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the channel. - return PublishTableUpdate(ctx, pubsub_channel_str, id, - GcsTableNotificationMode::APPEND_OR_ADD, data); + return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsChangeMode::APPEND_OR_ADD, + data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -366,8 +379,8 @@ int TableAppend_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int /*a if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the // channel. - return PublishTableUpdate(ctx, pubsub_channel_str, id, - GcsTableNotificationMode::APPEND_OR_ADD, data); + return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsChangeMode::APPEND_OR_ADD, + data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -419,10 +432,9 @@ int Set_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, bool is_add) { if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the // channel. - return PublishTableUpdate(ctx, pubsub_channel_str, id, - is_add ? GcsTableNotificationMode::APPEND_OR_ADD - : GcsTableNotificationMode::REMOVE, - data); + return PublishTableUpdate( + ctx, pubsub_channel_str, id, + is_add ? GcsChangeMode::APPEND_OR_ADD : GcsChangeMode::REMOVE, data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -518,7 +530,125 @@ int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar return RedisModule_ReplyWithSimpleString(ctx, "OK"); } -/// A helper function to create and finish a GcsTableEntry, based on the +int Hash_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv) { + RedisModuleString *pubsub_channel_str = argv[2]; + RedisModuleString *id = argv[3]; + RedisModuleString *data = argv[4]; + // Publish a message on the requested pubsub channel if necessary. + TablePubsub pubsub_channel; + REPLY_AND_RETURN_IF_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str)); + if (pubsub_channel != TablePubsub::NO_PUBLISH) { + // All other pubsub channels write the data back directly onto the + // channel. + return PublishDataHelper(ctx, pubsub_channel_str, id, data); + } else { + return RedisModule_ReplyWithSimpleString(ctx, "OK"); + } +} + +/// Do the hash table write operation. This is called from by HashUpdate_RedisCommand. +/// +/// \param change_mode Output the mode of the operation: APPEND_OR_ADD or REMOVE. +/// \param deleted_data Output data if the deleted data is not the same as required. +int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, + GcsChangeMode *change_mode, RedisModuleString **changed_data) { + if (argc != 5) { + return RedisModule_WrongArity(ctx); + } + RedisModuleString *prefix_str = argv[1]; + RedisModuleString *id = argv[3]; + RedisModuleString *update_data = argv[4]; + + RedisModuleKey *key; + REPLY_AND_RETURN_IF_NOT_OK(OpenPrefixedKey( + &key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE, nullptr)); + int type = RedisModule_KeyType(key); + REPLY_AND_RETURN_IF_FALSE( + type == REDISMODULE_KEYTYPE_HASH || type == REDISMODULE_KEYTYPE_EMPTY, + "HashUpdate_DoWrite: entries must be a hash or an empty hash"); + + size_t update_data_len = 0; + const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len); + + auto data_vec = flatbuffers::GetRoot(update_data_buf); + *change_mode = data_vec->change_mode(); + if (*change_mode == GcsChangeMode::APPEND_OR_ADD) { + // This code path means they are updating command. + size_t total_size = data_vec->entries()->size(); + REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector."); + for (int i = 0; i < total_size; i += 2) { + // Reconstruct a key-value pair from a flattened list. + RedisModuleString *entry_key = RedisModule_CreateString( + ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + RedisModuleString *entry_value = + RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(), + data_vec->entries()->Get(i + 1)->size()); + // Returning 0 if key exists(still updated), 1 if the key is created. + RAY_IGNORE_EXPR( + RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL)); + } + *changed_data = update_data; + } else { + // This code path means the command wants to remove the entries. + size_t total_size = data_vec->entries()->size(); + flatbuffers::FlatBufferBuilder fbb; + std::vector> data; + for (int i = 0; i < total_size; i++) { + RedisModuleString *entry_key = RedisModule_CreateString( + ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, + REDISMODULE_HASH_DELETE, NULL); + if (deleted_num != 0) { + // The corresponding key is removed. + data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(), + data_vec->entries()->Get(i)->size())); + } + } + auto message = + CreateGcsEntry(fbb, data_vec->change_mode(), + fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()), + fbb.CreateVector(data)); + fbb.Finish(message); + *changed_data = RedisModule_CreateString( + ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + auto size = RedisModule_ValueLength(key); + if (size == 0) { + REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, + "ERR Failed to delete empty hash."); + } + } + return REDISMODULE_OK; +} + +/// Update entries for a hash table. +/// +/// This is called from a client with the command: +// +/// RAY.HASH_UPDATE +/// +/// \param table_prefix The prefix string for keys in this table. +/// \param pubsub_channel The pubsub channel name that notifications for this +/// key should be published to. When publishing to a specific client, the +/// channel name should be :. +/// \param id The ID of the key to remove from. +/// \param data The GcsEntry flatbugger data used to update this hash table. +/// 1). For deletion, this is a list of keys. +/// 2). For updating, this is a list of pairs with each key followed by the value. +/// \return OK if the remove succeeds, or an error message string if the remove +/// fails. +int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + GcsChangeMode mode; + RedisModuleString *changed_data = nullptr; + if (HashUpdate_DoWrite(ctx, argv, argc, &mode, &changed_data) != REDISMODULE_OK) { + return REDISMODULE_ERR; + } + // Replace the data with the changed data to do the publish. + std::vector new_argv(argv, argv + argc); + new_argv[4] = changed_data; + return Hash_DoPublish(ctx, new_argv.data()); +} + +/// A helper function to create and finish a GcsEntry, based on the /// current value or values at the given key. /// /// \param ctx The Redis module context. @@ -528,7 +658,7 @@ int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar /// \param prefix_str The string prefix associated with the open Redis key. /// When parsed, this is expected to be a TablePrefix. /// \param entry_id The UniqueID associated with the open Redis key. -/// \param fbb A flatbuffer builder used to build the GcsTableEntry. +/// \param fbb A flatbuffer builder used to build the GcsEntry. Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, RedisModuleString *prefix_str, RedisModuleString *entry_id, flatbuffers::FlatBufferBuilder &fbb) { @@ -539,12 +669,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); auto data = fbb.CreateString(data_buf, data_len); - auto message = CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(&data, 1)); + auto message = + CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1)); fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_LIST: + case REDISMODULE_KEYTYPE_HASH: case REDISMODULE_KEYTYPE_SET: { RedisModule_CloseKey(table_key); // Close the key before executing the command. NOTE(swang): According to @@ -561,10 +692,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, case REDISMODULE_KEYTYPE_SET: reply = RedisModule_Call(ctx, "SMEMBERS", "s", table_key_str); break; + case REDISMODULE_KEYTYPE_HASH: + reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str); + break; } // Build the flatbuffer from the set of log entries. if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { - return Status::RedisError("Empty list or wrong type"); + return Status::RedisError("Empty list/set/hash or wrong type"); } std::vector> data; for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) { @@ -574,13 +708,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, data.push_back(fbb.CreateString(element_str, len)); } auto message = - CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); + CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_EMPTY: { - auto message = CreateGcsTableEntry( - fbb, GcsTableNotificationMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), + auto message = CreateGcsEntry( + fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(std::vector>())); fbb.Finish(message); } break; @@ -637,6 +771,7 @@ static Status DeleteKeyHelper(RedisModuleCtx *ctx, RedisModuleString *prefix_str return Status::RedisError("Key does not exist."); } auto key_type = RedisModule_KeyType(delete_key); + // Set/Hash will delete itself when the length is 0. if (key_type == REDISMODULE_KEYTYPE_STRING || key_type == REDISMODULE_KEYTYPE_LIST) { // Current Table or Log only has this two types of entries. RAY_RETURN_NOT_OK( @@ -873,6 +1008,7 @@ int DebugString_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int // Wrap all Redis commands with Redis' auto memory management. AUTO_MEMORY(TableAdd_RedisCommand); +AUTO_MEMORY(HashUpdate_RedisCommand); AUTO_MEMORY(TableAppend_RedisCommand); AUTO_MEMORY(SetAdd_RedisCommand); AUTO_MEMORY(SetRemove_RedisCommand); @@ -929,6 +1065,11 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } + if (RedisModule_CreateCommand(ctx, "ray.hash_update", HashUpdate_RedisCommand, + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + if (RedisModule_CreateCommand(ctx, "ray.table_request_notifications", TableRequestNotifications_RedisCommand, "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index ffc44daa049a..e20384a04bdc 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -92,7 +92,7 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, std::vector results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); + auto root = flatbuffers::GetRoot(data.data()); RAY_CHECK(from_flatbuf(*root->id()) == id); for (size_t i = 0; i < root->entries()->size(); i++) { DataT result; @@ -114,9 +114,9 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const Callback &subscribe, const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, - const GcsTableNotificationMode notification_mode, + const GcsChangeMode change_mode, const std::vector &data) { - RAY_CHECK(notification_mode != GcsTableNotificationMode::REMOVE); + RAY_CHECK(change_mode != GcsChangeMode::REMOVE); subscribe(client, id, data); }; return Subscribe(driver_id, client_id, subscribe_wrapper, done); @@ -141,7 +141,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); + auto root = flatbuffers::GetRoot(data.data()); ID id; if (root->id()->size() > 0) { id = from_flatbuf(*root->id()); @@ -153,7 +153,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien data_root->UnPackTo(&result); results.emplace_back(std::move(result)); } - subscribe(client_, id, root->notification_mode(), results); + subscribe(client_, id, root->change_mode(), results); } } }; @@ -339,6 +339,155 @@ std::string Set::DebugString() const { return result.str(); } +template +Status Hash::Update(const DriverID &driver_id, const ID &id, + const DataMap &data_map, const HashCallback &done) { + num_adds_++; + auto callback = [this, id, data_map, done](const CallbackReply &reply) { + if (done != nullptr) { + (done)(client_, id, data_map); + } + }; + flatbuffers::FlatBufferBuilder fbb; + std::vector> data_vec; + data_vec.reserve(data_map.size() * 2); + for (auto const &pair : data_map) { + // Add the key. + data_vec.push_back(fbb.CreateString(pair.first)); + flatbuffers::FlatBufferBuilder fbb_data; + fbb_data.ForceDefaults(true); + fbb_data.Finish(Data::Pack(fbb_data, pair.second.get())); + std::string data(reinterpret_cast(fbb_data.GetBufferPointer()), + fbb_data.GetSize()); + // Add the value. + data_vec.push_back(fbb.CreateString(data)); + } + + fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec))); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); +} + +template +Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, + const std::vector &keys, + const HashRemoveCallback &remove_callback) { + num_removes_++; + auto callback = [this, id, keys, remove_callback](const CallbackReply &reply) { + if (remove_callback != nullptr) { + (remove_callback)(client_, id, keys); + } + }; + flatbuffers::FlatBufferBuilder fbb; + std::vector> data_vec; + data_vec.reserve(keys.size()); + // Add the keys. + for (auto const &key : keys) { + data_vec.push_back(fbb.CreateString(key)); + } + + fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()), + fbb.CreateVector(data_vec))); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); +} + +template +std::string Hash::DebugString() const { + std::stringstream result; + result << "num lookups: " << num_lookups_ << ", num adds: " << num_adds_ + << ", num removes: " << num_removes_; + return result.str(); +} + +template +Status Hash::Lookup(const DriverID &driver_id, const ID &id, + const HashCallback &lookup) { + num_lookups_++; + auto callback = [this, id, lookup](const CallbackReply &reply) { + if (lookup != nullptr) { + DataMap results; + if (!reply.IsNil()) { + const auto data = reply.ReadAsString(); + auto root = flatbuffers::GetRoot(data.data()); + RAY_CHECK(from_flatbuf(*root->id()) == id); + RAY_CHECK(root->entries()->size() % 2 == 0); + for (size_t i = 0; i < root->entries()->size(); i += 2) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + auto result = std::make_shared(); + auto data_root = + flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); + data_root->UnPackTo(result.get()); + results.emplace(key, std::move(result)); + } + } + lookup(client_, id, results); + } + }; + std::vector nil; + return GetRedisContext(id)->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), + prefix_, pubsub_channel_, std::move(callback)); +} + +template +Status Hash::Subscribe(const DriverID &driver_id, const ClientID &client_id, + const HashNotificationCallback &subscribe, + const SubscriptionCallback &done) { + RAY_CHECK(subscribe_callback_index_ == -1) + << "Client called Subscribe twice on the same table"; + auto callback = [this, subscribe, done](const CallbackReply &reply) { + const auto data = reply.ReadAsPubsubData(); + if (data.empty()) { + // No notification data is provided. This is the callback for the + // initial subscription request. + if (done != nullptr) { + done(client_); + } + } else { + // Data is provided. This is the callback for a message. + if (subscribe != nullptr) { + // Parse the notification. + auto root = flatbuffers::GetRoot(data.data()); + DataMap data_map; + ID id; + if (root->id()->size() > 0) { + id = from_flatbuf(*root->id()); + } + if (root->change_mode() == GcsChangeMode::REMOVE) { + for (size_t i = 0; i < root->entries()->size(); i++) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + data_map.emplace(key, std::shared_ptr()); + } + } else { + RAY_CHECK(root->entries()->size() % 2 == 0); + for (size_t i = 0; i < root->entries()->size(); i += 2) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + auto result = std::make_shared(); + auto data_root = + flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); + data_root->UnPackTo(result.get()); + data_map.emplace(key, std::move(result)); + } + } + subscribe(client_, id, root->change_mode(), data_map); + } + } + }; + + subscribe_callback_index_ = 1; + for (auto &context : shard_contexts_) { + RAY_RETURN_NOT_OK(context->SubscribeAsync(client_id, pubsub_channel_, callback, + &subscribe_callback_index_)); + } + return Status::OK(); +} + Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { auto data = std::make_shared(); @@ -696,6 +845,9 @@ template class Log; template class Table; template class Table; +template class Log; +template class Hash; + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index af42509bda96..6a1d502a7f54 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -75,9 +75,9 @@ class Log : public LogInterface, virtual public PubsubInterface { using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; - using NotificationCallback = std::function &data)>; + using NotificationCallback = std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -214,7 +214,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// to subscribe to all modifications, or to subscribe only to keys that it /// requests notifications for. This may only be called once per Log /// instance. This function is different from public version due to - /// an additional parameter notification_mode in NotificationCallback. Therefore this + /// an additional parameter change_mode in NotificationCallback. Therefore this /// function supports notifications of remove operations. /// /// \param driver_id The ID of the job (= driver). @@ -451,6 +451,157 @@ class Set : private Log, using Log::num_lookups_; }; +template +class HashInterface { + public: + using DataT = typename Data::NativeTableType; + using DataMap = std::unordered_map>; + // Reuse Log's SubscriptionCallback when Subscribe is successfully called. + using SubscriptionCallback = typename Log::SubscriptionCallback; + + /// The callback function used by function Update & Lookup. + /// + /// \param client The client on which the RemoveEntries is called. + /// \param id The ID of the Hash Table whose entries are removed. + /// \param data Map data contains the change to the Hash Table. + /// \return Void + using HashCallback = + std::function; + + /// The callback function used by function RemoveEntries. + /// + /// \param client The client on which the RemoveEntries is called. + /// \param id The ID of the Hash Table whose entries are removed. + /// \param keys The keys that are moved from this Hash Table. + /// \return Void + using HashRemoveCallback = std::function &keys)>; + + /// The notification function used by function Subscribe. + /// + /// \param client The client on which the Subscribe is called. + /// \param change_mode The mode to identify the data is removed or updated. + /// \param data Map data contains the change to the Hash Table. + /// \return Void + using HashNotificationCallback = + std::function; + + /// Add entries of a hash table. + /// + /// \param driver_id The ID of the job (= driver). + /// \param id The ID of the data that is added to the GCS. + /// \param pairs Map data to add to the hash table. + /// \param done HashCallback that is called once the request data has been written to + /// the GCS. + /// \return Status + virtual Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs, + const HashCallback &done) = 0; + + /// Remove entries from the hash table. + /// + /// \param driver_id The ID of the job (= driver). + /// \param id The ID of the data that is removed from the GCS. + /// \param keys The entry keys of the hash table. + /// \param remove_callback HashRemoveCallback that is called once the data has been + /// written to the GCS no matter whether the key exists in the hash table. + /// \return Status + virtual Status RemoveEntries(const DriverID &driver_id, const ID &id, + const std::vector &keys, + const HashRemoveCallback &remove_callback) = 0; + + /// Lookup the map data of a hash table. + /// + /// \param driver_id The ID of the job (= driver). + /// \param id The ID of the data that is looked up in the GCS. + /// \param lookup HashCallback that is called after lookup. If the callback is + /// called with an empty hash table, then there was no data in the callback. + /// \return Status + virtual Status Lookup(const DriverID &driver_id, const ID &id, + const HashCallback &lookup) = 0; + + /// Subscribe to any Update or Remove operations to this hash table. + /// + /// \param driver_id The ID of the driver. + /// \param client_id The type of update to listen to. If this is nil, then a + /// message for each Update to the table will be received. Else, only + /// messages for the given client will be received. In the latter + /// case, the client may request notifications on specific keys in the + /// table via `RequestNotifications`. + /// \param subscribe HashNotificationCallback that is called on each received message. + /// \param done SubscriptionCallback that is called when subscription is complete and + /// we are ready to receive messages. + /// \return Status + virtual Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + const HashNotificationCallback &subscribe, + const SubscriptionCallback &done) = 0; + + virtual ~HashInterface(){}; +}; + +template +class Hash : private Log, + public HashInterface, + virtual public PubsubInterface { + public: + using DataT = typename Log::DataT; + using DataMap = std::unordered_map>; + using HashCallback = typename HashInterface::HashCallback; + using HashRemoveCallback = typename HashInterface::HashRemoveCallback; + using HashNotificationCallback = + typename HashInterface::HashNotificationCallback; + using SubscriptionCallback = typename Log::SubscriptionCallback; + + Hash(const std::vector> &contexts, AsyncGcsClient *client) + : Log(contexts, client) {} + + using Log::RequestNotifications; + using Log::CancelNotifications; + + Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs, + const HashCallback &done) override; + + Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + const HashNotificationCallback &subscribe, + const SubscriptionCallback &done) override; + + Status Lookup(const DriverID &driver_id, const ID &id, + const HashCallback &lookup) override; + + Status RemoveEntries(const DriverID &driver_id, const ID &id, + const std::vector &keys, + const HashRemoveCallback &remove_callback) override; + + /// Returns debug string for class. + /// + /// \return string. + std::string DebugString() const; + + protected: + using Log::shard_contexts_; + using Log::client_; + using Log::pubsub_channel_; + using Log::prefix_; + using Log::subscribe_callback_index_; + using Log::GetRedisContext; + + int64_t num_adds_ = 0; + int64_t num_removes_ = 0; + using Log::num_lookups_; +}; + +class DynamicResourceTable : public Hash { + public: + DynamicResourceTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Hash(contexts, client) { + pubsub_channel_ = TablePubsub::NODE_RESOURCE; + prefix_ = TablePrefix::NODE_RESOURCE; + }; + + virtual ~DynamicResourceTable(){}; +}; + class ObjectTable : public Set { public: ObjectTable(const std::vector> &contexts, diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 1f05559f4b87..d2496dceb8bf 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -11,16 +11,16 @@ namespace { /// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the /// object table entries up to but not including this notification. -void UpdateObjectLocations(const GcsTableNotificationMode notification_mode, +void UpdateObjectLocations(const GcsChangeMode change_mode, const std::vector &location_updates, const ray::gcs::ClientTable &client_table, std::unordered_set *client_ids) { // location_updates contains the updates of locations of the object. - // with GcsTableNotificationMode, we can determine whether the update mode is + // with GcsChangeMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { ClientID client_id = ClientID::FromBinary(object_table_data.manager); - if (notification_mode != GcsTableNotificationMode::REMOVE) { + if (change_mode != GcsChangeMode::REMOVE) { client_ids->insert(client_id); } else { client_ids->erase(client_id); @@ -41,7 +41,7 @@ void UpdateObjectLocations(const GcsTableNotificationMode notification_mode, void ObjectDirectory::RegisterBackend() { auto object_notification_callback = [this]( gcs::AsyncGcsClient *client, const ObjectID &object_id, - const GcsTableNotificationMode notification_mode, + const GcsChangeMode change_mode, const std::vector &location_updates) { // Objects are added to this map in SubscribeObjectLocations. auto it = listeners_.find(object_id); @@ -54,8 +54,7 @@ void ObjectDirectory::RegisterBackend() { it->second.subscribed = true; // Update entries for this object. - UpdateObjectLocations(notification_mode, location_updates, - gcs_client_->client_table(), + UpdateObjectLocations(change_mode, location_updates, gcs_client_->client_table(), &it->second.current_object_locations); // Copy the callbacks so that the callbacks can unsubscribe without interrupting // looping over the callbacks. @@ -135,8 +134,7 @@ void ObjectDirectory::HandleClientRemoved(const ClientID &client_id) { if (listener.second.current_object_locations.count(client_id) > 0) { // If the subscribed object has the removed client as a location, update // its locations with an empty update so that the location will be removed. - UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, {}, - gcs_client_->client_table(), + UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, {}, gcs_client_->client_table(), &listener.second.current_object_locations); // Re-call all the subscribed callbacks for the object, since its // locations have changed. @@ -213,7 +211,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; - UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, location_updates, + UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates, gcs_client_->client_table(), &client_ids); // It is safe to call the callback directly since this is already running // in the GCS client's lookup callback stack.