From 6889b951ba5c3c65f810e69330ea30b7595a4273 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Tue, 26 Feb 2019 18:51:04 +0800 Subject: [PATCH 1/7] Use strongly typed IDs for C++. --- python/ray/_raylet.pyx | 25 ++--- python/ray/includes/libraylet.pxd | 2 +- python/ray/includes/task.pxi | 18 ++-- python/ray/includes/unique_ids.pxd | 84 ++++++++++++--- python/ray/includes/unique_ids.pxi | 99 ++++++++++++++--- src/ray/common/common_protocol.cc | 68 ------------ src/ray/common/common_protocol.h | 97 +++++++++++++++-- src/ray/gcs/client_test.cc | 10 +- src/ray/gcs/tables.cc | 28 ++--- src/ray/gcs/tables.h | 14 +-- src/ray/id.cc | 12 ++- src/ray/id.h | 50 ++++++--- src/ray/id_def.h | 13 +++ src/ray/object_manager/object_directory.cc | 2 +- src/ray/object_manager/object_manager.cc | 4 +- .../object_store_notification_manager.cc | 2 +- src/ray/raylet/format/node_manager.fbs | 2 +- src/ray/raylet/lineage_cache.cc | 12 +-- src/ray/raylet/lineage_cache_test.cc | 6 +- src/ray/raylet/monitor.cc | 6 +- src/ray/raylet/node_manager.cc | 100 +++++++++--------- src/ray/raylet/node_manager.h | 2 +- src/ray/raylet/raylet_client.cc | 8 +- src/ray/raylet/raylet_client.h | 9 +- src/ray/raylet/reconstruction_policy_test.cc | 6 +- src/ray/raylet/task_dependency_manager.cc | 2 +- .../raylet/task_dependency_manager_test.cc | 6 +- src/ray/raylet/task_spec.cc | 26 ++--- src/ray/raylet/task_spec.h | 6 +- src/ray/raylet/worker_pool_test.cc | 2 +- 30 files changed, 449 insertions(+), 272 deletions(-) create mode 100644 src/ray/id_def.h diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 8b12a04fd979..94ebd625332b 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -72,7 +72,7 @@ cdef c_vector[CObjectID] ObjectIDsToVector(object_ids): ObjectID object_id c_vector[CObjectID] result for object_id in object_ids: - result.push_back(object_id.data) + result.push_back(object_id.native()) return result @@ -87,11 +87,11 @@ def compute_put_id(TaskID task_id, int64_t put_index): if put_index < 1 or put_index > kMaxTaskPuts: raise ValueError("The range of 'put_index' should be [1, %d]" % kMaxTaskPuts) - return ObjectID(ComputePutId(task_id.data, put_index).binary()) + return ObjectID(ComputePutId(task_id.native(), put_index).binary()) def compute_task_id(ObjectID object_id): - return TaskID(ComputeTaskId(object_id.data).binary()) + return TaskID(ComputeTaskId(object_id.native()).binary()) cdef c_bool is_simple_value(value, int *num_elements_contained): @@ -225,8 +225,8 @@ cdef class RayletClient: # parameter. # TODO(suquark): Should we allow unicode chars in "raylet_socket"? self.client.reset(new CRayletClient( - raylet_socket.encode("ascii"), client_id.data, is_worker, - driver_id.data, LANGUAGE_PYTHON)) + raylet_socket.encode("ascii"), client_id.native(), is_worker, + driver_id.native(), LANGUAGE_PYTHON)) def disconnect(self): check_status(self.client.get().Disconnect()) @@ -252,10 +252,10 @@ cdef class RayletClient: TaskID current_task_id=TaskID.nil()): cdef c_vector[CObjectID] fetch_ids = ObjectIDsToVector(object_ids) check_status(self.client.get().FetchOrReconstruct( - fetch_ids, fetch_only, current_task_id.data)) + fetch_ids, fetch_only, current_task_id.native())) def notify_unblocked(self, TaskID current_task_id): - check_status(self.client.get().NotifyUnblocked(current_task_id.data)) + check_status(self.client.get().NotifyUnblocked(current_task_id.native())) def wait(self, object_ids, int num_returns, int64_t timeout_milliseconds, c_bool wait_local, TaskID current_task_id): @@ -263,11 +263,12 @@ cdef class RayletClient: WaitResultPair result c_vector[CObjectID] wait_ids wait_ids = ObjectIDsToVector(object_ids) + c_task_id = current_task_id.native() with nogil: check_status(self.client.get().Wait(wait_ids, num_returns, timeout_milliseconds, wait_local, - current_task_id.data, &result)) + c_task_id, &result)) return (VectorToObjectIDs(result.first), VectorToObjectIDs(result.second)) @@ -291,9 +292,9 @@ cdef class RayletClient: postincrement(iterator) return resources_dict - def push_error(self, DriverID job_id, error_type, error_message, + def push_error(self, DriverID driver_id, error_type, error_message, double timestamp): - check_status(self.client.get().PushError(job_id.data, + check_status(self.client.get().PushError(driver_id.native(), error_type.encode("ascii"), error_message.encode("ascii"), timestamp)) @@ -354,7 +355,7 @@ cdef class RayletClient: def prepare_actor_checkpoint(self, ActorID actor_id): cdef CActorCheckpointID checkpoint_id - cdef CActorID c_actor_id = actor_id.data + cdef CActorID c_actor_id = actor_id.native() # PrepareActorCheckpoint will wait for raylet's reply, release # the GIL so other Python threads can run. with nogil: @@ -365,7 +366,7 @@ cdef class RayletClient: def notify_actor_resumed_from_checkpoint(self, ActorID actor_id, ActorCheckpointID checkpoint_id): check_status(self.client.get().NotifyActorResumedFromCheckpoint( - actor_id.data, checkpoint_id.data)) + actor_id.native(), checkpoint_id.native())) @property def language(self): diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index 1a4ffb250235..a496c5b83783 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -62,7 +62,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: int num_returns, int64_t timeout_milliseconds, c_bool wait_local, const CTaskID ¤t_task_id, WaitResultPair *result) - CRayStatus PushError(const CDriverID &job_id, const c_string &type, + CRayStatus PushError(const CDriverID &driver_id, const c_string &type, const c_string &error_message, double timestamp) CRayStatus PushProfileEvents( const GCSProfileTableDataT &profile_events) diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index a7cfc684b9d0..872b93d22269 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -54,7 +54,7 @@ cdef class Task: for arg in arguments: if isinstance(arg, ObjectID): references = c_vector[CObjectID]() - references.push_back((arg).data) + references.push_back((arg).native()) task_args.push_back( static_pointer_cast[CTaskArgument, CTaskArgumentByReference]( @@ -71,23 +71,21 @@ cdef class Task: for new_actor_handle in new_actor_handles: task_new_actor_handles.push_back( - (new_actor_handle).data) + (new_actor_handle).native()) self.task_spec.reset(new CTaskSpecification( - CUniqueID(driver_id.data), parent_task_id.data, parent_counter, - actor_creation_id.data, actor_creation_dummy_object_id.data, - max_actor_reconstructions, CUniqueID(actor_id.data), - CUniqueID(actor_handle_id.data), actor_counter, - task_new_actor_handles, task_args, num_returns, - required_resources, required_placement_resources, - LANGUAGE_PYTHON, c_function_descriptor)) + driver_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(), + actor_creation_dummy_object_id.native(), max_actor_reconstructions, actor_id.native(), + actor_handle_id.native(), actor_counter, task_new_actor_handles, task_args, num_returns, + required_resources, required_placement_resources, LANGUAGE_PYTHON, + c_function_descriptor)) # Set the task's execution dependencies. self.execution_dependencies.reset(new c_vector[CObjectID]()) if execution_arguments is not None: for execution_arg in execution_arguments: self.execution_dependencies.get().push_back( - (execution_arg).data) + (execution_arg).native()) @staticmethod cdef make(unique_ptr[CTaskSpecification]& task_spec): diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index fc36f97766c1..a7b77300ba21 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -5,13 +5,14 @@ from libc.stdint cimport uint8_t cdef extern from "ray/id.h" namespace "ray" nogil: cdef cppclass CUniqueID "ray::UniqueID": CUniqueID() + CUniqueID(const c_string &binary) CUniqueID(const CUniqueID &from_id) @staticmethod CUniqueID from_random() @staticmethod - CUniqueID from_binary(const c_string & binary) + CUniqueID from_binary(const c_string &binary) @staticmethod const CUniqueID nil() @@ -26,14 +27,73 @@ cdef extern from "ray/id.h" namespace "ray" nogil: c_string binary() const c_string hex() const -ctypedef CUniqueID CActorCheckpointID -ctypedef CUniqueID CActorClassID -ctypedef CUniqueID CActorHandleID -ctypedef CUniqueID CActorID -ctypedef CUniqueID CClientID -ctypedef CUniqueID CConfigID -ctypedef CUniqueID CDriverID -ctypedef CUniqueID CFunctionID -ctypedef CUniqueID CObjectID -ctypedef CUniqueID CTaskID -ctypedef CUniqueID CWorkerID + cdef cppclass CActorCheckpointID "ray::ActorCheckpointID"(CUniqueID): + CActorCheckpointID() + CActorCheckpointID(const c_string &binary) + CActorCheckpointID(const CUniqueID &from_id) + + + cdef cppclass CActorClassID "ray::ActorClassID"(CUniqueID): + CActorClassID() + CActorClassID(const c_string &binary) + CActorClassID(const CUniqueID &from_id) + + + cdef cppclass CActorID "ray::ActorID"(CUniqueID): + CActorID() + CActorID(const c_string &binary) + CActorID(const CUniqueID &from_id) + + + cdef cppclass CActorHandleID "ray::ActorHandleID"(CUniqueID): + CActorHandleID() + CActorHandleID(const c_string &binary) + CActorHandleID(const CUniqueID &from_id) + + + cdef cppclass CClientID "ray::ClientID"(CUniqueID): + CClientID() + CClientID(const c_string &binary) + CClientID(const CUniqueID &from_id) + + + cdef cppclass CConfigID "ray::ConfigID"(CUniqueID): + CConfigID() + CConfigID(const c_string &binary) + CConfigID(const CUniqueID &from_id) + + + cdef cppclass CFunctionID "ray::FunctionID"(CUniqueID): + CFunctionID() + CFunctionID(const c_string &binary) + CFunctionID(const CUniqueID &from_id) + + + cdef cppclass CDriverID "ray::DriverID"(CUniqueID): + CDriverID() + CDriverID(const c_string &binary) + CDriverID(const CUniqueID &from_id) + + + cdef cppclass CJobID "ray::JobID"(CUniqueID): + CJobID() + CJobID(const c_string &binary) + CJobID(const CUniqueID &from_id) + + + cdef cppclass CTaskID "ray::TaskID"(CUniqueID): + CTaskID() + CTaskID(const c_string &binary) + CTaskID(const CUniqueID &from_id) + + + cdef cppclass CObjectID" ray::ObjectID"(CUniqueID): + CObjectID() + CObjectID(const c_string &binary) + CObjectID(const CUniqueID &from_id) + + + cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID): + CWorkerID() + CWorkerID(const c_string &binary) + CWorkerID(const CUniqueID &from_id) diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 670579737d7c..1badc75b2d22 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -19,6 +19,7 @@ from ray.includes.unique_ids cimport ( CConfigID, CDriverID, CFunctionID, + CJobID, CObjectID, CTaskID, CUniqueID, @@ -42,14 +43,18 @@ cdef extern from "ray/constants.h" nogil: cdef class UniqueID: - cdef CUniqueID data + cdef CUniqueID *data - def __init__(self, id): - if not id: - self.data = CUniqueID() - else: + def __cinit__(self, id): + # The derived class should also check self type and fill self.data. + if type(self) is UniqueID: check_id(id) - self.data = CUniqueID.from_binary(id) + self.data = new CUniqueID(id) + + def __dealloc__(self): + # The derived classes do not need to define __dealloc__, + # the base class will do it. + del self.data @classmethod def from_binary(cls, id_bytes): @@ -59,7 +64,7 @@ cdef class UniqueID: @classmethod def nil(cls): - return cls(b"") + return cls(CUniqueID.nil().binary()) def __hash__(self): return self.data.hash() @@ -106,40 +111,102 @@ cdef class UniqueID: cdef class ObjectID(UniqueID): - pass + + def __cinit__(self, id): + if type(self) is ObjectID: + check_id(id) + self.data = new CObjectID(id) + + cdef CObjectID native(self): + return (self.data)[0] cdef class TaskID(UniqueID): - pass + + def __cinit__(self, id): + if type(self) is TaskID: + check_id(id) + self.data = new CTaskID(id) + + cdef CTaskID native(self): + return (self.data)[0] cdef class ClientID(UniqueID): - pass + + def __cinit__(self, id): + if type(self) is ClientID: + check_id(id) + self.data = new CClientID(id) + + cdef CClientID native(self): + return (self.data)[0] cdef class DriverID(UniqueID): - pass + + def __cinit__(self, id): + if type(self) is DriverID: + check_id(id) + self.data = new CDriverID(id) + + cdef CDriverID native(self): + return (self.data)[0] cdef class ActorID(UniqueID): - pass + + def __cinit__(self, id): + if type(self) is ActorID: + check_id(id) + self.data = new CActorID(id) + + cdef CActorID native(self): + return (self.data)[0] cdef class ActorHandleID(UniqueID): - pass + + def __cinit__(self, id): + if type(self) is ActorHandleID: + check_id(id) + self.data = new CActorHandleID(id) + + cdef CActorHandleID native(self): + return (self.data)[0] cdef class ActorCheckpointID(UniqueID): - pass + + def __cinit__(self, id): + if type(self) is ActorCheckpointID: + check_id(id) + self.data = new CActorCheckpointID(id) + + cdef CActorCheckpointID native(self): + return (self.data)[0] cdef class FunctionID(UniqueID): - pass + + def __cinit__(self, id): + if type(self) is FunctionID: + check_id(id) + self.data = new CFunctionID(id) + + cdef CFunctionID native(self): + return (self.data)[0] cdef class ActorClassID(UniqueID): - pass + def __cinit__(self, id): + if type(self) is ActorClassID: + check_id(id) + self.data = new CActorClassID(id) + + cdef CActorClassID native(self): + return (self.data)[0] _ID_TYPES = [ ActorCheckpointID, diff --git a/src/ray/common/common_protocol.cc b/src/ray/common/common_protocol.cc index f5ed40af570c..adce684fc299 100644 --- a/src/ray/common/common_protocol.cc +++ b/src/ray/common/common_protocol.cc @@ -2,74 +2,6 @@ #include "ray/util/logging.h" -flatbuffers::Offset to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - ray::ObjectID object_id) { - return fbb.CreateString(reinterpret_cast(object_id.data()), - sizeof(ray::ObjectID)); -} - -ray::ObjectID from_flatbuf(const flatbuffers::String &string) { - ray::ObjectID object_id; - RAY_CHECK(string.size() == sizeof(ray::ObjectID)); - memcpy(object_id.mutable_data(), string.data(), sizeof(ray::ObjectID)); - return object_id; -} - -const std::vector from_flatbuf( - const flatbuffers::Vector> &vector) { - std::vector object_ids; - for (int64_t i = 0; i < vector.Length(); i++) { - object_ids.push_back(from_flatbuf(*vector.Get(i))); - } - return object_ids; -} - -const std::vector object_ids_from_flatbuf( - const flatbuffers::String &string) { - const auto &object_ids = string_from_flatbuf(string); - std::vector ret; - RAY_CHECK(object_ids.size() % kUniqueIDSize == 0); - auto count = object_ids.size() / kUniqueIDSize; - - for (size_t i = 0; i < count; ++i) { - auto pos = static_cast(kUniqueIDSize * i); - const auto &id = object_ids.substr(pos, kUniqueIDSize); - ret.push_back(ray::ObjectID::from_binary(id)); - } - - return ret; -} - -flatbuffers::Offset object_ids_to_flatbuf( - flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids) { - std::string result; - for (const auto &id : object_ids) { - result += id.binary(); - } - - return fbb.CreateString(result); -} - -flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ray::ObjectID object_ids[], - int64_t num_objects) { - std::vector> results; - for (int64_t i = 0; i < num_objects; i++) { - results.push_back(to_flatbuf(fbb, object_ids[i])); - } - return fbb.CreateVector(results); -} - -flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - const std::vector &object_ids) { - std::vector> results; - for (auto object_id : object_ids) { - results.push_back(to_flatbuf(fbb, object_id)); - } - return fbb.CreateVector(results); -} - std::string string_from_flatbuf(const flatbuffers::String &string) { return std::string(string.data(), string.size()); } diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index bea4a5b92542..aecd84197f98 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -6,26 +6,30 @@ #include #include "ray/id.h" +#include "ray/util/logging.h" /// Convert an object ID to a flatbuffer string. /// /// @param fbb Reference to the flatbuffer builder. /// @param object_id The object ID to be converted. /// @return The flatbuffer string contining the object ID. +template flatbuffers::Offset to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - ray::ObjectID object_id); + ID object_id); /// Convert a flatbuffer string to an object ID. /// /// @param string The flatbuffer string. /// @return The object ID. -ray::ObjectID from_flatbuf(const flatbuffers::String &string); +template +ID from_flatbuf(const flatbuffers::String &string); /// Convert a flatbuffer vector of strings to a vector of object IDs. /// /// @param vector The flatbuffer vector. /// @return The vector of object IDs. -const std::vector from_flatbuf( +template +const std::vector from_flatbuf( const flatbuffers::Vector> &vector); /// Convert a flatbuffer of string that concatenated @@ -33,8 +37,8 @@ const std::vector from_flatbuf( /// /// @param vector The flatbuffer vector. /// @return The vector of object IDs. -const std::vector object_ids_from_flatbuf( - const flatbuffers::String &string); +template +const std::vector object_ids_from_flatbuf(const flatbuffers::String &string); /// Convert a vector of object IDs to a flatbuffer string. /// The IDs are concatenated to a string with binary. @@ -42,8 +46,9 @@ const std::vector object_ids_from_flatbuf( /// @param fbb Reference to the flatbuffer builder. /// @param object_ids The vector of object IDs. /// @return Flatbuffer string of concatenated IDs. +template flatbuffers::Offset object_ids_to_flatbuf( - flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids); + flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids); /// Convert an array of object IDs to a flatbuffer vector of strings. /// @@ -51,18 +56,18 @@ flatbuffers::Offset object_ids_to_flatbuf( /// @param object_ids Array of object IDs. /// @param num_objects Number of elements in the array. /// @return Flatbuffer vector of strings. +template flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ray::ObjectID object_ids[], - int64_t num_objects); +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID object_ids[], int64_t num_objects); /// Convert a vector of object IDs to a flatbuffer vector of strings. /// /// @param fbb Reference to the flatbuffer builder. /// @param object_ids Vector of object IDs. /// @return Flatbuffer vector of strings. +template flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - const std::vector &object_ids); +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids); /// Convert a flatbuffer string to a std::string. /// @@ -95,4 +100,76 @@ std::vector string_vec_from_flatbuf( flatbuffers::Offset>> string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &string_vector); + +template +flatbuffers::Offset to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, + ID object_id) { + return fbb.CreateString(reinterpret_cast(object_id.data()), sizeof(ID)); +} + +template +ID from_flatbuf(const flatbuffers::String &string) { + ID object_id; + RAY_CHECK(string.size() == sizeof(ID)); + memcpy(object_id.mutable_data(), string.data(), sizeof(ID)); + return object_id; +} + +template +const std::vector from_flatbuf( + const flatbuffers::Vector> &vector) { + std::vector object_ids; + for (int64_t i = 0; i < vector.Length(); i++) { + object_ids.push_back(from_flatbuf(*vector.Get(i))); + } + return object_ids; +} + +template +const std::vector object_ids_from_flatbuf(const flatbuffers::String &string) { + const auto &object_ids = string_from_flatbuf(string); + std::vector ret; + RAY_CHECK(object_ids.size() % kUniqueIDSize == 0); + auto count = object_ids.size() / kUniqueIDSize; + + for (size_t i = 0; i < count; ++i) { + auto pos = static_cast(kUniqueIDSize * i); + const auto &id = object_ids.substr(pos, kUniqueIDSize); + ret.push_back(ID::from_binary(id)); + } + + return ret; +} + +template +flatbuffers::Offset object_ids_to_flatbuf( + flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids) { + std::string result; + for (const auto &id : object_ids) { + result += id.binary(); + } + + return fbb.CreateString(result); +} + +template +flatbuffers::Offset>> +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID object_ids[], int64_t num_objects) { + std::vector> results; + for (int64_t i = 0; i < num_objects; i++) { + results.push_back(to_flatbuf(fbb, object_ids[i])); + } + return fbb.CreateVector(results); +} + +template +flatbuffers::Offset>> +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids) { + std::vector> results; + for (auto object_id : object_ids) { + results.push_back(to_flatbuf(fbb, object_id)); + } + return fbb.CreateVector(results); +} + #endif diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 6bf2a53156be..b7aab1582ef0 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -814,7 +814,7 @@ void TestClientTableConnect(const JobID &job_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, true); test->Stop(); }); @@ -839,14 +839,14 @@ void TestClientTableDisconnect(const JobID &job_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, /*is_insertion=*/true); // Disconnect from the client table. We should receive a notification // for the removal of our own entry. RAY_CHECK_OK(client->client_table().Disconnect()); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, /*is_insertion=*/false); test->Stop(); }); @@ -870,11 +870,11 @@ void TestClientTableImmediateDisconnect(const JobID &job_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, true); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, false); test->Stop(); }); diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 87a72258ba2a..8e60c3a0d96f 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -91,7 +91,7 @@ Status Log::Lookup(const JobID &job_id, const ID &id, const Callback & std::vector results; if (!data.empty()) { auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); + RAY_CHECK(from_flatbuf(*root->id()) == id); for (size_t i = 0; i < root->entries()->size(); i++) { DataT result; auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); @@ -128,7 +128,7 @@ Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, auto root = flatbuffers::GetRoot(data.data()); ID id; if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); + id = from_flatbuf(*root->id()); } std::vector results; for (size_t i = 0; i < root->entries()->size(); i++) { @@ -274,18 +274,18 @@ std::string Table::DebugString() const { return result.str(); } -Status ErrorTable::PushErrorToDriver(const JobID &job_id, const std::string &type, +Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { auto data = std::make_shared(); - data->job_id = job_id.binary(); + data->job_id = driver_id.binary(); data->type = type; data->error_message = error_message; data->timestamp = timestamp; - return Append(job_id, job_id, data, /*done_callback=*/nullptr); + return Append(JobID(driver_id), driver_id, data, /*done_callback=*/nullptr); } std::string ErrorTable::DebugString() const { - return Log::DebugString(); + return Log::DebugString(); } Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { @@ -302,11 +302,11 @@ std::string ProfileTable::DebugString() const { return Log::DebugString(); } -Status DriverTable::AppendDriverData(const JobID &driver_id, bool is_dead) { +Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { auto data = std::make_shared(); data->driver_id = driver_id.binary(); data->is_dead = is_dead; - return Append(driver_id, driver_id, data, /*done_callback=*/nullptr); + return Append(JobID(driver_id), driver_id, data, /*done_callback=*/nullptr); } void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) { @@ -492,7 +492,7 @@ Status ClientTable::Lookup(const Callback &lookup) { std::string ClientTable::DebugString() const { std::stringstream result; - result << Log::DebugString(); + result << Log::DebugString(); result << ", cache size: " << client_cache_.size() << ", num removed: " << removed_clients_.size(); return result.str(); @@ -500,7 +500,7 @@ std::string ClientTable::DebugString() const { Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, const ActorID &actor_id, - const UniqueID &checkpoint_id) { + const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, job_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id, const ActorCheckpointIdDataT &data) { @@ -512,7 +512,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, while (copy->timestamps.size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. const auto &checkpoint_id = - UniqueID::from_binary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); + ActorCheckpointID::from_binary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor " << actor_id; copy->timestamps.erase(copy->timestamps.begin()); @@ -542,9 +542,9 @@ template class Log; template class Table; template class Table; template class Table; -template class Log; -template class Log; -template class Log; +template class Log; +template class Log; +template class Log; template class Log; template class Table; template class Table; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 71e1c39d6da7..2aabf2ae3e01 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -382,7 +382,7 @@ class HeartbeatBatchTable : public Table { virtual ~HeartbeatBatchTable() {} }; -class DriverTable : public Log { +class DriverTable : public Log { public: DriverTable(const std::vector> &contexts, AsyncGcsClient *client) @@ -398,7 +398,7 @@ class DriverTable : public Log { /// \param driver_id The driver id. /// \param is_dead Whether the driver is dead. /// \return The return status. - Status AppendDriverData(const JobID &driver_id, bool is_dead); + Status AppendDriverData(const DriverID &driver_id, bool is_dead); }; class FunctionTable : public Table { @@ -488,7 +488,7 @@ class ActorCheckpointIdTable : public Table { /// \param checkpoint_id ID of the checkpoint. /// \return Status. Status AddCheckpointId(const JobID &job_id, const ActorID &actor_id, - const UniqueID &checkpoint_id); + const ActorCheckpointID &checkpoint_id); }; namespace raylet { @@ -511,7 +511,7 @@ class TaskTable : public Table { } // namespace raylet -class ErrorTable : private Log { +class ErrorTable : private Log { public: ErrorTable(const std::vector> &contexts, AsyncGcsClient *client) @@ -532,7 +532,7 @@ class ErrorTable : private Log { /// \param error_message The error message to push. /// \param timestamp The timestamp of the error. /// \return Status. - Status PushErrorToDriver(const JobID &job_id, const std::string &type, + Status PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp); /// Returns debug string for class. @@ -574,7 +574,7 @@ using ConfigTable = Table; /// it should append an entry to the log indicating that it is dead. A client /// that is marked as dead should never again be marked as alive; if it needs /// to reconnect, it must connect with a different ClientID. -class ClientTable : private Log { +class ClientTable : private Log { public: using ClientTableCallback = std::function; @@ -678,7 +678,7 @@ class ClientTable : private Log { /// The key at which the log of client information is stored. This key must /// be kept the same across all instances of the ClientTable, so that all /// clients append and read from the same key. - UniqueID client_log_key_; + ClientID client_log_key_; /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. diff --git a/src/ray/id.cc b/src/ray/id.cc index 70454bbdfb0d..15e86910e0c1 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -33,6 +33,10 @@ UniqueID::UniqueID(const plasma::UniqueID &from) { std::memcpy(&id_, from.data(), kUniqueIDSize); } +UniqueID::UniqueID(const std::string &binary) { + std::memcpy(id_, binary.data(), kUniqueIDSize); +} + UniqueID UniqueID::from_random() { UniqueID id; uint8_t *data = id.mutable_data(); @@ -165,7 +169,7 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id) { const ObjectID ComputeObjectId(const TaskID &task_id, int64_t object_index) { RAY_CHECK(object_index <= kMaxTaskReturns && object_index >= -kMaxTaskPuts); - ObjectID return_id = task_id; + ObjectID return_id = ObjectID(task_id); int64_t *first_bytes = reinterpret_cast(&return_id); // Zero out the lowest kObjectIdIndexSize bits of the first byte of the // object ID. @@ -176,7 +180,9 @@ const ObjectID ComputeObjectId(const TaskID &task_id, int64_t object_index) { return return_id; } -const TaskID FinishTaskId(const TaskID &task_id) { return ComputeObjectId(task_id, 0); } +const TaskID FinishTaskId(const TaskID &task_id) { + return TaskID(ComputeObjectId(task_id, 0)); +} const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index) { RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns); @@ -190,7 +196,7 @@ const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index) { } const TaskID ComputeTaskId(const ObjectID &object_id) { - TaskID task_id = object_id; + TaskID task_id = TaskID(object_id); int64_t *first_bytes = reinterpret_cast(&task_id); // Zero out the lowest kObjectIdIndexSize bits of the first byte of the // object ID. diff --git a/src/ray/id.h b/src/ray/id.h index 562365951fc2..94837dc271c3 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -15,6 +15,7 @@ namespace ray { class RAY_EXPORT UniqueID { public: UniqueID(); + UniqueID(const std::string &binary); UniqueID(const plasma::UniqueID &from); static UniqueID from_random(); static UniqueID from_binary(const std::string &binary); @@ -30,7 +31,7 @@ class RAY_EXPORT UniqueID { std::string hex() const; plasma::UniqueID to_plasma_id() const; - private: + protected: uint8_t id_[kUniqueIDSize]; }; @@ -38,18 +39,25 @@ static_assert(std::is_standard_layout::value, "UniqueID must be standa std::ostream &operator<<(std::ostream &os, const UniqueID &id); -typedef UniqueID TaskID; -typedef UniqueID JobID; -typedef UniqueID ObjectID; -typedef UniqueID FunctionID; -typedef UniqueID ActorClassID; -typedef UniqueID ActorID; -typedef UniqueID ActorHandleID; -typedef UniqueID ActorCheckpointID; -typedef UniqueID WorkerID; -typedef UniqueID DriverID; -typedef UniqueID ConfigID; -typedef UniqueID ClientID; +#define DEFINE_UNIQUE_ID(type) \ + class RAY_EXPORT type : public UniqueID { \ + public: \ + explicit type(const UniqueID &from) { \ + std::memcpy(&id_, from.data(), kUniqueIDSize); \ + } \ + type() : UniqueID() {} \ + type(const std::string &binary) { std::memcpy(id_, binary.data(), kUniqueIDSize); } \ + const UniqueID &get() { return *this; } \ + static type from_random() { return type(UniqueID::from_random()); } \ + static type from_binary(const std::string &binary) { \ + return type(UniqueID::from_binary(binary)); \ + } \ + static type nil() { return type(UniqueID::nil()); } \ + }; + +#include "id_def.h" + +#undef DEFINE_UNIQUE_ID // TODO(swang): ObjectID and TaskID should derive from UniqueID. Then, we // can make these methods of the derived classes. @@ -110,5 +118,19 @@ template <> struct hash { size_t operator()(const ::ray::UniqueID &id) const { return id.hash(); } }; -} + +#define DEFINE_UNIQUE_ID(type) \ + template <> \ + struct hash<::ray::type> { \ + size_t operator()(const ::ray::type &id) const { return id.hash(); } \ + }; \ + template <> \ + struct hash { \ + size_t operator()(const ::ray::type &id) const { return id.hash(); } \ + }; + +#include "id_def.h" + +#undef DEFINE_UNIQUE_ID +} // namespace std #endif // RAY_ID_H_ diff --git a/src/ray/id_def.h b/src/ray/id_def.h new file mode 100644 index 000000000000..4f115298e5fb --- /dev/null +++ b/src/ray/id_def.h @@ -0,0 +1,13 @@ + +DEFINE_UNIQUE_ID(TaskID); +DEFINE_UNIQUE_ID(JobID); +DEFINE_UNIQUE_ID(ObjectID); +DEFINE_UNIQUE_ID(FunctionID); +DEFINE_UNIQUE_ID(ActorClassID); +DEFINE_UNIQUE_ID(ActorID); +DEFINE_UNIQUE_ID(ActorHandleID); +DEFINE_UNIQUE_ID(ActorCheckpointID); +DEFINE_UNIQUE_ID(WorkerID); +DEFINE_UNIQUE_ID(DriverID); +DEFINE_UNIQUE_ID(ConfigID); +DEFINE_UNIQUE_ID(ClientID); diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 51cb2600beb3..d9f7b87a700e 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -78,7 +78,7 @@ void ObjectDirectory::RegisterBackend() { } }; RAY_CHECK_OK(gcs_client_->object_table().Subscribe( - UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), + JobID::nil(), gcs_client_->client_table().GetLocalClientId(), object_notification_callback, nullptr)); } diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 5459985e5b61..7c949be311a6 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -767,7 +767,7 @@ void ObjectManager::ConnectClient(std::shared_ptr &conn, // TODO: trash connection on failure. auto info = flatbuffers::GetRoot(message); - ClientID client_id = ObjectID::from_binary(info->client_id()->str()); + ClientID client_id = ClientID::from_binary(info->client_id()->str()); bool is_transfer = info->is_transfer(); conn->SetClientID(client_id); if (is_transfer) { @@ -885,7 +885,7 @@ void ObjectManager::ReceiveFreeRequest(std::shared_ptr &con const uint8_t *message) { auto free_request = flatbuffers::GetRoot(message); - std::vector object_ids = from_flatbuf(*free_request->object_ids()); + std::vector object_ids = from_flatbuf(*free_request->object_ids()); // This RPC should come from another Object Manager. // Keep this request local. bool local_only = true; diff --git a/src/ray/object_manager/object_store_notification_manager.cc b/src/ray/object_manager/object_store_notification_manager.cc index aa19787f3c37..746f4d622d5a 100644 --- a/src/ray/object_manager/object_store_notification_manager.cc +++ b/src/ray/object_manager/object_store_notification_manager.cc @@ -58,7 +58,7 @@ void ObjectStoreNotificationManager::ProcessStoreNotification( const auto &object_info = flatbuffers::GetRoot(notification_.data()); - const auto &object_id = from_flatbuf(*object_info->object_id()); + const auto &object_id = from_flatbuf(*object_info->object_id()); if (object_info->is_deletion()) { ProcessStoreRemove(object_id); } else { diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 710928cdbd88..20bb1c735c1c 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -196,7 +196,7 @@ table WaitReply { // This struct is the same as ErrorTableData. table PushErrorRequest { // The ID of the job that the error is for. - job_id: string; + driver_id: string; // The type of the error. type: string; // The error message. diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 93e56a93a81b..49f39487eda1 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -346,10 +346,9 @@ void LineageCache::FlushTask(const TaskID &task_id) { RAY_CHECK(entry); RAY_CHECK(entry->GetStatus() == GcsStatus::UNCOMMITTED_READY); - gcs::raylet::TaskTable::WriteCallback task_callback = [this]( - ray::gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) { - HandleEntryCommitted(id); - }; + gcs::raylet::TaskTable::WriteCallback task_callback = + [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, + const protocol::TaskT &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... flatbuffers::FlatBufferBuilder fbb; @@ -358,8 +357,9 @@ void LineageCache::FlushTask(const TaskID &task_id) { auto task_data = std::make_shared(); auto root = flatbuffers::GetRoot(fbb.GetBufferPointer()); root->UnPackTo(task_data.get()); - RAY_CHECK_OK(task_storage_.Add(task->TaskData().GetTaskSpecification().DriverId(), - task_id, task_data, task_callback)); + RAY_CHECK_OK( + task_storage_.Add(JobID(task->TaskData().GetTaskSpecification().DriverId()), + task_id, task_data, task_callback)); // We successfully wrote the task, so mark it as committing. // TODO(swang): Use a batched interface and write with all object entries. diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 973483759e4b..1ed0dcc84f39 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -113,9 +113,9 @@ static inline Task ExampleTask(const std::vector &arguments, task_arguments.emplace_back(std::make_shared(references)); } std::vector function_descriptor(3); - auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0, - task_arguments, num_returns, required_resources, - Language::PYTHON, function_descriptor); + auto spec = TaskSpecification(DriverID::nil(), TaskID::from_random(), 0, task_arguments, + num_returns, required_resources, Language::PYTHON, + function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); execution_spec.IncrementNumForwards(); Task task = Task(execution_spec, spec); diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 30f05de226c4..d18edbad8238 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -35,7 +35,7 @@ void Monitor::Start() { HandleHeartbeat(id, heartbeat_data); }; RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe( - UniqueID::nil(), UniqueID::nil(), heartbeat_callback, nullptr, nullptr)); + JobID::nil(), ClientID::nil(), heartbeat_callback, nullptr, nullptr)); Tick(); } @@ -69,7 +69,7 @@ void Monitor::Tick() { << " has missed too many heartbeats from it."; // We use the nil JobID to broadcast the message to all drivers. RAY_CHECK_OK(gcs_client_.error_table().PushErrorToDriver( - JobID::nil(), type, error_message.str(), current_time_ms())); + DriverID::nil(), type, error_message.str(), current_time_ms())); } }; RAY_CHECK_OK(gcs_client_.client_table().Lookup(lookup_callback)); @@ -88,7 +88,7 @@ void Monitor::Tick() { batch->batch.push_back(std::unique_ptr( new HeartbeatTableDataT(heartbeat.second))); } - RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(UniqueID::nil(), UniqueID::nil(), + RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(JobID::nil(), ClientID::nil(), batch, nullptr)); heartbeat_buffer_.clear(); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 684cad003b87..c9a06cecc106 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -145,7 +145,7 @@ ray::Status NodeManager::RegisterGcs() { }; RAY_RETURN_NOT_OK(gcs_client_->actor_table().Subscribe( - UniqueID::nil(), UniqueID::nil(), actor_notification_callback, nullptr)); + JobID::nil(), ClientID::nil(), actor_notification_callback, nullptr)); // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, @@ -154,30 +154,29 @@ ray::Status NodeManager::RegisterGcs() { }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. - auto node_manager_client_removed = [this]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { - ClientRemoved(data); - }; + auto node_manager_client_removed = + [this](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { ClientRemoved(data); }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Subscribe to heartbeat batches from the monitor. - const auto &heartbeat_batch_added = [this]( - gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatBatchTableDataT &heartbeat_batch) { - HeartbeatBatchAdded(heartbeat_batch); - }; + const auto &heartbeat_batch_added = + [this](gcs::AsyncGcsClient *client, const ClientID &id, + const HeartbeatBatchTableDataT &heartbeat_batch) { + HeartbeatBatchAdded(heartbeat_batch); + }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( - UniqueID::nil(), UniqueID::nil(), heartbeat_batch_added, + JobID::nil(), ClientID::nil(), heartbeat_batch_added, /*subscribe_callback=*/nullptr, /*done_callback=*/nullptr)); // Subscribe to driver table updates. - const auto driver_table_handler = [this]( - gcs::AsyncGcsClient *client, const ClientID &client_id, - const std::vector &driver_data) { - HandleDriverTableUpdate(client_id, driver_data); - }; - RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe(JobID::nil(), UniqueID::nil(), + const auto driver_table_handler = + [this](gcs::AsyncGcsClient *client, const DriverID &client_id, + const std::vector &driver_data) { + HandleDriverTableUpdate(client_id, driver_data); + }; + RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe(JobID::nil(), ClientID::nil(), driver_table_handler, nullptr)); // Start sending heartbeats to the GCS. @@ -210,12 +209,12 @@ void NodeManager::KillWorker(std::shared_ptr worker) { } void NodeManager::HandleDriverTableUpdate( - const ClientID &id, const std::vector &driver_data) { + const DriverID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::from_binary(entry.driver_id) << " " << entry.is_dead; if (entry.is_dead) { - auto driver_id = UniqueID::from_binary(entry.driver_id); + auto driver_id = DriverID::from_binary(entry.driver_id); auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); // Kill all the workers. The actual cleanup for these workers is done @@ -270,7 +269,7 @@ void NodeManager::Heartbeat() { } ray::Status status = heartbeat_table.Add( - UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, + JobID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, /*success_callback=*/nullptr); RAY_CHECK_OK_PREPEND(status, "Heartbeat failed"); @@ -351,7 +350,7 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { << ". This may be since the node was recently removed."; // We use the nil JobID to broadcast the message to all drivers. RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - JobID::nil(), type, error_message.str(), current_time_ms())); + DriverID::nil(), type, error_message.str(), current_time_ms())); return; } @@ -684,7 +683,7 @@ void NodeManager::ProcessClientMessage( } break; case protocol::MessageType::NotifyUnblocked: { auto message = flatbuffers::GetRoot(message_data); - HandleTaskUnblocked(client, from_flatbuf(*message->task_id())); + HandleTaskUnblocked(client, from_flatbuf(*message->task_id())); } break; case protocol::MessageType::WaitRequest: { ProcessWaitRequestMessage(client, message_data); @@ -698,7 +697,7 @@ void NodeManager::ProcessClientMessage( } break; case protocol::MessageType::FreeObjectsInObjectStoreRequest: { auto message = flatbuffers::GetRoot(message_data); - std::vector object_ids = from_flatbuf(*message->object_ids()); + std::vector object_ids = from_flatbuf(*message->object_ids()); object_manager_.FreeObjects(object_ids, message->local_only()); } break; case protocol::MessageType::PrepareActorCheckpointRequest: { @@ -719,7 +718,7 @@ void NodeManager::ProcessClientMessage( void NodeManager::ProcessRegisterClientRequestMessage( const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); - client->SetClientID(from_flatbuf(*message->client_id())); + client->SetClientID(from_flatbuf(*message->client_id())); auto worker = std::make_shared(message->worker_pid(), message->language(), client); if (message->is_worker()) { @@ -731,11 +730,11 @@ void NodeManager::ProcessRegisterClientRequestMessage( // message is actually the ID of the driver task, while client_id represents the // real driver ID, which can associate all the tasks/actors for a given driver, // which is set to the worker ID. - const JobID driver_task_id = from_flatbuf(*message->driver_id()); - worker->AssignTaskId(driver_task_id); - worker->AssignDriverId(from_flatbuf(*message->client_id())); + const JobID driver_task_id = from_flatbuf(*message->driver_id()); + worker->AssignTaskId(TaskID(driver_task_id)); + worker->AssignDriverId(from_flatbuf(*message->client_id())); worker_pool_.RegisterDriver(std::move(worker)); - local_queues_.AddDriverTaskId(driver_task_id); + local_queues_.AddDriverTaskId(TaskID(driver_task_id)); } } @@ -865,14 +864,14 @@ void NodeManager::ProcessDisconnectClientMessage( if (!intentional_disconnect) { // Push the error to driver. - const JobID &job_id = worker->GetAssignedDriverId(); + const DriverID &driver_id = worker->GetAssignedDriverId(); // TODO(rkn): Define this constant somewhere else. std::string type = "worker_died"; std::ostringstream error_message; error_message << "A worker died or was killed while executing task " << task_id << "."; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - job_id, type, error_message.str(), current_time_ms())); + driver_id, type, error_message.str(), current_time_ms())); } } @@ -899,8 +898,9 @@ void NodeManager::ProcessDisconnectClientMessage( DispatchTasks(local_queues_.GetReadyTasksWithResources()); } else if (is_driver) { // The client is a driver. - RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientId(), - /*is_dead=*/true)); + RAY_CHECK_OK( + gcs_client_->driver_table().AppendDriverData(DriverID(client->GetClientId()), + /*is_dead=*/true)); auto driver_id = worker->GetAssignedTaskId(); RAY_CHECK(!driver_id.is_nil()); local_queues_.RemoveDriverTaskId(driver_id); @@ -919,7 +919,7 @@ void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { // Read the task submitted by the client. auto message = flatbuffers::GetRoot(message_data); TaskExecutionSpecification task_execution_spec( - from_flatbuf(*message->execution_dependencies())); + from_flatbuf(*message->execution_dependencies())); TaskSpecification task_spec(*message->task_spec()); Task task(task_execution_spec, task_spec); // Submit the task to the local scheduler. Since the task was submitted @@ -932,7 +932,7 @@ void NodeManager::ProcessFetchOrReconstructMessage( auto message = flatbuffers::GetRoot(message_data); std::vector required_object_ids; for (size_t i = 0; i < message->object_ids()->size(); ++i) { - ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i)); + ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i)); if (message->fetch_only()) { // If only a fetch is required, then do not subscribe to the // dependencies to the task dependency manager. @@ -950,7 +950,7 @@ void NodeManager::ProcessFetchOrReconstructMessage( } if (!required_object_ids.empty()) { - const TaskID task_id = from_flatbuf(*message->task_id()); + const TaskID task_id = from_flatbuf(*message->task_id()); HandleTaskBlocked(client, required_object_ids, task_id); } } @@ -959,7 +959,7 @@ void NodeManager::ProcessWaitRequestMessage( const std::shared_ptr &client, const uint8_t *message_data) { // Read the data. auto message = flatbuffers::GetRoot(message_data); - std::vector object_ids = from_flatbuf(*message->object_ids()); + std::vector object_ids = from_flatbuf(*message->object_ids()); int64_t wait_ms = message->timeout(); uint64_t num_required_objects = static_cast(message->num_ready_objects()); bool wait_local = message->wait_local(); @@ -974,7 +974,7 @@ void NodeManager::ProcessWaitRequestMessage( } } - const TaskID ¤t_task_id = from_flatbuf(*message->task_id()); + const TaskID ¤t_task_id = from_flatbuf(*message->task_id()); bool client_blocked = !required_object_ids.empty(); if (client_blocked) { HandleTaskBlocked(client, required_object_ids, current_task_id); @@ -1012,20 +1012,20 @@ void NodeManager::ProcessWaitRequestMessage( void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); - JobID job_id = from_flatbuf(*message->job_id()); + DriverID driver_id = from_flatbuf(*message->driver_id()); auto const &type = string_from_flatbuf(*message->type()); auto const &error_message = string_from_flatbuf(*message->error_message()); double timestamp = message->timestamp(); - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message, - timestamp)); + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(driver_id, type, + error_message, timestamp)); } void NodeManager::ProcessPrepareActorCheckpointRequest( const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); - ActorID actor_id = from_flatbuf(*message->actor_id()); + ActorID actor_id = from_flatbuf(*message->actor_id()); RAY_LOG(DEBUG) << "Preparing checkpoint for actor " << actor_id; const auto &actor_entry = actor_registry_.find(actor_id); RAY_CHECK(actor_entry != actor_registry_.end()); @@ -1037,15 +1037,15 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( const auto task_id = worker->GetAssignedTaskId(); const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING); // Generate checkpoint id and data. - ActorCheckpointID checkpoint_id = UniqueID::from_random(); + ActorCheckpointID checkpoint_id = ActorCheckpointID::from_random(); auto checkpoint_data = actor_entry->second.GenerateCheckpointData(actor_entry->first, task); // Write checkpoint data to GCS. RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add( - UniqueID::nil(), checkpoint_id, checkpoint_data, + JobID::nil(), checkpoint_id, checkpoint_data, [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, - const UniqueID &checkpoint_id, + const ActorCheckpointID &checkpoint_id, const ActorCheckpointDataT &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " << worker->GetActorId(); @@ -1072,8 +1072,9 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( void NodeManager::ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); - ActorID actor_id = from_flatbuf(*message->actor_id()); - ActorCheckpointID checkpoint_id = from_flatbuf(*message->checkpoint_id()); + ActorID actor_id = from_flatbuf(*message->actor_id()); + ActorCheckpointID checkpoint_id = + from_flatbuf(*message->checkpoint_id()); RAY_LOG(DEBUG) << "Actor " << actor_id << " was resumed from checkpoint " << checkpoint_id; checkpoint_id_to_restore_.emplace(actor_id, checkpoint_id); @@ -1093,12 +1094,12 @@ void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_cl switch (message_type_value) { case protocol::MessageType::ConnectClient: { auto message = flatbuffers::GetRoot(message_data); - auto client_id = from_flatbuf(*message->client_id()); + auto client_id = from_flatbuf(*message->client_id()); node_manager_client.SetClientID(client_id); } break; case protocol::MessageType::ForwardTaskRequest: { auto message = flatbuffers::GetRoot(message_data); - TaskID task_id = from_flatbuf(*message->task_id()); + TaskID task_id = from_flatbuf(*message->task_id()); Lineage uncommitted_lineage(*message); const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); @@ -1589,7 +1590,7 @@ bool NodeManager::AssignTask(const Task &task) { const std::string warning_message = worker_pool_.WarningAboutSize(); if (warning_message != "") { RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - JobID::nil(), "worker_pool_large", warning_message, current_time_ms())); + DriverID::nil(), "worker_pool_large", warning_message, current_time_ms())); } } // We couldn't assign this task, as no worker available. @@ -1902,7 +1903,6 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { // Use a copy of the cached task spec to re-execute the task. const Task task = lineage_cache_.GetTaskOrDie(task_id); ResubmitTask(task); - })); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 061ef5ef8969..1e97c380b1f5 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -326,7 +326,7 @@ class NodeManager { /// \param id An unused value. TODO(rkn): Should this be removed? /// \param driver_data Data associated with a driver table event. /// \return Void. - void HandleDriverTableUpdate(const ClientID &id, + void HandleDriverTableUpdate(const DriverID &id, const std::vector &driver_data); /// Check if certain invariants associated with the task dependency manager diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 13e92d0c4ccc..9299cd599d0d 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -202,7 +202,7 @@ ray::Status RayletConnection::AtomicRequestReply( } RayletClient::RayletClient(const std::string &raylet_socket, const UniqueID &client_id, - bool is_worker, const JobID &driver_id, + bool is_worker, const DriverID &driver_id, const Language &language) : client_id_(client_id), is_worker_(is_worker), @@ -323,11 +323,11 @@ ray::Status RayletClient::Wait(const std::vector &object_ids, int num_ return ray::Status::OK(); } -ray::Status RayletClient::PushError(const JobID &job_id, const std::string &type, +ray::Status RayletClient::PushError(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { flatbuffers::FlatBufferBuilder fbb; auto message = ray::protocol::CreatePushErrorRequest( - fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), + fbb, to_flatbuf(fbb, driver_id), fbb.CreateString(type), fbb.CreateString(error_message), timestamp); fbb.Finish(message); @@ -373,7 +373,7 @@ ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, if (!status.ok()) return status; auto reply_message = flatbuffers::GetRoot(reply.get()); - checkpoint_id = ObjectID::from_binary(reply_message->checkpoint_id()->str()); + checkpoint_id = ActorCheckpointID::from_binary(reply_message->checkpoint_id()->str()); return ray::Status::OK(); } diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index d3ea765df65c..39b3092ca42d 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -9,13 +9,14 @@ #include "ray/raylet/task_spec.h" #include "ray/status.h" -using ray::ActorID; using ray::ActorCheckpointID; +using ray::ActorID; +using ray::ClientID; +using ray::DriverID; using ray::JobID; using ray::ObjectID; using ray::TaskID; using ray::UniqueID; -using ray::ClientID; using MessageType = ray::protocol::MessageType; using ResourceMappingType = @@ -69,7 +70,7 @@ class RayletClient { /// \param driver_id The ID of the driver. This is non-nil if the client is a driver. /// \return The connection information. RayletClient(const std::string &raylet_socket, const UniqueID &client_id, - bool is_worker, const JobID &driver_id, const Language &language); + bool is_worker, const DriverID &driver_id, const Language &language); ray::Status Disconnect() { return conn_->Disconnect(); }; @@ -130,7 +131,7 @@ class RayletClient { /// \param The error message. /// \param The timestamp of the error. /// \return ray::Status. - ray::Status PushError(const JobID &job_id, const std::string &type, + ray::Status PushError(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp); /// Store some profile events in the GCS. diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 5e9ae6d7e521..093f5c236261 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -322,7 +322,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { task_lease_data->node_manager_id = ClientID::from_random().binary(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = 2 * test_period; - mock_gcs_.Add(DriverID::nil(), task_id, task_lease_data); + mock_gcs_.Add(JobID::nil(), task_id, task_lease_data); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -350,7 +350,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { task_lease_data->node_manager_id = ClientID::from_random().binary(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = reconstruction_timeout_ms_; - mock_gcs_.Add(DriverID::nil(), task_id, task_lease_data); + mock_gcs_.Add(JobID::nil(), task_id, task_lease_data); }); // Run the test for much longer than the reconstruction timeout. Run(reconstruction_timeout_ms_ * 2); @@ -404,7 +404,7 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { task_reconstruction_data->node_manager_id = ClientID::from_random().binary(); task_reconstruction_data->num_reconstructions = 0; RAY_CHECK_OK( - mock_gcs_.AppendAt(DriverID::nil(), task_id, task_reconstruction_data, nullptr, + mock_gcs_.AppendAt(JobID::nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, const TaskReconstructionDataT &data) { ASSERT_TRUE(false); }, diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index fe4364c4491f..2f1b64a87480 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -263,7 +263,7 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { task_lease_data->node_manager_id = client_id_.hex(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = it->second.lease_period; - RAY_CHECK_OK(task_lease_table_.Add(DriverID::nil(), task_id, task_lease_data, nullptr)); + RAY_CHECK_OK(task_lease_table_.Add(JobID::nil(), task_id, task_lease_data, nullptr)); auto period = boost::posix_time::milliseconds(it->second.lease_period / 2); it->second.lease_timer->expires_from_now(period); diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index f414d7469565..e0d30bf9ebd6 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -75,9 +75,9 @@ static inline Task ExampleTask(const std::vector &arguments, task_arguments.emplace_back(std::make_shared(references)); } std::vector function_descriptor(3); - auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0, - task_arguments, num_returns, required_resources, - Language::PYTHON, function_descriptor); + auto spec = TaskSpecification(DriverID::nil(), TaskID::from_random(), 0, task_arguments, + num_returns, required_resources, Language::PYTHON, + function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); execution_spec.IncrementNumForwards(); Task task = Task(execution_spec, spec); diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index a8c0f40fed60..ea5fe5248041 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -57,7 +57,7 @@ TaskSpecification::TaskSpecification(const std::string &string) { } TaskSpecification::TaskSpecification( - const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const Language &language, const std::vector &function_descriptor) @@ -68,7 +68,7 @@ TaskSpecification::TaskSpecification( function_descriptor) {} TaskSpecification::TaskSpecification( - const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, const int64_t max_actor_reconstructions, const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter, @@ -122,15 +122,15 @@ size_t TaskSpecification::size() const { return spec_.size(); } // Task specification getter methods. TaskID TaskSpecification::TaskId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->task_id()); + return from_flatbuf(*message->task_id()); } -UniqueID TaskSpecification::DriverId() const { +DriverID TaskSpecification::DriverId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->driver_id()); + return from_flatbuf(*message->driver_id()); } TaskID TaskSpecification::ParentTaskId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->parent_task_id()); + return from_flatbuf(*message->parent_task_id()); } int64_t TaskSpecification::ParentCounter() const { auto message = flatbuffers::GetRoot(spec_.data()); @@ -168,7 +168,7 @@ int64_t TaskSpecification::NumReturns() const { ObjectID TaskSpecification::ReturnId(int64_t return_index) const { auto message = flatbuffers::GetRoot(spec_.data()); - return object_ids_from_flatbuf(*message->returns())[return_index]; + return object_ids_from_flatbuf(*message->returns())[return_index]; } bool TaskSpecification::ArgByRef(int64_t arg_index) const { @@ -184,7 +184,7 @@ int TaskSpecification::ArgIdCount(int64_t arg_index) const { ObjectID TaskSpecification::ArgId(int64_t arg_index, int64_t id_index) const { auto message = flatbuffers::GetRoot(spec_.data()); const auto &object_ids = - object_ids_from_flatbuf(*message->args()->Get(arg_index)->object_ids()); + object_ids_from_flatbuf(*message->args()->Get(arg_index)->object_ids()); return object_ids[id_index]; } @@ -232,12 +232,12 @@ bool TaskSpecification::IsActorTask() const { return !ActorId().is_nil(); } ActorID TaskSpecification::ActorCreationId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_creation_id()); + return from_flatbuf(*message->actor_creation_id()); } ObjectID TaskSpecification::ActorCreationDummyObjectId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_creation_dummy_object_id()); + return from_flatbuf(*message->actor_creation_dummy_object_id()); } int64_t TaskSpecification::MaxActorReconstructions() const { @@ -247,12 +247,12 @@ int64_t TaskSpecification::MaxActorReconstructions() const { ActorID TaskSpecification::ActorId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_id()); + return from_flatbuf(*message->actor_id()); } ActorHandleID TaskSpecification::ActorHandleId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->actor_handle_id()); + return from_flatbuf(*message->actor_handle_id()); } int64_t TaskSpecification::ActorCounter() const { @@ -267,7 +267,7 @@ ObjectID TaskSpecification::ActorDummyObject() const { std::vector TaskSpecification::NewActorHandles() const { auto message = flatbuffers::GetRoot(spec_.data()); - return object_ids_from_flatbuf(*message->new_actor_handles()); + return object_ids_from_flatbuf(*message->new_actor_handles()); } } // namespace raylet diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 11e93050b9d1..baa6165c9ede 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -96,7 +96,7 @@ class TaskSpecification { /// \param num_returns The number of values returned by the task. /// \param required_resources The task's resource demands. /// \param language The language of the worker that must execute the function. - TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id, + TaskSpecification(const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const std::vector> &task_arguments, int64_t num_returns, @@ -129,7 +129,7 @@ class TaskSpecification { /// \param language The language of the worker that must execute the function. /// \param function_descriptor The function descriptor. TaskSpecification( - const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, int64_t max_actor_reconstructions, const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter, @@ -164,7 +164,7 @@ class TaskSpecification { // TODO(swang): Finalize and document these methods. TaskID TaskId() const; - UniqueID DriverId() const; + DriverID DriverId() const; TaskID ParentTaskId() const; int64_t ParentCounter() const; std::vector FunctionDescriptor() const; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 4a7f71ea81ea..c548fc924d67 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -75,7 +75,7 @@ static inline TaskSpecification ExampleTaskSpec( const ActorID actor_id = ActorID::nil(), const Language &language = Language::PYTHON) { std::vector function_descriptor(3); - return TaskSpecification(UniqueID::nil(), TaskID::nil(), 0, ActorID::nil(), + return TaskSpecification(DriverID::nil(), TaskID::nil(), 0, ActorID::nil(), ObjectID::nil(), 0, actor_id, ActorHandleID::nil(), 0, {}, {}, 0, {{}}, {{}}, language, function_descriptor); } From fbf26cc43599b95cff4972ab814f05fd6a104b20 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Thu, 28 Feb 2019 15:29:46 +0800 Subject: [PATCH 2/7] Avoid heap allocation in cython. --- python/ray/includes/unique_ids.pxd | 48 ++++++++++++++++------ python/ray/includes/unique_ids.pxi | 65 ++++++++++++++---------------- src/ray/id.cc | 4 -- src/ray/id.h | 20 +++------ src/ray/id_def.h | 5 +++ src/ray/raylet/node_manager.cc | 4 +- 6 files changed, 78 insertions(+), 68 deletions(-) diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index a7b77300ba21..c5389267f0df 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -29,71 +29,95 @@ cdef extern from "ray/id.h" namespace "ray" nogil: cdef cppclass CActorCheckpointID "ray::ActorCheckpointID"(CUniqueID): CActorCheckpointID() - CActorCheckpointID(const c_string &binary) CActorCheckpointID(const CUniqueID &from_id) + @staticmethod + CActorCheckpointID from_binary(const c_string &binary) + cdef cppclass CActorClassID "ray::ActorClassID"(CUniqueID): CActorClassID() - CActorClassID(const c_string &binary) CActorClassID(const CUniqueID &from_id) + @staticmethod + CActorClassID from_binary(const c_string &binary) + cdef cppclass CActorID "ray::ActorID"(CUniqueID): CActorID() - CActorID(const c_string &binary) CActorID(const CUniqueID &from_id) + @staticmethod + CActorID from_binary(const c_string &binary) + cdef cppclass CActorHandleID "ray::ActorHandleID"(CUniqueID): CActorHandleID() - CActorHandleID(const c_string &binary) CActorHandleID(const CUniqueID &from_id) + @staticmethod + CActorHandleID from_binary(const c_string &binary) + cdef cppclass CClientID "ray::ClientID"(CUniqueID): CClientID() - CClientID(const c_string &binary) CClientID(const CUniqueID &from_id) + @staticmethod + CClientID from_binary(const c_string &binary) + cdef cppclass CConfigID "ray::ConfigID"(CUniqueID): CConfigID() - CConfigID(const c_string &binary) CConfigID(const CUniqueID &from_id) + @staticmethod + CConfigID from_binary(const c_string &binary) + cdef cppclass CFunctionID "ray::FunctionID"(CUniqueID): CFunctionID() - CFunctionID(const c_string &binary) CFunctionID(const CUniqueID &from_id) + @staticmethod + CFunctionID from_binary(const c_string &binary) + cdef cppclass CDriverID "ray::DriverID"(CUniqueID): CDriverID() - CDriverID(const c_string &binary) CDriverID(const CUniqueID &from_id) + @staticmethod + CDriverID from_binary(const c_string &binary) + cdef cppclass CJobID "ray::JobID"(CUniqueID): CJobID() - CJobID(const c_string &binary) CJobID(const CUniqueID &from_id) + @staticmethod + CJobID from_binary(const c_string &binary) + cdef cppclass CTaskID "ray::TaskID"(CUniqueID): CTaskID() - CTaskID(const c_string &binary) CTaskID(const CUniqueID &from_id) + @staticmethod + CTaskID from_binary(const c_string &binary) + cdef cppclass CObjectID" ray::ObjectID"(CUniqueID): CObjectID() - CObjectID(const c_string &binary) CObjectID(const CUniqueID &from_id) + @staticmethod + CObjectID from_binary(const c_string &binary) + cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID): CWorkerID() - CWorkerID(const c_string &binary) CWorkerID(const CUniqueID &from_id) + + @staticmethod + CWorkerID from_binary(const c_string &binary) diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 1badc75b2d22..70e551780b72 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -43,18 +43,13 @@ cdef extern from "ray/constants.h" nogil: cdef class UniqueID: - cdef CUniqueID *data + cdef CUniqueID data - def __cinit__(self, id): + def __init__(self, id): # The derived class should also check self type and fill self.data. if type(self) is UniqueID: check_id(id) - self.data = new CUniqueID(id) - - def __dealloc__(self): - # The derived classes do not need to define __dealloc__, - # the base class will do it. - del self.data + self.data = CUniqueID.from_binary(id) @classmethod def from_binary(cls, id_bytes): @@ -112,101 +107,101 @@ cdef class UniqueID: cdef class ObjectID(UniqueID): - def __cinit__(self, id): + def __init__(self, id): if type(self) is ObjectID: check_id(id) - self.data = new CObjectID(id) + self.data = CObjectID.from_binary(id) cdef CObjectID native(self): - return (self.data)[0] + return self.data cdef class TaskID(UniqueID): - def __cinit__(self, id): + def __init__(self, id): if type(self) is TaskID: check_id(id) - self.data = new CTaskID(id) + self.data = CTaskID.from_binary(id) cdef CTaskID native(self): - return (self.data)[0] + return self.data cdef class ClientID(UniqueID): - def __cinit__(self, id): + def __init__(self, id): if type(self) is ClientID: check_id(id) - self.data = new CClientID(id) + self.data = CClientID.from_binary(id) cdef CClientID native(self): - return (self.data)[0] + return self.data cdef class DriverID(UniqueID): - def __cinit__(self, id): + def __init__(self, id): if type(self) is DriverID: check_id(id) - self.data = new CDriverID(id) + self.data = CDriverID.from_binary(id) cdef CDriverID native(self): - return (self.data)[0] + return self.data cdef class ActorID(UniqueID): - def __cinit__(self, id): + def __init__(self, id): if type(self) is ActorID: check_id(id) - self.data = new CActorID(id) + self.data = CActorID.from_binary(id) cdef CActorID native(self): - return (self.data)[0] + return self.data cdef class ActorHandleID(UniqueID): - def __cinit__(self, id): + def __init__(self, id): if type(self) is ActorHandleID: check_id(id) - self.data = new CActorHandleID(id) + self.data = CActorHandleID.from_binary(id) cdef CActorHandleID native(self): - return (self.data)[0] + return self.data cdef class ActorCheckpointID(UniqueID): - def __cinit__(self, id): + def __init__(self, id): if type(self) is ActorCheckpointID: check_id(id) - self.data = new CActorCheckpointID(id) + self.data = CActorCheckpointID.from_binary(id) cdef CActorCheckpointID native(self): - return (self.data)[0] + return self.data cdef class FunctionID(UniqueID): - def __cinit__(self, id): + def __init__(self, id): if type(self) is FunctionID: check_id(id) - self.data = new CFunctionID(id) + self.data = CFunctionID.from_binary(id) cdef CFunctionID native(self): - return (self.data)[0] + return self.data cdef class ActorClassID(UniqueID): - def __cinit__(self, id): + def __init__(self, id): if type(self) is ActorClassID: check_id(id) - self.data = new CActorClassID(id) + self.data = CActorClassID.from_binary(id) cdef CActorClassID native(self): - return (self.data)[0] + return self.data _ID_TYPES = [ ActorCheckpointID, diff --git a/src/ray/id.cc b/src/ray/id.cc index 15e86910e0c1..c707c209b053 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -29,10 +29,6 @@ UniqueID::UniqueID() { std::fill_n(id_, kUniqueIDSize, 255); } -UniqueID::UniqueID(const plasma::UniqueID &from) { - std::memcpy(&id_, from.data(), kUniqueIDSize); -} - UniqueID::UniqueID(const std::string &binary) { std::memcpy(id_, binary.data(), kUniqueIDSize); } diff --git a/src/ray/id.h b/src/ray/id.h index 94837dc271c3..04fb8be49aa6 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -16,7 +16,6 @@ class RAY_EXPORT UniqueID { public: UniqueID(); UniqueID(const std::string &binary); - UniqueID(const plasma::UniqueID &from); static UniqueID from_random(); static UniqueID from_binary(const std::string &binary); static const UniqueID &nil(); @@ -46,13 +45,12 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id); std::memcpy(&id_, from.data(), kUniqueIDSize); \ } \ type() : UniqueID() {} \ - type(const std::string &binary) { std::memcpy(id_, binary.data(), kUniqueIDSize); } \ - const UniqueID &get() { return *this; } \ static type from_random() { return type(UniqueID::from_random()); } \ - static type from_binary(const std::string &binary) { \ - return type(UniqueID::from_binary(binary)); \ - } \ + static type from_binary(const std::string &binary) { return type(binary); } \ static type nil() { return type(UniqueID::nil()); } \ + \ + private: \ + type(const std::string &binary) { std::memcpy(id_, binary.data(), kUniqueIDSize); } \ }; #include "id_def.h" @@ -109,15 +107,6 @@ int64_t ComputeObjectIndex(const ObjectID &object_id); } // namespace ray namespace std { -template <> -struct hash<::ray::UniqueID> { - size_t operator()(const ::ray::UniqueID &id) const { return id.hash(); } -}; - -template <> -struct hash { - size_t operator()(const ::ray::UniqueID &id) const { return id.hash(); } -}; #define DEFINE_UNIQUE_ID(type) \ template <> \ @@ -129,6 +118,7 @@ struct hash { size_t operator()(const ::ray::type &id) const { return id.hash(); } \ }; +DEFINE_UNIQUE_ID(UniqueID); #include "id_def.h" #undef DEFINE_UNIQUE_ID diff --git a/src/ray/id_def.h b/src/ray/id_def.h index 4f115298e5fb..8e8b2b3fb717 100644 --- a/src/ray/id_def.h +++ b/src/ray/id_def.h @@ -1,3 +1,8 @@ +// This header file is used to avoid code duplication. +// It can be included multiple times in id.h, and each inclusion +// could use a different definition of the DEFINE_UNIQUE_ID macro. +// Macro definition format: DEFINE_UNIQUE_ID(id_type). +// NOTE: This file should NOT be included in any file other than id.h. DEFINE_UNIQUE_ID(TaskID); DEFINE_UNIQUE_ID(JobID); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index c9a06cecc106..a28853595788 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1244,8 +1244,8 @@ void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_typ // to the driver. std::ostringstream stream; stream << "An plasma error (" << status.ToString() << ") occurred while saving" - << " error code to object " << object_id << ". Anyone who's getting this" - << " object may hang forever."; + << " error code to object " << object_id.hex() << ". Anyone who's getting" + << " this object may hang forever."; std::string error_message = stream.str(); RAY_LOG(WARNING) << error_message; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( From aa197392fec7acc20b90da9bc8d513fcd647fdc7 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Thu, 28 Feb 2019 17:30:19 +0800 Subject: [PATCH 3/7] Fix JNI part --- ...org_ray_runtime_raylet_RayletClientImpl.cc | 96 +++++++++---------- src/ray/raylet/lineage_cache.cc | 7 +- src/ray/raylet/node_manager.cc | 27 +++--- src/ray/raylet/raylet_client.cc | 2 +- src/ray/raylet/raylet_client.h | 2 +- 5 files changed, 66 insertions(+), 68 deletions(-) diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 68004a37bf21..c55b2608b2fd 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -6,31 +6,30 @@ #include "ray/raylet/raylet_client.h" #include "ray/util/logging.h" -#ifdef __cplusplus -extern "C" { -#endif - +template class UniqueIdFromJByteArray { - private: - JNIEnv *_env; - jbyteArray _bytes; - public: - UniqueID *PID; + const ID &GetId() const { return *id_pointer_; } - UniqueIdFromJByteArray(JNIEnv *env, jbyteArray wid) { - _env = env; - _bytes = wid; - - jbyte *b = reinterpret_cast(_env->GetByteArrayElements(_bytes, nullptr)); - PID = reinterpret_cast(b); + UniqueIdFromJByteArray(JNIEnv *env, jbyteArray bytes) : env_(env), bytes_(bytes) { + jbyte *b = reinterpret_cast(env_->GetByteArrayElements(bytes_, nullptr)); + id_pointer_ = reinterpret_cast(b); } ~UniqueIdFromJByteArray() { - _env->ReleaseByteArrayElements(_bytes, reinterpret_cast(PID), 0); + env_->ReleaseByteArrayElements(bytes_, reinterpret_cast(id_pointer_), 0); } + + private: + JNIEnv *env_; + jbyteArray bytes_; + ID *id_pointer_; }; +#ifdef __cplusplus +extern "C" { +#endif + inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) { if (!status.ok()) { jclass exception_class = env->FindClass("org/ray/api/exception/RayException"); @@ -49,11 +48,11 @@ inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) { JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker, jbyteArray driverId) { - UniqueIdFromJByteArray worker_id(env, workerId); - UniqueIdFromJByteArray driver_id(env, driverId); + UniqueIdFromJByteArray worker_id(env, workerId); + UniqueIdFromJByteArray driver_id(env, driverId); const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); - auto raylet_client = new RayletClient(nativeString, *worker_id.PID, isWorker, - *driver_id.PID, Language::JAVA); + auto raylet_client = new RayletClient(nativeString, worker_id.GetId(), isWorker, + driver_id.GetId(), Language::JAVA); env->ReleaseStringUTFChars(sockName, nativeString); return reinterpret_cast(raylet_client); } @@ -70,8 +69,8 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmit std::vector execution_dependencies; if (cursorId != nullptr) { - UniqueIdFromJByteArray cursor_id(env, cursorId); - execution_dependencies.push_back(*cursor_id.PID); + UniqueIdFromJByteArray cursor_id(env, cursorId); + execution_dependencies.push_back(cursor_id.GetId()); } auto data = reinterpret_cast(env->GetDirectBufferAddress(taskBuff)) + pos; @@ -143,14 +142,14 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( for (int i = 0; i < len; i++) { jbyteArray object_id_bytes = static_cast(env->GetObjectArrayElement(objectIds, i)); - UniqueIdFromJByteArray object_id(env, object_id_bytes); - object_ids.push_back(*object_id.PID); + UniqueIdFromJByteArray object_id(env, object_id_bytes); + object_ids.push_back(object_id.GetId()); env->DeleteLocalRef(object_id_bytes); } - UniqueIdFromJByteArray current_task_id(env, currentTaskId); + UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto raylet_client = reinterpret_cast(client); auto status = - raylet_client->FetchOrReconstruct(object_ids, fetchOnly, *current_task_id.PID); + raylet_client->FetchOrReconstruct(object_ids, fetchOnly, current_task_id.GetId()); ThrowRayExceptionIfNotOK(env, status); } @@ -161,9 +160,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFetchOrReconstruct( */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyUnblocked( JNIEnv *env, jclass, jlong client, jbyteArray currentTaskId) { - UniqueIdFromJByteArray current_task_id(env, currentTaskId); + UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto raylet_client = reinterpret_cast(client); - auto status = raylet_client->NotifyUnblocked(*current_task_id.PID); + auto status = raylet_client->NotifyUnblocked(current_task_id.GetId()); ThrowRayExceptionIfNotOK(env, status); } @@ -181,19 +180,19 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( for (int i = 0; i < len; i++) { jbyteArray object_id_bytes = static_cast(env->GetObjectArrayElement(objectIds, i)); - UniqueIdFromJByteArray object_id(env, object_id_bytes); - object_ids.push_back(*object_id.PID); + UniqueIdFromJByteArray object_id(env, object_id_bytes); + object_ids.push_back(object_id.GetId()); env->DeleteLocalRef(object_id_bytes); } - UniqueIdFromJByteArray current_task_id(env, currentTaskId); + UniqueIdFromJByteArray current_task_id(env, currentTaskId); auto raylet_client = reinterpret_cast(client); // Invoke wait. WaitResultPair result; - auto status = - raylet_client->Wait(object_ids, numReturns, timeoutMillis, - static_cast(isWaitLocal), *current_task_id.PID, &result); + auto status = raylet_client->Wait(object_ids, numReturns, timeoutMillis, + static_cast(isWaitLocal), + current_task_id.GetId(), &result); if (ThrowRayExceptionIfNotOK(env, status)) { return nullptr; } @@ -231,15 +230,12 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId( JNIEnv *env, jclass, jbyteArray driverId, jbyteArray parentTaskId, jint parent_task_counter) { - UniqueIdFromJByteArray object_id1(env, driverId); - ray::DriverID driver_id = *object_id1.PID; + UniqueIdFromJByteArray driver_id(env, driverId); + UniqueIdFromJByteArray parent_task_id(env, parentTaskId); - UniqueIdFromJByteArray object_id2(env, parentTaskId); - ray::TaskID parent_task_id = *object_id2.PID; - - ray::TaskID task_id = - ray::GenerateTaskId(driver_id, parent_task_id, parent_task_counter); - jbyteArray result = env->NewByteArray(sizeof(ray::TaskID)); + TaskID task_id = + ray::GenerateTaskId(driver_id.GetId(), parent_task_id.GetId(), parent_task_counter); + jbyteArray result = env->NewByteArray(sizeof(TaskID)); if (nullptr == result) { return nullptr; } @@ -261,8 +257,8 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( for (int i = 0; i < len; i++) { jbyteArray object_id_bytes = static_cast(env->GetObjectArrayElement(objectIds, i)); - UniqueIdFromJByteArray object_id(env, object_id_bytes); - object_ids.push_back(*object_id.PID); + UniqueIdFromJByteArray object_id(env, object_id_bytes); + object_ids.push_back(object_id.GetId()); env->DeleteLocalRef(object_id_bytes); } auto raylet_client = reinterpret_cast(client); @@ -280,9 +276,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env jlong client, jbyteArray actorId) { auto raylet_client = reinterpret_cast(client); - UniqueIdFromJByteArray actor_id(env, actorId); + UniqueIdFromJByteArray actor_id(env, actorId); ActorCheckpointID checkpoint_id; - auto status = raylet_client->PrepareActorCheckpoint(*actor_id.PID, checkpoint_id); + auto status = raylet_client->PrepareActorCheckpoint(actor_id.GetId(), checkpoint_id); if (ThrowRayExceptionIfNotOK(env, status)) { return nullptr; } @@ -301,10 +297,10 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) { auto raylet_client = reinterpret_cast(client); - UniqueIdFromJByteArray actor_id(env, actorId); - UniqueIdFromJByteArray checkpoint_id(env, checkpointId); - auto status = - raylet_client->NotifyActorResumedFromCheckpoint(*actor_id.PID, *checkpoint_id.PID); + UniqueIdFromJByteArray actor_id(env, actorId); + UniqueIdFromJByteArray checkpoint_id(env, checkpointId); + auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id.GetId(), + checkpoint_id.GetId()); ThrowRayExceptionIfNotOK(env, status); } diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 49f39487eda1..949dc9eca1c2 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -346,9 +346,10 @@ void LineageCache::FlushTask(const TaskID &task_id) { RAY_CHECK(entry); RAY_CHECK(entry->GetStatus() == GcsStatus::UNCOMMITTED_READY); - gcs::raylet::TaskTable::WriteCallback task_callback = - [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { HandleEntryCommitted(id); }; + gcs::raylet::TaskTable::WriteCallback task_callback = [this]( + ray::gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) { + HandleEntryCommitted(id); + }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... flatbuffers::FlatBufferBuilder fbb; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index a28853595788..a317d1ff0947 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -154,28 +154,29 @@ ray::Status NodeManager::RegisterGcs() { }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. - auto node_manager_client_removed = - [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ClientRemoved(data); }; + auto node_manager_client_removed = [this]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + ClientRemoved(data); + }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Subscribe to heartbeat batches from the monitor. - const auto &heartbeat_batch_added = - [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatBatchTableDataT &heartbeat_batch) { - HeartbeatBatchAdded(heartbeat_batch); - }; + const auto &heartbeat_batch_added = [this]( + gcs::AsyncGcsClient *client, const ClientID &id, + const HeartbeatBatchTableDataT &heartbeat_batch) { + HeartbeatBatchAdded(heartbeat_batch); + }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( JobID::nil(), ClientID::nil(), heartbeat_batch_added, /*subscribe_callback=*/nullptr, /*done_callback=*/nullptr)); // Subscribe to driver table updates. - const auto driver_table_handler = - [this](gcs::AsyncGcsClient *client, const DriverID &client_id, - const std::vector &driver_data) { - HandleDriverTableUpdate(client_id, driver_data); - }; + const auto driver_table_handler = [this]( + gcs::AsyncGcsClient *client, const DriverID &client_id, + const std::vector &driver_data) { + HandleDriverTableUpdate(client_id, driver_data); + }; RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe(JobID::nil(), ClientID::nil(), driver_table_handler, nullptr)); diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 9299cd599d0d..28a51c7e10fd 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -201,7 +201,7 @@ ray::Status RayletConnection::AtomicRequestReply( return ReadMessage(reply_type, reply_message); } -RayletClient::RayletClient(const std::string &raylet_socket, const UniqueID &client_id, +RayletClient::RayletClient(const std::string &raylet_socket, const ClientID &client_id, bool is_worker, const DriverID &driver_id, const Language &language) : client_id_(client_id), diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 39b3092ca42d..2e07becfc245 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -69,7 +69,7 @@ class RayletClient { /// additional message will be sent to register as one. /// \param driver_id The ID of the driver. This is non-nil if the client is a driver. /// \return The connection information. - RayletClient(const std::string &raylet_socket, const UniqueID &client_id, + RayletClient(const std::string &raylet_socket, const ClientID &client_id, bool is_worker, const DriverID &driver_id, const Language &language); ray::Status Disconnect() { return conn_->Disconnect(); }; From d3c1eae03aa035a299fa090075a10eb9eeba710a Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Thu, 28 Feb 2019 17:46:22 +0800 Subject: [PATCH 4/7] Fix rebase conflict --- python/ray/_raylet.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 94ebd625332b..453f66bb4009 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -262,8 +262,8 @@ cdef class RayletClient: cdef: WaitResultPair result c_vector[CObjectID] wait_ids + CTaskID c_task_id = current_task_id.native() wait_ids = ObjectIDsToVector(object_ids) - c_task_id = current_task_id.native() with nogil: check_status(self.client.get().Wait(wait_ids, num_returns, timeout_milliseconds, From b79d4e2013ecb49fc3a366c79bde8a0ab5fad75c Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Thu, 28 Feb 2019 18:45:03 +0800 Subject: [PATCH 5/7] Refine --- python/ray/includes/unique_ids.pxi | 2 +- src/ray/common/common_protocol.h | 88 +++++++++++++++--------------- src/ray/id.cc | 4 +- src/ray/id.h | 2 +- src/ray/raylet/node_manager.cc | 4 +- src/ray/raylet/task_spec.cc | 12 ++-- 6 files changed, 56 insertions(+), 56 deletions(-) diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 70e551780b72..428fa7c66060 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -49,7 +49,7 @@ cdef class UniqueID: # The derived class should also check self type and fill self.data. if type(self) is UniqueID: check_id(id) - self.data = CUniqueID.from_binary(id) + self.data = CUniqueID.from_binary(id) @classmethod def from_binary(cls, id_bytes): diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index aecd84197f98..bc3d9b646a4b 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -8,66 +8,66 @@ #include "ray/id.h" #include "ray/util/logging.h" -/// Convert an object ID to a flatbuffer string. +/// Convert an unique ID to a flatbuffer string. /// /// @param fbb Reference to the flatbuffer builder. -/// @param object_id The object ID to be converted. -/// @return The flatbuffer string contining the object ID. +/// @param id The ID to be converted. +/// @return The flatbuffer string containing the ID. template flatbuffers::Offset to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - ID object_id); + ID id); -/// Convert a flatbuffer string to an object ID. +/// Convert a flatbuffer string to an unique ID. /// /// @param string The flatbuffer string. -/// @return The object ID. +/// @return The ID. template ID from_flatbuf(const flatbuffers::String &string); -/// Convert a flatbuffer vector of strings to a vector of object IDs. +/// Convert a flatbuffer vector of strings to a vector of unique IDs. /// /// @param vector The flatbuffer vector. -/// @return The vector of object IDs. +/// @return The vector of IDs. template const std::vector from_flatbuf( const flatbuffers::Vector> &vector); /// Convert a flatbuffer of string that concatenated -/// object IDs to a vector of object IDs. +/// unique IDs to a vector of unique IDs. /// /// @param vector The flatbuffer vector. -/// @return The vector of object IDs. +/// @return The vector of IDs. template -const std::vector object_ids_from_flatbuf(const flatbuffers::String &string); +const std::vector ids_from_flatbuf(const flatbuffers::String &string); -/// Convert a vector of object IDs to a flatbuffer string. +/// Convert a vector of unique IDs to a flatbuffer string. /// The IDs are concatenated to a string with binary. /// /// @param fbb Reference to the flatbuffer builder. -/// @param object_ids The vector of object IDs. +/// @param ids The vector of IDs. /// @return Flatbuffer string of concatenated IDs. template -flatbuffers::Offset object_ids_to_flatbuf( - flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids); +flatbuffers::Offset ids_to_flatbuf( + flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids); -/// Convert an array of object IDs to a flatbuffer vector of strings. +/// Convert an array of unique IDs to a flatbuffer vector of strings. /// /// @param fbb Reference to the flatbuffer builder. -/// @param object_ids Array of object IDs. -/// @param num_objects Number of elements in the array. +/// @param ids Array of unique IDs. +/// @param num_ids Number of elements in the array. /// @return Flatbuffer vector of strings. template flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID object_ids[], int64_t num_objects); +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID ids[], int64_t num_ids); -/// Convert a vector of object IDs to a flatbuffer vector of strings. +/// Convert a vector of unique IDs to a flatbuffer vector of strings. /// /// @param fbb Reference to the flatbuffer builder. -/// @param object_ids Vector of object IDs. +/// @param ids Vector of IDs. /// @return Flatbuffer vector of strings. template flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids); +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids); /// Convert a flatbuffer string to a std::string. /// @@ -103,38 +103,38 @@ string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, template flatbuffers::Offset to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, - ID object_id) { - return fbb.CreateString(reinterpret_cast(object_id.data()), sizeof(ID)); + ID id) { + return fbb.CreateString(reinterpret_cast(id.data()), sizeof(ID)); } template ID from_flatbuf(const flatbuffers::String &string) { - ID object_id; + ID id; RAY_CHECK(string.size() == sizeof(ID)); - memcpy(object_id.mutable_data(), string.data(), sizeof(ID)); - return object_id; + memcpy(id.mutable_data(), string.data(), sizeof(ID)); + return id; } template const std::vector from_flatbuf( const flatbuffers::Vector> &vector) { - std::vector object_ids; + std::vector ids; for (int64_t i = 0; i < vector.Length(); i++) { - object_ids.push_back(from_flatbuf(*vector.Get(i))); + ids.push_back(from_flatbuf(*vector.Get(i))); } - return object_ids; + return ids; } template -const std::vector object_ids_from_flatbuf(const flatbuffers::String &string) { - const auto &object_ids = string_from_flatbuf(string); +const std::vector ids_from_flatbuf(const flatbuffers::String &string) { + const auto &ids = string_from_flatbuf(string); std::vector ret; - RAY_CHECK(object_ids.size() % kUniqueIDSize == 0); - auto count = object_ids.size() / kUniqueIDSize; + RAY_CHECK(ids.size() % kUniqueIDSize == 0); + auto count = ids.size() / kUniqueIDSize; for (size_t i = 0; i < count; ++i) { auto pos = static_cast(kUniqueIDSize * i); - const auto &id = object_ids.substr(pos, kUniqueIDSize); + const auto &id = ids.substr(pos, kUniqueIDSize); ret.push_back(ID::from_binary(id)); } @@ -142,10 +142,10 @@ const std::vector object_ids_from_flatbuf(const flatbuffers::String &string) } template -flatbuffers::Offset object_ids_to_flatbuf( - flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids) { +flatbuffers::Offset ids_to_flatbuf( + flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids) { std::string result; - for (const auto &id : object_ids) { + for (const auto &id : ids) { result += id.binary(); } @@ -154,20 +154,20 @@ flatbuffers::Offset object_ids_to_flatbuf( template flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID object_ids[], int64_t num_objects) { +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID ids[], int64_t num_ids) { std::vector> results; - for (int64_t i = 0; i < num_objects; i++) { - results.push_back(to_flatbuf(fbb, object_ids[i])); + for (int64_t i = 0; i < num_ids; i++) { + results.push_back(to_flatbuf(fbb, ids[i])); } return fbb.CreateVector(results); } template flatbuffers::Offset>> -to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &object_ids) { +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids) { std::vector> results; - for (auto object_id : object_ids) { - results.push_back(to_flatbuf(fbb, object_id)); + for (auto id : ids) { + results.push_back(to_flatbuf(fbb, id)); } return fbb.CreateVector(results); } diff --git a/src/ray/id.cc b/src/ray/id.cc index c707c209b053..a9d9c5a7e765 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -29,8 +29,8 @@ UniqueID::UniqueID() { std::fill_n(id_, kUniqueIDSize, 255); } -UniqueID::UniqueID(const std::string &binary) { - std::memcpy(id_, binary.data(), kUniqueIDSize); +UniqueID::UniqueID(const plasma::UniqueID &from) { + std::memcpy(&id_, from.data(), kUniqueIDSize); } UniqueID UniqueID::from_random() { diff --git a/src/ray/id.h b/src/ray/id.h index 04fb8be49aa6..35c67b220faf 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -15,7 +15,7 @@ namespace ray { class RAY_EXPORT UniqueID { public: UniqueID(); - UniqueID(const std::string &binary); + UniqueID(const plasma::UniqueID &from); static UniqueID from_random(); static UniqueID from_binary(const std::string &binary); static const UniqueID &nil(); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index a317d1ff0947..a49b6268cf4f 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1245,8 +1245,8 @@ void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_typ // to the driver. std::ostringstream stream; stream << "An plasma error (" << status.ToString() << ") occurred while saving" - << " error code to object " << object_id.hex() << ". Anyone who's getting" - << " this object may hang forever."; + << " error code to object " << object_id << ". Anyone who's getting this" + << " object may hang forever."; std::string error_message = stream.str(); RAY_LOG(WARNING) << error_message; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index ea5fe5248041..da8bafc60fd4 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -17,7 +17,7 @@ TaskArgumentByReference::TaskArgumentByReference(const std::vector &re flatbuffers::Offset TaskArgumentByReference::ToFlatbuffer( flatbuffers::FlatBufferBuilder &fbb) const { - return CreateArg(fbb, object_ids_to_flatbuf(fbb, references_)); + return CreateArg(fbb, ids_to_flatbuf(fbb, references_)); } TaskArgumentByValue::TaskArgumentByValue(const uint8_t *value, size_t length) { @@ -100,8 +100,8 @@ TaskSpecification::TaskSpecification( to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id), to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, - object_ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), - object_ids_to_flatbuf(fbb, returns), map_to_flatbuf(fbb, required_resources), + ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), + ids_to_flatbuf(fbb, returns), map_to_flatbuf(fbb, required_resources), map_to_flatbuf(fbb, required_placement_resources), language, string_vec_to_flatbuf(fbb, function_descriptor)); fbb.Finish(spec); @@ -168,7 +168,7 @@ int64_t TaskSpecification::NumReturns() const { ObjectID TaskSpecification::ReturnId(int64_t return_index) const { auto message = flatbuffers::GetRoot(spec_.data()); - return object_ids_from_flatbuf(*message->returns())[return_index]; + return ids_from_flatbuf(*message->returns())[return_index]; } bool TaskSpecification::ArgByRef(int64_t arg_index) const { @@ -184,7 +184,7 @@ int TaskSpecification::ArgIdCount(int64_t arg_index) const { ObjectID TaskSpecification::ArgId(int64_t arg_index, int64_t id_index) const { auto message = flatbuffers::GetRoot(spec_.data()); const auto &object_ids = - object_ids_from_flatbuf(*message->args()->Get(arg_index)->object_ids()); + ids_from_flatbuf(*message->args()->Get(arg_index)->object_ids()); return object_ids[id_index]; } @@ -267,7 +267,7 @@ ObjectID TaskSpecification::ActorDummyObject() const { std::vector TaskSpecification::NewActorHandles() const { auto message = flatbuffers::GetRoot(spec_.data()); - return object_ids_from_flatbuf(*message->new_actor_handles()); + return ids_from_flatbuf(*message->new_actor_handles()); } } // namespace raylet From 204b6b1348842920bd82022366aeafc3846d7519 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Fri, 1 Mar 2019 18:18:16 +0800 Subject: [PATCH 6/7] Remove type check from __init__ --- python/ray/includes/unique_ids.pxi | 51 ++++++++++++------------------ 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 428fa7c66060..0086f76b51b0 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -46,10 +46,8 @@ cdef class UniqueID: cdef CUniqueID data def __init__(self, id): - # The derived class should also check self type and fill self.data. - if type(self) is UniqueID: - check_id(id) - self.data = CUniqueID.from_binary(id) + check_id(id) + self.data = CUniqueID.from_binary(id) @classmethod def from_binary(cls, id_bytes): @@ -108,9 +106,8 @@ cdef class UniqueID: cdef class ObjectID(UniqueID): def __init__(self, id): - if type(self) is ObjectID: - check_id(id) - self.data = CObjectID.from_binary(id) + check_id(id) + self.data = CObjectID.from_binary(id) cdef CObjectID native(self): return self.data @@ -119,9 +116,8 @@ cdef class ObjectID(UniqueID): cdef class TaskID(UniqueID): def __init__(self, id): - if type(self) is TaskID: - check_id(id) - self.data = CTaskID.from_binary(id) + check_id(id) + self.data = CTaskID.from_binary(id) cdef CTaskID native(self): return self.data @@ -130,9 +126,8 @@ cdef class TaskID(UniqueID): cdef class ClientID(UniqueID): def __init__(self, id): - if type(self) is ClientID: - check_id(id) - self.data = CClientID.from_binary(id) + check_id(id) + self.data = CClientID.from_binary(id) cdef CClientID native(self): return self.data @@ -141,9 +136,8 @@ cdef class ClientID(UniqueID): cdef class DriverID(UniqueID): def __init__(self, id): - if type(self) is DriverID: - check_id(id) - self.data = CDriverID.from_binary(id) + check_id(id) + self.data = CDriverID.from_binary(id) cdef CDriverID native(self): return self.data @@ -152,9 +146,8 @@ cdef class DriverID(UniqueID): cdef class ActorID(UniqueID): def __init__(self, id): - if type(self) is ActorID: - check_id(id) - self.data = CActorID.from_binary(id) + check_id(id) + self.data = CActorID.from_binary(id) cdef CActorID native(self): return self.data @@ -163,9 +156,8 @@ cdef class ActorID(UniqueID): cdef class ActorHandleID(UniqueID): def __init__(self, id): - if type(self) is ActorHandleID: - check_id(id) - self.data = CActorHandleID.from_binary(id) + check_id(id) + self.data = CActorHandleID.from_binary(id) cdef CActorHandleID native(self): return self.data @@ -174,9 +166,8 @@ cdef class ActorHandleID(UniqueID): cdef class ActorCheckpointID(UniqueID): def __init__(self, id): - if type(self) is ActorCheckpointID: - check_id(id) - self.data = CActorCheckpointID.from_binary(id) + check_id(id) + self.data = CActorCheckpointID.from_binary(id) cdef CActorCheckpointID native(self): return self.data @@ -185,9 +176,8 @@ cdef class ActorCheckpointID(UniqueID): cdef class FunctionID(UniqueID): def __init__(self, id): - if type(self) is FunctionID: - check_id(id) - self.data = CFunctionID.from_binary(id) + check_id(id) + self.data = CFunctionID.from_binary(id) cdef CFunctionID native(self): return self.data @@ -196,9 +186,8 @@ cdef class FunctionID(UniqueID): cdef class ActorClassID(UniqueID): def __init__(self, id): - if type(self) is ActorClassID: - check_id(id) - self.data = CActorClassID.from_binary(id) + check_id(id) + self.data = CActorClassID.from_binary(id) cdef CActorClassID native(self): return self.data From c86c041ec0be8dfff7067ecac18fe36c23773692 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Wed, 6 Mar 2019 17:30:27 +0800 Subject: [PATCH 7/7] Remove unused constructor declarations. --- python/ray/includes/unique_ids.pxd | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index c5389267f0df..cadbdfea2827 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -28,96 +28,72 @@ cdef extern from "ray/id.h" namespace "ray" nogil: c_string hex() const cdef cppclass CActorCheckpointID "ray::ActorCheckpointID"(CUniqueID): - CActorCheckpointID() - CActorCheckpointID(const CUniqueID &from_id) @staticmethod CActorCheckpointID from_binary(const c_string &binary) cdef cppclass CActorClassID "ray::ActorClassID"(CUniqueID): - CActorClassID() - CActorClassID(const CUniqueID &from_id) @staticmethod CActorClassID from_binary(const c_string &binary) cdef cppclass CActorID "ray::ActorID"(CUniqueID): - CActorID() - CActorID(const CUniqueID &from_id) @staticmethod CActorID from_binary(const c_string &binary) cdef cppclass CActorHandleID "ray::ActorHandleID"(CUniqueID): - CActorHandleID() - CActorHandleID(const CUniqueID &from_id) @staticmethod CActorHandleID from_binary(const c_string &binary) cdef cppclass CClientID "ray::ClientID"(CUniqueID): - CClientID() - CClientID(const CUniqueID &from_id) @staticmethod CClientID from_binary(const c_string &binary) cdef cppclass CConfigID "ray::ConfigID"(CUniqueID): - CConfigID() - CConfigID(const CUniqueID &from_id) @staticmethod CConfigID from_binary(const c_string &binary) cdef cppclass CFunctionID "ray::FunctionID"(CUniqueID): - CFunctionID() - CFunctionID(const CUniqueID &from_id) @staticmethod CFunctionID from_binary(const c_string &binary) cdef cppclass CDriverID "ray::DriverID"(CUniqueID): - CDriverID() - CDriverID(const CUniqueID &from_id) @staticmethod CDriverID from_binary(const c_string &binary) cdef cppclass CJobID "ray::JobID"(CUniqueID): - CJobID() - CJobID(const CUniqueID &from_id) @staticmethod CJobID from_binary(const c_string &binary) cdef cppclass CTaskID "ray::TaskID"(CUniqueID): - CTaskID() - CTaskID(const CUniqueID &from_id) @staticmethod CTaskID from_binary(const c_string &binary) cdef cppclass CObjectID" ray::ObjectID"(CUniqueID): - CObjectID() - CObjectID(const CUniqueID &from_id) @staticmethod CObjectID from_binary(const c_string &binary) cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID): - CWorkerID() - CWorkerID(const CUniqueID &from_id) @staticmethod CWorkerID from_binary(const c_string &binary)