Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/ray/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
from __future__ import print_function

from .tfutils import TensorFlowVariables
from .features import flush_redis_unsafe, flush_task_and_object_metadata_unsafe
from .features import (
flush_redis_unsafe, flush_task_and_object_metadata_unsafe,
flush_finished_tasks_unsafe, flush_evicted_objects_unsafe,
_flush_finished_tasks_unsafe_shard, _flush_evicted_objects_unsafe_shard)

__all__ = [
"TensorFlowVariables", "flush_redis_unsafe",
"flush_task_and_object_metadata_unsafe"
"flush_task_and_object_metadata_unsafe", "flush_finished_tasks_unsafe",
"flush_evicted_objects_unsafe", "_flush_finished_tasks_unsafe_shard",
"_flush_evicted_objects_unsafe_shard"
]
115 changes: 107 additions & 8 deletions python/ray/experimental/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from __future__ import print_function

import ray
from ray.utils import binary_to_hex

OBJECT_INFO_PREFIX = b"OI:"
OBJECT_LOCATION_PREFIX = b"OL:"
TASK_TABLE_PREFIX = b"TT:"
TASK_PREFIX = b"TT:"


def flush_redis_unsafe():
Expand All @@ -18,9 +19,7 @@ def flush_redis_unsafe():
much of the data is in the task table (and object table), which are not
flushed.
"""
if not hasattr(ray.worker.global_worker, "redis_client"):
raise Exception("ray.experimental.flush_redis_unsafe cannot be called "
"before ray.init() has been called.")
ray.worker.global_worker.check_connected()

redis_client = ray.worker.global_worker.redis_client

Expand Down Expand Up @@ -52,15 +51,13 @@ def flush_task_and_object_metadata_unsafe():
Redis. However, after running this command, fault tolerance will most
likely not work.
"""
if not hasattr(ray.worker.global_worker, "redis_client"):
raise Exception("ray.experimental.flush_redis_unsafe cannot be called "
"before ray.init() has been called.")
ray.worker.global_worker.check_connected()

def flush_shard(redis_client):
# Flush the task table. Note that this also flushes the driver tasks
# which may be undesirable.
num_task_keys_deleted = 0
for key in redis_client.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
for key in redis_client.scan_iter(match=TASK_PREFIX + b"*"):
num_task_keys_deleted += redis_client.delete(key)
print("Deleted {} task keys from Redis.".format(num_task_keys_deleted))

Expand All @@ -81,3 +78,105 @@ def flush_shard(redis_client):
# Loop over the shards and flush all of them.
for redis_client in ray.worker.global_state.redis_clients:
flush_shard(redis_client)


def _task_table_shard(shard_index):
redis_client = ray.global_state.redis_clients[shard_index]
task_table_keys = redis_client.keys(TASK_PREFIX + b"*")
results = {}
for key in task_table_keys:
task_id_binary = key[len(TASK_PREFIX):]
results[binary_to_hex(task_id_binary)] = ray.global_state._task_table(
ray.ObjectID(task_id_binary))

return results


def _object_table_shard(shard_index):
redis_client = ray.global_state.redis_clients[shard_index]
object_table_keys = redis_client.keys(OBJECT_LOCATION_PREFIX + b"*")
results = {}
for key in object_table_keys:
object_id_binary = key[len(OBJECT_LOCATION_PREFIX):]
results[binary_to_hex(object_id_binary)] = (
ray.global_state._object_table(ray.ObjectID(object_id_binary)))

return results


def _flush_finished_tasks_unsafe_shard(shard_index):
ray.worker.global_worker.check_connected()

redis_client = ray.global_state.redis_clients[shard_index]
tasks = _task_table_shard(shard_index)

keys_to_delete = []
for task_id, task_info in tasks.items():
if task_info["State"] == ray.experimental.state.TASK_STATUS_DONE:
keys_to_delete.append(TASK_PREFIX +
ray.utils.hex_to_binary(task_id))

num_task_keys_deleted = 0
if len(keys_to_delete) > 0:
num_task_keys_deleted = redis_client.execute_command(
"del", *keys_to_delete)

print("Deleted {} finished tasks from Redis shard."
.format(num_task_keys_deleted))


def _flush_evicted_objects_unsafe_shard(shard_index):
ray.worker.global_worker.check_connected()

redis_client = ray.global_state.redis_clients[shard_index]
objects = _object_table_shard(shard_index)

keys_to_delete = []
for object_id, object_info in objects.items():
if object_info["ManagerIDs"] == []:
keys_to_delete.append(OBJECT_LOCATION_PREFIX +
ray.utils.hex_to_binary(object_id))
keys_to_delete.append(OBJECT_INFO_PREFIX +
ray.utils.hex_to_binary(object_id))

num_object_keys_deleted = 0
if len(keys_to_delete) > 0:
num_object_keys_deleted = redis_client.execute_command(
"del", *keys_to_delete)

print("Deleted {} keys for evicted objects from Redis."
.format(num_object_keys_deleted))


def flush_finished_tasks_unsafe():
"""This removes some critical state from the Redis shards.

In a multitenant environment, this will flush metadata for all jobs, which
may be undesirable.

This removes all of the metadata for finished tasks. This can be used to
try to address out-of-memory errors caused by the accumulation of metadata
in Redis. However, after running this command, fault tolerance will most
likely not work.
"""
ray.worker.global_worker.check_connected()

for shard_index in range(len(ray.global_state.redis_clients)):
_flush_finished_tasks_unsafe_shard(shard_index)


def flush_evicted_objects_unsafe():
"""This removes some critical state from the Redis shards.

In a multitenant environment, this will flush metadata for all jobs, which
may be undesirable.

This removes all of the metadata for objects that have been evicted. This
can be used to try to address out-of-memory errors caused by the
accumulation of metadata in Redis. However, after running this command,
fault tolerance will most likely not work.
"""
ray.worker.global_worker.check_connected()

for shard_index in range(len(ray.global_state.redis_clients)):
_flush_evicted_objects_unsafe_shard(shard_index)
10 changes: 10 additions & 0 deletions test/runtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2218,6 +2218,16 @@ def f():
assert len(ray.global_state.object_table()) == 0
assert len(ray.global_state.task_table()) == 0

# Run some more tasks.
ray.get([f.remote() for _ in range(10)])

while len(ray.global_state.task_table()) != 0:
ray.experimental.flush_finished_tasks_unsafe()

# Make sure that we can call this method (but it won't do anything in
# this test case).
ray.experimental.flush_evicted_objects_unsafe()


if __name__ == "__main__":
unittest.main(verbosity=2)