diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index 4172f027f84a..c9cf46ba9a62 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -209,15 +209,10 @@ def approx_workers_used(self): def num_workers_connected(self): return self._info()["NumNodesConnected"] - def info_string(self): - return ", ".join( - ["{}={}".format(k, v) for k, v in sorted(self._info().items())]) - - def _info(self): + def get_resource_usage(self): nodes_used = 0.0 resources_used = {} resources_total = {} - now = time.time() for ip, max_resources in self.static_resources_by_ip.items(): avail_resources = self.dynamic_resources_by_ip[ip] max_frac = 0.0 @@ -234,6 +229,17 @@ def _info(self): if frac > max_frac: max_frac = frac nodes_used += max_frac + + return nodes_used, resources_used, resources_total + + def info_string(self): + return ", ".join( + ["{}={}".format(k, v) for k, v in sorted(self._info().items())]) + + def _info(self): + nodes_used, resources_used, resources_total = self.get_resource_usage() + + now = time.time() idle_times = [now - t for t in self.last_used_time_by_ip.values()] heartbeat_times = [ now - t for t in self.last_heartbeat_time_by_ip.values() diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 4f512d6db68f..9bcb7dfd59db 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -122,8 +122,9 @@ def xray_heartbeat_batch_handler(self, unused_channel, data): # Update the load metrics for this raylet. client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) ip = self.raylet_id_to_ip_map.get(client_id) + load_metrics_id = ip + "-" + client_id if ip: - self.load_metrics.update(ip, static_resources, + self.load_metrics.update(load_metrics_id, static_resources, dynamic_resources) else: logger.warning( diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index dc8853c5e917..57948034ca53 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -59,6 +59,26 @@ def test_internal_config(ray_start_cluster_head): assert ray.cluster_resources()["CPU"] == 1 +def verify_load_metrics(monitor, expected_resource_usage=None, timeout=10): + while True: + monitor.process_messages() + resource_usage = monitor.load_metrics.get_resource_usage() + + if expected_resource_usage is None: + if all(x for x in resource_usage[1:]): + break + elif all(x == y for x, y in zip(resource_usage, expected_resource_usage)): + break + else: + timeout -= 1 + time.sleep(1) + + if timeout <= 0: + raise ValueError("Should not be here.") + + return resource_usage + + def test_heartbeats(ray_start_cluster_head): """Unit test for `Cluster.wait_for_nodes`. @@ -67,28 +87,55 @@ def test_heartbeats(ray_start_cluster_head): cluster = ray_start_cluster_head monitor = Monitor(cluster.redis_address, None) + work_handles = [] + @ray.remote class Actor(): - pass + def work(self, timeout=10): + time.sleep(timeout) + return True test_actors = [Actor.remote()] - # This is only used to update the load metrics for the autoscaler. monitor.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL) monitor.subscribe(ray.gcs_utils.XRAY_JOB_CHANNEL) monitor.update_raylet_map() monitor._maybe_flush_gcs() - # Process a round of messages. - monitor.process_messages() - from pprint import pprint; import ipdb; ipdb.set_trace(context=30) - pprint(vars(monitor.load_metrics)) - - # worker_nodes = [cluster.add_node() for i in range(4)] - # for i in range(3): - # test_actors += [Actor.remote()] - # check_resource_usage(monitor.get_heartbeat()) - # cluster.wait_for_nodes() + + timeout = 5 + + verify_load_metrics(monitor, (0.0, {'CPU': 0.0}, {'CPU': 1.0})) + + work_handles += [test_actors[0].work.remote(timeout=timeout * 2)] + + verify_load_metrics(monitor, (1.0, {'CPU': 1.0}, {'CPU': 1.0})) + + ray.get(work_handles) + + num_workers = 4 + num_nodes_total = float(num_workers + 1) + worker_nodes = [cluster.add_node() for i in range(num_workers)] + + cluster.wait_for_nodes() + monitor.update_raylet_map() + monitor._maybe_flush_gcs() + + verify_load_metrics(monitor, (0.0, {'CPU': 0.0}, {'CPU': num_nodes_total})) + + work_handles = [test_actors[0].work.remote(timeout=timeout * 2)] + for i in range(num_workers): + new_actor = Actor.remote() + work_handles += [new_actor.work.remote(timeout=timeout * 2)] + test_actors += [new_actor] + + verify_load_metrics( + monitor, + (num_nodes_total, {'CPU': num_nodes_total}, {'CPU': num_nodes_total})) + + ray.get(work_handles) + + verify_load_metrics(monitor, (0.0, {'CPU': 0.0}, {'CPU': num_nodes_total})) def test_wait_for_nodes(ray_start_cluster_head):