diff --git a/python/ray/dashboard/modules/state/state_head.py b/python/ray/dashboard/modules/state/state_head.py index cf7e8fe92cd7..187a11f91f07 100644 --- a/python/ray/dashboard/modules/state/state_head.py +++ b/python/ray/dashboard/modules/state/state_head.py @@ -30,7 +30,7 @@ from ray.dashboard.subprocesses.module import SubprocessModule from ray.dashboard.subprocesses.routes import SubprocessRouteTable as routes from ray.dashboard.subprocesses.utils import ResponseType -from ray.dashboard.utils import RateLimitedModule +from ray.dashboard.utils import HTTPStatusCode, RateLimitedModule from ray.util.state.common import ( DEFAULT_DOWNLOAD_FILENAME, DEFAULT_LOG_LIMIT, @@ -86,7 +86,7 @@ def __init__(self, *args, **kwargs): async def limit_handler_(self): return do_reply( - success=False, + status_code=HTTPStatusCode.TOO_MANY_REQUESTS, error_message=( "Max number of in-progress requests=" f"{self.max_num_call_} reached. " @@ -110,12 +110,16 @@ async def list_jobs(self, req: aiohttp.web.Request) -> aiohttp.web.Response: try: result = await self._state_api.list_jobs(option=options_from_req(req)) return do_reply( - success=True, + status_code=HTTPStatusCode.OK, error_message="", result=asdict(result), ) except DataSourceUnavailable as e: - return do_reply(success=False, error_message=str(e), result=None) + return do_reply( + status_code=HTTPStatusCode.INTERNAL_ERROR, + error_message=str(e), + result=None, + ) @routes.get("/api/v0/nodes") @RateLimitedModule.enforce_max_concurrent_calls @@ -171,7 +175,7 @@ async def list_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response: if not node_id and not node_ip: return do_reply( - success=False, + status_code=HTTPStatusCode.BAD_REQUEST, error_message=( "Both node id and node ip are not provided. " "Please provide at least one of them." @@ -182,7 +186,7 @@ async def list_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response: node_id = await self._log_api.ip_to_node_id(node_ip) if not node_id: return do_reply( - success=False, + status_code=HTTPStatusCode.NOT_FOUND, error_message=( f"Cannot find matching node_id for a given node ip {node_ip}" ), @@ -195,12 +199,16 @@ async def list_logs(self, req: aiohttp.web.Request) -> aiohttp.web.Response: ) except DataSourceUnavailable as e: return do_reply( - success=False, + status_code=HTTPStatusCode.INTERNAL_ERROR, error_message=str(e), result=None, ) - return do_reply(success=True, error_message="", result=result) + return do_reply( + status_code=HTTPStatusCode.OK, + error_message="", + result=result, + ) @routes.get("/api/v0/logs/{media_type}", resp_type=ResponseType.STREAM) @RateLimitedModule.enforce_max_concurrent_calls @@ -330,7 +338,7 @@ async def delayed_response(self, req: aiohttp.web.Request): delay = int(req.match_info.get("delay_s", 10)) await asyncio.sleep(delay) return do_reply( - success=True, + status_code=HTTPStatusCode.OK, error_message="", result={}, partial_failure_warning=None, diff --git a/python/ray/dashboard/state_api_utils.py b/python/ray/dashboard/state_api_utils.py index 30240d9aba10..b794bf840d32 100644 --- a/python/ray/dashboard/state_api_utils.py +++ b/python/ray/dashboard/state_api_utils.py @@ -23,9 +23,11 @@ from ray.util.state.util import convert_string_to_type -def do_reply(success: bool, error_message: str, result: ListApiResponse, **kwargs): +def do_reply( + status_code: HTTPStatusCode, error_message: str, result: ListApiResponse, **kwargs +): return rest_response( - status_code=HTTPStatusCode.OK if success else HTTPStatusCode.INTERNAL_ERROR, + status_code=status_code, message=error_message, result=result, convert_google_style=False, @@ -40,14 +42,22 @@ async def handle_list_api( try: result = await list_api_fn(option=options_from_req(req)) return do_reply( - success=True, + status_code=HTTPStatusCode.OK, error_message="", result=asdict(result), ) except ValueError as e: - return do_reply(success=False, error_message=str(e), result=None) + return do_reply( + status_code=HTTPStatusCode.BAD_REQUEST, + error_message=str(e), + result=None, + ) except DataSourceUnavailable as e: - return do_reply(success=False, error_message=str(e), result=None) + return do_reply( + status_code=HTTPStatusCode.INTERNAL_ERROR, + error_message=str(e), + result=None, + ) def _get_filters_from_req( @@ -104,7 +114,7 @@ async def handle_summary_api( ): result = await summary_fn(option=summary_options_from_req(req)) return do_reply( - success=True, + status_code=HTTPStatusCode.OK, error_message="", result=asdict(result), ) diff --git a/python/ray/dashboard/utils.py b/python/ray/dashboard/utils.py index 4cb60681abeb..f146a682c89c 100644 --- a/python/ray/dashboard/utils.py +++ b/python/ray/dashboard/utils.py @@ -48,7 +48,9 @@ class HTTPStatusCode(IntEnum): OK = 200 # 4xx Client Errors + BAD_REQUEST = 400 NOT_FOUND = 404 + TOO_MANY_REQUESTS = 429 # 5xx Server Errors INTERNAL_ERROR = 500 diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index 487ba294000b..b65b187b5e9d 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -5,7 +5,7 @@ import time from collections import Counter from concurrent.futures import ThreadPoolExecutor -from typing import List +from typing import List, Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -860,6 +860,59 @@ async def test_api_manager_list_workers(state_api_manager): assert exc_info.value.args[0] == GCS_QUERY_FAILURE_WARNING +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("exception", "status_code"), + [ + (None, 200), + (ValueError("Invalid filter parameter"), 400), + (DataSourceUnavailable("GCS connection failed"), 500), + ], +) +async def test_handle_list_api_status_codes( + exception: Optional[Exception], status_code: int +): + """Test that handle_list_api calls do_reply with correct status codes. + + This directly tests the HTTP layer logic that maps exceptions to status codes: + - Success → HTTP 200 OK + - ValueError → HTTP 400 BAD_REQUEST + - DataSourceUnavailable → HTTP 500 INTERNAL_ERROR + """ + from unittest.mock import AsyncMock, MagicMock + + from ray.dashboard.state_api_utils import handle_list_api + from ray.util.state.common import ListApiResponse + + # 1. Mock aiohttp request with proper query interface + mock_request = MagicMock() + + def mock_get(key, default=None): + return default + + mock_request.query = MagicMock() + mock_request.query.get = mock_get + + # 2. Mock response whether success or failure. + if exception is None: + mock_backend = AsyncMock( + return_value=ListApiResponse( + result=[], + total=0, + num_after_truncation=0, + num_filtered=0, + partial_failure_warning="", + ) + ) + else: + mock_backend = AsyncMock(side_effect=exception) + + response = await handle_list_api(mock_backend, mock_request) + + # 3. Assert status_code is correct. + assert response.status == status_code + + @pytest.mark.asyncio async def test_api_manager_list_tasks(state_api_manager): data_source_client = state_api_manager.data_source_client