Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions python/ray/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ray.signature as signature
import ray.worker
from ray.utils import _random_string
from ray import ObjectID

DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1

Expand All @@ -41,8 +42,7 @@ def compute_actor_handle_id(actor_handle_id, num_forks):
handle_id_hash.update(actor_handle_id.id())
handle_id_hash.update(str(num_forks).encode("ascii"))
handle_id = handle_id_hash.digest()
assert len(handle_id) == ray_constants.ID_SIZE
return ray.ObjectID(handle_id)
return ObjectID(handle_id)


def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id):
Expand All @@ -69,8 +69,7 @@ def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id):
handle_id_hash.update(actor_handle_id.id())
handle_id_hash.update(current_task_id.id())
handle_id = handle_id_hash.digest()
assert len(handle_id) == ray_constants.ID_SIZE
return ray.ObjectID(handle_id)
return ObjectID(handle_id)


def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
Expand All @@ -84,7 +83,7 @@ def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
checkpoint: The state object to save.
frontier: The task frontier at the time of the checkpoint.
"""
actor_key = b"Actor:" + actor_id
actor_key = b"Actor:" + actor_id.id()
worker.redis_client.hmset(
actor_key, {
"checkpoint_index": checkpoint_index,
Expand All @@ -110,7 +109,7 @@ def save_and_log_checkpoint(worker, actor):
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id.id(),
driver_id=worker.task_driver_id,
data={
"actor_class": actor.__class__.__name__,
"function_name": actor.__ray_checkpoint__.__name__
Expand All @@ -134,7 +133,7 @@ def restore_and_log_checkpoint(worker, actor):
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id.id(),
driver_id=worker.task_driver_id,
data={
"actor_class": actor.__class__.__name__,
"function_name": actor.__ray_checkpoint_restore__.__name__
Expand All @@ -156,7 +155,7 @@ def get_actor_checkpoint(worker, actor_id):
exists, all objects are set to None. The checkpoint index is the .
executed on the actor before the checkpoint was made.
"""
actor_key = b"Actor:" + actor_id
actor_key = b"Actor:" + actor_id.id()
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
if checkpoint_index is not None:
Expand Down Expand Up @@ -371,7 +370,7 @@ def _remote(self,
raise Exception("Actors cannot be created before ray.init() "
"has been called.")

actor_id = ray.ObjectID(_random_string())
actor_id = ObjectID(_random_string())
# The actor cursor is a dummy object representing the most recent
# actor method invocation. For each subsequent method invocation,
# the current cursor should be added as a dependency, and then
Expand Down Expand Up @@ -509,8 +508,7 @@ def __init__(self,
# if it was created by the _serialization_helper function.
self._ray_original_handle = actor_handle_id is None
if self._ray_original_handle:
self._ray_actor_handle_id = ray.ObjectID(
ray.worker.NIL_ACTOR_HANDLE_ID)
self._ray_actor_handle_id = ObjectID.nil_id()
else:
self._ray_actor_handle_id = actor_handle_id
self._ray_actor_cursor = actor_cursor
Expand Down Expand Up @@ -713,7 +711,7 @@ def _serialization_helper(self, ray_forking):
# to release, since it could be unpickled and submit another
# dependent task at any time. Therefore, we notify the backend of a
# random handle ID that will never actually be used.
new_actor_handle_id = ray.ObjectID(_random_string())
new_actor_handle_id = ObjectID(_random_string())
# Notify the backend to expect this new actor handle. The backend will
# not release the cursor for any new handles until the first task for
# each of the new handles is submitted.
Expand All @@ -735,7 +733,7 @@ def _deserialization_helper(self, state, ray_forking):
worker.check_connected()

if state["ray_forking"]:
actor_handle_id = ray.ObjectID(state["actor_handle_id"])
actor_handle_id = ObjectID(state["actor_handle_id"])
else:
# Right now, if the actor handle has been pickled, we create a
# temporary actor handle id for invocations.
Expand All @@ -749,22 +747,22 @@ def _deserialization_helper(self, state, ray_forking):
# same actor is likely a performance bug. We should consider
# logging a warning in these cases.
actor_handle_id = compute_actor_handle_id_non_forked(
ray.ObjectID(state["actor_handle_id"]), worker.current_task_id)
ObjectID(state["actor_handle_id"]), worker.current_task_id)

# This is the driver ID of the driver that owns the actor, not
# necessarily the driver that owns this actor handle.
actor_driver_id = ray.ObjectID(state["actor_driver_id"])
actor_driver_id = ObjectID(state["actor_driver_id"])

self.__init__(
ray.ObjectID(state["actor_id"]),
ObjectID(state["actor_id"]),
state["module_name"],
state["class_name"],
ray.ObjectID(state["actor_cursor"])
ObjectID(state["actor_cursor"])
if state["actor_cursor"] is not None else None,
state["actor_method_names"],
state["method_signatures"],
state["method_num_return_vals"],
ray.ObjectID(state["actor_creation_dummy_object_id"])
ObjectID(state["actor_creation_dummy_object_id"])
if state["actor_creation_dummy_object_id"] is not None else None,
state["actor_method_cpus"],
actor_driver_id,
Expand Down Expand Up @@ -843,7 +841,7 @@ def __ray_checkpoint__(self):
# scheduler has seen. Handle IDs for which no task has yet reached
# the local scheduler will not be included, and may not be runnable
# on checkpoint resumption.
actor_id = ray.ObjectID(worker.actor_id)
actor_id = worker.actor_id
frontier = worker.raylet_client.get_actor_frontier(actor_id)
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
# should not be stored in Redis. Fix this.
Expand Down
19 changes: 9 additions & 10 deletions python/ray/experimental/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def parse_client_table(redis_client):
Returns:
A list of information about the nodes in the cluster.
"""
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
NIL_CLIENT_ID = ray.ObjectID.nil_id().id()
message = redis_client.execute_command("RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.CLIENT,
"", NIL_CLIENT_ID)
Expand Down Expand Up @@ -308,20 +308,19 @@ def _task_table(self, task_id):
function_descriptor = FunctionDescriptor.from_bytes_list(
function_descriptor_list)
task_spec_info = {
"DriverID": binary_to_hex(task_spec.driver_id().id()),
"TaskID": binary_to_hex(task_spec.task_id().id()),
"ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
"DriverID": task_spec.driver_id().hex(),
"TaskID": task_spec.task_id().hex(),
"ParentTaskID": task_spec.parent_task_id().hex(),
"ParentCounter": task_spec.parent_counter(),
"ActorID": binary_to_hex(task_spec.actor_id().id()),
"ActorCreationID": binary_to_hex(
task_spec.actor_creation_id().id()),
"ActorCreationDummyObjectID": binary_to_hex(
task_spec.actor_creation_dummy_object_id().id()),
"ActorID": (task_spec.actor_id().hex()),
"ActorCreationID": task_spec.actor_creation_id().hex(),
"ActorCreationDummyObjectID": (
task_spec.actor_creation_dummy_object_id().hex()),
"ActorCounter": task_spec.actor_counter(),
"Args": task_spec.arguments(),
"ReturnObjectIDs": task_spec.returns(),
"RequiredResources": task_spec.required_resources(),
"FunctionID": binary_to_hex(function_descriptor.function_id.id()),
"FunctionID": function_descriptor.function_id.hex(),
"FunctionHash": binary_to_hex(function_descriptor.function_hash),
"ModuleName": function_descriptor.module_name,
"ClassName": function_descriptor.class_name,
Expand Down
49 changes: 25 additions & 24 deletions python/ray/function_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def function_id(self):
Returns:
The value of ray.ObjectID that represents the function id.
"""
return ray.ObjectID(self._function_id)
return self._function_id

def _get_function_id(self):
"""Calculate the function id of current function descriptor.
Expand All @@ -220,10 +220,10 @@ def _get_function_id(self):
descriptor.

Returns:
bytes with length of ray_constants.ID_SIZE.
ray.ObjectID to represent the function descriptor.
"""
if self.is_for_driver_task:
return ray_constants.NIL_FUNCTION_ID.id()
return ray.ObjectID.nil_id()
function_id_hash = hashlib.sha1()
# Include the function module and name in the hash.
function_id_hash.update(self.module_name.encode("ascii"))
Expand All @@ -232,8 +232,7 @@ def _get_function_id(self):
function_id_hash.update(self._function_source_hash)
# Compute the function ID.
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return function_id
return ray.ObjectID(function_id)

def get_function_descriptor_list(self):
"""Return a list of bytes representing the function descriptor.
Expand Down Expand Up @@ -290,11 +289,11 @@ def __init__(self, worker):
self.imported_actor_classes = set()

def increase_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id.id()
function_id = function_descriptor.function_id
self._num_task_executions[driver_id][function_id] += 1

def get_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id.id()
function_id = function_descriptor.function_id
return self._num_task_executions[driver_id][function_id]

def export_cached(self):
Expand Down Expand Up @@ -372,13 +371,14 @@ def _do_export(self, remote_function):

def fetch_and_register_remote_function(self, key):
"""Import a remote function."""
(driver_id, function_id_str, function_name, serialized_function,
(driver_id_str, function_id_str, function_name, serialized_function,
num_return_vals, module, resources,
max_calls) = self._worker.redis_client.hmget(key, [
"driver_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
])
function_id = ray.ObjectID(function_id_str)
driver_id = ray.ObjectID(driver_id_str)
function_name = decode(function_name)
max_calls = int(max_calls)
module = decode(module)
Expand All @@ -388,10 +388,10 @@ def fetch_and_register_remote_function(self, key):
def f():
raise Exception("This function was not imported properly.")

self._function_execution_info[driver_id][function_id.id()] = (
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=f, function_name=function_name, max_calls=max_calls))
self._num_task_executions[driver_id][function_id.id()] = 0
self._num_task_executions[driver_id][function_id] = 0

try:
function = pickle.loads(serialized_function)
Expand All @@ -416,7 +416,7 @@ def f():
# However in the worker process, the `__main__` module is a
# different module, which is `default_worker.py`
function.__module__ = module
self._function_execution_info[driver_id][function_id.id()] = (
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=function,
function_name=function_name,
Expand All @@ -435,7 +435,7 @@ def get_execution_info(self, driver_id, function_descriptor):
Returns:
A FunctionExecutionInfo object.
"""
function_id = function_descriptor.function_id.id()
function_id = function_descriptor.function_id

# Wait until the function to be executed has actually been
# registered on this worker. We will push warnings to the user if
Expand All @@ -449,7 +449,7 @@ def get_execution_info(self, driver_id, function_descriptor):
except KeyError as e:
message = ("Error occurs in get_execution_info: "
"driver_id: %s, function_descriptor: %s. Message: %s" %
(binary_to_hex(driver_id), function_descriptor, e))
driver_id, function_descriptor, e)
raise KeyError(message)
return info

Expand All @@ -474,11 +474,11 @@ def _wait_for_function(self, function_descriptor, driver_id, timeout=10):
warning_sent = False
while True:
with self._worker.lock:
if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID
and (function_descriptor.function_id.id() in
if (self._worker.actor_id.is_nil()
and (function_descriptor.function_id in
self._function_execution_info[driver_id])):
break
elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and (
elif not self._worker.actor_id.is_nil() and (
self._worker.actor_id in self._worker.actors):
break
if time.time() - start_time > timeout:
Expand Down Expand Up @@ -556,7 +556,7 @@ def export_actor_class(self, Class, actor_method_names,
# because of https://github.com/ray-project/ray/issues/1146.

def load_actor(self, driver_id, function_descriptor):
key = (b"ActorClass:" + driver_id + b":" +
key = (b"ActorClass:" + driver_id.id() + b":" +
function_descriptor.function_id.id())
# Wait for the actor class key to have been imported by the
# import thread. TODO(rkn): It shouldn't be possible to end
Expand All @@ -578,8 +578,8 @@ def fetch_and_register_actor(self, actor_class_key):
actor_class_key: The key in Redis to use to fetch the actor.
worker: The worker to use.
"""
actor_id_str = self._worker.actor_id
(driver_id, class_name, module, pickled_class, checkpoint_interval,
actor_id = self._worker.actor_id
(driver_id_str, class_name, module, pickled_class, checkpoint_interval,
actor_method_names) = self._worker.redis_client.hmget(
actor_class_key, [
"driver_id", "class_name", "module", "class",
Expand All @@ -588,6 +588,7 @@ def fetch_and_register_actor(self, actor_class_key):

class_name = decode(class_name)
module = decode(module)
driver_id = ray.ObjectID(driver_id_str)
checkpoint_interval = int(checkpoint_interval)
actor_method_names = json.loads(decode(actor_method_names))

Expand All @@ -606,7 +607,7 @@ def fetch_and_register_actor(self, actor_class_key):
class TemporaryActor(object):
pass

self._worker.actors[actor_id_str] = TemporaryActor()
self._worker.actors[actor_id] = TemporaryActor()
self._worker.actor_checkpoint_interval = checkpoint_interval

def temporary_actor_method(*xs):
Expand All @@ -618,7 +619,7 @@ def temporary_actor_method(*xs):
for actor_method_name in actor_method_names:
function_descriptor = FunctionDescriptor(module, actor_method_name,
class_name)
function_id = function_descriptor.function_id.id()
function_id = function_descriptor.function_id
temporary_executor = self._make_actor_method_executor(
actor_method_name,
temporary_actor_method,
Expand All @@ -644,22 +645,22 @@ def temporary_actor_method(*xs):
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
traceback_str,
driver_id,
data={"actor_id": actor_id_str})
data={"actor_id": actor_id.id()})
# TODO(rkn): In the future, it might make sense to have the worker
# exit here. However, currently that would lead to hanging if
# someone calls ray.get on a method invoked on the actor.
else:
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
self._worker.actors[actor_id_str] = unpickled_class.__new__(
self._worker.actors[actor_id] = unpickled_class.__new__(
unpickled_class)

actor_methods = inspect.getmembers(
unpickled_class, predicate=is_function_or_method)
for actor_method_name, actor_method in actor_methods:
function_descriptor = FunctionDescriptor(
module, actor_method_name, class_name)
function_id = function_descriptor.function_id.id()
function_id = function_descriptor.function_id
executor = self._make_actor_method_executor(
actor_method_name, actor_method, actor_imported=True)
self._function_execution_info[driver_id][function_id] = (
Expand Down
2 changes: 1 addition & 1 deletion python/ray/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def construct_error_message(driver_id, error_type, message, timestamp):
The serialized object.
"""
builder = flatbuffers.Builder(0)
driver_offset = builder.CreateString(driver_id)
driver_offset = builder.CreateString(driver_id.id())
error_type_offset = builder.CreateString(error_type)
message_offset = builder.CreateString(message)

Expand Down
2 changes: 1 addition & 1 deletion python/ray/import_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,5 @@ def fetch_and_execute_function_to_run(self, key):
self.worker,
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
driver_id=ray.ObjectID(driver_id),
data={"name": name})
4 changes: 0 additions & 4 deletions python/ray/ray_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import os

from ray.raylet import ObjectID


def env_integer(key, default):
if key in os.environ:
Expand All @@ -15,8 +13,6 @@ def env_integer(key, default):


ID_SIZE = 20
NIL_JOB_ID = ObjectID(ID_SIZE * b"\xff")
NIL_FUNCTION_ID = NIL_JOB_ID

# The default maximum number of bytes to allocate to the object store unless
# overridden by the user.
Expand Down
Loading