From 2afc1560c43a710ce4432761a3df7ca815ed47ca Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 19 Jun 2019 21:07:57 +0800 Subject: [PATCH 01/34] gcs.proto --- BUILD.bazel | 36 ++++- bazel/ray_deps_setup.bzl | 4 +- src/ray/protobuf/gcs.proto | 283 +++++++++++++++++++++++++++++++++++++ 3 files changed, 314 insertions(+), 9 deletions(-) create mode 100644 src/ray/protobuf/gcs.proto diff --git a/BUILD.bazel b/BUILD.bazel index da36eec0cf57..f07dd3cdf300 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,22 +1,44 @@ # Bazel build # C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html -load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@//bazel:ray.bzl", "flatbuffer_py_library") load("@//bazel:cython_library.bzl", "pyx_library") COPTS = ["-DRAY_USE_GLOG"] -# Node manager gRPC lib. -grpc_proto_library( - name = "node_manager_grpc_lib", +proto_library( + name = "gcs_proto", + srcs = ["src/ray/protobuf/gcs.proto"], +) + +cc_proto_library( + name = "gcs_cc_proto", + deps = [":gcs_proto"], +) + +proto_library( + name = "node_manager_proto", srcs = ["src/ray/protobuf/node_manager.proto"], ) +cc_proto_library( + name = "node_manager_cc_proto", + deps = ["node_manager_proto"], +) + +# Node manager gRPC lib. +cc_grpc_library( + name = "node_manager_cc_grpc", + srcs = [":node_manager_proto"], + deps = [":node_manager_cc_proto"], + grpc_only = True, +) + # Node manager server and client. cc_library( - name = "node_manager_rpc_lib", + name = "node_manager_rpc", srcs = glob([ "src/ray/rpc/*.cc", ]), @@ -25,7 +47,7 @@ cc_library( ]), copts = COPTS, deps = [ - ":node_manager_grpc_lib", + ":node_manager_cc_grpc", ":ray_common", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", @@ -114,7 +136,7 @@ cc_library( ":gcs", ":gcs_fbs", ":node_manager_fbs", - ":node_manager_rpc_lib", + ":node_manager_rpc", ":object_manager", ":ray_common", ":ray_util", diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index e6dc21585699..1d0a4e7632f4 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -105,7 +105,7 @@ def ray_deps_setup(): http_archive( name = "com_github_grpc_grpc", urls = [ - "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz", + "https://github.com/grpc/grpc/archive/76a381869413834692b8ed305fbe923c0f9c4472.tar.gz", ], - strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49", + strip_prefix = "grpc-76a381869413834692b8ed305fbe923c0f9c4472", ) diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto new file mode 100644 index 000000000000..eaa317f6917a --- /dev/null +++ b/src/ray/protobuf/gcs.proto @@ -0,0 +1,283 @@ +syntax = "proto3"; + +package ray.rpc; + +enum Language { + PYTHON = 0; + CPP = 1; + JAVA = 2; +} + +// These indexes are mapped to strings in ray_redis_module.cc. +enum TablePrefix { + UNUSED = 0; + TASK = 1; + RAYLET_TASK = 2; + CLIENT = 3; + OBJECT = 4; + ACTOR = 5; + FUNCTION = 6; + TASK_RECONSTRUCTION = 7; + HEARTBEAT = 8; + HEARTBEAT_BATCH = 9; + ERROR_INFO = 10; + DRIVER = 11; + PROFILE = 12; + TASK_LEASE = 13; + ACTOR_CHECKPOINT = 14; + ACTOR_CHECKPOINT_ID = 15; + NODE_RESOURCE = 16; +} + +// The channel that Add operations to the Table should be published on, if any. +enum TablePubsub { + NO_PUBLISH = 0; + TASK_PUBSUB = 1; + RAYLET_TASK_PUBSUB = 2; + CLIENT_PUBSUB = 3; + OBJECT_PUBSUB = 4; + ACTOR_PUBSUB = 5; + HEARTBEAT_PUBSUB = 6; + HEARTBEAT_BATCH_PUBSUB = 7; + ERROR_INFO_PUBSUB = 8; + TASK_LEASE_PUBSUB = 9; + DRIVER_PUBSUB = 10; + NODE_RESOURCE_PUBSUB = 11; +} + +// Enum for the entry type in the ClientTable +enum EntryType { + INSERTION = 0; + DELETION = 1; + RES_CREATEUPDATE = 2; + RES_DELETE = 3; +} + +enum GcsChangeMode { + APPEND_OR_ADD = 0; + REMOVE = 1; +} + +message GcsEntry { + GcsChangeMode change_mode = 1; + string id = 2; + repeated string entries = 3; +} + +message FunctionTableData { + Language language = 1; + string name = 2; + string data = 3; +} + +message ObjectTableData { + // The size of the object. + uint64 object_size = 1; + // The node manager ID that this object appeared on or was evicted by. + string manager = 2; +} + +message TaskReconstructionData { + // The number of times this task has been reconstructed so far. + uint64 num_reconstructions = 1; + // The node manager that is trying to reconstruct the task. + string node_manager_id = 2; +} + +enum SchedulingState { + NONE = 0; + WAITING = 1; + SCHEDULED = 2; + QUEUED = 4; + RUNNING = 8; + DONE = 16; + LOST = 32; + RECONSTRUCTING = 64; +} + +message ClassTableData {} + +message ActorTableData { + enum ActorState { + // Actor is alive. + ALIVE = 0; + // Actor is dead, now being reconstructed. + // After reconstruction finishes, the state will become alive again. + RECONSTRUCTING = 1; + // Actor is already dead and won't be reconstructed. + DEAD = 2; + } + // The ID of the actor that was created. + string actor_id = 1; + // The dummy object ID returned by the actor creation task. If the actor + // dies, then this is the object that should be reconstructed for the actor + // to be recreated. + string actor_creation_dummy_object_id = 2; + // The ID of the driver that created the actor. + string driver_id = 3; + // The ID of the node manager that created the actor. + string node_manager_id = 4; + // Current state of this actor. + ActorState state = 5; + // Max number of times this actor should be reconstructed. + uint64 max_reconstructions = 6; + // Remaining number of reconstructions. + uint64 remaining_reconstructions = 7; +} + +message ErrorTableData { + // The ID of the driver that the error is for. + string driver_id = 1; + // The type of the error. + string type = 2; + // The error message. + string error_message = 3; + // The timestamp of the error message. + double timestamp = 4; +} + +message CustomSerializerData {} + +message ConfigTableData {} + +message ProfileEvent { + // The type of the event. + string event_type = 1; + // The start time of the event. + double start_time = 2; + // The end time of the event. If the event is a point event, then this should + // be the same as the start time. + double end_time = 3; + // Additional data associated with the event. This data must be serialized + // using JSON. + string extra_data = 4; +} + +message ProfileTableData { + // The type of the component that generated the event, e.g., worker or + // object_manager, or node_manager. + string component_type = 1; + // An identifier for the component that generated the event. + string component_id = 2; + // An identifier for the node that generated the event. + string node_ip_address = 3; + // This is a batch of profiling events. We batch these together for + // performance reasons because a single task may generate many events, and + // we don't want each event to require a GCS command. + repeated ProfileEvent profile_events = 4; +} + +message RayResource { + // The type of the resource. + string resource_name = 1; + // The total capacity of this resource type. + double resource_capacity = 2; +} + +message ClientTableData { + // The client ID of the client that the message is about. + string client_id = 1; + // The IP address of the client's node manager. + string node_manager_address = 2; + // The IPC socket name of the client's raylet. + string raylet_socket_name = 3; + // The IPC socket name of the client's plasma store. + string object_store_socket_name = 4; + // The port at which the client's node manager is listening for TCP + // connections from other node managers. + int32 node_manager_port = 5; + // The port at which the client's object manager is listening for TCP + // connections from other object managers. + int32 object_manager_port = 6; + // Enum to store the entry type in the log + EntryType entry_type = 7; + repeated string resources_total_label = 8; + repeated double resources_total_capacity = 9; +} + +message HeartbeatTableData { + // Node manager client id + string client_id = 1; + // Resource capacity currently available on this node manager. + repeated string resources_available_label = 2; + repeated double resources_available_capacity = 3; + // Total resource capacity configured for this node manager. + repeated string resources_total_label = 4; + repeated double resources_total_capacity = 5; + // Aggregate outstanding resource load on this node manager. + repeated string resource_load_label = 6; + repeated double resource_load_capacity = 7; +} + +message HeartbeatBatchTableData { + repeated HeartbeatTableData batch = 1; +} + +// Data for a lease on task execution. +message TaskLeaseData { + // Node manager client ID. + string node_manager_id = 1; + // The time that the lease was last acquired at. NOTE(swang): This is the + // system clock time according to the node that added the entry and is not + // synchronized with other nodes. + uint64 acquired_at = 2; + // The period that the lease is active for. + uint64 timeout = 3; +} + +message DriverTableData { + // The driver ID. + string driver_id = 1; + // Whether it's dead. + bool is_dead = 2; +} + +// This table stores the actor checkpoint data. An actor checkpoint +// is the snapshot of an actor's state in the actor registration. +// See `actor_registration.h` for more detailed explanation of these fields. +message ActorCheckpointData { + // ID of this actor. + string actor_id = 1; + // The dummy object ID of actor's most recently executed task. + string execution_dependency = 2; + // A list of IDs of this actor's handles. + repeated string handle_ids = 3; + // The task counters of the above handles. + repeated uint64 task_counters = 4; + // The frontier dependencies of the above handles. + repeated string frontier_dependencies = 5; + // A list of unreleased dummy objects from this actor. + repeated string unreleased_dummy_objects = 6; + // The numbers of dependencies for the above unreleased dummy objects. + repeated uint32 num_dummy_object_dependencies = 7; +} + +// This table stores the actor-to-available-checkpoint-ids mapping. +message ActorCheckpointIdData { + // ID of this actor. + string actor_id = 1; + // IDs of this actor's available checkpoints. + repeated string checkpoint_ids = 2; + // A list of the timestamps for each of the above `checkpoint_ids`. + repeated uint64 timestamps = 3; +} + +// This enum type is used as object's metadata to indicate the object's creating +// task has failed because of a certain error. +// TODO(hchen): We may want to make these errors more specific. E.g., we may want +// to distinguish between intentional and expected actor failures, and between +// worker process failure and node failure. +enum ErrorType { + // Indicates that a task failed because the worker died unexpectedly while executing it. + WORKER_DIED = 0; + // Indicates that a task failed because the actor died unexpectedly before finishing it. + ACTOR_DIED = 1; + // Indicates that an object is lost and cannot be reconstructed. + // Note, this currently only happens to actor objects. When the actor's state is already + // after the object's creating task, the actor cannot re-run the task. + // TODO(hchen): we may want to reuse this error type for more cases. E.g., + // 1) A object that was put by the driver. + // 2) The object's creating task is already cleaned up from GCS (this currently + // crashes raylet). + OBJECT_UNRECONSTRUCTABLE = 2; +} From d5afdf526ad906e0514fe7a787e7b738ec130a1b Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 20 Jun 2019 18:09:40 +0800 Subject: [PATCH 02/34] compilable --- BUILD.bazel | 3 + src/ray/gcs/format/gcs.fbs | 306 +---------------------- src/ray/gcs/gcs_entry.h | 113 +++++++++ src/ray/gcs/redis_context.cc | 4 +- src/ray/gcs/redis_context.h | 14 +- src/ray/gcs/tables.cc | 450 ++++++++++++++-------------------- src/ray/gcs/tables.h | 185 +++++++------- src/ray/gcs/util.h | 11 + src/ray/protobuf/gcs.proto | 8 + src/ray/rpc/message_wrapper.h | 63 +++++ src/ray/rpc/util.h | 26 ++ 11 files changed, 515 insertions(+), 668 deletions(-) create mode 100644 src/ray/gcs/gcs_entry.h create mode 100644 src/ray/gcs/util.h create mode 100644 src/ray/rpc/message_wrapper.h diff --git a/BUILD.bazel b/BUILD.bazel index f07dd3cdf300..0f3085ba6699 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -442,9 +442,12 @@ cc_library( copts = COPTS, includes = [ "src/ray/gcs/format", + "src/ray/gcs/gcs_entry.h", ], deps = [ + ":gcs_cc_proto", ":gcs_fbs", + ":node_manager_rpc", ":hiredis", ":node_manager_fbs", ":ray_common", diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 90476da73425..0b35cce81a0d 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -1,52 +1,7 @@ enum Language:int { - PYTHON = 0, - CPP = 1, - JAVA = 2 -} - -// These indexes are mapped to strings in ray_redis_module.cc. -enum TablePrefix:int { - UNUSED = 0, - TASK, - RAYLET_TASK, - CLIENT, - OBJECT, - ACTOR, - FUNCTION, - TASK_RECONSTRUCTION, - HEARTBEAT, - HEARTBEAT_BATCH, - ERROR_INFO, - DRIVER, - PROFILE, - TASK_LEASE, - ACTOR_CHECKPOINT, - ACTOR_CHECKPOINT_ID, - NODE_RESOURCE, -} - -// The channel that Add operations to the Table should be published on, if any. -enum TablePubsub:int { - NO_PUBLISH = 0, - TASK, - RAYLET_TASK, - CLIENT, - OBJECT, - ACTOR, - HEARTBEAT, - HEARTBEAT_BATCH, - ERROR_INFO, - TASK_LEASE, - DRIVER, - NODE_RESOURCE, -} - -// Enum for the entry type in the ClientTable -enum EntryType:int { - INSERTION = 0, - DELETION, - RES_CREATEUPDATE, - RES_DELETE, + PYTHON=0, + JAVA=1, + CPP=2, } table Arg { @@ -119,258 +74,3 @@ table ResourcePair { // The quantity of the resource. value: double; } - -enum GcsChangeMode:int { - APPEND_OR_ADD = 0, - REMOVE, -} - -table GcsEntry { - change_mode: GcsChangeMode; - id: string; - entries: [string]; -} - -table FunctionTableData { - language: Language; - name: string; - data: string; -} - -table ObjectTableData { - // The size of the object. - object_size: long; - // The node manager ID that this object appeared on or was evicted by. - manager: string; -} - -table TaskReconstructionData { - // The number of times this task has been reconstructed so far. - num_reconstructions: int; - // The node manager that is trying to reconstruct the task. - node_manager_id: string; -} - -enum SchedulingState:int { - NONE = 0, - WAITING = 1, - SCHEDULED = 2, - QUEUED = 4, - RUNNING = 8, - DONE = 16, - LOST = 32, - RECONSTRUCTING = 64 -} - -table TaskTableData { - // The state of the task. - scheduling_state: SchedulingState; - // A raylet ID. - raylet_id: string; - // A string of bytes representing the task's TaskExecutionDependencies. - execution_dependencies: string; - // The number of times the task was spilled back by raylets. - spillback_count: long; - // A string of bytes representing the task specification. - task_info: string; - // TODO(pcm): This is at the moment duplicated in task_info, remove that one - updated: bool; -} - -table TaskTableTestAndUpdate { - test_raylet_id: string; - test_state_bitmask: SchedulingState; - update_state: SchedulingState; -} - -table ClassTableData { -} - -enum ActorState:int { - // Actor is alive. - ALIVE = 0, - // Actor is dead, now being reconstructed. - // After reconstruction finishes, the state will become alive again. - RECONSTRUCTING = 1, - // Actor is already dead and won't be reconstructed. - DEAD = 2 -} - -table ActorTableData { - // The ID of the actor that was created. - actor_id: string; - // The dummy object ID returned by the actor creation task. If the actor - // dies, then this is the object that should be reconstructed for the actor - // to be recreated. - actor_creation_dummy_object_id: string; - // The ID of the driver that created the actor. - driver_id: string; - // The ID of the node manager that created the actor. - node_manager_id: string; - // Current state of this actor. - state: ActorState; - // Max number of times this actor should be reconstructed. - max_reconstructions: int; - // Remaining number of reconstructions. - remaining_reconstructions: int; -} - -table ErrorTableData { - // The ID of the driver that the error is for. - driver_id: string; - // The type of the error. - type: string; - // The error message. - error_message: string; - // The timestamp of the error message. - timestamp: double; -} - -table CustomSerializerData { -} - -table ConfigTableData { -} - -table ProfileEvent { - // The type of the event. - event_type: string; - // The start time of the event. - start_time: double; - // The end time of the event. If the event is a point event, then this should - // be the same as the start time. - end_time: double; - // Additional data associated with the event. This data must be serialized - // using JSON. - extra_data: string; -} - -table ProfileTableData { - // The type of the component that generated the event, e.g., worker or - // object_manager, or node_manager. - component_type: string; - // An identifier for the component that generated the event. - component_id: string; - // An identifier for the node that generated the event. - node_ip_address: string; - // This is a batch of profiling events. We batch these together for - // performance reasons because a single task may generate many events, and - // we don't want each event to require a GCS command. - profile_events: [ProfileEvent]; -} - -table RayResource { - // The type of the resource. - resource_name: string; - // The total capacity of this resource type. - resource_capacity: double; -} - -table ClientTableData { - // The client ID of the client that the message is about. - client_id: string; - // The IP address of the client's node manager. - node_manager_address: string; - // The IPC socket name of the client's raylet. - raylet_socket_name: string; - // The IPC socket name of the client's plasma store. - object_store_socket_name: string; - // The port at which the client's node manager is listening for TCP - // connections from other node managers. - node_manager_port: int; - // The port at which the client's object manager is listening for TCP - // connections from other object managers. - object_manager_port: int; - // Enum to store the entry type in the log - entry_type: EntryType = INSERTION; - resources_total_label: [string]; - resources_total_capacity: [double]; -} - -table HeartbeatTableData { - // Node manager client id - client_id: string; - // Resource capacity currently available on this node manager. - resources_available_label: [string]; - resources_available_capacity: [double]; - // Total resource capacity configured for this node manager. - resources_total_label: [string]; - resources_total_capacity: [double]; - // Aggregate outstanding resource load on this node manager. - resource_load_label: [string]; - resource_load_capacity: [double]; -} - -table HeartbeatBatchTableData { - batch: [HeartbeatTableData]; -} - -// Data for a lease on task execution. -table TaskLeaseData { - // Node manager client ID. - node_manager_id: string; - // The time that the lease was last acquired at. NOTE(swang): This is the - // system clock time according to the node that added the entry and is not - // synchronized with other nodes. - acquired_at: long; - // The period that the lease is active for. - timeout: long; -} - -table DriverTableData { - // The driver ID. - driver_id: string; - // Whether it's dead. - is_dead: bool; -} - -// This table stores the actor checkpoint data. An actor checkpoint -// is the snapshot of an actor's state in the actor registration. -// See `actor_registration.h` for more detailed explanation of these fields. -table ActorCheckpointData { - // ID of this actor. - actor_id: string; - // The dummy object ID of actor's most recently executed task. - execution_dependency: string; - // A list of IDs of this actor's handles. - handle_ids: [string]; - // The task counters of the above handles. - task_counters: [long]; - // The frontier dependencies of the above handles. - frontier_dependencies: [string]; - // A list of unreleased dummy objects from this actor. - unreleased_dummy_objects: [string]; - // The numbers of dependencies for the above unreleased dummy objects. - num_dummy_object_dependencies: [int]; -} - -// This table stores the actor-to-available-checkpoint-ids mapping. -table ActorCheckpointIdData { - // ID of this actor. - actor_id: string; - // IDs of this actor's available checkpoints. - // Note, this is a long string that concatenates all the IDs. - checkpoint_ids: string; - // A list of the timestamps for each of the above `checkpoint_ids`. - timestamps: [long]; -} - -// This enum type is used as object's metadata to indicate the object's creating -// task has failed because of a certain error. -// TODO(hchen): We may want to make these errors more specific. E.g., we may want -// to distinguish between intentional and expected actor failures, and between -// worker process failure and node failure. -enum ErrorType:int { - // Indicates that a task failed because the worker died unexpectedly while executing it. - WORKER_DIED = 1, - // Indicates that a task failed because the actor died unexpectedly before finishing it. - ACTOR_DIED = 2, - // Indicates that an object is lost and cannot be reconstructed. - // Note, this currently only happens to actor objects. When the actor's state is already - // after the object's creating task, the actor cannot re-run the task. - // TODO(hchen): we may want to reuse this error type for more cases. E.g., - // 1) A object that was put by the driver. - // 2) The object's creating task is already cleaned up from GCS (this currently - // crashes raylet). - OBJECT_UNRECONSTRUCTABLE = 3, -} diff --git a/src/ray/gcs/gcs_entry.h b/src/ray/gcs/gcs_entry.h new file mode 100644 index 000000000000..2ba4b4fe7144 --- /dev/null +++ b/src/ray/gcs/gcs_entry.h @@ -0,0 +1,113 @@ +#ifndef RAY_GCS_GCS_ENTRY_H +#define RAY_GCS_GCS_ENTRY_H + +#include "ray/protobuf/gcs.pb.h" +#include "ray/rpc/message_wrapper.h" +#include "ray/rpc/util.h" + +namespace ray { + +namespace gcs { + +template +class GcsEntry : public rpc::ConstMessageWrapper { + public: + explicit GcsEntry(const rpc::GcsEntry &message) : ConstMessageWrapper(message) {} + + explicit GcsEntry(const std::string &data) + : ConstMessageWrapper(ParseGcsEntryMessage(data)) {} + + GcsEntry(const ID &id, const rpc::GcsChangeMode &change_mode, + const std::vector &entries) + : ConstMessageWrapper(CreateGcsEntryMessage(id, change_mode, entries)) {} + + const ID GetId() { return ID::FromBinary(message_->id()); } + + const rpc::GcsChangeMode GetChangeMode() { return message_->change_mode(); } + + const std::vector GetEntries() { +// return rpc::VectorFromProtobuf(message_->entries()); +// XXX + return {}; + } + + private: + inline static std::unique_ptr ParseGcsEntryMessage( + const std::string &data) { + auto *gcs_entry = new rpc::GcsEntry(); + gcs_entry->ParseFromString(data); + return std::unique_ptr(gcs_entry); + } + + static inline std::unique_ptr CreateGcsEntryMessage( + const ID &id, const rpc::GcsChangeMode &change_mode, + const std::vector &entries) { + auto *gcs_entry = new rpc::GcsEntry(); + gcs_entry->set_id(id.ToBinary()); + gcs_entry->set_change_mode(change_mode); + for (const auto &entry : entries) { + std::string str; + entry.SerializeToString(&str); + gcs_entry->add_entries(std::move(str)); + } + return std::unique_ptr(gcs_entry); + } +}; + +template +class GcsMapEntry : public rpc::ConstMessageWrapper { + public: + explicit GcsMapEntry(const rpc::GcsMapEntry &message) : ConstMessageWrapper(message) {} + + explicit GcsMapEntry(const std::string &data) + : ConstMessageWrapper(ParseGcsEntryMessage(data)) {} + + GcsMapEntry(const ID &id, const rpc::GcsChangeMode &change_mode, + const std::unordered_map> &entries) + : ConstMessageWrapper(CreateGcsEntryMessage(id, change_mode, entries)) {} + + const ID GetId() { return ID::FromBinary(message_->id()); } + + const rpc::GcsChangeMode GetChangeMode() { return message_->change_mode(); } + + const std::unordered_map> GetEntries() { + std::unordered_map> map; + for (const auto &pair : message_->entries()) { + if (pair.second.empty()) { + map.emplace(pair.first, nullptr); + } else { + // XXX + auto entry = std::make_shared(); + map.emplace(pair.first, std::move(entry)); + } + } + return map; + } + + private: + inline static std::unique_ptr ParseGcsEntryMessage( + const std::string &data) { + auto *gcs_entry = new rpc::GcsMapEntry(); + gcs_entry->ParseFromString(data); + return std::unique_ptr(gcs_entry); + } + + static inline std::unique_ptr CreateGcsEntryMessage( + const ID &id, const rpc::GcsChangeMode &change_mode, + const std::unordered_map> &entries) { + auto *gcs_entry = new rpc::GcsMapEntry(); + gcs_entry->set_id(id.Binary()); + gcs_entry->set_change_mode(change_mode); + for (const auto &entry : entries) { + std::string str; + entry.second->SerializeToString(&str); + (*gcs_entry->mutable_entries())[entry.first] = std::move(str); + } + return std::unique_ptr(gcs_entry); + } +}; +} // namespace gcs + +} // namespace ray + +#endif // RAY_GCS_GCS_ENTRY_H diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index ae6cb6088cec..60606e031905 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -263,10 +263,10 @@ Status RedisContext::RunArgvAsync(const std::vector &args) { } Status RedisContext::SubscribeAsync(const ClientID &client_id, - const TablePubsub pubsub_channel, + const rpc::TablePubsub pubsub_channel, const RedisCallback &redisCallback, int64_t *out_callback_index) { - RAY_CHECK(pubsub_channel != TablePubsub::NO_PUBLISH) + RAY_CHECK(pubsub_channel != rpc::TablePubsub::NO_PUBLISH) << "Client requested subscribe on a table that does not support pubsub"; int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, true); diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index fc42e5cd98c2..e1cde567e835 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -9,7 +9,7 @@ #include "ray/common/status.h" #include "ray/util/logging.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" extern "C" { #include "ray/thirdparty/hiredis/adapters/ae.h" @@ -126,9 +126,9 @@ class RedisContext { /// -1 for unused. If set, then data must be provided. /// \return Status. template - Status RunAsync(const std::string &command, const ID &id, const uint8_t *data, - int64_t length, const TablePrefix prefix, - const TablePubsub pubsub_channel, RedisCallback redisCallback, + Status RunAsync(const std::string &command, const ID &id, const void *data, + size_t length, const rpc::TablePrefix prefix, + const rpc::TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); /// Run an arbitrary Redis command without a callback. @@ -144,7 +144,7 @@ class RedisContext { /// \param redisCallback The callback function that the notification calls. /// \param out_callback_index The output pointer to callback index. /// \return Status. - Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel, + Status SubscribeAsync(const ClientID &client_id, const rpc::TablePubsub pubsub_channel, const RedisCallback &redisCallback, int64_t *out_callback_index); redisContext *sync_context() { return context_; } redisAsyncContext *async_context() { return async_context_; } @@ -158,8 +158,8 @@ class RedisContext { template Status RedisContext::RunAsync(const std::string &command, const ID &id, - const uint8_t *data, int64_t length, - const TablePrefix prefix, const TablePubsub pubsub_channel, + const void *data, size_t length, + const rpc::TablePrefix prefix, const rpc::TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length) { int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); if (length > 0) { diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 33f1615580a6..a3f313bca9e0 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -3,6 +3,7 @@ #include "ray/common/common_protocol.h" #include "ray/common/ray_config.h" #include "ray/gcs/client.h" +#include "ray/rpc/util.h" #include "ray/util/util.h" namespace { @@ -39,48 +40,46 @@ namespace gcs { template Status Log::Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_appends_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); // Failed to append the entry. RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:" << status.ToString(); if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + std::string str; + data->SerializeToString(&str); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback)); } template Status Log::AppendAt(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) { num_appends_++; - auto callback = [this, id, dataT, done, failure](const CallbackReply &reply) { + auto callback = [this, id, data, done, failure](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); if (status.ok()) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } } else { if (failure != nullptr) { - (failure)(client_, id, *dataT); + (failure)(client_, id, *data); } } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback), log_length); + std::string str; + data->SerializeToString(&str); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback), log_length); } template @@ -89,19 +88,13 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; auto callback = [this, id, lookup](const CallbackReply &reply) { if (lookup != nullptr) { - std::vector results; if (!reply.IsNil()) { - const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); - results.emplace_back(std::move(result)); - } + GcsEntry gcs_entry(reply.ReadAsString()); + RAY_CHECK(gcs_entry.GetId() == id); + lookup(client_, id, gcs_entry.GetEntries()); + } else { + lookup(client_, id, {}); } - lookup(client_, id, results); } }; std::vector nil; @@ -114,9 +107,9 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const Callback &subscribe, const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, - const GcsChangeMode change_mode, - const std::vector &data) { - RAY_CHECK(change_mode != GcsChangeMode::REMOVE); + const rpc::GcsChangeMode change_mode, + const std::vector &data) { + RAY_CHECK(change_mode != rpc::GcsChangeMode::REMOVE); subscribe(client, id, data); }; return Subscribe(driver_id, client_id, subscribe_wrapper, done); @@ -141,19 +134,8 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); - ID id; - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - std::vector results; - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); - results.emplace_back(std::move(result)); - } - subscribe(client_, id, root->change_mode(), results); + GcsEntry gcs_entry(data); + subscribe(client_, gcs_entry.GetId(), gcs_entry.GetChangeMode(), gcs_entry.GetEntries()); } } }; @@ -234,19 +216,18 @@ std::string Log::DebugString() const { template Status Table::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + std::string str; + data->SerializeToString(&str); + return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback)); } template @@ -255,7 +236,7 @@ Status Table::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; return Log::Lookup(driver_id, id, [lookup, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { if (data.empty()) { if (failure != nullptr) { (failure)(client, id); @@ -277,7 +258,7 @@ Status Table::Subscribe(const DriverID &driver_id, const ClientID &cli return Log::Subscribe( driver_id, client_id, [subscribe, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(data.empty() || data.size() == 1); if (data.size() == 1) { subscribe(client, id, data[0]); @@ -299,36 +280,32 @@ std::string Table::DebugString() const { template Status Set::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str; + data->SerializeToString(&str); + return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, str.data(), str.length(), + prefix_, pubsub_channel_, std::move(callback)); } template Status Set::Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_removes_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str; + data->SerializeToString(&str); + return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, str.data(), str.length(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -348,25 +325,10 @@ Status Hash::Update(const DriverID &driver_id, const ID &id, (done)(client_, id, data_map); } }; - flatbuffers::FlatBufferBuilder fbb; - std::vector> data_vec; - data_vec.reserve(data_map.size() * 2); - for (auto const &pair : data_map) { - // Add the key. - data_vec.push_back(fbb.CreateString(pair.first)); - flatbuffers::FlatBufferBuilder fbb_data; - fbb_data.ForceDefaults(true); - fbb_data.Finish(Data::Pack(fbb_data, pair.second.get())); - std::string data(reinterpret_cast(fbb_data.GetBufferPointer()), - fbb_data.GetSize()); - // Add the value. - data_vec.push_back(fbb.CreateString(data)); - } - - fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec))); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, + GcsMapEntry gcs_map_entry(id, rpc::GcsChangeMode::APPEND_OR_ADD, data_map); + std::string str = gcs_map_entry.Serialize(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), + str.size(), prefix_, pubsub_channel_, std::move(callback)); } @@ -380,18 +342,12 @@ Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, (remove_callback)(client_, id, keys); } }; - flatbuffers::FlatBufferBuilder fbb; - std::vector> data_vec; - data_vec.reserve(keys.size()); - // Add the keys. - for (auto const &key : keys) { - data_vec.push_back(fbb.CreateString(key)); - } - - fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()), - fbb.CreateVector(data_vec))); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, +// GcsEntry gcs_entry(id, rpc::GcsChangeMode::REMOVE, keys); +// std::string str = gcs_entry.Serialize(); +// XXX + std::string str; + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), + str.size(), prefix_, pubsub_channel_, std::move(callback)); } @@ -412,20 +368,13 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, DataMap results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); - RAY_CHECK(root->entries()->size() % 2 == 0); - for (size_t i = 0; i < root->entries()->size(); i += 2) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - auto result = std::make_shared(); - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); - data_root->UnPackTo(result.get()); - results.emplace(key, std::move(result)); - } + GcsMapEntry gcs_map_entry(reply.ReadAsString()); + RAY_CHECK(gcs_map_entry.GetId() == id); +// lookup(client_, id, gcs_map_entry.GetEntries()); +// XXX + } else { + lookup(client_, id, {}); } - lookup(client_, id, results); } }; std::vector nil; @@ -451,31 +400,8 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); - DataMap data_map; - ID id; - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - if (root->change_mode() == GcsChangeMode::REMOVE) { - for (size_t i = 0; i < root->entries()->size(); i++) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - data_map.emplace(key, std::shared_ptr()); - } - } else { - RAY_CHECK(root->entries()->size() % 2 == 0); - for (size_t i = 0; i < root->entries()->size(); i += 2) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - auto result = std::make_shared(); - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); - data_root->UnPackTo(result.get()); - data_map.emplace(key, std::move(result)); - } - } - subscribe(client_, id, root->change_mode(), data_map); + GcsMapEntry gcs_map_entry(data); + subscribe(client_, gcs_map_entry.GetId(), gcs_map_entry.GetChangeMode(), gcs_map_entry.GetEntries()); } } }; @@ -490,36 +416,34 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { - auto data = std::make_shared(); - data->driver_id = driver_id.Binary(); - data->type = type; - data->error_message = error_message; - data->timestamp = timestamp; + auto data = std::make_shared(); + data->set_driver_id(driver_id.Binary()); + data->set_type(type); + data->set_error_message(error_message); + data->set_timestamp(timestamp); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } std::string ErrorTable::DebugString() const { - return Log::DebugString(); + return Log::DebugString(); } -Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { - auto data = std::make_shared(); - // There is some room for optimization here because the Append function will just - // call "Pack" and undo the "UnPack". - profile_events.UnPackTo(data.get()); - +Status ProfileTable::AddProfileEventBatch(const rpc::ProfileTableData &profile_events) { + auto data = std::make_shared(); + data->CopyFrom(profile_events); + // XXX return Append(DriverID::Nil(), UniqueID::FromRandom(), data, /*done_callback=*/nullptr); } std::string ProfileTable::DebugString() const { - return Log::DebugString(); + return Log::DebugString(); } Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { - auto data = std::make_shared(); - data->driver_id = driver_id.Binary(); - data->is_dead = is_dead; + auto data = std::make_shared(); + data->set_driver_id(driver_id.Binary()); + data->set_is_dead(is_dead); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -527,7 +451,7 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && (entry.second.entry_type == EntryType::INSERTION)) { + if (!entry.first.IsNil() && (entry.second.entry_type() == rpc::EntryType::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -537,7 +461,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type == EntryType::DELETION) { + if (!entry.first.IsNil() && entry.second.entry_type() == rpc::EntryType::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -549,7 +473,7 @@ void ClientTable::RegisterResourceCreateUpdatedCallback( // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - (entry.second.entry_type == EntryType::RES_CREATEUPDATE)) { + (entry.second.entry_type() == rpc::EntryType::RES_CREATEUPDATE)) { resource_createupdated_callback_(client_, entry.first, entry.second); } } @@ -559,15 +483,15 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal resource_deleted_callback_ = callback; // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type == EntryType::RES_DELETE) { + if (!entry.first.IsNil() && entry.second.entry_type() == rpc::EntryType::RES_DELETE) { resource_deleted_callback_(client_, entry.first, entry.second); } } } void ClientTable::HandleNotification(AsyncGcsClient *client, - const ClientTableDataT &data) { - ClientID client_id = ClientID::FromBinary(data.client_id); + const rpc::ClientTableData &data) { + ClientID client_id = ClientID::FromBinary(data.client_id()); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); @@ -578,16 +502,16 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } else { // If the entry is in the cache, then the notification is new if the client // was alive and is now dead or resources have been updated. - bool was_not_deleted = (entry->second.entry_type != EntryType::DELETION); - bool is_deleted = (data.entry_type == EntryType::DELETION); - bool is_res_modified = ((data.entry_type == EntryType::RES_CREATEUPDATE) || - (data.entry_type == EntryType::RES_DELETE)); + bool was_not_deleted = (entry->second.entry_type() != rpc::EntryType::DELETION); + bool is_deleted = (data.entry_type() == rpc::EntryType::DELETION); + bool is_res_modified = ((data.entry_type() == rpc::EntryType::RES_CREATEUPDATE) || + (data.entry_type() == rpc::EntryType::RES_DELETE)); is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (entry->second.entry_type == EntryType::DELETION) { - RAY_CHECK((data.entry_type == EntryType::DELETION)) + if (entry->second.entry_type() == rpc::EntryType::DELETION) { + RAY_CHECK((data.entry_type() == rpc::EntryType::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } @@ -595,64 +519,65 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // Add the notification to our cache. Notifications are idempotent. // If it is a new client or a client removal, add as is - if ((data.entry_type == EntryType::INSERTION) || - (data.entry_type == EntryType::DELETION)) { + if ((data.entry_type() == rpc::EntryType::INSERTION) || + (data.entry_type() == rpc::EntryType::DELETION)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Setting the client cache to data."; client_cache_[client_id] = data; - } else if ((data.entry_type == EntryType::RES_CREATEUPDATE) || - (data.entry_type == EntryType::RES_DELETE)) { + } else if ((data.entry_type() == rpc::EntryType::RES_CREATEUPDATE) || + (data.entry_type() == rpc::EntryType::RES_DELETE)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Updating the client cache with the delta from the log."; - ClientTableDataT &cache_data = client_cache_[client_id]; + rpc::ClientTableData &cache_data = client_cache_[client_id]; // Iterate over all resources in the new create/update notification - for (std::vector::size_type i = 0; i != data.resources_total_label.size(); i++) { - auto const &resource_name = data.resources_total_label[i]; - auto const &capacity = data.resources_total_capacity[i]; + for (std::vector::size_type i = 0; i != data.resources_total_label_size(); i++) { + auto const &resource_name = data.resources_total_label(i); + auto const &capacity = data.resources_total_capacity(i); // If resource exists in the ClientTableData, update it, else create it auto existing_resource_label = - std::find(cache_data.resources_total_label.begin(), - cache_data.resources_total_label.end(), resource_name); - if (existing_resource_label != cache_data.resources_total_label.end()) { - auto index = std::distance(cache_data.resources_total_label.begin(), + std::find(cache_data.resources_total_label().begin(), + cache_data.resources_total_label().end(), resource_name); + if (existing_resource_label != cache_data.resources_total_label().end()) { + auto index = std::distance(cache_data.resources_total_label().begin(), existing_resource_label); // Resource already exists, set capacity if updation call.. - if (data.entry_type == EntryType::RES_CREATEUPDATE) { - cache_data.resources_total_capacity[index] = capacity; + if (data.entry_type() == rpc::EntryType::RES_CREATEUPDATE) { +// cache_data.mutable_resources_total_capacity()[index] = capacity; +// XXX:639 } // .. delete if deletion call. - else if (data.entry_type == EntryType::RES_DELETE) { - cache_data.resources_total_label.erase( - cache_data.resources_total_label.begin() + index); - cache_data.resources_total_capacity.erase( - cache_data.resources_total_capacity.begin() + index); + else if (data.entry_type() == rpc::EntryType::RES_DELETE) { + cache_data.mutable_resources_total_label()->erase( + cache_data.resources_total_label().begin() + index); + cache_data.mutable_resources_total_capacity()->erase( + cache_data.resources_total_capacity().begin() + index); } } else { // Resource does not exist, create resource and add capacity if it was a resource // create call. - if (data.entry_type == EntryType::RES_CREATEUPDATE) { - cache_data.resources_total_label.push_back(resource_name); - cache_data.resources_total_capacity.push_back(capacity); + if (data.entry_type() == rpc::EntryType::RES_CREATEUPDATE) { + cache_data.add_resources_total_label(resource_name); + cache_data.add_resources_total_capacity(capacity); } } } } // If the notification is new, call any registered callbacks. - ClientTableDataT &cache_data = client_cache_[client_id]; + rpc::ClientTableData &cache_data = client_cache_[client_id]; if (is_notif_new) { - if (data.entry_type == EntryType::INSERTION) { + if (data.entry_type() == rpc::EntryType::INSERTION) { if (client_added_callback_ != nullptr) { client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else if (data.entry_type == EntryType::DELETION) { + } else if (data.entry_type() == rpc::EntryType::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. @@ -660,11 +585,11 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, if (client_removed_callback_ != nullptr) { client_removed_callback_(client, client_id, cache_data); } - } else if (data.entry_type == EntryType::RES_CREATEUPDATE) { + } else if (data.entry_type() == rpc::EntryType::RES_CREATEUPDATE) { if (resource_createupdated_callback_ != nullptr) { resource_createupdated_callback_(client, client_id, cache_data); } - } else if (data.entry_type == EntryType::RES_DELETE) { + } else if (data.entry_type() == rpc::EntryType::RES_DELETE) { if (resource_deleted_callback_ != nullptr) { resource_deleted_callback_(client, client_id, cache_data); } @@ -672,54 +597,54 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } } -void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) { - auto connected_client_id = ClientID::FromBinary(data.client_id); +void ClientTable::HandleConnected(AsyncGcsClient *client, const rpc::ClientTableData &data) { + auto connected_client_id = ClientID::FromBinary(data.client_id()); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; } const ClientID &ClientTable::GetLocalClientId() const { return client_id_; } -const ClientTableDataT &ClientTable::GetLocalClient() const { return local_client_; } +const rpc::ClientTableData &ClientTable::GetLocalClient() const { return local_client_; } bool ClientTable::IsRemoved(const ClientID &client_id) const { return removed_clients_.count(client_id) == 1; } -Status ClientTable::Connect(const ClientTableDataT &local_client) { +Status ClientTable::Connect(const rpc::ClientTableData &local_client) { RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client."; - RAY_CHECK(local_client.client_id == local_client_.client_id); + RAY_CHECK(local_client.client_id() == local_client_.client_id()); local_client_ = local_client; // Construct the data to add to the client table. - auto data = std::make_shared(local_client_); - data->entry_type = EntryType::INSERTION; + auto data = std::make_shared(local_client_); + data->set_entry_type(rpc::EntryType::INSERTION); // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, - const ClientTableDataT &data) { + const rpc::ClientTableData &data) { RAY_CHECK(log_key == client_log_key_); HandleConnected(client, data); // Callback for a notification from the client table. auto notification_callback = [this]( AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { + const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; + std::unordered_map connected_nodes; + std::unordered_map disconnected_nodes; for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type != EntryType::DELETION) { - connected_nodes.emplace(notification.client_id, notification); + if (notification.entry_type() != rpc::EntryType::DELETION) { + connected_nodes.emplace(notification.client_id(), notification); } else { - auto iter = connected_nodes.find(notification.client_id); + auto iter = connected_nodes.find(notification.client_id()); if (iter != connected_nodes.end()) { connected_nodes.erase(iter); } - disconnected_nodes.emplace(notification.client_id, notification); + disconnected_nodes.emplace(notification.client_id(), notification); } } for (const auto &pair : connected_nodes) { @@ -742,10 +667,10 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { } Status ClientTable::Disconnect(const DisconnectCallback &callback) { - auto data = std::make_shared(local_client_); - data->entry_type = EntryType::DELETION; + auto data = std::make_shared(local_client_); + data->set_entry_type(rpc::EntryType::DELETION); auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { + const rpc::ClientTableData &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(DriverID::Nil(), client_log_key_, id)); if (callback != nullptr) { @@ -759,24 +684,24 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { } ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { - auto data = std::make_shared(); - data->client_id = dead_client_id.Binary(); - data->entry_type = EntryType::DELETION; + auto data = std::make_shared(); + data->set_client_id(dead_client_id.Binary()); + data->set_entry_type(rpc::EntryType::DELETION); return Append(DriverID::Nil(), client_log_key_, data, nullptr); } void ClientTable::GetClient(const ClientID &client_id, - ClientTableDataT &client_info) const { + rpc::ClientTableData &client_info) const { RAY_CHECK(!client_id.IsNil()); auto entry = client_cache_.find(client_id); if (entry != client_cache_.end()) { client_info = entry->second; } else { - client_info.client_id = ClientID::Nil().Binary(); + client_info.set_client_id(ClientID::Nil().Binary()); } } -const std::unordered_map &ClientTable::GetAllClients() const { +const std::unordered_map &ClientTable::GetAllClients() const { return client_cache_; } @@ -787,7 +712,7 @@ Status ClientTable::Lookup(const Callback &lookup) { std::string ClientTable::DebugString() const { std::stringstream result; - result << Log::DebugString(); + result << Log::DebugString(); result << ", cache size: " << client_cache_.size() << ", num removed: " << removed_clients_.size(); return result.str(); @@ -798,55 +723,56 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id, - const ActorCheckpointIdDataT &data) { - std::shared_ptr copy = - std::make_shared(data); - copy->timestamps.push_back(current_sys_time_ms()); - copy->checkpoint_ids += checkpoint_id.Binary(); + const rpc::ActorCheckpointIdData &data) { + std::shared_ptr copy = + std::make_shared(data); + copy->add_timestamps(current_sys_time_ms()); + copy->add_checkpoint_ids(checkpoint_id.Binary()); auto num_to_keep = RayConfig::instance().num_actor_checkpoints_to_keep(); - while (copy->timestamps.size() > num_to_keep) { + while (copy->timestamps().size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. - const auto &checkpoint_id = - ActorCheckpointID::FromBinary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); - RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor " + const auto &to_delete = + ActorCheckpointID::FromBinary(copy->checkpoint_ids(0)); + RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " << actor_id; - copy->timestamps.erase(copy->timestamps.begin()); - copy->checkpoint_ids.erase(0, kUniqueIDSize); - client_->actor_checkpoint_table().Delete(driver_id, checkpoint_id); + copy->mutable_checkpoint_ids()->erase(copy->mutable_checkpoint_ids()->begin()); + copy->mutable_timestamps()->erase(copy->mutable_timestamps()->begin()); + client_->actor_checkpoint_table().Delete(driver_id, to_delete); } RAY_CHECK_OK(Add(driver_id, actor_id, copy, nullptr)); }; auto failure_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id) { - std::shared_ptr data = - std::make_shared(); - data->actor_id = id.Binary(); - data->timestamps.push_back(current_sys_time_ms()); - data->checkpoint_ids = checkpoint_id.Binary(); + std::shared_ptr data = + std::make_shared(); + data->set_actor_id(id.Binary()); + data->add_timestamps(current_sys_time_ms()); + *data->add_checkpoint_ids() = checkpoint_id.Binary(); RAY_CHECK_OK(Add(driver_id, actor_id, data, nullptr)); }; return Lookup(driver_id, actor_id, lookup_callback, failure_callback); } -template class Log; -template class Set; -template class Log; -template class Table; -template class Table; -template class Log; -template class Log; -template class Table; -template class Table; -template class Table; -template class Log; -template class Log; -template class Log; -template class Log; -template class Table; -template class Table; - -template class Log; -template class Hash; +template class Log; +template class Set; +// template class Log; +// template class Table; +template class Log; +template class Table; +template class Log; +template class Log; +template class Table; +template class Table; +template class Table; +template class Log; +template class Log; +template class Log; +template class Log; +template class Table; +template class Table; + +template class Log; +template class Hash; } // namespace gcs diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 6a1d502a7f54..11934d45ac8a 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -9,12 +9,14 @@ #include "ray/common/constants.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/gcs/gcs_entry.h" #include "ray/util/logging.h" -#include "ray/gcs/format/gcs_generated.h" +//#include "ray/gcs/format/gcs_generated.h" #include "ray/gcs/redis_context.h" // TODO(rkn): Remove this include. -#include "ray/raylet/format/node_manager_generated.h" +//#include "ray/raylet/format/node_manager_generated.h" +#include "ray/protobuf/gcs.pb.h" struct redisAsyncContext; @@ -48,13 +50,12 @@ class PubsubInterface { template class LogInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = - std::function; + std::function; virtual Status Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual Status AppendAt(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) = 0; virtual ~LogInterface(){}; }; @@ -72,12 +73,11 @@ class LogInterface { template class Log : public LogInterface, virtual public PubsubInterface { public: - using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; + const std::vector &data)>; using NotificationCallback = std::function &data)>; + const rpc::GcsChangeMode change_mode, + const std::vector &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -86,7 +86,7 @@ class Log : public LogInterface, virtual public PubsubInterface { struct CallbackData { ID id; - std::shared_ptr data; + std::shared_ptr data; Callback callback; // An optional callback to call for subscription operations, where the // first message is a notification of subscription success. @@ -98,8 +98,8 @@ class Log : public LogInterface, virtual public PubsubInterface { Log(const std::vector> &contexts, AsyncGcsClient *client) : shard_contexts_(contexts), client_(client), - pubsub_channel_(TablePubsub::NO_PUBLISH), - prefix_(TablePrefix::UNUSED), + pubsub_channel_(rpc::TablePubsub::NO_PUBLISH), + prefix_(rpc::TablePrefix::UNUSED), subscribe_callback_index_(-1){}; /// Append a log entry to a key. @@ -111,7 +111,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Append a log entry to a key if and only if the log has the given number @@ -126,7 +126,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param log_length The number of entries that the log must have for the /// append to succeed. /// \return Status - Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length); @@ -238,12 +238,12 @@ class Log : public LogInterface, virtual public PubsubInterface { AsyncGcsClient *client_; /// The pubsub channel to subscribe to for notifications about keys in this /// table. If no notifications are required, this should be set to - /// TablePubsub_NO_PUBLISH. If notifications are required, then this must be + /// rpc::TablePubsub_NO_PUBLISH. If notifications are required, then this must be /// unique across all instances of Log. - TablePubsub pubsub_channel_; + rpc::TablePubsub pubsub_channel_; /// The prefix to use for keys in this table. This must be unique across all /// instances of Log. - TablePrefix prefix_; + rpc::TablePrefix prefix_; /// The index in the RedisCallbackManager for the callback that is called /// when we receive notifications. This is >= 0 iff we have subscribed to the /// table, otherwise -1. @@ -259,10 +259,9 @@ class Log : public LogInterface, virtual public PubsubInterface { template class TableInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; virtual Status Add(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~TableInterface(){}; }; @@ -280,9 +279,8 @@ class Table : private Log, public TableInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; using Callback = - std::function; + std::function; using WriteCallback = typename Log::WriteCallback; /// The callback to call when a Lookup call returns an empty entry. using FailureCallback = std::function; @@ -305,7 +303,7 @@ class Table : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Lookup an entry asynchronously. @@ -369,12 +367,11 @@ class Table : private Log, template class SetInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + virtual Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) = 0; virtual Status Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~SetInterface(){}; }; @@ -392,7 +389,6 @@ class Set : private Log, public SetInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; using Callback = typename Log::Callback; using WriteCallback = typename Log::WriteCallback; using NotificationCallback = typename Log::NotificationCallback; @@ -414,7 +410,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Remove an entry from the set. @@ -425,7 +421,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); Status Subscribe(const DriverID &driver_id, const ClientID &client_id, @@ -454,8 +450,7 @@ class Set : private Log, template class HashInterface { public: - using DataT = typename Data::NativeTableType; - using DataMap = std::unordered_map>; + using DataMap = std::unordered_map>; // Reuse Log's SubscriptionCallback when Subscribe is successfully called. using SubscriptionCallback = typename Log::SubscriptionCallback; @@ -485,7 +480,7 @@ class HashInterface { /// \return Void using HashNotificationCallback = std::function; + const rpc::GcsChangeMode change_mode, const DataMap &data)>; /// Add entries of a hash table. /// @@ -544,8 +539,7 @@ class Hash : private Log, public HashInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; - using DataMap = std::unordered_map>; + using DataMap = std::unordered_map>; using HashCallback = typename HashInterface::HashCallback; using HashRemoveCallback = typename HashInterface::HashRemoveCallback; using HashNotificationCallback = @@ -590,59 +584,59 @@ class Hash : private Log, using Log::num_lookups_; }; -class DynamicResourceTable : public Hash { +class DynamicResourceTable : public Hash { public: DynamicResourceTable(const std::vector> &contexts, AsyncGcsClient *client) : Hash(contexts, client) { - pubsub_channel_ = TablePubsub::NODE_RESOURCE; - prefix_ = TablePrefix::NODE_RESOURCE; + pubsub_channel_ = rpc::TablePubsub::NODE_RESOURCE_PUBSUB; + prefix_ = rpc::TablePrefix::NODE_RESOURCE; }; virtual ~DynamicResourceTable(){}; }; -class ObjectTable : public Set { +class ObjectTable : public Set { public: ObjectTable(const std::vector> &contexts, AsyncGcsClient *client) : Set(contexts, client) { - pubsub_channel_ = TablePubsub::OBJECT; - prefix_ = TablePrefix::OBJECT; + pubsub_channel_ = rpc::TablePubsub::OBJECT_PUBSUB; + prefix_ = rpc::TablePrefix::OBJECT; }; virtual ~ObjectTable(){}; }; -class HeartbeatTable : public Table { +class HeartbeatTable : public Table { public: HeartbeatTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT; - prefix_ = TablePrefix::HEARTBEAT; + pubsub_channel_ = rpc::TablePubsub::HEARTBEAT_PUBSUB; + prefix_ = rpc::TablePrefix::HEARTBEAT; } virtual ~HeartbeatTable() {} }; -class HeartbeatBatchTable : public Table { +class HeartbeatBatchTable : public Table { public: HeartbeatBatchTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH; - prefix_ = TablePrefix::HEARTBEAT_BATCH; + pubsub_channel_ = rpc::TablePubsub::HEARTBEAT_BATCH_PUBSUB; + prefix_ = rpc::TablePrefix::HEARTBEAT_BATCH; } virtual ~HeartbeatBatchTable() {} }; -class DriverTable : public Log { +class DriverTable : public Log { public: DriverTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::DRIVER; - prefix_ = TablePrefix::DRIVER; + pubsub_channel_ = rpc::TablePubsub::DRIVER_PUBSUB; + prefix_ = rpc::TablePrefix::DRIVER; }; virtual ~DriverTable() {} @@ -655,54 +649,56 @@ class DriverTable : public Log { Status AppendDriverData(const DriverID &driver_id, bool is_dead); }; -class FunctionTable : public Table { +class FunctionTable : public Table { public: FunctionTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::NO_PUBLISH; - prefix_ = TablePrefix::FUNCTION; + pubsub_channel_ = rpc::TablePubsub::NO_PUBLISH; + prefix_ = rpc::TablePrefix::FUNCTION; }; }; -using ClassTable = Table; +using ClassTable = Table; /// Actor table starts with an ALIVE entry, which represents the first time the actor /// is created. This may be followed by 0 or more pairs of RECONSTRUCTING, ALIVE entries, /// which represent each time the actor fails (RECONSTRUCTING) and gets recreated (ALIVE). /// These may be followed by a DEAD entry, which means that the actor has failed and will /// not be reconstructed. -class ActorTable : public Log { +class ActorTable : public Log { public: ActorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ACTOR; - prefix_ = TablePrefix::ACTOR; + pubsub_channel_ = rpc::TablePubsub::ACTOR_PUBSUB; + prefix_ = rpc::TablePrefix::ACTOR; } }; -class TaskReconstructionLog : public Log { +class TaskReconstructionLog : public Log { public: TaskReconstructionLog(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - prefix_ = TablePrefix::TASK_RECONSTRUCTION; + prefix_ = rpc::TablePrefix::TASK_RECONSTRUCTION; } }; -class TaskLeaseTable : public Table { +class TaskLeaseTable : public Table { public: TaskLeaseTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::TASK_LEASE; - prefix_ = TablePrefix::TASK_LEASE; + pubsub_channel_ = rpc::TablePubsub::TASK_LEASE_PUBSUB; + prefix_ = rpc::TablePrefix::TASK_LEASE; } Status Add(const DriverID &driver_id, const TaskID &id, - std::shared_ptr &data, const WriteCallback &done) override { - RAY_RETURN_NOT_OK((Table::Add(driver_id, id, data, done))); + std::shared_ptr &data, + const WriteCallback &done) override { + RAY_RETURN_NOT_OK( + (Table::Add(driver_id, id, data, done))); // Mark the entry for expiration in Redis. It's okay if this command fails // since the lease entry itself contains the expiration period. In the // worst case, if the command fails, then a client that looks up the lease @@ -710,28 +706,28 @@ class TaskLeaseTable : public Table { // TODO(swang): Use a common helper function to format the key instead of // hardcoding it to match the Redis module. std::vector args = {"PEXPIRE", - EnumNameTablePrefix(prefix_) + id.Binary(), - std::to_string(data->timeout)}; + rpc::TablePrefix_Name(prefix_) + id.Binary(), + std::to_string(data->timeout())}; return GetRedisContext(id)->RunArgvAsync(args); } }; -class ActorCheckpointTable : public Table { +class ActorCheckpointTable : public Table { public: ActorCheckpointTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - prefix_ = TablePrefix::ACTOR_CHECKPOINT; + prefix_ = rpc::TablePrefix::ACTOR_CHECKPOINT; }; }; -class ActorCheckpointIdTable : public Table { +class ActorCheckpointIdTable : public Table { public: ActorCheckpointIdTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - prefix_ = TablePrefix::ACTOR_CHECKPOINT_ID; + prefix_ = rpc::TablePrefix::ACTOR_CHECKPOINT_ID; }; /// Add a checkpoint id to an actor, and remove a previous checkpoint if the @@ -747,13 +743,13 @@ class ActorCheckpointIdTable : public Table { namespace raylet { -class TaskTable : public Table { +class TaskTable : public Table { public: TaskTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::RAYLET_TASK; - prefix_ = TablePrefix::RAYLET_TASK; + pubsub_channel_ = rpc::TablePubsub::RAYLET_TASK_PUBSUB; + prefix_ = rpc::TablePrefix::RAYLET_TASK; } TaskTable(const std::vector> &contexts, @@ -765,13 +761,13 @@ class TaskTable : public Table { } // namespace raylet -class ErrorTable : private Log { +class ErrorTable : private Log { public: ErrorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ERROR_INFO; - prefix_ = TablePrefix::ERROR_INFO; + pubsub_channel_ = rpc::TablePubsub::ERROR_INFO_PUBSUB; + prefix_ = rpc::TablePrefix::ERROR_INFO; }; /// Push an error message for a specific job. @@ -795,19 +791,19 @@ class ErrorTable : private Log { std::string DebugString() const; }; -class ProfileTable : private Log { +class ProfileTable : private Log { public: ProfileTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - prefix_ = TablePrefix::PROFILE; + prefix_ = rpc::TablePrefix::PROFILE; }; /// Add a batch of profiling events to the profile table. /// /// \param profile_events The profile events to record. /// \return Status. - Status AddProfileEventBatch(const ProfileTableData &profile_events); + Status AddProfileEventBatch(const rpc::ProfileTableData &profile_events); /// Returns debug string for class. /// @@ -815,9 +811,9 @@ class ProfileTable : private Log { std::string DebugString() const; }; -using CustomSerializerTable = Table; +using CustomSerializerTable = Table; -using ConfigTable = Table; +using ConfigTable = Table; /// \class ClientTable /// @@ -828,10 +824,10 @@ using ConfigTable = Table; /// it should append an entry to the log indicating that it is dead. A client /// that is marked as dead should never again be marked as alive; if it needs /// to reconnect, it must connect with a different ClientID. -class ClientTable : public Log { +class ClientTable : public Log { public: using ClientTableCallback = std::function; + AsyncGcsClient *client, const ClientID &id, const rpc::ClientTableData &data)>; using DisconnectCallback = std::function; ClientTable(const std::vector> &contexts, AsyncGcsClient *client, const ClientID &client_id) @@ -842,11 +838,11 @@ class ClientTable : public Log { disconnected_(false), client_id_(client_id), local_client_() { - pubsub_channel_ = TablePubsub::CLIENT; - prefix_ = TablePrefix::CLIENT; + pubsub_channel_ = rpc::TablePubsub::CLIENT_PUBSUB; + prefix_ = rpc::TablePrefix::CLIENT; // Set the local client's ID. - local_client_.client_id = client_id.Binary(); + local_client_.set_client_id(client_id.Binary()); }; /// Connect as a client to the GCS. This registers us in the client table @@ -855,7 +851,7 @@ class ClientTable : public Log { /// \param Information about the connecting client. This must have the /// same client_id as the one set in the client table. /// \return Status - ray::Status Connect(const ClientTableDataT &local_client); + ray::Status Connect(const rpc::ClientTableData &local_client); /// Disconnect the client from the GCS. The client ID assigned during /// registration should never be reused after disconnecting. @@ -898,7 +894,7 @@ class ClientTable : public Log { /// about the client in the cache, then the reference will be modified to /// contain that information. Else, the reference will be updated to contain /// a nil client ID. - void GetClient(const ClientID &client, ClientTableDataT &client_info) const; + void GetClient(const ClientID &client, rpc::ClientTableData &client_info) const; /// Get the local client's ID. /// @@ -908,7 +904,7 @@ class ClientTable : public Log { /// Get the local client's information. /// /// \return The local client's information. - const ClientTableDataT &GetLocalClient() const; + const rpc::ClientTableData &GetLocalClient() const; /// Check whether the given client is removed. /// @@ -919,7 +915,7 @@ class ClientTable : public Log { /// Get the information of all clients. /// /// \return The client ID to client information map. - const std::unordered_map &GetAllClients() const; + const std::unordered_map &GetAllClients() const; /// Lookup the client data in the client table. /// @@ -940,15 +936,16 @@ class ClientTable : public Log { private: /// Handle a client table notification. - void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications); + void HandleNotification(AsyncGcsClient *client, + const rpc::ClientTableData ¬ifications); /// Handle this client's successful connection to the GCS. - void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data); + void HandleConnected(AsyncGcsClient *client, const rpc::ClientTableData &client_data); /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. const ClientID client_id_; /// Information about this client. - ClientTableDataT local_client_; + rpc::ClientTableData local_client_; /// The callback to call when a new client is added. ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. @@ -958,7 +955,7 @@ class ClientTable : public Log { /// The callback to call when a resource is deleted. ClientTableCallback resource_deleted_callback_; /// A cache for information about all clients. - std::unordered_map client_cache_; + std::unordered_map client_cache_; /// The set of removed clients. std::unordered_set removed_clients_; }; diff --git a/src/ray/gcs/util.h b/src/ray/gcs/util.h new file mode 100644 index 000000000000..b536463f1190 --- /dev/null +++ b/src/ray/gcs/util.h @@ -0,0 +1,11 @@ +#ifndef ANT_RAY_SRC_RAY_GCS_UTIL_H_ +#define ANT_RAY_SRC_RAY_GCS_UTIL_H_ + +namespace ray { +namespace gcs { + +}; + +}; // namespace ray + +#endif // ANT_RAY_SRC_RAY_GCS_UTIL_H_ diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index eaa317f6917a..482017dd9aba 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -64,6 +64,12 @@ message GcsEntry { repeated string entries = 3; } +message GcsMapEntry { + GcsChangeMode change_mode = 1; + string id = 2; + map entries = 3; +} + message FunctionTableData { Language language = 1; string name = 2; @@ -95,6 +101,8 @@ enum SchedulingState { RECONSTRUCTING = 64; } +message TaskTableData {} + message ClassTableData {} message ActorTableData { diff --git a/src/ray/rpc/message_wrapper.h b/src/ray/rpc/message_wrapper.h new file mode 100644 index 000000000000..c9feb5809375 --- /dev/null +++ b/src/ray/rpc/message_wrapper.h @@ -0,0 +1,63 @@ +#ifndef RAY_RPC_WRAPPER_H +#define RAY_RPC_WRAPPER_H + +#include +#include + +namespace ray { + +namespace rpc { + +template +class MessageWrapper { + public: + explicit MessageWrapper(Message &message) : message_(&message) {} + + explicit MessageWrapper(std::unique_ptr message) + : message_unique_ptr_(std::move(message)), message_(message_unique_ptr_.get()) {} + + MessageWrapper(const MessageWrapper &from) + : MessageWrapper(std::unique_ptr(new Message(from.GetMessage()))) {} + + const Message &GetMessage() const { return *message_; } + + const std::string Serialize() const { + std::string ret; + message_->SerializeToString(&ret); + return ret; + } + + protected: + std::unique_ptr message_unique_ptr_; + Message *message_; +}; + +template +class ConstMessageWrapper { + public: + explicit ConstMessageWrapper(const Message &message) : message_(&message) {} + + explicit ConstMessageWrapper(std::unique_ptr message) + : message_unique_ptr_(std::move(message)), message_(message_unique_ptr_.get()) {} + + ConstMessageWrapper(const ConstMessageWrapper &from) + : ConstMessageWrapper(std::unique_ptr(new Message(from.GetMessage()))) {} + + const Message &GetMessage() const { return *message_; } + + const std::string Serialize() const { + std::string ret; + message_->SerializeToString(&ret); + return ret; + } + + protected: + std::unique_ptr message_unique_ptr_; + const Message *message_; +}; + +} // namespace rpc + +} // namespace ray + +#endif // RAY_RPC_WRAPPER_H diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h index 6ecc6c3c4a34..f356b2a88b5a 100644 --- a/src/ray/rpc/util.h +++ b/src/ray/rpc/util.h @@ -1,6 +1,7 @@ #ifndef RAY_RPC_UTIL_H #define RAY_RPC_UTIL_H +#include #include #include "ray/common/status.h" @@ -27,6 +28,31 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { } } +template +inline std::unordered_map MapFromProtobuf(::google::protobuf::Map pb_map) { + return std::unordered_map(pb_map.begin(), pb_map.end()); +} + +template +inline std::vector VectorFromProtobuf( + const ::google::protobuf::RepeatedPtrField &pb_repeated) { + std::vector vector(static_cast(pb_repeated.size())); + for (const auto &item : pb_repeated) { + vector.push_back(item); + } + return vector; +} + +template +inline std::vector IdVectorFromProtobuf( + const ::google::protobuf::RepeatedPtrField<::std::string> &pb_repeated) { + auto str_vec = VectorFromProtobuf(pb_repeated); + std::vector ret(str_vec.size()); + std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(ret), + &ID::FromBinary); + return ret; +} + } // namespace rpc } // namespace ray From 4d2574aa5d1d5b978b21639dfac15507bb86d93b Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 20 Jun 2019 18:53:05 +0800 Subject: [PATCH 03/34] remove unused --- src/ray/gcs/client.h | 3 -- src/ray/gcs/tables.h | 6 ---- src/ray/protobuf/gcs.proto | 60 +++++++++++++++----------------------- 3 files changed, 24 insertions(+), 45 deletions(-) diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index c9f5b4bca624..92b7073d8fc7 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -47,8 +47,6 @@ class RAY_EXPORT AsyncGcsClient { inline FunctionTable &function_table(); // TODO: Some API for getting the error on the driver inline ClassTable &class_table(); - inline CustomSerializerTable &custom_serializer_table(); - inline ConfigTable &config_table(); ObjectTable &object_table(); raylet::TaskTable &raylet_task_table(); ActorTable &actor_table(); @@ -82,7 +80,6 @@ class RAY_EXPORT AsyncGcsClient { private: std::unique_ptr function_table_; - std::unique_ptr class_table_; std::unique_ptr object_table_; std::unique_ptr raylet_task_table_; std::unique_ptr actor_table_; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 11934d45ac8a..a11486b5b49b 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -659,8 +659,6 @@ class FunctionTable : public Table { }; }; -using ClassTable = Table; - /// Actor table starts with an ALIVE entry, which represents the first time the actor /// is created. This may be followed by 0 or more pairs of RECONSTRUCTING, ALIVE entries, /// which represent each time the actor fails (RECONSTRUCTING) and gets recreated (ALIVE). @@ -811,10 +809,6 @@ class ProfileTable : private Log { std::string DebugString() const; }; -using CustomSerializerTable = Table; - -using ConfigTable = Table; - /// \class ClientTable /// /// The ClientTable stores information about active and inactive clients. It is diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 482017dd9aba..530f5a740999 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -45,14 +45,6 @@ enum TablePubsub { NODE_RESOURCE_PUBSUB = 11; } -// Enum for the entry type in the ClientTable -enum EntryType { - INSERTION = 0; - DELETION = 1; - RES_CREATEUPDATE = 2; - RES_DELETE = 3; -} - enum GcsChangeMode { APPEND_OR_ADD = 0; REMOVE = 1; @@ -90,19 +82,10 @@ message TaskReconstructionData { string node_manager_id = 2; } -enum SchedulingState { - NONE = 0; - WAITING = 1; - SCHEDULED = 2; - QUEUED = 4; - RUNNING = 8; - DONE = 16; - LOST = 32; - RECONSTRUCTING = 64; +message TaskTableData { + string task = 1; } -message TaskTableData {} - message ClassTableData {} message ActorTableData { @@ -144,24 +127,21 @@ message ErrorTableData { double timestamp = 4; } -message CustomSerializerData {} - -message ConfigTableData {} - -message ProfileEvent { - // The type of the event. - string event_type = 1; - // The start time of the event. - double start_time = 2; - // The end time of the event. If the event is a point event, then this should - // be the same as the start time. - double end_time = 3; - // Additional data associated with the event. This data must be serialized - // using JSON. - string extra_data = 4; -} - message ProfileTableData { + + message ProfileEvent { + // The type of the event. + string event_type = 1; + // The start time of the event. + double start_time = 2; + // The end time of the event. If the event is a point event, then this should + // be the same as the start time. + double end_time = 3; + // Additional data associated with the event. This data must be serialized + // using JSON. + string extra_data = 4; + } + // The type of the component that generated the event, e.g., worker or // object_manager, or node_manager. string component_type = 1; @@ -182,6 +162,14 @@ message RayResource { double resource_capacity = 2; } +// Enum for the entry type in the ClientTable +enum EntryType { + INSERTION = 0; + DELETION = 1; + RES_CREATEUPDATE = 2; + RES_DELETE = 3; +} + message ClientTableData { // The client ID of the client that the message is about. string client_id = 1; From bf2c6039d4c155f24660d743cacfed04673c1755 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 20 Jun 2019 19:14:14 +0800 Subject: [PATCH 04/34] move entrytype inside clienttabledata --- src/ray/gcs/client.cc | 2 -- src/ray/gcs/client.h | 1 - src/ray/gcs/tables.cc | 57 +++++++++++++++++++------------------- src/ray/protobuf/gcs.proto | 19 ++++++------- 4 files changed, 36 insertions(+), 43 deletions(-) diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index c9b1e138575d..57b8694e17d6 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -208,8 +208,6 @@ ClientTable &AsyncGcsClient::client_table() { return *client_table_; } FunctionTable &AsyncGcsClient::function_table() { return *function_table_; } -ClassTable &AsyncGcsClient::class_table() { return *class_table_; } - HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() { diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 92b7073d8fc7..6addd9789fde 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -46,7 +46,6 @@ class RAY_EXPORT AsyncGcsClient { inline FunctionTable &function_table(); // TODO: Some API for getting the error on the driver - inline ClassTable &class_table(); ObjectTable &object_table(); raylet::TaskTable &raylet_task_table(); ActorTable &actor_table(); diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index a3f313bca9e0..18a3f73d8d47 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -370,8 +370,7 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, const auto data = reply.ReadAsString(); GcsMapEntry gcs_map_entry(reply.ReadAsString()); RAY_CHECK(gcs_map_entry.GetId() == id); -// lookup(client_, id, gcs_map_entry.GetEntries()); -// XXX + lookup(client_, id, gcs_map_entry.GetEntries()); } else { lookup(client_, id, {}); } @@ -451,7 +450,7 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && (entry.second.entry_type() == rpc::EntryType::INSERTION)) { + if (!entry.first.IsNil() && (entry.second.entry_type() == rpc::ClientTableData::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -461,7 +460,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type() == rpc::EntryType::DELETION) { + if (!entry.first.IsNil() && entry.second.entry_type() == rpc::ClientTableData::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -473,7 +472,7 @@ void ClientTable::RegisterResourceCreateUpdatedCallback( // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - (entry.second.entry_type() == rpc::EntryType::RES_CREATEUPDATE)) { + (entry.second.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE)) { resource_createupdated_callback_(client_, entry.first, entry.second); } } @@ -483,7 +482,7 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal resource_deleted_callback_ = callback; // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type() == rpc::EntryType::RES_DELETE) { + if (!entry.first.IsNil() && entry.second.entry_type() == rpc::ClientTableData::RES_DELETE) { resource_deleted_callback_(client_, entry.first, entry.second); } } @@ -502,16 +501,16 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } else { // If the entry is in the cache, then the notification is new if the client // was alive and is now dead or resources have been updated. - bool was_not_deleted = (entry->second.entry_type() != rpc::EntryType::DELETION); - bool is_deleted = (data.entry_type() == rpc::EntryType::DELETION); - bool is_res_modified = ((data.entry_type() == rpc::EntryType::RES_CREATEUPDATE) || - (data.entry_type() == rpc::EntryType::RES_DELETE)); + bool was_not_deleted = (entry->second.entry_type() != rpc::ClientTableData::DELETION); + bool is_deleted = (data.entry_type() == rpc::ClientTableData::DELETION); + bool is_res_modified = ((data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == rpc::ClientTableData::RES_DELETE)); is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (entry->second.entry_type() == rpc::EntryType::DELETION) { - RAY_CHECK((data.entry_type() == rpc::EntryType::DELETION)) + if (entry->second.entry_type() == rpc::ClientTableData::DELETION) { + RAY_CHECK((data.entry_type() == rpc::ClientTableData::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } @@ -519,18 +518,18 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // Add the notification to our cache. Notifications are idempotent. // If it is a new client or a client removal, add as is - if ((data.entry_type() == rpc::EntryType::INSERTION) || - (data.entry_type() == rpc::EntryType::DELETION)) { + if ((data.entry_type() == rpc::ClientTableData::INSERTION) || + (data.entry_type() == rpc::ClientTableData::DELETION)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type()) + << client_id << ". ClientTableData: " << int(data.entry_type()) << ". Setting the client cache to data."; client_cache_[client_id] = data; - } else if ((data.entry_type() == rpc::EntryType::RES_CREATEUPDATE) || - (data.entry_type() == rpc::EntryType::RES_DELETE)) { + } else if ((data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == rpc::ClientTableData::RES_DELETE)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type()) + << client_id << ". ClientTableData: " << int(data.entry_type()) << ". Updating the client cache with the delta from the log."; rpc::ClientTableData &cache_data = client_cache_[client_id]; @@ -547,12 +546,12 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, auto index = std::distance(cache_data.resources_total_label().begin(), existing_resource_label); // Resource already exists, set capacity if updation call.. - if (data.entry_type() == rpc::EntryType::RES_CREATEUPDATE) { + if (data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) { // cache_data.mutable_resources_total_capacity()[index] = capacity; // XXX:639 } // .. delete if deletion call. - else if (data.entry_type() == rpc::EntryType::RES_DELETE) { + else if (data.entry_type() == rpc::ClientTableData::RES_DELETE) { cache_data.mutable_resources_total_label()->erase( cache_data.resources_total_label().begin() + index); cache_data.mutable_resources_total_capacity()->erase( @@ -561,7 +560,7 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } else { // Resource does not exist, create resource and add capacity if it was a resource // create call. - if (data.entry_type() == rpc::EntryType::RES_CREATEUPDATE) { + if (data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) { cache_data.add_resources_total_label(resource_name); cache_data.add_resources_total_capacity(capacity); } @@ -572,12 +571,12 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // If the notification is new, call any registered callbacks. rpc::ClientTableData &cache_data = client_cache_[client_id]; if (is_notif_new) { - if (data.entry_type() == rpc::EntryType::INSERTION) { + if (data.entry_type() == rpc::ClientTableData::INSERTION) { if (client_added_callback_ != nullptr) { client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else if (data.entry_type() == rpc::EntryType::DELETION) { + } else if (data.entry_type() == rpc::ClientTableData::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. @@ -585,11 +584,11 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, if (client_removed_callback_ != nullptr) { client_removed_callback_(client, client_id, cache_data); } - } else if (data.entry_type() == rpc::EntryType::RES_CREATEUPDATE) { + } else if (data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) { if (resource_createupdated_callback_ != nullptr) { resource_createupdated_callback_(client, client_id, cache_data); } - } else if (data.entry_type() == rpc::EntryType::RES_DELETE) { + } else if (data.entry_type() == rpc::ClientTableData::RES_DELETE) { if (resource_deleted_callback_ != nullptr) { resource_deleted_callback_(client, client_id, cache_data); } @@ -619,7 +618,7 @@ Status ClientTable::Connect(const rpc::ClientTableData &local_client) { // Construct the data to add to the client table. auto data = std::make_shared(local_client_); - data->set_entry_type(rpc::EntryType::INSERTION); + data->set_entry_type(rpc::ClientTableData::INSERTION); // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, @@ -637,7 +636,7 @@ Status ClientTable::Connect(const rpc::ClientTableData &local_client) { for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type() != rpc::EntryType::DELETION) { + if (notification.entry_type() != rpc::ClientTableData::DELETION) { connected_nodes.emplace(notification.client_id(), notification); } else { auto iter = connected_nodes.find(notification.client_id()); @@ -668,7 +667,7 @@ Status ClientTable::Connect(const rpc::ClientTableData &local_client) { Status ClientTable::Disconnect(const DisconnectCallback &callback) { auto data = std::make_shared(local_client_); - data->set_entry_type(rpc::EntryType::DELETION); + data->set_entry_type(rpc::ClientTableData::DELETION); auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, const rpc::ClientTableData &data) { HandleConnected(client, data); @@ -686,7 +685,7 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { auto data = std::make_shared(); data->set_client_id(dead_client_id.Binary()); - data->set_entry_type(rpc::EntryType::DELETION); + data->set_entry_type(rpc::ClientTableData::DELETION); return Append(DriverID::Nil(), client_log_key_, data, nullptr); } diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 530f5a740999..66d886b5ca15 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -86,8 +86,6 @@ message TaskTableData { string task = 1; } -message ClassTableData {} - message ActorTableData { enum ActorState { // Actor is alive. @@ -128,7 +126,6 @@ message ErrorTableData { } message ProfileTableData { - message ProfileEvent { // The type of the event. string event_type = 1; @@ -162,15 +159,15 @@ message RayResource { double resource_capacity = 2; } -// Enum for the entry type in the ClientTable -enum EntryType { - INSERTION = 0; - DELETION = 1; - RES_CREATEUPDATE = 2; - RES_DELETE = 3; -} - message ClientTableData { + // Enum for the entry type in the ClientTable + enum EntryType { + INSERTION = 0; + DELETION = 1; + RES_CREATEUPDATE = 2; + RES_DELETE = 3; + } + // The client ID of the client that the message is about. string client_id = 1; // The IP address of the client's node manager. From a7d334acdfa50cbd08254b479f35f699c16c673b Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 20 Jun 2019 21:44:53 +0800 Subject: [PATCH 05/34] redis module compilable --- BUILD.bazel | 1 + src/ray/gcs/gcs_entry.h | 52 ------ src/ray/gcs/redis_module/ray_redis_module.cc | 178 +++++++------------ src/ray/protobuf/gcs.proto | 69 ++++--- 4 files changed, 102 insertions(+), 198 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 0f3085ba6699..7bdf2b7ace4f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -705,6 +705,7 @@ cc_binary( visibility = ["//java:__subpackages__"], deps = [ ":ray_common", + ":gcs_cc_proto", ], ) diff --git a/src/ray/gcs/gcs_entry.h b/src/ray/gcs/gcs_entry.h index 2ba4b4fe7144..8a86ae49001a 100644 --- a/src/ray/gcs/gcs_entry.h +++ b/src/ray/gcs/gcs_entry.h @@ -54,58 +54,6 @@ class GcsEntry : public rpc::ConstMessageWrapper { } }; -template -class GcsMapEntry : public rpc::ConstMessageWrapper { - public: - explicit GcsMapEntry(const rpc::GcsMapEntry &message) : ConstMessageWrapper(message) {} - - explicit GcsMapEntry(const std::string &data) - : ConstMessageWrapper(ParseGcsEntryMessage(data)) {} - - GcsMapEntry(const ID &id, const rpc::GcsChangeMode &change_mode, - const std::unordered_map> &entries) - : ConstMessageWrapper(CreateGcsEntryMessage(id, change_mode, entries)) {} - - const ID GetId() { return ID::FromBinary(message_->id()); } - - const rpc::GcsChangeMode GetChangeMode() { return message_->change_mode(); } - - const std::unordered_map> GetEntries() { - std::unordered_map> map; - for (const auto &pair : message_->entries()) { - if (pair.second.empty()) { - map.emplace(pair.first, nullptr); - } else { - // XXX - auto entry = std::make_shared(); - map.emplace(pair.first, std::move(entry)); - } - } - return map; - } - - private: - inline static std::unique_ptr ParseGcsEntryMessage( - const std::string &data) { - auto *gcs_entry = new rpc::GcsMapEntry(); - gcs_entry->ParseFromString(data); - return std::unique_ptr(gcs_entry); - } - - static inline std::unique_ptr CreateGcsEntryMessage( - const ID &id, const rpc::GcsChangeMode &change_mode, - const std::unordered_map> &entries) { - auto *gcs_entry = new rpc::GcsMapEntry(); - gcs_entry->set_id(id.Binary()); - gcs_entry->set_change_mode(change_mode); - for (const auto &entry : entries) { - std::string str; - entry.second->SerializeToString(&str); - (*gcs_entry->mutable_entries())[entry.first] = std::move(str); - } - return std::unique_ptr(gcs_entry); - } -}; } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index e291b7ffdb32..1bb258cd5a14 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -5,11 +5,16 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" #include "ray/util/logging.h" #include "redis_string.h" #include "redismodule.h" using ray::Status; +using ray::rpc::GcsChangeMode; +using ray::rpc::GcsEntry; +using ray::rpc::TablePrefix; +using ray::rpc::TablePubsub; #if RAY_USE_NEW_GCS // Under this flag, ray-project/credis will be loaded. Specifically, via @@ -64,8 +69,8 @@ Status ParseTablePubsub(TablePubsub *out, const RedisModuleString *pubsub_channe REDISMODULE_OK) { return Status::RedisError("Pubsub channel must be a valid integer."); } - if (pubsub_channel_long > static_cast(TablePubsub::MAX) || - pubsub_channel_long < static_cast(TablePubsub::MIN)) { + if (pubsub_channel_long >= static_cast(TablePubsub::TABLE_PUBSUB_MAX) || + pubsub_channel_long <= static_cast(TablePubsub::TABLE_PUBSUB_MIN)) { return Status::RedisError("Pubsub channel must be in the TablePubsub range."); } else { *out = static_cast(pubsub_channel_long); @@ -80,7 +85,7 @@ Status FormatPubsubChannel(RedisModuleString **out, RedisModuleCtx *ctx, const RedisModuleString *id) { // Format the pubsub channel enum to a string. TablePubsub_MAX should be more // than enough digits, but add 1 just in case for the null terminator. - char pubsub_channel[static_cast(TablePubsub::MAX) + 1]; + char pubsub_channel[static_cast(TablePubsub::TABLE_PUBSUB_MAX) + 1]; TablePubsub table_pubsub; RAY_RETURN_NOT_OK(ParseTablePubsub(&table_pubsub, pubsub_channel_str)); sprintf(pubsub_channel, "%d", static_cast(table_pubsub)); @@ -95,8 +100,8 @@ Status ParseTablePrefix(const RedisModuleString *table_prefix_str, TablePrefix * REDISMODULE_OK) { return Status::RedisError("Prefix must be a valid TablePrefix integer"); } - if (table_prefix_long > static_cast(TablePrefix::MAX) || - table_prefix_long < static_cast(TablePrefix::MIN)) { + if (table_prefix_long >= static_cast(TablePrefix::TABLE_PREFIX_MAX) || + table_prefix_long <= static_cast(TablePrefix::TABLE_PREFIX_MIN)) { return Status::RedisError("Prefix must be in the TablePrefix range"); } else { *out = static_cast(table_prefix_long); @@ -113,7 +118,7 @@ RedisModuleString *PrefixedKeyString(RedisModuleCtx *ctx, RedisModuleString *pre if (!ParseTablePrefix(prefix_enum, &prefix).ok()) { return nullptr; } - return RedisString_Format(ctx, "%s%S", EnumNameTablePrefix(prefix), keyname); + return RedisString_Format(ctx, "%s%S", TablePrefix_Name(prefix).c_str(), keyname); } // TODO(swang): This helper function should be deprecated by the version below, @@ -136,8 +141,8 @@ Status OpenPrefixedKey(RedisModuleKey **out, RedisModuleCtx *ctx, int mode, RedisModuleString **mutated_key_str) { TablePrefix prefix; RAY_RETURN_NOT_OK(ParseTablePrefix(prefix_enum, &prefix)); - *out = - OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode, mutated_key_str); + *out = OpenPrefixedKey(ctx, TablePrefix_Name(prefix).c_str(), keyname, mode, + mutated_key_str); return Status::OK(); } @@ -179,6 +184,20 @@ flatbuffers::Offset RedisStringToFlatbuf( return fbb.CreateString(redis_string_str, redis_string_size); } +inline void CreateGcsEntry(RedisModuleString *id, GcsChangeMode change_mode, + const std::vector &entries, + GcsEntry *result) { + const char *data; + size_t size; + data = RedisModule_StringPtrLen(id, &size); + result->set_id(data, size); + result->set_change_mode(change_mode); + for (const auto &entry : entries) { + data = RedisModule_StringPtrLen(entry, &size); + result->add_entries(data, size); + } +} + /// Helper method to publish formatted data to target channel. /// /// \param pubsub_channel_str The pubsub channel name that notifications for @@ -234,13 +253,10 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st RedisModuleString *id, GcsChangeMode change_mode, RedisModuleString *data) { // Serialize the notification to send. - flatbuffers::FlatBufferBuilder fbb; - auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); - fbb.Finish(message); - auto data_buffer = RedisModule_CreateString( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + CreateGcsEntry(id, change_mode, {data}, &gcs_entry); + std::string str = gcs_entry.SerializeAsString(); + auto data_buffer = RedisModule_CreateString(ctx, str.data(), str.size()); return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer); } @@ -570,19 +586,22 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, size_t update_data_len = 0; const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len); + GcsEntry gcs_entry; + gcs_entry.ParseFromArray(update_data_buf, update_data_len); + *change_mode = gcs_entry.change_mode(); + auto data_vec = flatbuffers::GetRoot(update_data_buf); *change_mode = data_vec->change_mode(); if (*change_mode == GcsChangeMode::APPEND_OR_ADD) { // This code path means they are updating command. - size_t total_size = data_vec->entries()->size(); + size_t total_size = gcs_entry.entries_size(); REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector."); for (int i = 0; i < total_size; i += 2) { // Reconstruct a key-value pair from a flattened list. RedisModuleString *entry_key = RedisModule_CreateString( - ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); - RedisModuleString *entry_value = - RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(), - data_vec->entries()->Get(i + 1)->size()); + ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); + RedisModuleString *entry_value = RedisModule_CreateString( + ctx, gcs_entry.entries(i + 1).data(), gcs_entry.entries(i + 1).size()); // Returning 0 if key exists(still updated), 1 if the key is created. RAY_IGNORE_EXPR( RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL)); @@ -590,27 +609,25 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, *changed_data = update_data; } else { // This code path means the command wants to remove the entries. - size_t total_size = data_vec->entries()->size(); - flatbuffers::FlatBufferBuilder fbb; - std::vector> data; + GcsEntry updated; + updated.set_id(gcs_entry.id()); + updated.set_change_mode(gcs_entry.change_mode()); + + size_t total_size = gcs_entry.entries_size(); for (int i = 0; i < total_size; i++) { RedisModuleString *entry_key = RedisModule_CreateString( - ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, REDISMODULE_HASH_DELETE, NULL); if (deleted_num != 0) { // The corresponding key is removed. - data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(), - data_vec->entries()->Get(i)->size())); + updated.add_entries(gcs_entry.entries(i)); } } - auto message = - CreateGcsEntry(fbb, data_vec->change_mode(), - fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()), - fbb.CreateVector(data)); - fbb.Finish(message); - *changed_data = RedisModule_CreateString( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + + // Serialize updated data. + std::string str = updated.SerializeAsString(); + *changed_data = RedisModule_CreateString(ctx, str.data(), str.size()); auto size = RedisModule_ValueLength(key); if (size == 0) { REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, @@ -661,18 +678,15 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a /// \param fbb A flatbuffer builder used to build the GcsEntry. Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, RedisModuleString *prefix_str, RedisModuleString *entry_id, - flatbuffers::FlatBufferBuilder &fbb) { + GcsEntry *gcs_entry) { auto key_type = RedisModule_KeyType(table_key); switch (key_type) { case REDISMODULE_KEYTYPE_STRING: { // Build the flatbuffer from the string data. + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); - auto data = fbb.CreateString(data_buf, data_len); - auto message = - CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1)); - fbb.Finish(message); + gcs_entry->add_entries(data_buf, data_len); } break; case REDISMODULE_KEYTYPE_LIST: case REDISMODULE_KEYTYPE_HASH: @@ -700,23 +714,17 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { return Status::RedisError("Empty list/set/hash or wrong type"); } - std::vector> data; + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); + std::vector data; for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) { RedisModuleCallReply *element = RedisModule_CallReplyArrayElement(reply, i); size_t len; const char *element_str = RedisModule_CallReplyStringPtr(element, &len); - data.push_back(fbb.CreateString(element_str, len)); + gcs_entry->add_entries(element_str, len); } - auto message = - CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); - fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_EMPTY: { - auto message = CreateGcsEntry( - fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(std::vector>())); - fbb.Finish(message); + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); } break; default: return Status::RedisError("Invalid Redis type during lookup."); @@ -753,10 +761,11 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int RedisModule_ReplyWithNull(ctx); } else { // Serialize the data to a flatbuffer to return to the client. - flatbuffers::FlatBufferBuilder fbb; - REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); - RedisModule_ReplyWithStringBuffer( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + REPLY_AND_RETURN_IF_NOT_OK( + TableEntryToFlatbuf(ctx, table_key, prefix_str, id, &gcs_entry)); + std::string str = gcs_entry.SerializeAsString(); + RedisModule_ReplyWithStringBuffer(ctx, str.data(), str.size()); } return REDISMODULE_OK; } @@ -870,10 +879,11 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleStrin // Publish the current value at the key to the client that is requesting // notifications. An empty notification will be published if the key is // empty. - flatbuffers::FlatBufferBuilder fbb; - REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, - reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + REPLY_AND_RETURN_IF_NOT_OK( + TableEntryToFlatbuf(ctx, table_key, prefix_str, id, &gcs_entry)); + std::string str = gcs_entry.SerializeAsString(); + RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, str.data(), str.size()); return RedisModule_ReplyWithNull(ctx); } @@ -940,53 +950,6 @@ Status IsNil(bool *out, const std::string &data) { return Status::OK(); } -// This is a temporary redis command that will be removed once -// the GCS uses https://github.com/pcmoritz/credis. -// Be careful, this only supports Task Table payloads. -int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, - int argc) { - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *update_data = argv[4]; - - RedisModuleKey *key; - REPLY_AND_RETURN_IF_NOT_OK( - OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE)); - - size_t value_len = 0; - char *value_buf = RedisModule_StringDMA(key, &value_len, REDISMODULE_READ); - - size_t update_len = 0; - const char *update_buf = RedisModule_StringPtrLen(update_data, &update_len); - - auto data = - flatbuffers::GetMutableRoot(reinterpret_cast(value_buf)); - - auto update = flatbuffers::GetRoot(update_buf); - - bool do_update = static_cast(data->scheduling_state()) & - static_cast(update->test_state_bitmask()); - - bool is_nil_result; - REPLY_AND_RETURN_IF_NOT_OK(IsNil(&is_nil_result, update->test_raylet_id()->str())); - if (!is_nil_result) { - do_update = do_update && update->test_raylet_id()->str() == data->raylet_id()->str(); - } - - if (do_update) { - REPLY_AND_RETURN_IF_FALSE(data->mutate_scheduling_state(update->update_state()), - "mutate_scheduling_state failed"); - } - REPLY_AND_RETURN_IF_FALSE(data->mutate_updated(do_update), "mutate_updated failed"); - - int result = RedisModule_ReplyWithStringBuffer(ctx, value_buf, value_len); - - return result; -} - std::string DebugString() { std::stringstream result; result << "RedisModule:"; @@ -1016,7 +979,6 @@ AUTO_MEMORY(TableLookup_RedisCommand); AUTO_MEMORY(TableRequestNotifications_RedisCommand); AUTO_MEMORY(TableDelete_RedisCommand); AUTO_MEMORY(TableCancelNotifications_RedisCommand); -AUTO_MEMORY(TableTestAndUpdate_RedisCommand); AUTO_MEMORY(DebugString_RedisCommand); #if RAY_USE_NEW_GCS AUTO_MEMORY(ChainTableAdd_RedisCommand); @@ -1082,12 +1044,6 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } - if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update", - TableTestAndUpdate_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - if (RedisModule_CreateCommand(ctx, "ray.debug_string", DebugString_RedisCommand, "readonly", 0, 0, 0) == REDISMODULE_ERR) { return REDISMODULE_ERR; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 66d886b5ca15..66a7e51d5f16 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -10,39 +10,43 @@ enum Language { // These indexes are mapped to strings in ray_redis_module.cc. enum TablePrefix { - UNUSED = 0; - TASK = 1; - RAYLET_TASK = 2; - CLIENT = 3; - OBJECT = 4; - ACTOR = 5; - FUNCTION = 6; - TASK_RECONSTRUCTION = 7; - HEARTBEAT = 8; - HEARTBEAT_BATCH = 9; - ERROR_INFO = 10; - DRIVER = 11; - PROFILE = 12; - TASK_LEASE = 13; - ACTOR_CHECKPOINT = 14; - ACTOR_CHECKPOINT_ID = 15; - NODE_RESOURCE = 16; + TABLE_PREFIX_MIN = 0; + UNUSED = 1; + TASK = 2; + RAYLET_TASK = 3; + CLIENT = 4; + OBJECT = 5; + ACTOR = 6; + FUNCTION = 7; + TASK_RECONSTRUCTION = 8; + HEARTBEAT = 9; + HEARTBEAT_BATCH = 10; + ERROR_INFO = 11; + DRIVER = 12; + PROFILE = 13; + TASK_LEASE = 14; + ACTOR_CHECKPOINT = 15; + ACTOR_CHECKPOINT_ID = 16; + NODE_RESOURCE = 17; + TABLE_PREFIX_MAX = 18; } // The channel that Add operations to the Table should be published on, if any. enum TablePubsub { - NO_PUBLISH = 0; - TASK_PUBSUB = 1; - RAYLET_TASK_PUBSUB = 2; - CLIENT_PUBSUB = 3; - OBJECT_PUBSUB = 4; - ACTOR_PUBSUB = 5; - HEARTBEAT_PUBSUB = 6; - HEARTBEAT_BATCH_PUBSUB = 7; - ERROR_INFO_PUBSUB = 8; - TASK_LEASE_PUBSUB = 9; - DRIVER_PUBSUB = 10; - NODE_RESOURCE_PUBSUB = 11; + TABLE_PUBSUB_MIN = 0; + NO_PUBLISH = 1; + TASK_PUBSUB = 2; + RAYLET_TASK_PUBSUB = 3; + CLIENT_PUBSUB = 4; + OBJECT_PUBSUB = 5; + ACTOR_PUBSUB = 6; + HEARTBEAT_PUBSUB = 7; + HEARTBEAT_BATCH_PUBSUB = 8; + ERROR_INFO_PUBSUB = 9; + TASK_LEASE_PUBSUB = 10; + DRIVER_PUBSUB = 11; + NODE_RESOURCE_PUBSUB = 12; + TABLE_PUBSUB_MAX = 13; } enum GcsChangeMode { @@ -56,12 +60,6 @@ message GcsEntry { repeated string entries = 3; } -message GcsMapEntry { - GcsChangeMode change_mode = 1; - string id = 2; - map entries = 3; -} - message FunctionTableData { Language language = 1; string name = 2; @@ -184,6 +182,7 @@ message ClientTableData { int32 object_manager_port = 6; // Enum to store the entry type in the log EntryType entry_type = 7; + repeated string resources_total_label = 8; repeated double resources_total_capacity = 9; } From 4c0bb647865098467ccab9f71ad8ca8a59d8a3c3 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 20 Jun 2019 22:06:32 +0800 Subject: [PATCH 06/34] remove gcsmapentry --- src/ray/gcs/tables.cc | 150 ++++++++++++++++++++++++++---------------- 1 file changed, 94 insertions(+), 56 deletions(-) diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 18a3f73d8d47..e90ae0b8b1cc 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -135,7 +135,8 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien if (subscribe != nullptr) { // Parse the notification. GcsEntry gcs_entry(data); - subscribe(client_, gcs_entry.GetId(), gcs_entry.GetChangeMode(), gcs_entry.GetEntries()); + subscribe(client_, gcs_entry.GetId(), gcs_entry.GetChangeMode(), + gcs_entry.GetEntries()); } } }; @@ -325,11 +326,16 @@ Status Hash::Update(const DriverID &driver_id, const ID &id, (done)(client_, id, data_map); } }; - GcsMapEntry gcs_map_entry(id, rpc::GcsChangeMode::APPEND_OR_ADD, data_map); - std::string str = gcs_map_entry.Serialize(); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), - str.size(), prefix_, pubsub_channel_, - std::move(callback)); + rpc::GcsEntry gcs_entry; + gcs_entry.set_id(id.Binary()); + gcs_entry.set_change_mode(rpc::GcsChangeMode::APPEND_OR_ADD); + for (const auto &pair : data_map) { + gcs_entry.add_entries(pair.first); + gcs_entry.add_entries(pair.second->SerializeAsString()); + } + std::string str = gcs_entry.SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -342,13 +348,15 @@ Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, (remove_callback)(client_, id, keys); } }; -// GcsEntry gcs_entry(id, rpc::GcsChangeMode::REMOVE, keys); -// std::string str = gcs_entry.Serialize(); -// XXX - std::string str; - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), - str.size(), prefix_, pubsub_channel_, - std::move(callback)); + rpc::GcsEntry gcs_entry; + gcs_entry.set_id(id.Binary()); + gcs_entry.set_change_mode(rpc::GcsChangeMode::REMOVE); + for (const auto &key : keys) { + gcs_entry.add_entries(key); + } + std::string str = gcs_entry.SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -368,9 +376,19 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, DataMap results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - GcsMapEntry gcs_map_entry(reply.ReadAsString()); - RAY_CHECK(gcs_map_entry.GetId() == id); - lookup(client_, id, gcs_map_entry.GetEntries()); + rpc::GcsEntry gcs_entry; + gcs_entry.ParseFromString(reply.ReadAsString()); + RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); + RAY_CHECK(gcs_entry.entries_size() % 2 == 0); + DataMap result; + for (int i = 0; i < gcs_entry.entries_size(); i += 2) { + const auto &key = gcs_entry.entries(i); + const auto value = std::make_shared(); + value->ParseFromString(gcs_entry.entries(i + 1)); + result.emplace(key, std::move(value)); + } + + lookup(client_, id, result); } else { lookup(client_, id, {}); } @@ -399,8 +417,24 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - GcsMapEntry gcs_map_entry(data); - subscribe(client_, gcs_map_entry.GetId(), gcs_map_entry.GetChangeMode(), gcs_map_entry.GetEntries()); + rpc::GcsEntry gcs_entry; + gcs_entry.ParseFromString(data); + ID id = ID::FromBinary(gcs_entry.id()); + DataMap data_map; + if (gcs_entry.change_mode() == rpc::GcsChangeMode::REMOVE) { + for (const auto &key : gcs_entry.entries()) { + data_map.emplace(key, std::shared_ptr()); + } + } else { + RAY_CHECK(gcs_entry.entries_size() % 2 == 0); + for (int i = 0; i < gcs_entry.entries_size(); i += 2) { + const auto &key = gcs_entry.entries(i); + const auto value = std::make_shared(); + value->ParseFromString(gcs_entry.entries(i + 1)); + data_map.emplace(key, std::move(value)); + } + } + subscribe(client_, id, gcs_entry.change_mode(), data_map); } } }; @@ -450,7 +484,8 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && (entry.second.entry_type() == rpc::ClientTableData::INSERTION)) { + if (!entry.first.IsNil() && + (entry.second.entry_type() == rpc::ClientTableData::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -460,7 +495,8 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type() == rpc::ClientTableData::DELETION) { + if (!entry.first.IsNil() && + entry.second.entry_type() == rpc::ClientTableData::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -482,7 +518,8 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal resource_deleted_callback_ = callback; // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type() == rpc::ClientTableData::RES_DELETE) { + if (!entry.first.IsNil() && + entry.second.entry_type() == rpc::ClientTableData::RES_DELETE) { resource_deleted_callback_(client_, entry.first, entry.second); } } @@ -503,8 +540,9 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // was alive and is now dead or resources have been updated. bool was_not_deleted = (entry->second.entry_type() != rpc::ClientTableData::DELETION); bool is_deleted = (data.entry_type() == rpc::ClientTableData::DELETION); - bool is_res_modified = ((data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) || - (data.entry_type() == rpc::ClientTableData::RES_DELETE)); + bool is_res_modified = + ((data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == rpc::ClientTableData::RES_DELETE)); is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check @@ -547,8 +585,8 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, existing_resource_label); // Resource already exists, set capacity if updation call.. if (data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) { -// cache_data.mutable_resources_total_capacity()[index] = capacity; -// XXX:639 + // cache_data.mutable_resources_total_capacity()[index] = capacity; + // XXX:639 } // .. delete if deletion call. else if (data.entry_type() == rpc::ClientTableData::RES_DELETE) { @@ -596,7 +634,8 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } } -void ClientTable::HandleConnected(AsyncGcsClient *client, const rpc::ClientTableData &data) { +void ClientTable::HandleConnected(AsyncGcsClient *client, + const rpc::ClientTableData &data) { auto connected_client_id = ClientID::FromBinary(data.client_id()); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; @@ -627,32 +666,32 @@ Status ClientTable::Connect(const rpc::ClientTableData &local_client) { HandleConnected(client, data); // Callback for a notification from the client table. - auto notification_callback = [this]( - AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { - RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; - for (auto ¬ification : notifications) { - // This is temporary fix for Issue 4140 to avoid connect to dead nodes. - // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type() != rpc::ClientTableData::DELETION) { - connected_nodes.emplace(notification.client_id(), notification); - } else { - auto iter = connected_nodes.find(notification.client_id()); - if (iter != connected_nodes.end()) { - connected_nodes.erase(iter); + auto notification_callback = + [this](AsyncGcsClient *client, const UniqueID &log_key, + const std::vector ¬ifications) { + RAY_CHECK(log_key == client_log_key_); + std::unordered_map connected_nodes; + std::unordered_map disconnected_nodes; + for (auto ¬ification : notifications) { + // This is temporary fix for Issue 4140 to avoid connect to dead nodes. + // TODO(yuhguo): remove this temporary fix after GCS entry is removable. + if (notification.entry_type() != rpc::ClientTableData::DELETION) { + connected_nodes.emplace(notification.client_id(), notification); + } else { + auto iter = connected_nodes.find(notification.client_id()); + if (iter != connected_nodes.end()) { + connected_nodes.erase(iter); + } + disconnected_nodes.emplace(notification.client_id(), notification); + } } - disconnected_nodes.emplace(notification.client_id(), notification); - } - } - for (const auto &pair : connected_nodes) { - HandleNotification(client, pair.second); - } - for (const auto &pair : disconnected_nodes) { - HandleNotification(client, pair.second); - } - }; + for (const auto &pair : connected_nodes) { + HandleNotification(client, pair.second); + } + for (const auto &pair : disconnected_nodes) { + HandleNotification(client, pair.second); + } + }; // Callback to request notifications from the client table once we've // successfully subscribed. auto subscription_callback = [this](AsyncGcsClient *c) { @@ -700,7 +739,8 @@ void ClientTable::GetClient(const ClientID &client_id, } } -const std::unordered_map &ClientTable::GetAllClients() const { +const std::unordered_map &ClientTable::GetAllClients() + const { return client_cache_; } @@ -730,10 +770,8 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, auto num_to_keep = RayConfig::instance().num_actor_checkpoints_to_keep(); while (copy->timestamps().size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. - const auto &to_delete = - ActorCheckpointID::FromBinary(copy->checkpoint_ids(0)); - RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " - << actor_id; + const auto &to_delete = ActorCheckpointID::FromBinary(copy->checkpoint_ids(0)); + RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " << actor_id; copy->mutable_checkpoint_ids()->erase(copy->mutable_checkpoint_ids()->begin()); copy->mutable_timestamps()->erase(copy->mutable_timestamps()->begin()); client_->actor_checkpoint_table().Delete(driver_id, to_delete); From 486303c07afd7049ee33d8bb4026433f8160dda9 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 21 Jun 2019 17:20:48 +0800 Subject: [PATCH 07/34] raylet compilable --- src/ray/gcs/tables.cc | 2 - src/ray/gcs/tables.h | 17 +- src/ray/object_manager/object_directory.cc | 34 ++-- src/ray/object_manager/object_manager.cc | 44 ++--- src/ray/object_manager/object_manager.h | 7 +- src/ray/raylet/actor_registration.cc | 51 ++--- src/ray/raylet/actor_registration.h | 22 ++- src/ray/raylet/lineage_cache.cc | 11 +- src/ray/raylet/lineage_cache.h | 6 +- src/ray/raylet/monitor.cc | 15 +- src/ray/raylet/monitor.h | 8 +- src/ray/raylet/node_manager.cc | 209 ++++++++++----------- src/ray/raylet/node_manager.h | 26 ++- src/ray/raylet/raylet.cc | 24 +-- src/ray/raylet/raylet.h | 2 + src/ray/raylet/raylet_client.cc | 7 +- src/ray/raylet/raylet_client.h | 5 +- src/ray/raylet/reconstruction_policy.cc | 10 +- src/ray/raylet/reconstruction_policy.h | 2 + src/ray/raylet/task_dependency_manager.cc | 8 +- src/ray/raylet/task_dependency_manager.h | 2 + src/ray/rpc/util.h | 12 +- 22 files changed, 280 insertions(+), 244 deletions(-) diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index e90ae0b8b1cc..ea20aa6109fd 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -792,8 +792,6 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, template class Log; template class Set; -// template class Log; -// template class Table; template class Log; template class Table; template class Log; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index a11486b5b49b..399b2f242af1 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -12,10 +12,7 @@ #include "ray/gcs/gcs_entry.h" #include "ray/util/logging.h" -//#include "ray/gcs/format/gcs_generated.h" #include "ray/gcs/redis_context.h" -// TODO(rkn): Remove this include. -//#include "ray/raylet/format/node_manager_generated.h" #include "ray/protobuf/gcs.pb.h" struct redisAsyncContext; @@ -24,6 +21,20 @@ namespace ray { namespace gcs { +using rpc::ObjectTableData; +using rpc::TaskTableData; +using rpc::ActorTableData; +using rpc::TaskReconstructionData; +using rpc::TaskLeaseData; +using rpc::HeartbeatTableData; +using rpc::ErrorTableData; +using rpc::ClientTableData; +using rpc::DriverTableData; +using rpc::ProfileTableData; +using rpc::ActorCheckpointData; +using rpc::ActorCheckpointIdData; + + class RedisContext; class AsyncGcsClient; diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 5b6794a505d3..6ffda660a300 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -8,18 +8,22 @@ ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service, namespace { +using ray::rpc::GcsChangeMode; +using ray::rpc::ObjectTableData; +using ray::rpc::ClientTableData; + /// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the /// object table entries up to but not including this notification. void UpdateObjectLocations(const GcsChangeMode change_mode, - const std::vector &location_updates, + const std::vector &location_updates, const ray::gcs::ClientTable &client_table, std::unordered_set *client_ids) { // location_updates contains the updates of locations of the object. // with GcsChangeMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { - ClientID client_id = ClientID::FromBinary(object_table_data.manager); + ClientID client_id = ClientID::FromBinary(object_table_data.manager()); if (change_mode != GcsChangeMode::REMOVE) { client_ids->insert(client_id); } else { @@ -42,7 +46,7 @@ void ObjectDirectory::RegisterBackend() { auto object_notification_callback = [this](gcs::AsyncGcsClient *client, const ObjectID &object_id, const GcsChangeMode change_mode, - const std::vector &location_updates) { + const std::vector &location_updates) { // Objects are added to this map in SubscribeObjectLocations. auto it = listeners_.find(object_id); // Do nothing for objects we are not listening for. @@ -79,9 +83,9 @@ ray::Status ObjectDirectory::ReportObjectAdded( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object added to GCS " << object_id; // Append the addition entry to the object table. - auto data = std::make_shared(); - data->manager = client_id.Binary(); - data->object_size = object_info.data_size; + auto data = std::make_shared(); + data->set_manager(client_id.Binary()); + data->set_object_size(object_info.data_size); ray::Status status = gcs_client_->object_table().Add(DriverID::Nil(), object_id, data, nullptr); return status; @@ -92,9 +96,9 @@ ray::Status ObjectDirectory::ReportObjectRemoved( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object removed to GCS " << object_id; // Append the eviction entry to the object table. - auto data = std::make_shared(); - data->manager = client_id.Binary(); - data->object_size = object_info.data_size; + auto data = std::make_shared(); + data->set_manager(client_id.Binary()); + data->set_object_size(object_info.data_size); ray::Status status = gcs_client_->object_table().Remove(DriverID::Nil(), object_id, data, nullptr); return status; @@ -102,14 +106,14 @@ ray::Status ObjectDirectory::ReportObjectRemoved( void ObjectDirectory::LookupRemoteConnectionInfo( RemoteConnectionInfo &connection_info) const { - ClientTableDataT client_data; + ClientTableData client_data; gcs_client_->client_table().GetClient(connection_info.client_id, client_data); - ClientID result_client_id = ClientID::FromBinary(client_data.client_id); + ClientID result_client_id = ClientID::FromBinary(client_data.client_id()); if (!result_client_id.IsNil()) { RAY_CHECK(result_client_id == connection_info.client_id); - if (client_data.entry_type == EntryType::INSERTION) { - connection_info.ip = client_data.node_manager_address; - connection_info.port = static_cast(client_data.object_manager_port); + if (client_data.entry_type() == ClientTableData::INSERTION) { + connection_info.ip = client_data.node_manager_address(); + connection_info.port = static_cast(client_data.object_manager_port()); } } } @@ -208,7 +212,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, status = gcs_client_->object_table().Lookup( DriverID::Nil(), object_id, [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, - const std::vector &location_updates) { + const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates, diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 954162c21aef..a8b6fbaa70a7 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -309,15 +309,15 @@ void ObjectManager::HandleSendFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEventT profile_event; - profile_event.event_type = "transfer_send"; - profile_event.start_time = start_time; - profile_event.end_time = end_time; + ProfileEvent profile_event; + profile_event.set_event_type("transfer_send"); + profile_event.set_start_time(start_time); + profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"; + "\"]"); profile_events_.push_back(profile_event); } @@ -329,15 +329,15 @@ void ObjectManager::HandleReceiveFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEventT profile_event; - profile_event.event_type = "transfer_receive"; - profile_event.start_time = start_time; - profile_event.end_time = end_time; + ProfileEvent profile_event; + profile_event.set_event_type("transfer_receive"); + profile_event.set_start_time(start_time); + profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"; + "\"]"); profile_events_.push_back(profile_event); } @@ -801,11 +801,11 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr &con ObjectID object_id = ObjectID::FromBinary(pr->object_id()->str()); ClientID client_id = ClientID::FromBinary(pr->client_id()->str()); - ProfileEventT profile_event; - profile_event.event_type = "receive_pull_request"; - profile_event.start_time = current_sys_time_seconds(); - profile_event.end_time = profile_event.start_time; - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"]"; + ProfileEvent profile_event; + profile_event.set_event_type("receive_pull_request"); + profile_event.set_start_time(current_sys_time_seconds()); + profile_event.set_end_time(profile_event.start_time()); + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"]"); profile_events_.push_back(profile_event); Push(object_id, client_id); @@ -938,13 +938,13 @@ void ObjectManager::SpreadFreeObjectRequest(const std::vector &object_ } } -ProfileTableDataT ObjectManager::GetAndResetProfilingInfo() { - ProfileTableDataT profile_info; - profile_info.component_type = "object_manager"; - profile_info.component_id = client_id_.Binary(); +ProfileTableData ObjectManager::GetAndResetProfilingInfo() { + ProfileTableData profile_info; + profile_info.set_component_type("object_manager"); + profile_info.set_component_id(client_id_.Binary()); for (auto const &profile_event : profile_events_) { - profile_info.profile_events.emplace_back(new ProfileEventT(profile_event)); + profile_info.add_profile_events()->CopyFrom(profile_event); } profile_events_.clear(); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 6318250ae3e8..1f2c2c8f78f0 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -28,6 +28,9 @@ namespace ray { +using rpc::ProfileTableData; +using ProfileEvent = rpc::ProfileTableData::ProfileEvent; + struct ObjectManagerConfig { /// The port that the object manager should use to listen for connections /// from other object managers. If this is 0, the object manager will choose @@ -180,7 +183,7 @@ class ObjectManager : public ObjectManagerInterface { /// /// \return All profiling information that has accumulated since the last call /// to this method. - ProfileTableDataT GetAndResetProfilingInfo(); + ProfileTableData GetAndResetProfilingInfo(); /// Returns debug string for class. /// @@ -412,7 +415,7 @@ class ObjectManager : public ObjectManagerInterface { /// Profiling events that are to be batched together and added to the profile /// table in the GCS. - std::vector profile_events_; + std::vector profile_events_; /// Internally maintained random number generator. std::mt19937_64 gen_; diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index cc587bc4d74e..7f940006b5be 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -8,34 +8,35 @@ namespace ray { namespace raylet { -ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data) +ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data) : actor_table_data_(actor_table_data) {} -ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data, - const ActorCheckpointDataT &checkpoint_data) +ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data, + const ActorCheckpointData &checkpoint_data) : actor_table_data_(actor_table_data), - execution_dependency_(ObjectID::FromBinary(checkpoint_data.execution_dependency)) { + execution_dependency_( + ObjectID::FromBinary(checkpoint_data.execution_dependency())) { // Restore `frontier_`. - for (size_t i = 0; i < checkpoint_data.handle_ids.size(); i++) { - auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids[i]); + for (size_t i = 0; i < checkpoint_data.handle_ids_size(); i++) { + auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids(i)); auto &frontier_entry = frontier_[handle_id]; - frontier_entry.task_counter = checkpoint_data.task_counters[i]; + frontier_entry.task_counter = checkpoint_data.task_counters(i); frontier_entry.execution_dependency = - ObjectID::FromBinary(checkpoint_data.frontier_dependencies[i]); + ObjectID::FromBinary(checkpoint_data.frontier_dependencies(i)); } // Restore `dummy_objects_`. - for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects.size(); i++) { - auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects[i]); - dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies[i]; + for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects_size(); i++) { + auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects(i)); + dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies(i); } } const ClientID ActorRegistration::GetNodeManagerId() const { - return ClientID::FromBinary(actor_table_data_.node_manager_id); + return ClientID::FromBinary(actor_table_data_.node_manager_id()); } const ObjectID ActorRegistration::GetActorCreationDependency() const { - return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id); + return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id()); } const ObjectID ActorRegistration::GetExecutionDependency() const { @@ -43,15 +44,15 @@ const ObjectID ActorRegistration::GetExecutionDependency() const { } const DriverID ActorRegistration::GetDriverId() const { - return DriverID::FromBinary(actor_table_data_.driver_id); + return DriverID::FromBinary(actor_table_data_.driver_id()); } const int64_t ActorRegistration::GetMaxReconstructions() const { - return actor_table_data_.max_reconstructions; + return actor_table_data_.max_reconstructions(); } const int64_t ActorRegistration::GetRemainingReconstructions() const { - return actor_table_data_.remaining_reconstructions; + return actor_table_data_.remaining_reconstructions(); } const std::unordered_map @@ -96,7 +97,7 @@ void ActorRegistration::AddHandle(const ActorHandleID &handle_id, int ActorRegistration::NumHandles() const { return frontier_.size(); } -std::shared_ptr ActorRegistration::GenerateCheckpointData( +std::shared_ptr ActorRegistration::GenerateCheckpointData( const ActorID &actor_id, const Task &task) { const auto actor_handle_id = task.GetTaskSpecification().ActorHandleId(); const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); @@ -109,18 +110,18 @@ std::shared_ptr ActorRegistration::GenerateCheckpointData( copy.ExtendFrontier(actor_handle_id, dummy_object); // Use actor's current state to generate checkpoint data. - auto checkpoint_data = std::make_shared(); - checkpoint_data->actor_id = actor_id.Binary(); - checkpoint_data->execution_dependency = copy.GetExecutionDependency().Binary(); + auto checkpoint_data = std::make_shared(); + checkpoint_data->set_actor_id(actor_id.Binary()); + checkpoint_data->set_execution_dependency(copy.GetExecutionDependency().Binary()); for (const auto &frontier : copy.GetFrontier()) { - checkpoint_data->handle_ids.push_back(frontier.first.Binary()); - checkpoint_data->task_counters.push_back(frontier.second.task_counter); - checkpoint_data->frontier_dependencies.push_back( + checkpoint_data->add_handle_ids(frontier.first.Binary()); + checkpoint_data->add_task_counters(frontier.second.task_counter); + checkpoint_data->add_frontier_dependencies( frontier.second.execution_dependency.Binary()); } for (const auto &entry : copy.GetDummyObjects()) { - checkpoint_data->unreleased_dummy_objects.push_back(entry.first.Binary()); - checkpoint_data->num_dummy_object_dependencies.push_back(entry.second); + checkpoint_data->add_unreleased_dummy_objects(entry.first.Binary()); + checkpoint_data->add_num_dummy_object_dependencies(entry.second); } return checkpoint_data; } diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 8d7ce2a449ec..c90636087158 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -4,13 +4,17 @@ #include #include "ray/common/id.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" #include "ray/raylet/task.h" namespace ray { namespace raylet { +using rpc::ActorTableData; +using ActorState = rpc::ActorTableData::ActorState; +using rpc::ActorCheckpointData; + /// \class ActorRegistration /// /// Information about an actor registered in the system. This includes the @@ -23,13 +27,13 @@ class ActorRegistration { /// /// \param actor_table_data Information from the global actor table about /// this actor. This includes the actor's node manager location. - ActorRegistration(const ActorTableDataT &actor_table_data); + explicit ActorRegistration(const ActorTableData &actor_table_data); /// Recreate an actor's registration from a checkpoint. /// /// \param checkpoint_data The checkpoint used to restore the actor. - ActorRegistration(const ActorTableDataT &actor_table_data, - const ActorCheckpointDataT &checkpoint_data); + ActorRegistration(const ActorTableData &actor_table_data, + const ActorCheckpointData &checkpoint_data); /// Each actor may have multiple callers, or "handles". A frontier leaf /// represents the execution state of the actor with respect to a single @@ -46,15 +50,15 @@ class ActorRegistration { /// Get the actor table data. /// /// \return The actor table data. - const ActorTableDataT &GetTableData() const { return actor_table_data_; } + const ActorTableData &GetTableData() const { return actor_table_data_; } /// Get the actor's current state (ALIVE or DEAD). /// /// \return The actor's current state. - const ActorState &GetState() const { return actor_table_data_.state; } + const ActorState GetState() const { return actor_table_data_.state(); } /// Update actor's state. - void SetState(const ActorState &state) { actor_table_data_.state = state; } + void SetState(const ActorState &state) { actor_table_data_.set_state(state); } /// Get the actor's node manager location. /// @@ -131,13 +135,13 @@ class ActorRegistration { /// \param actor_id ID of this actor. /// \param task The task that just finished on the actor. /// \return A shared pointer to the generated checkpoint data. - std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, + std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, const Task &task); private: /// Information from the global actor table about this actor, including the /// node manager location. - ActorTableDataT actor_table_data_; + ActorTableData actor_table_data_; /// The object representing the state following the actor's most recently /// executed task. The next task to execute on the actor should be marked as /// execution-dependent on this object. diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 32dddada5244..5959c9b30bb8 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -176,7 +176,7 @@ const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) co } LineageCache::LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size) : client_id_(client_id), task_storage_(task_storage), task_pubsub_(task_pubsub) {} @@ -292,15 +292,14 @@ void LineageCache::FlushTask(const TaskID &task_id) { gcs::raylet::TaskTable::WriteCallback task_callback = [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { HandleEntryCommitted(id); }; + const TaskTableData &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... flatbuffers::FlatBufferBuilder fbb; auto message = task->TaskData().ToFlatbuffer(fbb); fbb.Finish(message); - auto task_data = std::make_shared(); - auto root = flatbuffers::GetRoot(fbb.GetBufferPointer()); - root->UnPackTo(task_data.get()); + auto task_data = std::make_shared(); + task_data->set_task(reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); RAY_CHECK_OK( task_storage_.Add(DriverID(task->TaskData().GetTaskSpecification().DriverId()), task_id, task_data, task_callback)); @@ -365,8 +364,6 @@ void LineageCache::EvictTask(const TaskID &task_id) { for (const auto &child_id : children) { EvictTask(child_id); } - - return; } void LineageCache::HandleEntryCommitted(const TaskID &task_id) { diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 5436fa372fa4..7956022a77ea 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -16,6 +16,8 @@ namespace ray { namespace raylet { +using rpc::TaskTableData; + /// The status of a lineage cache entry according to its status in the GCS. /// Tasks can only transition to a higher GcsStatus (e.g., an UNCOMMITTED state /// can become COMMITTING but not vice versa). If a task is evicted from the @@ -221,7 +223,7 @@ class LineageCache { /// Create a lineage cache for the given task storage system. /// TODO(swang): Pass in the policy (interface?). LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size); /// Asynchronously commit a task to the GCS. @@ -319,7 +321,7 @@ class LineageCache { /// TODO(swang): Move the ClientID into the generic Table implementation. ClientID client_id_; /// The durable storage system for task information. - gcs::TableInterface &task_storage_; + gcs::TableInterface &task_storage_; /// The pubsub storage system for task information. This can be used to /// request notifications for the commit of a task entry. gcs::PubsubInterface &task_pubsub_; diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 62ecb00b819f..0a853260887e 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -24,14 +24,14 @@ Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_a } void Monitor::HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { heartbeats_[client_id] = num_heartbeats_timeout_; heartbeat_buffer_[client_id] = heartbeat_data; } void Monitor::Start() { const auto heartbeat_callback = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { HandleHeartbeat(id, heartbeat_data); }; RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe( @@ -49,11 +49,11 @@ void Monitor::Tick() { RAY_LOG(WARNING) << "Client timed out: " << client_id; auto lookup_callback = [this, client_id]( gcs::AsyncGcsClient *client, const ClientID &id, - const std::vector &all_data) { + const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { - if (client_id.Binary() == data.client_id && - data.entry_type == EntryType::DELETION) { + if (client_id.Binary() == data.client_id() && + data.entry_type() == ClientTableData::DELETION) { // The node has been marked dead by itself. marked = true; } @@ -84,10 +84,9 @@ void Monitor::Tick() { // Send any buffered heartbeats as a single publish. if (!heartbeat_buffer_.empty()) { - auto batch = std::make_shared(); + auto batch = std::make_shared(); for (const auto &heartbeat : heartbeat_buffer_) { - batch->batch.push_back(std::unique_ptr( - new HeartbeatTableDataT(heartbeat.second))); + batch->add_batch()->CopyFrom(heartbeat.second); } RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(DriverID::Nil(), ClientID::Nil(), batch, nullptr)); diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index c69cc9f003e0..4f05b537a295 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -11,6 +11,10 @@ namespace ray { namespace raylet { +using rpc::HeartbeatTableData; +using rpc::HeartbeatBatchTableData; +using rpc::ClientTableData; + class Monitor { public: /// Create a Raylet monitor attached to the given GCS address and port. @@ -35,7 +39,7 @@ class Monitor { /// \param client_id The client ID of the Raylet that sent the heartbeat. /// \param heartbeat_data The heartbeat sent by the client. void HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data); + const HeartbeatTableData &heartbeat_data); private: /// A client to the GCS, through which heartbeats are received. @@ -50,7 +54,7 @@ class Monitor { /// The Raylets that have been marked as dead in the client table. std::unordered_set dead_clients_; /// A buffer containing heartbeats received from node managers in the last tick. - std::unordered_map heartbeat_buffer_; + std::unordered_map heartbeat_buffer_; }; } // namespace raylet diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index fc364539ccce..ab408cd53d86 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -46,9 +46,9 @@ ActorStats GetActorStatisticalData( std::unordered_map actor_registry) { ActorStats item; for (auto &pair : actor_registry) { - if (pair.second.GetState() == ActorState::ALIVE) { + if (pair.second.GetState() == ray::rpc::ActorTableData::ALIVE) { item.live_actors += 1; - } else if (pair.second.GetState() == ActorState::RECONSTRUCTING) { + } else if (pair.second.GetState() == ray::rpc::ActorTableData::RECONSTRUCTING) { item.reconstructing_actors += 1; } else { item.dead_actors += 1; @@ -130,7 +130,7 @@ ray::Status NodeManager::RegisterGcs() { // that were executed remotely. const auto task_committed_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &task_data) { + const TaskTableData &task_data) { lineage_cache_.HandleEntryCommitted(task_id); }; RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe( @@ -139,8 +139,8 @@ ray::Status NodeManager::RegisterGcs() { const auto task_lease_notification_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseDataT &task_lease) { - const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id); + const TaskLeaseData &task_lease) { + const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id()); if (gcs_client_->client_table().IsRemoved(node_manager_id)) { // The node manager that added the task lease is already removed. The // lease is considered inactive. @@ -150,7 +150,7 @@ ray::Status NodeManager::RegisterGcs() { // expiration period since the entry may have been in the GCS for some // time already. For a more accurate estimate, the age of the entry in // the GCS should be subtracted from task_lease.timeout. - reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout); + reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout()); } }; const auto task_lease_empty_callback = [this](gcs::AsyncGcsClient *client, @@ -164,7 +164,7 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback to handle actor notifications. auto actor_notification_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // We only need the last entry, because it represents the latest state of // this actor. @@ -177,34 +177,34 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { + const ClientTableData &data) { ClientAdded(data); }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. auto node_manager_client_removed = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ClientRemoved(data); }; + const ClientTableData &data) { ClientRemoved(data); }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Register a callback on the client table for resource create/update requests auto node_manager_resource_createupdated = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ResourceCreateUpdated(data); }; + const ClientTableData &data) { ResourceCreateUpdated(data); }; gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( node_manager_resource_createupdated); // Register a callback on the client table for resource delete requests auto node_manager_resource_deleted = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ResourceDeleted(data); }; + const ClientTableData &data) { ResourceDeleted(data); }; gcs_client_->client_table().RegisterResourceDeletedCallback( node_manager_resource_deleted); // Subscribe to heartbeat batches from the monitor. const auto &heartbeat_batch_added = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatBatchTableDataT &heartbeat_batch) { + const HeartbeatBatchTableData &heartbeat_batch) { HeartbeatBatchAdded(heartbeat_batch); }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( @@ -215,7 +215,7 @@ ray::Status NodeManager::RegisterGcs() { // Subscribe to driver table updates. const auto driver_table_handler = [this](gcs::AsyncGcsClient *client, const DriverID &client_id, - const std::vector &driver_data) { + const std::vector &driver_data) { HandleDriverTableUpdate(client_id, driver_data); }; RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe( @@ -251,12 +251,12 @@ void NodeManager::KillWorker(std::shared_ptr worker) { } void NodeManager::HandleDriverTableUpdate( - const DriverID &id, const std::vector &driver_data) { + const DriverID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { - RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::FromBinary(entry.driver_id) - << " " << entry.is_dead; - if (entry.is_dead) { - auto driver_id = DriverID::FromBinary(entry.driver_id); + RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::FromBinary(entry.driver_id()) + << " " << entry.is_dead(); + if (entry.is_dead()) { + auto driver_id = DriverID::FromBinary(entry.driver_id()); auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); // Kill all the workers. The actual cleanup for these workers is done @@ -288,26 +288,26 @@ void NodeManager::Heartbeat() { last_heartbeat_at_ms_ = now_ms; auto &heartbeat_table = gcs_client_->heartbeat_table(); - auto heartbeat_data = std::make_shared(); + auto heartbeat_data = std::make_shared(); const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); SchedulingResources &local_resources = cluster_resource_map_[my_client_id]; - heartbeat_data->client_id = my_client_id.Binary(); + heartbeat_data->set_client_id(my_client_id.Binary()); // TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly. // TODO(atumanov): implement a ResourceSet const_iterator. for (const auto &resource_pair : local_resources.GetAvailableResources().GetResourceMap()) { - heartbeat_data->resources_available_label.push_back(resource_pair.first); - heartbeat_data->resources_available_capacity.push_back(resource_pair.second); + heartbeat_data->add_resources_available_label(resource_pair.first); + heartbeat_data->add_resources_available_capacity(resource_pair.second); } for (const auto &resource_pair : local_resources.GetTotalResources().GetResourceMap()) { - heartbeat_data->resources_total_label.push_back(resource_pair.first); - heartbeat_data->resources_total_capacity.push_back(resource_pair.second); + heartbeat_data->add_resources_total_label(resource_pair.first); + heartbeat_data->add_resources_total_capacity(resource_pair.second); } local_resources.SetLoadResources(local_queues_.GetResourceLoad()); for (const auto &resource_pair : local_resources.GetLoadResources().GetResourceMap()) { - heartbeat_data->resource_load_label.push_back(resource_pair.first); - heartbeat_data->resource_load_capacity.push_back(resource_pair.second); + heartbeat_data->add_resource_load_label(resource_pair.first); + heartbeat_data->add_resource_load_capacity(resource_pair.second); } ray::Status status = heartbeat_table.Add( @@ -335,13 +335,9 @@ void NodeManager::GetObjectManagerProfileInfo() { auto profile_info = object_manager_.GetAndResetProfilingInfo(); - if (profile_info.profile_events.size() > 0) { + if (profile_info.profile_events_size() > 0) { flatbuffers::FlatBufferBuilder fbb; - auto message = CreateProfileTableData(fbb, &profile_info); - fbb.Finish(message); - auto profile_message = flatbuffers::GetRoot(fbb.GetBufferPointer()); - - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*profile_message)); + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_info)); } // Reset the timer. @@ -358,8 +354,8 @@ void NodeManager::GetObjectManagerProfileInfo() { } } -void NodeManager::ClientAdded(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ClientAdded(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); RAY_LOG(DEBUG) << "[ClientAdded] Received callback from client id " << client_id; if (client_id == gcs_client_->client_table().GetLocalClientId()) { @@ -378,19 +374,19 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { // Initialize a rpc client to the new node manager. std::unique_ptr client( - new rpc::NodeManagerClient(client_data.node_manager_address, - client_data.node_manager_port, client_call_manager_)); + new rpc::NodeManagerClient(client_data.node_manager_address(), + client_data.node_manager_port(), client_call_manager_)); remote_node_manager_clients_.emplace(client_id, std::move(client)); - ResourceSet resources_total(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet resources_total(rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } -void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { +void NodeManager::ClientRemoved(const ClientTableData &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. - const ClientID client_id = ClientID::FromBinary(client_data.client_id); + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); RAY_LOG(DEBUG) << "[ClientRemoved] Received callback from client id " << client_id; RAY_CHECK(client_id != gcs_client_->client_table().GetLocalClientId()) @@ -418,7 +414,7 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): This could be very slow if there are many actors. for (const auto &actor_entry : actor_registry_) { if (actor_entry.second.GetNodeManagerId() == client_id && - actor_entry.second.GetState() == ActorState::ALIVE) { + actor_entry.second.GetState() == ActorTableData::ALIVE) { RAY_LOG(INFO) << "Actor " << actor_entry.first << " is disconnected, because its node " << client_id << " is removed from cluster. It may be reconstructed."; @@ -436,14 +432,14 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { lineage_cache_.FlushAllUncommittedTasks(); } -void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ResourceCreateUpdated(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " << client_id << ". Updating resource map."; - ResourceSet new_res_set(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet new_res_set(rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); ResourceSet difference_set = old_res_set.FindUpdatedResources(new_res_set); @@ -472,12 +468,12 @@ void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { return; } -void NodeManager::ResourceDeleted(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ResourceDeleted(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - ResourceSet new_res_set(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet new_res_set(rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); RAY_LOG(DEBUG) << "[ResourceDeleted] received callback from client id " << client_id << " with new resources: " << new_res_set.ToString() << ". Updating resource map."; @@ -523,7 +519,7 @@ void NodeManager::TryLocalInfeasibleTaskScheduling() { } void NodeManager::HeartbeatAdded(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { // Locate the client id in remote client table and update available resources based on // the received heartbeat information. auto it = cluster_resource_map_.find(client_id); @@ -535,10 +531,12 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } SchedulingResources &remote_resources = it->second; - ResourceSet remote_available(heartbeat_data.resources_available_label, - heartbeat_data.resources_available_capacity); - ResourceSet remote_load(heartbeat_data.resource_load_label, - heartbeat_data.resource_load_capacity); + ResourceSet remote_available( + rpc::VectorFromProtobuf(heartbeat_data.resources_total_label()), + rpc::VectorFromProtobuf(heartbeat_data.resources_total_capacity())); + ResourceSet remote_load( + rpc::VectorFromProtobuf(heartbeat_data.resource_load_label()), + rpc::VectorFromProtobuf(heartbeat_data.resource_load_capacity())); // TODO(atumanov): assert that the load is a non-empty ResourceSet. remote_resources.SetAvailableResources(std::move(remote_available)); // Extract the load information and save it locally. @@ -563,40 +561,40 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } } -void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch) { +void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch) { const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); // Update load information provided by each heartbeat. - for (const auto &heartbeat_data : heartbeat_batch.batch) { - const ClientID &client_id = ClientID::FromBinary(heartbeat_data->client_id); + for (const auto &heartbeat_data : heartbeat_batch.batch()) { + const ClientID &client_id = ClientID::FromBinary(heartbeat_data.client_id()); if (client_id == local_client_id) { // Skip heartbeats from self. continue; } - HeartbeatAdded(client_id, *heartbeat_data); + HeartbeatAdded(client_id, heartbeat_data); } } void NodeManager::PublishActorStateTransition( - const ActorID &actor_id, const ActorTableDataT &data, + const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback) { // Copy the actor notification data. - auto actor_notification = std::make_shared(data); + auto actor_notification = std::make_shared(data); // The actor log starts with an ALIVE entry. This is followed by 0 to N pairs // of (RECONSTRUCTING, ALIVE) entries, where N is the maximum number of // reconstructions. This is followed optionally by a DEAD entry. - int log_length = 2 * (actor_notification->max_reconstructions - - actor_notification->remaining_reconstructions); - if (actor_notification->state != ActorState::ALIVE) { + int log_length = 2 * (actor_notification->max_reconstructions() - + actor_notification->remaining_reconstructions()); + if (actor_notification->state() != ActorTableData::ALIVE) { // RECONSTRUCTING or DEAD entries have an odd index. log_length += 1; } // If we successful appended a record to the GCS table of the actor that // has died, signal this to anyone receiving signals from this actor. auto success_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { auto redis_context = client->primary_context(); - if (data.state == ActorState::DEAD || data.state == ActorState::RECONSTRUCTING) { + if (data.state() == ActorTableData::DEAD || data.state() == ActorTableData::RECONSTRUCTING) { std::vector args = {"XADD", id.Hex(), "*", "signal", "ACTOR_DIED_SIGNAL"}; RAY_CHECK_OK(redis_context->RunArgvAsync(args)); @@ -633,11 +631,11 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } RAY_LOG(DEBUG) << "Actor notification received: actor_id = " << actor_id << ", node_manager_id = " << actor_registration.GetNodeManagerId() - << ", state = " << EnumNameActorState(actor_registration.GetState()) + << ", state = " << ActorTableData::ActorState_Name(actor_registration.GetState()) << ", remaining_reconstructions = " << actor_registration.GetRemainingReconstructions(); - if (actor_registration.GetState() == ActorState::ALIVE) { + if (actor_registration.GetState() == ActorTableData::ALIVE) { // The actor's location is now known. Dequeue any methods that were // submitted before the actor's location was known. // (See design_docs/task_states.rst for the state transition diagram.) @@ -664,7 +662,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, // empty lineage this time. SubmitTask(method, Lineage()); } - } else if (actor_registration.GetState() == ActorState::DEAD) { + } else if (actor_registration.GetState() == ActorTableData::DEAD) { // When an actor dies, loop over all of the queued tasks for that actor // and treat them as failed. auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id); @@ -673,7 +671,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); } } else { - RAY_CHECK(actor_registration.GetState() == ActorState::RECONSTRUCTING); + RAY_CHECK(actor_registration.GetState() == ActorTableData::RECONSTRUCTING); RAY_LOG(DEBUG) << "Actor is being reconstructed: " << actor_id; // When an actor fails but can be reconstructed, resubmit all of the queued // tasks for that actor. This will mark the tasks as waiting for actor @@ -863,8 +861,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca // Check if this actor needs to be reconstructed. ActorState new_state = actor_registration.GetRemainingReconstructions() > 0 && !intentional_disconnect - ? ActorState::RECONSTRUCTING - : ActorState::DEAD; + ? ActorTableData::RECONSTRUCTING + : ActorTableData::DEAD; if (was_local) { // Clean up the dummy objects from this actor. RAY_LOG(DEBUG) << "Removing dummy objects for actor: " << actor_id; @@ -873,8 +871,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca } } // Update the actor's state. - ActorTableDataT new_actor_data = actor_entry->second.GetTableData(); - new_actor_data.state = new_state; + ActorTableData new_actor_data = actor_entry->second.GetTableData(); + new_actor_data.set_state(new_state); if (was_local) { // If the actor was local, immediately update the state in actor registry. // So if we receive any actor tasks before we receive GCS notification, @@ -885,7 +883,7 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca ray::gcs::ActorTable::WriteCallback failure_callback = nullptr; if (was_local) { failure_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // If the disconnected actor was local, only this node will try to update actor // state. So the update shouldn't fail. RAY_LOG(FATAL) << "Failed to update state for actor " << id; @@ -1160,7 +1158,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( DriverID::Nil(), checkpoint_id, checkpoint_data, [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, const ActorCheckpointID &checkpoint_id, - const ActorCheckpointDataT &data) { + const ActorCheckpointData &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " << worker->GetActorId(); // Save this actor-to-checkpoint mapping, and remove old checkpoints associated @@ -1244,19 +1242,19 @@ void NodeManager::ProcessSetResourceRequest( return; } - // Add the new resource to a skeleton ClientTableDataT object - ClientTableDataT data; + // Add the new resource to a skeleton ClientTableData object + ClientTableData data; gcs_client_->client_table().GetClient(client_id, data); // Replace the resource vectors with the resource deltas from the message. // RES_CREATEUPDATE and RES_DELETE entries in the ClientTable track changes (deltas) in // the resources - data.resources_total_label = std::vector{resource_name}; - data.resources_total_capacity = std::vector{capacity}; + data.add_resources_total_label(resource_name); + data.add_resources_total_capacity(capacity); // Set the correct flag for entry_type if (is_deletion) { - data.entry_type = EntryType::RES_DELETE; + data.set_entry_type(ClientTableData::RES_DELETE); } else { - data.entry_type = EntryType::RES_CREATEUPDATE; + data.set_entry_type(ClientTableData::RES_CREATEUPDATE); } // Submit to the client table. This calls the ResourceCreateUpdated callback, which @@ -1265,7 +1263,7 @@ void NodeManager::ProcessSetResourceRequest( if (not worker) { worker = worker_pool_.GetRegisteredDriver(client); } - auto data_shared_ptr = std::make_shared(data); + auto data_shared_ptr = std::make_shared(data); auto client_table = gcs_client_->client_table(); RAY_CHECK_OK(gcs_client_->client_table().Append( DriverID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); @@ -1370,7 +1368,7 @@ bool NodeManager::CheckDependencyManagerInvariant() const { void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_type) { const TaskSpecification &spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed because of error " - << EnumNameErrorType(error_type) << "."; + << ErrorType_Name(error_type) << "."; // If this was an actor creation task that tried to resume from a checkpoint, // then erase it here since the task did not finish. if (spec.IsActorCreationTask()) { @@ -1488,9 +1486,9 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // If we have already seen this actor and this actor is not being reconstructed, // its location is known. bool location_known = - seen && actor_entry->second.GetState() != ActorState::RECONSTRUCTING; + seen && actor_entry->second.GetState() != ActorTableData::RECONSTRUCTING; if (location_known) { - if (actor_entry->second.GetState() == ActorState::DEAD) { + if (actor_entry->second.GetState() == ActorTableData::DEAD) { // If this actor is dead, either because the actor process is dead // or because its residing node is dead, treat this task as failed. TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); @@ -1535,7 +1533,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // we missed the creation notification. auto lookup_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // The actor has been created. We only need the last entry, because // it represents the latest state of this actor. @@ -1861,11 +1859,11 @@ void NodeManager::FinishAssignedTask(Worker &worker) { } } -ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { +ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { RAY_CHECK(task.GetTaskSpecification().IsActorCreationTask()); auto actor_id = task.GetTaskSpecification().ActorCreationId(); auto actor_entry = actor_registry_.find(actor_id); - ActorTableDataT new_actor_data; + ActorTableData new_actor_data; // TODO(swang): If this is an actor that was reconstructed, and previous // actor notifications were delayed, then this node may not have an entry for // the actor in actor_regisry_. Then, the fields for the number of @@ -1873,32 +1871,28 @@ ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &ta if (actor_entry == actor_registry_.end()) { // Set all of the static fields for the actor. These fields will not // change even if the actor fails or is reconstructed. - new_actor_data.actor_id = actor_id.Binary(); - new_actor_data.actor_creation_dummy_object_id = - task.GetTaskSpecification().ActorDummyObject().Binary(); - new_actor_data.driver_id = task.GetTaskSpecification().DriverId().Binary(); - new_actor_data.max_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); + new_actor_data.set_actor_id(actor_id.Binary()); + new_actor_data.set_actor_creation_dummy_object_id(task.GetTaskSpecification().ActorDummyObject().Binary()); + new_actor_data.set_driver_id(task.GetTaskSpecification().DriverId().Binary()); + new_actor_data.set_max_reconstructions(task.GetTaskSpecification().MaxActorReconstructions()); // This is the first time that the actor has been created, so the number // of remaining reconstructions is the max. - new_actor_data.remaining_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); + new_actor_data.set_remaining_reconstructions(task.GetTaskSpecification().MaxActorReconstructions()); } else { // If we've already seen this actor, it means that this actor was reconstructed. // Thus, its previous state must be RECONSTRUCTING. - RAY_CHECK(actor_entry->second.GetState() == ActorState::RECONSTRUCTING); + RAY_CHECK(actor_entry->second.GetState() == ActorTableData::RECONSTRUCTING); // Copy the static fields from the current actor entry. new_actor_data = actor_entry->second.GetTableData(); // We are reconstructing the actor, so subtract its // remaining_reconstructions by 1. - new_actor_data.remaining_reconstructions--; + new_actor_data.set_remaining_reconstructions(new_actor_data.remaining_reconstructions() - 1); } // Set the new fields for the actor's state to indicate that the actor is // now alive on this node manager. - new_actor_data.node_manager_id = - gcs_client_->client_table().GetLocalClientId().Binary(); - new_actor_data.state = ActorState::ALIVE; + new_actor_data.set_node_manager_id(gcs_client_->client_table().GetLocalClientId().Binary()); + new_actor_data.set_state(ActorTableData::ALIVE); return new_actor_data; } @@ -1934,7 +1928,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { DriverID::Nil(), checkpoint_id, [this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client, const UniqueID &checkpoint_id, - const ActorCheckpointDataT &checkpoint_data) { + const ActorCheckpointData &checkpoint_data) { RAY_LOG(INFO) << "Restoring registration for actor " << actor_id << " from checkpoint " << checkpoint_id; ActorRegistration actor_registration = @@ -1948,7 +1942,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { actor_id, new_actor_data, /*failure_callback=*/ [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -1965,7 +1959,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { actor_id, new_actor_data, /*failure_callback=*/ [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -2004,10 +1998,11 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { DriverID::Nil(), task_id, /*success_callback=*/ [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &task_data) { + const TaskTableData &task_data) { // The task was in the GCS task table. Use the stored task spec to // re-execute the task. - const Task task(task_data); + auto message = flatbuffers::GetRoot(task_data.task().data()); + const Task task(*message); ResubmitTask(task); }, /*failure_callback=*/ @@ -2035,7 +2030,7 @@ void NodeManager::ResubmitTask(const Task &task) { if (task.GetTaskSpecification().IsActorCreationTask()) { const auto &actor_id = task.GetTaskSpecification().ActorCreationId(); const auto it = actor_registry_.find(actor_id); - if (it != actor_registry_.end() && it->second.GetState() == ActorState::ALIVE) { + if (it != actor_registry_.end() && it->second.GetState() == ActorTableData::ALIVE) { // If the actor is still alive, then do not resubmit the task. If the // actor actually is dead and a result is needed, then reconstruction // for this task will be triggered again. diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 61613358330c..efdd5d238bf7 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -10,7 +10,6 @@ #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" -#include "ray/gcs/format/util.h" #include "ray/raylet/actor_registration.h" #include "ray/raylet/lineage_cache.h" #include "ray/raylet/scheduling_policy.h" @@ -26,6 +25,13 @@ namespace ray { namespace raylet { +using rpc::ClientTableData; +using rpc::HeartbeatTableData; +using rpc::HeartbeatBatchTableData; +using rpc::DriverTableData; +using rpc::ActorTableData; +using rpc::ErrorType; + struct NodeManagerConfig { /// The node's resource configuration. ResourceSet resource_config; @@ -112,22 +118,22 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// /// \param data Data associated with the new client. /// \return Void. - void ClientAdded(const ClientTableDataT &data); + void ClientAdded(const ClientTableData &data); /// Handler for the removal of a GCS client. /// \param client_data Data associated with the removed client. /// \return Void. - void ClientRemoved(const ClientTableDataT &client_data); + void ClientRemoved(const ClientTableData &client_data); /// Handler for the addition or updation of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceCreateUpdated(const ClientTableDataT &client_data); + void ResourceCreateUpdated(const ClientTableData &client_data); /// Handler for the deletion of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceDeleted(const ClientTableDataT &client_data); + void ResourceDeleted(const ClientTableData &client_data); /// Evaluates the local infeasible queue to check if any tasks can be scheduled. /// This is called whenever there's an update to the resources on the local client. @@ -150,11 +156,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param id The ID of the node manager that sent the heartbeat. /// \param data The heartbeat data including load information. /// \return Void. - void HeartbeatAdded(const ClientID &id, const HeartbeatTableDataT &data); + void HeartbeatAdded(const ClientID &id, const HeartbeatTableData &data); /// Handler for a heartbeat batch notification from the GCS /// /// \param heartbeat_batch The batch of heartbeat data. - void HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch); + void HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch); /// Methods for task scheduling. @@ -206,7 +212,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Helper function to produce actor table data for a newly created actor. /// /// \param task The actor creation task that created the actor. - ActorTableDataT CreateActorTableDataFromCreationTask(const Task &task); + ActorTableData CreateActorTableDataFromCreationTask(const Task &task); /// Handle a worker finishing an assigned actor task or actor creation task. /// \param worker The worker that finished the task. /// \param task The actor task or actor creationt ask. @@ -317,7 +323,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param failure_callback An optional callback to call if the publish is /// unsuccessful. void PublishActorStateTransition( - const ActorID &actor_id, const ActorTableDataT &data, + const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback); /// When a driver dies, loop over all of the queued tasks for that driver and @@ -346,7 +352,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param driver_data Data associated with a driver table event. /// \return Void. void HandleDriverTableUpdate(const DriverID &id, - const std::vector &driver_data); + const std::vector &driver_data); /// Check if certain invariants associated with the task dependency manager /// and the local queues are satisfied. This is only used for debugging diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 473e6c263ffe..cbf9b25213ca 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -90,23 +90,23 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, const NodeManagerConfig &node_manager_config) { RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = node_ip_address; - client_info.raylet_socket_name = raylet_socket_name; - client_info.object_store_socket_name = object_store_socket_name; - client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port(); - client_info.node_manager_port = node_manager_.GetServerPort(); + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(node_ip_address); + client_info.set_raylet_socket_name(raylet_socket_name); + client_info.set_object_store_socket_name(object_store_socket_name); + client_info.set_object_manager_port(object_manager_acceptor_.local_endpoint().port()); + client_info.set_node_manager_port(node_manager_.GetServerPort()); // Add resource information. for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { - client_info.resources_total_label.push_back(resource_pair.first); - client_info.resources_total_capacity.push_back(resource_pair.second); + client_info.add_resources_total_label(resource_pair.first); + client_info.add_resources_total_capacity(resource_pair.second); } RAY_LOG(DEBUG) << "Node manager " << gcs_client_->client_table().GetLocalClientId() - << " started on " << client_info.node_manager_address << ":" - << client_info.node_manager_port << " object manager at " - << client_info.node_manager_address << ":" - << client_info.object_manager_port; + << " started on " << client_info.node_manager_address() << ":" + << client_info.node_manager_port() << " object manager at " + << client_info.node_manager_address() << ":" + << client_info.object_manager_port(); ; RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info)); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 26fe74b2b622..9367a5054591 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -16,6 +16,8 @@ namespace ray { namespace raylet { +using rpc::ClientTableData; + class Task; class NodeManager; diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 801bb9112241..63dbcc3e08db 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -334,10 +334,11 @@ ray::Status RayletClient::PushError(const DriverID &driver_id, const std::string return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb); } -ray::Status RayletClient::PushProfileEvents(const ProfileTableDataT &profile_events) { +ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) { flatbuffers::FlatBufferBuilder fbb; - auto message = CreateProfileTableData(fbb, &profile_events); - fbb.Finish(message); +// auto message = CreateProfileTableData(fbb, &profile_events); +// fbb.Finish(message); +// XXX auto status = conn_->WriteMessage(MessageType::PushProfileEventsRequest, &fbb); // Don't be too strict for profile errors. Just create logs and prevent it from crash. diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 8b4dfad5b37a..5437455bffc6 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -8,6 +8,7 @@ #include "ray/common/status.h" #include "ray/raylet/task_spec.h" +#include "ray/protobuf/gcs.pb.h" using ray::ActorCheckpointID; using ray::ActorID; @@ -17,6 +18,8 @@ using ray::ObjectID; using ray::TaskID; using ray::UniqueID; +using ray::rpc::ProfileTableData; + using MessageType = ray::protocol::MessageType; using ResourceMappingType = std::unordered_map>>; @@ -137,7 +140,7 @@ class RayletClient { /// /// \param profile_events A batch of profiling event information. /// \return ray::Status. - ray::Status PushProfileEvents(const ProfileTableDataT &profile_events); + ray::Status PushProfileEvents(const ProfileTableData &profile_events); /// Free a list of objects from object stores. /// diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 97c86ea73cd8..bf5c1acfaa37 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -106,19 +106,19 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, // Attempt to reconstruct the task by inserting an entry into the task // reconstruction log. This will fail if another node has already inserted // an entry for this reconstruction. - auto reconstruction_entry = std::make_shared(); - reconstruction_entry->num_reconstructions = reconstruction_attempt; - reconstruction_entry->node_manager_id = client_id_.Binary(); + auto reconstruction_entry = std::make_shared(); + reconstruction_entry->set_num_reconstructions(reconstruction_attempt); + reconstruction_entry->set_node_manager_id(client_id_.Binary()); RAY_CHECK_OK(task_reconstruction_log_.AppendAt( DriverID::Nil(), task_id, reconstruction_entry, /*success_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { + const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/true); }, /*failure_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { + const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/false); }, reconstruction_attempt)); diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index cd969cc2706e..a194443e1425 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -17,6 +17,8 @@ namespace ray { namespace raylet { +using rpc::TaskReconstructionData; + class ReconstructionPolicyInterface { public: virtual void ListenAndMaybeReconstruct(const ObjectID &object_id) = 0; diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index c5155b96b0c1..89028c733d0d 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -261,10 +261,10 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { << (it->second.expires_at - now_ms) << "ms"; } - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = client_id_.Hex(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = it->second.lease_period; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(client_id_.Hex()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(it->second.lease_period); RAY_CHECK_OK(task_lease_table_.Add(DriverID::Nil(), task_id, task_lease_data, nullptr)); auto period = boost::posix_time::milliseconds(it->second.lease_period / 2); diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index 3788a5eae7ae..a96558295234 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -13,6 +13,8 @@ namespace ray { namespace raylet { +using rpc::TaskLeaseData; + class ReconstructionPolicy; /// \class TaskDependencyManager diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h index f356b2a88b5a..5e928dd1de43 100644 --- a/src/ray/rpc/util.h +++ b/src/ray/rpc/util.h @@ -36,11 +36,13 @@ inline std::unordered_map MapFromProtobuf(::google::protobuf::Map pb template inline std::vector VectorFromProtobuf( const ::google::protobuf::RepeatedPtrField &pb_repeated) { - std::vector vector(static_cast(pb_repeated.size())); - for (const auto &item : pb_repeated) { - vector.push_back(item); - } - return vector; + return std::vector(pb_repeated.begin(), pb_repeated.end()); +} + +template +inline std::vector VectorFromProtobuf( + const ::google::protobuf::RepeatedField &pb_repeated) { + return std::vector(pb_repeated.begin(), pb_repeated.end()); } template From 035038777b464b6553b7a55e7db67618583229cc Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 21 Jun 2019 17:45:06 +0800 Subject: [PATCH 08/34] remove gcs_entry wrapper remove gcs_entry wrapper --- BUILD.bazel | 1 - src/ray/gcs/client.cc | 2 - src/ray/gcs/client.h | 2 - src/ray/gcs/gcs_entry.h | 61 ----------- src/ray/gcs/tables.cc | 188 ++++++++++++++++++---------------- src/ray/gcs/tables.h | 134 ++++++++++++------------ src/ray/rpc/message_wrapper.h | 63 ------------ 7 files changed, 165 insertions(+), 286 deletions(-) delete mode 100644 src/ray/gcs/gcs_entry.h delete mode 100644 src/ray/rpc/message_wrapper.h diff --git a/BUILD.bazel b/BUILD.bazel index 7bdf2b7ace4f..3801cb95bc7e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -442,7 +442,6 @@ cc_library( copts = COPTS, includes = [ "src/ray/gcs/format", - "src/ray/gcs/gcs_entry.h", ], deps = [ ":gcs_cc_proto", diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 57b8694e17d6..6de29bb52764 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -206,8 +206,6 @@ TaskLeaseTable &AsyncGcsClient::task_lease_table() { return *task_lease_table_; ClientTable &AsyncGcsClient::client_table() { return *client_table_; } -FunctionTable &AsyncGcsClient::function_table() { return *function_table_; } - HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() { diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 6addd9789fde..5e70025b39a0 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -44,7 +44,6 @@ class RAY_EXPORT AsyncGcsClient { /// one event loop should be attached at a time. Status Attach(boost::asio::io_service &io_service); - inline FunctionTable &function_table(); // TODO: Some API for getting the error on the driver ObjectTable &object_table(); raylet::TaskTable &raylet_task_table(); @@ -78,7 +77,6 @@ class RAY_EXPORT AsyncGcsClient { std::string DebugString() const; private: - std::unique_ptr function_table_; std::unique_ptr object_table_; std::unique_ptr raylet_task_table_; std::unique_ptr actor_table_; diff --git a/src/ray/gcs/gcs_entry.h b/src/ray/gcs/gcs_entry.h deleted file mode 100644 index 8a86ae49001a..000000000000 --- a/src/ray/gcs/gcs_entry.h +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef RAY_GCS_GCS_ENTRY_H -#define RAY_GCS_GCS_ENTRY_H - -#include "ray/protobuf/gcs.pb.h" -#include "ray/rpc/message_wrapper.h" -#include "ray/rpc/util.h" - -namespace ray { - -namespace gcs { - -template -class GcsEntry : public rpc::ConstMessageWrapper { - public: - explicit GcsEntry(const rpc::GcsEntry &message) : ConstMessageWrapper(message) {} - - explicit GcsEntry(const std::string &data) - : ConstMessageWrapper(ParseGcsEntryMessage(data)) {} - - GcsEntry(const ID &id, const rpc::GcsChangeMode &change_mode, - const std::vector &entries) - : ConstMessageWrapper(CreateGcsEntryMessage(id, change_mode, entries)) {} - - const ID GetId() { return ID::FromBinary(message_->id()); } - - const rpc::GcsChangeMode GetChangeMode() { return message_->change_mode(); } - - const std::vector GetEntries() { -// return rpc::VectorFromProtobuf(message_->entries()); -// XXX - return {}; - } - - private: - inline static std::unique_ptr ParseGcsEntryMessage( - const std::string &data) { - auto *gcs_entry = new rpc::GcsEntry(); - gcs_entry->ParseFromString(data); - return std::unique_ptr(gcs_entry); - } - - static inline std::unique_ptr CreateGcsEntryMessage( - const ID &id, const rpc::GcsChangeMode &change_mode, - const std::vector &entries) { - auto *gcs_entry = new rpc::GcsEntry(); - gcs_entry->set_id(id.ToBinary()); - gcs_entry->set_change_mode(change_mode); - for (const auto &entry : entries) { - std::string str; - entry.SerializeToString(&str); - gcs_entry->add_entries(std::move(str)); - } - return std::unique_ptr(gcs_entry); - } -}; - -} // namespace gcs - -} // namespace ray - -#endif // RAY_GCS_GCS_ENTRY_H diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index ea20aa6109fd..b6cb3cf6514b 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -88,13 +88,18 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; auto callback = [this, id, lookup](const CallbackReply &reply) { if (lookup != nullptr) { + std::vector results; if (!reply.IsNil()) { - GcsEntry gcs_entry(reply.ReadAsString()); - RAY_CHECK(gcs_entry.GetId() == id); - lookup(client_, id, gcs_entry.GetEntries()); - } else { - lookup(client_, id, {}); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(reply.ReadAsString()); + RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); + for (size_t i = 0; i < gcs_entry.entries_size(); i++) { + Data data; + data.ParseFromString(gcs_entry.entries(i)); + results.emplace_back(std::move(data)); + } } + lookup(client_, id, results); } }; std::vector nil; @@ -107,9 +112,9 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const Callback &subscribe, const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, - const rpc::GcsChangeMode change_mode, + const GcsChangeMode change_mode, const std::vector &data) { - RAY_CHECK(change_mode != rpc::GcsChangeMode::REMOVE); + RAY_CHECK(change_mode != GcsChangeMode::REMOVE); subscribe(client, id, data); }; return Subscribe(driver_id, client_id, subscribe_wrapper, done); @@ -134,9 +139,16 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - GcsEntry gcs_entry(data); - subscribe(client_, gcs_entry.GetId(), gcs_entry.GetChangeMode(), - gcs_entry.GetEntries()); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(data); + ID id = ID::FromBinary(gcs_entry.id()); + std::vector results; + for (size_t i = 0; i < gcs_entry.entries_size(); i++) { + Data result; + result.ParseFromString(gcs_entry.entries(i)); + results.emplace_back(std::move(result)); + } + subscribe(client_, id, gcs_entry.change_mode(), results); } } }; @@ -326,9 +338,9 @@ Status Hash::Update(const DriverID &driver_id, const ID &id, (done)(client_, id, data_map); } }; - rpc::GcsEntry gcs_entry; + GcsEntry gcs_entry; gcs_entry.set_id(id.Binary()); - gcs_entry.set_change_mode(rpc::GcsChangeMode::APPEND_OR_ADD); + gcs_entry.set_change_mode(GcsChangeMode::APPEND_OR_ADD); for (const auto &pair : data_map) { gcs_entry.add_entries(pair.first); gcs_entry.add_entries(pair.second->SerializeAsString()); @@ -348,9 +360,9 @@ Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, (remove_callback)(client_, id, keys); } }; - rpc::GcsEntry gcs_entry; + GcsEntry gcs_entry; gcs_entry.set_id(id.Binary()); - gcs_entry.set_change_mode(rpc::GcsChangeMode::REMOVE); + gcs_entry.set_change_mode(GcsChangeMode::REMOVE); for (const auto &key : keys) { gcs_entry.add_entries(key); } @@ -376,7 +388,7 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, DataMap results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - rpc::GcsEntry gcs_entry; + GcsEntry gcs_entry; gcs_entry.ParseFromString(reply.ReadAsString()); RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); RAY_CHECK(gcs_entry.entries_size() % 2 == 0); @@ -417,11 +429,11 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - rpc::GcsEntry gcs_entry; + GcsEntry gcs_entry; gcs_entry.ParseFromString(data); ID id = ID::FromBinary(gcs_entry.id()); DataMap data_map; - if (gcs_entry.change_mode() == rpc::GcsChangeMode::REMOVE) { + if (gcs_entry.change_mode() == GcsChangeMode::REMOVE) { for (const auto &key : gcs_entry.entries()) { data_map.emplace(key, std::shared_ptr()); } @@ -449,7 +461,7 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { - auto data = std::make_shared(); + auto data = std::make_shared(); data->set_driver_id(driver_id.Binary()); data->set_type(type); data->set_error_message(error_message); @@ -458,11 +470,11 @@ Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::strin } std::string ErrorTable::DebugString() const { - return Log::DebugString(); + return Log::DebugString(); } -Status ProfileTable::AddProfileEventBatch(const rpc::ProfileTableData &profile_events) { - auto data = std::make_shared(); +Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { + auto data = std::make_shared(); data->CopyFrom(profile_events); // XXX return Append(DriverID::Nil(), UniqueID::FromRandom(), data, @@ -470,11 +482,11 @@ Status ProfileTable::AddProfileEventBatch(const rpc::ProfileTableData &profile_e } std::string ProfileTable::DebugString() const { - return Log::DebugString(); + return Log::DebugString(); } Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { - auto data = std::make_shared(); + auto data = std::make_shared(); data->set_driver_id(driver_id.Binary()); data->set_is_dead(is_dead); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); @@ -485,7 +497,7 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - (entry.second.entry_type() == rpc::ClientTableData::INSERTION)) { + (entry.second.entry_type() == ClientTableData::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -496,7 +508,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - entry.second.entry_type() == rpc::ClientTableData::DELETION) { + entry.second.entry_type() == ClientTableData::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -508,7 +520,7 @@ void ClientTable::RegisterResourceCreateUpdatedCallback( // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - (entry.second.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE)) { + (entry.second.entry_type() == ClientTableData::RES_CREATEUPDATE)) { resource_createupdated_callback_(client_, entry.first, entry.second); } } @@ -519,14 +531,14 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - entry.second.entry_type() == rpc::ClientTableData::RES_DELETE) { + entry.second.entry_type() == ClientTableData::RES_DELETE) { resource_deleted_callback_(client_, entry.first, entry.second); } } } void ClientTable::HandleNotification(AsyncGcsClient *client, - const rpc::ClientTableData &data) { + const ClientTableData &data) { ClientID client_id = ClientID::FromBinary(data.client_id()); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. @@ -538,17 +550,17 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } else { // If the entry is in the cache, then the notification is new if the client // was alive and is now dead or resources have been updated. - bool was_not_deleted = (entry->second.entry_type() != rpc::ClientTableData::DELETION); - bool is_deleted = (data.entry_type() == rpc::ClientTableData::DELETION); + bool was_not_deleted = (entry->second.entry_type() != ClientTableData::DELETION); + bool is_deleted = (data.entry_type() == ClientTableData::DELETION); bool is_res_modified = - ((data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) || - (data.entry_type() == rpc::ClientTableData::RES_DELETE)); + ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)); is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (entry->second.entry_type() == rpc::ClientTableData::DELETION) { - RAY_CHECK((data.entry_type() == rpc::ClientTableData::DELETION)) + if (entry->second.entry_type() == ClientTableData::DELETION) { + RAY_CHECK((data.entry_type() == ClientTableData::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } @@ -556,21 +568,21 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // Add the notification to our cache. Notifications are idempotent. // If it is a new client or a client removal, add as is - if ((data.entry_type() == rpc::ClientTableData::INSERTION) || - (data.entry_type() == rpc::ClientTableData::DELETION)) { + if ((data.entry_type() == ClientTableData::INSERTION) || + (data.entry_type() == ClientTableData::DELETION)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " "notification for client id " << client_id << ". ClientTableData: " << int(data.entry_type()) << ". Setting the client cache to data."; client_cache_[client_id] = data; - } else if ((data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) || - (data.entry_type() == rpc::ClientTableData::RES_DELETE)) { + } else if ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " "notification for client id " << client_id << ". ClientTableData: " << int(data.entry_type()) << ". Updating the client cache with the delta from the log."; - rpc::ClientTableData &cache_data = client_cache_[client_id]; + ClientTableData &cache_data = client_cache_[client_id]; // Iterate over all resources in the new create/update notification for (std::vector::size_type i = 0; i != data.resources_total_label_size(); i++) { auto const &resource_name = data.resources_total_label(i); @@ -584,12 +596,12 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, auto index = std::distance(cache_data.resources_total_label().begin(), existing_resource_label); // Resource already exists, set capacity if updation call.. - if (data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) { + if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { // cache_data.mutable_resources_total_capacity()[index] = capacity; // XXX:639 } // .. delete if deletion call. - else if (data.entry_type() == rpc::ClientTableData::RES_DELETE) { + else if (data.entry_type() == ClientTableData::RES_DELETE) { cache_data.mutable_resources_total_label()->erase( cache_data.resources_total_label().begin() + index); cache_data.mutable_resources_total_capacity()->erase( @@ -598,7 +610,7 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } else { // Resource does not exist, create resource and add capacity if it was a resource // create call. - if (data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) { + if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { cache_data.add_resources_total_label(resource_name); cache_data.add_resources_total_capacity(capacity); } @@ -607,14 +619,14 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } // If the notification is new, call any registered callbacks. - rpc::ClientTableData &cache_data = client_cache_[client_id]; + ClientTableData &cache_data = client_cache_[client_id]; if (is_notif_new) { - if (data.entry_type() == rpc::ClientTableData::INSERTION) { + if (data.entry_type() == ClientTableData::INSERTION) { if (client_added_callback_ != nullptr) { client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else if (data.entry_type() == rpc::ClientTableData::DELETION) { + } else if (data.entry_type() == ClientTableData::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. @@ -622,11 +634,11 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, if (client_removed_callback_ != nullptr) { client_removed_callback_(client, client_id, cache_data); } - } else if (data.entry_type() == rpc::ClientTableData::RES_CREATEUPDATE) { + } else if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { if (resource_createupdated_callback_ != nullptr) { resource_createupdated_callback_(client, client_id, cache_data); } - } else if (data.entry_type() == rpc::ClientTableData::RES_DELETE) { + } else if (data.entry_type() == ClientTableData::RES_DELETE) { if (resource_deleted_callback_ != nullptr) { resource_deleted_callback_(client, client_id, cache_data); } @@ -635,7 +647,7 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } void ClientTable::HandleConnected(AsyncGcsClient *client, - const rpc::ClientTableData &data) { + const ClientTableData &data) { auto connected_client_id = ClientID::FromBinary(data.client_id()); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; @@ -643,39 +655,39 @@ void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientID &ClientTable::GetLocalClientId() const { return client_id_; } -const rpc::ClientTableData &ClientTable::GetLocalClient() const { return local_client_; } +const ClientTableData &ClientTable::GetLocalClient() const { return local_client_; } bool ClientTable::IsRemoved(const ClientID &client_id) const { return removed_clients_.count(client_id) == 1; } -Status ClientTable::Connect(const rpc::ClientTableData &local_client) { +Status ClientTable::Connect(const ClientTableData &local_client) { RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client."; RAY_CHECK(local_client.client_id() == local_client_.client_id()); local_client_ = local_client; // Construct the data to add to the client table. - auto data = std::make_shared(local_client_); - data->set_entry_type(rpc::ClientTableData::INSERTION); + auto data = std::make_shared(local_client_); + data->set_entry_type(ClientTableData::INSERTION); // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, - const rpc::ClientTableData &data) { + const ClientTableData &data) { RAY_CHECK(log_key == client_log_key_); HandleConnected(client, data); // Callback for a notification from the client table. auto notification_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { + const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; + std::unordered_map connected_nodes; + std::unordered_map disconnected_nodes; for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type() != rpc::ClientTableData::DELETION) { + if (notification.entry_type() != ClientTableData::DELETION) { connected_nodes.emplace(notification.client_id(), notification); } else { auto iter = connected_nodes.find(notification.client_id()); @@ -705,10 +717,10 @@ Status ClientTable::Connect(const rpc::ClientTableData &local_client) { } Status ClientTable::Disconnect(const DisconnectCallback &callback) { - auto data = std::make_shared(local_client_); - data->set_entry_type(rpc::ClientTableData::DELETION); + auto data = std::make_shared(local_client_); + data->set_entry_type(ClientTableData::DELETION); auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, - const rpc::ClientTableData &data) { + const ClientTableData &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(DriverID::Nil(), client_log_key_, id)); if (callback != nullptr) { @@ -722,14 +734,14 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { } ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { - auto data = std::make_shared(); + auto data = std::make_shared(); data->set_client_id(dead_client_id.Binary()); - data->set_entry_type(rpc::ClientTableData::DELETION); + data->set_entry_type(ClientTableData::DELETION); return Append(DriverID::Nil(), client_log_key_, data, nullptr); } void ClientTable::GetClient(const ClientID &client_id, - rpc::ClientTableData &client_info) const { + ClientTableData &client_info) const { RAY_CHECK(!client_id.IsNil()); auto entry = client_cache_.find(client_id); if (entry != client_cache_.end()) { @@ -739,7 +751,7 @@ void ClientTable::GetClient(const ClientID &client_id, } } -const std::unordered_map &ClientTable::GetAllClients() +const std::unordered_map &ClientTable::GetAllClients() const { return client_cache_; } @@ -751,7 +763,7 @@ Status ClientTable::Lookup(const Callback &lookup) { std::string ClientTable::DebugString() const { std::stringstream result; - result << Log::DebugString(); + result << Log::DebugString(); result << ", cache size: " << client_cache_.size() << ", num removed: " << removed_clients_.size(); return result.str(); @@ -762,9 +774,9 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id, - const rpc::ActorCheckpointIdData &data) { - std::shared_ptr copy = - std::make_shared(data); + const ActorCheckpointIdData &data) { + std::shared_ptr copy = + std::make_shared(data); copy->add_timestamps(current_sys_time_ms()); copy->add_checkpoint_ids(checkpoint_id.Binary()); auto num_to_keep = RayConfig::instance().num_actor_checkpoints_to_keep(); @@ -780,8 +792,8 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, }; auto failure_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id) { - std::shared_ptr data = - std::make_shared(); + std::shared_ptr data = + std::make_shared(); data->set_actor_id(id.Binary()); data->add_timestamps(current_sys_time_ms()); *data->add_checkpoint_ids() = checkpoint_id.Binary(); @@ -790,24 +802,24 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, return Lookup(driver_id, actor_id, lookup_callback, failure_callback); } -template class Log; -template class Set; -template class Log; -template class Table; -template class Log; -template class Log; -template class Table; -template class Table; -template class Table; -template class Log; -template class Log; -template class Log; -template class Log; -template class Table; -template class Table; - -template class Log; -template class Hash; +template class Log; +template class Set; +template class Log; +template class Table; +template class Log; +template class Log; +template class Table; +template class Table; +template class Table; +template class Log; +template class Log; +template class Log; +template class Log; +template class Table; +template class Table; + +template class Log; +template class Hash; } // namespace gcs diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 399b2f242af1..43b80df0cfcf 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -9,7 +9,6 @@ #include "ray/common/constants.h" #include "ray/common/id.h" #include "ray/common/status.h" -#include "ray/gcs/gcs_entry.h" #include "ray/util/logging.h" #include "ray/gcs/redis_context.h" @@ -21,6 +20,10 @@ namespace ray { namespace gcs { +using rpc::TablePrefix; +using rpc::TablePubsub; +using rpc::GcsChangeMode; +using rpc::GcsEntry; using rpc::ObjectTableData; using rpc::TaskTableData; using rpc::ActorTableData; @@ -33,6 +36,9 @@ using rpc::DriverTableData; using rpc::ProfileTableData; using rpc::ActorCheckpointData; using rpc::ActorCheckpointIdData; +using rpc::RayResource; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatBatchTableData; class RedisContext; @@ -87,7 +93,7 @@ class Log : public LogInterface, virtual public PubsubInterface { using Callback = std::function &data)>; using NotificationCallback = std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; @@ -109,8 +115,8 @@ class Log : public LogInterface, virtual public PubsubInterface { Log(const std::vector> &contexts, AsyncGcsClient *client) : shard_contexts_(contexts), client_(client), - pubsub_channel_(rpc::TablePubsub::NO_PUBLISH), - prefix_(rpc::TablePrefix::UNUSED), + pubsub_channel_(TablePubsub::NO_PUBLISH), + prefix_(TablePrefix::UNUSED), subscribe_callback_index_(-1){}; /// Append a log entry to a key. @@ -249,12 +255,12 @@ class Log : public LogInterface, virtual public PubsubInterface { AsyncGcsClient *client_; /// The pubsub channel to subscribe to for notifications about keys in this /// table. If no notifications are required, this should be set to - /// rpc::TablePubsub_NO_PUBLISH. If notifications are required, then this must be + /// TablePubsub_NO_PUBLISH. If notifications are required, then this must be /// unique across all instances of Log. - rpc::TablePubsub pubsub_channel_; + TablePubsub pubsub_channel_; /// The prefix to use for keys in this table. This must be unique across all /// instances of Log. - rpc::TablePrefix prefix_; + TablePrefix prefix_; /// The index in the RedisCallbackManager for the callback that is called /// when we receive notifications. This is >= 0 iff we have subscribed to the /// table, otherwise -1. @@ -491,7 +497,7 @@ class HashInterface { /// \return Void using HashNotificationCallback = std::function; + const GcsChangeMode change_mode, const DataMap &data)>; /// Add entries of a hash table. /// @@ -595,59 +601,59 @@ class Hash : private Log, using Log::num_lookups_; }; -class DynamicResourceTable : public Hash { +class DynamicResourceTable : public Hash { public: DynamicResourceTable(const std::vector> &contexts, AsyncGcsClient *client) : Hash(contexts, client) { - pubsub_channel_ = rpc::TablePubsub::NODE_RESOURCE_PUBSUB; - prefix_ = rpc::TablePrefix::NODE_RESOURCE; + pubsub_channel_ = TablePubsub::NODE_RESOURCE_PUBSUB; + prefix_ = TablePrefix::NODE_RESOURCE; }; virtual ~DynamicResourceTable(){}; }; -class ObjectTable : public Set { +class ObjectTable : public Set { public: ObjectTable(const std::vector> &contexts, AsyncGcsClient *client) : Set(contexts, client) { - pubsub_channel_ = rpc::TablePubsub::OBJECT_PUBSUB; - prefix_ = rpc::TablePrefix::OBJECT; + pubsub_channel_ = TablePubsub::OBJECT_PUBSUB; + prefix_ = TablePrefix::OBJECT; }; virtual ~ObjectTable(){}; }; -class HeartbeatTable : public Table { +class HeartbeatTable : public Table { public: HeartbeatTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = rpc::TablePubsub::HEARTBEAT_PUBSUB; - prefix_ = rpc::TablePrefix::HEARTBEAT; + pubsub_channel_ = TablePubsub::HEARTBEAT_PUBSUB; + prefix_ = TablePrefix::HEARTBEAT; } virtual ~HeartbeatTable() {} }; -class HeartbeatBatchTable : public Table { +class HeartbeatBatchTable : public Table { public: HeartbeatBatchTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = rpc::TablePubsub::HEARTBEAT_BATCH_PUBSUB; - prefix_ = rpc::TablePrefix::HEARTBEAT_BATCH; + pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH_PUBSUB; + prefix_ = TablePrefix::HEARTBEAT_BATCH; } virtual ~HeartbeatBatchTable() {} }; -class DriverTable : public Log { +class DriverTable : public Log { public: DriverTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = rpc::TablePubsub::DRIVER_PUBSUB; - prefix_ = rpc::TablePrefix::DRIVER; + pubsub_channel_ = TablePubsub::DRIVER_PUBSUB; + prefix_ = TablePrefix::DRIVER; }; virtual ~DriverTable() {} @@ -660,54 +666,44 @@ class DriverTable : public Log { Status AppendDriverData(const DriverID &driver_id, bool is_dead); }; -class FunctionTable : public Table { - public: - FunctionTable(const std::vector> &contexts, - AsyncGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = rpc::TablePubsub::NO_PUBLISH; - prefix_ = rpc::TablePrefix::FUNCTION; - }; -}; - /// Actor table starts with an ALIVE entry, which represents the first time the actor /// is created. This may be followed by 0 or more pairs of RECONSTRUCTING, ALIVE entries, /// which represent each time the actor fails (RECONSTRUCTING) and gets recreated (ALIVE). /// These may be followed by a DEAD entry, which means that the actor has failed and will /// not be reconstructed. -class ActorTable : public Log { +class ActorTable : public Log { public: ActorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = rpc::TablePubsub::ACTOR_PUBSUB; - prefix_ = rpc::TablePrefix::ACTOR; + pubsub_channel_ = TablePubsub::ACTOR_PUBSUB; + prefix_ = TablePrefix::ACTOR; } }; -class TaskReconstructionLog : public Log { +class TaskReconstructionLog : public Log { public: TaskReconstructionLog(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - prefix_ = rpc::TablePrefix::TASK_RECONSTRUCTION; + prefix_ = TablePrefix::TASK_RECONSTRUCTION; } }; -class TaskLeaseTable : public Table { +class TaskLeaseTable : public Table { public: TaskLeaseTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = rpc::TablePubsub::TASK_LEASE_PUBSUB; - prefix_ = rpc::TablePrefix::TASK_LEASE; + pubsub_channel_ = TablePubsub::TASK_LEASE_PUBSUB; + prefix_ = TablePrefix::TASK_LEASE; } Status Add(const DriverID &driver_id, const TaskID &id, - std::shared_ptr &data, + std::shared_ptr &data, const WriteCallback &done) override { RAY_RETURN_NOT_OK( - (Table::Add(driver_id, id, data, done))); + (Table::Add(driver_id, id, data, done))); // Mark the entry for expiration in Redis. It's okay if this command fails // since the lease entry itself contains the expiration period. In the // worst case, if the command fails, then a client that looks up the lease @@ -715,28 +711,28 @@ class TaskLeaseTable : public Table { // TODO(swang): Use a common helper function to format the key instead of // hardcoding it to match the Redis module. std::vector args = {"PEXPIRE", - rpc::TablePrefix_Name(prefix_) + id.Binary(), + TablePrefix_Name(prefix_) + id.Binary(), std::to_string(data->timeout())}; return GetRedisContext(id)->RunArgvAsync(args); } }; -class ActorCheckpointTable : public Table { +class ActorCheckpointTable : public Table { public: ActorCheckpointTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - prefix_ = rpc::TablePrefix::ACTOR_CHECKPOINT; + prefix_ = TablePrefix::ACTOR_CHECKPOINT; }; }; -class ActorCheckpointIdTable : public Table { +class ActorCheckpointIdTable : public Table { public: ActorCheckpointIdTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - prefix_ = rpc::TablePrefix::ACTOR_CHECKPOINT_ID; + prefix_ = TablePrefix::ACTOR_CHECKPOINT_ID; }; /// Add a checkpoint id to an actor, and remove a previous checkpoint if the @@ -752,13 +748,13 @@ class ActorCheckpointIdTable : public Table namespace raylet { -class TaskTable : public Table { +class TaskTable : public Table { public: TaskTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = rpc::TablePubsub::RAYLET_TASK_PUBSUB; - prefix_ = rpc::TablePrefix::RAYLET_TASK; + pubsub_channel_ = TablePubsub::RAYLET_TASK_PUBSUB; + prefix_ = TablePrefix::RAYLET_TASK; } TaskTable(const std::vector> &contexts, @@ -770,13 +766,13 @@ class TaskTable : public Table { } // namespace raylet -class ErrorTable : private Log { +class ErrorTable : private Log { public: ErrorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = rpc::TablePubsub::ERROR_INFO_PUBSUB; - prefix_ = rpc::TablePrefix::ERROR_INFO; + pubsub_channel_ = TablePubsub::ERROR_INFO_PUBSUB; + prefix_ = TablePrefix::ERROR_INFO; }; /// Push an error message for a specific job. @@ -800,19 +796,19 @@ class ErrorTable : private Log { std::string DebugString() const; }; -class ProfileTable : private Log { +class ProfileTable : private Log { public: ProfileTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - prefix_ = rpc::TablePrefix::PROFILE; + prefix_ = TablePrefix::PROFILE; }; /// Add a batch of profiling events to the profile table. /// /// \param profile_events The profile events to record. /// \return Status. - Status AddProfileEventBatch(const rpc::ProfileTableData &profile_events); + Status AddProfileEventBatch(const ProfileTableData &profile_events); /// Returns debug string for class. /// @@ -829,10 +825,10 @@ class ProfileTable : private Log { /// it should append an entry to the log indicating that it is dead. A client /// that is marked as dead should never again be marked as alive; if it needs /// to reconnect, it must connect with a different ClientID. -class ClientTable : public Log { +class ClientTable : public Log { public: using ClientTableCallback = std::function; + AsyncGcsClient *client, const ClientID &id, const ClientTableData &data)>; using DisconnectCallback = std::function; ClientTable(const std::vector> &contexts, AsyncGcsClient *client, const ClientID &client_id) @@ -843,8 +839,8 @@ class ClientTable : public Log { disconnected_(false), client_id_(client_id), local_client_() { - pubsub_channel_ = rpc::TablePubsub::CLIENT_PUBSUB; - prefix_ = rpc::TablePrefix::CLIENT; + pubsub_channel_ = TablePubsub::CLIENT_PUBSUB; + prefix_ = TablePrefix::CLIENT; // Set the local client's ID. local_client_.set_client_id(client_id.Binary()); @@ -856,7 +852,7 @@ class ClientTable : public Log { /// \param Information about the connecting client. This must have the /// same client_id as the one set in the client table. /// \return Status - ray::Status Connect(const rpc::ClientTableData &local_client); + ray::Status Connect(const ClientTableData &local_client); /// Disconnect the client from the GCS. The client ID assigned during /// registration should never be reused after disconnecting. @@ -899,7 +895,7 @@ class ClientTable : public Log { /// about the client in the cache, then the reference will be modified to /// contain that information. Else, the reference will be updated to contain /// a nil client ID. - void GetClient(const ClientID &client, rpc::ClientTableData &client_info) const; + void GetClient(const ClientID &client, ClientTableData &client_info) const; /// Get the local client's ID. /// @@ -909,7 +905,7 @@ class ClientTable : public Log { /// Get the local client's information. /// /// \return The local client's information. - const rpc::ClientTableData &GetLocalClient() const; + const ClientTableData &GetLocalClient() const; /// Check whether the given client is removed. /// @@ -920,7 +916,7 @@ class ClientTable : public Log { /// Get the information of all clients. /// /// \return The client ID to client information map. - const std::unordered_map &GetAllClients() const; + const std::unordered_map &GetAllClients() const; /// Lookup the client data in the client table. /// @@ -942,15 +938,15 @@ class ClientTable : public Log { private: /// Handle a client table notification. void HandleNotification(AsyncGcsClient *client, - const rpc::ClientTableData ¬ifications); + const ClientTableData ¬ifications); /// Handle this client's successful connection to the GCS. - void HandleConnected(AsyncGcsClient *client, const rpc::ClientTableData &client_data); + void HandleConnected(AsyncGcsClient *client, const ClientTableData &client_data); /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. const ClientID client_id_; /// Information about this client. - rpc::ClientTableData local_client_; + ClientTableData local_client_; /// The callback to call when a new client is added. ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. @@ -960,7 +956,7 @@ class ClientTable : public Log { /// The callback to call when a resource is deleted. ClientTableCallback resource_deleted_callback_; /// A cache for information about all clients. - std::unordered_map client_cache_; + std::unordered_map client_cache_; /// The set of removed clients. std::unordered_set removed_clients_; }; diff --git a/src/ray/rpc/message_wrapper.h b/src/ray/rpc/message_wrapper.h deleted file mode 100644 index c9feb5809375..000000000000 --- a/src/ray/rpc/message_wrapper.h +++ /dev/null @@ -1,63 +0,0 @@ -#ifndef RAY_RPC_WRAPPER_H -#define RAY_RPC_WRAPPER_H - -#include -#include - -namespace ray { - -namespace rpc { - -template -class MessageWrapper { - public: - explicit MessageWrapper(Message &message) : message_(&message) {} - - explicit MessageWrapper(std::unique_ptr message) - : message_unique_ptr_(std::move(message)), message_(message_unique_ptr_.get()) {} - - MessageWrapper(const MessageWrapper &from) - : MessageWrapper(std::unique_ptr(new Message(from.GetMessage()))) {} - - const Message &GetMessage() const { return *message_; } - - const std::string Serialize() const { - std::string ret; - message_->SerializeToString(&ret); - return ret; - } - - protected: - std::unique_ptr message_unique_ptr_; - Message *message_; -}; - -template -class ConstMessageWrapper { - public: - explicit ConstMessageWrapper(const Message &message) : message_(&message) {} - - explicit ConstMessageWrapper(std::unique_ptr message) - : message_unique_ptr_(std::move(message)), message_(message_unique_ptr_.get()) {} - - ConstMessageWrapper(const ConstMessageWrapper &from) - : ConstMessageWrapper(std::unique_ptr(new Message(from.GetMessage()))) {} - - const Message &GetMessage() const { return *message_; } - - const std::string Serialize() const { - std::string ret; - message_->SerializeToString(&ret); - return ret; - } - - protected: - std::unique_ptr message_unique_ptr_; - const Message *message_; -}; - -} // namespace rpc - -} // namespace ray - -#endif // RAY_RPC_WRAPPER_H From 5ed942632ea617a94cf7bb68124ce9c1ba880a47 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 21 Jun 2019 18:08:31 +0800 Subject: [PATCH 09/34] keep profile table in fbs --- src/ray/gcs/format/gcs.fbs | 27 +++++++++++++++++++++++++++ src/ray/raylet/node_manager.cc | 14 ++++++++++++-- src/ray/raylet/raylet_client.cc | 7 +++---- src/ray/raylet/raylet_client.h | 5 +---- 4 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 0b35cce81a0d..26fe6772a52b 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -74,3 +74,30 @@ table ResourcePair { // The quantity of the resource. value: double; } + +table ProfileEvent { + // The type of the event. + event_type: string; + // The start time of the event. + start_time: double; + // The end time of the event. If the event is a point event, then this should + // be the same as the start time. + end_time: double; + // Additional data associated with the event. This data must be serialized + // using JSON. + extra_data: string; +} + +table ProfileTableData { + // The type of the component that generated the event, e.g., worker or + // object_manager, or node_manager. + component_type: string; + // An identifier for the component that generated the event. + component_id: string; + // An identifier for the node that generated the event. + node_ip_address: string; + // This is a batch of profiling events. We batch these together for + // performance reasons because a single task may generate many events, and + // we don't want each event to require a GCS command. + profile_events: [ProfileEvent]; +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index ab408cd53d86..e697df520939 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -792,8 +792,18 @@ void NodeManager::ProcessClientMessage( ProcessPushErrorRequestMessage(message_data); } break; case protocol::MessageType::PushProfileEventsRequest: { - auto message = flatbuffers::GetRoot(message_data); - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message)); + auto fbs_message = flatbuffers::GetRoot(message_data); + ProfileTableData profile_table_data; + profile_table_data.set_component_type(fbs_message->component_type); + profile_table_data.set_component_id(fbs_message->component_id); + for (const auto &fbs_event : fbs_message->profile_events) { + ProfileEvent *event = profile_table_data.add_profile_events(); + event->set_event_type(fbs_event->event_type); + event->set_start_time(fbs_event->start_time); + event->set_end_time(fbs_event->end_time); + event->set_extra_data(fbs_event->extra_data); + } + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); } break; case protocol::MessageType::FreeObjectsInObjectStoreRequest: { auto message = flatbuffers::GetRoot(message_data); diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 63dbcc3e08db..801bb9112241 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -334,11 +334,10 @@ ray::Status RayletClient::PushError(const DriverID &driver_id, const std::string return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb); } -ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) { +ray::Status RayletClient::PushProfileEvents(const ProfileTableDataT &profile_events) { flatbuffers::FlatBufferBuilder fbb; -// auto message = CreateProfileTableData(fbb, &profile_events); -// fbb.Finish(message); -// XXX + auto message = CreateProfileTableData(fbb, &profile_events); + fbb.Finish(message); auto status = conn_->WriteMessage(MessageType::PushProfileEventsRequest, &fbb); // Don't be too strict for profile errors. Just create logs and prevent it from crash. diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 5437455bffc6..8b4dfad5b37a 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -8,7 +8,6 @@ #include "ray/common/status.h" #include "ray/raylet/task_spec.h" -#include "ray/protobuf/gcs.pb.h" using ray::ActorCheckpointID; using ray::ActorID; @@ -18,8 +17,6 @@ using ray::ObjectID; using ray::TaskID; using ray::UniqueID; -using ray::rpc::ProfileTableData; - using MessageType = ray::protocol::MessageType; using ResourceMappingType = std::unordered_map>>; @@ -140,7 +137,7 @@ class RayletClient { /// /// \param profile_events A batch of profiling event information. /// \return ray::Status. - ray::Status PushProfileEvents(const ProfileTableData &profile_events); + ray::Status PushProfileEvents(const ProfileTableDataT &profile_events); /// Free a list of objects from object stores. /// From 9c71799f02f3f68a6fbbeb04d16a746ea4f3833a Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 21 Jun 2019 18:23:32 +0800 Subject: [PATCH 10/34] refine lineage cache --- src/ray/raylet/lineage_cache.cc | 28 +--------------------------- src/ray/raylet/lineage_cache.h | 22 ++-------------------- 2 files changed, 3 insertions(+), 47 deletions(-) diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 5959c9b30bb8..68d5aa817c2b 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -63,15 +63,6 @@ void LineageEntry::UpdateTaskData(const Task &task) { Lineage::Lineage() {} -Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) { - // Deserialize and set entries for the uncommitted tasks. - auto tasks = task_request.uncommitted_tasks(); - for (auto it = tasks->begin(); it != tasks->end(); it++) { - const auto &task = **it; - RAY_CHECK(SetEntry(task, GcsStatus::UNCOMMITTED)); - } -} - boost::optional Lineage::GetEntry(const TaskID &task_id) const { auto entry = entries_.find(task_id); if (entry != entries_.end()) { @@ -151,20 +142,6 @@ const std::unordered_map &Lineage::GetEntries() cons return entries_; } -flatbuffers::Offset Lineage::ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb, const TaskID &task_id) const { - RAY_CHECK(GetEntry(task_id)); - // Serialize the task and object entries. - std::vector> uncommitted_tasks; - for (const auto &entry : entries_) { - uncommitted_tasks.push_back(entry.second.TaskData().ToFlatbuffer(fbb)); - } - - auto request = protocol::CreateForwardTaskRequest(fbb, to_flatbuf(fbb, task_id), - fbb.CreateVector(uncommitted_tasks)); - return request; -} - const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) const { static const std::unordered_set empty_children; const auto it = children_.find(task_id); @@ -295,11 +272,8 @@ void LineageCache::FlushTask(const TaskID &task_id) { const TaskTableData &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... - flatbuffers::FlatBufferBuilder fbb; - auto message = task->TaskData().ToFlatbuffer(fbb); - fbb.Finish(message); auto task_data = std::make_shared(); - task_data->set_task(reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + task_data->set_task(task->TaskData().Serialize()); RAY_CHECK_OK( task_storage_.Add(DriverID(task->TaskData().GetTaskSpecification().DriverId()), task_id, task_data, task_callback)); diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 7956022a77ea..37ce5caf6507 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -4,13 +4,10 @@ #include #include -// clang-format off -#include "ray/common/common_protocol.h" -#include "ray/raylet/task.h" -#include "ray/gcs/tables.h" #include "ray/common/id.h" #include "ray/common/status.h" -// clang-format on +#include "ray/gcs/tables.h" +#include "ray/raylet/task.h" namespace ray { @@ -138,12 +135,6 @@ class Lineage { /// Construct an empty Lineage. Lineage(); - /// Construct a Lineage from a ForwardTaskRequest. - /// - /// \param task_request The request to construct the lineage from. All - /// uncommitted tasks in the request will be added to the lineage. - Lineage(const protocol::ForwardTaskRequest &task_request); - /// Get an entry from the lineage. /// /// \param entry_id The ID of the entry to get. @@ -174,15 +165,6 @@ class Lineage { /// \return A const reference to the lineage entries. const std::unordered_map &GetEntries() const; - /// Serialize this lineage to a ForwardTaskRequest flatbuffer. - /// - /// \param entry_id The task ID to include in the ForwardTaskRequest - /// flatbuffer. - /// \return An offset to the serialized lineage. The serialization includes - /// all task and object entries in the lineage. - flatbuffers::Offset ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb, const TaskID &entry_id) const; - /// Return the IDs of tasks in the lineage that are dependent on the given /// task. /// From 1729841b1edd6914b9c54597a4d6d7e8a83352d8 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 21 Jun 2019 18:29:21 +0800 Subject: [PATCH 11/34] remove bazel gcs fbs --- BUILD.bazel | 24 ------------------------ java/BUILD.bazel | 25 ------------------------- 2 files changed, 49 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 3801cb95bc7e..cd6e44f82da4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -585,36 +585,12 @@ flatbuffer_py_library( ":gcs_fbs_file", ], outs = [ - "ActorCheckpointIdData.py", - "ActorState.py", - "ActorTableData.py", "Arg.py", - "ClassTableData.py", - "ClientTableData.py", - "ConfigTableData.py", - "CustomSerializerData.py", - "DriverTableData.py", - "EntryType.py", - "ErrorTableData.py", - "ErrorType.py", - "FunctionTableData.py", - "GcsEntry.py", - "HeartbeatBatchTableData.py", - "HeartbeatTableData.py", "Language.py", - "ObjectTableData.py", "ProfileEvent.py", "ProfileTableData.py", "RayResource.py", - "ResourcePair.py", - "SchedulingState.py", - "TablePrefix.py", - "TablePubsub.py", "TaskInfo.py", - "TaskLeaseData.py", - "TaskReconstructionData.py", - "TaskTableData.py", - "TaskTableTestAndUpdate.py", ], out_prefix = "python/ray/core/generated/", ) diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 80ccabccfc12..9fc48c70feb7 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -149,37 +149,12 @@ java_binary( ) flatbuffers_generated_files = [ - "ActorCheckpointData.java", - "ActorCheckpointIdData.java", - "ActorState.java", - "ActorTableData.java", "Arg.java", - "ClassTableData.java", - "ClientTableData.java", - "ConfigTableData.java", - "CustomSerializerData.java", - "DriverTableData.java", - "EntryType.java", - "ErrorTableData.java", - "ErrorType.java", - "FunctionTableData.java", - "GcsEntry.java", - "HeartbeatBatchTableData.java", - "HeartbeatTableData.java", "Language.java", - "ObjectTableData.java", "ProfileEvent.java", "ProfileTableData.java", "RayResource.java", - "ResourcePair.java", - "SchedulingState.java", - "TablePrefix.java", - "TablePubsub.java", "TaskInfo.java", - "TaskLeaseData.java", - "TaskReconstructionData.java", - "TaskTableData.java", - "TaskTableTestAndUpdate.java", ] flatbuffer_java_library( From 8b9b8caac80ba0d4c01c1f97ac5532080cc7a64e Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 21 Jun 2019 18:35:43 +0800 Subject: [PATCH 12/34] refine redis module --- src/ray/gcs/redis_module/ray_redis_module.cc | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 1bb258cd5a14..6d4b0f50ac84 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -170,20 +170,6 @@ Status GetBroadcastKey(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st return Status::OK(); } -/// This is a helper method to convert a redis module string to a flatbuffer -/// string. -/// -/// \param fbb The flatbuffer builder. -/// \param redis_string The redis string. -/// \return The flatbuffer string. -flatbuffers::Offset RedisStringToFlatbuf( - flatbuffers::FlatBufferBuilder &fbb, RedisModuleString *redis_string) { - size_t redis_string_size; - const char *redis_string_str = - RedisModule_StringPtrLen(redis_string, &redis_string_size); - return fbb.CreateString(redis_string_str, redis_string_size); -} - inline void CreateGcsEntry(RedisModuleString *id, GcsChangeMode change_mode, const std::vector &entries, GcsEntry *result) { @@ -590,8 +576,6 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, gcs_entry.ParseFromArray(update_data_buf, update_data_len); *change_mode = gcs_entry.change_mode(); - auto data_vec = flatbuffers::GetRoot(update_data_buf); - *change_mode = data_vec->change_mode(); if (*change_mode == GcsChangeMode::APPEND_OR_ADD) { // This code path means they are updating command. size_t total_size = gcs_entry.entries_size(); @@ -715,7 +699,6 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, return Status::RedisError("Empty list/set/hash or wrong type"); } CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); - std::vector data; for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) { RedisModuleCallReply *element = RedisModule_CallReplyArrayElement(reply, i); size_t len; From 8062a045eecb8e0488f14802cffed3c8ee58801d Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 21 Jun 2019 19:02:24 +0800 Subject: [PATCH 13/34] small updates --- BUILD.bazel | 2 +- java/BUILD.bazel | 2 +- src/ray/gcs/tables.cc | 96 +++++++++++++++------------------- src/ray/gcs/util.h | 11 ---- src/ray/raylet/node_manager.cc | 1 - 5 files changed, 43 insertions(+), 69 deletions(-) delete mode 100644 src/ray/gcs/util.h diff --git a/BUILD.bazel b/BUILD.bazel index cd6e44f82da4..26341c497a72 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -589,8 +589,8 @@ flatbuffer_py_library( "Language.py", "ProfileEvent.py", "ProfileTableData.py", - "RayResource.py", "TaskInfo.py", + "ResourcePair.py" ], out_prefix = "python/ray/core/generated/", ) diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 9fc48c70feb7..5c1bca1580b5 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -153,8 +153,8 @@ flatbuffers_generated_files = [ "Language.java", "ProfileEvent.java", "ProfileTableData.java", - "RayResource.java", "TaskInfo.java", + "ResourcePair.java", ] flatbuffer_java_library( diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index b6cb3cf6514b..b7c19ebfd595 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -51,8 +51,7 @@ Status Log::Append(const DriverID &driver_id, const ID &id, (done)(client_, id, *data); } }; - std::string str; - data->SerializeToString(&str); + std::string str = data->SerializeAsString(); return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), str.length(), prefix_, pubsub_channel_, std::move(callback)); @@ -75,8 +74,7 @@ Status Log::AppendAt(const DriverID &driver_id, const ID &id, } } }; - std::string str; - data->SerializeToString(&str); + std::string str = data->SerializeAsString(); return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), str.length(), prefix_, pubsub_channel_, std::move(callback), log_length); @@ -236,8 +234,7 @@ Status Table::Add(const DriverID &driver_id, const ID &id, (done)(client_, id, *data); } }; - std::string str; - data->SerializeToString(&str); + std::string str = data->SerializeAsString(); return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, str.data(), str.length(), prefix_, pubsub_channel_, std::move(callback)); @@ -300,8 +297,7 @@ Status Set::Add(const DriverID &driver_id, const ID &id, (done)(client_, id, *data); } }; - std::string str; - data->SerializeToString(&str); + std::string str = data->SerializeAsString(); return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, str.data(), str.length(), prefix_, pubsub_channel_, std::move(callback)); } @@ -315,8 +311,7 @@ Status Set::Remove(const DriverID &driver_id, const ID &id, (done)(client_, id, *data); } }; - std::string str; - data->SerializeToString(&str); + std::string str = data->SerializeAsString(); return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, str.data(), str.length(), prefix_, pubsub_channel_, std::move(callback)); } @@ -392,18 +387,14 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, gcs_entry.ParseFromString(reply.ReadAsString()); RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); RAY_CHECK(gcs_entry.entries_size() % 2 == 0); - DataMap result; for (int i = 0; i < gcs_entry.entries_size(); i += 2) { const auto &key = gcs_entry.entries(i); const auto value = std::make_shared(); value->ParseFromString(gcs_entry.entries(i + 1)); - result.emplace(key, std::move(value)); + results.emplace(key, std::move(value)); } - - lookup(client_, id, result); - } else { - lookup(client_, id, {}); } + lookup(client_, id, results); } }; std::vector nil; @@ -474,9 +465,9 @@ std::string ErrorTable::DebugString() const { } Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { + // TODO(hchen): Change the parameter to shared_ptr to avoid copying data. auto data = std::make_shared(); data->CopyFrom(profile_events); - // XXX return Append(DriverID::Nil(), UniqueID::FromRandom(), data, /*done_callback=*/nullptr); } @@ -507,8 +498,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && - entry.second.entry_type() == ClientTableData::DELETION) { + if (!entry.first.IsNil() && entry.second.entry_type() == ClientTableData::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -552,9 +542,8 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // was alive and is now dead or resources have been updated. bool was_not_deleted = (entry->second.entry_type() != ClientTableData::DELETION); bool is_deleted = (data.entry_type() == ClientTableData::DELETION); - bool is_res_modified = - ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || - (data.entry_type() == ClientTableData::RES_DELETE)); + bool is_res_modified = ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)); is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check @@ -572,14 +561,14 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, (data.entry_type() == ClientTableData::DELETION)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " "notification for client id " - << client_id << ". ClientTableData: " << int(data.entry_type()) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Setting the client cache to data."; client_cache_[client_id] = data; } else if ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || (data.entry_type() == ClientTableData::RES_DELETE)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " "notification for client id " - << client_id << ". ClientTableData: " << int(data.entry_type()) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Updating the client cache with the delta from the log."; ClientTableData &cache_data = client_cache_[client_id]; @@ -597,8 +586,7 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, existing_resource_label); // Resource already exists, set capacity if updation call.. if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { - // cache_data.mutable_resources_total_capacity()[index] = capacity; - // XXX:639 + cache_data.set_resources_total_capacity(index, capacity); } // .. delete if deletion call. else if (data.entry_type() == ClientTableData::RES_DELETE) { @@ -646,8 +634,7 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } } -void ClientTable::HandleConnected(AsyncGcsClient *client, - const ClientTableData &data) { +void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableData &data) { auto connected_client_id = ClientID::FromBinary(data.client_id()); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; @@ -678,32 +665,32 @@ Status ClientTable::Connect(const ClientTableData &local_client) { HandleConnected(client, data); // Callback for a notification from the client table. - auto notification_callback = - [this](AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { - RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; - for (auto ¬ification : notifications) { - // This is temporary fix for Issue 4140 to avoid connect to dead nodes. - // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type() != ClientTableData::DELETION) { - connected_nodes.emplace(notification.client_id(), notification); - } else { - auto iter = connected_nodes.find(notification.client_id()); - if (iter != connected_nodes.end()) { - connected_nodes.erase(iter); - } - disconnected_nodes.emplace(notification.client_id(), notification); - } - } - for (const auto &pair : connected_nodes) { - HandleNotification(client, pair.second); - } - for (const auto &pair : disconnected_nodes) { - HandleNotification(client, pair.second); + auto notification_callback = [this]( + AsyncGcsClient *client, const UniqueID &log_key, + const std::vector ¬ifications) { + RAY_CHECK(log_key == client_log_key_); + std::unordered_map connected_nodes; + std::unordered_map disconnected_nodes; + for (auto ¬ification : notifications) { + // This is temporary fix for Issue 4140 to avoid connect to dead nodes. + // TODO(yuhguo): remove this temporary fix after GCS entry is removable. + if (notification.entry_type() != ClientTableData::DELETION) { + connected_nodes.emplace(notification.client_id(), notification); + } else { + auto iter = connected_nodes.find(notification.client_id()); + if (iter != connected_nodes.end()) { + connected_nodes.erase(iter); } - }; + disconnected_nodes.emplace(notification.client_id(), notification); + } + } + for (const auto &pair : connected_nodes) { + HandleNotification(client, pair.second); + } + for (const auto &pair : disconnected_nodes) { + HandleNotification(client, pair.second); + } + }; // Callback to request notifications from the client table once we've // successfully subscribed. auto subscription_callback = [this](AsyncGcsClient *c) { @@ -751,8 +738,7 @@ void ClientTable::GetClient(const ClientID &client_id, } } -const std::unordered_map &ClientTable::GetAllClients() - const { +const std::unordered_map &ClientTable::GetAllClients() const { return client_cache_; } diff --git a/src/ray/gcs/util.h b/src/ray/gcs/util.h deleted file mode 100644 index b536463f1190..000000000000 --- a/src/ray/gcs/util.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef ANT_RAY_SRC_RAY_GCS_UTIL_H_ -#define ANT_RAY_SRC_RAY_GCS_UTIL_H_ - -namespace ray { -namespace gcs { - -}; - -}; // namespace ray - -#endif // ANT_RAY_SRC_RAY_GCS_UTIL_H_ diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e697df520939..ced4b3aac560 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -336,7 +336,6 @@ void NodeManager::GetObjectManagerProfileInfo() { auto profile_info = object_manager_.GetAndResetProfilingInfo(); if (profile_info.profile_events_size() > 0) { - flatbuffers::FlatBufferBuilder fbb; RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_info)); } From 8fa6c659d1f5ab7c5cbd6b7c20b4a125666f76ce Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 21 Jun 2019 21:13:29 +0800 Subject: [PATCH 14/34] java compilable --- BUILD.bazel | 1 + bazel/ray_deps_build_all.bzl | 2 + bazel/ray_deps_setup.bzl | 7 ++ java/BUILD.bazel | 69 +++++++++++-------- java/dependencies.bzl | 1 + ...modify_generated_java_flatbuffers_files.py | 16 ++--- java/runtime/pom.xml | 5 ++ .../java/org/ray/runtime/gcs/GcsClient.java | 63 ++++++++++------- src/ray/protobuf/gcs.proto | 9 +-- 9 files changed, 103 insertions(+), 70 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 26341c497a72..3696e137d5dc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -11,6 +11,7 @@ COPTS = ["-DRAY_USE_GLOG"] proto_library( name = "gcs_proto", srcs = ["src/ray/protobuf/gcs.proto"], + visibility = ["//java:__subpackages__"], ) cc_proto_library( diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index 3e1e1838a59a..d77f62549d53 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -4,6 +4,7 @@ load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_rep load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure") load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +load("@build_stack_rules_proto//java:deps.bzl", "java_proto_compile") def ray_deps_build_all(): @@ -13,4 +14,5 @@ def ray_deps_build_all(): prometheus_cpp_repositories() python_configure(name = "local_config_python") grpc_deps() + java_proto_compile() diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index 1d0a4e7632f4..aa322654cf9f 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -109,3 +109,10 @@ def ray_deps_setup(): ], strip_prefix = "grpc-76a381869413834692b8ed305fbe923c0f9c4472", ) + + http_archive( + name = "build_stack_rules_proto", + urls = ["https://github.com/stackb/rules_proto/archive/b93b544f851fdcd3fc5c3d47aee3b7ca158a8841.tar.gz"], + sha256 = "c62f0b442e82a6152fcd5b1c0b7c4028233a9e314078952b6b04253421d56d61", + strip_prefix = "rules_proto-b93b544f851fdcd3fc5c3d47aee3b7ca158a8841", + ) diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 5c1bca1580b5..114ae144f1bd 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -1,4 +1,5 @@ load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module") +load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile") exports_files([ "testng.xml", @@ -48,9 +49,6 @@ define_java_module( define_java_module( name = "runtime", - additional_srcs = [ - ":generate_java_gcs_fbs", - ], additional_resources = [ ":java_native_deps", ], @@ -65,9 +63,11 @@ define_java_module( ], deps = [ ":org_ray_ray_api", + ":copy_auto_generated_files", "@plasma//:org_apache_arrow_arrow_plasma", "@maven//:com_github_davidmoten_flatbuffers_java", "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_typesafe_config", "@maven//:commons_io_commons_io", "@maven//:de_ruedigermoeller_fst", @@ -148,11 +148,14 @@ java_binary( ], ) +java_proto_compile( + name = "gcs_java_proto", + deps = ["@//:gcs_proto"], +) + flatbuffers_generated_files = [ "Arg.java", "Language.java", - "ProfileEvent.java", - "ProfileTableData.java", "TaskInfo.java", "ResourcePair.java", ] @@ -164,22 +167,6 @@ flatbuffer_java_library( out_prefix = "", ) -genrule( - name = "generate_java_gcs_fbs", - srcs = [":java_gcs_fbs"], - outs = [ - "runtime/src/main/java/org/ray/runtime/generated/" + file for file in flatbuffers_generated_files - ], - cmd = """ - for f in $(locations //java:java_gcs_fbs); do - chmod +w $$f - cp -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated - done - python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/.. - """, - local = 1, -) - filegroup( name = "java_native_deps", srcs = [ @@ -192,12 +179,41 @@ filegroup( ], ) +genrule( + name = "copy_auto_generated_files", + srcs = [ + ":gcs_java_proto", + ":java_gcs_fbs", + ], + outs = ["copy_auto_generated_files.out"], + cmd = """ + set -x + WORK_DIR=$$(pwd) + # Copy protobuf-generated files. + GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated + rm -rf $$GENERATED_DIR + mkdir -p $$GENERATED_DIR + for f in $(locations //java:gcs_java_proto); do + unzip $$f + mv org/ray/runtime/generated/* $$GENERATED_DIR + done + # Copy flatbuffers-generated files + for f in $(locations //java:java_gcs_fbs); do + chmod +w $$f + cp $$f $$GENERATED_DIR + done + echo $$(date) > $@ + """, + local = 1, + tags = ["no-cache"], +) + # Generates the depedencies needed by maven. genrule( name = "gen_maven_deps", srcs = [ + ":copy_auto_generated_files", ":java_native_deps", - ":generate_java_gcs_fbs", "@plasma//:org_apache_arrow_arrow_plasma", ], outs = ["gen_maven_deps.out"], @@ -212,19 +228,14 @@ genrule( chmod +w $$f cp $$f $$NATIVE_DEPS_DIR done - # Copy flatbuffers-generated files - GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated - rm -rf $$GENERATED_DIR - mkdir -p $$GENERATED_DIR - for f in $(locations //java:generate_java_gcs_fbs); do - cp $$f $$GENERATED_DIR - done + python $$WORK_DIR/java/modify_generated_java_flatbuffers_files.py $$WORK_DIR # Install plasma jar to local maven repo. mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \ -DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT echo $$(date) > $@ """, local = 1, + tags = ["no-cache"], ) genrule( diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 7c716166d399..ef667137562b 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -6,6 +6,7 @@ def gen_java_deps(): "com.beust:jcommander:1.72", "com.github.davidmoten:flatbuffers-java:1.9.0.1", "com.google.guava:guava:27.0.1-jre", + "com.google.protobuf:protobuf-java:3.8.0", "com.puppycrawl.tools:checkstyle:8.15", "com.sun.xml.bind:jaxb-core:2.3.0", "com.sun.xml.bind:jaxb-impl:2.3.0", diff --git a/java/modify_generated_java_flatbuffers_files.py b/java/modify_generated_java_flatbuffers_files.py index c1b723f25f8d..fa1867bde8a7 100644 --- a/java/modify_generated_java_flatbuffers_files.py +++ b/java/modify_generated_java_flatbuffers_files.py @@ -21,19 +21,18 @@ PACKAGE_DECLARATION = "package org.ray.runtime.generated;" -def add_new_line(file, line_num, text): +def add_package(file): with open(file, "r") as file_handler: lines = file_handler.readlines() - if (line_num <= 0) or (line_num > len(lines) + 1): - return False - lines.insert(line_num - 1, text + os.linesep) + if "FlatBuffers" not in lines[0]: + return + + lines.insert(1, PACKAGE_DECLARATION + os.linesep) with open(file, "w") as file_handler: for line in lines: file_handler.write(line) - return True - def add_package_declarations(generated_root_path): file_names = os.listdir(generated_root_path) @@ -41,10 +40,7 @@ def add_package_declarations(generated_root_path): if not file_name.endswith(".java"): continue full_name = os.path.join(generated_root_path, file_name) - success = add_new_line(full_name, 2, PACKAGE_DECLARATION) - if not success: - raise RuntimeError("Failed to add package declarations, " - "file name is %s" % full_name) + add_package(full_name) if __name__ == "__main__": diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index c75e2eeef13f..e13dd95f927f 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -41,6 +41,11 @@ guava 27.0.1-jre + + com.google.protobuf + protobuf-java + 3.8.0 + com.typesafe config diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 431b48ded58c..7bfccddf550d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -1,7 +1,7 @@ package org.ray.runtime.gcs; import com.google.common.base.Preconditions; -import java.nio.ByteBuffer; +import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -15,7 +15,7 @@ import org.ray.api.runtimecontext.NodeInfo; import org.ray.runtime.generated.ActorCheckpointIdData; import org.ray.runtime.generated.ClientTableData; -import org.ray.runtime.generated.EntryType; +import org.ray.runtime.generated.ClientTableData.EntryType; import org.ray.runtime.generated.TablePrefix; import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; @@ -51,7 +51,7 @@ public GcsClient(String redisAddress, String redisPassword) { } public List getAllNodeInfo() { - final String prefix = TablePrefix.name(TablePrefix.CLIENT); + final String prefix = TablePrefix.CLIENT.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes()); List results = primary.lrange(key, 0, -1); @@ -63,36 +63,42 @@ public List getAllNodeInfo() { Map clients = new HashMap<>(); for (byte[] result : results) { Preconditions.checkNotNull(result); - ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result)); - final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer()); + ClientTableData data = null; + try { + data = ClientTableData.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Received invaild protobuf data from GCS."); + } + final UniqueId clientId = UniqueId + .fromByteBuffer(data.getClientIdBytes().asReadOnlyByteBuffer()); - if (data.entryType() == EntryType.INSERTION) { + if (data.getEntryType() == EntryType.INSERTION) { //Code path of node insertion. Map resources = new HashMap<>(); // Compute resources. Preconditions.checkState( - data.resourcesTotalLabelLength() == data.resourcesTotalCapacityLength()); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + data.getResourcesTotalLabelCount() == data.getResourcesTotalCapacityCount()); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); } NodeInfo nodeInfo = new NodeInfo( - clientId, data.nodeManagerAddress(), true, resources); + clientId, data.getNodeManagerAddress(), true, resources); clients.put(clientId, nodeInfo); - } else if (data.entryType() == EntryType.RES_CREATEUPDATE) { + } else if (data.getEntryType() == EntryType.RES_CREATEUPDATE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + nodeInfo.resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); } - } else if (data.entryType() == EntryType.RES_DELETE) { + } else if (data.getEntryType() == EntryType.RES_DELETE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - nodeInfo.resources.remove(data.resourcesTotalLabel(i)); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + nodeInfo.resources.remove(data.getResourcesTotalLabel(i)); } } else { // Code path of node deletion. - Preconditions.checkState(data.entryType() == EntryType.DELETION); + Preconditions.checkState(data.getEntryType() == EntryType.DELETION); NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress, false, clients.get(clientId).resources); clients.put(clientId, nodeInfo); @@ -107,7 +113,7 @@ public List getAllNodeInfo() { */ public boolean actorExists(UniqueId actorId) { byte[] key = ArrayUtils.addAll( - TablePrefix.name(TablePrefix.ACTOR).getBytes(), actorId.getBytes()); + TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes()); return primary.exists(key); } @@ -115,7 +121,7 @@ public boolean actorExists(UniqueId actorId) { * Query whether the raylet task exists in Gcs. */ public boolean rayletTaskExistsInGcs(TaskId taskId) { - byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(), + byte[] key = ArrayUtils.addAll(TablePrefix.RAYLET_TASK.toString().getBytes(), taskId.getBytes()); RedisClient client = getShardClient(taskId); return client.exists(key); @@ -126,19 +132,26 @@ public boolean rayletTaskExistsInGcs(TaskId taskId) { */ public List getCheckpointsForActor(UniqueId actorId) { List checkpoints = new ArrayList<>(); - final String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID); + final String prefix = TablePrefix.ACTOR_CHECKPOINT_ID.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes()); RedisClient client = getShardClient(actorId); byte[] result = client.get(key); if (result != null) { - ActorCheckpointIdData data = - ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result)); - UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer( - data.checkpointIdsAsByteBuffer()); + ActorCheckpointIdData data = null; + try { + data = ActorCheckpointIdData.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Received invaild protobuf data from GCS."); + } + UniqueId[] checkpointIds = new UniqueId[data.getCheckpointIdsCount()]; + for (int i = 0; i < checkpointIds.length; i++) { + checkpointIds[i] = UniqueId + .fromByteBuffer(data.getCheckpointIdsBytes(i).asReadOnlyByteBuffer()); + } for (int i = 0; i < checkpointIds.length; i++) { - checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i))); + checkpoints.add(new Checkpoint(checkpointIds[i], data.getTimestamps(i))); } } checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp)); diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 66a7e51d5f16..3e5a82b437b8 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -2,6 +2,9 @@ syntax = "proto3"; package ray.rpc; +option java_multiple_files = true; +option java_package = "org.ray.runtime.generated"; + enum Language { PYTHON = 0; CPP = 1; @@ -60,12 +63,6 @@ message GcsEntry { repeated string entries = 3; } -message FunctionTableData { - Language language = 1; - string name = 2; - string data = 3; -} - message ObjectTableData { // The size of the object. uint64 object_size = 1; From 064d5320acea7eb1c337c89b277014015ed6d7d9 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 21 Jun 2019 21:24:19 +0800 Subject: [PATCH 15/34] change string to bytes --- src/ray/protobuf/gcs.proto | 42 +++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 3e5a82b437b8..ff1db2115fbf 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -59,26 +59,26 @@ enum GcsChangeMode { message GcsEntry { GcsChangeMode change_mode = 1; - string id = 2; - repeated string entries = 3; + bytes id = 2; + repeated bytes entries = 3; } message ObjectTableData { // The size of the object. uint64 object_size = 1; // The node manager ID that this object appeared on or was evicted by. - string manager = 2; + bytes manager = 2; } message TaskReconstructionData { // The number of times this task has been reconstructed so far. uint64 num_reconstructions = 1; // The node manager that is trying to reconstruct the task. - string node_manager_id = 2; + bytes node_manager_id = 2; } message TaskTableData { - string task = 1; + bytes task = 1; } message ActorTableData { @@ -92,15 +92,15 @@ message ActorTableData { DEAD = 2; } // The ID of the actor that was created. - string actor_id = 1; + bytes actor_id = 1; // The dummy object ID returned by the actor creation task. If the actor // dies, then this is the object that should be reconstructed for the actor // to be recreated. - string actor_creation_dummy_object_id = 2; + bytes actor_creation_dummy_object_id = 2; // The ID of the driver that created the actor. - string driver_id = 3; + bytes driver_id = 3; // The ID of the node manager that created the actor. - string node_manager_id = 4; + bytes node_manager_id = 4; // Current state of this actor. ActorState state = 5; // Max number of times this actor should be reconstructed. @@ -111,7 +111,7 @@ message ActorTableData { message ErrorTableData { // The ID of the driver that the error is for. - string driver_id = 1; + bytes driver_id = 1; // The type of the error. string type = 2; // The error message. @@ -164,7 +164,7 @@ message ClientTableData { } // The client ID of the client that the message is about. - string client_id = 1; + bytes client_id = 1; // The IP address of the client's node manager. string node_manager_address = 2; // The IPC socket name of the client's raylet. @@ -186,7 +186,7 @@ message ClientTableData { message HeartbeatTableData { // Node manager client id - string client_id = 1; + bytes client_id = 1; // Resource capacity currently available on this node manager. repeated string resources_available_label = 2; repeated double resources_available_capacity = 3; @@ -205,7 +205,7 @@ message HeartbeatBatchTableData { // Data for a lease on task execution. message TaskLeaseData { // Node manager client ID. - string node_manager_id = 1; + bytes node_manager_id = 1; // The time that the lease was last acquired at. NOTE(swang): This is the // system clock time according to the node that added the entry and is not // synchronized with other nodes. @@ -216,7 +216,7 @@ message TaskLeaseData { message DriverTableData { // The driver ID. - string driver_id = 1; + bytes driver_id = 1; // Whether it's dead. bool is_dead = 2; } @@ -226,17 +226,17 @@ message DriverTableData { // See `actor_registration.h` for more detailed explanation of these fields. message ActorCheckpointData { // ID of this actor. - string actor_id = 1; + bytes actor_id = 1; // The dummy object ID of actor's most recently executed task. - string execution_dependency = 2; + bytes execution_dependency = 2; // A list of IDs of this actor's handles. - repeated string handle_ids = 3; + repeated bytes handle_ids = 3; // The task counters of the above handles. repeated uint64 task_counters = 4; // The frontier dependencies of the above handles. - repeated string frontier_dependencies = 5; + repeated bytes frontier_dependencies = 5; // A list of unreleased dummy objects from this actor. - repeated string unreleased_dummy_objects = 6; + repeated bytes unreleased_dummy_objects = 6; // The numbers of dependencies for the above unreleased dummy objects. repeated uint32 num_dummy_object_dependencies = 7; } @@ -244,9 +244,9 @@ message ActorCheckpointData { // This table stores the actor-to-available-checkpoint-ids mapping. message ActorCheckpointIdData { // ID of this actor. - string actor_id = 1; + bytes actor_id = 1; // IDs of this actor's available checkpoints. - repeated string checkpoint_ids = 2; + repeated bytes checkpoint_ids = 2; // A list of the timestamps for each of the above `checkpoint_ids`. repeated uint64 timestamps = 3; } From 0dd6a1aafd08aab2803aad119a5bb68c1eec843d Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Sat, 22 Jun 2019 17:37:38 +0800 Subject: [PATCH 16/34] fix --- java/BUILD.bazel | 3 ++- .../src/main/java/org/ray/runtime/gcs/GcsClient.java | 4 ++-- .../org/ray/runtime/objectstore/ObjectStoreProxy.java | 10 +++++----- src/ray/raylet/node_manager.cc | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 114ae144f1bd..5034171db93a 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -193,6 +193,7 @@ genrule( GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated rm -rf $$GENERATED_DIR mkdir -p $$GENERATED_DIR + # TODO(hchen): Only copy files needed by Java. for f in $(locations //java:gcs_java_proto); do unzip $$f mv org/ray/runtime/generated/* $$GENERATED_DIR @@ -202,6 +203,7 @@ genrule( chmod +w $$f cp $$f $$GENERATED_DIR done + python $$WORK_DIR/java/modify_generated_java_flatbuffers_files.py $$WORK_DIR echo $$(date) > $@ """, local = 1, @@ -228,7 +230,6 @@ genrule( chmod +w $$f cp $$f $$NATIVE_DEPS_DIR done - python $$WORK_DIR/java/modify_generated_java_flatbuffers_files.py $$WORK_DIR # Install plasma jar to local maven repo. mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \ -DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 7bfccddf550d..3830a05bc26d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -70,7 +70,7 @@ public List getAllNodeInfo() { throw new RuntimeException("Received invaild protobuf data from GCS."); } final UniqueId clientId = UniqueId - .fromByteBuffer(data.getClientIdBytes().asReadOnlyByteBuffer()); + .fromByteBuffer(data.getClientId().asReadOnlyByteBuffer()); if (data.getEntryType() == EntryType.INSERTION) { //Code path of node insertion. @@ -147,7 +147,7 @@ public List getCheckpointsForActor(UniqueId actorId) { UniqueId[] checkpointIds = new UniqueId[data.getCheckpointIdsCount()]; for (int i = 0; i < checkpointIds.length; i++) { checkpointIds[i] = UniqueId - .fromByteBuffer(data.getCheckpointIdsBytes(i).asReadOnlyByteBuffer()); + .fromByteBuffer(data.getCheckpointIds(i).asReadOnlyByteBuffer()); } for (int i = 0; i < checkpointIds.length; i++) { diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index f9e310249a35..74e940674eda 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -29,12 +29,12 @@ public class ObjectStoreProxy { private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class); - private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED) - .getBytes(); - private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED) - .getBytes(); + private static final byte[] WORKER_EXCEPTION_META = String + .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes(); + private static final byte[] ACTOR_EXCEPTION_META = String + .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String - .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes(); + .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); private static final byte[] RAW_TYPE_META = "RAW".getBytes(); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index ced4b3aac560..48aea5ae7422 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -2010,7 +2010,7 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { const TaskTableData &task_data) { // The task was in the GCS task table. Use the stored task spec to // re-execute the task. - auto message = flatbuffers::GetRoot(task_data.task().data()); + auto message = flatbuffers::GetRoot(task_data.task().data()); const Task task(*message); ResubmitTask(task); }, From 9c56a34f432aba40c5125ee8316db63efb2f8126 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 12:48:44 +0800 Subject: [PATCH 17/34] python gcs proto --- BUILD.bazel | 10 ++++++ bazel/ray_deps_build_all.bzl | 2 ++ python/ray/gcs_utils.py | 55 +++++++++++++------------------ python/ray/tests/cluster_utils.py | 4 +-- python/ray/worker.py | 2 +- 5 files changed, 37 insertions(+), 36 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 3696e137d5dc..8ebd56d1ec8c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -2,6 +2,7 @@ # C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load("@build_stack_rules_proto//python:python_proto_compile.bzl", "python_proto_compile") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@//bazel:ray.bzl", "flatbuffer_py_library") load("@//bazel:cython_library.bzl", "pyx_library") @@ -19,6 +20,11 @@ cc_proto_library( deps = [":gcs_proto"], ) +python_proto_compile( + name = "gcs_py_proto", + deps = [":gcs_proto"], +) + proto_library( name = "node_manager_proto", srcs = ["src/ray/protobuf/node_manager.proto"], @@ -691,6 +697,7 @@ genrule( "python/ray/_raylet.so", "//:python_sources", "//:python_gcs_fbs", + "//:gcs_py_proto", "//:python_node_manager_fbs", "//:redis-server", "//:redis-cli", @@ -717,6 +724,9 @@ genrule( for f in $(locations //:python_node_manager_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/; done && + for f in $(locations //:gcs_py_proto); do + cp -f $$f $$WORK_DIR/python/ray/core/generated/; + done && echo $$WORK_DIR > $@ """, local = 1, diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index d77f62549d53..eda88bece7d2 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -5,6 +5,7 @@ load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configur load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") load("@build_stack_rules_proto//java:deps.bzl", "java_proto_compile") +load("@build_stack_rules_proto//python:deps.bzl", "python_proto_compile") def ray_deps_build_all(): @@ -15,4 +16,5 @@ def ray_deps_build_all(): python_configure(name = "local_config_python") grpc_deps() java_proto_compile() + python_proto_compile() diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index cadd197ec73f..2aff1ba2ba8c 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -3,23 +3,24 @@ from __future__ import print_function import flatbuffers -import ray.core.generated.ErrorTableData -from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData -from ray.core.generated.ClientTableData import ClientTableData -from ray.core.generated.DriverTableData import DriverTableData -from ray.core.generated.ErrorTableData import ErrorTableData -from ray.core.generated.GcsEntry import GcsEntry -from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData -from ray.core.generated.HeartbeatTableData import HeartbeatTableData from ray.core.generated.Language import Language -from ray.core.generated.ObjectTableData import ObjectTableData -from ray.core.generated.ProfileTableData import ProfileTableData -from ray.core.generated.TablePrefix import TablePrefix -from ray.core.generated.TablePubsub import TablePubsub - from ray.core.generated.ray.protocol.Task import Task +from ray.core.generated.gcs_pb2 import ( + ActorCheckpointIdData, + ClientTableData, + DriverTableData, + ErrorTableData, + GcsEntry, + HeartbeatBatchTableData, + HeartbeatTableData, + ObjectTableData, + ProfileTableData, + TablePrefix, + TablePubsub, +) + __all__ = [ "ActorCheckpointIdData", "ClientTableData", @@ -48,7 +49,8 @@ # xray driver updates XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii") -# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs. +# These prefixes must be kept up-to-date with the TablePrefix enum in +# gcs.proto. # TODO(rkn): We should use scoped enums, in which case we should be able to # just access the flatbuffer generated values. TablePrefix_RAYLET_TASK_string = "RAYLET_TASK" @@ -70,22 +72,9 @@ def construct_error_message(driver_id, error_type, message, timestamp): Returns: The serialized object. """ - builder = flatbuffers.Builder(0) - driver_offset = builder.CreateString(driver_id.binary()) - error_type_offset = builder.CreateString(error_type) - message_offset = builder.CreateString(message) - - ray.core.generated.ErrorTableData.ErrorTableDataStart(builder) - ray.core.generated.ErrorTableData.ErrorTableDataAddDriverId( - builder, driver_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddType( - builder, error_type_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage( - builder, message_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp( - builder, timestamp) - error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd( - builder) - builder.Finish(error_data_offset) - - return bytes(builder.Output()) + data = ErrorTableData() + data.driver_id = driver_id.binary() + data.error_type = error_type + data.message = message + data.timestamp = timestamp + return data.SerializeToString() diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index 703c3a1420ed..ea1b3164d457 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -8,7 +8,7 @@ import redis import ray -from ray.core.generated.EntryType import EntryType +from ray.gcs_utils import ClientTableData logger = logging.getLogger(__name__) @@ -177,7 +177,7 @@ def wait_for_nodes(self, timeout=30): clients = ray.state._parse_client_table(redis_client) live_clients = [ client for client in clients - if client["EntryType"] == EntryType.INSERTION + if client["EntryType"] == ClientTableData.EntryType.INSERTION ] expected = len(self.list_all_nodes()) diff --git a/python/ray/worker.py b/python/ray/worker.py index 7505120574a6..cdb28a7395c7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -47,7 +47,7 @@ from ray import import_thread from ray import profiling -from ray.core.generated.ErrorType import ErrorType +from ray.gcs_utils import ErrorType from ray.exceptions import ( RayActorError, RayError, From e3efe99896cd4fd52a9ec86fdb6c26da2054677f Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 14:49:11 +0800 Subject: [PATCH 18/34] fix profile table data --- python/ray/gcs_utils.py | 8 ++-- python/ray/state.py | 14 +++--- python/ray/tests/test_basic.py | 8 ++-- python/ray/tests/test_failure.py | 2 +- python/ray/utils.py | 4 +- python/ray/worker.py | 6 +-- src/ray/object_manager/object_manager.cc | 10 ++--- src/ray/object_manager/object_manager.h | 7 +-- src/ray/raylet/node_manager.cc | 57 ++++++++++++++---------- 9 files changed, 63 insertions(+), 53 deletions(-) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 2aff1ba2ba8c..7aa4577200d1 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -12,6 +12,7 @@ ClientTableData, DriverTableData, ErrorTableData, + ErrorType, GcsEntry, HeartbeatBatchTableData, HeartbeatTableData, @@ -26,6 +27,7 @@ "ClientTableData", "DriverTableData", "ErrorTableData", + "ErrorType", "GcsEntry", "HeartbeatBatchTableData", "HeartbeatTableData", @@ -43,11 +45,11 @@ REPORTER_CHANNEL = "RAY_REPORTER" # xray heartbeats -XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii") -XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii") +XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.Value('HEARTBEAT_PUBSUB')).encode("ascii") +XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.Value('HEARTBEAT_BATCH_PUBSUB')).encode("ascii") # xray driver updates -XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii") +XRAY_DRIVER_CHANNEL = str(TablePubsub.Value('DRIVER_PUBSUB')).encode("ascii") # These prefixes must be kept up-to-date with the TablePrefix enum in # gcs.proto. diff --git a/python/ray/state.py b/python/ray/state.py index 14ba49987ec4..af776d9017e5 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -32,7 +32,7 @@ def _parse_client_table(redis_client): """ NIL_CLIENT_ID = ray.ObjectID.nil().binary() message = redis_client.execute_command("RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.CLIENT, + ray.gcs_utils.TablePrefix.Value('CLIENT'), "", NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only @@ -41,7 +41,7 @@ def _parse_client_table(redis_client): return [] node_info = {} - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entry = ray.gcs_utils.GcsEntry.ParseFromStrng(message) ordered_client_ids = [] @@ -244,7 +244,7 @@ def _object_table(self, object_id): # Return information about a single object ID. message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.OBJECT, "", + ray.gcs_utils.TablePrefix.Value('OBJECT'), "", object_id.binary()) if message is None: return {} @@ -302,7 +302,7 @@ def _task_table(self, task_id): """ assert isinstance(task_id, ray.TaskID) message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.RAYLET_TASK, + ray.gcs_utils.TablePrefix.Value('RAYLET_TASK'), "", task_id.binary()) if message is None: return {} @@ -423,7 +423,7 @@ def _profile_table(self, batch_id): # TODO(rkn): This method should support limiting the number of log # events and should also support returning a window of events. message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.PROFILE, "", + ray.gcs_utils.TablePrefix.Value('PROFILE'), "", batch_id.binary()) if message is None: @@ -860,7 +860,7 @@ def _error_messages(self, driver_id): """ assert isinstance(driver_id, ray.DriverID) message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "", + "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.Value('ERROR_INFO'), "", driver_id.binary()) # If there are no errors, return early. @@ -923,7 +923,7 @@ def actor_checkpoint_info(self, actor_id): message = self._execute_command( actor_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID, + ray.gcs_utils.TablePrefix.Value('ACTOR_CHECKPOINT_ID'), "", actor_id.binary(), ) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 7f1f78d1b5c4..eb5920ebfeac 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2736,14 +2736,14 @@ def test_duplicate_error_messages(shutdown_only): r = ray.worker.global_worker.redis_client - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), + r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value('ERROR_INFO'), + ray.gcs_utils.TablePubsub.Value('ERROR_INFO_PUBSUB'), driver_id.binary(), error_data) # Before https://github.com/ray-project/ray/pull/3316 this would # give an error - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), + r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value('ERROR_INFO'), + ray.gcs_utils.TablePubsub.Value('ERROR_INFO_PUBSUB'), driver_id.binary(), error_data) diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 51b906695c2d..9f25f97c577d 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -494,7 +494,7 @@ def test_warning_monitor_died(shutdown_only): redis_client = ray.worker.global_worker.redis_client redis_client.execute_command( "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH, - ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message) + ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH_PUBSUB, fake_id, malformed_message) wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1) diff --git a/python/ray/utils.py b/python/ray/utils.py index 7b87486e325e..27ea4424732b 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -94,8 +94,8 @@ def push_error_to_driver_through_redis(redis_client, error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, message, time.time()) redis_client.execute_command("RAY.TABLE_APPEND", - ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, + ray.gcs_utils.TablePrefix.Value('ERROR_INFO'), + ray.gcs_utils.TablePubsub.Value('ERROR_INFO_PUBSUB'), driver_id.binary(), error_data) diff --git a/python/ray/worker.py b/python/ray/worker.py index cdb28a7395c7..01748d886b11 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1637,7 +1637,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): # Really we should just subscribe to the errors for this specific job. # However, currently all errors seem to be published on the same channel. error_pubsub_channel = str( - ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii") + ray.gcs_utils.TablePubsub.Value('ERROR_INFO_PUBSUB')).encode("ascii") worker.error_message_pubsub_client.subscribe(error_pubsub_channel) # worker.error_message_pubsub_client.psubscribe("*") @@ -1882,8 +1882,8 @@ def connect(node, # Add the driver task to the task table. ray.state.state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - ray.gcs_utils.TablePubsub.RAYLET_TASK, + ray.gcs_utils.TablePrefix.Value('RAYLET_TASK'), + ray.gcs_utils.TablePubsub.Value('RAYLET_TASK_PUBSUB'), driver_task.task_id().binary(), driver_task._serialized_raylet_task()) diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index a8b6fbaa70a7..39a4326eac99 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -309,7 +309,7 @@ void ObjectManager::HandleSendFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEvent profile_event; + rpc::ProfileTableData::ProfileEvent profile_event; profile_event.set_event_type("transfer_send"); profile_event.set_start_time(start_time); profile_event.set_end_time(end_time); @@ -329,7 +329,7 @@ void ObjectManager::HandleReceiveFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEvent profile_event; + rpc::ProfileTableData::ProfileEvent profile_event; profile_event.set_event_type("transfer_receive"); profile_event.set_start_time(start_time); profile_event.set_end_time(end_time); @@ -801,7 +801,7 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr &con ObjectID object_id = ObjectID::FromBinary(pr->object_id()->str()); ClientID client_id = ClientID::FromBinary(pr->client_id()->str()); - ProfileEvent profile_event; + rpc::ProfileTableData::ProfileEvent profile_event; profile_event.set_event_type("receive_pull_request"); profile_event.set_start_time(current_sys_time_seconds()); profile_event.set_end_time(profile_event.start_time()); @@ -938,8 +938,8 @@ void ObjectManager::SpreadFreeObjectRequest(const std::vector &object_ } } -ProfileTableData ObjectManager::GetAndResetProfilingInfo() { - ProfileTableData profile_info; +rpc::ProfileTableData ObjectManager::GetAndResetProfilingInfo() { + rpc::ProfileTableData profile_info; profile_info.set_component_type("object_manager"); profile_info.set_component_id(client_id_.Binary()); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 1f2c2c8f78f0..6664dd0a93bd 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -28,9 +28,6 @@ namespace ray { -using rpc::ProfileTableData; -using ProfileEvent = rpc::ProfileTableData::ProfileEvent; - struct ObjectManagerConfig { /// The port that the object manager should use to listen for connections /// from other object managers. If this is 0, the object manager will choose @@ -183,7 +180,7 @@ class ObjectManager : public ObjectManagerInterface { /// /// \return All profiling information that has accumulated since the last call /// to this method. - ProfileTableData GetAndResetProfilingInfo(); + rpc::ProfileTableData GetAndResetProfilingInfo(); /// Returns debug string for class. /// @@ -415,7 +412,7 @@ class ObjectManager : public ObjectManagerInterface { /// Profiling events that are to be batched together and added to the profile /// table in the GCS. - std::vector profile_events_; + std::vector profile_events_; /// Internally maintained random number generator. std::mt19937_64 gen_; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 48aea5ae7422..808eeb6fd211 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -253,8 +253,8 @@ void NodeManager::KillWorker(std::shared_ptr worker) { void NodeManager::HandleDriverTableUpdate( const DriverID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { - RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::FromBinary(entry.driver_id()) - << " " << entry.is_dead(); + RAY_LOG(DEBUG) << "HandleDriverTableUpdate " + << UniqueID::FromBinary(entry.driver_id()) << " " << entry.is_dead(); if (entry.is_dead()) { auto driver_id = DriverID::FromBinary(entry.driver_id()); auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); @@ -377,8 +377,9 @@ void NodeManager::ClientAdded(const ClientTableData &client_data) { client_data.node_manager_port(), client_call_manager_)); remote_node_manager_clients_.emplace(client_id, std::move(client)); - ResourceSet resources_total(rpc::VectorFromProtobuf(client_data.resources_total_label()), - rpc::VectorFromProtobuf(client_data.resources_total_capacity())); + ResourceSet resources_total( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } @@ -437,8 +438,9 @@ void NodeManager::ResourceCreateUpdated(const ClientTableData &client_data) { RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " << client_id << ". Updating resource map."; - ResourceSet new_res_set(rpc::VectorFromProtobuf(client_data.resources_total_label()), - rpc::VectorFromProtobuf(client_data.resources_total_capacity())); + ResourceSet new_res_set( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); ResourceSet difference_set = old_res_set.FindUpdatedResources(new_res_set); @@ -471,8 +473,9 @@ void NodeManager::ResourceDeleted(const ClientTableData &client_data) { const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - ResourceSet new_res_set(rpc::VectorFromProtobuf(client_data.resources_total_label()), - rpc::VectorFromProtobuf(client_data.resources_total_capacity())); + ResourceSet new_res_set( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); RAY_LOG(DEBUG) << "[ResourceDeleted] received callback from client id " << client_id << " with new resources: " << new_res_set.ToString() << ". Updating resource map."; @@ -593,7 +596,8 @@ void NodeManager::PublishActorStateTransition( auto success_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, const ActorTableData &data) { auto redis_context = client->primary_context(); - if (data.state() == ActorTableData::DEAD || data.state() == ActorTableData::RECONSTRUCTING) { + if (data.state() == ActorTableData::DEAD || + data.state() == ActorTableData::RECONSTRUCTING) { std::vector args = {"XADD", id.Hex(), "*", "signal", "ACTOR_DIED_SIGNAL"}; RAY_CHECK_OK(redis_context->RunArgvAsync(args)); @@ -630,7 +634,8 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } RAY_LOG(DEBUG) << "Actor notification received: actor_id = " << actor_id << ", node_manager_id = " << actor_registration.GetNodeManagerId() - << ", state = " << ActorTableData::ActorState_Name(actor_registration.GetState()) + << ", state = " + << ActorTableData::ActorState_Name(actor_registration.GetState()) << ", remaining_reconstructions = " << actor_registration.GetRemainingReconstructions(); @@ -791,12 +796,14 @@ void NodeManager::ProcessClientMessage( ProcessPushErrorRequestMessage(message_data); } break; case protocol::MessageType::PushProfileEventsRequest: { - auto fbs_message = flatbuffers::GetRoot(message_data); - ProfileTableData profile_table_data; - profile_table_data.set_component_type(fbs_message->component_type); - profile_table_data.set_component_id(fbs_message->component_id); - for (const auto &fbs_event : fbs_message->profile_events) { - ProfileEvent *event = profile_table_data.add_profile_events(); + ProfileTableDataT fbs_message; + flatbuffers::GetRoot(message_data)->UnPackTo(&fbs_message); + rpc::ProfileTableData profile_table_data; + profile_table_data.set_component_type(fbs_message.component_type); + profile_table_data.set_component_id(fbs_message.component_id); + for (const auto &fbs_event : fbs_message.profile_events) { + rpc::ProfileTableData::ProfileEvent *event = + profile_table_data.add_profile_events(); event->set_event_type(fbs_event->event_type); event->set_start_time(fbs_event->start_time); event->set_end_time(fbs_event->end_time); @@ -1881,12 +1888,15 @@ ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &tas // Set all of the static fields for the actor. These fields will not // change even if the actor fails or is reconstructed. new_actor_data.set_actor_id(actor_id.Binary()); - new_actor_data.set_actor_creation_dummy_object_id(task.GetTaskSpecification().ActorDummyObject().Binary()); + new_actor_data.set_actor_creation_dummy_object_id( + task.GetTaskSpecification().ActorDummyObject().Binary()); new_actor_data.set_driver_id(task.GetTaskSpecification().DriverId().Binary()); - new_actor_data.set_max_reconstructions(task.GetTaskSpecification().MaxActorReconstructions()); + new_actor_data.set_max_reconstructions( + task.GetTaskSpecification().MaxActorReconstructions()); // This is the first time that the actor has been created, so the number // of remaining reconstructions is the max. - new_actor_data.set_remaining_reconstructions(task.GetTaskSpecification().MaxActorReconstructions()); + new_actor_data.set_remaining_reconstructions( + task.GetTaskSpecification().MaxActorReconstructions()); } else { // If we've already seen this actor, it means that this actor was reconstructed. // Thus, its previous state must be RECONSTRUCTING. @@ -1895,12 +1905,14 @@ ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &tas new_actor_data = actor_entry->second.GetTableData(); // We are reconstructing the actor, so subtract its // remaining_reconstructions by 1. - new_actor_data.set_remaining_reconstructions(new_actor_data.remaining_reconstructions() - 1); + new_actor_data.set_remaining_reconstructions( + new_actor_data.remaining_reconstructions() - 1); } // Set the new fields for the actor's state to indicate that the actor is // now alive on this node manager. - new_actor_data.set_node_manager_id(gcs_client_->client_table().GetLocalClientId().Binary()); + new_actor_data.set_node_manager_id( + gcs_client_->client_table().GetLocalClientId().Binary()); new_actor_data.set_state(ActorTableData::ALIVE); return new_actor_data; } @@ -1967,8 +1979,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { PublishActorStateTransition( actor_id, new_actor_data, /*failure_callback=*/ - [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableData &data) { + [](gcs::AsyncGcsClient *client, const ActorID &id, const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); From 0dc661d62659ff4a8ea681b6c543cbe1e74c3888 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 16:36:27 +0800 Subject: [PATCH 19/34] fix state.py --- python/ray/gcs_utils.py | 8 +- python/ray/state.py | 223 ++++++++++++++++--------------------- python/ray/worker.py | 20 ++-- src/ray/protobuf/gcs.proto | 2 +- 4 files changed, 111 insertions(+), 142 deletions(-) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 7aa4577200d1..1475962714fe 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -20,6 +20,7 @@ ProfileTableData, TablePrefix, TablePubsub, + TaskTableData, ) __all__ = [ @@ -37,6 +38,7 @@ "TablePrefix", "TablePubsub", "Task", + "TaskTableData", "construct_error_message", ] @@ -45,8 +47,10 @@ REPORTER_CHANNEL = "RAY_REPORTER" # xray heartbeats -XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.Value('HEARTBEAT_PUBSUB')).encode("ascii") -XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.Value('HEARTBEAT_BATCH_PUBSUB')).encode("ascii") +XRAY_HEARTBEAT_CHANNEL = str( + TablePubsub.Value('HEARTBEAT_PUBSUB')).encode("ascii") +XRAY_HEARTBEAT_BATCH_CHANNEL = str( + TablePubsub.Value('HEARTBEAT_BATCH_PUBSUB')).encode("ascii") # xray driver updates XRAY_DRIVER_CHANNEL = str(TablePubsub.Value('DRIVER_PUBSUB')).encode("ascii") diff --git a/python/ray/state.py b/python/ray/state.py index af776d9017e5..c0347e290b71 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -10,11 +10,12 @@ import ray from ray.function_manager import FunctionDescriptor -import ray.gcs_utils from ray.ray_constants import ID_SIZE -from ray import services -from ray.core.generated.EntryType import EntryType +from ray import ( + gcs_utils, + services, +) from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -31,9 +32,9 @@ def _parse_client_table(redis_client): A list of information about the nodes in the cluster. """ NIL_CLIENT_ID = ray.ObjectID.nil().binary() - message = redis_client.execute_command("RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.Value('CLIENT'), - "", NIL_CLIENT_ID) + message = redis_client.execute_command( + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value('CLIENT'), "", + NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only # occur potentially immediately after the cluster is started. @@ -41,36 +42,31 @@ def _parse_client_table(redis_client): return [] node_info = {} - gcs_entry = ray.gcs_utils.GcsEntry.ParseFromStrng(message) + gcs_entry = gcs_utils.GcsEntry.FromString(message) ordered_client_ids = [] # Since GCS entries are append-only, we override so that # only the latest entries are kept. - for i in range(gcs_entry.EntriesLength()): - client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - gcs_entry.Entries(i), 0)) + for entry in gcs_entry.entries: + client = gcs_utils.ClientTableData.FromString(entry) resources = { - decode(client.ResourcesTotalLabel(i)): - client.ResourcesTotalCapacity(i) - for i in range(client.ResourcesTotalLabelLength()) + client.resources_total_label[i]: client.resources_total_capacity[i] + for i in range(len(client.resources_total_label)) } - client_id = ray.utils.binary_to_hex(client.ClientId()) + client_id = ray.utils.binary_to_hex(client.client_id) - if client.EntryType() == EntryType.INSERTION: + if client.entry_type == gcs_utils.ClientTableData.INSERTION: ordered_client_ids.append(client_id) node_info[client_id] = { "ClientID": client_id, - "EntryType": client.EntryType(), - "NodeManagerAddress": decode( - client.NodeManagerAddress(), allow_none=True), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName(), allow_none=True), - "RayletSocketName": decode( - client.RayletSocketName(), allow_none=True), + "EntryType": client.entry_type, + "NodeManagerAddress": client.node_manager_address, + "NodeManagerPort": client.node_manager_port, + "ObjectManagerPort": client.object_manager_port, + "ObjectStoreSocketName": client.object_store_socket_name, + "RayletSocketName": client.raylet_socket_name, "Resources": resources } @@ -79,22 +75,22 @@ def _parse_client_table(redis_client): # it cannot have previously been removed. else: assert client_id in node_info, "Client not found!" - assert node_info[client_id]["EntryType"] != EntryType.DELETION, ( + assert node_info[client_id]["EntryType"] != gcs_utils.ClientTableData.DELETION, ( "Unexpected updation of deleted client.") res_map = node_info[client_id]["Resources"] - if client.EntryType() == EntryType.RES_CREATEUPDATE: + if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE: for res in resources: res_map[res] = resources[res] - elif client.EntryType() == EntryType.RES_DELETE: + elif client.entry_type == gcs_utils.ClientTableData.RES_DELETE: for res in resources: res_map.pop(res, None) - elif client.EntryType() == EntryType.DELETION: + elif client.entry_type == gcs_utils.ClientTableData.DELETION: pass # Do nothing with the resmap if client deletion else: raise RuntimeError("Unexpected EntryType {}".format( - client.EntryType())) + client.entry_type)) node_info[client_id]["Resources"] = res_map - node_info[client_id]["EntryType"] = client.EntryType() + node_info[client_id]["EntryType"] = client.entry_type # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve @@ -244,20 +240,19 @@ def _object_table(self, object_id): # Return information about a single object ID. message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.Value('OBJECT'), "", + gcs_utils.TablePrefix.Value('OBJECT'), "", object_id.binary()) if message is None: return {} - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entry = gcs_utils.GcsEntry.FromString(message) - assert gcs_entry.EntriesLength() > 0 + assert len(gcs_entry.entries) > 0 - entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( - gcs_entry.Entries(0), 0) + entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0]) object_info = { - "DataSize": entry.ObjectSize(), - "Manager": entry.Manager(), + "DataSize": entry.object_size, + "Manager": entry.manager, } return object_info @@ -278,10 +273,10 @@ def object_table(self, object_id=None): return self._object_table(object_id) else: # Return the entire object table. - object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string + + object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*") object_ids_binary = { - key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] + key[len(gcs_utils.TablePrefix_OBJECT_string):] for key in object_keys } @@ -301,17 +296,21 @@ def _task_table(self, task_id): A dictionary with information about the task ID in question. """ assert isinstance(task_id, ray.TaskID) - message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.Value('RAYLET_TASK'), - "", task_id.binary()) + message = self._execute_command( + task_id, "RAY.TABLE_LOOKUP", + gcs_utils.TablePrefix.Value('RAYLET_TASK'), "", task_id.binary()) if message is None: return {} - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - - assert gcs_entries.EntriesLength() == 1 + gcs_entries = gcs_utils.GcsEntry.FromString(message) - task_table_message = ray.gcs_utils.Task.GetRootAsTask( - gcs_entries.Entries(0), 0) + assert len(gcs_entries.entries) == 1 + print(task_id) + print(len(gcs_entries.entries[0])) + task_table_data = gcs_utils.TaskTableData.FromString( + gcs_entries.entries[0]) + print(task_table_data.task) + task_table_message = gcs_utils.Task.GetRootAsTask( + task_table_data.task, 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() @@ -368,11 +367,12 @@ def task_table(self, task_id=None): return self._task_table(task_id) else: task_table_keys = self._keys( - ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") + gcs_utils.TablePrefix_RAYLET_TASK_string + "*") task_ids_binary = [ - key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] + key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):] for key in task_table_keys ] + print([ray.TaskID(id) for id in task_ids_binary]) results = {} for task_id_binary in task_ids_binary: @@ -380,27 +380,6 @@ def task_table(self, task_id=None): ray.TaskID(task_id_binary)) return results - def function_table(self, function_id=None): - """Fetch and parse the function table. - - Returns: - A dictionary that maps function IDs to information about the - function. - """ - self._check_connected() - function_table_keys = self.redis_client.keys( - ray.gcs_utils.FUNCTION_PREFIX + "*") - results = {} - for key in function_table_keys: - info = self.redis_client.hgetall(key) - function_info_parsed = { - "DriverID": binary_to_hex(info[b"driver_id"]), - "Module": decode(info[b"module"]), - "Name": decode(info[b"name"]) - } - results[binary_to_hex(info[b"function_id"])] = function_info_parsed - return results - def client_table(self): """Fetch and parse the Redis DB client table. @@ -423,37 +402,32 @@ def _profile_table(self, batch_id): # TODO(rkn): This method should support limiting the number of log # events and should also support returning a window of events. message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.Value('PROFILE'), "", + gcs_utils.TablePrefix.Value('PROFILE'), "", batch_id.binary()) if message is None: return [] - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entries = gcs_utils.GcsEntry.FromString(message) profile_events = [] - for i in range(gcs_entries.EntriesLength()): - profile_table_message = ( - ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData( - gcs_entries.Entries(i), 0)) - - component_type = decode(profile_table_message.ComponentType()) - component_id = binary_to_hex(profile_table_message.ComponentId()) - node_ip_address = decode( - profile_table_message.NodeIpAddress(), allow_none=True) + for entry in gcs_entries.entries: + profile_table_message = gcs_utils.ProfileTableData.FromString( + entry) - for j in range(profile_table_message.ProfileEventsLength()): - profile_event_message = profile_table_message.ProfileEvents(j) + component_type = profile_table_message.component_type + component_id = binary_to_hex(profile_table_message.component_id) + node_ip_address = profile_table_message.node_ip_address + for profile_event_message in profile_table_message.profile_events: profile_event = { - "event_type": decode(profile_event_message.EventType()), + "event_type": profile_event_message.event_type, "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, - "start_time": profile_event_message.StartTime(), - "end_time": profile_event_message.EndTime(), - "extra_data": json.loads( - decode(profile_event_message.ExtraData())), + "start_time": profile_event_message.start_time, + "end_time": profile_event_message.end_time, + "extra_data": json.loads(profile_event_message.extra_data), } profile_events.append(profile_event) @@ -463,9 +437,9 @@ def _profile_table(self, batch_id): def profile_table(self): self._check_connected() profile_table_keys = self._keys( - ray.gcs_utils.TablePrefix_PROFILE_string + "*") + gcs_utils.TablePrefix_PROFILE_string + "*") batch_identifiers_binary = [ - key[len(ray.gcs_utils.TablePrefix_PROFILE_string):] + key[len(gcs_utils.TablePrefix_PROFILE_string):] for key in profile_table_keys ] @@ -766,7 +740,7 @@ def cluster_resources(self): clients = self.client_table() for client in clients: # Only count resources from latest entries of live clients. - if client["EntryType"] != EntryType.DELETION: + if client["EntryType"] != gcs_utils.ClientTableData.DELETION: for key, value in client["Resources"].items(): resources[key] += value return dict(resources) @@ -776,7 +750,7 @@ def _live_client_ids(self): return { client["ClientID"] for client in self.client_table() - if (client["EntryType"] != EntryType.DELETION) + if (client["EntryType"] != gcs_utils.ClientTableData.DELETION) } def available_resources(self): @@ -800,7 +774,7 @@ def available_resources(self): for redis_client in self.redis_clients ] for subscribe_client in subscribe_clients: - subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) + subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL) client_ids = self._live_client_ids() @@ -809,24 +783,23 @@ def available_resources(self): # Parse client message raw_message = subscribe_client.get_message() if (raw_message is None or raw_message["channel"] != - ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): + gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( - data, 0)) - heartbeat_data = gcs_entries.Entries(0) - message = (ray.gcs_utils.HeartbeatTableData. - GetRootAsHeartbeatTableData(heartbeat_data, 0)) + gcs_entries = gcs_utils.GcsEntry.FromString(data) + heartbeat_data = gcs_entries.entries[0] + message = gcs_utils.HeartbeatTableData.FromString( + heartbeat_data) # Calculate available resources for this client - num_resources = message.ResourcesAvailableLabelLength() + num_resources = len(message.resources_available_label) dynamic_resources = {} for i in range(num_resources): - resource_id = decode(message.ResourcesAvailableLabel(i)) + resource_id = message.resources_available_label[i] dynamic_resources[resource_id] = ( - message.ResourcesAvailableCapacity(i)) + message.resources_available_capacity[i]) # Update available resources for this client - client_id = ray.utils.binary_to_hex(message.ClientId()) + client_id = ray.utils.binary_to_hex(message.client_id) available_resources_by_id[client_id] = dynamic_resources # Update clients in cluster @@ -860,23 +833,22 @@ def _error_messages(self, driver_id): """ assert isinstance(driver_id, ray.DriverID) message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.Value('ERROR_INFO'), "", + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value('ERROR_INFO'), "", driver_id.binary()) # If there are no errors, return early. if message is None: return [] - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entries = gcs_utils.GcsEntry.FromString(message) error_messages = [] - for i in range(gcs_entries.EntriesLength()): - error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( - gcs_entries.Entries(i), 0) - assert driver_id.binary() == error_data.DriverId() + for entry in gcs_entries.entries: + error_data = gcs_utils.ErrorTableData.FromString(entry) + assert driver_id.binary() == error_data.driver_id error_message = { - "type": decode(error_data.Type()), - "message": decode(error_data.ErrorMessage()), - "timestamp": error_data.Timestamp(), + "type": error_data.type, + "message": error_data.error_message, + "timestamp": error_data.timestamp, } error_messages.append(error_message) return error_messages @@ -899,9 +871,9 @@ def error_messages(self, driver_id=None): return self._error_messages(driver_id) error_table_keys = self.redis_client.keys( - ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*") + gcs_utils.TablePrefix_ERROR_INFO_string + "*") driver_ids = [ - key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):] + key[len(gcs_utils.TablePrefix_ERROR_INFO_string):] for key in error_table_keys ] @@ -923,30 +895,23 @@ def actor_checkpoint_info(self, actor_id): message = self._execute_command( actor_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.Value('ACTOR_CHECKPOINT_ID'), + gcs_utils.TablePrefix.Value('ACTOR_CHECKPOINT_ID'), "", actor_id.binary(), ) if message is None: return None - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - entry = ( - ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData( - gcs_entry.Entries(0), 0)) - checkpoint_ids_str = entry.CheckpointIds() - num_checkpoints = len(checkpoint_ids_str) // ID_SIZE - assert len(checkpoint_ids_str) % ID_SIZE == 0 + gcs_entry = gcs_utils.GcsEntry.FromString(message) + entry = gcs_utils.ActorCheckpointIdData.FromString( + gcs_entry.entries[0]) checkpoint_ids = [ - ray.ActorCheckpointID( - checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)]) - for i in range(num_checkpoints) + ray.ActorCheckpointID(checkpoint_id) + for checkpoint_id in entry.checkpoint_ids ] return { - "ActorID": ray.utils.binary_to_hex(entry.ActorId()), + "ActorID": ray.utils.binary_to_hex(entry.actor_id), "CheckpointIds": checkpoint_ids, - "Timestamps": [ - entry.Timestamps(i) for i in range(num_checkpoints) - ], + "Timestamps": entry.timestamps, } diff --git a/python/ray/worker.py b/python/ray/worker.py index 01748d886b11..8f0e9f24f301 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1656,21 +1656,19 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): if msg is None: threads_stopped.wait(timeout=0.01) continue - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( - msg["data"], 0) - assert gcs_entry.EntriesLength() == 1 - error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( - gcs_entry.Entries(0), 0) - driver_id = error_data.DriverId() + gcs_entry = ray.gcs_utils.GcsEntry.FromString(msg["data"]) + assert len(gcs_entry.entries) == 1 + error_data = ray.gcs_utils.ErrorTableData.FromString( + gcs_entry.entries[0]) + driver_id = error_data.driver_id if driver_id not in [ worker.task_driver_id.binary(), DriverID.nil().binary() ]: continue - error_message = ray.utils.decode(error_data.ErrorMessage()) - if (ray.utils.decode( - error_data.Type()) == ray_constants.TASK_PUSH_ERROR): + error_message = error_data.error_message + if (error_data.type == ray_constants.TASK_PUSH_ERROR): # Delay it a bit to see if we can suppress it task_error_queue.put((error_message, time.time())) else: @@ -1878,6 +1876,8 @@ def connect(node, {}, # resource_map. {}, # placement_resource_map. ) + task_table_data = ray.gcs_utils.TaskTableData() + task_table_data.task = driver_task._serialized_raylet_task() # Add the driver task to the task table. ray.state.state._execute_command(driver_task.task_id(), @@ -1885,7 +1885,7 @@ def connect(node, ray.gcs_utils.TablePrefix.Value('RAYLET_TASK'), ray.gcs_utils.TablePubsub.Value('RAYLET_TASK_PUBSUB'), driver_task.task_id().binary(), - driver_task._serialized_raylet_task()) + task_table_data.SerializeToString()) # Set the driver's current task ID to the task ID assigned to the # driver task. diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index ff1db2115fbf..fbb08ade4415 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -138,7 +138,7 @@ message ProfileTableData { // object_manager, or node_manager. string component_type = 1; // An identifier for the component that generated the event. - string component_id = 2; + bytes component_id = 2; // An identifier for the node that generated the event. string node_ip_address = 3; // This is a batch of profiling events. We batch these together for From a802ea8eafd18964a88cfe0c7cd7c02642ca097d Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 16:55:47 +0800 Subject: [PATCH 20/34] fix --- python/ray/gcs_utils.py | 4 ++-- python/ray/tests/cluster_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 1475962714fe..f1e69b2f2669 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -80,7 +80,7 @@ def construct_error_message(driver_id, error_type, message, timestamp): """ data = ErrorTableData() data.driver_id = driver_id.binary() - data.error_type = error_type - data.message = message + data.type = error_type + data.error_message = message data.timestamp = timestamp return data.SerializeToString() diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index ea1b3164d457..76dfd3000b86 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -177,7 +177,7 @@ def wait_for_nodes(self, timeout=30): clients = ray.state._parse_client_table(redis_client) live_clients = [ client for client in clients - if client["EntryType"] == ClientTableData.EntryType.INSERTION + if client["EntryType"] == ClientTableData.INSERTION ] expected = len(self.list_all_nodes()) From 47a5ffc622cfbd25e477f278061299dfaa60ad8e Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 16:59:09 +0800 Subject: [PATCH 21/34] remove gcs fbs python files --- BUILD.bazel | 18 ------------------ doc/source/conf.py | 15 +-------------- python/ray/core/generated/__init__.py | 0 python/ray/core/generated/ray/__init__.py | 0 .../core/generated/ray/protocol/__init__.py | 0 python/ray/gcs_utils.py | 1 - 6 files changed, 1 insertion(+), 33 deletions(-) delete mode 100644 python/ray/core/generated/__init__.py delete mode 100644 python/ray/core/generated/ray/__init__.py delete mode 100644 python/ray/core/generated/ray/protocol/__init__.py diff --git a/BUILD.bazel b/BUILD.bazel index 8ebd56d1ec8c..71c6484b5619 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -586,22 +586,6 @@ filegroup( visibility = ["//java:__subpackages__"], ) -flatbuffer_py_library( - name = "python_gcs_fbs", - srcs = [ - ":gcs_fbs_file", - ], - outs = [ - "Arg.py", - "Language.py", - "ProfileEvent.py", - "ProfileTableData.py", - "TaskInfo.py", - "ResourcePair.py" - ], - out_prefix = "python/ray/core/generated/", -) - flatbuffer_py_library( name = "python_node_manager_fbs", srcs = [ @@ -696,7 +680,6 @@ genrule( srcs = [ "python/ray/_raylet.so", "//:python_sources", - "//:python_gcs_fbs", "//:gcs_py_proto", "//:python_node_manager_fbs", "//:redis-server", @@ -719,7 +702,6 @@ genrule( cp -f $(location //:raylet_monitor) $$WORK_DIR/python/ray/core/src/ray/raylet/ && cp -f $(location @plasma//:plasma_store_server) $$WORK_DIR/python/ray/core/src/plasma/ && cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ && - for f in $(locations //:python_gcs_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done && mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ && for f in $(locations //:python_node_manager_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/; diff --git a/doc/source/conf.py b/doc/source/conf.py index 98fb3e0d02dd..5cf6b01217f9 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -23,20 +23,7 @@ "gym.spaces", "ray._raylet", "ray.core.generated", - "ray.core.generated.ActorCheckpointIdData", - "ray.core.generated.ClientTableData", - "ray.core.generated.DriverTableData", - "ray.core.generated.EntryType", - "ray.core.generated.ErrorTableData", - "ray.core.generated.ErrorType", - "ray.core.generated.GcsEntry", - "ray.core.generated.HeartbeatBatchTableData", - "ray.core.generated.HeartbeatTableData", - "ray.core.generated.Language", - "ray.core.generated.ObjectTableData", - "ray.core.generated.ProfileTableData", - "ray.core.generated.TablePrefix", - "ray.core.generated.TablePubsub", + "ray.core.generated.gcs_pb2", "ray.core.generated.ray.protocol.Task", "scipy", "scipy.signal", diff --git a/python/ray/core/generated/__init__.py b/python/ray/core/generated/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/ray/core/generated/ray/__init__.py b/python/ray/core/generated/ray/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/ray/core/generated/ray/protocol/__init__.py b/python/ray/core/generated/ray/protocol/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index f1e69b2f2669..3097fcac4f82 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -4,7 +4,6 @@ import flatbuffers -from ray.core.generated.Language import Language from ray.core.generated.ray.protocol.Task import Task from ray.core.generated.gcs_pb2 import ( From 0d2ca331dee86104cb4925074dbdbb5fe7d95d44 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 17:48:14 +0800 Subject: [PATCH 22/34] fix --- BUILD.bazel | 4 +++ .../java/org/ray/runtime/gcs/GcsClient.java | 4 +-- python/ray/monitor.py | 33 +++++++++---------- python/ray/state.py | 6 +--- python/ray/tests/test_failure.py | 5 +-- python/ray/worker.py | 6 ++-- 6 files changed, 28 insertions(+), 30 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 71c6484b5619..80dc692841bd 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -9,6 +9,8 @@ load("@//bazel:cython_library.bzl", "pyx_library") COPTS = ["-DRAY_USE_GLOG"] +# === Begin of protobuf definitions === + proto_library( name = "gcs_proto", srcs = ["src/ray/protobuf/gcs.proto"], @@ -35,6 +37,8 @@ cc_proto_library( deps = ["node_manager_proto"], ) +# === End of protobuf definitions === + # Node manager gRPC lib. cc_grpc_library( name = "node_manager_cc_grpc", diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 3830a05bc26d..ea11696e3cc1 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -67,7 +67,7 @@ public List getAllNodeInfo() { try { data = ClientTableData.parseFrom(result); } catch (InvalidProtocolBufferException e) { - throw new RuntimeException("Received invaild protobuf data from GCS."); + throw new RuntimeException("Received invalid protobuf data from GCS."); } final UniqueId clientId = UniqueId .fromByteBuffer(data.getClientId().asReadOnlyByteBuffer()); @@ -142,7 +142,7 @@ public List getCheckpointsForActor(UniqueId actorId) { try { data = ActorCheckpointIdData.parseFrom(result); } catch (InvalidProtocolBufferException e) { - throw new RuntimeException("Received invaild protobuf data from GCS."); + throw new RuntimeException("Received invalid protobuf data from GCS."); } UniqueId[] checkpointIds = new UniqueId[data.getCheckpointIdsCount()]; for (int i = 0; i < checkpointIds.length; i++) { diff --git a/python/ray/monitor.py b/python/ray/monitor.py index c9e0424b3eb8..35597ef231e3 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -101,28 +101,26 @@ def subscribe(self, channel): def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) - heartbeat_data = gcs_entries.Entries(0) + gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) + heartbeat_data = gcs_entries.entries[0] - message = (ray.gcs_utils.HeartbeatBatchTableData. - GetRootAsHeartbeatBatchTableData(heartbeat_data, 0)) + message = ray.gcs_utils.HeartbeatBatchTableData.FromString( + heartbeat_data) - for j in range(message.BatchLength()): - heartbeat_message = message.Batch(j) - - num_resources = heartbeat_message.ResourcesTotalLabelLength() + for heartbeat_message in message.batch: + num_resources = len(heartbeat_message.resources_available_label) static_resources = {} dynamic_resources = {} for i in range(num_resources): - dyn = heartbeat_message.ResourcesAvailableLabel(i) - static = heartbeat_message.ResourcesTotalLabel(i) + dyn = heartbeat_message.resources_available_label[i] + static = heartbeat_message.resources_total_label[i] dynamic_resources[dyn] = ( - heartbeat_message.ResourcesAvailableCapacity(i)) + heartbeat_message.resources_available_capacity[i]) static_resources[static] = ( - heartbeat_message.ResourcesTotalCapacity(i)) + heartbeat_message.resources_total_capacity[i]) # Update the load metrics for this raylet. - client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId()) + client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) ip = self.raylet_id_to_ip_map.get(client_id) if ip: self.load_metrics.update(ip, static_resources, @@ -207,11 +205,10 @@ def xray_driver_removed_handler(self, unused_channel, data): unused_channel: The message channel. data: The message data. """ - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) - driver_data = gcs_entries.Entries(0) - message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( - driver_data, 0) - driver_id = message.DriverId() + gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) + driver_data = gcs_entries.entries[0] + message = ray.gcs_utils.DriverTableData.FromString(driver_data) + driver_id = message.driver_id logger.info("Monitor: " "XRay Driver {} has been removed.".format( binary_to_hex(driver_id))) diff --git a/python/ray/state.py b/python/ray/state.py index c0347e290b71..90326673a211 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -304,11 +304,8 @@ def _task_table(self, task_id): gcs_entries = gcs_utils.GcsEntry.FromString(message) assert len(gcs_entries.entries) == 1 - print(task_id) - print(len(gcs_entries.entries[0])) task_table_data = gcs_utils.TaskTableData.FromString( gcs_entries.entries[0]) - print(task_table_data.task) task_table_message = gcs_utils.Task.GetRootAsTask( task_table_data.task, 0) @@ -372,7 +369,6 @@ def task_table(self, task_id=None): key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):] for key in task_table_keys ] - print([ray.TaskID(id) for id in task_ids_binary]) results = {} for task_id_binary in task_ids_binary: @@ -911,7 +907,7 @@ def actor_checkpoint_info(self, actor_id): return { "ActorID": ray.utils.binary_to_hex(entry.actor_id), "CheckpointIds": checkpoint_ids, - "Timestamps": entry.timestamps, + "Timestamps": list(entry.timestamps), } diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 9f25f97c577d..7ec64d37d616 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -493,8 +493,9 @@ def test_warning_monitor_died(shutdown_only): malformed_message = "asdf" redis_client = ray.worker.global_worker.redis_client redis_client.execute_command( - "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH, - ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH_PUBSUB, fake_id, malformed_message) + "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value('HEARTBEAT_BATCH'), + ray.gcs_utils.TablePubsub.Value('HEARTBEAT_BATCH_PUBSUB'), fake_id, + malformed_message) wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1) diff --git a/python/ray/worker.py b/python/ray/worker.py index 8f0e9f24f301..179f54c51fd9 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -461,11 +461,11 @@ def _deserialize_object_from_arrow(self, data, metadata, object_id, # Otherwise, return an exception object based on # the error type. error_type = int(metadata) - if error_type == ErrorType.WORKER_DIED: + if error_type == ErrorType.Value('WORKER_DIED'): return RayWorkerError() - elif error_type == ErrorType.ACTOR_DIED: + elif error_type == ErrorType.Value('ACTOR_DIED'): return RayActorError() - elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE: + elif error_type == ErrorType.Value('OBJECT_UNRECONSTRUCTABLE'): return UnreconstructableError(ray.ObjectID(object_id.binary())) else: assert False, "Unrecognized error type " + str(error_type) From dede65066c434a169dd0310af20fe06157b662aa Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 18:22:52 +0800 Subject: [PATCH 23/34] refine and add comments --- src/ray/gcs/format/gcs.fbs | 2 ++ src/ray/gcs/redis_context.cc | 4 ++-- src/ray/gcs/redis_context.h | 15 +++++++----- src/ray/gcs/redis_module/ray_redis_module.cc | 24 ++++++++++++-------- 4 files changed, 28 insertions(+), 17 deletions(-) diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 26fe6772a52b..c06c79a02928 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -1,3 +1,5 @@ +// TODO(hchen): Migrate data structures in this file to protobuf (`gcs.proto`). + enum Language:int { PYTHON=0, JAVA=1, diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 60606e031905..ae6cb6088cec 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -263,10 +263,10 @@ Status RedisContext::RunArgvAsync(const std::vector &args) { } Status RedisContext::SubscribeAsync(const ClientID &client_id, - const rpc::TablePubsub pubsub_channel, + const TablePubsub pubsub_channel, const RedisCallback &redisCallback, int64_t *out_callback_index) { - RAY_CHECK(pubsub_channel != rpc::TablePubsub::NO_PUBLISH) + RAY_CHECK(pubsub_channel != TablePubsub::NO_PUBLISH) << "Client requested subscribe on a table that does not support pubsub"; int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, true); diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index e1cde567e835..093aab2455d9 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -25,6 +25,9 @@ namespace ray { namespace gcs { +using rpc::TablePrefix; +using rpc::TablePubsub; + /// A simple reply wrapper for redis reply. class CallbackReply { public: @@ -127,8 +130,8 @@ class RedisContext { /// \return Status. template Status RunAsync(const std::string &command, const ID &id, const void *data, - size_t length, const rpc::TablePrefix prefix, - const rpc::TablePubsub pubsub_channel, RedisCallback redisCallback, + size_t length, const TablePrefix prefix, + const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); /// Run an arbitrary Redis command without a callback. @@ -144,7 +147,7 @@ class RedisContext { /// \param redisCallback The callback function that the notification calls. /// \param out_callback_index The output pointer to callback index. /// \return Status. - Status SubscribeAsync(const ClientID &client_id, const rpc::TablePubsub pubsub_channel, + Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel, const RedisCallback &redisCallback, int64_t *out_callback_index); redisContext *sync_context() { return context_; } redisAsyncContext *async_context() { return async_context_; } @@ -157,9 +160,9 @@ class RedisContext { }; template -Status RedisContext::RunAsync(const std::string &command, const ID &id, - const void *data, size_t length, - const rpc::TablePrefix prefix, const rpc::TablePubsub pubsub_channel, +Status RedisContext::RunAsync(const std::string &command, const ID &id, const void *data, + size_t length, const TablePrefix prefix, + const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length) { int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); if (length > 0) { diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 6d4b0f50ac84..b0c539f277ff 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -170,6 +170,12 @@ Status GetBroadcastKey(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st return Status::OK(); } +/// A helper function that creates `GcsEntry` protobuf object. +/// +/// \param[in] id Id of the entry. +/// \param[in] change_mode Change mode of the entry. +/// \param[in] entries Vector of entries. +/// \param[out] result The created `GcsEntry` object. inline void CreateGcsEntry(RedisModuleString *id, GcsChangeMode change_mode, const std::vector &entries, GcsEntry *result) { @@ -632,7 +638,7 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, /// key should be published to. When publishing to a specific client, the /// channel name should be :. /// \param id The ID of the key to remove from. -/// \param data The GcsEntry flatbugger data used to update this hash table. +/// \param data The GcsEntry protobuf data used to update this hash table. /// 1). For deletion, this is a list of keys. /// 2). For updating, this is a list of pairs with each key followed by the value. /// \return OK if the remove succeeds, or an error message string if the remove @@ -649,7 +655,7 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a return Hash_DoPublish(ctx, new_argv.data()); } -/// A helper function to create and finish a GcsEntry, based on the +/// A helper function to create a GcsEntry protobuf, based on the /// current value or values at the given key. /// /// \param ctx The Redis module context. @@ -659,14 +665,14 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a /// \param prefix_str The string prefix associated with the open Redis key. /// When parsed, this is expected to be a TablePrefix. /// \param entry_id The UniqueID associated with the open Redis key. -/// \param fbb A flatbuffer builder used to build the GcsEntry. -Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, +/// \param[out] gcs_entry The created GcsEntry. +Status TableEntryToProtobuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, RedisModuleString *prefix_str, RedisModuleString *entry_id, GcsEntry *gcs_entry) { auto key_type = RedisModule_KeyType(table_key); switch (key_type) { case REDISMODULE_KEYTYPE_STRING: { - // Build the flatbuffer from the string data. + // Build the GcsEntry from the string data. CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); @@ -694,7 +700,7 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str); break; } - // Build the flatbuffer from the set of log entries. + // Build the GcsEntry from the set of log entries. if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { return Status::RedisError("Empty list/set/hash or wrong type"); } @@ -743,10 +749,10 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int if (table_key == nullptr) { RedisModule_ReplyWithNull(ctx); } else { - // Serialize the data to a flatbuffer to return to the client. + // Serialize the data to a GcsEntry to return to the client. GcsEntry gcs_entry; REPLY_AND_RETURN_IF_NOT_OK( - TableEntryToFlatbuf(ctx, table_key, prefix_str, id, &gcs_entry)); + TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); std::string str = gcs_entry.SerializeAsString(); RedisModule_ReplyWithStringBuffer(ctx, str.data(), str.size()); } @@ -864,7 +870,7 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleStrin // empty. GcsEntry gcs_entry; REPLY_AND_RETURN_IF_NOT_OK( - TableEntryToFlatbuf(ctx, table_key, prefix_str, id, &gcs_entry)); + TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); std::string str = gcs_entry.SerializeAsString(); RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, str.data(), str.size()); From 7b9e17ec4063d6b106e2da354b98ce9751482471 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 19:50:02 +0800 Subject: [PATCH 24/34] fix cc tests --- src/ray/gcs/client_test.cc | 349 ++++++++---------- .../test/object_manager_stress_test.cc | 30 +- .../test/object_manager_test.cc | 36 +- src/ray/raylet/lineage_cache_test.cc | 28 +- src/ray/raylet/reconstruction_policy_test.cc | 42 ++- .../raylet/task_dependency_manager_test.cc | 2 +- 6 files changed, 227 insertions(+), 260 deletions(-) diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index c7dc02e50651..d50c77f3f178 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -85,21 +85,21 @@ class TestGcsWithChainAsio : public TestGcsWithAsio { void TestTableLookup(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); - auto data = std::make_shared(); - data->task_specification = "123"; + auto data = std::make_shared(); + data->set_task("123"); // Check that we added the correct task. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); }; // Check that the lookup returns the added task. auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); test->Stop(); }; @@ -136,13 +136,13 @@ void TestLogLookup(const DriverID &driver_id, TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"abc", "def", "ghi"}; for (auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->node_manager_id = node_manager_id; + auto data = std::make_shared(); + data->set_node_manager_id(node_manager_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id, d.node_manager_id); + ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); }; RAY_CHECK_OK( client->task_reconstruction_log().Append(driver_id, task_id, data, add_callback)); @@ -151,10 +151,10 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { - ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]); + ASSERT_EQ(entry.node_manager_id(), node_manager_ids[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == node_manager_ids.size()) { @@ -182,7 +182,7 @@ void TestTableLookupFailure(const DriverID &driver_id, // Check that the lookup does not return data. auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { RAY_CHECK(false); }; + const TaskTableData &d) { RAY_CHECK(false); }; // Check that the lookup returns an empty entry. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) { @@ -207,16 +207,16 @@ void TestLogAppendAt(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"A", "B"}; - std::vector> data_log; + std::vector> data_log; for (const auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->node_manager_id = node_manager_id; + auto data = std::make_shared(); + data->set_node_manager_id(node_manager_id); data_log.push_back(data); } // Check that we added the correct task. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -242,10 +242,10 @@ void TestLogAppendAt(const DriverID &driver_id, auto lookup_callback = [node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { - appended_managers.push_back(entry.node_manager_id); + appended_managers.push_back(entry.node_manager_id()); } ASSERT_EQ(appended_managers, node_manager_ids); test->Stop(); @@ -268,13 +268,13 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"abc", "def", "ghi"}; for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); @@ -283,7 +283,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that lookup returns the added object entries. auto lookup_callback = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), managers.size()); test->IncrementNumCallbacks(); @@ -293,14 +293,14 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback)); for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); // Check that we added the correct object entries. auto remove_entry_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -310,7 +310,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that the entries are removed. auto lookup_callback2 = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 0); test->IncrementNumCallbacks(); @@ -332,7 +332,7 @@ TEST_F(TestGcsWithAsio, TestSet) { void TestDeleteKeysFromLog( const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; TaskID task_id; for (auto &data : data_vector) { @@ -340,9 +340,9 @@ void TestDeleteKeysFromLog( ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id, d.node_manager_id); + ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -352,7 +352,7 @@ void TestDeleteKeysFromLog( // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -367,7 +367,7 @@ void TestDeleteKeysFromLog( } for (const auto &task_id : ids) { auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -379,7 +379,7 @@ void TestDeleteKeysFromLog( void TestDeleteKeysFromTable(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector, + std::vector> &data_vector, bool stop_at_end) { std::vector ids; TaskID task_id; @@ -388,16 +388,16 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, add_callback)); } for (const auto &task_id : ids) { auto task_lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -414,7 +414,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, test->IncrementNumCallbacks(); }; auto undesired_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { ASSERT_TRUE(false); }; + const TaskTableData &data) { ASSERT_TRUE(false); }; for (size_t i = 0; i < ids.size(); ++i) { RAY_CHECK_OK(client->raylet_task_table().Lookup( driver_id, task_id, undesired_callback, expected_failure_callback)); @@ -428,7 +428,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, void TestDeleteKeysFromSet(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; ObjectID object_id; for (auto &data : data_vector) { @@ -436,9 +436,9 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, ids.push_back(object_id); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); @@ -447,7 +447,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [object_id, data_vector]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -461,7 +461,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, } for (const auto &object_id : ids) { auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -474,11 +474,11 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, void TestDeleteKeys(const DriverID &driver_id, std::shared_ptr client) { // Test delete function for keys of Log. - std::vector> task_reconstruction_vector; + std::vector> task_reconstruction_vector; auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->node_manager_id = ObjectID::FromRandom().Hex(); + auto data = std::make_shared(); + data->set_node_manager_id(ObjectID::FromRandom().Hex()); task_reconstruction_vector.push_back(data); } }; @@ -503,11 +503,11 @@ void TestDeleteKeys(const DriverID &driver_id, TestDeleteKeysFromLog(driver_id, client, task_reconstruction_vector); // Test delete function for keys of Table. - std::vector> task_vector; + std::vector> task_vector; auto AppendTaskData = [&task_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto task_data = std::make_shared(); - task_data->task_specification = ObjectID::FromRandom().Hex(); + auto task_data = std::make_shared(); + task_data->set_task(ObjectID::FromRandom().Hex()); task_vector.push_back(task_data); } }; @@ -529,11 +529,11 @@ void TestDeleteKeys(const DriverID &driver_id, 9 * RayConfig::instance().maximum_gcs_deletion_batch_size()); // Test delete function for keys of Set. - std::vector> object_vector; + std::vector> object_vector; auto AppendObjectData = [&object_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->manager = ObjectID::FromRandom().Hex(); + auto data = std::make_shared(); + data->set_manager(ObjectID::FromRandom().Hex()); object_vector.push_back(data); } }; @@ -561,45 +561,6 @@ TEST_F(TestGcsWithAsio, TestDeleteKey) { TestDeleteKeys(driver_id_, client_); } -// Task table callbacks. -void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); - ASSERT_EQ(data.raylet_id, kRandomId); -} - -void TaskLookupHelper(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data, bool do_stop) { - ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); - ASSERT_EQ(data.raylet_id, kRandomId); - if (do_stop) { - test->Stop(); - } -} -void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - TaskLookupHelper(client, id, data, /*do_stop=*/false); -} -void TaskLookupWithStop(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - TaskLookupHelper(client, id, data, /*do_stop=*/true); -} - -void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) { - RAY_CHECK(false); -} - -void TaskLookupAfterUpdate(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - ASSERT_EQ(data.scheduling_state, SchedulingState::LOST); - test->Stop(); -} - -void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id) { - RAY_CHECK(false); - test->Stop(); -} - void TestLogSubscribeAll(const DriverID &driver_id, std::shared_ptr client) { std::vector driver_ids; @@ -609,11 +570,11 @@ void TestLogSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, const DriverID &id, - const std::vector data) { + const std::vector data) { ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].Binary()); + ASSERT_EQ(entry.driver_id(), driver_ids[test->NumCallbacks()].Binary()); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids.size()) { @@ -660,7 +621,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, auto notification_callback = [object_ids, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector data) { + const std::vector data) { if (test->NumCallbacks() < 3 * 3) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); } else { @@ -669,7 +630,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers[test->NumCallbacks() % 3]); + ASSERT_EQ(entry.manager(), managers[test->NumCallbacks() % 3]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == object_ids.size() * 3 * 2) { @@ -684,8 +645,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, // We have subscribed. Do the writes to the table. for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->manager = managers[j]; + auto data = std::make_shared(); + data->set_manager(managers[j]); for (int k = 0; k < 3; k++) { // Add the same entry several times. // Expect no notification if the entry already exists. @@ -696,8 +657,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, } for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->manager = managers[j]; + auto data = std::make_shared(); + data->set_manager(managers[j]); for (int k = 0; k < 3; k++) { // Remove the same entry several times. // Expect no notification if the entry doesn't exist. @@ -740,11 +701,11 @@ void TestTableSubscribeId(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, task_id2); // Check that we get notifications in the same order as the writes. - ASSERT_EQ(data.task_specification, task_specs2[test->NumCallbacks()]); + ASSERT_EQ(data.task(), task_specs2[test->NumCallbacks()]); test->IncrementNumCallbacks(); if (test->NumCallbacks() == task_specs2.size()) { test->Stop(); @@ -771,13 +732,13 @@ void TestTableSubscribeId(const DriverID &driver_id, // Write both keys. We should only receive notifications for the key that // we requested them for. for (const auto &task_spec : task_specs1) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id1, data, nullptr)); } for (const auto &task_spec : task_specs2) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id2, data, nullptr)); } }; @@ -808,27 +769,27 @@ void TestLogSubscribeId(const DriverID &driver_id, // Add a log entry. DriverID driver_id1 = DriverID::FromRandom(); std::vector driver_ids1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->driver_id = driver_ids1[0]; + auto data1 = std::make_shared(); + data1->set_driver_id(driver_ids1[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data1, nullptr)); // Add a log entry at a second key. DriverID driver_id2 = DriverID::FromRandom(); std::vector driver_ids2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->driver_id = driver_ids2[0]; + auto data2 = std::make_shared(); + data2->set_driver_id(driver_ids2[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data2, nullptr)); // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [driver_id2, driver_ids2]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, driver_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids2[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id(), driver_ids2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids2.size()) { @@ -847,14 +808,14 @@ void TestLogSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++driver_ids1.begin(), driver_ids1.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->driver_id = driver_id_it; + auto data = std::make_shared(); + data->set_driver_id(driver_id_it); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data, nullptr)); } remaining = std::vector(++driver_ids2.begin(), driver_ids2.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->driver_id = driver_id_it; + auto data = std::make_shared(); + data->set_driver_id(driver_id_it); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data, nullptr)); } }; @@ -882,15 +843,15 @@ void TestSetSubscribeId(const DriverID &driver_id, // Add a set entry. ObjectID object_id1 = ObjectID::FromRandom(); std::vector managers1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->manager = managers1[0]; + auto data1 = std::make_shared(); + data1->set_manager(managers1[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data1, nullptr)); // Add a set entry at a second key. ObjectID object_id2 = ObjectID::FromRandom(); std::vector managers2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->manager = managers2[0]; + auto data2 = std::make_shared(); + data2->set_manager(managers2[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data2, nullptr)); // The callback for a notification from the table. This should only be @@ -898,13 +859,13 @@ void TestSetSubscribeId(const DriverID &driver_id, auto notification_callback = [object_id2, managers2]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers2[test->NumCallbacks()]); + ASSERT_EQ(entry.manager(), managers2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == managers2.size()) { @@ -923,14 +884,14 @@ void TestSetSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++managers1.begin(), managers1.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data, nullptr)); } remaining = std::vector(++managers2.begin(), managers2.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data, nullptr)); } }; @@ -958,8 +919,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // Add a table entry. TaskID task_id = TaskID::FromRandom(); std::vector task_specs = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->task_specification = task_specs[0]; + auto data = std::make_shared(); + data->set_task(task_specs[0]); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); // The failure callback should not be called since all keys are non-empty @@ -972,14 +933,14 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { ASSERT_EQ(id, task_id); // Check that we only get notifications for the first and last writes, // since notifications are canceled in between. if (test->NumCallbacks() == 0) { - ASSERT_EQ(data.task_specification, task_specs.front()); + ASSERT_EQ(data.task(), task_specs.front()); } else { - ASSERT_EQ(data.task_specification, task_specs.back()); + ASSERT_EQ(data.task(), task_specs.back()); } test->IncrementNumCallbacks(); if (test->NumCallbacks() == 2) { @@ -1001,8 +962,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // a notification for these writes. auto remaining = std::vector(++task_specs.begin(), task_specs.end()); for (const auto &task_spec : remaining) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1034,15 +995,15 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // Add a log entry. DriverID random_driver_id = DriverID::FromRandom(); std::vector driver_ids = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->driver_id = driver_ids[0]; + auto data = std::make_shared(); + data->set_driver_id(driver_ids[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [random_driver_id, driver_ids]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, random_driver_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because the log is append-only and notifications @@ -1050,7 +1011,7 @@ void TestLogSubscribeCancel(const DriverID &driver_id, auto driver_ids_copy = driver_ids; driver_ids_copy.insert(driver_ids_copy.begin(), driver_ids_copy.front()); for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids_copy[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id(), driver_ids_copy[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids_copy.size()) { @@ -1072,8 +1033,8 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++driver_ids.begin(), driver_ids.end()); for (const auto &remaining_driver_id : remaining) { - auto data = std::make_shared(); - data->driver_id = remaining_driver_id; + auto data = std::make_shared(); + data->set_driver_id(remaining_driver_id); RAY_CHECK_OK( client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); } @@ -1107,8 +1068,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // Add a set entry. ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->manager = managers[0]; + auto data = std::make_shared(); + data->set_manager(managers[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); // The callback for a notification from the object table. This should only be @@ -1116,7 +1077,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, auto notification_callback = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a @@ -1124,7 +1085,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // are canceled after the first write, then requested again. if (data.size() == 1) { // first notification - ASSERT_EQ(data[0].manager, managers[0]); + ASSERT_EQ(data[0].manager(), managers[0]); test->IncrementNumCallbacks(); } else { // second notification @@ -1132,7 +1093,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, std::unordered_set managers_set(managers.begin(), managers.end()); std::unordered_set data_managers_set; for (const auto &entry : data) { - data_managers_set.insert(entry.manager); + data_managers_set.insert(entry.manager()); test->IncrementNumCallbacks(); } ASSERT_EQ(managers_set, data_managers_set); @@ -1156,8 +1117,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++managers.begin(), managers.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1186,17 +1147,17 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { } void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id, - const ClientTableDataT &data, bool is_insertion) { + const ClientTableData &data, bool is_insertion) { ClientID added_id = client->client_table().GetLocalClientId(); ASSERT_EQ(client_id, added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); - ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); + ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); + ASSERT_EQ(data.entry_type() == ClientTableData::INSERTION, is_insertion); - ClientTableDataT cached_client; + ClientTableData cached_client; client->client_table().GetClient(added_id, cached_client); - ASSERT_EQ(ClientID::FromBinary(cached_client.client_id), added_id); - ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(cached_client.client_id()), added_id); + ASSERT_EQ(cached_client.entry_type() == ClientTableData::INSERTION, is_insertion); } void TestClientTableConnect(const DriverID &driver_id, @@ -1204,17 +1165,17 @@ void TestClientTableConnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); test->Stop(); }); // Connect and disconnect to client table. We should receive notifications // for the addition and removal of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1229,23 +1190,23 @@ void TestClientTableDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/true); // Disconnect from the client table. We should receive a notification // for the removal of our own entry. RAY_CHECK_OK(client->client_table().Disconnect()); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/false); test->Stop(); }); // Connect to the client table. We should receive notification for the // addition of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1260,20 +1221,20 @@ void TestClientTableImmediateDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, false); test->Stop(); }); // Connect to then immediately disconnect from the client table. We should // receive notifications for the addition and removal of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); RAY_CHECK_OK(client->client_table().Disconnect()); test->Start(); @@ -1286,10 +1247,10 @@ TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) { void TestClientTableMarkDisconnected(const DriverID &driver_id, std::shared_ptr client) { - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); // Connect to the client table to start receiving notifications. RAY_CHECK_OK(client->client_table().Connect(local_client_info)); // Mark a different client as dead. @@ -1299,8 +1260,8 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, // marked as dead. client->client_table().RegisterClientRemovedCallback( [dead_client_id](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { - ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); + const ClientTableData &data) { + ASSERT_EQ(ClientID::FromBinary(data.client_id()), dead_client_id); test->Stop(); }); test->Start(); @@ -1316,31 +1277,31 @@ void TestHashTable(const DriverID &driver_id, const int expected_count = 14; ClientID client_id = ClientID::FromRandom(); // Prepare the first resource map: data_map1. - auto cpu_data = std::make_shared(); - cpu_data->resource_name = "CPU"; - cpu_data->resource_capacity = 100; - auto gpu_data = std::make_shared(); - gpu_data->resource_name = "GPU"; - gpu_data->resource_capacity = 2; + auto cpu_data = std::make_shared(); + cpu_data->set_resource_name("CPU"); + cpu_data->set_resource_capacity(100); + auto gpu_data = std::make_shared(); + gpu_data->set_resource_name("GPU"); + gpu_data->set_resource_capacity(2); DynamicResourceTable::DataMap data_map1; data_map1.emplace("CPU", cpu_data); data_map1.emplace("GPU", gpu_data); // Prepare the second resource map: data_map2 which decreases CPU, // increases GPU and add a new CUSTOM compared to data_map1. - auto data_cpu = std::make_shared(); - data_cpu->resource_name = "CPU"; - data_cpu->resource_capacity = 50; - auto data_gpu = std::make_shared(); - data_gpu->resource_name = "GPU"; - data_gpu->resource_capacity = 10; - auto data_custom = std::make_shared(); - data_custom->resource_name = "CUSTOM"; - data_custom->resource_capacity = 2; + auto data_cpu = std::make_shared(); + data_cpu->set_resource_name("CPU"); + data_cpu->set_resource_capacity(50); + auto data_gpu = std::make_shared(); + data_gpu->set_resource_name("GPU"); + data_gpu->set_resource_capacity(10); + auto data_custom = std::make_shared(); + data_custom->set_resource_name("CUSTOM"); + data_custom->set_resource_capacity(2); DynamicResourceTable::DataMap data_map2; data_map2.emplace("CPU", data_cpu); data_map2.emplace("GPU", data_gpu); data_map2.emplace("CUSTOM", data_custom); - data_map2["CPU"]->resource_capacity = 50; + data_map2["CPU"]->set_resource_capacity(50); // This is a common comparison function for the test. auto compare_test = [](const DynamicResourceTable::DataMap &data1, const DynamicResourceTable::DataMap &data2) { @@ -1348,8 +1309,8 @@ void TestHashTable(const DriverID &driver_id, for (const auto &data : data1) { auto iter = data2.find(data.first); ASSERT_TRUE(iter != data2.end()); - ASSERT_EQ(iter->second->resource_name, data.second->resource_name); - ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity); + ASSERT_EQ(iter->second->resource_name(), data.second->resource_name()); + ASSERT_EQ(iter->second->resource_capacity(), data.second->resource_capacity()); } }; auto subscribe_callback = [](AsyncGcsClient *client) { diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 55aa59124a99..2d5292842acf 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -11,6 +11,8 @@ namespace ray { +using rpc::ClientTableData; + std::string store_executable; static inline void flushall_redis(void) { @@ -52,10 +54,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = ip; - client_info.node_manager_port = object_manager_port; - client_info.object_manager_port = object_manager_port; + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(ip); + client_info.set_node_manager_port(object_manager_port); + client_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -242,8 +244,8 @@ class StressTestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -438,16 +440,16 @@ class StressTestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "All connected clients:" << "\n"; - ClientTableDataT data; + ClientTableData data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id) << "\n" - << "ClientIp=" << data.node_manager_address << "\n" - << "ClientPort=" << data.node_manager_port; - ClientTableDataT data2; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id()) << "\n" + << "ClientIp=" << data.node_manager_address() << "\n" + << "ClientPort=" << data.node_manager_port(); + ClientTableData data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id) << "\n" - << "ClientIp=" << data2.node_manager_address << "\n" - << "ClientPort=" << data2.node_manager_port; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id()) << "\n" + << "ClientIp=" << data2.node_manager_address() << "\n" + << "ClientPort=" << data2.node_manager_port(); } }; diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index ee6c78d8ed42..45b80a267f2f 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -14,6 +14,8 @@ int64_t wait_timeout_ms; namespace ray { +using rpc::ClientTableData; + static inline void flushall_redis(void) { redisContext *context = redisConnect("127.0.0.1", 6379); freeReplyObject(redisCommand(context, "FLUSHALL")); @@ -46,10 +48,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = ip; - client_info.node_manager_port = object_manager_port; - client_info.object_manager_port = object_manager_port; + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(ip); + client_info.set_node_manager_port(object_manager_port); + client_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -221,8 +223,8 @@ class TestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -457,19 +459,19 @@ class TestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "Server client ids:" << "\n"; - ClientTableDataT data; + ClientTableData data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id).IsNil()); - RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id); - RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address; - RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port; - ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id)); - ClientTableDataT data2; + RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id()).IsNil()); + RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id()); + RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address(); + RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port(); + ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id())); + ClientTableData data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id); - RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address; - RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port; - ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id)); + RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id()); + RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address(); + RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port(); + ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id())); } }; diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 43e64e400292..a6184902f803 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -13,7 +13,7 @@ namespace ray { namespace raylet { -class MockGcs : public gcs::TableInterface, +class MockGcs : public gcs::TableInterface, public gcs::PubsubInterface { public: MockGcs() {} @@ -23,15 +23,15 @@ class MockGcs : public gcs::TableInterface, } Status Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, - const gcs::TableInterface::WriteCallback &done) { + std::shared_ptr &task_data, + const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; auto callback = done; // If we requested notifications for this task ID, send the notification as // part of the callback. if (subscribed_tasks_.count(task_id) == 1) { callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const protocol::TaskT &data) { + const TaskTableData &data) { done(client, task_id, data); // If we're subscribed to the task to be added, also send a // subscription notification. @@ -45,14 +45,14 @@ class MockGcs : public gcs::TableInterface, return ray::Status::OK(); } - Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { + Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { task_table_[task_id] = task_data; // Send a notification after the add if the lineage cache requested // notifications for this key. bool send_notification = (subscribed_tasks_.count(task_id) == 1); auto callback = [this, send_notification](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const protocol::TaskT &data) { + const TaskTableData &data) { if (send_notification) { notification_callback_(client, task_id, data); } @@ -84,7 +84,7 @@ class MockGcs : public gcs::TableInterface, } } - const std::unordered_map> &TaskTable() const { + const std::unordered_map> &TaskTable() const { return task_table_; } @@ -95,7 +95,7 @@ class MockGcs : public gcs::TableInterface, const int NumTaskAdds() const { return num_task_adds_; } private: - std::unordered_map> task_table_; + std::unordered_map> task_table_; std::vector> callbacks_; gcs::raylet::TaskTable::WriteCallback notification_callback_; std::unordered_set subscribed_tasks_; @@ -111,7 +111,7 @@ class LineageCacheTest : public ::testing::Test { mock_gcs_(), lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &data) { + const TaskTableData &data) { lineage_cache_.HandleEntryCommitted(task_id); num_notifications_++; }); @@ -341,7 +341,7 @@ TEST_F(LineageCacheTest, TestEvictChain) { ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); // Simulate executing the task on a remote node and adding it to the GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK( mock_gcs_.RemoteAdd(tasks.at(1).GetTaskSpecification().TaskId(), task_data)); mock_gcs_.Flush(); @@ -432,7 +432,7 @@ TEST_F(LineageCacheTest, TestEviction) { // Simulate executing the first task on a remote node and adding it to the // GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); auto it = tasks.begin(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); it++; @@ -490,7 +490,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { auto last_task = tasks.front(); tasks.erase(tasks.begin()); for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); // Check that the remote task is flushed. num_tasks_flushed++; @@ -500,7 +500,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { } // Flush the last task. The lineage should not get evicted until this task's // commit is received. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(last_task.GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; mock_gcs_.Flush(); @@ -536,7 +536,7 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { // until after the final remote task is executed, since a task can only be // evicted once all of its ancestors have been committed. for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size * 2); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 4ccebd0c0c09..12d9336a382f 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -14,6 +14,8 @@ namespace ray { namespace raylet { +using rpc::TaskLeaseData; + class MockObjectDirectory : public ObjectDirectoryInterface { public: MockObjectDirectory() {} @@ -83,7 +85,7 @@ class MockGcs : public gcs::PubsubInterface, } void Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_lease_data) { + std::shared_ptr &task_lease_data) { task_lease_table_[task_id] = task_lease_data; if (subscribed_tasks_.count(task_id) == 1) { notification_callback_(nullptr, task_id, *task_lease_data); @@ -110,7 +112,7 @@ class MockGcs : public gcs::PubsubInterface, Status AppendAt( const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const ray::gcs::LogInterface::WriteCallback &success_callback, const ray::gcs::LogInterface::WriteCallback @@ -132,15 +134,15 @@ class MockGcs : public gcs::PubsubInterface, MOCK_METHOD4( Append, ray::Status( - const DriverID &, const TaskID &, std::shared_ptr &, + const DriverID &, const TaskID &, std::shared_ptr &, const ray::gcs::LogInterface::WriteCallback &)); private: gcs::TaskLeaseTable::WriteCallback notification_callback_; gcs::TaskLeaseTable::FailureCallback failure_callback_; - std::unordered_map> task_lease_table_; + std::unordered_map> task_lease_table_; std::unordered_set subscribed_tasks_; - std::unordered_map> + std::unordered_map> task_reconstruction_log_; }; @@ -159,9 +161,9 @@ class ReconstructionPolicyTest : public ::testing::Test { timer_canceled_(false) { mock_gcs_.Subscribe( [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseDataT &task_lease) { + const TaskLeaseData &task_lease) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, - task_lease.timeout); + task_lease.timeout()); }, [this](gcs::AsyncGcsClient *client, const TaskID &task_id) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, 0); @@ -314,10 +316,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { int64_t test_period = 2 * reconstruction_timeout_ms_; // Acquire the task lease for a period longer than the test period. - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = 2 * test_period; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(2 * test_period); mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); // Listen for an object. @@ -328,7 +330,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { ASSERT_TRUE(reconstructed_tasks_.empty()); // Run the test again past the expiration time of the lease. - Run(task_lease_data->timeout * 1.1); + Run(task_lease_data->timeout() * 1.1); // Check that this time, reconstruction is triggered. ASSERT_EQ(reconstructed_tasks_[task_id], 1); } @@ -341,10 +343,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { reconstruction_policy_->ListenAndMaybeReconstruct(object_id); // Send the reconstruction manager heartbeats about the object. SetPeriodicTimer(reconstruction_timeout_ms_ / 2, [this, task_id]() { - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = reconstruction_timeout_ms_; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(reconstruction_timeout_ms_); mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); }); // Run the test for much longer than the reconstruction timeout. @@ -393,14 +395,14 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at // reconstruction. - auto task_reconstruction_data = std::make_shared(); - task_reconstruction_data->node_manager_id = ClientID::FromRandom().Binary(); - task_reconstruction_data->num_reconstructions = 0; + auto task_reconstruction_data = std::make_shared(); + task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_reconstruction_data->set_num_reconstructions(0); RAY_CHECK_OK( mock_gcs_.AppendAt(DriverID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { ASSERT_TRUE(false); }, + const TaskReconstructionData &data) { ASSERT_TRUE(false); }, /*log_index=*/0)); // Listen for an object. diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index e0f832a12870..f7a60989fcba 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -30,7 +30,7 @@ class MockGcs : public gcs::TableInterface { MOCK_METHOD4( Add, ray::Status(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done)); }; From a1e3d0869d16f73105e965b6b448c2db75031715 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 20:30:32 +0800 Subject: [PATCH 25/34] remove unused --- src/ray/rpc/util.h | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h index 5e928dd1de43..24dd55446d00 100644 --- a/src/ray/rpc/util.h +++ b/src/ray/rpc/util.h @@ -28,11 +28,6 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { } } -template -inline std::unordered_map MapFromProtobuf(::google::protobuf::Map pb_map) { - return std::unordered_map(pb_map.begin(), pb_map.end()); -} - template inline std::vector VectorFromProtobuf( const ::google::protobuf::RepeatedPtrField &pb_repeated) { @@ -45,16 +40,6 @@ inline std::vector VectorFromProtobuf( return std::vector(pb_repeated.begin(), pb_repeated.end()); } -template -inline std::vector IdVectorFromProtobuf( - const ::google::protobuf::RepeatedPtrField<::std::string> &pb_repeated) { - auto str_vec = VectorFromProtobuf(pb_repeated); - std::vector ret(str_vec.size()); - std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(ret), - &ID::FromBinary); - return ret; -} - } // namespace rpc } // namespace ray From 4d376b2b966382566949d39b149828c3edcb3bfc Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 20:41:25 +0800 Subject: [PATCH 26/34] comments in gcs.proto --- src/ray/protobuf/gcs.proto | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index fbb08ade4415..e2e19dd19f7f 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -5,6 +5,7 @@ package ray.rpc; option java_multiple_files = true; option java_package = "org.ray.runtime.generated"; +// Language of a worker or task. enum Language { PYTHON = 0; CPP = 1; @@ -77,11 +78,16 @@ message TaskReconstructionData { bytes node_manager_id = 2; } +// TODO(hchen): Task table currently still uses flatbuffers-defined data structure +// (`Task` in `node_manager.fbs`), because a lot of code depends on that. This should +// be migrated to protobuf very soon. message TaskTableData { + // Flatbuffers-serialized content of the task, see `src/ray/raylet/task.h`. bytes task = 1; } message ActorTableData { + // State of an actor. enum ActorState { // Actor is alive. ALIVE = 0; @@ -121,6 +127,7 @@ message ErrorTableData { } message ProfileTableData { + // Represents a profile event. message ProfileEvent { // The type of the event. string event_type = 1; @@ -180,6 +187,7 @@ message ClientTableData { // Enum to store the entry type in the log EntryType entry_type = 7; + // TODO(hchen): Define the following resources in map format. repeated string resources_total_label = 8; repeated double resources_total_capacity = 9; } @@ -187,6 +195,7 @@ message ClientTableData { message HeartbeatTableData { // Node manager client id bytes client_id = 1; + // TODO(hchen): Define the following resources in map format. // Resource capacity currently available on this node manager. repeated string resources_available_label = 2; repeated double resources_available_capacity = 3; From 9c1c6258ed95e5803a58f96f34981dcc13e8c081 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 21:05:29 +0800 Subject: [PATCH 27/34] format --- ...modify_generated_java_flatbuffers_files.py | 4 +-- python/ray/gcs_utils.py | 9 ++---- python/ray/state.py | 29 +++++++++---------- python/ray/tests/test_basic.py | 14 +++++---- python/ray/tests/test_failure.py | 4 +-- python/ray/utils.py | 8 ++--- python/ray/worker.py | 20 ++++++------- src/ray/gcs/client_test.cc | 6 ++-- src/ray/gcs/redis_module/ray_redis_module.cc | 4 +-- 9 files changed, 47 insertions(+), 51 deletions(-) diff --git a/java/modify_generated_java_flatbuffers_files.py b/java/modify_generated_java_flatbuffers_files.py index fa1867bde8a7..5bf62e56d7e4 100644 --- a/java/modify_generated_java_flatbuffers_files.py +++ b/java/modify_generated_java_flatbuffers_files.py @@ -4,7 +4,6 @@ import os import sys - """ This script is used for modifying the generated java flatbuffer files for the reason: The package declaration in Java is different @@ -46,6 +45,5 @@ def add_package_declarations(generated_root_path): if __name__ == "__main__": ray_home = sys.argv[1] root_path = os.path.join( - ray_home, - "java/runtime/src/main/java/org/ray/runtime/generated") + ray_home, "java/runtime/src/main/java/org/ray/runtime/generated") add_package_declarations(root_path) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 3097fcac4f82..ba72e96f41db 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -2,8 +2,6 @@ from __future__ import division from __future__ import print_function -import flatbuffers - from ray.core.generated.ray.protocol.Task import Task from ray.core.generated.gcs_pb2 import ( @@ -31,7 +29,6 @@ "GcsEntry", "HeartbeatBatchTableData", "HeartbeatTableData", - "Language", "ObjectTableData", "ProfileTableData", "TablePrefix", @@ -47,12 +44,12 @@ # xray heartbeats XRAY_HEARTBEAT_CHANNEL = str( - TablePubsub.Value('HEARTBEAT_PUBSUB')).encode("ascii") + TablePubsub.Value("HEARTBEAT_PUBSUB")).encode("ascii") XRAY_HEARTBEAT_BATCH_CHANNEL = str( - TablePubsub.Value('HEARTBEAT_BATCH_PUBSUB')).encode("ascii") + TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii") # xray driver updates -XRAY_DRIVER_CHANNEL = str(TablePubsub.Value('DRIVER_PUBSUB')).encode("ascii") +XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii") # These prefixes must be kept up-to-date with the TablePrefix enum in # gcs.proto. diff --git a/python/ray/state.py b/python/ray/state.py index 90326673a211..35f97cd65f5e 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -11,7 +11,6 @@ import ray from ray.function_manager import FunctionDescriptor -from ray.ray_constants import ID_SIZE from ray import ( gcs_utils, services, @@ -33,7 +32,7 @@ def _parse_client_table(redis_client): """ NIL_CLIENT_ID = ray.ObjectID.nil().binary() message = redis_client.execute_command( - "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value('CLIENT'), "", + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "", NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only @@ -75,8 +74,9 @@ def _parse_client_table(redis_client): # it cannot have previously been removed. else: assert client_id in node_info, "Client not found!" - assert node_info[client_id]["EntryType"] != gcs_utils.ClientTableData.DELETION, ( - "Unexpected updation of deleted client.") + is_deletion = (node_info[client_id]["EntryType"] != + gcs_utils.ClientTableData.DELETION) + assert is_deletion, "Unexpected updation of deleted client." res_map = node_info[client_id]["Resources"] if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE: for res in resources: @@ -240,8 +240,8 @@ def _object_table(self, object_id): # Return information about a single object ID. message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - gcs_utils.TablePrefix.Value('OBJECT'), "", - object_id.binary()) + gcs_utils.TablePrefix.Value("OBJECT"), + "", object_id.binary()) if message is None: return {} gcs_entry = gcs_utils.GcsEntry.FromString(message) @@ -273,8 +273,7 @@ def object_table(self, object_id=None): return self._object_table(object_id) else: # Return the entire object table. - object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + - "*") + object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*") object_ids_binary = { key[len(gcs_utils.TablePrefix_OBJECT_string):] for key in object_keys @@ -298,7 +297,7 @@ def _task_table(self, task_id): assert isinstance(task_id, ray.TaskID) message = self._execute_command( task_id, "RAY.TABLE_LOOKUP", - gcs_utils.TablePrefix.Value('RAYLET_TASK'), "", task_id.binary()) + gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary()) if message is None: return {} gcs_entries = gcs_utils.GcsEntry.FromString(message) @@ -398,8 +397,8 @@ def _profile_table(self, batch_id): # TODO(rkn): This method should support limiting the number of log # events and should also support returning a window of events. message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP", - gcs_utils.TablePrefix.Value('PROFILE'), "", - batch_id.binary()) + gcs_utils.TablePrefix.Value("PROFILE"), + "", batch_id.binary()) if message is None: return [] @@ -432,8 +431,8 @@ def _profile_table(self, batch_id): def profile_table(self): self._check_connected() - profile_table_keys = self._keys( - gcs_utils.TablePrefix_PROFILE_string + "*") + profile_table_keys = self._keys(gcs_utils.TablePrefix_PROFILE_string + + "*") batch_identifiers_binary = [ key[len(gcs_utils.TablePrefix_PROFILE_string):] for key in profile_table_keys @@ -829,7 +828,7 @@ def _error_messages(self, driver_id): """ assert isinstance(driver_id, ray.DriverID) message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value('ERROR_INFO'), "", + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "", driver_id.binary()) # If there are no errors, return early. @@ -891,7 +890,7 @@ def actor_checkpoint_info(self, actor_id): message = self._execute_command( actor_id, "RAY.TABLE_LOOKUP", - gcs_utils.TablePrefix.Value('ACTOR_CHECKPOINT_ID'), + gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"), "", actor_id.binary(), ) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index eb5920ebfeac..6b4bd754cd4d 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2736,15 +2736,17 @@ def test_duplicate_error_messages(shutdown_only): r = ray.worker.global_worker.redis_client - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value('ERROR_INFO'), - ray.gcs_utils.TablePubsub.Value('ERROR_INFO_PUBSUB'), driver_id.binary(), - error_data) + r.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) # Before https://github.com/ray-project/ray/pull/3316 this would # give an error - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value('ERROR_INFO'), - ray.gcs_utils.TablePubsub.Value('ERROR_INFO_PUBSUB'), driver_id.binary(), - error_data) + r.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) @pytest.mark.skipif( diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 7ec64d37d616..a560e461f7a2 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -493,8 +493,8 @@ def test_warning_monitor_died(shutdown_only): malformed_message = "asdf" redis_client = ray.worker.global_worker.redis_client redis_client.execute_command( - "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value('HEARTBEAT_BATCH'), - ray.gcs_utils.TablePubsub.Value('HEARTBEAT_BATCH_PUBSUB'), fake_id, + "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("HEARTBEAT_BATCH"), + ray.gcs_utils.TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB"), fake_id, malformed_message) wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1) diff --git a/python/ray/utils.py b/python/ray/utils.py index 27ea4424732b..0db48e41d025 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -93,10 +93,10 @@ def push_error_to_driver_through_redis(redis_client, # of through the raylet. error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, message, time.time()) - redis_client.execute_command("RAY.TABLE_APPEND", - ray.gcs_utils.TablePrefix.Value('ERROR_INFO'), - ray.gcs_utils.TablePubsub.Value('ERROR_INFO_PUBSUB'), - driver_id.binary(), error_data) + redis_client.execute_command( + "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) def is_cython(obj): diff --git a/python/ray/worker.py b/python/ray/worker.py index 179f54c51fd9..710f0db43c6b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -461,11 +461,11 @@ def _deserialize_object_from_arrow(self, data, metadata, object_id, # Otherwise, return an exception object based on # the error type. error_type = int(metadata) - if error_type == ErrorType.Value('WORKER_DIED'): + if error_type == ErrorType.Value("WORKER_DIED"): return RayWorkerError() - elif error_type == ErrorType.Value('ACTOR_DIED'): + elif error_type == ErrorType.Value("ACTOR_DIED"): return RayActorError() - elif error_type == ErrorType.Value('OBJECT_UNRECONSTRUCTABLE'): + elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"): return UnreconstructableError(ray.ObjectID(object_id.binary())) else: assert False, "Unrecognized error type " + str(error_type) @@ -1637,7 +1637,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): # Really we should just subscribe to the errors for this specific job. # However, currently all errors seem to be published on the same channel. error_pubsub_channel = str( - ray.gcs_utils.TablePubsub.Value('ERROR_INFO_PUBSUB')).encode("ascii") + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")).encode("ascii") worker.error_message_pubsub_client.subscribe(error_pubsub_channel) # worker.error_message_pubsub_client.psubscribe("*") @@ -1880,12 +1880,12 @@ def connect(node, task_table_data.task = driver_task._serialized_raylet_task() # Add the driver task to the task table. - ray.state.state._execute_command(driver_task.task_id(), - "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.Value('RAYLET_TASK'), - ray.gcs_utils.TablePubsub.Value('RAYLET_TASK_PUBSUB'), - driver_task.task_id().binary(), - task_table_data.SerializeToString()) + ray.state.state._execute_command( + driver_task.task_id(), "RAY.TABLE_ADD", + ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"), + ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"), + driver_task.task_id().binary(), + task_table_data.SerializeToString()) # Set the driver's current task ID to the task ID assigned to the # driver task. diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index d50c77f3f178..55115b1e2067 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -281,9 +281,9 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli } // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + auto lookup_callback = [object_id, managers](gcs::AsyncGcsClient *client, + const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), managers.size()); test->IncrementNumCallbacks(); diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index b0c539f277ff..c3a82c320d06 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -667,8 +667,8 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a /// \param entry_id The UniqueID associated with the open Redis key. /// \param[out] gcs_entry The created GcsEntry. Status TableEntryToProtobuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, - RedisModuleString *prefix_str, RedisModuleString *entry_id, - GcsEntry *gcs_entry) { + RedisModuleString *prefix_str, RedisModuleString *entry_id, + GcsEntry *gcs_entry) { auto key_type = RedisModule_KeyType(table_key); switch (key_type) { case REDISMODULE_KEYTYPE_STRING: { From 951f4602ec57041eebe4ecaf74546085c8505fb1 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 24 Jun 2019 21:07:05 +0800 Subject: [PATCH 28/34] format cpp --- src/ray/gcs/tables.h | 46 ++++++++++------------ src/ray/object_manager/object_directory.cc | 2 +- src/ray/object_manager/object_manager.cc | 19 +++++---- src/ray/raylet/actor_registration.h | 2 +- src/ray/raylet/monitor.h | 4 +- src/ray/raylet/node_manager.h | 6 +-- 6 files changed, 38 insertions(+), 41 deletions(-) diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 43b80df0cfcf..2ecc3440839e 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -20,26 +20,24 @@ namespace ray { namespace gcs { -using rpc::TablePrefix; -using rpc::TablePubsub; -using rpc::GcsChangeMode; -using rpc::GcsEntry; -using rpc::ObjectTableData; -using rpc::TaskTableData; +using rpc::ActorCheckpointData; +using rpc::ActorCheckpointIdData; using rpc::ActorTableData; -using rpc::TaskReconstructionData; -using rpc::TaskLeaseData; -using rpc::HeartbeatTableData; -using rpc::ErrorTableData; using rpc::ClientTableData; using rpc::DriverTableData; +using rpc::ErrorTableData; +using rpc::GcsChangeMode; +using rpc::GcsEntry; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; +using rpc::ObjectTableData; using rpc::ProfileTableData; -using rpc::ActorCheckpointData; -using rpc::ActorCheckpointIdData; using rpc::RayResource; -using rpc::HeartbeatBatchTableData; -using rpc::HeartbeatBatchTableData; - +using rpc::TablePrefix; +using rpc::TablePubsub; +using rpc::TaskLeaseData; +using rpc::TaskReconstructionData; +using rpc::TaskTableData; class RedisContext; @@ -92,9 +90,9 @@ class Log : public LogInterface, virtual public PubsubInterface { public: using Callback = std::function &data)>; - using NotificationCallback = std::function &data)>; + using NotificationCallback = + std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -700,18 +698,15 @@ class TaskLeaseTable : public Table { } Status Add(const DriverID &driver_id, const TaskID &id, - std::shared_ptr &data, - const WriteCallback &done) override { - RAY_RETURN_NOT_OK( - (Table::Add(driver_id, id, data, done))); + std::shared_ptr &data, const WriteCallback &done) override { + RAY_RETURN_NOT_OK((Table::Add(driver_id, id, data, done))); // Mark the entry for expiration in Redis. It's okay if this command fails // since the lease entry itself contains the expiration period. In the // worst case, if the command fails, then a client that looks up the lease // entry will overestimate the expiration time. // TODO(swang): Use a common helper function to format the key instead of // hardcoding it to match the Redis module. - std::vector args = {"PEXPIRE", - TablePrefix_Name(prefix_) + id.Binary(), + std::vector args = {"PEXPIRE", TablePrefix_Name(prefix_) + id.Binary(), std::to_string(data->timeout())}; return GetRedisContext(id)->RunArgvAsync(args); @@ -937,8 +932,7 @@ class ClientTable : public Log { private: /// Handle a client table notification. - void HandleNotification(AsyncGcsClient *client, - const ClientTableData ¬ifications); + void HandleNotification(AsyncGcsClient *client, const ClientTableData ¬ifications); /// Handle this client's successful connection to the GCS. void HandleConnected(AsyncGcsClient *client, const ClientTableData &client_data); /// Whether this client has called Disconnect(). diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 6ffda660a300..454379d18302 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -8,9 +8,9 @@ ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service, namespace { +using ray::rpc::ClientTableData; using ray::rpc::GcsChangeMode; using ray::rpc::ObjectTableData; -using ray::rpc::ClientTableData; /// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 39a4326eac99..cb126f23ef47 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -315,9 +315,9 @@ void ObjectManager::HandleSendFinished(const ObjectID &object_id, profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + - std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"); + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"," + std::to_string(chunk_index) + ",\"" + + status.ToString() + "\"]"); profile_events_.push_back(profile_event); } @@ -335,9 +335,9 @@ void ObjectManager::HandleReceiveFinished(const ObjectID &object_id, profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + - std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"); + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"," + std::to_string(chunk_index) + ",\"" + + status.ToString() + "\"]"); profile_events_.push_back(profile_event); } @@ -764,7 +764,9 @@ void ObjectManager::ProcessClientMessage(std::shared_ptr &c DisconnectClient(conn, message); break; } - default: { RAY_LOG(FATAL) << "invalid request " << message_type; } + default: { + RAY_LOG(FATAL) << "invalid request " << message_type; + } } } @@ -805,7 +807,8 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr &con profile_event.set_event_type("receive_pull_request"); profile_event.set_start_time(current_sys_time_seconds()); profile_event.set_end_time(profile_event.start_time()); - profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"]"); + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"]"); profile_events_.push_back(profile_event); Push(object_id, client_id); diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index c90636087158..208e4998263f 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -136,7 +136,7 @@ class ActorRegistration { /// \param task The task that just finished on the actor. /// \return A shared pointer to the generated checkpoint data. std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, - const Task &task); + const Task &task); private: /// Information from the global actor table about this actor, including the diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index 4f05b537a295..5725e52cf495 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -11,9 +11,9 @@ namespace ray { namespace raylet { -using rpc::HeartbeatTableData; -using rpc::HeartbeatBatchTableData; using rpc::ClientTableData; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; class Monitor { public: diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index efdd5d238bf7..f45c8b035553 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -25,12 +25,12 @@ namespace ray { namespace raylet { +using rpc::ActorTableData; using rpc::ClientTableData; -using rpc::HeartbeatTableData; -using rpc::HeartbeatBatchTableData; using rpc::DriverTableData; -using rpc::ActorTableData; using rpc::ErrorType; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; struct NodeManagerConfig { /// The node's resource configuration. From 73641a058767521ae0010f5b19cb2c55847d9794 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 25 Jun 2019 10:59:53 +0800 Subject: [PATCH 29/34] format --- BUILD.bazel | 6 +++--- src/ray/object_manager/object_manager.cc | 4 +--- src/ray/rpc/util.h | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 80dc692841bd..bc9e6bcd8006 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -43,8 +43,8 @@ cc_proto_library( cc_grpc_library( name = "node_manager_cc_grpc", srcs = [":node_manager_proto"], - deps = [":node_manager_cc_proto"], grpc_only = True, + deps = [":node_manager_cc_proto"], ) # Node manager server and client. @@ -457,9 +457,9 @@ cc_library( deps = [ ":gcs_cc_proto", ":gcs_fbs", - ":node_manager_rpc", ":hiredis", ":node_manager_fbs", + ":node_manager_rpc", ":ray_common", ":ray_util", ":stats_lib", @@ -674,8 +674,8 @@ cc_binary( linkstatic = 1, visibility = ["//java:__subpackages__"], deps = [ - ":ray_common", ":gcs_cc_proto", + ":ray_common", ], ) diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index cb126f23ef47..964cee605ced 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -764,9 +764,7 @@ void ObjectManager::ProcessClientMessage(std::shared_ptr &c DisconnectClient(conn, message); break; } - default: { - RAY_LOG(FATAL) << "invalid request " << message_type; - } + default: { RAY_LOG(FATAL) << "invalid request " << message_type; } } } diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h index 24dd55446d00..59ae75ae33be 100644 --- a/src/ray/rpc/util.h +++ b/src/ray/rpc/util.h @@ -1,7 +1,7 @@ #ifndef RAY_RPC_UTIL_H #define RAY_RPC_UTIL_H -#include +#include #include #include "ray/common/status.h" From 12a301d0ade29439753d32872c52a30324ae9a48 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 25 Jun 2019 11:23:20 +0800 Subject: [PATCH 30/34] add back __init__.py --- python/ray/core/generated/__init__.py | 0 python/ray/core/generated/ray/__init__.py | 0 python/ray/core/generated/ray/protocol/__init__.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 python/ray/core/generated/__init__.py create mode 100644 python/ray/core/generated/ray/__init__.py create mode 100644 python/ray/core/generated/ray/protocol/__init__.py diff --git a/python/ray/core/generated/__init__.py b/python/ray/core/generated/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/core/generated/ray/__init__.py b/python/ray/core/generated/ray/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/core/generated/ray/protocol/__init__.py b/python/ray/core/generated/ray/protocol/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From b3853e4ec653de3c9191721cbfef9bd594505393 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 25 Jun 2019 11:37:30 +0800 Subject: [PATCH 31/34] Java proto generates single file --- .../src/main/java/org/ray/runtime/gcs/GcsClient.java | 8 ++++---- .../org/ray/runtime/objectstore/ObjectStoreProxy.java | 2 +- src/ray/protobuf/gcs.proto | 1 - 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index ea11696e3cc1..17c248ed0a57 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -13,10 +13,10 @@ import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; -import org.ray.runtime.generated.ActorCheckpointIdData; -import org.ray.runtime.generated.ClientTableData; -import org.ray.runtime.generated.ClientTableData.EntryType; -import org.ray.runtime.generated.TablePrefix; +import org.ray.runtime.generated.Gcs.ActorCheckpointIdData; +import org.ray.runtime.generated.Gcs.ClientTableData; +import org.ray.runtime.generated.Gcs.ClientTableData.EntryType; +import org.ray.runtime.generated.Gcs.TablePrefix; import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 74e940674eda..1a7e4701c22b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -16,7 +16,7 @@ import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.config.RunMode; -import org.ray.runtime.generated.ErrorType; +import org.ray.runtime.generated.Gcs.ErrorType; import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.Serializer; import org.slf4j.Logger; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index e2e19dd19f7f..d0b2c5e007fe 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -2,7 +2,6 @@ syntax = "proto3"; package ray.rpc; -option java_multiple_files = true; option java_package = "org.ray.runtime.generated"; // Language of a worker or task. From 1b9f5f9a9f410220475c4146e6137f19c014437e Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 25 Jun 2019 11:38:08 +0800 Subject: [PATCH 32/34] fix warning --- src/ray/raylet/worker_pool.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 719378216fb7..16086565de80 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -48,8 +48,8 @@ WorkerPool::WorkerPool( : num_workers_per_process_(num_workers_per_process), multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), maximum_startup_concurrency_(maximum_startup_concurrency), - gcs_client_(std::move(gcs_client)), - last_warning_multiple_(0) { + last_warning_multiple_(0), + gcs_client_(std::move(gcs_client)) { RAY_CHECK(num_workers_per_process > 0) << "num_workers_per_process must be positive."; RAY_CHECK(maximum_startup_concurrency > 0); // Ignore SIGCHLD signals. If we don't do this, then worker processes will From 3dfb7ae514eb40ff08d32dadcaef726c62b51884 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 25 Jun 2019 13:29:43 +0800 Subject: [PATCH 33/34] fix java bazel --- java/BUILD.bazel | 68 +++++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 5034171db93a..4960434af180 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -49,6 +49,10 @@ define_java_module( define_java_module( name = "runtime", + additional_srcs = [ + ":generate_java_gcs_fbs", + ":gcs_java_proto", + ], additional_resources = [ ":java_native_deps", ], @@ -63,7 +67,6 @@ define_java_module( ], deps = [ ":org_ray_ray_api", - ":copy_auto_generated_files", "@plasma//:org_apache_arrow_arrow_plasma", "@maven//:com_github_davidmoten_flatbuffers_java", "@maven//:com_google_guava_guava", @@ -167,6 +170,22 @@ flatbuffer_java_library( out_prefix = "", ) +genrule( + name = "generate_java_gcs_fbs", + srcs = [":java_gcs_fbs"], + outs = [ + "runtime/src/main/java/org/ray/runtime/generated/" + file for file in flatbuffers_generated_files + ], + cmd = """ + for f in $(locations //java:java_gcs_fbs); do + chmod +w $$f + mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated + done + python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/.. + """, + local = 1, +) + filegroup( name = "java_native_deps", srcs = [ @@ -179,43 +198,14 @@ filegroup( ], ) -genrule( - name = "copy_auto_generated_files", - srcs = [ - ":gcs_java_proto", - ":java_gcs_fbs", - ], - outs = ["copy_auto_generated_files.out"], - cmd = """ - set -x - WORK_DIR=$$(pwd) - # Copy protobuf-generated files. - GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated - rm -rf $$GENERATED_DIR - mkdir -p $$GENERATED_DIR - # TODO(hchen): Only copy files needed by Java. - for f in $(locations //java:gcs_java_proto); do - unzip $$f - mv org/ray/runtime/generated/* $$GENERATED_DIR - done - # Copy flatbuffers-generated files - for f in $(locations //java:java_gcs_fbs); do - chmod +w $$f - cp $$f $$GENERATED_DIR - done - python $$WORK_DIR/java/modify_generated_java_flatbuffers_files.py $$WORK_DIR - echo $$(date) > $@ - """, - local = 1, - tags = ["no-cache"], -) - # Generates the depedencies needed by maven. genrule( name = "gen_maven_deps", srcs = [ - ":copy_auto_generated_files", + ":gcs_java_proto", + ":generate_java_gcs_fbs", ":java_native_deps", + ":copy_pom_file", "@plasma//:org_apache_arrow_arrow_plasma", ], outs = ["gen_maven_deps.out"], @@ -230,6 +220,18 @@ genrule( chmod +w $$f cp $$f $$NATIVE_DEPS_DIR done + # Copy protobuf-generated files. + GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated + rm -rf $$GENERATED_DIR + mkdir -p $$GENERATED_DIR + for f in $(locations //java:gcs_java_proto); do + unzip $$f + mv org/ray/runtime/generated/* $$GENERATED_DIR + done + # Copy flatbuffers-generated files + for f in $(locations //java:generate_java_gcs_fbs); do + cp $$f $$GENERATED_DIR + done # Install plasma jar to local maven repo. mvn install:install-file -Dfile=$(locations @plasma//:org_apache_arrow_arrow_plasma) -Dpackaging=jar \ -DgroupId=org.apache.arrow -DartifactId=arrow-plasma -Dversion=0.13.0-SNAPSHOT From e15925876e226476576a82af9e87edf45f32a592 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Tue, 25 Jun 2019 16:54:30 +0800 Subject: [PATCH 34/34] add python protobuf dep --- python/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/setup.py b/python/setup.py index db8676042de9..e7cf14737ee2 100644 --- a/python/setup.py +++ b/python/setup.py @@ -150,6 +150,7 @@ def find_version(*filepath): "six >= 1.0.0", "flatbuffers", "faulthandler;python_version<'3.3'", + "protobuf", ] setup(