diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 6d41f78ebb7f..54c44db8f3ea 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -157,13 +157,13 @@ def f(): def fetch_and_execute_function_to_run(self, key): """Run on arbitrary function on the worker.""" - driver_id, serialized_function = self.redis_client.hmget( - key, ["driver_id", "function"]) + (driver_id, serialized_function, + run_on_other_drivers) = self.redis_client.hmget( + key, ["driver_id", "function", "run_on_other_drivers"]) - if (self.worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE] + if (run_on_other_drivers == "False" + and self.worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE] and driver_id != self.worker.task_driver_id.id()): - # This export was from a different driver and there's no need for - # this driver to import it. return try: diff --git a/python/ray/worker.py b/python/ray/worker.py index e40ec7f91125..1c4c5e5600b7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -245,6 +245,25 @@ def __init__(self): self.original_gpu_ids = ray.utils.get_cuda_visible_devices() self.profiler = profiling.Profiler(self) self.state_lock = threading.Lock() + # A dictionary that maps from driver id to SerializationContext + # TODO: clean up the SerializationContext once the job finished. + self.serialization_context_map = {} + # Identity of the driver that this worker is processing. + self.task_driver_id = None + + def get_serialization_context(self, driver_id): + """Get the SerializationContext of the driver that this worker is processing. + + Args: + driver_id: The ID of the driver that indicates which driver to get + the serialization context for. + + Returns: + The serialization context of the given driver. + """ + if driver_id not in self.serialization_context_map: + _initialize_serialization(driver_id) + return self.serialization_context_map[driver_id] def check_connected(self): """Check if the worker is connected. @@ -308,7 +327,8 @@ def store_and_register(self, object_id, value, depth=100): value, object_id=pyarrow.plasma.ObjectID(object_id.id()), memcopy_threads=self.memcopy_threads, - serialization_context=self.serialization_context) + serialization_context=self.get_serialization_context( + self.task_driver_id)) break except pyarrow.SerializationCallbackError as e: try: @@ -400,7 +420,8 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): results += self.plasma_client.get( object_ids[i:( i + ray._config.worker_get_request_size())], - timeout, self.serialization_context) + timeout, + self.get_serialization_context(self.task_driver_id)) return results except pyarrow.lib.ArrowInvalid: # TODO(ekl): the local scheduler could include relevant @@ -690,7 +711,8 @@ def export_remote_function(self, function_id, function_name, function, }) self.redis_client.rpush("Exports", key) - def run_function_on_all_workers(self, function): + def run_function_on_all_workers(self, function, + run_on_other_drivers=False): """Run arbitrary code on all of the workers. This function will first be run on the driver, and then it will be @@ -702,6 +724,9 @@ def run_function_on_all_workers(self, function): function (Callable): The function to run on all of the workers. It should not take any arguments. If it returns anything, its return values will not be used. + run_on_other_drivers: The boolean that indicates whether we want to + run this funtion on other drivers. One case is we may need to + share objects across drivers. """ # If ray.init has not been called yet, then cache the function and # export it when connect is called. Otherwise, run the function on all @@ -734,7 +759,8 @@ def run_function_on_all_workers(self, function): key, { "driver_id": self.task_driver_id.id(), "function_id": function_to_run_id, - "function": pickled_function + "function": pickled_function, + "run_on_other_drivers": run_on_other_drivers }) self.redis_client.rpush("Exports", key) # TODO(rkn): If the worker fails after it calls setnx and before it @@ -1209,17 +1235,17 @@ def error_info(worker=global_worker): return errors -def _initialize_serialization(worker=global_worker): +def _initialize_serialization(driver_id, worker=global_worker): """Initialize the serialization library. This defines a custom serializer for object IDs and also tells ray to serialize several exception classes that we define for error handling. """ - worker.serialization_context = pyarrow.default_serialization_context() + serialization_context = pyarrow.default_serialization_context() # Tell the serialization context to use the cloudpickle version that we # ship with Ray. - worker.serialization_context.set_pickle(pickle.dumps, pickle.loads) - pyarrow.register_torch_serialization_handlers(worker.serialization_context) + serialization_context.set_pickle(pickle.dumps, pickle.loads) + pyarrow.register_torch_serialization_handlers(serialization_context) # Define a custom serializer and deserializer for handling Object IDs. def object_id_custom_serializer(obj): @@ -1231,7 +1257,7 @@ def object_id_custom_deserializer(serialized_obj): # We register this serializer on each worker instead of calling # register_custom_serializer from the driver so that isinstance still # works. - worker.serialization_context.register_type( + serialization_context.register_type( ray.ObjectID, "ray.ObjectID", pickle=False, @@ -1249,28 +1275,55 @@ def actor_handle_deserializer(serialized_obj): # We register this serializer on each worker instead of calling # register_custom_serializer from the driver so that isinstance still # works. - worker.serialization_context.register_type( + serialization_context.register_type( ray.actor.ActorHandle, "ray.ActorHandle", pickle=False, custom_serializer=actor_handle_serializer, custom_deserializer=actor_handle_deserializer) - if worker.mode in [SCRIPT_MODE, SILENT_MODE]: - # These should only be called on the driver because - # register_custom_serializer will export the class to all of the - # workers. - register_custom_serializer(RayTaskError, use_dict=True) - register_custom_serializer(RayGetError, use_dict=True) - register_custom_serializer(RayGetArgumentError, use_dict=True) - # Tell Ray to serialize lambdas with pickle. - register_custom_serializer(type(lambda: 0), use_pickle=True) - # Tell Ray to serialize types with pickle. - register_custom_serializer(type(int), use_pickle=True) - # Tell Ray to serialize FunctionSignatures as dictionaries. This is - # used when passing around actor handles. - register_custom_serializer( - ray.signature.FunctionSignature, use_dict=True) + worker.serialization_context_map[driver_id] = serialization_context + + register_custom_serializer( + RayTaskError, + use_dict=True, + local=True, + driver_id=driver_id, + class_id="ray.RayTaskError") + register_custom_serializer( + RayGetError, + use_dict=True, + local=True, + driver_id=driver_id, + class_id="ray.RayGetError") + register_custom_serializer( + RayGetArgumentError, + use_dict=True, + local=True, + driver_id=driver_id, + class_id="ray.RayGetArgumentError") + # Tell Ray to serialize lambdas with pickle. + register_custom_serializer( + type(lambda: 0), + use_pickle=True, + local=True, + driver_id=driver_id, + class_id="lambda") + # Tell Ray to serialize types with pickle. + register_custom_serializer( + type(int), + use_pickle=True, + local=True, + driver_id=driver_id, + class_id="type") + # Tell Ray to serialize FunctionSignatures as dictionaries. This is + # used when passing around actor handles. + register_custom_serializer( + ray.signature.FunctionSignature, + use_dict=True, + local=True, + driver_id=driver_id, + class_id="ray.signature.FunctionSignature") def get_address_info_from_redis_helper(redis_address, @@ -2167,10 +2220,6 @@ def connect(info, # driver task. worker.current_task_id = driver_task.task_id() - # Initialize the serialization library. This registers some classes, and so - # it must be run before we export all of the cached remote functions. - _initialize_serialization() - # Start the import thread import_thread.ImportThread(worker, mode).start() @@ -2242,7 +2291,7 @@ def disconnect(worker=global_worker): worker.connected = False worker.cached_functions_to_run = [] worker.cached_remote_functions_and_actors = [] - worker.serialization_context = pyarrow.SerializationContext() + worker.serialization_context_map.clear() def _try_to_compute_deterministic_class_id(cls, depth=5): @@ -2293,6 +2342,8 @@ def register_custom_serializer(cls, serializer=None, deserializer=None, local=False, + driver_id=None, + class_id=None, worker=global_worker): """Enable serialization and deserialization for a particular class. @@ -2313,6 +2364,9 @@ def register_custom_serializer(cls, if and only if use_pickle and use_dict are False. local: True if the serializers should only be registered on the current worker. This should usually be False. + driver_id: ID of the driver that we want to register the class for. + class_id: ID of the class that we are registering. If this is not + specified, we will calculate a new one inside the function. Raises: Exception: An exception is raised if pickle=False and the class cannot @@ -2332,25 +2386,32 @@ def register_custom_serializer(cls, # Raise an exception if cls cannot be serialized efficiently by Ray. serialization.check_serializable(cls) - if not local: - # In this case, the class ID will be used to deduplicate the class - # across workers. Note that cloudpickle unfortunately does not produce - # deterministic strings, so these IDs could be different on different - # workers. We could use something weaker like cls.__name__, however - # that would run the risk of having collisions. TODO(rkn): We should - # improve this. - try: - # Attempt to produce a class ID that will be the same on each - # worker. However, determinism is not guaranteed, and the result - # may be different on different workers. - class_id = _try_to_compute_deterministic_class_id(cls) - except Exception: - raise serialization.CloudPickleError("Failed to pickle class " - "'{}'".format(cls)) + if class_id is None: + if not local: + # In this case, the class ID will be used to deduplicate the class + # across workers. Note that cloudpickle unfortunately does not + # produce deterministic strings, so these IDs could be different + # on different workers. We could use something weaker like + # cls.__name__, however that would run the risk of having + # collisions. + # TODO(rkn): We should improve this. + try: + # Attempt to produce a class ID that will be the same on each + # worker. However, determinism is not guaranteed, and the + # result may be different on different workers. + class_id = _try_to_compute_deterministic_class_id(cls) + except Exception as e: + raise serialization.CloudPickleError("Failed to pickle class " + "'{}'".format(cls)) + else: + # In this case, the class ID only needs to be meaningful on this + # worker and not across workers. + class_id = random_string() + + if driver_id is None: + driver_id_bytes = worker.task_driver_id.id() else: - # In this case, the class ID only needs to be meaningful on this worker - # and not across workers. - class_id = random_string() + driver_id_bytes = driver_id.id() def register_class_for_serialization(worker_info): # TODO(rkn): We need to be more thoughtful about what to do if custom @@ -2358,7 +2419,10 @@ def register_class_for_serialization(worker_info): # we may want to use the last user-defined serializers and ignore # subsequent calls to register_custom_serializer that were made by the # system. - worker_info["worker"].serialization_context.register_type( + + serialization_context = worker_info[ + "worker"].get_serialization_context(ray.ObjectID(driver_id_bytes)) + serialization_context.register_type( cls, class_id, pickle=use_pickle,