Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/ray/import_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ def f():

def fetch_and_execute_function_to_run(self, key):
"""Run on arbitrary function on the worker."""
driver_id, serialized_function = self.redis_client.hmget(
key, ["driver_id", "function"])
(driver_id, serialized_function,
run_on_other_drivers) = self.redis_client.hmget(
key, ["driver_id", "function", "run_on_other_drivers"])

if (self.worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE]
if (run_on_other_drivers == "False"
and self.worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some issues with this if statement. In particular, run_on_other_drivers can be b"False". I think there are other issues as well. See #2769

and driver_id != self.worker.task_driver_id.id()):
# This export was from a different driver and there's no need for
# this driver to import it.
return

try:
Expand Down
162 changes: 113 additions & 49 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,25 @@ def __init__(self):
self.original_gpu_ids = ray.utils.get_cuda_visible_devices()
self.profiler = profiling.Profiler(self)
self.state_lock = threading.Lock()
# A dictionary that maps from driver id to SerializationContext
# TODO: clean up the SerializationContext once the job finished.
self.serialization_context_map = {}
# Identity of the driver that this worker is processing.
self.task_driver_id = None

def get_serialization_context(self, driver_id):
"""Get the SerializationContext of the driver that this worker is processing.

Args:
driver_id: The ID of the driver that indicates which driver to get
the serialization context for.

Returns:
The serialization context of the given driver.
"""
if driver_id not in self.serialization_context_map:
_initialize_serialization(driver_id)
return self.serialization_context_map[driver_id]

def check_connected(self):
"""Check if the worker is connected.
Expand Down Expand Up @@ -308,7 +327,8 @@ def store_and_register(self, object_id, value, depth=100):
value,
object_id=pyarrow.plasma.ObjectID(object_id.id()),
memcopy_threads=self.memcopy_threads,
serialization_context=self.serialization_context)
serialization_context=self.get_serialization_context(
self.task_driver_id))
break
except pyarrow.SerializationCallbackError as e:
try:
Expand Down Expand Up @@ -400,7 +420,8 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10):
results += self.plasma_client.get(
object_ids[i:(
i + ray._config.worker_get_request_size())],
timeout, self.serialization_context)
timeout,
self.get_serialization_context(self.task_driver_id))
return results
except pyarrow.lib.ArrowInvalid:
# TODO(ekl): the local scheduler could include relevant
Expand Down Expand Up @@ -690,7 +711,8 @@ def export_remote_function(self, function_id, function_name, function,
})
self.redis_client.rpush("Exports", key)

def run_function_on_all_workers(self, function):
def run_function_on_all_workers(self, function,
run_on_other_drivers=False):
"""Run arbitrary code on all of the workers.

This function will first be run on the driver, and then it will be
Expand All @@ -702,6 +724,9 @@ def run_function_on_all_workers(self, function):
function (Callable): The function to run on all of the workers. It
should not take any arguments. If it returns anything, its
return values will not be used.
run_on_other_drivers: The boolean that indicates whether we want to
run this funtion on other drivers. One case is we may need to
share objects across drivers.
"""
# If ray.init has not been called yet, then cache the function and
# export it when connect is called. Otherwise, run the function on all
Expand Down Expand Up @@ -734,7 +759,8 @@ def run_function_on_all_workers(self, function):
key, {
"driver_id": self.task_driver_id.id(),
"function_id": function_to_run_id,
"function": pickled_function
"function": pickled_function,
"run_on_other_drivers": run_on_other_drivers
})
self.redis_client.rpush("Exports", key)
# TODO(rkn): If the worker fails after it calls setnx and before it
Expand Down Expand Up @@ -1209,17 +1235,17 @@ def error_info(worker=global_worker):
return errors


def _initialize_serialization(worker=global_worker):
def _initialize_serialization(driver_id, worker=global_worker):
"""Initialize the serialization library.

