diff --git a/doc/requirements-doc.txt b/doc/requirements-doc.txt index 8bd4c51cb9f8..edf90aa48d85 100644 --- a/doc/requirements-doc.txt +++ b/doc/requirements-doc.txt @@ -1,5 +1,6 @@ colorama click +flatbuffers funcsigs mock numpy diff --git a/doc/source/conf.py b/doc/source/conf.py index 3b82a3819d27..68551d2fe143 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -34,14 +34,24 @@ "tensorflow.python.util", "ray.local_scheduler", "ray.plasma", + "ray.core", + "ray.core.generated", + "ray.core.generated.DriverTableMessage", + "ray.core.generated.LocalSchedulerInfoMessage", + "ray.core.generated.ResultTableReply", + "ray.core.generated.SubscribeToDBClientTableReply", + "ray.core.generated.SubscribeToNotificationsReply", "ray.core.generated.TaskInfo", "ray.core.generated.TaskReply", - "ray.core.generated.ResultTableReply", "ray.core.generated.TaskExecutionDependencies", "ray.core.generated.ClientTableData", "ray.core.generated.GcsTableEntry", + "ray.core.generated.HeartbeatTableData", + "ray.core.generated.ErrorTableData", "ray.core.generated.ObjectTableData", - "ray.core.generated.ray.protocol.Task"] + "ray.core.generated.ray.protocol.Task", + "ray.core.generated.TablePrefix", + "ray.core.generated.TablePubsub",] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/python/ray/actor.py b/python/ray/actor.py index 51756b1a00fe..fd94d5e96a56 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -164,7 +164,7 @@ def save_and_log_checkpoint(worker, actor): traceback_str = ray.utils.format_error_message(traceback.format_exc()) # Log the error message. ray.utils.push_error_to_driver( - worker.redis_client, + worker, ray_constants.CHECKPOINT_PUSH_ERROR, traceback_str, driver_id=worker.task_driver_id.id(), @@ -188,7 +188,7 @@ def restore_and_log_checkpoint(worker, actor): traceback_str = ray.utils.format_error_message(traceback.format_exc()) # Log the error message. ray.utils.push_error_to_driver( - worker.redis_client, + worker, ray_constants.CHECKPOINT_PUSH_ERROR, traceback_str, driver_id=worker.task_driver_id.id(), @@ -330,7 +330,7 @@ def temporary_actor_method(*xs): traceback_str = ray.utils.format_error_message(traceback.format_exc()) # Log the error message. push_error_to_driver( - worker.redis_client, + worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR, traceback_str, driver_id, @@ -402,7 +402,7 @@ def export_actor_class(class_id, Class, actor_method_names, .format(actor_class_info["class_name"], len(actor_class_info["class"]))) ray.utils.push_error_to_driver( - worker.redis_client, + worker, ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, warning_message, driver_id=worker.task_driver_id.id()) diff --git a/python/ray/common/redis_module/runtest.py b/python/ray/common/redis_module/runtest.py index fd6f55c03855..7a7d25c6bedc 100644 --- a/python/ray/common/redis_module/runtest.py +++ b/python/ray/common/redis_module/runtest.py @@ -8,20 +8,9 @@ import time import unittest +import ray.gcs_utils import ray.services -# Import flatbuffer bindings. -from ray.core.generated.SubscribeToNotificationsReply \ - import SubscribeToNotificationsReply -from ray.core.generated.TaskReply import TaskReply -from ray.core.generated.ResultTableReply import ResultTableReply - -OBJECT_INFO_PREFIX = "OI:" -OBJECT_LOCATION_PREFIX = "OL:" -OBJECT_SUBSCRIBE_PREFIX = "OS:" -TASK_PREFIX = "TT:" -OBJECT_CHANNEL_PREFIX = "OC:" - def integerToAsciiHex(num, numbytes): retstr = b"" @@ -194,7 +183,7 @@ def testObjectTableSubscribeToNotifications(self): # notifications. def check_object_notification(notification_message, object_id, object_size, manager_ids): - notification_object = (SubscribeToNotificationsReply. + notification_object = (ray.gcs_utils.SubscribeToNotificationsReply. GetRootAsSubscribeToNotificationsReply( notification_message, 0)) self.assertEqual(notification_object.ObjectId(), object_id) @@ -208,7 +197,8 @@ def check_object_notification(notification_message, object_id, data_size = 0xf1f0 p = self.redis.pubsub() # Subscribe to an object ID. - p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX)) + p.psubscribe("{}manager_id1".format( + ray.gcs_utils.OBJECT_CHANNEL_PREFIX)) self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size, "hash1", "manager_id2") # Receive the acknowledgement message. @@ -252,8 +242,9 @@ def check_object_notification(notification_message, object_id, def testResultTableAddAndLookup(self): def check_result_table_entry(message, task_id, is_put): - result_table_reply = ResultTableReply.GetRootAsResultTableReply( - message, 0) + result_table_reply = ( + ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply( + message, 0)) self.assertEqual(result_table_reply.TaskId(), task_id) self.assertEqual(result_table_reply.IsPut(), is_put) @@ -315,12 +306,13 @@ def testTaskTableAddAndLookup(self): # make sure somebody will get a notification (checked in the redis # module) p = self.redis.pubsub() - p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) + p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX)) def check_task_reply(message, task_args, updated=False): (task_status, local_scheduler_id, execution_dependencies_string, spillback_count, task_spec) = task_args - task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) + task_reply_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply( + message, 0) self.assertEqual(task_reply_object.State(), task_status) self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id) @@ -409,7 +401,8 @@ def check_task_subscription(self, p, scheduling_state, local_scheduler_id): # Receive the data. message = get_next_message(p)["data"] # Check that the notification object is correct. - notification_object = TaskReply.GetRootAsTaskReply(message, 0) + notification_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply( + message, 0) self.assertEqual(notification_object.TaskId(), task_args[0]) self.assertEqual(notification_object.State(), task_args[1]) self.assertEqual(notification_object.LocalSchedulerId(), task_args[2]) @@ -422,32 +415,34 @@ def testTaskTableSubscribe(self): local_scheduler_id = "local_scheduler_id" # Subscribe to the task table. p = self.redis.pubsub() - p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) + p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX)) # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 1) self.check_task_subscription(p, scheduling_state, local_scheduler_id) # unsubscribe to make sure there is only one subscriber at a given time - p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) + p.punsubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX)) # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 0) p.psubscribe("{prefix}*:{state}".format( - prefix=TASK_PREFIX, state=scheduling_state)) + prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state)) # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 1) self.check_task_subscription(p, scheduling_state, local_scheduler_id) p.punsubscribe("{prefix}*:{state}".format( - prefix=TASK_PREFIX, state=scheduling_state)) + prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state)) # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 0) p.psubscribe("{prefix}{local_scheduler_id}:*".format( - prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) + prefix=ray.gcs_utils.TASK_PREFIX, + local_scheduler_id=local_scheduler_id)) # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 1) self.check_task_subscription(p, scheduling_state, local_scheduler_id) p.punsubscribe("{prefix}{local_scheduler_id}:*".format( - prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) + prefix=ray.gcs_utils.TASK_PREFIX, + local_scheduler_id=local_scheduler_id)) # Receive acknowledgment. self.assertEqual(get_next_message(p)["data"], 0) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index cc23a4e58d15..aceb672fc932 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -12,41 +12,10 @@ import time import ray +import ray.gcs_utils from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) -# Import flatbuffer bindings. -from ray.core.generated.TaskReply import TaskReply -from ray.core.generated.ResultTableReply import ResultTableReply -from ray.core.generated.TaskExecutionDependencies import \ - TaskExecutionDependencies - -from ray.core.generated.ClientTableData import ClientTableData -from ray.core.generated.GcsTableEntry import GcsTableEntry -from ray.core.generated.ObjectTableData import ObjectTableData - -from ray.core.generated.ray.protocol.Task import Task - -# These prefixes must be kept up-to-date with the definitions in -# ray_redis_module.cc. -DB_CLIENT_PREFIX = "CL:" -OBJECT_INFO_PREFIX = "OI:" -OBJECT_LOCATION_PREFIX = "OL:" -OBJECT_SUBSCRIBE_PREFIX = "OS:" -TASK_PREFIX = "TT:" -FUNCTION_PREFIX = "RemoteFunction:" -OBJECT_CHANNEL_PREFIX = "OC:" - -# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs. -# TODO(rkn): We should use scoped enums, in which case we should be able to -# just access the flatbuffer generated values. -TablePrefix_RAYLET_TASK = 2 -TablePrefix_RAYLET_TASK_string = "TASK" -TablePrefix_CLIENT = 3 -TablePrefix_CLIENT_string = "CLIENT" -TablePrefix_OBJECT = 4 -TablePrefix_OBJECT_string = "OBJECT" - # This mapping from integer to task state string must be kept up-to-date with # the scheduling_state enum in task.h. TASK_STATUS_WAITING = 1 @@ -231,8 +200,9 @@ def _object_table(self, object_id): result_table_response = self._execute_command( object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id()) - result_table_message = ResultTableReply.GetRootAsResultTableReply( - result_table_response, 0) + result_table_message = ( + ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply( + result_table_response, 0)) result = { "ManagerIDs": manager_ids, @@ -245,12 +215,14 @@ def _object_table(self, object_id): else: # Use the raylet code path. message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", TablePrefix_OBJECT, "", object_id.id()) + "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.OBJECT, "", + object_id.id()) result = [] - gcs_entry = GcsTableEntry.GetRootAsGcsTableEntry(message, 0) + gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) for i in range(gcs_entry.EntriesLength()): - entry = ObjectTableData.GetRootAsObjectTableData( + entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( gcs_entry.Entries(i), 0) object_info = { "DataSize": entry.ObjectSize(), @@ -279,19 +251,22 @@ def object_table(self, object_id=None): else: # Return the entire object table. if not self.use_raylet: - object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*") - object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*") + object_info_keys = self._keys( + ray.gcs_utils.OBJECT_INFO_PREFIX + "*") + object_location_keys = self._keys( + ray.gcs_utils.OBJECT_LOCATION_PREFIX + "*") object_ids_binary = set([ - key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys + key[len(ray.gcs_utils.OBJECT_INFO_PREFIX):] + for key in object_info_keys ] + [ - key[len(OBJECT_LOCATION_PREFIX):] + key[len(ray.gcs_utils.OBJECT_LOCATION_PREFIX):] for key in object_location_keys ]) else: object_keys = self.redis_client.keys( - TablePrefix_OBJECT_string + ":*") + ray.gcs_utils.TablePrefix_OBJECT_string + "*") object_ids_binary = { - key[len(TablePrefix_OBJECT_string + ":"):] + key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] for key in object_keys } @@ -320,7 +295,7 @@ def _task_table(self, task_id): if task_table_response is None: raise Exception("There is no entry for task ID {} in the task " "table.".format(binary_to_hex(task_id.id()))) - task_table_message = TaskReply.GetRootAsTaskReply( + task_table_message = ray.gcs_utils.TaskReply.GetRootAsTaskReply( task_table_response, 0) task_spec = task_table_message.TaskSpec() task_spec = ray.local_scheduler.task_from_string(task_spec) @@ -343,7 +318,8 @@ def _task_table(self, task_id): } execution_dependencies_message = ( - TaskExecutionDependencies.GetRootAsTaskExecutionDependencies( + ray.gcs_utils.TaskExecutionDependencies. + GetRootAsTaskExecutionDependencies( task_table_message.ExecutionDependencies(), 0)) execution_dependencies = [ ray.ObjectID( @@ -371,15 +347,17 @@ def _task_table(self, task_id): else: # Use the raylet code path. message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", TablePrefix_RAYLET_TASK, "", task_id.id()) - gcs_entries = GcsTableEntry.GetRootAsGcsTableEntry(message, 0) + "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.RAYLET_TASK, "", + task_id.id()) + gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) info = [] for i in range(gcs_entries.EntriesLength()): - task_table_message = Task.GetRootAsTask( + task_table_message = ray.gcs_utils.Task.GetRootAsTask( gcs_entries.Entries(i), 0) - task_table_message = Task.GetRootAsTask( + task_table_message = ray.gcs_utils.Task.GetRootAsTask( gcs_entries.Entries(0), 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() @@ -432,15 +410,16 @@ def task_table(self, task_id=None): return self._task_table(task_id) else: if not self.use_raylet: - task_table_keys = self._keys(TASK_PREFIX + "*") + task_table_keys = self._keys(ray.gcs_utils.TASK_PREFIX + "*") task_ids_binary = [ - key[len(TASK_PREFIX):] for key in task_table_keys + key[len(ray.gcs_utils.TASK_PREFIX):] + for key in task_table_keys ] else: task_table_keys = self.redis_client.keys( - TablePrefix_RAYLET_TASK_string + ":*") + ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") task_ids_binary = [ - key[len(TablePrefix_RAYLET_TASK_string + ":"):] + key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] for key in task_table_keys ] @@ -458,7 +437,8 @@ def function_table(self, function_id=None): function. """ self._check_connected() - function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*") + function_table_keys = self.redis_client.keys( + ray.gcs_utils.FUNCTION_PREFIX + "*") results = {} for key in function_table_keys: info = self.redis_client.hgetall(key) @@ -478,7 +458,8 @@ def client_table(self): """ self._check_connected() if not self.use_raylet: - db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*") + db_client_keys = self.redis_client.keys( + ray.gcs_utils.DB_CLIENT_PREFIX + "*") node_info = {} for key in db_client_keys: client_info = self.redis_client.hgetall(key) @@ -520,13 +501,16 @@ def client_table(self): # This is the raylet code path. NIL_CLIENT_ID = 20 * b"\xff" message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", TablePrefix_CLIENT, "", NIL_CLIENT_ID) + "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", + NIL_CLIENT_ID) node_info = [] - gcs_entry = GcsTableEntry.GetRootAsGcsTableEntry(message, 0) + gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) for i in range(gcs_entry.EntriesLength()): - client = ClientTableData.GetRootAsClientTableData( - gcs_entry.Entries(i), 0) + client = ( + ray.gcs_utils.ClientTableData.GetRootAsClientTableData( + gcs_entry.Entries(i), 0)) resources = { client.ResourcesTotalLabel(i).decode("ascii"): @@ -1146,3 +1130,64 @@ def cluster_resources(self): resources[key] += value return dict(resources) + + def _error_messages(self, job_id): + """Get the error messages for a specific job. + + Args: + job_id: The ID of the job to get the errors for. + + Returns: + A list of the error messages for this job. + """ + message = self.redis_client.execute_command( + "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "", + job_id.id()) + + # If there are no errors, return early. + if message is None: + return [] + + gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) + error_messages = [] + for i in range(gcs_entries.EntriesLength()): + error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( + gcs_entries.Entries(i), 0) + error_message = { + "type": error_data.Type().decode("ascii"), + "message": error_data.ErrorMessage().decode("ascii"), + "timestamp": error_data.Timestamp(), + } + error_messages.append(error_message) + return error_messages + + def error_messages(self, job_id=None): + """Get the error messages for all jobs or a specific job. + + Args: + job_id: The specific job to get the errors for. If this is None, + then this method retrieves the errors for all jobs. + + Returns: + A dictionary mapping job ID to a list of the error messages for + that job. + """ + if not self.use_raylet: + raise Exception("The error_messages method is only supported in " + "the raylet code path.") + + if job_id is not None: + return self._error_messages(job_id) + + error_table_keys = self.redis_client.keys( + ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*") + job_ids = [ + key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):] + for key in error_table_keys + ] + + return { + binary_to_hex(job_id): self._error_messages(ray.ObjectID(job_id)) + for job_id in job_ids + } diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py new file mode 100644 index 000000000000..708093f212eb --- /dev/null +++ b/python/ray/gcs_utils.py @@ -0,0 +1,84 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import flatbuffers + +from ray.core.generated.ResultTableReply import ResultTableReply +from ray.core.generated.SubscribeToNotificationsReply \ + import SubscribeToNotificationsReply +from ray.core.generated.TaskExecutionDependencies import \ + TaskExecutionDependencies +from ray.core.generated.TaskReply import TaskReply +from ray.core.generated.DriverTableMessage import DriverTableMessage +from ray.core.generated.LocalSchedulerInfoMessage import \ + LocalSchedulerInfoMessage +from ray.core.generated.SubscribeToDBClientTableReply import \ + SubscribeToDBClientTableReply +from ray.core.generated.TaskInfo import TaskInfo + +import ray.core.generated.ErrorTableData + +from ray.core.generated.GcsTableEntry import GcsTableEntry +from ray.core.generated.ClientTableData import ClientTableData +from ray.core.generated.ErrorTableData import ErrorTableData +from ray.core.generated.HeartbeatTableData import HeartbeatTableData +from ray.core.generated.ObjectTableData import ObjectTableData +from ray.core.generated.ray.protocol.Task import Task + +from ray.core.generated.TablePrefix import TablePrefix +from ray.core.generated.TablePubsub import TablePubsub + +__all__ = [ + "SubscribeToNotificationsReply", "ResultTableReply", + "TaskExecutionDependencies", "TaskReply", "DriverTableMessage", + "LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo", + "GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData", + "ObjectTableData", "Task", "TablePrefix", "TablePubsub", + "construct_error_message" +] + +# These prefixes must be kept up-to-date with the definitions in +# ray_redis_module.cc. +DB_CLIENT_PREFIX = "CL:" +TASK_PREFIX = "TT:" +OBJECT_CHANNEL_PREFIX = "OC:" +OBJECT_INFO_PREFIX = "OI:" +OBJECT_LOCATION_PREFIX = "OL:" +FUNCTION_PREFIX = "RemoteFunction:" + +# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs. +# TODO(rkn): We should use scoped enums, in which case we should be able to +# just access the flatbuffer generated values. +TablePrefix_RAYLET_TASK_string = "RAYLET_TASK" +TablePrefix_OBJECT_string = "OBJECT" +TablePrefix_ERROR_INFO_string = "ERROR_INFO" + + +def construct_error_message(error_type, message, timestamp): + """Construct a serialized ErrorTableData object. + + Args: + error_type: The type of the error. + message: The error message. + timestamp: The time of the error. + + Returns: + The serialized object. + """ + builder = flatbuffers.Builder(0) + error_type_offset = builder.CreateString(error_type) + message_offset = builder.CreateString(message) + + ray.core.generated.ErrorTableData.ErrorTableDataStart(builder) + ray.core.generated.ErrorTableData.ErrorTableDataAddType( + builder, error_type_offset) + ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage( + builder, message_offset) + ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp( + builder, timestamp) + error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd( + builder) + builder.Finish(error_data_offset) + + return bytes(builder.Output()) diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py index 64d8e4047091..8bfd8f27c1aa 100644 --- a/python/ray/global_scheduler/test/test.py +++ b/python/ray/global_scheduler/test/test.py @@ -29,11 +29,6 @@ NIL_OBJECT_ID = 20 * b"\xff" NIL_ACTOR_ID = 20 * b"\xff" -# These constants are an implementation detail of ray_redis_module.cc, so this -# must be kept in sync with that file. -DB_CLIENT_PREFIX = "CL:" -TASK_PREFIX = "TT:" - def random_driver_id(): return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 13729d6c7316..a4705314a4a2 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -9,20 +9,13 @@ import time from collections import Counter, defaultdict +import redis + import ray +from ray.autoscaler.autoscaler import LoadMetrics, StandardAutoscaler import ray.cloudpickle as pickle +import ray.gcs_utils import ray.utils -import redis -# Import flatbuffer bindings. -from ray.core.generated.DriverTableMessage import DriverTableMessage -from ray.core.generated.GcsTableEntry import GcsTableEntry -from ray.core.generated.HeartbeatTableData import HeartbeatTableData -from ray.core.generated.LocalSchedulerInfoMessage import \ - LocalSchedulerInfoMessage -from ray.core.generated.SubscribeToDBClientTableReply import \ - SubscribeToDBClientTableReply -from ray.autoscaler.autoscaler import LoadMetrics, StandardAutoscaler -from ray.core.generated.TaskInfo import TaskInfo from ray.services import get_ip_address, get_port from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary from ray.worker import NIL_ACTOR_ID @@ -259,7 +252,7 @@ def db_client_notification_handler(self, unused_channel, data): the associated state in the state tables should be handled by the caller. """ - notification_object = (SubscribeToDBClientTableReply. + notification_object = (ray.gcs_utils.SubscribeToDBClientTableReply. GetRootAsSubscribeToDBClientTableReply(data, 0)) db_client_id = binary_to_hex(notification_object.DbClientId()) client_type = notification_object.ClientType() @@ -285,8 +278,8 @@ def db_client_notification_handler(self, unused_channel, data): def local_scheduler_info_handler(self, unused_channel, data): """Handle a local scheduler heartbeat from Redis.""" - message = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage( - data, 0) + message = (ray.gcs_utils.LocalSchedulerInfoMessage. + GetRootAsLocalSchedulerInfoMessage(data, 0)) num_resources = message.DynamicResourcesLength() static_resources = {} dynamic_resources = {} @@ -308,9 +301,10 @@ def local_scheduler_info_handler(self, unused_channel, data): def xray_heartbeat_handler(self, unused_channel, data): """Handle an xray heartbeat message from Redis.""" - gcs_entries = GcsTableEntry.GetRootAsGcsTableEntry(data, 0) + gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + data, 0) heartbeat_data = gcs_entries.Entries(0) - message = HeartbeatTableData.GetRootAsHeartbeatTableData( + message = ray.gcs_utils.HeartbeatTableData.GetRootAsHeartbeatTableData( heartbeat_data, 0) num_resources = message.ResourcesAvailableLabelLength() static_resources = {} @@ -363,7 +357,8 @@ def _entries_for_driver_in_shard(self, driver_id, redis_shard_index): # driver. Use a cursor in order not to block the redis shards. for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"): entry = redis.hgetall(key) - task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0) + task_info = ray.gcs_utils.TaskInfo.GetRootAsTaskInfo( + entry[b"TaskSpec"], 0) if driver_id != task_info.DriverId(): # Ignore tasks that aren't from this driver. continue @@ -475,7 +470,8 @@ def driver_removed_handler(self, unused_channel, data): This releases any GPU resources that were reserved for that driver in Redis. """ - message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0) + message = ray.gcs_utils.DriverTableMessage.GetRootAsDriverTableMessage( + data, 0) driver_id = message.DriverId() log.info("Driver {} has been removed.".format( binary_to_hex(driver_id))) diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index ada9a76aafb4..9310757876b9 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -5,6 +5,8 @@ import os +import ray + def env_integer(key, default): if key in os.environ: @@ -12,6 +14,9 @@ def env_integer(key, default): return default +ID_SIZE = 20 +NIL_JOB_ID = ray.ObjectID(ID_SIZE * b"\x00") + # If a remote function or actor (or some other export) has serialized size # greater than this quantity, print an warning. PICKLE_OBJECT_WARNING_SIZE = 10**7 diff --git a/python/ray/utils.py b/python/ray/utils.py index f4d669ac66bd..5f51b6f24325 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -7,9 +7,12 @@ import numpy as np import os import sys +import time import uuid +import ray.gcs_utils import ray.local_scheduler +import ray.ray_constants as ray_constants ERROR_KEY_PREFIX = b"Error:" DRIVER_ID_LENGTH = 20 @@ -45,15 +48,56 @@ def format_error_message(exception_message, task_exception=False): return "\n".join(lines) -def push_error_to_driver(redis_client, +def push_error_to_driver(worker, error_type, message, driver_id=None, data=None): """Push an error message to the driver to be printed in the background. + Args: + worker: The worker to use. + error_type (str): The type of the error. + message (str): The message that will be printed in the background + on the driver. + driver_id: The ID of the driver to push the error message to. If this + is None, then the message will be pushed to all drivers. + data: This should be a dictionary mapping strings to strings. It + will be serialized with json and stored in Redis. + """ + if driver_id is None: + driver_id = ray_constants.NIL_JOB_ID.id() + error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string() + data = {} if data is None else data + if not worker.use_raylet: + worker.redis_client.hmset(error_key, { + "type": error_type, + "message": message, + "data": data + }) + worker.redis_client.rpush("ErrorKeys", error_key) + else: + worker.local_scheduler_client.push_error( + ray.ObjectID(driver_id), error_type, message, time.time()) + + +def push_error_to_driver_through_redis(redis_client, + use_raylet, + error_type, + message, + driver_id=None, + data=None): + """Push an error message to the driver to be printed in the background. + + Normally the push_error_to_driver function should be used. However, in some + instances, the local scheduler client is not available, e.g., because the + error happens in Python before the driver or worker has connected to the + backend processes. + Args: redis_client: The redis client to use. + use_raylet: True if we are using the Raylet code path and false + otherwise. error_type (str): The type of the error. message (str): The message that will be printed in the background on the driver. @@ -63,15 +107,24 @@ def push_error_to_driver(redis_client, will be serialized with json and stored in Redis. """ if driver_id is None: - driver_id = DRIVER_ID_LENGTH * b"\x00" + driver_id = ray_constants.NIL_JOB_ID.id() error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string() data = {} if data is None else data - redis_client.hmset(error_key, { - "type": error_type, - "message": message, - "data": data - }) - redis_client.rpush("ErrorKeys", error_key) + if not use_raylet: + redis_client.hmset(error_key, { + "type": error_type, + "message": message, + "data": data + }) + redis_client.rpush("ErrorKeys", error_key) + else: + # Do everything in Python and through the Python Redis client instead + # of through the raylet. + error_data = ray.gcs_utils.construct_error_message( + error_type, message, time.time()) + redis_client.execute_command( + "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, + ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data) def is_cython(obj): diff --git a/python/ray/worker.py b/python/ray/worker.py index 6567d327575e..84e88e18d493 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -22,6 +22,7 @@ import pyarrow.plasma as plasma import ray.cloudpickle as pickle import ray.experimental.state as state +import ray.gcs_utils import ray.remote_function import ray.serialization as serialization import ray.services as services @@ -31,9 +32,6 @@ import ray.ray_constants as ray_constants from ray.utils import random_string, binary_to_hex, is_cython -# Import flatbuffer bindings. -from ray.core.generated.ClientTableData import ClientTableData - SCRIPT_MODE = 0 WORKER_MODE = 1 PYTHON_MODE = 2 @@ -415,7 +413,7 @@ def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): "may be a bug.") if not warning_sent: ray.utils.push_error_to_driver( - self.redis_client, + self, ray_constants.WAIT_FOR_CLASS_PUSH_ERROR, warning_message, driver_id=self.task_driver_id.id()) @@ -663,7 +661,7 @@ def export_remote_function(self, function_id, function_name, function, "large array or other object.".format( function_name, len(pickled_function))) ray.utils.push_error_to_driver( - self.redis_client, + self, ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, warning_message, driver_id=self.task_driver_id.id()) @@ -726,7 +724,7 @@ def run_function_on_all_workers(self, function): .format(function.__name__, len(pickled_function))) ray.utils.push_error_to_driver( - self.redis_client, + self, ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, warning_message, driver_id=self.task_driver_id.id()) @@ -781,7 +779,7 @@ def _wait_for_function(self, function_id, driver_id, timeout=10): "Ray.") if not warning_sent: ray.utils.push_error_to_driver( - self.redis_client, + self, ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, warning_message, driver_id=driver_id) @@ -942,7 +940,7 @@ def _handle_process_task_failure(self, function_id, return_object_ids, self._store_outputs_in_objstore(return_object_ids, failure_objects) # Log the error message. ray.utils.push_error_to_driver( - self.redis_client, + self, ray_constants.TASK_PUSH_ERROR, str(failure_object), driver_id=self.task_driver_id.id(), @@ -1200,6 +1198,11 @@ def error_info(worker=global_worker): """Return information about failed tasks.""" worker.check_connected() check_main_thread() + + if worker.use_raylet: + return (global_state.error_messages(job_id=worker.task_driver_id) + + global_state.error_messages(job_id=ray_constants.NIL_JOB_ID)) + error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) errors = [] for error_key in error_keys: @@ -1291,9 +1294,8 @@ def get_address_info_from_redis_helper(redis_address, if not use_raylet: # The client table prefix must be kept in sync with the file # "src/common/redis_module/ray_redis_module.cc" where it is defined. - REDIS_CLIENT_TABLE_PREFIX = "CL:" - client_keys = redis_client.keys( - "{}*".format(REDIS_CLIENT_TABLE_PREFIX)) + client_keys = redis_client.keys("{}*".format( + ray.gcs_utils.DB_CLIENT_PREFIX)) # Filter to live clients on the same node and do some basic checking. plasma_managers = [] local_schedulers = [] @@ -1350,11 +1352,11 @@ def get_address_info_from_redis_helper(redis_address, else: # In the raylet code path, all client data is stored in a zset at the # key for the nil client. - client_key = b"CLIENT:" + NIL_CLIENT_ID + client_key = b"CLIENT" + NIL_CLIENT_ID clients = redis_client.zrange(client_key, 0, -1) raylets = [] for client_message in clients: - client = ClientTableData.GetRootAsClientTableData( + client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData( client_message, 0) client_node_ip_address = client.NodeManagerAddress().decode( "ascii") @@ -1819,6 +1821,71 @@ def custom_excepthook(type, value, tb): sys.excepthook = custom_excepthook +def print_error_messages_raylet(worker): + """Print error messages in the background on the driver. + + This runs in a separate thread on the driver and prints error messages in + the background. + """ + if not worker.use_raylet: + raise Exception("This function is specific to the raylet code path.") + + worker.error_message_pubsub_client = worker.redis_client.pubsub( + ignore_subscribe_messages=True) + # Exports that are published after the call to + # error_message_pubsub_client.subscribe and before the call to + # error_message_pubsub_client.listen will still be processed in the loop. + + # Really we should just subscribe to the errors for this specific job. + # However, currently all errors seem to be published on the same channel. + error_pubsub_channel = str( + ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii") + worker.error_message_pubsub_client.subscribe(error_pubsub_channel) + # worker.error_message_pubsub_client.psubscribe("*") + + # Keep a set of all the error messages that we've seen so far in order to + # avoid printing the same error message repeatedly. This is especially + # important when running a script inside of a tool like screen where + # scrolling is difficult. + old_error_messages = set() + + # Get the exports that occurred before the call to subscribe. + with worker.lock: + error_messages = global_state.error_messages(worker.task_driver_id) + for error_message in error_messages: + if error_message not in old_error_messages: + print(error_message) + old_error_messages.add(error_message) + else: + print("Suppressing duplicate error message.") + + try: + for msg in worker.error_message_pubsub_client.listen(): + + gcs_entry = state.GcsTableEntry.GetRootAsGcsTableEntry( + msg["data"], 0) + assert gcs_entry.EntriesLength() == 1 + error_data = state.ErrorTableData.GetRootAsErrorTableData( + gcs_entry.Entries(0), 0) + NIL_JOB_ID = 20 * b"\x00" + job_id = error_data.JobId() + if job_id not in [worker.task_driver_id.id(), NIL_JOB_ID]: + continue + + error_message = error_data.ErrorMessage().decode("ascii") + + if error_message not in old_error_messages: + print(error_message) + old_error_messages.add(error_message) + else: + print("Suppressing duplicate error message.") + + except redis.ConnectionError: + # When Redis terminates the listen call will throw a ConnectionError, + # which we catch here. + pass + + def print_error_messages(worker): """Print error messages in the background on the driver. @@ -1907,7 +1974,7 @@ def f(): traceback_str = ray.utils.format_error_message(traceback.format_exc()) # Log the error message. ray.utils.push_error_to_driver( - worker.redis_client, + worker, ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, traceback_str, driver_id=driver_id, @@ -1952,7 +2019,7 @@ def fetch_and_execute_function_to_run(key, worker=global_worker): name = function.__name__ if ("function" in locals() and hasattr(function, "__name__")) else "" ray.utils.push_error_to_driver( - worker.redis_client, + worker, ray_constants.FUNCTION_TO_RUN_PUSH_ERROR, traceback_str, driver_id=driver_id, @@ -2111,8 +2178,9 @@ def connect(info, raise e elif mode == WORKER_MODE: traceback_str = traceback.format_exc() - ray.utils.push_error_to_driver( + ray.utils.push_error_to_driver_through_redis( worker.redis_client, + worker.use_raylet, ray_constants.VERSION_MISMATCH_PUSH_ERROR, traceback_str, driver_id=None) @@ -2237,13 +2305,11 @@ def connect(info, driver_task.execution_dependencies_string(), 0, ray.local_scheduler.task_to_string(driver_task)) else: - TablePubsub_RAYLET_TASK = 2 - # TODO(rkn): When we shard the GCS in xray, we will need to change # this to use _execute_command. global_state.redis_client.execute_command( - "RAY.TABLE_ADD", state.TablePrefix_RAYLET_TASK, - TablePubsub_RAYLET_TASK, + "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.RAYLET_TASK, + ray.gcs_utils.TablePubsub.RAYLET_TASK, driver_task.task_id().id(), driver_task._serialized_raylet_task()) @@ -2271,7 +2337,11 @@ def connect(info, # temporarily using this implementation which constantly queries the # scheduler for new error messages. if mode == SCRIPT_MODE: - t = threading.Thread(target=print_error_messages, args=(worker, )) + if not worker.use_raylet: + t = threading.Thread(target=print_error_messages, args=(worker, )) + else: + t = threading.Thread( + target=print_error_messages_raylet, args=(worker, )) # Making the thread a daemon causes it to exit when the main thread # exits. t.daemon = True diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 3e761a9d4c77..0bc21f2b2bda 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -69,10 +69,11 @@ ray.worker.global_worker.main_loop() except Exception as e: traceback_str = traceback.format_exc() + error_explanation - # Create a Redis client. - redis_client = ray.services.create_redis_client(args.redis_address) ray.utils.push_error_to_driver( - redis_client, "worker_crash", traceback_str, driver_id=None) + ray.worker.global_worker, + "worker_crash", + traceback_str, + driver_id=None) # TODO(rkn): Note that if the worker was in the middle of executing # a task, then any worker or driver that is blocking in a get call # and waiting for the output of that task will hang. We need to diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc index 7bcd0c60d9f2..3314665b55b8 100644 --- a/src/common/redis_module/ray_redis_module.cc +++ b/src/common/redis_module/ray_redis_module.cc @@ -61,14 +61,6 @@ extern RedisChainModule module; return RedisModule_ReplyWithError(ctx, (MESSAGE)); \ } -// NOTE(swang): The order of prefixes here must match the TablePrefix enum -// defined in src/ray/gcs/format/gcs.fbs. -static const char *table_prefixes[] = { - NULL, "TASK:", "TASK:", "CLIENT:", - "OBJECT:", "ACTOR:", "FUNCTION:", "TASK_RECONSTRUCTION:", - "HEARTBEAT:", -}; - /// Parse a Redis string into a TablePubsub channel. TablePubsub ParseTablePubsub(const RedisModuleString *pubsub_channel_str) { long long pubsub_channel_long; @@ -128,8 +120,8 @@ RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, << "This table has no prefix registered"; RAY_CHECK(prefix >= TablePrefix::MIN && prefix <= TablePrefix::MAX) << "Prefix must be a valid TablePrefix"; - return OpenPrefixedKey(ctx, table_prefixes[static_cast(prefix)], - keyname, mode, mutated_key_str); + return OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode, + mutated_key_str); } RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, diff --git a/src/local_scheduler/lib/python/local_scheduler_extension.cc b/src/local_scheduler/lib/python/local_scheduler_extension.cc index ce68c6bfe250..bd9dc6f0fc09 100644 --- a/src/local_scheduler/lib/python/local_scheduler_extension.cc +++ b/src/local_scheduler/lib/python/local_scheduler_extension.cc @@ -286,6 +286,29 @@ static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) { return Py_BuildValue("(OO)", py_found, py_remaining); } +static PyObject *PyLocalSchedulerClient_push_error(PyObject *self, + PyObject *args) { + JobID job_id; + const char *type; + int type_length; + const char *error_message; + int error_message_length; + double timestamp; + if (!PyArg_ParseTuple(args, "O&s#s#d", &PyObjectToUniqueID, &job_id, &type, + &type_length, &error_message, &error_message_length, + ×tamp)) { + return NULL; + } + + local_scheduler_push_error(reinterpret_cast(self) + ->local_scheduler_connection, + job_id, std::string(type, type_length), + std::string(error_message, error_message_length), + timestamp); + + Py_RETURN_NONE; +} + static PyMethodDef PyLocalSchedulerClient_methods[] = { {"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS, "Notify the local scheduler that this client is exiting gracefully."}, @@ -313,6 +336,8 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = { (PyCFunction) PyLocalSchedulerClient_set_actor_frontier, METH_VARARGS, ""}, {"wait", (PyCFunction) PyLocalSchedulerClient_wait, METH_VARARGS, "Wait for a list of objects to be created."}, + {"push_error", (PyCFunction) PyLocalSchedulerClient_push_error, + METH_VARARGS, "Push an error message to the relevant driver."}, {NULL} /* Sentinel */ }; diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc index 68642d813ad1..f8ac3026bd7d 100644 --- a/src/local_scheduler/local_scheduler_client.cc +++ b/src/local_scheduler/local_scheduler_client.cc @@ -306,3 +306,19 @@ std::pair, std::vector> local_scheduler_wait( free(reply); return result; } + +void local_scheduler_push_error(LocalSchedulerConnection *conn, + const JobID &job_id, + const std::string &type, + const std::string &error_message, + double timestamp) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreatePushErrorRequest( + fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), + fbb.CreateString(error_message), timestamp); + fbb.Finish(message); + + write_message(conn->conn, static_cast( + ray::protocol::MessageType::PushErrorRequest), + fbb.GetSize(), fbb.GetBufferPointer()); +} diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h index 1a8cbd240f29..95a3de0c073c 100644 --- a/src/local_scheduler/local_scheduler_client.h +++ b/src/local_scheduler/local_scheduler_client.h @@ -211,4 +211,18 @@ std::pair, std::vector> local_scheduler_wait( int64_t timeout_milliseconds, bool wait_local); +/// Push an error to the relevant driver. +/// +/// \param conn The connection information. +/// \param The ID of the job that the error is for. +/// \param The type of the error. +/// \param The error message. +/// \param The timestamp of the error. +/// \return Void. +void local_scheduler_push_error(LocalSchedulerConnection *conn, + const JobID &job_id, + const std::string &type, + const std::string &error_message, + double timestamp); + #endif diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 8e7fef935fb8..875986b4d0a6 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -15,6 +15,7 @@ AsyncGcsClient::AsyncGcsClient(const ClientID &client_id, CommandType command_ty raylet_task_table_.reset(new raylet::TaskTable(context_, this, command_type)); task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this)); heartbeat_table_.reset(new HeartbeatTable(context_, this)); + error_table_.reset(new ErrorTable(context_, this)); command_type_ = command_type; } @@ -74,6 +75,9 @@ FunctionTable &AsyncGcsClient::function_table() { return *function_table_; } ClassTable &AsyncGcsClient::class_table() { return *class_table_; } HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } + +ErrorTable &AsyncGcsClient::error_table() { return *error_table_; } + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index a5c07f70ed6d..cf054c4c93c2 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -57,7 +57,7 @@ class RAY_EXPORT AsyncGcsClient { TaskReconstructionLog &task_reconstruction_log(); ClientTable &client_table(); HeartbeatTable &heartbeat_table(); - inline ErrorTable &error_table(); + ErrorTable &error_table(); // We also need something to export generic code to run on workers from the // driver (to set the PYTHONPATH) @@ -78,6 +78,7 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr actor_table_; std::unique_ptr task_reconstruction_log_; std::unique_ptr heartbeat_table_; + std::unique_ptr error_table_; std::unique_ptr client_table_; std::shared_ptr context_; std::unique_ptr asio_async_client_; diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index f336aacac873..8f343437bbfa 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -14,6 +14,7 @@ enum TablePrefix:int { FUNCTION, TASK_RECONSTRUCTION, HEARTBEAT, + ERROR_INFO, } // The channel that Add operations to the Table should be published on, if any. @@ -24,7 +25,8 @@ enum TablePubsub:int { CLIENT, OBJECT, ACTOR, - HEARTBEAT + HEARTBEAT, + ERROR_INFO, } table GcsTableEntry { @@ -103,6 +105,14 @@ table ActorTableData { } table ErrorTableData { + // The ID of the job that the error is for. + job_id: string; + // The type of the error. + type: string; + // The error message. + error_message: string; + // The timestamp of the error message. + timestamp: double; } table CustomSerializerData { diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index ef766985022f..fa2e9b7b25b1 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -183,6 +183,19 @@ Status Table::Subscribe(const JobID &job_id, const ClientID &client_id done); } +Status ErrorTable::PushErrorToDriver(const JobID &job_id, const std::string &type, + const std::string &error_message, double timestamp) { + auto data = std::make_shared(); + data->job_id = job_id.binary(); + data->type = type; + data->error_message = error_message; + data->timestamp = timestamp; + return Append(job_id, job_id, data, [](ray::gcs::AsyncGcsClient *client, + const JobID &id, const ErrorTableDataT &data) { + RAY_LOG(DEBUG) << "Error message pushed callback"; + }); +} + void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) { client_added_callback_ = callback; // Call the callback for any added clients that are cached. @@ -333,6 +346,7 @@ template class Table; template class Log; template class Log; template class Table; +template class Log; template class Log; } // namespace gcs diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 89b68a87b6ac..e68ecf9a9c36 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -95,7 +95,7 @@ class Log : virtual public PubsubInterface { /// \param id The ID of the data that is added to the GCS. /// \param data Data to append to the log. /// \param done Callback that is called once the data has been written to the - /// GCS. + /// GCS. /// \return Status Status Append(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); @@ -108,10 +108,9 @@ class Log : virtual public PubsubInterface { /// \param data Data to append to the log. /// \param done Callback that is called if the data was appended to the log. /// \param failure Callback that is called if the data was not appended to - /// the log because the log length did not match the given - /// `log_length`. + /// the log because the log length did not match the given `log_length`. /// \param log_length The number of entries that the log must have for the - /// append to succeed. + /// append to succeed. /// \return Status Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, @@ -122,7 +121,7 @@ class Log : virtual public PubsubInterface { /// \param job_id The ID of the job (= driver). /// \param id The ID of the data that is looked up in the GCS. /// \param lookup Callback that is called after lookup. If the callback is - /// called with an empty vector, then there was no data at the key. + /// called with an empty vector, then there was no data at the key. /// \return Status Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup); @@ -133,15 +132,14 @@ class Log : virtual public PubsubInterface { /// /// \param job_id The ID of the job (= driver). /// \param client_id The type of update to listen to. If this is nil, then a - /// message for each Add to the table will be received. Else, only - /// messages for the given client will be received. In the latter - /// case, the client may request notifications on specific keys in the - /// table via `RequestNotifications`. + /// message for each Add to the table will be received. Else, only + /// messages for the given client will be received. In the latter + /// case, the client may request notifications on specific keys in the + /// table via `RequestNotifications`. /// \param subscribe Callback that is called on each received message. If the - /// callback is called with an empty vector, then there was no data at - /// the key. + /// callback is called with an empty vector, then there was no data at the key. /// \param done Callback that is called when subscription is complete and we - /// are ready to receive messages. + /// are ready to receive messages. /// \return Status Status Subscribe(const JobID &job_id, const ClientID &client_id, const Callback &subscribe, const SubscriptionCallback &done); @@ -158,8 +156,8 @@ class Log : virtual public PubsubInterface { /// \param job_id The ID of the job (= driver). /// \param id The ID of the key to request notifications for. /// \param client_id The client who is requesting notifications. Before - /// notifications can be requested, a call to `Subscribe` to this - /// table with the same `client_id` must complete successfully. + /// notifications can be requested, a call to `Subscribe` to this + /// table with the same `client_id` must complete successfully. /// \return Status Status RequestNotifications(const JobID &job_id, const ID &id, const ClientID &client_id); @@ -241,7 +239,7 @@ class Table : private Log, /// \param id The ID of the data that is added to the GCS. /// \param data Data that is added to the GCS. /// \param done Callback that is called once the data has been written to the - /// GCS. + /// GCS. /// \return Status Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); @@ -251,9 +249,9 @@ class Table : private Log, /// \param job_id The ID of the job (= driver). /// \param id The ID of the data that is looked up in the GCS. /// \param lookup Callback that is called after lookup if there was data the - /// key. + /// key. /// \param failure Callback that is called after lookup if there was no data - /// at the key. + /// at the key. /// \return Status Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup, const FailureCallback &failure); @@ -366,10 +364,10 @@ class TaskTable : public Table { /// /// \param task_id The task ID of the task entry to update. /// \param test_state_bitmask The bitmask to apply to the task entry's current - /// scheduling state. The update happens if and only if the current - /// scheduling state AND-ed with the bitmask is greater than 0. + /// scheduling state. The update happens if and only if the current + /// scheduling state AND-ed with the bitmask is greater than 0. /// \param update_state The value to update the task entry's scheduling state - /// with, if the current state matches test_state_bitmask. + /// with, if the current state matches test_state_bitmask. /// \param callback Function to be called when database returns result. /// \return Status Status TestAndUpdate(const JobID &job_id, const TaskID &id, @@ -397,16 +395,14 @@ class TaskTable : public Table { /// task's local scheduler ID. /// /// \param local_scheduler_id The db_client_id of the local scheduler whose - /// events we want to listen to. If you want to subscribe to updates - /// from - /// all local schedulers, pass in NIL_ID. + /// events we want to listen to. If you want to subscribe to updates from + /// all local schedulers, pass in NIL_ID. /// \param subscribe_callback Callback that will be called when the task table - /// is - /// updated. + /// is updated. /// \param state_filter Events we want to listen to. Can have values from the - /// enum "scheduling_state" in task.h. - /// TODO(pcm): Make it possible to combine these using flags like - /// TASK_STATUS_WAITING | TASK_STATUS_SCHEDULED. + /// enum "scheduling_state" in task.h. + /// TODO(pcm): Make it possible to combine these using flags like + /// TASK_STATUS_WAITING | TASK_STATUS_SCHEDULED. /// \param callback Function to be called when database returns result. /// \return Status Status SubscribeToTask(const JobID &job_id, const ClientID &local_scheduler_id, @@ -422,7 +418,28 @@ Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id, SchedulingState update_state, const TaskTable::TestAndUpdateCallback &callback); -using ErrorTable = Table; +class ErrorTable : private Log { + public: + ErrorTable(const std::shared_ptr &context, AsyncGcsClient *client) + : Log(context, client) { + pubsub_channel_ = TablePubsub::ERROR_INFO; + prefix_ = TablePrefix::ERROR_INFO; + }; + + /// Push an error message for a specific job. + /// + /// TODO(rkn): We need to make sure that the errors are unique because + /// duplicate messages currently cause failures (the GCS doesn't allow it). + /// + /// \param job_id The ID of the job that generated the error. If the error + /// should be pushed to all jobs, then this should be nil. + /// \param type The type of the error. + /// \param error_message The error message to push. + /// \param timestamp The timestamp of the error. + /// \return Status. + Status PushErrorToDriver(const JobID &job_id, const std::string &type, + const std::string &error_message, double timestamp); +}; using CustomSerializerTable = Table; @@ -467,7 +484,7 @@ class ClientTable : private Log { /// and begins subscription to client table notifications. /// /// \param Information about the connecting client. This must have the - /// same client_id as the one set in the client table. + /// same client_id as the one set in the client table. /// \return Status ray::Status Connect(const ClientTableDataT &local_client); @@ -499,7 +516,7 @@ class ClientTable : private Log { /// /// \param client The client to get information about. /// \return A reference to the requested client. If the client is not in the - /// cache, then an entry with a nil ClientID will be returned. + /// cache, then an entry with a nil ClientID will be returned. const ClientTableDataT &GetClient(const ClientID &client); /// Get the local client's ID. diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index b56243cfcb2b..29c635ad5b4b 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -58,7 +58,10 @@ enum MessageType:int { WaitRequest, // The response message to WaitRequest; replies with the objects found and objects // remaining. - WaitReply + WaitReply, + // Push an error to the relevant driver. This is sent from a worker to the + // node manager. + PushErrorRequest, } table TaskExecutionSpecification { @@ -154,3 +157,15 @@ table WaitReply { // List of object ids not found. remaining: [string]; } + +// This struct is the same as ErrorTableData. +table PushErrorRequest { + // The ID of the job that the error is for. + job_id: string; + // The type of the error. + type: string; + // The error message. + error_message: string; + // The timestamp of the error message. + timestamp: double; +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 59c28c020861..e90e2f7d46fd 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -3,6 +3,7 @@ #include "common_protocol.h" #include "local_scheduler/format/local_scheduler_generated.h" #include "ray/raylet/format/node_manager_generated.h" +#include "ray/util/util.h" namespace { @@ -372,11 +373,28 @@ void NodeManager::ProcessClientMessage( // This if statement distinguishes workers from drivers. if (worker) { - // TODO(swang): Handle the case where the worker is killed while - // executing a task. Clean up the assigned task's resources, return an - // error to the driver. - // RAY_CHECK(worker->GetAssignedTaskId().is_nil()) - // << "Worker died while executing task: " << worker->GetAssignedTaskId(); + // Handle the case where the worker is killed while executing a task. + // Clean up the assigned task's resources, push an error to the driver. + const TaskID &task_id = worker->GetAssignedTaskId(); + if (!task_id.is_nil()) { + auto const &running_tasks = local_queues_.GetRunningTasks(); + // TODO(rkn): This is too heavyweight just to get the task's driver ID. + auto const it = std::find_if( + running_tasks.begin(), running_tasks.end(), [task_id](const Task &task) { + return task.GetTaskSpecification().TaskId() == task_id; + }); + RAY_CHECK(running_tasks.size() != 0); + RAY_CHECK(it != running_tasks.end()); + JobID job_id = it->GetTaskSpecification().DriverId(); + // TODO(rkn): Define this constant somewhere else. + std::string type = "worker_died"; + std::ostringstream error_message; + error_message << "A worker died or was killed while executing task " << task_id + << "."; + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + job_id, type, error_message.str(), current_time_ms())); + } + worker_pool_.DisconnectWorker(worker); const ClientID &client_id = gcs_client_->client_table().GetLocalClientId(); @@ -521,6 +539,17 @@ void NodeManager::ProcessClientMessage( }); RAY_CHECK_OK(status); } break; + case protocol::MessageType::PushErrorRequest: { + auto message = flatbuffers::GetRoot(message_data); + + JobID job_id = from_flatbuf(*message->job_id()); + auto const &type = string_from_flatbuf(*message->type()); + auto const &error_message = string_from_flatbuf(*message->error_message()); + double timestamp = message->timestamp(); + + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message, + timestamp)); + } break; default: RAY_LOG(FATAL) << "Received unexpected message type " << message_type; diff --git a/src/ray/util/CMakeLists.txt b/src/ray/util/CMakeLists.txt index c4fd15c0dca4..a69896fff177 100644 --- a/src/ray/util/CMakeLists.txt +++ b/src/ray/util/CMakeLists.txt @@ -1,6 +1,7 @@ install(FILES logging.h macros.h + util.h visibility.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/ray/util" ) diff --git a/src/ray/util/util.h b/src/ray/util/util.h new file mode 100644 index 000000000000..42632fa6b7a3 --- /dev/null +++ b/src/ray/util/util.h @@ -0,0 +1,19 @@ +#ifndef RAY_UTIL_UTIL_H +#define RAY_UTIL_UTIL_H + +#include + +/// Return the number of milliseconds since the Unix epoch. +/// +/// TODO(rkn): This function appears in multiple places. It should be +/// deduplicated. +/// +/// \return The number of milliseconds since the Unix epoch. +int64_t current_time_ms() { + std::chrono::milliseconds ms_since_epoch = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()); + return ms_since_epoch.count(); +} + +#endif // RAY_UTIL_UTIL_H diff --git a/test/failure_test.py b/test/failure_test.py index 7d1511f300f2..24818bffa5cb 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -269,9 +269,6 @@ class WorkerDeath(unittest.TestCase): def tearDown(self): ray.worker.cleanup() - @unittest.skipIf( - os.environ.get("RAY_USE_XRAY") == "1", - "This test does not work with xray yet.") def testWorkerRaisingException(self): ray.init(num_workers=1, driver_mode=ray.SILENT_MODE) @@ -287,9 +284,6 @@ def f(): wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1) self.assertEqual(len(ray.error_info()), 2) - @unittest.skipIf( - os.environ.get("RAY_USE_XRAY") == "1", - "This test does not work with xray yet.") def testWorkerDying(self): ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) @@ -303,7 +297,7 @@ def f(): wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1) self.assertEqual(len(ray.error_info()), 1) - self.assertIn("died or was killed while executing the task", + self.assertIn("died or was killed while executing", ray.error_info()[0]["message"]) @unittest.skipIf( diff --git a/test/runtest.py b/test/runtest.py index 0a20dae380df..e544344e0cfc 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -2243,7 +2243,7 @@ def f(): worker_ids = set(ray.get([f.remote() for _ in range(10)])) worker_info = ray.global_state.workers() - self.assertEqual(len(worker_info), num_workers) + assert len(worker_info) >= num_workers for worker_id, info in worker_info.items(): self.assertIn("node_ip_address", info) self.assertIn("local_scheduler_socket", info)