diff --git a/python/ray/dashboard/agent.py b/python/ray/dashboard/agent.py index 355ac03d9aa5..32e910580a04 100644 --- a/python/ray/dashboard/agent.py +++ b/python/ray/dashboard/agent.py @@ -10,12 +10,13 @@ import ray.dashboard.consts as dashboard_consts import ray.dashboard.utils as dashboard_utils from ray._common.network_utils import build_address, is_localhost +from ray._common.retry import call_with_retry from ray._common.utils import get_or_create_event_loop from ray._private import logging_utils from ray._private.process_watcher import create_check_raylet_task from ray._private.ray_constants import AGENT_GRPC_MAX_MESSAGE_LENGTH from ray._private.ray_logging import setup_component_logger -from ray._raylet import GcsClient +from ray._raylet import GcsClient, NodeID logger = logging.getLogger(__name__) @@ -76,6 +77,15 @@ def __init__( cluster_id=self.cluster_id_hex, ) + # Fetch node info and save is_head. + node_info = call_with_retry( + lambda: self.gcs_client.get_all_node_info()[NodeID.from_hex(self.node_id)], + description="get self node info", + max_attempts=30, + max_backoff_s=1, + ) + self.is_head = node_info.is_head_node + if not self.minimal: self._init_non_minimal() diff --git a/python/ray/dashboard/modules/reporter/healthz_agent.py b/python/ray/dashboard/modules/reporter/healthz_agent.py index cff4edaf33d1..b654740ecbcc 100644 --- a/python/ray/dashboard/modules/reporter/healthz_agent.py +++ b/python/ray/dashboard/modules/reporter/healthz_agent.py @@ -1,3 +1,6 @@ +import asyncio +import logging + from aiohttp.web import Request, Response import ray.dashboard.optional_utils as optional_utils @@ -8,6 +11,8 @@ routes = optional_utils.DashboardAgentRouteTable +logger = logging.getLogger(__name__) + class HealthzAgent(dashboard_utils.DashboardAgentModule): """Health check in the agent. @@ -30,10 +35,21 @@ def __init__(self, dashboard_agent): @routes.get("/api/local_raylet_healthz") async def health_check(self, req: Request) -> Response: + try: + await self.raylet_health() + except Exception as e: + return Response(status=503, text=str(e), content_type="application/text") + + return Response( + text="success", + content_type="application/text", + ) + + async def raylet_health(self) -> str: try: alive = await self._health_checker.check_local_raylet_liveness() if alive is False: - return Response(status=503, text="Local Raylet failed") + raise Exception("Local Raylet failed") except ray.exceptions.RpcError as e: # We only consider the error other than GCS unreachable as raylet failure # to avoid false positive. @@ -45,10 +61,37 @@ async def health_check(self, req: Request) -> Response: ray._raylet.GRPC_STATUS_CODE_UNKNOWN, ray._raylet.GRPC_STATUS_CODE_DEADLINE_EXCEEDED, ): - return Response(status=503, text=f"Health check failed due to: {e}") + raise Exception(f"Health check failed due to: {e}") + return "success" + + async def local_gcs_health(self) -> str: + # If GCS is not local, don't check its health. + if not self._dashboard_agent.is_head: + return "success (no local gcs)" + gcs_alive = await self._health_checker.check_gcs_liveness() + if not gcs_alive: + raise Exception("GCS health check failed.") + return "success" + + @routes.get("/api/healthz") + async def unified_health(self, req: Request) -> Response: + [raylet_check, gcs_check] = await asyncio.gather( + self.raylet_health(), + self.local_gcs_health(), + return_exceptions=True, + ) + checks = {"raylet": raylet_check, "gcs": gcs_check} + + # Log failures. + status = 200 + for name, result in checks.items(): + if isinstance(result, Exception): + status = 503 + logger.warning(f"health check {name} failed: {result}") return Response( - text="success", + status=status, + text="\n".join([f"{name}: {result}" for name, result in checks.items()]), content_type="application/text", ) diff --git a/python/ray/dashboard/modules/reporter/reporter_agent.py b/python/ray/dashboard/modules/reporter/reporter_agent.py index ee493539a7a6..151b52718639 100644 --- a/python/ray/dashboard/modules/reporter/reporter_agent.py +++ b/python/ray/dashboard/modules/reporter/reporter_agent.py @@ -25,7 +25,6 @@ import ray._private.prometheus_exporter as prometheus_exporter import ray.dashboard.modules.reporter.reporter_consts as reporter_consts import ray.dashboard.utils as dashboard_utils -from ray._common.network_utils import parse_address from ray._common.utils import ( get_or_create_event_loop, get_user_temp_dir, @@ -424,7 +423,7 @@ def __init__(self, dashboard_agent, raylet_client=None): self._gcs_client = dashboard_agent.gcs_client self._ip = dashboard_agent.ip self._log_dir = dashboard_agent.log_dir - self._is_head_node = self._ip == parse_address(dashboard_agent.gcs_address)[0] + self._is_head_node = dashboard_agent.is_head self._hostname = socket.gethostname() # (pid, created_time) -> psutil.Process self._workers = {} diff --git a/python/ray/dashboard/modules/reporter/tests/test_healthz.py b/python/ray/dashboard/modules/reporter/tests/test_healthz.py index 4c6fd5c97624..7d4c4a62e065 100644 --- a/python/ray/dashboard/modules/reporter/tests/test_healthz.py +++ b/python/ray/dashboard/modules/reporter/tests/test_healthz.py @@ -1,3 +1,4 @@ +import signal import sys import pytest @@ -47,8 +48,6 @@ def test_healthz_agent_2(monkeypatch, ray_start_cluster): wait_for_condition(lambda: requests.get(uri).status_code == 200) - import signal - h.all_processes[ray_constants.PROCESS_TYPE_RAYLET][0].process.send_signal( signal.SIGSTOP ) @@ -59,5 +58,64 @@ def test_healthz_agent_2(monkeypatch, ray_start_cluster): wait_for_condition(lambda: requests.get(uri).status_code != 200) +def test_unified_healthz_head(monkeypatch, ray_start_cluster): + agent_port = find_free_port() + h = ray_start_cluster.add_node(dashboard_agent_listen_port=agent_port) + uri = f"http://{h.node_ip_address}:{agent_port}/api/healthz" + + wait_for_condition(lambda: requests.get(uri).status_code == 200) + resp = requests.get(uri) + assert "raylet: success" in resp.text + assert "gcs: success" in resp.text + + h.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][0].process.kill() + wait_for_condition(lambda: requests.get(uri).status_code == 503) + resp = requests.get(uri) + assert "gcs: " in resp.text + assert "gcs: success" not in resp.text + + +@pytest.mark.skipif(sys.platform == "win32", reason="SIGSTOP only on posix") +def test_unified_healthz_worker(monkeypatch, ray_start_cluster): + monkeypatch.setenv("RAY_health_check_failure_threshold", "3") + monkeypatch.setenv("RAY_health_check_timeout_ms", "100") + monkeypatch.setenv("RAY_health_check_period_ms", "1000") + monkeypatch.setenv("RAY_health_check_initial_delay_ms", "0") + + ray_start_cluster.add_node() + agent_port = find_free_port() + h = ray_start_cluster.add_node(dashboard_agent_listen_port=agent_port) + uri = f"http://{h.node_ip_address}:{agent_port}/api/healthz" + + wait_for_condition(lambda: requests.get(uri).status_code == 200) + resp = requests.get(uri) + assert "gcs: success (no local gcs)" in resp.text + + # Stop local raylet and verify this makes /healthz fail. + h.all_processes[ray_constants.PROCESS_TYPE_RAYLET][0].process.send_signal( + signal.SIGSTOP + ) + wait_for_condition(lambda: requests.get(uri).status_code == 503) + resp = requests.get(uri) + assert "raylet: Local Raylet failed" in resp.text + + +def test_unified_healthz_worker_gcs_down(monkeypatch, ray_start_cluster): + h_head = ray_start_cluster.add_node() + agent_port = find_free_port() + h_worker = ray_start_cluster.add_node(dashboard_agent_listen_port=agent_port) + uri = f"http://{h_worker.node_ip_address}:{agent_port}/api/healthz" + + wait_for_condition(lambda: requests.get(uri).status_code == 200) + resp = requests.get(uri) + assert "gcs: success (no local gcs)" in resp.text + + # Stop the head GCS server. + h_head.all_processes[ray_constants.PROCESS_TYPE_GCS_SERVER][0].process.kill() + + # Worker health check should still succeed. + assert requests.get(uri).status_code == 200 + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__]))