This defines a custom serializer for object IDs and also tells ray to
serialize several exception classes that we define for error handling.
"""
worker.serialization_context = pyarrow.default_serialization_context()
serialization_context = pyarrow.default_serialization_context()
# Tell the serialization context to use the cloudpickle version that we
# ship with Ray.
worker.serialization_context.set_pickle(pickle.dumps, pickle.loads)
pyarrow.register_torch_serialization_handlers(worker.serialization_context)
serialization_context.set_pickle(pickle.dumps, pickle.loads)
pyarrow.register_torch_serialization_handlers(serialization_context)

# Define a custom serializer and deserializer for handling Object IDs.
def object_id_custom_serializer(obj):
Expand All @@ -1231,7 +1257,7 @@ def object_id_custom_deserializer(serialized_obj):
# We register this serializer on each worker instead of calling
# register_custom_serializer from the driver so that isinstance still
# works.
worker.serialization_context.register_type(
serialization_context.register_type(
ray.ObjectID,
"ray.ObjectID",
pickle=False,
Expand All @@ -1249,28 +1275,55 @@ def actor_handle_deserializer(serialized_obj):
# We register this serializer on each worker instead of calling
# register_custom_serializer from the driver so that isinstance still
# works.
worker.serialization_context.register_type(
serialization_context.register_type(
ray.actor.ActorHandle,
"ray.ActorHandle",
pickle=False,
custom_serializer=actor_handle_serializer,
custom_deserializer=actor_handle_deserializer)

if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
# These should only be called on the driver because
# register_custom_serializer will export the class to all of the
# workers.
register_custom_serializer(RayTaskError, use_dict=True)
register_custom_serializer(RayGetError, use_dict=True)
register_custom_serializer(RayGetArgumentError, use_dict=True)
# Tell Ray to serialize lambdas with pickle.
register_custom_serializer(type(lambda: 0), use_pickle=True)
# Tell Ray to serialize types with pickle.
register_custom_serializer(type(int), use_pickle=True)
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
# used when passing around actor handles.
register_custom_serializer(
ray.signature.FunctionSignature, use_dict=True)
worker.serialization_context_map[driver_id] = serialization_context

register_custom_serializer(
RayTaskError,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.RayTaskError")
register_custom_serializer(
RayGetError,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.RayGetError")
register_custom_serializer(
RayGetArgumentError,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.RayGetArgumentError")
# Tell Ray to serialize lambdas with pickle.
register_custom_serializer(
type(lambda: 0),
use_pickle=True,
local=True,
driver_id=driver_id,
class_id="lambda")
# Tell Ray to serialize types with pickle.
register_custom_serializer(
type(int),
use_pickle=True,
local=True,
driver_id=driver_id,
class_id="type")
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
# used when passing around actor handles.
register_custom_serializer(
ray.signature.FunctionSignature,
use_dict=True,
local=True,
driver_id=driver_id,
class_id="ray.signature.FunctionSignature")


def get_address_info_from_redis_helper(redis_address,
Expand Down Expand Up @@ -2167,10 +2220,6 @@ def connect(info,
# driver task.
worker.current_task_id = driver_task.task_id()

# Initialize the serialization library. This registers some classes, and so
# it must be run before we export all of the cached remote functions.
_initialize_serialization()

# Start the import thread
import_thread.ImportThread(worker, mode).start()

Expand Down Expand Up @@ -2242,7 +2291,7 @@ def disconnect(worker=global_worker):
worker.connected = False
worker.cached_functions_to_run = []
worker.cached_remote_functions_and_actors = []
worker.serialization_context = pyarrow.SerializationContext()
worker.serialization_context_map.clear()


def _try_to_compute_deterministic_class_id(cls, depth=5):
Expand Down Expand Up @@ -2293,6 +2342,8 @@ def register_custom_serializer(cls,
serializer=None,
deserializer=None,
local=False,
driver_id=None,
class_id=None,
worker=global_worker):
"""Enable serialization and deserialization for a particular class.

Expand All @@ -2313,6 +2364,9 @@ def register_custom_serializer(cls,
if and only if use_pickle and use_dict are False.
local: True if the serializers should only be registered on the current
worker. This should usually be False.
driver_id: ID of the driver that we want to register the class for.
class_id: ID of the class that we are registering. If this is not
specified, we will calculate a new one inside the function.

Raises:
Exception: An exception is raised if pickle=False and the class cannot
Expand All @@ -2332,33 +2386,43 @@ def register_custom_serializer(cls,
# Raise an exception if cls cannot be serialized efficiently by Ray.
serialization.check_serializable(cls)

if not local:
# In this case, the class ID will be used to deduplicate the class
# across workers. Note that cloudpickle unfortunately does not produce
# deterministic strings, so these IDs could be different on different
# workers. We could use something weaker like cls.__name__, however
# that would run the risk of having collisions. TODO(rkn): We should
# improve this.
try:
# Attempt to produce a class ID that will be the same on each
# worker. However, determinism is not guaranteed, and the result
# may be different on different workers.
class_id = _try_to_compute_deterministic_class_id(cls)
except Exception:
raise serialization.CloudPickleError("Failed to pickle class "
"'{}'".format(cls))
if class_id is None:
if not local:
# In this case, the class ID will be used to deduplicate the class
# across workers. Note that cloudpickle unfortunately does not
# produce deterministic strings, so these IDs could be different
# on different workers. We could use something weaker like
# cls.__name__, however that would run the risk of having
# collisions.
# TODO(rkn): We should improve this.
try:
# Attempt to produce a class ID that will be the same on each
# worker. However, determinism is not guaranteed, and the
# result may be different on different workers.
class_id = _try_to_compute_deterministic_class_id(cls)
except Exception as e:
raise serialization.CloudPickleError("Failed to pickle class "
"'{}'".format(cls))
else:
# In this case, the class ID only needs to be meaningful on this
# worker and not across workers.
class_id = random_string()

if driver_id is None:
driver_id_bytes = worker.task_driver_id.id()
else:
# In this case, the class ID only needs to be meaningful on this worker
# and not across workers.
class_id = random_string()
driver_id_bytes = driver_id.id()

def register_class_for_serialization(worker_info):
# TODO(rkn): We need to be more thoughtful about what to do if custom
# serializers have already been registered for class_id. In some cases,
# we may want to use the last user-defined serializers and ignore
# subsequent calls to register_custom_serializer that were made by the
# system.
worker_info["worker"].serialization_context.register_type(

serialization_context = worker_info[
"worker"].get_serialization_context(ray.ObjectID(driver_id_bytes))
serialization_context.register_type(
cls,
class_id,
pickle=use_pickle,
Expand Down