Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions python/ray/autoscaler/autoscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,10 @@ def approx_workers_used(self):
def num_workers_connected(self):
return self._info()["NumNodesConnected"]

def info_string(self):
return ", ".join(
["{}={}".format(k, v) for k, v in sorted(self._info().items())])

def _info(self):
def get_resource_usage(self):
nodes_used = 0.0
resources_used = {}
resources_total = {}
now = time.time()
for ip, max_resources in self.static_resources_by_ip.items():
avail_resources = self.dynamic_resources_by_ip[ip]
max_frac = 0.0
Expand All @@ -234,6 +229,17 @@ def _info(self):
if frac > max_frac:
max_frac = frac
nodes_used += max_frac

return nodes_used, resources_used, resources_total

def info_string(self):
return ", ".join(
["{}={}".format(k, v) for k, v in sorted(self._info().items())])

def _info(self):
nodes_used, resources_used, resources_total = self.get_resource_usage()

now = time.time()
idle_times = [now - t for t in self.last_used_time_by_ip.values()]
heartbeat_times = [
now - t for t in self.last_heartbeat_time_by_ip.values()
Expand Down
3 changes: 2 additions & 1 deletion python/ray/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ def xray_heartbeat_batch_handler(self, unused_channel, data):
# Update the load metrics for this raylet.
client_id = ray.utils.binary_to_hex(heartbeat_message.client_id)
ip = self.raylet_id_to_ip_map.get(client_id)
load_metrics_id = ip + "-" + client_id
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is questionable. Not sure what the correct way of handling this is.

if ip:
self.load_metrics.update(ip, static_resources,
self.load_metrics.update(load_metrics_id, static_resources,
dynamic_resources)
else:
logger.warning(
Expand Down
71 changes: 59 additions & 12 deletions python/ray/tests/test_multi_node_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ def test_internal_config(ray_start_cluster_head):
assert ray.cluster_resources()["CPU"] == 1


def verify_load_metrics(monitor, expected_resource_usage=None, timeout=10):
while True:
monitor.process_messages()
resource_usage = monitor.load_metrics.get_resource_usage()

if expected_resource_usage is None:
if all(x for x in resource_usage[1:]):
break
elif all(x == y for x, y in zip(resource_usage, expected_resource_usage)):
break
else:
timeout -= 1
time.sleep(1)

if timeout <= 0:
raise ValueError("Should not be here.")

return resource_usage


def test_heartbeats(ray_start_cluster_head):
"""Unit test for `Cluster.wait_for_nodes`.

Expand All @@ -67,28 +87,55 @@ def test_heartbeats(ray_start_cluster_head):
cluster = ray_start_cluster_head
monitor = Monitor(cluster.redis_address, None)

work_handles = []

@ray.remote
class Actor():
pass
def work(self, timeout=10):
time.sleep(timeout)
return True

test_actors = [Actor.remote()]
# This is only used to update the load metrics for the autoscaler.

monitor.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL)
monitor.subscribe(ray.gcs_utils.XRAY_JOB_CHANNEL)

monitor.update_raylet_map()
monitor._maybe_flush_gcs()
# Process a round of messages.
monitor.process_messages()
from pprint import pprint; import ipdb; ipdb.set_trace(context=30)
pprint(vars(monitor.load_metrics))

# worker_nodes = [cluster.add_node() for i in range(4)]
# for i in range(3):
# test_actors += [Actor.remote()]
# check_resource_usage(monitor.get_heartbeat())
# cluster.wait_for_nodes()

timeout = 5

verify_load_metrics(monitor, (0.0, {'CPU': 0.0}, {'CPU': 1.0}))

work_handles += [test_actors[0].work.remote(timeout=timeout * 2)]

verify_load_metrics(monitor, (1.0, {'CPU': 1.0}, {'CPU': 1.0}))

ray.get(work_handles)

num_workers = 4
num_nodes_total = float(num_workers + 1)
worker_nodes = [cluster.add_node() for i in range(num_workers)]

cluster.wait_for_nodes()
monitor.update_raylet_map()
monitor._maybe_flush_gcs()

verify_load_metrics(monitor, (0.0, {'CPU': 0.0}, {'CPU': num_nodes_total}))

work_handles = [test_actors[0].work.remote(timeout=timeout * 2)]
for i in range(num_workers):
new_actor = Actor.remote()
work_handles += [new_actor.work.remote(timeout=timeout * 2)]
test_actors += [new_actor]

verify_load_metrics(
monitor,
(num_nodes_total, {'CPU': num_nodes_total}, {'CPU': num_nodes_total}))

ray.get(work_handles)

verify_load_metrics(monitor, (0.0, {'CPU': 0.0}, {'CPU': num_nodes_total}))


def test_wait_for_nodes(ray_start_cluster_head):
Expand Down