diff --git a/python/ray/worker.py b/python/ray/worker.py index 181b58147fd4..70d48f5a3d13 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -237,6 +237,20 @@ def __init__(self): # When the worker is constructed. Record the original value of the # CUDA_VISIBLE_DEVICES environment variable. self.original_gpu_ids = ray.utils.get_cuda_visible_devices() + # A dictionary that maps from driver id to SerializationContext + # TODO: clean up the SerializationContext once the job finished. + self.serialization_context_map = dict() + # Identity of the driver that this worker is processing. + self.task_driver_id = None + + def get_serialization_context(self): + """Get the SerializationContext of the driver that this worker is processing + """ + if self.task_driver_id == None: + return None + if not self.serialization_context_map.has_key(self.task_driver_id): + _initialize_serialization() + return self.serialization_context_map[self.task_driver_id] def check_connected(self): """Check if the worker is connected. @@ -300,7 +314,7 @@ def store_and_register(self, object_id, value, depth=100): value, object_id=pyarrow.plasma.ObjectID(object_id.id()), memcopy_threads=self.memcopy_threads, - serialization_context=self.serialization_context) + serialization_context=self.get_serialization_context()) break except pyarrow.SerializationCallbackError as e: try: @@ -392,7 +406,7 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): results += self.plasma_client.get( object_ids[i:( i + ray._config.worker_get_request_size())], - timeout, self.serialization_context) + timeout, self.get_serialization_context()) return results except pyarrow.lib.ArrowInvalid as e: # TODO(ekl): the local scheduler could include relevant @@ -674,6 +688,50 @@ def export_remote_function(self, function_id, function_name, function, }) self.redis_client.rpush("Exports", key) + def register_class_on_all_workers(self, function): + """ Register type on all of the workers in the same driver.""" + check_main_thread() + + # Attempt to pickle the function before we need it. This could + # fail, and it is more convenient if the failure happens before we + # actually run the function locally. + pickled_function = pickle.dumps(function) + function_to_run_id = hashlib.sha1(pickled_function).digest() + key = b"RegisterType:{}:{}".format(self.task_driver_id.id(), function_to_run_id) + # First register the type on the driver. + # We always run the task locally. + function({"worker": self}) + # Check if the type has already been put into redis. + type_exported = self.redis_client.setnx(b"Lock:" + key, 1) + if not type_exported: + # In this case, the type has already been exported, so + # we don't need to export it again. + return + + if (len(pickled_function) > + ray_constants.PICKLE_OBJECT_WARNING_SIZE): + warning_message = ("Warning: The function {} has size {} when " + "pickled. It will be stored in Redis, " + "which could cause memory issues. This may " + "mean that the remote function definition " + "uses a large array or other object." + .format(function.__name__, + len(pickled_function))) + ray.utils.push_error_to_driver( + self, + ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, + warning_message, + driver_id=self.task_driver_id.id()) + + # Run the function on all workers. + self.redis_client.hmset( + key, { + "driver_id": self.task_driver_id.id(), + "function": pickled_function + }) + self.redis_client.rpush("Exports", key) + + def run_function_on_all_workers(self, function): """Run arbitrary code on all of the workers. @@ -688,6 +746,7 @@ def run_function_on_all_workers(self, function): return values will not be used. """ check_main_thread() + # If ray.init has not been called yet, then cache the function and # export it when connect is called. Otherwise, run the function on all # workers. @@ -865,6 +924,10 @@ def _process_task(self, task): function_name = self.function_execution_info[self.task_driver_id.id()][ function_id.id()].function_name + if not self.serialization_context_map.has_key(self.task_driver_id): + _initialize_serialization() + register_existing_class() + # Get task arguments from the object store. try: with profile("task:deserialize_arguments", worker=self): @@ -1213,11 +1276,11 @@ def _initialize_serialization(worker=global_worker): This defines a custom serializer for object IDs and also tells ray to serialize several exception classes that we define for error handling. """ - worker.serialization_context = pyarrow.default_serialization_context() + serialization_context = pyarrow.default_serialization_context() # Tell the serialization context to use the cloudpickle version that we # ship with Ray. - worker.serialization_context.set_pickle(pickle.dumps, pickle.loads) - pyarrow.register_torch_serialization_handlers(worker.serialization_context) + serialization_context.set_pickle(pickle.dumps, pickle.loads) + pyarrow.register_torch_serialization_handlers(serialization_context) # Define a custom serializer and deserializer for handling Object IDs. def object_id_custom_serializer(obj): @@ -1229,7 +1292,7 @@ def object_id_custom_deserializer(serialized_obj): # We register this serializer on each worker instead of calling # register_custom_serializer from the driver so that isinstance still # works. - worker.serialization_context.register_type( + serialization_context.register_type( ray.ObjectID, "ray.ObjectID", pickle=False, @@ -1247,13 +1310,15 @@ def actor_handle_deserializer(serialized_obj): # We register this serializer on each worker instead of calling # register_custom_serializer from the driver so that isinstance still # works. - worker.serialization_context.register_type( + serialization_context.register_type( ray.actor.ActorHandle, "ray.ActorHandle", pickle=False, custom_serializer=actor_handle_serializer, custom_deserializer=actor_handle_deserializer) + worker.serialization_context_map[worker.task_driver_id] = serialization_context + if worker.mode in [SCRIPT_MODE, SILENT_MODE]: # These should only be called on the driver because # register_custom_serializer will export the class to all of the @@ -2002,6 +2067,35 @@ def f(): worker.worker_id) +def fetch_and_register_type(key, worker=global_worker): + """Run the function to register a class on the worker.""" + driver_id, serialized_function = worker.redis_client.hmget( + key, ["driver_id", "function"]) + + if (worker.task_driver_id == None or driver_id != worker.task_driver_id.id()): + # This export was from a different driver and there's no need for this + # driver to import it. + return + + try: + # Deserialize the function. + function = pickle.loads(serialized_function) + # Run the function. + function({"worker": worker}) + except Exception: + # If an exception was thrown when the function was run, we record the + # traceback and notify the scheduler of the failure. + traceback_str = traceback.format_exc() + # Log the error message. + name = function.__name__ if ("function" in locals() + and hasattr(function, "__name__")) else "" + ray.utils.push_error_to_driver( + worker, + ray_constants.FUNCTION_TO_RUN_PUSH_ERROR, + traceback_str, + driver_id=driver_id, + data={"name": name}) + def fetch_and_execute_function_to_run(key, worker=global_worker): """Run on arbitrary function on the worker.""" driver_id, serialized_function = worker.redis_client.hmget( @@ -2033,6 +2127,13 @@ def fetch_and_execute_function_to_run(key, worker=global_worker): data={"name": name}) +def register_existing_class(worker=global_worker): + export_keys = worker.redis_client.lrange("Exports", 0, -1) + for key in export_keys: + if key.startswith(b"RegisterType"): + fetch_and_register_type(key, worker=worker) + + def import_thread(worker, mode): worker.import_pubsub_client = worker.redis_client.pubsub() # Exports that are published after the call to @@ -2041,7 +2142,6 @@ def import_thread(worker, mode): worker.import_pubsub_client.subscribe("__keyspace@0__:Exports") # Keep track of the number of imports that we've imported. num_imported = 0 - # Get the exports that occurred before the call to subscribe. with worker.lock: export_keys = worker.redis_client.lrange("Exports", 0, -1) @@ -2053,8 +2153,9 @@ def import_thread(worker, mode): if key.startswith(b"FunctionsToRun"): with profile("fetch_and_run_function", worker=worker): fetch_and_execute_function_to_run(key, worker=worker) - # Continue because FunctionsToRun are the only things that the - # driver should import. + elif key.startswith(b"RegisterType"): + with profile("fetch_and_register_type", worker=worker): + fetch_and_register_type(key, worker=worker) continue if key.startswith(b"RemoteFunction"): @@ -2063,6 +2164,9 @@ def import_thread(worker, mode): elif key.startswith(b"FunctionsToRun"): with profile("fetch_and_run_function", worker=worker): fetch_and_execute_function_to_run(key, worker=worker) + elif key.startswith(b"RegisterType"): + with profile("fetch_and_register_type", worker=worker): + fetch_and_register_type(key, worker=worker) elif key.startswith(b"ActorClass"): # Keep track of the fact that this actor class has been # exported so that we know it is safe to turn this worker into @@ -2090,8 +2194,12 @@ def import_thread(worker, mode): "fetch_and_run_function", worker=worker): fetch_and_execute_function_to_run( key, worker=worker) - # Continue because FunctionsToRun are the only things - # that the driver should import. + elif key.startswith(b"RegisterType"): + with log_span( + "ray:import_type_to_register", + worker=worker): + fetch_and_register_type( + key, worker=worker) continue if key.startswith(b"RemoteFunction"): @@ -2103,6 +2211,12 @@ def import_thread(worker, mode): with profile("fetch_and_run_function", worker=worker): fetch_and_execute_function_to_run( key, worker=worker) + elif key.startswith(b"RegisterType"): + with log_span( + "ray:import_type_to_register", + worker=worker): + fetch_and_register_type( + key, worker=worker) elif key.startswith(b"ActorClass"): # Keep track of the fact that this actor class has been # exported so that we know it is safe to turn this @@ -2151,6 +2265,7 @@ def connect(info, # the correct driver. if mode != WORKER_MODE: worker.task_driver_id = ray.ObjectID(worker.worker_id) + _initialize_serialization() # All workers start out as non-actors. A worker can be turned into an actor # after it is created. @@ -2324,10 +2439,6 @@ def connect(info, # driver task. worker.current_task_id = driver_task.task_id() - # Initialize the serialization library. This registers some classes, and so - # it must be run before we export all of the cached remote functions. - _initialize_serialization() - # Start a thread to import exports from the driver or from other workers. # Note that the driver also has an import thread, which is used only to # import custom class definitions from calls to register_custom_serializer @@ -2407,7 +2518,7 @@ def disconnect(worker=global_worker): worker.connected = False worker.cached_functions_to_run = [] worker.cached_remote_functions_and_actors = [] - worker.serialization_context = pyarrow.SerializationContext() + worker.serialization_context_map[worker.task_driver_id] = pyarrow.SerializationContext() def _try_to_compute_deterministic_class_id(cls, depth=5): @@ -2461,7 +2572,7 @@ def register_custom_serializer(cls, worker=global_worker): """Enable serialization and deserialization for a particular class. - This method runs the register_class function defined below on every worker, + This method runs the register_type function defined below on every worker, which will enable ray to properly serialize and deserialize objects of this class. @@ -2523,15 +2634,19 @@ def register_class_for_serialization(worker_info): # we may want to use the last user-defined serializers and ignore # subsequent calls to register_custom_serializer that were made by the # system. - worker_info["worker"].serialization_context.register_type( - cls, - class_id, - pickle=use_pickle, - custom_serializer=serializer, - custom_deserializer=deserializer) + serialization_context = worker_info["worker"].get_serialization_context() + # Don't register if serialization_context is None (task_driver_id is None). + # This happens when a worker started but haven't retrieved a task to execute. + if serialization_context: + serialization_context.register_type( + cls, + class_id, + pickle=use_pickle, + custom_serializer=serializer, + custom_deserializer=deserializer) if not local: - worker.run_function_on_all_workers(register_class_for_serialization) + worker.register_class_on_all_workers(register_class_for_serialization) else: # Since we are pickling objects of this class, we don't actually need # to ship the class definition.