diff --git a/python/ray/actor.py b/python/ray/actor.py index dee308f0d54e..51c36484f9e0 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -17,6 +17,7 @@ import ray.signature as signature import ray.worker from ray.utils import _random_string +from ray import ObjectID DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1 @@ -41,8 +42,7 @@ def compute_actor_handle_id(actor_handle_id, num_forks): handle_id_hash.update(actor_handle_id.id()) handle_id_hash.update(str(num_forks).encode("ascii")) handle_id = handle_id_hash.digest() - assert len(handle_id) == ray_constants.ID_SIZE - return ray.ObjectID(handle_id) + return ObjectID(handle_id) def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id): @@ -69,8 +69,7 @@ def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id): handle_id_hash.update(actor_handle_id.id()) handle_id_hash.update(current_task_id.id()) handle_id = handle_id_hash.digest() - assert len(handle_id) == ray_constants.ID_SIZE - return ray.ObjectID(handle_id) + return ObjectID(handle_id) def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, @@ -84,7 +83,7 @@ def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, checkpoint: The state object to save. frontier: The task frontier at the time of the checkpoint. """ - actor_key = b"Actor:" + actor_id + actor_key = b"Actor:" + actor_id.id() worker.redis_client.hmset( actor_key, { "checkpoint_index": checkpoint_index, @@ -110,7 +109,7 @@ def save_and_log_checkpoint(worker, actor): worker, ray_constants.CHECKPOINT_PUSH_ERROR, traceback_str, - driver_id=worker.task_driver_id.id(), + driver_id=worker.task_driver_id, data={ "actor_class": actor.__class__.__name__, "function_name": actor.__ray_checkpoint__.__name__ @@ -134,7 +133,7 @@ def restore_and_log_checkpoint(worker, actor): worker, ray_constants.CHECKPOINT_PUSH_ERROR, traceback_str, - driver_id=worker.task_driver_id.id(), + driver_id=worker.task_driver_id, data={ "actor_class": actor.__class__.__name__, "function_name": actor.__ray_checkpoint_restore__.__name__ @@ -156,7 +155,7 @@ def get_actor_checkpoint(worker, actor_id): exists, all objects are set to None. The checkpoint index is the . executed on the actor before the checkpoint was made. """ - actor_key = b"Actor:" + actor_id + actor_key = b"Actor:" + actor_id.id() checkpoint_index, checkpoint, frontier = worker.redis_client.hmget( actor_key, ["checkpoint_index", "checkpoint", "frontier"]) if checkpoint_index is not None: @@ -371,7 +370,7 @@ def _remote(self, raise Exception("Actors cannot be created before ray.init() " "has been called.") - actor_id = ray.ObjectID(_random_string()) + actor_id = ObjectID(_random_string()) # The actor cursor is a dummy object representing the most recent # actor method invocation. For each subsequent method invocation, # the current cursor should be added as a dependency, and then @@ -509,8 +508,7 @@ def __init__(self, # if it was created by the _serialization_helper function. self._ray_original_handle = actor_handle_id is None if self._ray_original_handle: - self._ray_actor_handle_id = ray.ObjectID( - ray.worker.NIL_ACTOR_HANDLE_ID) + self._ray_actor_handle_id = ObjectID.nil_id() else: self._ray_actor_handle_id = actor_handle_id self._ray_actor_cursor = actor_cursor @@ -713,7 +711,7 @@ def _serialization_helper(self, ray_forking): # to release, since it could be unpickled and submit another # dependent task at any time. Therefore, we notify the backend of a # random handle ID that will never actually be used. - new_actor_handle_id = ray.ObjectID(_random_string()) + new_actor_handle_id = ObjectID(_random_string()) # Notify the backend to expect this new actor handle. The backend will # not release the cursor for any new handles until the first task for # each of the new handles is submitted. @@ -735,7 +733,7 @@ def _deserialization_helper(self, state, ray_forking): worker.check_connected() if state["ray_forking"]: - actor_handle_id = ray.ObjectID(state["actor_handle_id"]) + actor_handle_id = ObjectID(state["actor_handle_id"]) else: # Right now, if the actor handle has been pickled, we create a # temporary actor handle id for invocations. @@ -749,22 +747,22 @@ def _deserialization_helper(self, state, ray_forking): # same actor is likely a performance bug. We should consider # logging a warning in these cases. actor_handle_id = compute_actor_handle_id_non_forked( - ray.ObjectID(state["actor_handle_id"]), worker.current_task_id) + ObjectID(state["actor_handle_id"]), worker.current_task_id) # This is the driver ID of the driver that owns the actor, not # necessarily the driver that owns this actor handle. - actor_driver_id = ray.ObjectID(state["actor_driver_id"]) + actor_driver_id = ObjectID(state["actor_driver_id"]) self.__init__( - ray.ObjectID(state["actor_id"]), + ObjectID(state["actor_id"]), state["module_name"], state["class_name"], - ray.ObjectID(state["actor_cursor"]) + ObjectID(state["actor_cursor"]) if state["actor_cursor"] is not None else None, state["actor_method_names"], state["method_signatures"], state["method_num_return_vals"], - ray.ObjectID(state["actor_creation_dummy_object_id"]) + ObjectID(state["actor_creation_dummy_object_id"]) if state["actor_creation_dummy_object_id"] is not None else None, state["actor_method_cpus"], actor_driver_id, @@ -843,7 +841,7 @@ def __ray_checkpoint__(self): # scheduler has seen. Handle IDs for which no task has yet reached # the local scheduler will not be included, and may not be runnable # on checkpoint resumption. - actor_id = ray.ObjectID(worker.actor_id) + actor_id = worker.actor_id frontier = worker.raylet_client.get_actor_frontier(actor_id) # Save the checkpoint in Redis. TODO(rkn): Checkpoints # should not be stored in Redis. Fix this. diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 8fb91d3a3466..20a3bc6137fa 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -25,7 +25,7 @@ def parse_client_table(redis_client): Returns: A list of information about the nodes in the cluster. """ - NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" + NIL_CLIENT_ID = ray.ObjectID.nil_id().id() message = redis_client.execute_command("RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", NIL_CLIENT_ID) @@ -308,20 +308,19 @@ def _task_table(self, task_id): function_descriptor = FunctionDescriptor.from_bytes_list( function_descriptor_list) task_spec_info = { - "DriverID": binary_to_hex(task_spec.driver_id().id()), - "TaskID": binary_to_hex(task_spec.task_id().id()), - "ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()), + "DriverID": task_spec.driver_id().hex(), + "TaskID": task_spec.task_id().hex(), + "ParentTaskID": task_spec.parent_task_id().hex(), "ParentCounter": task_spec.parent_counter(), - "ActorID": binary_to_hex(task_spec.actor_id().id()), - "ActorCreationID": binary_to_hex( - task_spec.actor_creation_id().id()), - "ActorCreationDummyObjectID": binary_to_hex( - task_spec.actor_creation_dummy_object_id().id()), + "ActorID": (task_spec.actor_id().hex()), + "ActorCreationID": task_spec.actor_creation_id().hex(), + "ActorCreationDummyObjectID": ( + task_spec.actor_creation_dummy_object_id().hex()), "ActorCounter": task_spec.actor_counter(), "Args": task_spec.arguments(), "ReturnObjectIDs": task_spec.returns(), "RequiredResources": task_spec.required_resources(), - "FunctionID": binary_to_hex(function_descriptor.function_id.id()), + "FunctionID": function_descriptor.function_id.hex(), "FunctionHash": binary_to_hex(function_descriptor.function_hash), "ModuleName": function_descriptor.module_name, "ClassName": function_descriptor.class_name, diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index f89bfcf3e36d..f530aeabde69 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -211,7 +211,7 @@ def function_id(self): Returns: The value of ray.ObjectID that represents the function id. """ - return ray.ObjectID(self._function_id) + return self._function_id def _get_function_id(self): """Calculate the function id of current function descriptor. @@ -220,10 +220,10 @@ def _get_function_id(self): descriptor. Returns: - bytes with length of ray_constants.ID_SIZE. + ray.ObjectID to represent the function descriptor. """ if self.is_for_driver_task: - return ray_constants.NIL_FUNCTION_ID.id() + return ray.ObjectID.nil_id() function_id_hash = hashlib.sha1() # Include the function module and name in the hash. function_id_hash.update(self.module_name.encode("ascii")) @@ -232,8 +232,7 @@ def _get_function_id(self): function_id_hash.update(self._function_source_hash) # Compute the function ID. function_id = function_id_hash.digest() - assert len(function_id) == ray_constants.ID_SIZE - return function_id + return ray.ObjectID(function_id) def get_function_descriptor_list(self): """Return a list of bytes representing the function descriptor. @@ -290,11 +289,11 @@ def __init__(self, worker): self.imported_actor_classes = set() def increase_task_counter(self, driver_id, function_descriptor): - function_id = function_descriptor.function_id.id() + function_id = function_descriptor.function_id self._num_task_executions[driver_id][function_id] += 1 def get_task_counter(self, driver_id, function_descriptor): - function_id = function_descriptor.function_id.id() + function_id = function_descriptor.function_id return self._num_task_executions[driver_id][function_id] def export_cached(self): @@ -372,13 +371,14 @@ def _do_export(self, remote_function): def fetch_and_register_remote_function(self, key): """Import a remote function.""" - (driver_id, function_id_str, function_name, serialized_function, + (driver_id_str, function_id_str, function_name, serialized_function, num_return_vals, module, resources, max_calls) = self._worker.redis_client.hmget(key, [ "driver_id", "function_id", "name", "function", "num_return_vals", "module", "resources", "max_calls" ]) function_id = ray.ObjectID(function_id_str) + driver_id = ray.ObjectID(driver_id_str) function_name = decode(function_name) max_calls = int(max_calls) module = decode(module) @@ -388,10 +388,10 @@ def fetch_and_register_remote_function(self, key): def f(): raise Exception("This function was not imported properly.") - self._function_execution_info[driver_id][function_id.id()] = ( + self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo( function=f, function_name=function_name, max_calls=max_calls)) - self._num_task_executions[driver_id][function_id.id()] = 0 + self._num_task_executions[driver_id][function_id] = 0 try: function = pickle.loads(serialized_function) @@ -416,7 +416,7 @@ def f(): # However in the worker process, the `__main__` module is a # different module, which is `default_worker.py` function.__module__ = module - self._function_execution_info[driver_id][function_id.id()] = ( + self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo( function=function, function_name=function_name, @@ -435,7 +435,7 @@ def get_execution_info(self, driver_id, function_descriptor): Returns: A FunctionExecutionInfo object. """ - function_id = function_descriptor.function_id.id() + function_id = function_descriptor.function_id # Wait until the function to be executed has actually been # registered on this worker. We will push warnings to the user if @@ -449,7 +449,7 @@ def get_execution_info(self, driver_id, function_descriptor): except KeyError as e: message = ("Error occurs in get_execution_info: " "driver_id: %s, function_descriptor: %s. Message: %s" % - (binary_to_hex(driver_id), function_descriptor, e)) + driver_id, function_descriptor, e) raise KeyError(message) return info @@ -474,11 +474,11 @@ def _wait_for_function(self, function_descriptor, driver_id, timeout=10): warning_sent = False while True: with self._worker.lock: - if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID - and (function_descriptor.function_id.id() in + if (self._worker.actor_id.is_nil() + and (function_descriptor.function_id in self._function_execution_info[driver_id])): break - elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and ( + elif not self._worker.actor_id.is_nil() and ( self._worker.actor_id in self._worker.actors): break if time.time() - start_time > timeout: @@ -556,7 +556,7 @@ def export_actor_class(self, Class, actor_method_names, # because of https://github.com/ray-project/ray/issues/1146. def load_actor(self, driver_id, function_descriptor): - key = (b"ActorClass:" + driver_id + b":" + + key = (b"ActorClass:" + driver_id.id() + b":" + function_descriptor.function_id.id()) # Wait for the actor class key to have been imported by the # import thread. TODO(rkn): It shouldn't be possible to end @@ -578,8 +578,8 @@ def fetch_and_register_actor(self, actor_class_key): actor_class_key: The key in Redis to use to fetch the actor. worker: The worker to use. """ - actor_id_str = self._worker.actor_id - (driver_id, class_name, module, pickled_class, checkpoint_interval, + actor_id = self._worker.actor_id + (driver_id_str, class_name, module, pickled_class, checkpoint_interval, actor_method_names) = self._worker.redis_client.hmget( actor_class_key, [ "driver_id", "class_name", "module", "class", @@ -588,6 +588,7 @@ def fetch_and_register_actor(self, actor_class_key): class_name = decode(class_name) module = decode(module) + driver_id = ray.ObjectID(driver_id_str) checkpoint_interval = int(checkpoint_interval) actor_method_names = json.loads(decode(actor_method_names)) @@ -606,7 +607,7 @@ def fetch_and_register_actor(self, actor_class_key): class TemporaryActor(object): pass - self._worker.actors[actor_id_str] = TemporaryActor() + self._worker.actors[actor_id] = TemporaryActor() self._worker.actor_checkpoint_interval = checkpoint_interval def temporary_actor_method(*xs): @@ -618,7 +619,7 @@ def temporary_actor_method(*xs): for actor_method_name in actor_method_names: function_descriptor = FunctionDescriptor(module, actor_method_name, class_name) - function_id = function_descriptor.function_id.id() + function_id = function_descriptor.function_id temporary_executor = self._make_actor_method_executor( actor_method_name, temporary_actor_method, @@ -644,14 +645,14 @@ def temporary_actor_method(*xs): ray_constants.REGISTER_ACTOR_PUSH_ERROR, traceback_str, driver_id, - data={"actor_id": actor_id_str}) + data={"actor_id": actor_id.id()}) # TODO(rkn): In the future, it might make sense to have the worker # exit here. However, currently that would lead to hanging if # someone calls ray.get on a method invoked on the actor. else: # TODO(pcm): Why is the below line necessary? unpickled_class.__module__ = module - self._worker.actors[actor_id_str] = unpickled_class.__new__( + self._worker.actors[actor_id] = unpickled_class.__new__( unpickled_class) actor_methods = inspect.getmembers( @@ -659,7 +660,7 @@ def temporary_actor_method(*xs): for actor_method_name, actor_method in actor_methods: function_descriptor = FunctionDescriptor( module, actor_method_name, class_name) - function_id = function_descriptor.function_id.id() + function_id = function_descriptor.function_id executor = self._make_actor_method_executor( actor_method_name, actor_method, actor_imported=True) self._function_execution_info[driver_id][function_id] = ( diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index c477f4bcc9e8..511cbeea09f3 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -58,7 +58,7 @@ def construct_error_message(driver_id, error_type, message, timestamp): The serialized object. """ builder = flatbuffers.Builder(0) - driver_offset = builder.CreateString(driver_id) + driver_offset = builder.CreateString(driver_id.id()) error_type_offset = builder.CreateString(error_type) message_offset = builder.CreateString(message) diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 08031c7b603b..e1c8742b9c44 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -131,5 +131,5 @@ def fetch_and_execute_function_to_run(self, key): self.worker, ray_constants.FUNCTION_TO_RUN_PUSH_ERROR, traceback_str, - driver_id=driver_id, + driver_id=ray.ObjectID(driver_id), data={"name": name}) diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 0ca708ec30fd..62b52d82da23 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -5,8 +5,6 @@ import os -from ray.raylet import ObjectID - def env_integer(key, default): if key in os.environ: @@ -15,8 +13,6 @@ def env_integer(key, default): ID_SIZE = 20 -NIL_JOB_ID = ObjectID(ID_SIZE * b"\xff") -NIL_FUNCTION_ID = NIL_JOB_ID # The default maximum number of bytes to allocate to the object store unless # overridden by the user. diff --git a/python/ray/raylet/__init__.py b/python/ray/raylet/__init__.py index 69545f5c6936..67a1976b8d50 100644 --- a/python/ray/raylet/__init__.py +++ b/python/ray/raylet/__init__.py @@ -4,10 +4,10 @@ from ray.core.src.ray.raylet.libraylet_library_python import ( Task, RayletClient, ObjectID, check_simple_value, compute_task_id, - task_from_string, task_to_string, _config, common_error) + task_from_string, task_to_string, _config, RayCommonError) __all__ = [ "Task", "RayletClient", "ObjectID", "check_simple_value", "compute_task_id", "task_from_string", "task_to_string", - "start_local_scheduler", "_config", "common_error" + "start_local_scheduler", "_config", "RayCommonError" ] diff --git a/python/ray/utils.py b/python/ray/utils.py index 44fa5b8dd933..3b660befd0d8 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -67,10 +67,10 @@ def push_error_to_driver(worker, will be serialized with json and stored in Redis. """ if driver_id is None: - driver_id = ray_constants.NIL_JOB_ID.id() + driver_id = ray.ObjectID.nil_id() data = {} if data is None else data - worker.raylet_client.push_error( - ray.ObjectID(driver_id), error_type, message, time.time()) + worker.raylet_client.push_error(driver_id, error_type, message, + time.time()) def push_error_to_driver_through_redis(redis_client, @@ -96,15 +96,16 @@ def push_error_to_driver_through_redis(redis_client, will be serialized with json and stored in Redis. """ if driver_id is None: - driver_id = ray_constants.NIL_JOB_ID.id() + driver_id = ray.ObjectID.nil_id() data = {} if data is None else data # Do everything in Python and through the Python Redis client instead # of through the raylet. error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, message, time.time()) - redis_client.execute_command( - "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data) + redis_client.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.ERROR_INFO, + ray.gcs_utils.TablePubsub.ERROR_INFO, + driver_id.id(), error_data) def is_cython(obj): @@ -400,7 +401,7 @@ def check_oversized_pickle(pickled, name, obj_type, worker): worker, ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, warning_message, - driver_id=worker.task_driver_id.id()) + driver_id=worker.task_driver_id) class _ThreadSafeProxy(object): diff --git a/python/ray/worker.py b/python/ray/worker.py index 4c8dde100379..1531f54e6bc3 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -36,6 +36,7 @@ import ray.plasma import ray.ray_constants as ray_constants from ray import import_thread +from ray import ObjectID from ray import profiling from ray.function_manager import (FunctionActorManager, FunctionDescriptor) import ray.parameter @@ -53,13 +54,6 @@ ERROR_KEY_PREFIX = b"Error:" -# This must match the definition of NIL_ACTOR_ID in task.h. -NIL_ID = ray_constants.ID_SIZE * b"\xff" -NIL_LOCAL_SCHEDULER_ID = NIL_ID -NIL_ACTOR_ID = NIL_ID -NIL_ACTOR_HANDLE_ID = NIL_ID -NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" - # Default resource requirements for actors when no resource requirements are # specified. DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE = 1 @@ -168,7 +162,7 @@ def __init__(self): self.serialization_context_map = {} self.function_actor_manager = FunctionActorManager(self) # Identity of the driver that this worker is processing. - self.task_driver_id = ray.ObjectID(NIL_ID) + self.task_driver_id = ObjectID.nil_id() self._task_context = threading.local() @property @@ -189,14 +183,13 @@ def task_context(self): # If this is running on the main thread, initialize it to # NIL. The actual value will set when the worker receives # a task from raylet backend. - self._task_context.current_task_id = ray.ObjectID(NIL_ID) + self._task_context.current_task_id = ObjectID.nil_id() else: # If this is running on a separate thread, then the mapping # to the current task ID may not be correct. Generate a # random task ID so that the backend can differentiate # between different threads. - self._task_context.current_task_id = ray.ObjectID( - random_string()) + self._task_context.current_task_id = ObjectID(random_string()) if getattr(self, '_multithreading_warned', False) is not True: logger.warning( "Calling ray.get or ray.wait in a separate thread " @@ -353,12 +346,13 @@ def put_object(self, object_id, value): full. """ # Make sure that the value is not an object ID. - if isinstance(value, ray.ObjectID): - raise Exception("Calling 'put' on an ObjectID is not allowed " - "(similarly, returning an ObjectID from a remote " - "function is not allowed). If you really want to " - "do this, you can wrap the ObjectID in a list and " - "call 'put' on it (or return it).") + if isinstance(value, ObjectID): + raise Exception( + "Calling 'put' on an ray.ObjectID is not allowed " + "(similarly, returning an ray.ObjectID from a remote " + "function is not allowed). If you really want to " + "do this, you can wrap the ray.ObjectID in a list and " + "call 'put' on it (or return it).") # Serialize and put the object in the object store. try: @@ -433,7 +427,7 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): self, ray_constants.WAIT_FOR_CLASS_PUSH_ERROR, warning_message, - driver_id=self.task_driver_id.id()) + driver_id=self.task_driver_id) warning_sent = True def get_object(self, object_ids): @@ -449,9 +443,10 @@ def get_object(self, object_ids): """ # Make sure that the values are object IDs. for object_id in object_ids: - if not isinstance(object_id, ray.ObjectID): - raise Exception("Attempting to call `get` on the value {}, " - "which is not an ObjectID.".format(object_id)) + if not isinstance(object_id, ObjectID): + raise Exception( + "Attempting to call `get` on the value {}, " + "which is not an ray.ObjectID.".format(object_id)) # Do an initial fetch for remote objects. We divide the fetch into # smaller fetches so as to not block the manager for a prolonged period # of time in a single call. @@ -484,8 +479,7 @@ def get_object(self, object_ids): for unready_id in unready_ids.keys() ] ray_object_ids_to_fetch = [ - ray.ObjectID(unready_id) - for unready_id in unready_ids.keys() + ObjectID(unready_id) for unready_id in unready_ids.keys() ] fetch_request_size = ray._config.worker_fetch_request_size() for i in range(0, len(object_ids_to_fetch), @@ -574,22 +568,22 @@ def submit_task(self, with profiling.profile("submit_task", worker=self): if actor_id is None: assert actor_handle_id is None - actor_id = ray.ObjectID(NIL_ACTOR_ID) - actor_handle_id = ray.ObjectID(NIL_ACTOR_HANDLE_ID) + actor_id = ObjectID.nil_id() + actor_handle_id = ObjectID.nil_id() else: assert actor_handle_id is not None if actor_creation_id is None: - actor_creation_id = ray.ObjectID(NIL_ACTOR_ID) + actor_creation_id = ObjectID.nil_id() if actor_creation_dummy_object_id is None: - actor_creation_dummy_object_id = (ray.ObjectID(NIL_ID)) + actor_creation_dummy_object_id = ObjectID.nil_id() # Put large or complex arguments that are passed by value in the # object store first. args_for_local_scheduler = [] for arg in args: - if isinstance(arg, ray.ObjectID): + if isinstance(arg, ObjectID): args_for_local_scheduler.append(arg) elif ray.raylet.check_simple_value(arg): args_for_local_scheduler.append(arg) @@ -722,7 +716,7 @@ def _get_arguments_for_execution(self, function_name, serialized_args): arguments are being retrieved. serialized_args (List): The arguments to the function. These are either strings representing serialized objects passed by value - or they are ObjectIDs. + or they are ray.ObjectIDs. Returns: The retrieved arguments in addition to the arguments that were @@ -734,7 +728,7 @@ def _get_arguments_for_execution(self, function_name, serialized_args): """ arguments = [] for (i, arg) in enumerate(serialized_args): - if isinstance(arg, ray.ObjectID): + if isinstance(arg, ObjectID): # get the object from the local object store argument = self.get_object([arg])[0] if isinstance(argument, RayTaskError): @@ -838,9 +832,9 @@ def _process_task(self, task, function_execution_info): outputs = function_executor(*arguments) else: if not task.actor_id().is_nil(): - key = task.actor_id().id() + key = task.actor_id() else: - key = task.actor_creation_id().id() + key = task.actor_creation_id() outputs = function_executor(dummy_return_id, self.actors[key], *arguments) except Exception as e: @@ -882,7 +876,7 @@ def _handle_process_task_failure(self, function_descriptor, self, ray_constants.TASK_PUSH_ERROR, str(failure_object), - driver_id=self.task_driver_id.id(), + driver_id=self.task_driver_id, data={ "function_id": function_id.id(), "function_name": function_name, @@ -890,7 +884,7 @@ def _handle_process_task_failure(self, function_descriptor, "class_name": function_descriptor.class_name }) # Mark the actor init as failed - if self.actor_id != NIL_ACTOR_ID and function_name == "__init__": + if not self.actor_id.is_nil() and function_name == "__init__": self.mark_actor_init_failed(error) def _wait_for_and_process_task(self, task): @@ -901,13 +895,13 @@ def _wait_for_and_process_task(self, task): """ function_descriptor = FunctionDescriptor.from_bytes_list( task.function_descriptor_list()) - driver_id = task.driver_id().id() + driver_id = task.driver_id() # TODO(rkn): It would be preferable for actor creation tasks to share # more of the code path with regular task execution. if not task.actor_creation_id().is_nil(): - assert self.actor_id == NIL_ACTOR_ID - self.actor_id = task.actor_creation_id().id() + assert self.actor_id.is_nil() + self.actor_id = task.actor_creation_id() self.function_actor_manager.load_actor(driver_id, function_descriptor) @@ -930,12 +924,12 @@ def _wait_for_and_process_task(self, task): title = "ray_worker:{}()".format(function_name) next_title = "ray_worker" else: - actor = self.actors[task.actor_creation_id().id()] + actor = self.actors[task.actor_creation_id()] title = "ray_{}:{}()".format(actor.__class__.__name__, function_name) next_title = "ray_{}".format(actor.__class__.__name__) else: - actor = self.actors[task.actor_id().id()] + actor = self.actors[task.actor_id()] title = "ray_{}:{}()".format(actor.__class__.__name__, function_name) next_title = "ray_{}".format(actor.__class__.__name__) @@ -943,14 +937,14 @@ def _wait_for_and_process_task(self, task): with _changeproctitle(title, next_title): self._process_task(task, execution_info) # Reset the state fields so the next task can run. - self.task_context.current_task_id = ray.ObjectID(NIL_ID) + self.task_context.current_task_id = ObjectID.nil_id() self.task_context.task_index = 0 self.task_context.put_index = 1 - if self.actor_id == NIL_ACTOR_ID: + if self.actor_id.is_nil(): # Don't need to reset task_driver_id if the worker is an # actor. Because the following tasks should all have the # same driver id. - self.task_driver_id = ray.ObjectID(NIL_ID) + self.task_driver_id = ObjectID.nil_id() # Increase the task execution counter. self.function_actor_manager.increase_task_counter( @@ -1104,17 +1098,17 @@ def error_applies_to_driver(error_key, worker=global_worker): + ray_constants.ID_SIZE), error_key # If the driver ID in the error message is a sequence of all zeros, then # the message is intended for all drivers. - driver_id = error_key[len(ERROR_KEY_PREFIX):( - len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)] - return (driver_id == worker.task_driver_id.id() - or driver_id == ray.ray_constants.NIL_JOB_ID.id()) + driver_id = ObjectID(error_key[len(ERROR_KEY_PREFIX):( + len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)]) + return (driver_id == worker.task_driver_id + or driver_id == ObjectID.nil_id()) def error_info(worker=global_worker): """Return information about failed tasks.""" worker.check_connected() return (global_state.error_messages(job_id=worker.task_driver_id) + - global_state.error_messages(job_id=ray_constants.NIL_JOB_ID)) + global_state.error_messages(job_id=ObjectID.nil_id())) def _initialize_serialization(driver_id, worker=global_worker): @@ -1134,13 +1128,13 @@ def object_id_custom_serializer(obj): return obj.id() def object_id_custom_deserializer(serialized_obj): - return ray.ObjectID(serialized_obj) + return ObjectID(serialized_obj) # We register this serializer on each worker instead of calling # register_custom_serializer from the driver so that isinstance still # works. serialization_context.register_type( - ray.ObjectID, + ObjectID, "ray.ObjectID", pickle=False, custom_serializer=object_id_custom_serializer, @@ -1661,7 +1655,7 @@ def listen_error_messages_raylet(worker, task_error_queue): job_id = error_data.JobId() if job_id not in [ worker.task_driver_id.id(), - ray_constants.NIL_JOB_ID.id() + ObjectID.nil_id().id() ]: continue @@ -1772,11 +1766,10 @@ def connect(info, else: # This is the code path of driver mode. if driver_id is None: - driver_id = ray.ObjectID(random_string()) + driver_id = ObjectID(random_string()) - if not isinstance(driver_id, ray.ObjectID): - raise Exception( - "The type of given driver id must be ray.ObjectID.") + if not isinstance(driver_id, ObjectID): + raise Exception("The type of given driver id must be ObjectID.") worker.worker_id = driver_id.id() @@ -1785,11 +1778,11 @@ def connect(info, # responsible for the task so that error messages will be propagated to # the correct driver. if mode != WORKER_MODE: - worker.task_driver_id = ray.ObjectID(worker.worker_id) + worker.task_driver_id = ObjectID(worker.worker_id) # All workers start out as non-actors. A worker can be turned into an actor # after it is created. - worker.actor_id = NIL_ACTOR_ID + worker.actor_id = ObjectID.nil_id() worker.connected = True worker.set_mode(mode) @@ -1920,13 +1913,13 @@ def connect(info, function_descriptor.get_function_descriptor_list(), [], # arguments. 0, # num_returns. - ray.ObjectID(random_string()), # parent_task_id. + ObjectID(random_string()), # parent_task_id. 0, # parent_counter. - ray.ObjectID(NIL_ACTOR_ID), # actor_creation_id. - ray.ObjectID(NIL_ACTOR_ID), # actor_creation_dummy_object_id. + ObjectID.nil_id(), # actor_creation_id. + ObjectID.nil_id(), # actor_creation_dummy_object_id. 0, # max_actor_reconstructions. - ray.ObjectID(NIL_ACTOR_ID), # actor_id. - ray.ObjectID(NIL_ACTOR_ID), # actor_handle_id. + ObjectID.nil_id(), # actor_id. + ObjectID.nil_id(), # actor_handle_id. nil_actor_counter, # actor_counter. [], # new_actor_handles. [], # execution_dependencies. @@ -2148,9 +2141,7 @@ def register_custom_serializer(cls, class_id = ray.utils.binary_to_hex(class_id) if driver_id is None: - driver_id_bytes = worker.task_driver_id.id() - else: - driver_id_bytes = driver_id.id() + driver_id = worker.task_driver_id def register_class_for_serialization(worker_info): # TODO(rkn): We need to be more thoughtful about what to do if custom @@ -2160,7 +2151,7 @@ def register_class_for_serialization(worker_info): # system. serialization_context = worker_info[ - "worker"].get_serialization_context(ray.ObjectID(driver_id_bytes)) + "worker"].get_serialization_context(driver_id) serialization_context.register_type( cls, class_id, @@ -2279,13 +2270,15 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): IDs. """ - if isinstance(object_ids, ray.ObjectID): + if isinstance(object_ids, ObjectID): raise TypeError( - "wait() expected a list of ObjectID, got a single ObjectID") + "wait() expected a list of ray.ObjectID, got a single ray.ObjectID" + ) if not isinstance(object_ids, list): - raise TypeError("wait() expected a list of ObjectID, got {}".format( - type(object_ids))) + raise TypeError( + "wait() expected a list of ray.ObjectID, got {}".format( + type(object_ids))) if isinstance(timeout, int) and timeout != 0: logger.warning("The 'timeout' argument now requires seconds instead " @@ -2298,8 +2291,8 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): if worker.mode != LOCAL_MODE: for object_id in object_ids: - if not isinstance(object_id, ray.ObjectID): - raise TypeError("wait() expected a list of ObjectID, " + if not isinstance(object_id, ObjectID): + raise TypeError("wait() expected a list of ray.ObjectID, " "got list containing {}".format( type(object_id))) diff --git a/src/ray/raylet/lib/python/common_extension.cc b/src/ray/raylet/lib/python/common_extension.cc index b8a51787e66b..ccfcf5cb3b12 100644 --- a/src/ray/raylet/lib/python/common_extension.cc +++ b/src/ray/raylet/lib/python/common_extension.cc @@ -30,8 +30,6 @@ using ray::UniqueID; using ray::FunctionID; using ray::TaskID; -PyObject *CommonError; - /* Initialize pickle module. */ PyObject *pickle_module = NULL; @@ -117,13 +115,18 @@ static int PyObjectID_init(PyObjectID *self, PyObject *args, PyObject *kwds) { return -1; } if (size != sizeof(ObjectID)) { - PyErr_SetString(CommonError, "ObjectID: object id string needs to have length 20"); + PyErr_SetString(PyExc_ValueError, + "ObjectID: object id string needs to have length 20"); return -1; } std::memcpy(self->object_id.mutable_data(), data, sizeof(self->object_id)); return 0; } +static PyObject *PyObjectID_nil_id(PyObject *cls) { + return PyObjectID_make(ray::UniqueID()); +} + /* Create a PyObjectID from C. */ PyObject *PyObjectID_make(ObjectID object_id) { PyObjectID *result = PyObject_New(PyObjectID, &PyObjectIDType); @@ -269,9 +272,14 @@ static PyObject *PyObjectID_repr(PyObjectID *self) { return result; } -static PyObject *PyObjectID___reduce__(PyObjectID *self) { - PyErr_SetString(CommonError, "ObjectID objects cannot be serialized."); - return NULL; +static PyObject *PyObjectID_getstate(PyObjectID *self) { + PyObject *field; + field = PyBytes_FromStringAndSize((char *)self->object_id.data(), sizeof(ObjectID)); + return Py_BuildValue("(N)", field); +} + +static PyObject *PyObjectID___reduce__(PyObjectID *self, PyObject *arg) { + return Py_BuildValue("(ON)", Py_TYPE(self), PyObjectID_getstate(self)); } static PyMethodDef PyObjectID_methods[] = { @@ -283,9 +291,10 @@ static PyMethodDef PyObjectID_methods[] = { "Return the object ID as a string in hex."}, {"is_nil", (PyCFunction)PyObjectID_is_nil, METH_NOARGS, "Return whether the ObjectID is nil"}, - {"__reduce__", (PyCFunction)PyObjectID___reduce__, METH_NOARGS, - "Say how to pickle this ObjectID. This raises an exception to prevent" - "object IDs from being serialized."}, + {"__reduce__", (PyCFunction)PyObjectID___reduce__, METH_VARARGS, + "Provide a way to pickle this ObjectID."}, + {"nil_id", (PyCFunction)PyObjectID_nil_id, METH_NOARGS | METH_CLASS, + "Create an instance of ray.ObjectID from random string"}, {NULL} /* Sentinel */ }; @@ -293,9 +302,11 @@ static PyMemberDef PyObjectID_members[] = { {NULL} /* Sentinel */ }; +// This python class is introduced by python/ray/raylet/__init__.py. +// Therefore, tp_name should match the path. ray.ObjectID is also OK. PyTypeObject PyObjectIDType = { PyVarObject_HEAD_INIT(NULL, 0) /* ob_size */ - "common.ObjectID", /* tp_name */ + "ray.raylet.ObjectID", /* tp_name */ sizeof(PyObjectID), /* tp_basicsize */ 0, /* tp_itemsize */ 0, /* tp_dealloc */ diff --git a/src/ray/raylet/lib/python/common_extension.h b/src/ray/raylet/lib/python/common_extension.h index 4dc346b96819..a0c26e53babf 100644 --- a/src/ray/raylet/lib/python/common_extension.h +++ b/src/ray/raylet/lib/python/common_extension.h @@ -12,7 +12,7 @@ typedef char TaskSpec; class TaskBuilder; -extern PyObject *CommonError; +extern PyObject *ray_common_error; // clang-format off typedef struct { diff --git a/src/ray/raylet/lib/python/raylet_extension.cc b/src/ray/raylet/lib/python/raylet_extension.cc index a3f02b213fe5..755ddf1ddd1c 100644 --- a/src/ray/raylet/lib/python/raylet_extension.cc +++ b/src/ray/raylet/lib/python/raylet_extension.cc @@ -31,6 +31,8 @@ static int PyRayletClient_init(PyRayletClient *self, PyObject *args, PyObject *k return 0; } +PyObject *ray_common_error = nullptr; + static void PyRayletClient_dealloc(PyRayletClient *self) { if (self->raylet_client != NULL) { delete self->raylet_client; @@ -112,7 +114,7 @@ static PyObject *PyRayletClient_FetchOrReconstruct(PyRayletClient *self, PyObjec stream << "[RayletClient] FetchOrReconstruct failed: " << "raylet client may be closed, check raylet status. error message: " << status.ToString(); - PyErr_SetString(CommonError, stream.str().c_str()); + PyErr_SetString(ray_common_error, stream.str().c_str()); return NULL; } } @@ -485,9 +487,9 @@ MOD_INIT(libraylet_library_python) { PyModule_AddObject(m, "RayletClient", (PyObject *)&PyRayletClientType); char common_error[] = "common.error"; - CommonError = PyErr_NewException(common_error, NULL, NULL); - Py_INCREF(CommonError); - PyModule_AddObject(m, "common_error", CommonError); + ray_common_error = PyErr_NewException(common_error, NULL, NULL); + Py_INCREF(ray_common_error); + PyModule_AddObject(m, "RayCommonError", ray_common_error); Py_INCREF(&PyRayConfigType); PyModule_AddObject(m, "RayConfig", (PyObject *)&PyRayConfigType); diff --git a/test/failure_test.py b/test/failure_test.py index 93e6808bdd94..03e3dc8a7b94 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -12,10 +12,11 @@ import time import ray.ray_constants as ray_constants -import ray.test.cluster_utils from ray.utils import _random_string import pytest +from ray.test.cluster_utils import Cluster + def relevant_errors(error_type): return [info for info in ray.error_info() if info["type"] == error_type] @@ -620,7 +621,7 @@ def g(): @pytest.fixture def ray_start_two_nodes(): # Start the Ray processes. - cluster = ray.test.cluster_utils.Cluster() + cluster = Cluster() for _ in range(2): cluster.add_node( num_cpus=0, @@ -674,6 +675,8 @@ def sleep_to_kill_raylet(): thread = threading.Thread(target=sleep_to_kill_raylet) thread.start() - with pytest.raises(Exception, match=r".*raylet client may be closed.*"): + with pytest.raises( + ray.raylet.RayCommonError, + match=r".*raylet client may be closed.*"): ray.get(nonexistent_id) thread.join() diff --git a/test/runtest.py b/test/runtest.py index 16a68f98e8a0..5fbe5f20a435 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -17,12 +17,13 @@ from concurrent.futures import ThreadPoolExecutor import numpy as np +import pickle import pytest import ray -import ray.ray_constants as ray_constants import ray.test.cluster_utils import ray.test.test_utils +from ray.utils import _random_string logger = logging.getLogger(__name__) @@ -301,8 +302,7 @@ def method(self): f f = Foo() - with pytest.raises(ray.raylet.common_error): - ray.put(f) + ray.put(f) def test_put_get(shutdown_only): @@ -2301,8 +2301,7 @@ def test_global_state_api(shutdown_only): driver_id = ray.experimental.state.binary_to_hex( ray.worker.global_worker.worker_id) - driver_task_id = ray.experimental.state.binary_to_hex( - ray.worker.global_worker.current_task_id.id()) + driver_task_id = ray.worker.global_worker.current_task_id.hex() # One task is put in the task table which corresponds to this driver. wait_for_num_tasks(1) @@ -2310,12 +2309,13 @@ def test_global_state_api(shutdown_only): assert len(task_table) == 1 assert driver_task_id == list(task_table.keys())[0] task_spec = task_table[driver_task_id]["TaskSpec"] + nil_id_hex = ray.ObjectID.nil_id().hex() assert task_spec["TaskID"] == driver_task_id - assert task_spec["ActorID"] == ray_constants.ID_SIZE * "ff" + assert task_spec["ActorID"] == nil_id_hex assert task_spec["Args"] == [] assert task_spec["DriverID"] == driver_id - assert task_spec["FunctionID"] == ray_constants.ID_SIZE * "ff" + assert task_spec["FunctionID"] == nil_id_hex assert task_spec["ReturnObjectIDs"] == [] client_table = ray.global_state.client_table() @@ -2341,7 +2341,7 @@ def f(*xs): function_table = ray.global_state.function_table() task_spec = task_table[task_id]["TaskSpec"] - assert task_spec["ActorID"] == ray_constants.ID_SIZE * "ff" + assert task_spec["ActorID"] == nil_id_hex assert task_spec["Args"] == [1, "hi", x_id] assert task_spec["DriverID"] == driver_id assert task_spec["ReturnObjectIDs"] == [result_id] @@ -2455,6 +2455,24 @@ def f(): ray.shutdown() +def test_object_id_properties(): + id_bytes = b"00112233445566778899" + object_id = ray.ObjectID(id_bytes) + assert object_id.id() == id_bytes + object_id = ray.ObjectID.nil_id() + assert object_id.is_nil() + with pytest.raises(ValueError, match=r".*needs to have length 20.*"): + ray.ObjectID(id_bytes + b"1234") + with pytest.raises(ValueError, match=r".*needs to have length 20.*"): + ray.ObjectID(b"0123456789") + object_id = ray.ObjectID(_random_string()) + assert not object_id.is_nil() + assert object_id.id() != id_bytes + id_dumps = pickle.dumps(object_id) + id_from_dumps = pickle.loads(id_dumps) + assert id_from_dumps == object_id + + @pytest.fixture def shutdown_only_with_initialization_check(): yield None @@ -2514,7 +2532,7 @@ def unique_1(): def test_duplicate_error_messages(shutdown_only): ray.init(num_cpus=0) - driver_id = ray.ray_constants.NIL_JOB_ID.id() + driver_id = ray.ObjectID.nil_id() error_data = ray.gcs_utils.construct_error_message(driver_id, "test", "message", 0) @@ -2524,13 +2542,13 @@ def test_duplicate_error_messages(shutdown_only): r = ray.worker.global_worker.redis_client r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, + ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.id(), error_data) # Before https://github.com/ray-project/ray/pull/3316 this would # give an error r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, + ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.id(), error_data)