diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 4d99dead0f6e..18381e3c4871 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -51,7 +51,7 @@ from ray._raylet import (UniqueID, ObjectID, DriverID, ClientID, ActorID, ActorHandleID, FunctionID, ActorClassID, TaskID, - Config as _Config) # noqa: E402 + _ID_TYPES, Config as _Config) # noqa: E402 _config = _Config() @@ -77,7 +77,8 @@ "remote", "profile", "actor", "method", "get_gpu_ids", "get_resource_ids", "get_webui_url", "register_custom_serializer", "shutdown", "is_initialized", "SCRIPT_MODE", "WORKER_MODE", "LOCAL_MODE", - "PYTHON_MODE", "global_state", "_config", "__version__", "internal" + "PYTHON_MODE", "global_state", "_config", "__version__", "internal", + "_ID_TYPES" ] __all__ += [ diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index a29476742163..1f31179e4fce 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -4,6 +4,8 @@ We define different types for different IDs for type safety. See https://github.com/ray-project/ray/issues/3721. """ +# WARNING: Any additional ID types defined in this file must be added to the +# _ID_TYPES list at the bottom of this file. from ray.includes.common cimport ( CUniqueID, CTaskID, CObjectID, CFunctionID, CActorClassID, CActorID, CActorHandleID, CWorkerID, CDriverID, CConfigID, CClientID, @@ -278,3 +280,7 @@ cdef class ActorClassID(UniqueID): def __repr__(self): return "ActorClassID(" + self.hex() + ")" + + +_ID_TYPES = [UniqueID, ObjectID, TaskID, ClientID, DriverID, ActorID, + ActorHandleID, FunctionID, ActorClassID] diff --git a/python/ray/worker.py b/python/ray/worker.py index 72198bcaf27c..afb3010e2738 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1099,22 +1099,11 @@ def _initialize_serialization(driver_id, worker=global_worker): serialization_context.set_pickle(pickle.dumps, pickle.loads) pyarrow.register_torch_serialization_handlers(serialization_context) - # Define a custom serializer and deserializer for handling Object IDs. - def object_id_custom_serializer(obj): - return obj.binary() - - def object_id_custom_deserializer(serialized_obj): - return ObjectID(serialized_obj) - - # We register this serializer on each worker instead of calling - # register_custom_serializer from the driver so that isinstance still - # works. - serialization_context.register_type( - ObjectID, - "ray.ObjectID", - pickle=False, - custom_serializer=object_id_custom_serializer, - custom_deserializer=object_id_custom_deserializer) + for id_type in ray._ID_TYPES: + serialization_context.register_type( + id_type, + "{}.{}".format(id_type.__module__, id_type.__name__), + pickle=True) def actor_handle_serializer(obj): return obj._serialization_helper(True)