diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index 45314de071a8..f47dfa85c178 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -6,11 +6,12 @@ ClusterNodeInfoCache, DefaultClusterNodeInfoCache, ) -from ray.serve._private.common import DeploymentID +from ray.serve._private.common import DeploymentHandleSource, DeploymentID, EndpointInfo from ray.serve._private.constants import ( RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE, RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS, RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING, + RAY_SERVE_PROXY_PREFER_LOCAL_NODE_ROUTING, ) from ray.serve._private.deployment_scheduler import ( DefaultDeploymentScheduler, @@ -131,3 +132,26 @@ def create_router( def add_grpc_address(grpc_server: gRPCServer, server_address: str): """Helper function to add a address to gRPC server.""" grpc_server.add_insecure_port(server_address) + + +def get_proxy_handle(endpoint: DeploymentID, info: EndpointInfo): + # NOTE(zcin): needs to be lazy import due to a circular dependency. + # We should not be importing from application_state in context. + from ray.serve.context import _get_global_client + + client = _get_global_client() + handle = client.get_handle(endpoint.name, endpoint.app_name, check_exists=True) + + # NOTE(zcin): It's possible that a handle is already initialized + # if a deployment with the same name and application name was + # deleted, then redeployed later. However this is not an issue since + # we initialize all handles with the same init options. + if not handle.is_initialized: + # NOTE(zcin): since the router is eagerly initialized here, the + # proxy will receive the replica set from the controller early. + handle._init( + _prefer_local_routing=RAY_SERVE_PROXY_PREFER_LOCAL_NODE_ROUTING, + _source=DeploymentHandleSource.PROXY, + ) + + return handle.options(stream=not info.app_is_cross_language) diff --git a/python/ray/serve/_private/handle_options.py b/python/ray/serve/_private/handle_options.py index 7354288b07ce..fe6e942e77ee 100644 --- a/python/ray/serve/_private/handle_options.py +++ b/python/ray/serve/_private/handle_options.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, fields import ray -from ray.serve._private.common import DeploymentHandleSource, RequestProtocol +from ray.serve._private.common import DeploymentHandleSource from ray.serve._private.utils import DEFAULT @@ -52,7 +52,6 @@ class DynamicHandleOptionsBase(ABC): method_name: str = "__call__" multiplexed_model_id: str = "" stream: bool = False - _request_protocol: str = RequestProtocol.UNDEFINED def copy_and_update(self, **kwargs) -> "DynamicHandleOptionsBase": new_kwargs = {} diff --git a/python/ray/serve/_private/proxy.py b/python/ray/serve/_private/proxy.py index fc4d51b9b181..614a8dc39508 100644 --- a/python/ray/serve/_private/proxy.py +++ b/python/ray/serve/_private/proxy.py @@ -6,8 +6,7 @@ import socket import time from abc import ABC, abstractmethod -from functools import partial -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple import grpc import starlette @@ -19,7 +18,6 @@ from starlette.types import Receive import ray -from ray import serve from ray._private.utils import get_or_create_event_loop from ray.actor import ActorHandle from ray.exceptions import RayActorError, RayTaskError @@ -41,7 +39,7 @@ SERVE_MULTIPLEXED_MODEL_ID, SERVE_NAMESPACE, ) -from ray.serve._private.default_impl import add_grpc_address +from ray.serve._private.default_impl import add_grpc_address, get_proxy_handle from ray.serve._private.grpc_util import DummyServicer, create_serve_grpc_server from ray.serve._private.http_util import ( MessageQueue, @@ -68,11 +66,7 @@ gRPCProxyRequest, ) from ray.serve._private.proxy_response_generator import ProxyResponseGenerator -from ray.serve._private.proxy_router import ( - EndpointRouter, - LongestPrefixRouter, - ProxyRouter, -) +from ray.serve._private.proxy_router import ProxyRouter from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import ( call_function_from_import_path, @@ -151,9 +145,8 @@ def __init__( node_id: NodeId, node_ip_address: str, is_head: bool, - proxy_router_class: Type[ProxyRouter], + proxy_router: ProxyRouter, request_timeout_s: Optional[float] = None, - get_handle_override: Optional[Callable] = None, ): self.request_timeout_s = request_timeout_s if self.request_timeout_s is not None and self.request_timeout_s < 0: @@ -162,14 +155,7 @@ def __init__( self._node_id = node_id self._is_head = is_head - # Used only for displaying the route table. - self.route_info: Dict[str, DeploymentID] = dict() - - self.proxy_router = proxy_router_class( - get_handle_override - or partial(serve.get_deployment_handle, _record_telemetry=False), - self.protocol, - ) + self.proxy_router = proxy_router self.request_counter = metrics.Counter( f"serve_num_{self.protocol.lower()}_requests", description=f"The number of {self.protocol} requests processed.", @@ -250,14 +236,6 @@ def _is_draining(self) -> bool: """Whether is proxy actor is in the draining status or not.""" return self._draining_start_time is not None - def update_routes(self, endpoints: Dict[DeploymentID, EndpointInfo]): - self.route_info: Dict[str, DeploymentID] = dict() - for deployment_id, info in endpoints.items(): - route = info.route - self.route_info[route] = deployment_id - - self.proxy_router.update_routes(endpoints) - def is_drained(self): """Check whether the proxy actor is drained or not. @@ -590,7 +568,7 @@ async def routes_response( ) -> ResponseGenerator: yield ListApplicationsResponse( application_names=[ - endpoint.app_name for endpoint in self.route_info.values() + endpoint.app_name for endpoint in self.proxy_router.endpoints ], ).SerializeToString() @@ -784,18 +762,16 @@ def __init__( node_id: NodeId, node_ip_address: str, is_head: bool, - proxy_router_class: Type[ProxyRouter], + proxy_router: ProxyRouter, request_timeout_s: Optional[float] = None, proxy_actor: Optional[ActorHandle] = None, - get_handle_override: Optional[Callable] = None, ): super().__init__( node_id, node_ip_address, is_head, - proxy_router_class, + proxy_router, request_timeout_s=request_timeout_s, - get_handle_override=get_handle_override, ) self.self_actor_handle = proxy_actor or ray.get_runtime_context().current_actor self.asgi_receive_queues: Dict[str, MessageQueue] = dict() @@ -823,13 +799,13 @@ async def routes_response( status_code = 200 if healthy else 503 if healthy: response = dict() - for route, endpoint in self.route_info.items(): + for endpoint, info in self.proxy_router.endpoints.items(): # For 2.x deployments, return {route -> app name} if endpoint.app_name: - response[route] = endpoint.app_name + response[info.route] = endpoint.app_name # Keep compatibility with 1.x deployments. else: - response[route] = endpoint.name + response[info.route] = endpoint.name else: response = message @@ -1231,11 +1207,12 @@ def __init__( http_middlewares.extend(middlewares) is_head = node_id == get_head_node_id() + self.proxy_router = ProxyRouter(get_proxy_handle) self.http_proxy = HTTPProxy( node_id=node_id, node_ip_address=node_ip_address, is_head=is_head, - proxy_router_class=LongestPrefixRouter, + proxy_router=self.proxy_router, request_timeout_s=( request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S ), @@ -1245,7 +1222,7 @@ def __init__( node_id=node_id, node_ip_address=node_ip_address, is_head=is_head, - proxy_router_class=EndpointRouter, + proxy_router=self.proxy_router, request_timeout_s=( request_timeout_s or RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S ), @@ -1283,9 +1260,7 @@ def __init__( ) def _update_routes_in_proxies(self, endpoints: Dict[DeploymentID, EndpointInfo]): - self.http_proxy.update_routes(endpoints) - if self.grpc_proxy is not None: - self.grpc_proxy.update_routes(endpoints) + self.proxy_router.update_routes(endpoints) def _update_logging_config(self, logging_config: LoggingConfig): configure_component_logger( diff --git a/python/ray/serve/_private/proxy_router.py b/python/ray/serve/_private/proxy_router.py index ead410198173..7c40cee8b418 100644 --- a/python/ray/serve/_private/proxy_router.py +++ b/python/ray/serve/_private/proxy_router.py @@ -1,18 +1,8 @@ import logging -from abc import ABC, abstractmethod from typing import Callable, Dict, List, Optional, Tuple -from ray.serve._private.common import ( - ApplicationName, - DeploymentHandleSource, - DeploymentID, - EndpointInfo, - RequestProtocol, -) -from ray.serve._private.constants import ( - RAY_SERVE_PROXY_PREFER_LOCAL_NODE_ROUTING, - SERVE_LOGGER_NAME, -) +from ray.serve._private.common import ApplicationName, DeploymentID, EndpointInfo +from ray.serve._private.constants import SERVE_LOGGER_NAME from ray.serve.handle import DeploymentHandle logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -21,27 +11,32 @@ NO_REPLICAS_MESSAGE = "No replicas are available yet." -class ProxyRouter(ABC): +class ProxyRouter: """Router interface for the proxy to use.""" def __init__( self, get_handle: Callable[[str, str], DeploymentHandle], - protocol: RequestProtocol, ): # Function to get a handle given a name. Used to mock for testing. self._get_handle = get_handle - # Protocol to config handle - self._protocol = protocol # Contains a ServeHandle for each endpoint. self.handles: Dict[DeploymentID, DeploymentHandle] = dict() # Flipped to `True` once the route table has been updated at least once. # The proxy router is not ready for traffic until the route table is populated self._route_table_populated = False - @abstractmethod - def update_routes(self, endpoints: Dict[DeploymentID, EndpointInfo]): - raise NotImplementedError + # Info used for HTTP proxy + # Routes sorted in order of decreasing length. + self.sorted_routes: List[str] = list() + # Endpoints associated with the routes. + self.route_info: Dict[str, DeploymentID] = dict() + # Map of application name to is_cross_language. + self.app_to_is_cross_language: Dict[ApplicationName, bool] = dict() + + # Info used for gRPC proxy + # Endpoints info associated with endpoints. + self.endpoints: Dict[DeploymentID, EndpointInfo] = dict() def ready_for_traffic(self, is_head: bool) -> Tuple[bool, str]: """Whether the proxy router is ready to serve traffic. @@ -72,31 +67,15 @@ def ready_for_traffic(self, is_head: bool) -> Tuple[bool, str]: return False, NO_REPLICAS_MESSAGE - -class LongestPrefixRouter(ProxyRouter): - """Router that performs longest prefix matches on incoming routes.""" - - def __init__( - self, - get_handle: Callable[[str, str], DeploymentHandle], - protocol: RequestProtocol, - ): - super().__init__(get_handle, protocol) - - # Routes sorted in order of decreasing length. - self.sorted_routes: List[str] = list() - # Endpoints associated with the routes. - self.route_info: Dict[str, DeploymentID] = dict() - # Map of application name to is_cross_language. - self.app_to_is_cross_language: Dict[ApplicationName, bool] = dict() - - def update_routes(self, endpoints: Dict[DeploymentID, EndpointInfo]) -> None: + def update_routes(self, endpoints: Dict[DeploymentID, EndpointInfo]): logger.info( - f"Got updated endpoints: {endpoints}.", extra={"log_to_stderr": False} + f"Got updated endpoints: {endpoints}.", extra={"log_to_stderr": True} ) if endpoints: self._route_table_populated = True + self.endpoints = endpoints + existing_handles = set(self.handles.keys()) routes = [] route_info = {} @@ -108,19 +87,7 @@ def update_routes(self, endpoints: Dict[DeploymentID, EndpointInfo]) -> None: if endpoint in self.handles: existing_handles.remove(endpoint) else: - handle = self._get_handle(endpoint.name, endpoint.app_name) - # NOTE(zcin): since the router is eagerly initialized here, - # it will receive the replica set from the controller early. - if not handle.is_initialized: - handle._init( - _prefer_local_routing=RAY_SERVE_PROXY_PREFER_LOCAL_NODE_ROUTING, - _source=DeploymentHandleSource.PROXY, - ) - handle._set_request_protocol(self._protocol) - # Streaming codepath isn't supported for Java. - self.handles[endpoint] = handle.options( - stream=not info.app_is_cross_language - ) + self.handles[endpoint] = self._get_handle(endpoint, info) # Clean up any handles that are no longer used. if len(existing_handles) > 0: @@ -173,53 +140,6 @@ def match_route( return None - -class EndpointRouter(ProxyRouter): - """Router that matches endpoint to return the handle.""" - - def __init__(self, get_handle: Callable, protocol: RequestProtocol): - super().__init__(get_handle, protocol) - - # Endpoints info associated with endpoints. - self.endpoints: Dict[DeploymentID, EndpointInfo] = dict() - - def update_routes(self, endpoints: Dict[DeploymentID, EndpointInfo]): - logger.info( - f"Got updated endpoints: {endpoints}.", extra={"log_to_stderr": False} - ) - if endpoints: - self._route_table_populated = True - - self.endpoints = endpoints - - existing_handles = set(self.handles.keys()) - for endpoint, info in endpoints.items(): - if endpoint in self.handles: - existing_handles.remove(endpoint) - else: - handle = self._get_handle(endpoint.name, endpoint.app_name) - # NOTE(zcin): since the router is eagerly initialized here, - # it will receive the replica set from the controller early. - if not handle.is_initialized: - handle._init( - _prefer_local_routing=RAY_SERVE_PROXY_PREFER_LOCAL_NODE_ROUTING, - _source=DeploymentHandleSource.PROXY, - ) - handle._set_request_protocol(self._protocol) - # Streaming codepath isn't supported for Java. - self.handles[endpoint] = handle.options( - stream=not info.app_is_cross_language - ) - - # Clean up any handles that are no longer used. - if len(existing_handles) > 0: - logger.info( - f"Deleting {len(existing_handles)} unused handles.", - extra={"log_to_stderr": False}, - ) - for endpoint in existing_handles: - del self.handles[endpoint] - def get_handle_for_endpoint( self, target_app_name: str ) -> Optional[Tuple[str, DeploymentHandle, bool]]: @@ -228,7 +148,7 @@ def get_handle_for_endpoint( Args: target_app_name: app_name to match against. Returns: - (route, handle, app_name, is_cross_language) for the single app if there + (route, handle, is_cross_language) for the single app if there is only one, else find the app and handle for exact match. Else return None. """ for endpoint_tag, handle in self.handles.items(): diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index e2d5366324b7..0eba1c5dc5ee 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -7,7 +7,12 @@ import ray from ray._raylet import ObjectRefGenerator -from ray.serve._private.common import DeploymentID, RequestMetadata, RequestProtocol +from ray.serve._private.common import ( + DeploymentHandleSource, + DeploymentID, + RequestMetadata, + RequestProtocol, +) from ray.serve._private.constants import SERVE_LOGGER_NAME from ray.serve._private.default_impl import ( CreateRouterCallable, @@ -43,23 +48,22 @@ def __init__( deployment_name: str, app_name: str, *, + init_options: Optional[InitHandleOptionsBase] = None, handle_options: Optional[DynamicHandleOptionsBase] = None, _router: Optional[Router] = None, _create_router: Optional[CreateRouterCallable] = None, _request_counter: Optional[metrics.Counter] = None, - _recorded_telemetry: bool = False, ): self.deployment_id = DeploymentID(name=deployment_name, app_name=app_name) + self.init_options: Optional[InitHandleOptionsBase] = init_options self.handle_options: DynamicHandleOptionsBase = ( handle_options or create_dynamic_handle_options() ) - self.init_options: Optional[InitHandleOptionsBase] = None self.handle_id = get_random_string() self.request_counter = _request_counter or self._create_request_counter( app_name, deployment_name, self.handle_id ) - self._recorded_telemetry = _recorded_telemetry self._router: Optional[Router] = _router if _create_router is None: @@ -72,23 +76,6 @@ def __init__( extra={"log_to_stderr": False}, ) - def _record_telemetry_if_needed(self): - # Record telemetry once per handle and not when used from the proxy - # (detected via request protocol). - if ( - not self._recorded_telemetry - and self.handle_options._request_protocol == RequestProtocol.UNDEFINED - ): - if self.__class__ == DeploymentHandle: - ServeUsageTag.DEPLOYMENT_HANDLE_API_USED.record("1") - - self._recorded_telemetry = True - - def _set_request_protocol(self, request_protocol: RequestProtocol): - self.handle_options = self.handle_options.copy_and_update( - _request_protocol=request_protocol - ) - def _get_or_create_router(self) -> Router: if self._router is None: self._router = self._create_router( @@ -166,6 +153,13 @@ def _init(self, **kwargs): self.init_options = create_init_handle_options(**kwargs) self._get_or_create_router() + # Record handle api telemetry when not in the proxy + if ( + self.init_options._source != DeploymentHandleSource.PROXY + and self.__class__ == DeploymentHandle + ): + ServeUsageTag.DEPLOYMENT_HANDLE_API_USED.record("1") + def _options(self, _prefer_local_routing=DEFAULT.VALUE, **kwargs): if kwargs.get("stream") is True and inside_ray_client_context(): raise RuntimeError( @@ -185,11 +179,11 @@ def _options(self, _prefer_local_routing=DEFAULT.VALUE, **kwargs): return DeploymentHandle( self.deployment_name, self.app_name, + init_options=self.init_options, handle_options=new_handle_options, _router=self._router, _create_router=self._create_router, _request_counter=self.request_counter, - _recorded_telemetry=self._recorded_telemetry, ) def _remote( @@ -198,7 +192,6 @@ def _remote( args: Tuple[Any], kwargs: Dict[str, Any], ) -> concurrent.futures.Future: - self._record_telemetry_if_needed() self.request_counter.inc( tags={ "route": request_metadata.route, @@ -729,6 +722,17 @@ def remote( remote method call. """ _request_context = ray.serve.context._serve_request_context.get() + + request_protocol = RequestProtocol.UNDEFINED + if ( + self.init_options + and self.init_options._source == DeploymentHandleSource.PROXY + ): + if _request_context.is_http_request: + request_protocol = RequestProtocol.HTTP + elif _request_context.grpc_context: + request_protocol = RequestProtocol.GRPC + request_metadata = RequestMetadata( request_id=_request_context.request_id if _request_context.request_id @@ -741,7 +745,7 @@ def remote( app_name=self.app_name, multiplexed_model_id=self.handle_options.multiplexed_model_id, is_streaming=self.handle_options.stream, - _request_protocol=self.handle_options._request_protocol, + _request_protocol=request_protocol, grpc_context=_request_context.grpc_context, ) diff --git a/python/ray/serve/tests/test_handle_1.py b/python/ray/serve/tests/test_handle_1.py index 922e67c816e7..ca738d24401b 100644 --- a/python/ray/serve/tests/test_handle_1.py +++ b/python/ray/serve/tests/test_handle_1.py @@ -7,7 +7,7 @@ import ray from ray import serve -from ray.serve._private.common import DeploymentHandleSource, RequestProtocol +from ray.serve._private.common import DeploymentHandleSource from ray.serve._private.constants import ( RAY_SERVE_FORCE_LOCAL_TESTING_MODE, SERVE_DEFAULT_APP_NAME, @@ -265,37 +265,6 @@ def echo(name: str): assert handle2._router is handle._router -def test_set_request_protocol(serve_instance): - """Test setting request protocol for a handle. - - When a handle is created, it's _request_protocol is undefined. When calling - `_set_request_protocol()`, _request_protocol is set to the specified protocol. - When chaining options, the _request_protocol on the new handle is copied over. - When calling `_set_request_protocol()` on the new handle, _request_protocol - on the new handle is changed accordingly, while _request_protocol on the - original handle remains unchanged. - """ - - @serve.deployment - def echo(name: str): - return f"Hi {name}" - - handle = serve.run(echo.bind()) - assert handle.handle_options._request_protocol == RequestProtocol.UNDEFINED - - handle._set_request_protocol(RequestProtocol.HTTP) - assert handle.handle_options._request_protocol == RequestProtocol.HTTP - - multiplexed_model_id = "fake-multiplexed_model_id" - new_handle = handle.options(multiplexed_model_id=multiplexed_model_id) - assert new_handle.handle_options.multiplexed_model_id == multiplexed_model_id - assert new_handle.handle_options._request_protocol == RequestProtocol.HTTP - - new_handle._set_request_protocol(RequestProtocol.GRPC) - assert new_handle.handle_options._request_protocol == RequestProtocol.GRPC - assert handle.handle_options._request_protocol == RequestProtocol.HTTP - - def test_init(serve_instance): @serve.deployment def f(): diff --git a/python/ray/serve/tests/unit/test_handle_options.py b/python/ray/serve/tests/unit/test_handle_options.py index aa2e22ac14ea..6acb811ca142 100644 --- a/python/ray/serve/tests/unit/test_handle_options.py +++ b/python/ray/serve/tests/unit/test_handle_options.py @@ -2,7 +2,7 @@ import pytest -from ray.serve._private.common import DeploymentHandleSource, RequestProtocol +from ray.serve._private.common import DeploymentHandleSource from ray.serve._private.handle_options import DynamicHandleOptions, InitHandleOptions from ray.serve._private.utils import DEFAULT @@ -12,53 +12,45 @@ def test_dynamic_handle_options(): assert default_options.method_name == "__call__" assert default_options.multiplexed_model_id == "" assert default_options.stream is False - assert default_options._request_protocol == RequestProtocol.UNDEFINED # Test setting method name. only_set_method = default_options.copy_and_update(method_name="hi") assert only_set_method.method_name == "hi" assert only_set_method.multiplexed_model_id == "" assert only_set_method.stream is False - assert default_options._request_protocol == RequestProtocol.UNDEFINED # Existing options should be unmodified. assert default_options.method_name == "__call__" assert default_options.multiplexed_model_id == "" assert default_options.stream is False - assert default_options._request_protocol == RequestProtocol.UNDEFINED # Test setting model ID. only_set_model_id = default_options.copy_and_update(multiplexed_model_id="hi") assert only_set_model_id.method_name == "__call__" assert only_set_model_id.multiplexed_model_id == "hi" assert only_set_model_id.stream is False - assert default_options._request_protocol == RequestProtocol.UNDEFINED # Existing options should be unmodified. assert default_options.method_name == "__call__" assert default_options.multiplexed_model_id == "" assert default_options.stream is False - assert default_options._request_protocol == RequestProtocol.UNDEFINED # Test setting stream. only_set_stream = default_options.copy_and_update(stream=True) assert only_set_stream.method_name == "__call__" assert only_set_stream.multiplexed_model_id == "" assert only_set_stream.stream is True - assert default_options._request_protocol == RequestProtocol.UNDEFINED # Existing options should be unmodified. assert default_options.method_name == "__call__" assert default_options.multiplexed_model_id == "" assert default_options.stream is False - assert default_options._request_protocol == RequestProtocol.UNDEFINED # Test setting multiple. set_multiple = default_options.copy_and_update(method_name="hi", stream=True) assert set_multiple.method_name == "hi" assert set_multiple.multiplexed_model_id == "" assert set_multiple.stream is True - assert default_options._request_protocol == RequestProtocol.UNDEFINED def test_init_handle_options(): diff --git a/python/ray/serve/tests/unit/test_proxy.py b/python/ray/serve/tests/unit/test_proxy.py index 4bf4f7a8e59a..8cef61671b4d 100644 --- a/python/ray/serve/tests/unit/test_proxy.py +++ b/python/ray/serve/tests/unit/test_proxy.py @@ -20,7 +20,6 @@ from ray.serve._private.proxy_router import ( NO_REPLICAS_MESSAGE, NO_ROUTES_MESSAGE, - EndpointRouter, ProxyRouter, ) from ray.serve._private.test_utils import FakeGrpcContext, MockDeploymentHandle @@ -108,7 +107,7 @@ def __init__(self, *args, **kwargs): self._ready_for_traffic = False def update_routes(self, endpoints: Dict[DeploymentID, EndpointInfo]): - pass + self.endpoints = endpoints def get_handle_for_endpoint(self, *args, **kwargs): if ( @@ -254,14 +253,14 @@ def create_grpc_proxy(self, is_head: bool = False): node_id=node_id, node_ip_address=node_ip_address, is_head=is_head, - proxy_router_class=FakeProxyRouter, + proxy_router=FakeProxyRouter(), ) @pytest.mark.asyncio async def test_not_found_response(self): """Test gRPCProxy returns the correct not found response.""" grpc_proxy = self.create_grpc_proxy() - grpc_proxy.update_routes({}) + grpc_proxy.proxy_router.update_routes({}) # Application name not provided. status, _ = await _consume_proxy_generator( @@ -298,7 +297,7 @@ async def test_routes_response( - the proxy is draining. """ grpc_proxy = self.create_grpc_proxy() - grpc_proxy.update_routes( + grpc_proxy.proxy_router.update_routes( {DeploymentID(name="deployment", app_name="app"): EndpointInfo("/route")}, ) if is_draining: @@ -447,7 +446,7 @@ def create_http_proxy(self, is_head: bool = False): node_id=node_id, node_ip_address=node_ip_address, is_head=is_head, - proxy_router_class=FakeProxyRouter, + proxy_router=FakeProxyRouter(), proxy_actor=FakeActorHandle(), ) @@ -455,7 +454,7 @@ def create_http_proxy(self, is_head: bool = False): async def test_not_found_response(self): """Test the response returned when a route is not found.""" http_proxy = self.create_http_proxy() - http_proxy.update_routes({}) + http_proxy.proxy_router.update_routes({}) status, messages = await _consume_proxy_generator( http_proxy.proxy_request( @@ -486,7 +485,7 @@ async def test_routes_response( - the proxy is draining. """ http_proxy = self.create_http_proxy() - http_proxy.update_routes( + http_proxy.proxy_router.update_routes( {DeploymentID(name="deployment", app_name="app"): EndpointInfo("/route")}, ) if is_draining: @@ -705,14 +704,16 @@ async def test_websocket_call(self, disconnect: str): async def test_head_http_unhealthy_until_route_table_updated(): """Health endpoint should error until `update_routes` has been called.""" + def get_handle_override(endpoint, info): + return MockDeploymentHandle(endpoint.name, endpoint.app_name) + http_proxy = HTTPProxy( node_id="fake-node-id", node_ip_address="fake-node-ip-address", # proxy is on head node is_head=True, - proxy_router_class=EndpointRouter, + proxy_router=ProxyRouter(get_handle_override), proxy_actor=FakeActorHandle(), - get_handle_override=lambda d, a: MockDeploymentHandle(d, a), ) proxy_request = FakeProxyRequest( request_type="http", @@ -730,7 +731,9 @@ async def test_head_http_unhealthy_until_route_table_updated(): assert messages[1]["body"].decode("utf-8") == NO_ROUTES_MESSAGE # Update route table, response should no longer error - http_proxy.update_routes({DeploymentID("a", "b"): EndpointInfo("/route")}) + http_proxy.proxy_router.update_routes( + {DeploymentID("a", "b"): EndpointInfo("/route")} + ) status, messages = await _consume_proxy_generator( http_proxy.proxy_request(proxy_request) @@ -748,15 +751,13 @@ async def test_worker_http_unhealthy_until_replicas_populated(): """Health endpoint should error until handle's running replicas is populated.""" handle = MockDeploymentHandle("a", "b") - http_proxy = HTTPProxy( node_id="fake-node-id", node_ip_address="fake-node-ip-address", # proxy is on worker node is_head=False, - proxy_router_class=EndpointRouter, + proxy_router=ProxyRouter(lambda *args: handle), proxy_actor=FakeActorHandle(), - get_handle_override=lambda *args: handle, ) proxy_request = FakeProxyRequest( request_type="http", @@ -775,7 +776,9 @@ async def test_worker_http_unhealthy_until_replicas_populated(): # Update route table, response should still error because running # replicas is not yet populated. - http_proxy.update_routes({DeploymentID("a", "b"): EndpointInfo("/route")}) + http_proxy.proxy_router.update_routes( + {DeploymentID("a", "b"): EndpointInfo("/route")} + ) status, messages = await _consume_proxy_generator( http_proxy.proxy_request(proxy_request) diff --git a/python/ray/serve/tests/unit/test_proxy_router.py b/python/ray/serve/tests/unit/test_proxy_router.py index 42d0e12c0f9d..382024d11182 100644 --- a/python/ray/serve/tests/unit/test_proxy_router.py +++ b/python/ray/serve/tests/unit/test_proxy_router.py @@ -1,46 +1,24 @@ -from typing import Callable - import pytest -from ray.serve._private.common import DeploymentID, EndpointInfo, RequestProtocol +from ray.serve._private.common import DeploymentID, EndpointInfo from ray.serve._private.proxy_router import ( NO_REPLICAS_MESSAGE, NO_ROUTES_MESSAGE, - EndpointRouter, - LongestPrefixRouter, ProxyRouter, ) from ray.serve._private.test_utils import MockDeploymentHandle -def get_handle_function(router: ProxyRouter) -> Callable: - if isinstance(router, LongestPrefixRouter): - return router.match_route - else: - return router.get_handle_for_endpoint - - -@pytest.fixture -def mock_longest_prefix_router() -> LongestPrefixRouter: - def mock_get_handle(deployment_name, app_name, *args, **kwargs): - return MockDeploymentHandle(deployment_name, app_name) - - yield LongestPrefixRouter(mock_get_handle, RequestProtocol.HTTP) - - @pytest.fixture -def mock_endpoint_router() -> EndpointRouter: - def mock_get_handle(deployment_name, app_name, *args, **kwargs): - return MockDeploymentHandle(deployment_name, app_name) +def mock_router(): + def mock_get_handle(endpoint, info): + return MockDeploymentHandle(endpoint.name, endpoint.app_name) - yield EndpointRouter(mock_get_handle, RequestProtocol.GRPC) + yield ProxyRouter(mock_get_handle) -@pytest.mark.parametrize( - "mocked_router", ["mock_longest_prefix_router", "mock_endpoint_router"] -) -def test_no_match(mocked_router, request): - router = request.getfixturevalue(mocked_router) +def test_no_match(mock_router: ProxyRouter): + router = mock_router router.update_routes( { DeploymentID(name="endpoint", app_name="default"): EndpointInfo( @@ -49,20 +27,14 @@ def test_no_match(mocked_router, request): DeploymentID(name="endpoint2", app_name="app2"): EndpointInfo( route="/hello2" ), - } + }, ) - assert get_handle_function(router)("/nonexistent") is None + assert router.match_route("/nonexistent") is None + assert router.get_handle_for_endpoint("/nonexistent") is None -@pytest.mark.parametrize( - "mocked_router, target_route", - [ - ("mock_longest_prefix_router", "/endpoint"), - ("mock_endpoint_router", "default"), - ], -) -def test_default_route(mocked_router, target_route, request): - router = request.getfixturevalue(mocked_router) +def test_default_route(mock_router: ProxyRouter): + router = mock_router router.update_routes( { DeploymentID(name="endpoint", app_name="default"): EndpointInfo( @@ -71,28 +43,37 @@ def test_default_route(mocked_router, target_route, request): DeploymentID(name="endpoint2", app_name="app2"): EndpointInfo( route="/endpoint2" ), - } + }, ) - assert get_handle_function(router)("/nonexistent") is None + # Route based matching + assert router.match_route("/nonexistent") is None - route, handle, app_is_cross_language = get_handle_function(router)(target_route) - assert all( - [ - route == "/endpoint", - handle == ("endpoint", "default"), - not app_is_cross_language, - ] - ) + route, handle, app_is_cross_language = router.match_route("/endpoint") + assert route == "/endpoint" + assert handle == ("endpoint", "default") + assert not app_is_cross_language + + # Endpoint based matching + assert router.get_handle_for_endpoint("/nonexistent") is None + + route, handle, app_is_cross_language = router.get_handle_for_endpoint("default") + assert route == "/endpoint" + assert handle == ("endpoint", "default") + assert not app_is_cross_language -def test_trailing_slash(mock_longest_prefix_router): - router = mock_longest_prefix_router +def test_trailing_slash(mock_router: ProxyRouter): + router = mock_router router.update_routes( - {DeploymentID(name="endpoint", app_name="default"): EndpointInfo(route="/test")} + { + DeploymentID(name="endpoint", app_name="default"): EndpointInfo( + route="/test" + ) + }, ) - route, handle, _ = get_handle_function(router)("/test/") + route, handle, _ = router.match_route("/test/") assert route == "/test" and handle == ("endpoint", "default") router.update_routes( @@ -100,14 +81,14 @@ def test_trailing_slash(mock_longest_prefix_router): DeploymentID(name="endpoint", app_name="default"): EndpointInfo( route="/test/" ) - } + }, ) - assert get_handle_function(router)("/test") is None + assert router.match_route("/test") is None -def test_prefix_match(mock_longest_prefix_router): - router = mock_longest_prefix_router +def test_prefix_match(mock_router): + router = mock_router router.update_routes( { DeploymentID(name="endpoint1", app_name="default"): EndpointInfo( @@ -117,54 +98,44 @@ def test_prefix_match(mock_longest_prefix_router): route="/test" ), DeploymentID(name="endpoint3", app_name="default"): EndpointInfo(route="/"), - } + }, ) - route, handle, _ = get_handle_function(router)("/test/test2/subpath") + route, handle, _ = router.match_route("/test/test2/subpath") assert route == "/test/test2" and handle == ("endpoint1", "default") - route, handle, _ = get_handle_function(router)("/test/test2/") + route, handle, _ = router.match_route("/test/test2/") assert route == "/test/test2" and handle == ("endpoint1", "default") - route, handle, _ = get_handle_function(router)("/test/test2") + route, handle, _ = router.match_route("/test/test2") assert route == "/test/test2" and handle == ("endpoint1", "default") - route, handle, _ = get_handle_function(router)("/test/subpath") + route, handle, _ = router.match_route("/test/subpath") assert route == "/test" and handle == ("endpoint2", "default") - route, handle, _ = get_handle_function(router)("/test/") + route, handle, _ = router.match_route("/test/") assert route == "/test" and handle == ("endpoint2", "default") - route, handle, _ = get_handle_function(router)("/test") + route, handle, _ = router.match_route("/test") assert route == "/test" and handle == ("endpoint2", "default") - route, handle, _ = get_handle_function(router)("/test2") + route, handle, _ = router.match_route("/test2") assert route == "/" and handle == ("endpoint3", "default") - route, handle, _ = get_handle_function(router)("/") + route, handle, _ = router.match_route("/") assert route == "/" and handle == ("endpoint3", "default") -@pytest.mark.parametrize( - "mocked_router, target_route1, target_route2", - [ - ("mock_longest_prefix_router", "/endpoint", "/endpoint2"), - ("mock_endpoint_router", "app1_endpoint", "app2"), - ], -) -def test_update_routes(mocked_router, target_route1, target_route2, request): - router = request.getfixturevalue(mocked_router) +def test_update_routes(mock_router): + router = mock_router router.update_routes( - { - DeploymentID(name="endpoint", app_name="app1"): EndpointInfo( - route="/endpoint" - ) - } + {DeploymentID("endpoint", "app1"): EndpointInfo(route="/endpoint")}, ) - route, handle, app_is_cross_language = get_handle_function(router)(target_route1) - assert all( - [ - route == "/endpoint", - handle == ("endpoint", "app1"), - not app_is_cross_language, - ] - ) + route, handle, app_is_cross_language = router.match_route("/endpoint") + assert route == "/endpoint" + assert handle == ("endpoint", "app1") + assert not app_is_cross_language + + route, handle, app_is_cross_language = router.get_handle_for_endpoint("app1") + assert route == "/endpoint" + assert handle == ("endpoint", "app1") + assert not app_is_cross_language router.update_routes( { @@ -176,33 +147,35 @@ def test_update_routes(mocked_router, target_route1, target_route2, request): route="/endpoint3", app_is_cross_language=True, ), - } + }, ) - assert get_handle_function(router)(target_route1) is None + assert router.match_route("/endpoint") is None + assert router.match_route("app1") is None - route, handle, app_is_cross_language = get_handle_function(router)(target_route2) - assert all( - [route == "/endpoint2", handle == ("endpoint2", "app2"), app_is_cross_language] - ) + route, handle, app_is_cross_language = router.match_route("/endpoint2") + assert route == "/endpoint2" + assert handle == ("endpoint2", "app2") + assert app_is_cross_language + + route, handle, app_is_cross_language = router.get_handle_for_endpoint("app2") + assert route == "/endpoint2" + assert handle == ("endpoint2", "app2") + assert app_is_cross_language class TestReadyForTraffic: @pytest.mark.parametrize("is_head", [False, True]) - def test_route_table_not_populated( - self, mock_endpoint_router: EndpointRouter, is_head: bool - ): + def test_route_table_not_populated(self, mock_router, is_head: bool): """Proxy router should NOT be ready for traffic if: - it has not received route table from controller """ - ready_for_traffic, msg = mock_endpoint_router.ready_for_traffic(is_head=is_head) + ready_for_traffic, msg = mock_router.ready_for_traffic(is_head=is_head) assert not ready_for_traffic assert msg == NO_ROUTES_MESSAGE - def test_head_route_table_populated_no_replicas( - self, mock_endpoint_router: EndpointRouter - ): + def test_head_route_table_populated_no_replicas(self, mock_router): """Proxy router should be ready for traffic if: - it has received route table from controller - it hasn't received any replicas yet @@ -210,16 +183,14 @@ def test_head_route_table_populated_no_replicas( """ d_id = DeploymentID(name="A", app_name="B") - mock_endpoint_router.update_routes({d_id: EndpointInfo(route="/")}) - mock_endpoint_router.handles[d_id].set_running_replicas_populated(False) + mock_router.update_routes({d_id: EndpointInfo(route="/")}) + mock_router.handles[d_id].set_running_replicas_populated(False) - ready_for_traffic, msg = mock_endpoint_router.ready_for_traffic(is_head=True) + ready_for_traffic, msg = mock_router.ready_for_traffic(is_head=True) assert ready_for_traffic assert not msg - def test_worker_route_table_populated_no_replicas( - self, mock_endpoint_router: EndpointRouter - ): + def test_worker_route_table_populated_no_replicas(self, mock_router): """Proxy router should NOT be ready for traffic if: - it has received route table from controller - it hasn't received any replicas yet @@ -227,27 +198,25 @@ def test_worker_route_table_populated_no_replicas( """ d_id = DeploymentID(name="A", app_name="B") - mock_endpoint_router.update_routes({d_id: EndpointInfo(route="/")}) - mock_endpoint_router.handles[d_id].set_running_replicas_populated(False) + mock_router.update_routes({d_id: EndpointInfo(route="/")}) + mock_router.handles[d_id].set_running_replicas_populated(False) - ready_for_traffic, msg = mock_endpoint_router.ready_for_traffic(is_head=False) + ready_for_traffic, msg = mock_router.ready_for_traffic(is_head=False) assert not ready_for_traffic assert msg == NO_REPLICAS_MESSAGE @pytest.mark.parametrize("is_head", [False, True]) - def test_route_table_populated_with_replicas( - self, mock_endpoint_router: EndpointRouter, is_head: bool - ): + def test_route_table_populated_with_replicas(self, mock_router, is_head: bool): """Proxy router should be ready for traffic if: - it has received route table from controller - it has received replicas from controller """ d_id = DeploymentID(name="A", app_name="B") - mock_endpoint_router.update_routes({d_id: EndpointInfo(route="/")}) - mock_endpoint_router.handles[d_id].set_running_replicas_populated(True) + mock_router.update_routes({d_id: EndpointInfo(route="/")}) + mock_router.handles[d_id].set_running_replicas_populated(True) - ready_for_traffic, msg = mock_endpoint_router.ready_for_traffic(is_head=is_head) + ready_for_traffic, msg = mock_router.ready_for_traffic(is_head=is_head) assert ready_for_traffic assert not msg