diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index 5f1020c0932f..e9697eee48d1 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -3,9 +3,14 @@ from __future__ import print_function from .tfutils import TensorFlowVariables -from .features import flush_redis_unsafe, flush_task_and_object_metadata_unsafe +from .features import ( + flush_redis_unsafe, flush_task_and_object_metadata_unsafe, + flush_finished_tasks_unsafe, flush_evicted_objects_unsafe, + _flush_finished_tasks_unsafe_shard, _flush_evicted_objects_unsafe_shard) __all__ = [ "TensorFlowVariables", "flush_redis_unsafe", - "flush_task_and_object_metadata_unsafe" + "flush_task_and_object_metadata_unsafe", "flush_finished_tasks_unsafe", + "flush_evicted_objects_unsafe", "_flush_finished_tasks_unsafe_shard", + "_flush_evicted_objects_unsafe_shard" ] diff --git a/python/ray/experimental/features.py b/python/ray/experimental/features.py index 7db9d611b9b0..304b275baebf 100644 --- a/python/ray/experimental/features.py +++ b/python/ray/experimental/features.py @@ -3,10 +3,11 @@ from __future__ import print_function import ray +from ray.utils import binary_to_hex OBJECT_INFO_PREFIX = b"OI:" OBJECT_LOCATION_PREFIX = b"OL:" -TASK_TABLE_PREFIX = b"TT:" +TASK_PREFIX = b"TT:" def flush_redis_unsafe(): @@ -18,9 +19,7 @@ def flush_redis_unsafe(): much of the data is in the task table (and object table), which are not flushed. """ - if not hasattr(ray.worker.global_worker, "redis_client"): - raise Exception("ray.experimental.flush_redis_unsafe cannot be called " - "before ray.init() has been called.") + ray.worker.global_worker.check_connected() redis_client = ray.worker.global_worker.redis_client @@ -52,15 +51,13 @@ def flush_task_and_object_metadata_unsafe(): Redis. However, after running this command, fault tolerance will most likely not work. """ - if not hasattr(ray.worker.global_worker, "redis_client"): - raise Exception("ray.experimental.flush_redis_unsafe cannot be called " - "before ray.init() has been called.") + ray.worker.global_worker.check_connected() def flush_shard(redis_client): # Flush the task table. Note that this also flushes the driver tasks # which may be undesirable. num_task_keys_deleted = 0 - for key in redis_client.scan_iter(match=TASK_TABLE_PREFIX + b"*"): + for key in redis_client.scan_iter(match=TASK_PREFIX + b"*"): num_task_keys_deleted += redis_client.delete(key) print("Deleted {} task keys from Redis.".format(num_task_keys_deleted)) @@ -81,3 +78,105 @@ def flush_shard(redis_client): # Loop over the shards and flush all of them. for redis_client in ray.worker.global_state.redis_clients: flush_shard(redis_client) + + +def _task_table_shard(shard_index): + redis_client = ray.global_state.redis_clients[shard_index] + task_table_keys = redis_client.keys(TASK_PREFIX + b"*") + results = {} + for key in task_table_keys: + task_id_binary = key[len(TASK_PREFIX):] + results[binary_to_hex(task_id_binary)] = ray.global_state._task_table( + ray.ObjectID(task_id_binary)) + + return results + + +def _object_table_shard(shard_index): + redis_client = ray.global_state.redis_clients[shard_index] + object_table_keys = redis_client.keys(OBJECT_LOCATION_PREFIX + b"*") + results = {} + for key in object_table_keys: + object_id_binary = key[len(OBJECT_LOCATION_PREFIX):] + results[binary_to_hex(object_id_binary)] = ( + ray.global_state._object_table(ray.ObjectID(object_id_binary))) + + return results + + +def _flush_finished_tasks_unsafe_shard(shard_index): + ray.worker.global_worker.check_connected() + + redis_client = ray.global_state.redis_clients[shard_index] + tasks = _task_table_shard(shard_index) + + keys_to_delete = [] + for task_id, task_info in tasks.items(): + if task_info["State"] == ray.experimental.state.TASK_STATUS_DONE: + keys_to_delete.append(TASK_PREFIX + + ray.utils.hex_to_binary(task_id)) + + num_task_keys_deleted = 0 + if len(keys_to_delete) > 0: + num_task_keys_deleted = redis_client.execute_command( + "del", *keys_to_delete) + + print("Deleted {} finished tasks from Redis shard." + .format(num_task_keys_deleted)) + + +def _flush_evicted_objects_unsafe_shard(shard_index): + ray.worker.global_worker.check_connected() + + redis_client = ray.global_state.redis_clients[shard_index] + objects = _object_table_shard(shard_index) + + keys_to_delete = [] + for object_id, object_info in objects.items(): + if object_info["ManagerIDs"] == []: + keys_to_delete.append(OBJECT_LOCATION_PREFIX + + ray.utils.hex_to_binary(object_id)) + keys_to_delete.append(OBJECT_INFO_PREFIX + + ray.utils.hex_to_binary(object_id)) + + num_object_keys_deleted = 0 + if len(keys_to_delete) > 0: + num_object_keys_deleted = redis_client.execute_command( + "del", *keys_to_delete) + + print("Deleted {} keys for evicted objects from Redis." + .format(num_object_keys_deleted)) + + +def flush_finished_tasks_unsafe(): + """This removes some critical state from the Redis shards. + + In a multitenant environment, this will flush metadata for all jobs, which + may be undesirable. + + This removes all of the metadata for finished tasks. This can be used to + try to address out-of-memory errors caused by the accumulation of metadata + in Redis. However, after running this command, fault tolerance will most + likely not work. + """ + ray.worker.global_worker.check_connected() + + for shard_index in range(len(ray.global_state.redis_clients)): + _flush_finished_tasks_unsafe_shard(shard_index) + + +def flush_evicted_objects_unsafe(): + """This removes some critical state from the Redis shards. + + In a multitenant environment, this will flush metadata for all jobs, which + may be undesirable. + + This removes all of the metadata for objects that have been evicted. This + can be used to try to address out-of-memory errors caused by the + accumulation of metadata in Redis. However, after running this command, + fault tolerance will most likely not work. + """ + ray.worker.global_worker.check_connected() + + for shard_index in range(len(ray.global_state.redis_clients)): + _flush_evicted_objects_unsafe_shard(shard_index) diff --git a/test/runtest.py b/test/runtest.py index 694574b003d7..b5923c7e9609 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -2218,6 +2218,16 @@ def f(): assert len(ray.global_state.object_table()) == 0 assert len(ray.global_state.task_table()) == 0 + # Run some more tasks. + ray.get([f.remote() for _ in range(10)]) + + while len(ray.global_state.task_table()) != 0: + ray.experimental.flush_finished_tasks_unsafe() + + # Make sure that we can call this method (but it won't do anything in + # this test case). + ray.experimental.flush_evicted_objects_unsafe() + if __name__ == "__main__": unittest.main(verbosity=2)