diff --git a/python/ray/_common/utils.py b/python/ray/_common/utils.py index 28a05a356549..32842113479c 100644 --- a/python/ray/_common/utils.py +++ b/python/ray/_common/utils.py @@ -9,27 +9,26 @@ import sys import tempfile from inspect import signature -from typing import Any, Coroutine, Dict, Optional +from types import ModuleType +from typing import Any, Coroutine, Dict, Optional, Tuple import psutil -def import_attr(full_path: str, *, reload_module: bool = False): - """Given a full import path to a module attr, return the imported attr. +def import_module_and_attr( + full_path: str, *, reload_module: bool = False +) -> Tuple[ModuleType, Any]: + """Given a full import path to a module attr, return the imported module and attr. If `reload_module` is set, the module will be reloaded using `importlib.reload`. - For example, the following are equivalent: - MyClass = import_attr("module.submodule:MyClass") - MyClass = import_attr("module.submodule.MyClass") - from module.submodule import MyClass + Args: + full_path: The full import path to the module and attr. + reload_module: Whether to reload the module. Returns: - Imported attr + A tuple of the imported module and attr. """ - if full_path is None: - raise TypeError("import path cannot be None") - if ":" in full_path: if full_path.count(":") > 1: raise ValueError( @@ -41,11 +40,26 @@ def import_attr(full_path: str, *, reload_module: bool = False): last_period_idx = full_path.rfind(".") module_name = full_path[:last_period_idx] attr_name = full_path[last_period_idx + 1 :] - module = importlib.import_module(module_name) if reload_module: importlib.reload(module) - return getattr(module, attr_name) + return module, getattr(module, attr_name) + + +def import_attr(full_path: str, *, reload_module: bool = False) -> Any: + """Given a full import path to a module attr, return the imported attr. + + If `reload_module` is set, the module will be reloaded using `importlib.reload`. + + For example, the following are equivalent: + MyClass = import_attr("module.submodule:MyClass") + MyClass = import_attr("module.submodule.MyClass") + from module.submodule import MyClass + + Returns: + Imported attr + """ + return import_module_and_attr(full_path, reload_module=reload_module)[1] def get_or_create_event_loop() -> asyncio.AbstractEventLoop: diff --git a/python/ray/serve/_private/application_state.py b/python/ray/serve/_private/application_state.py index 4f95d39a8603..db1bca725412 100644 --- a/python/ray/serve/_private/application_state.py +++ b/python/ray/serve/_private/application_state.py @@ -11,7 +11,7 @@ import ray from ray import cloudpickle -from ray._common.utils import import_attr +from ray._common.utils import import_attr, import_module_and_attr from ray.exceptions import RuntimeEnvSetupError from ray.serve._private.autoscaling_state import AutoscalingStateManager from ray.serve._private.build_app import BuiltApplication, build_app @@ -24,7 +24,12 @@ TargetCapacityDirection, ) from ray.serve._private.config import DeploymentConfig -from ray.serve._private.constants import RAY_SERVE_ENABLE_TASK_EVENTS, SERVE_LOGGER_NAME +from ray.serve._private.constants import ( + DEFAULT_AUTOSCALING_POLICY_NAME, + DEFAULT_REQUEST_ROUTER_PATH, + RAY_SERVE_ENABLE_TASK_EVENTS, + SERVE_LOGGER_NAME, +) from ray.serve._private.deploy_utils import ( deploy_args_to_deployment_info, get_app_code_version, @@ -43,7 +48,7 @@ validate_route_prefix, ) from ray.serve.api import ASGIAppReplicaWrapper -from ray.serve.config import AutoscalingConfig, AutoscalingPolicy +from ray.serve.config import AutoscalingConfig, AutoscalingPolicy, RequestRouterConfig from ray.serve.exceptions import RayServeException from ray.serve.generated.serve_pb2 import ( ApplicationStatus as ApplicationStatusProto, @@ -205,6 +210,7 @@ class ApplicationTargetState: target_capacity_direction: the scale direction to use when running the Serve autoscaler. deleting: whether the application is being deleted. + serialized_application_autoscaling_policy_def: Optional[bytes] """ deployment_infos: Optional[Dict[str, DeploymentInfo]] @@ -214,6 +220,7 @@ class ApplicationTargetState: target_capacity_direction: Optional[TargetCapacityDirection] deleting: bool api_type: APIType + serialized_application_autoscaling_policy_def: Optional[bytes] class ApplicationState: @@ -260,6 +267,7 @@ def __init__( target_capacity_direction=None, deleting=False, api_type=APIType.UNKNOWN, + serialized_application_autoscaling_policy_def=None, ) self._logging_config = logging_config @@ -339,7 +347,10 @@ def recover_target_state_from_checkpoint( ): self._autoscaling_state_manager.register_application( self._name, - AutoscalingPolicy(**checkpoint_data.config.autoscaling_policy), + AutoscalingPolicy( + _serialized_policy_def=checkpoint_data.serialized_application_autoscaling_policy_def, + **checkpoint_data.config.autoscaling_policy, + ), ) def _set_target_state( @@ -352,6 +363,7 @@ def _set_target_state( target_capacity: Optional[float] = None, target_capacity_direction: Optional[TargetCapacityDirection] = None, deleting: bool = False, + serialized_application_autoscaling_policy_def: Optional[bytes] = None, ): """Set application target state. @@ -382,6 +394,7 @@ def _set_target_state( target_capacity_direction, deleting, api_type=api_type, + serialized_application_autoscaling_policy_def=serialized_application_autoscaling_policy_def, ) self._target_state = target_state @@ -637,6 +650,28 @@ def apply_app_config( ) or self._target_state.config.runtime_env.get("image_uri"): ServeUsageTag.APP_CONTAINER_RUNTIME_ENV_USED.record("1") + if isinstance(config.autoscaling_policy, dict): + application_autoscaling_policy_function = config.autoscaling_policy.get( + "policy_function" + ) + else: + application_autoscaling_policy_function = None + + deployment_to_autoscaling_policy_function = { + deployment.name: deployment.autoscaling_config.get("policy", {}).get( + "policy_function", DEFAULT_AUTOSCALING_POLICY_NAME + ) + for deployment in config.deployments + if isinstance(deployment.autoscaling_config, dict) + } + deployment_to_request_router_cls = { + deployment.name: deployment.request_router_config.get( + "request_router_class", DEFAULT_REQUEST_ROUTER_PATH + ) + for deployment in config.deployments + if isinstance(deployment.request_router_config, dict) + } + # Kick off new build app task logger.info(f"Importing and building app '{self._name}'.") build_app_obj_ref = build_serve_application.options( @@ -648,6 +683,9 @@ def apply_app_config( config.name, config.args, self._logging_config, + application_autoscaling_policy_function, + deployment_to_autoscaling_policy_function, + deployment_to_request_router_cls, ) self._build_app_task_info = BuildAppTaskInfo( obj_ref=build_app_obj_ref, @@ -719,10 +757,15 @@ def _determine_app_status(self) -> Tuple[ApplicationStatus, str]: else: return ApplicationStatus.RUNNING, "" - def _reconcile_build_app_task(self) -> Tuple[Optional[Dict], BuildAppStatus, str]: + def _reconcile_build_app_task( + self, + ) -> Tuple[Optional[bytes], Optional[Dict], BuildAppStatus, str]: """If necessary, reconcile the in-progress build task. Returns: + Serialized application autoscaling policy def (bytes): + The serialized application autoscaling policy def returned from the build app task + if it was built successfully, otherwise None. Deploy arguments (Dict[str, DeploymentInfo]): The deploy arguments returned from the build app task and their code version. @@ -735,19 +778,22 @@ def _reconcile_build_app_task(self) -> Tuple[Optional[Dict], BuildAppStatus, str Non-empty string if status is DEPLOY_FAILED or UNHEALTHY """ if self._build_app_task_info is None or self._build_app_task_info.finished: - return None, BuildAppStatus.NO_TASK_IN_PROGRESS, "" + return None, None, BuildAppStatus.NO_TASK_IN_PROGRESS, "" if not check_obj_ref_ready_nowait(self._build_app_task_info.obj_ref): - return None, BuildAppStatus.IN_PROGRESS, "" + return None, None, BuildAppStatus.IN_PROGRESS, "" # Retrieve build app task result self._build_app_task_info.finished = True try: - args, err = ray.get(self._build_app_task_info.obj_ref) + serialized_application_autoscaling_policy_def, args, err = ray.get( + self._build_app_task_info.obj_ref + ) if err is None: logger.info(f"Imported and built app '{self._name}' successfully.") else: return ( + None, None, BuildAppStatus.FAILED, f"Deploying app '{self._name}' failed with exception:\n{err}", @@ -757,13 +803,13 @@ def _reconcile_build_app_task(self) -> Tuple[Optional[Dict], BuildAppStatus, str f"Runtime env setup for app '{self._name}' failed:\n" + traceback.format_exc() ) - return None, BuildAppStatus.FAILED, error_msg + return None, None, BuildAppStatus.FAILED, error_msg except Exception: error_msg = ( f"Unexpected error occurred while deploying application " f"'{self._name}': \n{traceback.format_exc()}" ) - return None, BuildAppStatus.FAILED, error_msg + return None, None, BuildAppStatus.FAILED, error_msg # Convert serialized deployment args (returned by build app task) # to deployment infos and apply option overrides from config @@ -774,19 +820,37 @@ def _reconcile_build_app_task(self) -> Tuple[Optional[Dict], BuildAppStatus, str ) for params in args } + deployment_to_serialized_autoscaling_policy_def = { + params["deployment_name"]: params["serialized_autoscaling_policy_def"] + for params in args + if params["serialized_autoscaling_policy_def"] is not None + } + deployment_to_serialized_request_router_cls = { + params["deployment_name"]: params["serialized_request_router_cls"] + for params in args + if params["serialized_request_router_cls"] is not None + } overrided_infos = override_deployment_info( - deployment_infos, self._build_app_task_info.config + deployment_infos, + self._build_app_task_info.config, + deployment_to_serialized_autoscaling_policy_def, + deployment_to_serialized_request_router_cls, ) self._route_prefix = self._check_routes(overrided_infos) - return overrided_infos, BuildAppStatus.SUCCEEDED, "" + return ( + serialized_application_autoscaling_policy_def, + overrided_infos, + BuildAppStatus.SUCCEEDED, + "", + ) except (TypeError, ValueError, RayServeException): - return None, BuildAppStatus.FAILED, traceback.format_exc() + return None, None, BuildAppStatus.FAILED, traceback.format_exc() except Exception: error_msg = ( f"Unexpected error occurred while applying config for application " f"'{self._name}': \n{traceback.format_exc()}" ) - return None, BuildAppStatus.FAILED, error_msg + return None, None, BuildAppStatus.FAILED, error_msg def _check_routes( self, deployment_infos: Dict[str, DeploymentInfo] @@ -877,7 +941,12 @@ def update(self) -> Tuple[bool, bool]: # If the application is being deleted, ignore any build task results to # avoid flipping the state back to DEPLOYING/RUNNING. if not self._target_state.deleting: - infos, task_status, msg = self._reconcile_build_app_task() + ( + serialized_application_autoscaling_policy_def, + infos, + task_status, + msg, + ) = self._reconcile_build_app_task() if task_status == BuildAppStatus.SUCCEEDED: target_state_changed = True self._set_target_state( @@ -889,6 +958,7 @@ def update(self) -> Tuple[bool, bool]: target_capacity_direction=( self._build_app_task_info.target_capacity_direction ), + serialized_application_autoscaling_policy_def=serialized_application_autoscaling_policy_def, ) # Handling the case where the user turns off/turns on app-level autoscaling policy, # between app deployment. @@ -899,7 +969,8 @@ def update(self) -> Tuple[bool, bool]: self._autoscaling_state_manager.register_application( self._name, AutoscalingPolicy( - **self._target_state.config.autoscaling_policy + _serialized_policy_def=serialized_application_autoscaling_policy_def, + **self._target_state.config.autoscaling_policy, ), ) else: @@ -1269,7 +1340,10 @@ def build_serve_application( name: str, args: Dict, logging_config: LoggingConfig, -) -> Tuple[Optional[List[Dict]], Optional[str]]: + application_autoscaling_policy_function: Optional[str], + deployment_to_autoscaling_policy_function: Dict[str, str], + deployment_to_request_router_cls: Dict[str, str], +) -> Tuple[Optional[bytes], Optional[List[Dict]], Optional[str]]: """Import and build a Serve application. Args: @@ -1280,7 +1354,13 @@ def build_serve_application( without removing existing applications. args: Arguments to be passed to the application builder. logging_config: the logging config for the build app task. + application_autoscaling_policy_function: the application autoscaling policy function name + deployment_to_autoscaling_policy_function: a dictionary mapping deployment names to autoscaling policy function names + deployment_to_request_router_cls: a dictionary mapping deployment names to request router class names + Returns: + Serialized application autoscaling policy def: a serialized autoscaling + policy def for the application if it was built successfully, otherwise None. Deploy arguments: a list of deployment arguments if application was built successfully, otherwise None. Error message: a string if an error was raised, otherwise None. @@ -1309,12 +1389,35 @@ def build_serve_application( default_runtime_env=ray.get_runtime_context().runtime_env, ) num_ingress_deployments = 0 + + def _get_serialized_def(attr_path: str) -> bytes: + module, attr = import_module_and_attr(attr_path) + cloudpickle.register_pickle_by_value(module) + serialized = cloudpickle.dumps(attr) + cloudpickle.unregister_pickle_by_value(module) + return serialized + + application_serialized_autoscaling_policy_def = None + if application_autoscaling_policy_function is not None: + application_serialized_autoscaling_policy_def = _get_serialized_def( + application_autoscaling_policy_function + ) for deployment in built_app.deployments: if inspect.isclass(deployment.func_or_class) and issubclass( deployment.func_or_class, ASGIAppReplicaWrapper ): num_ingress_deployments += 1 is_ingress = deployment.name == built_app.ingress_deployment_name + deployment_to_serialized_autoscaling_policy_def = None + deployment_to_serialized_request_router_cls = None + if deployment.name in deployment_to_autoscaling_policy_function: + deployment_to_serialized_autoscaling_policy_def = _get_serialized_def( + deployment_to_autoscaling_policy_function[deployment.name] + ) + if deployment.name in deployment_to_request_router_cls: + deployment_to_serialized_request_router_cls = _get_serialized_def( + deployment_to_request_router_cls[deployment.name] + ) deploy_args_list.append( get_deploy_args( name=deployment._name, @@ -1323,15 +1426,21 @@ def build_serve_application( deployment_config=deployment._deployment_config, version=code_version, route_prefix="/" if is_ingress else None, + serialized_autoscaling_policy_def=deployment_to_serialized_autoscaling_policy_def, + serialized_request_router_cls=deployment_to_serialized_request_router_cls, ) ) if num_ingress_deployments > 1: - return None, ( - f'Found multiple FastAPI deployments in application "{built_app.name}". ' - "Please only include one deployment with @serve.ingress " - "in your application to avoid this issue." + return ( + None, + None, + ( + f'Found multiple FastAPI deployments in application "{built_app.name}". ' + "Please only include one deployment with @serve.ingress " + "in your application to avoid this issue." + ), ) - return deploy_args_list, None + return application_serialized_autoscaling_policy_def, deploy_args_list, None except KeyboardInterrupt: # Error is raised when this task is canceled with ray.cancel(), which # happens when deploy_apps() is called. @@ -1339,17 +1448,19 @@ def build_serve_application( "Existing config deployment request terminated because of keyboard " "interrupt." ) - return None, None + return None, None, None except Exception: logger.error( f"Exception importing application '{name}'.\n{traceback.format_exc()}" ) - return None, traceback.format_exc() + return None, None, traceback.format_exc() def override_deployment_info( deployment_infos: Dict[str, DeploymentInfo], override_config: Optional[ServeApplicationSchema], + deployment_to_serialized_autoscaling_policy_def: Optional[Dict[str, bytes]] = None, + deployment_to_serialized_request_router_cls: Optional[Dict[str, bytes]] = None, ) -> Dict[str, DeploymentInfo]: """Override deployment infos with options from app config. @@ -1358,6 +1469,8 @@ def override_deployment_info( deployment_infos: deployment info loaded from code override_config: application config deployed by user with options to override those loaded from code. + deployment_to_serialized_autoscaling_policy_def: serialized autoscaling policy def for each deployment + deployment_to_serialized_request_router_cls: serialized request router cls for each deployment Returns: the updated deployment infos. @@ -1404,6 +1517,17 @@ def override_deployment_info( if autoscaling_config: new_config.update(autoscaling_config) + if ( + deployment_to_serialized_autoscaling_policy_def + and deployment_name in deployment_to_serialized_autoscaling_policy_def + ): + # By setting the serialized policy def, AutoscalingConfig constructor will not + # try to import the policy from the string import path + policy_obj = AutoscalingPolicy.from_serialized_policy_def( + new_config["policy"], + deployment_to_serialized_autoscaling_policy_def[deployment_name], + ) + new_config["policy"] = policy_obj options["autoscaling_config"] = AutoscalingConfig(**new_config) ServeUsageTag.AUTO_NUM_REPLICAS_USED.record("1") @@ -1453,6 +1577,26 @@ def override_deployment_info( ) override_options["replica_config"] = replica_config + if "request_router_config" in options: + request_router_config = options.get("request_router_config") + if request_router_config: + if ( + deployment_to_serialized_request_router_cls + and deployment_name in deployment_to_serialized_request_router_cls + ): + # By setting the serialized request router cls, RequestRouterConfig constructor will not + # try to import the request router cls from the string import path + options[ + "request_router_config" + ] = RequestRouterConfig.from_serialized_request_router_cls( + request_router_config, + deployment_to_serialized_request_router_cls[deployment_name], + ) + else: + options["request_router_config"] = RequestRouterConfig( + **request_router_config + ) + # Override deployment config options options.pop("name", None) original_options.update(options) diff --git a/python/ray/serve/_private/config.py b/python/ray/serve/_private/config.py index 2f642a385d64..e91e39b4fc59 100644 --- a/python/ray/serve/_private/config.py +++ b/python/ray/serve/_private/config.py @@ -166,7 +166,7 @@ class DeploymentConfig(BaseModel): ) request_router_config: RequestRouterConfig = Field( - default=RequestRouterConfig(), + default_factory=RequestRouterConfig, update_type=DeploymentOptionUpdateType.NeedsActorReconfigure, ) @@ -248,6 +248,11 @@ def to_proto(self): if self.needs_pickle(): data["user_config"] = cloudpickle.dumps(data["user_config"]) if data.get("autoscaling_config"): + # By setting the serialized policy def, on the protobuf level, AutoscalingConfig constructor will not + # try to import the policy from the string import path when the protobuf is deserialized on the controller side + data["autoscaling_config"]["policy"][ + "_serialized_policy_def" + ] = self.autoscaling_config.policy._serialized_policy_def data["autoscaling_config"] = AutoscalingConfigProto( **data["autoscaling_config"] ) @@ -266,6 +271,11 @@ def to_proto(self): "Non-empty request_router_kwargs not supported" f"for cross-language deployments. Got: {router_kwargs}" ) + # By setting the serialized request router cls, on the protobuf level, RequestRouterConfig constructor will not + # try to import the request router cls from the string import path when the protobuf is deserialized on the controller side + data["request_router_config"][ + "_serialized_request_router_cls" + ] = self.request_router_config._serialized_request_router_cls data["request_router_config"] = RequestRouterConfigProto( **data["request_router_config"] ) diff --git a/python/ray/serve/_private/deploy_utils.py b/python/ray/serve/_private/deploy_utils.py index f4750e102825..5f700d01ae8f 100644 --- a/python/ray/serve/_private/deploy_utils.py +++ b/python/ray/serve/_private/deploy_utils.py @@ -22,6 +22,8 @@ def get_deploy_args( deployment_config: Optional[Union[DeploymentConfig, Dict[str, Any]]] = None, version: Optional[str] = None, route_prefix: Optional[str] = None, + serialized_autoscaling_policy_def: Optional[bytes] = None, + serialized_request_router_cls: Optional[bytes] = None, ) -> Dict: """ Takes a deployment's configuration, and returns the arguments needed @@ -44,6 +46,8 @@ def get_deploy_args( "route_prefix": route_prefix, "deployer_job_id": ray.get_runtime_context().get_job_id(), "ingress": ingress, + "serialized_autoscaling_policy_def": serialized_autoscaling_policy_def, + "serialized_request_router_cls": serialized_request_router_cls, } return controller_deploy_args @@ -98,6 +102,11 @@ def get_app_code_version(app_config: ServeApplicationSchema) -> str: Returns: a hash of the import path and (application level) runtime env representing the code version of the application. """ + request_router_configs = [ + deployment.request_router_config + for deployment in app_config.deployments + if isinstance(deployment.request_router_config, dict) + ] deployment_autoscaling_policies = [ deployment_config.autoscaling_config.get("policy", None) for deployment_config in app_config.deployments @@ -113,6 +122,7 @@ def get_app_code_version(app_config: ServeApplicationSchema) -> str: # any one of the deployment level autoscaling policy is changed "autoscaling_policy": app_config.autoscaling_policy, "deployment_autoscaling_policies": deployment_autoscaling_policies, + "request_router_configs": request_router_configs, }, sort_keys=True, ).encode("utf-8") diff --git a/python/ray/serve/_private/request_router/common.py b/python/ray/serve/_private/request_router/common.py index 4653daca2185..b373f47f1528 100644 --- a/python/ray/serve/_private/request_router/common.py +++ b/python/ray/serve/_private/request_router/common.py @@ -38,7 +38,7 @@ class PendingRequest: metadata: RequestMetadata """Metadata for the request, including request ID and whether it's streaming.""" - created_at: float = field(default_factory=time.time) + created_at: float = field(default_factory=lambda: time.time()) """Timestamp when the request was created.""" future: asyncio.Future = field(default_factory=lambda: asyncio.Future()) @@ -73,7 +73,7 @@ def __init__( self._cache: Dict[ReplicaID, ReplicaQueueLengthCacheEntry] = {} self._staleness_timeout_s = staleness_timeout_s self._get_curr_time_s = ( - get_curr_time_s if get_curr_time_s is not None else time.time + get_curr_time_s if get_curr_time_s is not None else lambda: time.time() ) def _is_timed_out(self, timestamp_s: int) -> bool: diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 6787ec30ed87..e4cd1eff9f4a 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -650,7 +650,10 @@ def request_router(self) -> Optional[RequestRouter]: # Log usage telemetry to indicate that custom request router # feature is being used in this cluster. - if self._request_router_class is not PowerOfTwoChoicesRequestRouter: + if ( + self._request_router_class.__name__ + != PowerOfTwoChoicesRequestRouter.__name__ + ): ServeUsageTag.CUSTOM_REQUEST_ROUTER_USED.record("1") return self._request_router diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index f7555cdcaef0..942d8ff8e80e 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -16,7 +16,7 @@ PrivateAttr, validator, ) -from ray._common.utils import import_attr +from ray._common.utils import import_attr, import_module_and_attr # Import types needed for AutoscalingContext from ray.serve._private.common import DeploymentID, ReplicaID, TimeSeries @@ -192,8 +192,30 @@ def __init__(self, **kwargs: dict[str, Any]): Args: **kwargs: Keyword arguments to pass to BaseModel. """ + serialized_request_router_cls = kwargs.pop( + "_serialized_request_router_cls", None + ) super().__init__(**kwargs) - self._serialize_request_router_cls() + if serialized_request_router_cls: + self._serialized_request_router_cls = serialized_request_router_cls + else: + self._serialize_request_router_cls() + + def set_serialized_request_router_cls( + self, serialized_request_router_cls: bytes + ) -> None: + self._serialized_request_router_cls = serialized_request_router_cls + + @classmethod + def from_serialized_request_router_cls( + cls, request_router_config: dict, serialized_request_router_cls: bytes + ) -> "RequestRouterConfig": + config = request_router_config.copy() + config["_serialized_request_router_cls"] = serialized_request_router_cls + return cls(**config) + + def get_serialized_request_router_cls(self) -> Optional[bytes]: + return self._serialized_request_router_cls def _serialize_request_router_cls(self) -> None: """Import and serialize request router class with cloudpickle. @@ -209,9 +231,13 @@ def _serialize_request_router_cls(self) -> None: ) request_router_path = request_router_class or DEFAULT_REQUEST_ROUTER_PATH - request_router_class = import_attr(request_router_path) + request_router_module, request_router_class = import_module_and_attr( + request_router_path + ) + cloudpickle.register_pickle_by_value(request_router_module) + self.set_serialized_request_router_cls(cloudpickle.dumps(request_router_class)) + cloudpickle.unregister_pickle_by_value(request_router_module) - self._serialized_request_router_cls = cloudpickle.dumps(request_router_class) # Update the request_router_class field to be the string path self.request_router_class = request_router_path @@ -242,8 +268,26 @@ class AutoscalingPolicy(BaseModel): ) def __init__(self, **kwargs): + serialized_policy_def = kwargs.pop("_serialized_policy_def", None) super().__init__(**kwargs) - self.serialize_policy() + if serialized_policy_def: + self._serialized_policy_def = serialized_policy_def + else: + self.serialize_policy() + + def set_serialized_policy_def(self, serialized_policy_def: bytes) -> None: + self._serialized_policy_def = serialized_policy_def + + @classmethod + def from_serialized_policy_def( + cls, policy_config: dict, serialized_policy_def: bytes + ) -> "AutoscalingPolicy": + config = policy_config.copy() + config["_serialized_policy_def"] = serialized_policy_def + return cls(**config) + + def get_serialized_policy_def(self) -> Optional[bytes]: + return self._serialized_policy_def def serialize_policy(self) -> None: """Serialize policy with cloudpickle. @@ -257,7 +301,10 @@ def serialize_policy(self) -> None: policy_path = f"{policy_path.__module__}.{policy_path.__name__}" if not self._serialized_policy_def: - self._serialized_policy_def = cloudpickle.dumps(import_attr(policy_path)) + policy_module, policy_function = import_module_and_attr(policy_path) + cloudpickle.register_pickle_by_value(policy_module) + self.set_serialized_policy_def(cloudpickle.dumps(policy_function)) + cloudpickle.unregister_pickle_by_value(policy_module) self.policy_function = policy_path diff --git a/python/ray/serve/tests/test_cli.py b/python/ray/serve/tests/test_cli.py index 9943093f3220..0c57e0197942 100644 --- a/python/ray/serve/tests/test_cli.py +++ b/python/ray/serve/tests/test_cli.py @@ -783,5 +783,33 @@ def test_deployment_contains_utils(serve_instance): ) +def test_deploy_use_custom_request_router(serve_instance): + """Test that the custom request router is initialized and used correctly.""" + config_file = os.path.join( + os.path.dirname(__file__), + "test_config_files", + "use_custom_request_router.yaml", + ) + subprocess.check_output(["serve", "deploy", config_file], stderr=subprocess.STDOUT) + wait_for_condition( + lambda: httpx.post(f"{get_application_url(app_name='app1')}/").text + == "hello_from_custom_request_router" + ) + + +def test_deploy_use_custom_autoscaling(serve_instance): + """Test that the custom autoscaling is initialized correctly.""" + config_file = os.path.join( + os.path.dirname(__file__), + "test_config_files", + "use_custom_autoscaling.yaml", + ) + subprocess.check_output(["serve", "deploy", config_file], stderr=subprocess.STDOUT) + wait_for_condition( + lambda: httpx.post(f"{get_application_url(app_name='app1')}/").text + == "hello_from_custom_autoscaling_policy" + ) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_config_files/use_custom_autoscaling.yaml b/python/ray/serve/tests/test_config_files/use_custom_autoscaling.yaml new file mode 100644 index 000000000000..fd479db8ff5d --- /dev/null +++ b/python/ray/serve/tests/test_config_files/use_custom_autoscaling.yaml @@ -0,0 +1,16 @@ +applications: +- name: app1 + route_prefix: / + import_path: ray.serve.tests.test_config_files.use_custom_autoscaling_policy:app + deployments: + - name: CustomAutoscalingPolicy + num_replicas: auto + ray_actor_options: + num_cpus: 0.0 + autoscaling_config: + min_replicas: 1 + max_replicas: 2 + upscale_delay_s: 1 + downscale_delay_s: 2 + policy: + policy_function: ray.serve.tests.test_config_files.use_custom_autoscaling_policy.custom_autoscaling_policy diff --git a/python/ray/serve/tests/test_config_files/use_custom_autoscaling_policy.py b/python/ray/serve/tests/test_config_files/use_custom_autoscaling_policy.py new file mode 100644 index 000000000000..a0b47d6b0ae9 --- /dev/null +++ b/python/ray/serve/tests/test_config_files/use_custom_autoscaling_policy.py @@ -0,0 +1,16 @@ +from ray import serve +from ray.serve.config import AutoscalingContext + + +def custom_autoscaling_policy(ctx: AutoscalingContext): + print("custom_autoscaling_policy") + return 2, {} + + +@serve.deployment +class CustomAutoscalingPolicy: + def __call__(self): + return "hello_from_custom_autoscaling_policy" + + +app = CustomAutoscalingPolicy.bind() diff --git a/python/ray/serve/tests/test_config_files/use_custom_request_router.py b/python/ray/serve/tests/test_config_files/use_custom_request_router.py new file mode 100644 index 000000000000..8036be081651 --- /dev/null +++ b/python/ray/serve/tests/test_config_files/use_custom_request_router.py @@ -0,0 +1,47 @@ +import random +from typing import ( + List, + Optional, +) + +from ray import serve +from ray.serve.context import _get_internal_replica_context +from ray.serve.request_router import ( + PendingRequest, + ReplicaID, + ReplicaResult, + RequestRouter, + RunningReplica, +) + + +class UniformRequestRouter(RequestRouter): + async def choose_replicas( + self, + candidate_replicas: List[RunningReplica], + pending_request: Optional[PendingRequest] = None, + ) -> List[List[RunningReplica]]: + print("UniformRequestRouter routing request") + index = random.randint(0, len(candidate_replicas) - 1) + return [[candidate_replicas[index]]] + + def on_request_routed( + self, + pending_request: PendingRequest, + replica_id: ReplicaID, + result: ReplicaResult, + ): + print("on_request_routed callback is called!!") + + +@serve.deployment +class UniformRequestRouterApp: + def __init__(self): + context = _get_internal_replica_context() + self.replica_id: ReplicaID = context.replica_id + + async def __call__(self): + return "hello_from_custom_request_router" + + +app = UniformRequestRouterApp.bind() diff --git a/python/ray/serve/tests/test_config_files/use_custom_request_router.yaml b/python/ray/serve/tests/test_config_files/use_custom_request_router.yaml new file mode 100644 index 000000000000..163d58b9e6f8 --- /dev/null +++ b/python/ray/serve/tests/test_config_files/use_custom_request_router.yaml @@ -0,0 +1,14 @@ +applications: +- name: app1 + route_prefix: / + import_path: ray.serve.tests.test_config_files.use_custom_request_router:app + deployments: + - name: UniformRequestRouterApp + num_replicas: 2 + ray_actor_options: + num_cpus: 0.0 + request_router_config: + request_router_class: ray.serve.tests.test_config_files.use_custom_request_router.UniformRequestRouter + request_router_kwargs: {} + request_routing_stats_period_s: 10 + request_routing_stats_timeout_s: 30 diff --git a/python/ray/serve/tests/unit/test_application_state.py b/python/ray/serve/tests/unit/test_application_state.py index b4d7fd47f60f..5d274917841d 100644 --- a/python/ray/serve/tests/unit/test_application_state.py +++ b/python/ray/serve/tests/unit/test_application_state.py @@ -218,6 +218,8 @@ def deployment_params( "deployer_job_id": "random", "route_prefix": route_prefix, "ingress": False, + "serialized_autoscaling_policy_def": None, + "serialized_request_router_cls": None, } @@ -720,7 +722,7 @@ def test_app_unhealthy(mocked_application_state): @patch("ray.serve._private.application_state.build_serve_application", Mock()) -@patch("ray.get", Mock(return_value=([deployment_params("a", "/old")], None))) +@patch("ray.get", Mock(return_value=(None, [deployment_params("a", "/old")], None))) @patch("ray.serve._private.application_state.check_obj_ref_ready_nowait") def test_apply_app_configs_succeed(check_obj_ref_ready_nowait): """Test deploying through config successfully. @@ -815,7 +817,7 @@ def test_apply_app_configs_fail(check_obj_ref_ready_nowait): Mock(return_value="123"), ) @patch("ray.serve._private.application_state.build_serve_application", Mock()) -@patch("ray.get", Mock(return_value=([deployment_params("a", "/old")], None))) +@patch("ray.get", Mock(return_value=(None, [deployment_params("a", "/old")], None))) @patch("ray.serve._private.application_state.check_obj_ref_ready_nowait") def test_apply_app_configs_deletes_existing(check_obj_ref_ready_nowait): """Test that apply_app_configs deletes existing apps that aren't in the new list. @@ -2482,6 +2484,7 @@ def _deploy_app_with_mocks(self, app_state_manager, app_config): ) mock_reconcile.return_value = ( + None, deployment_infos, BuildAppStatus.SUCCEEDED, "", @@ -2542,6 +2545,7 @@ def _deploy_multiple_apps_with_mocks(self, app_state_manager, app_configs): ) mock_reconcile.return_value = ( + None, deployment_infos, BuildAppStatus.SUCCEEDED, "", diff --git a/python/ray/serve/tests/unit/test_config.py b/python/ray/serve/tests/unit/test_config.py index 611212c6cd08..e50a585e1932 100644 --- a/python/ray/serve/tests/unit/test_config.py +++ b/python/ray/serve/tests/unit/test_config.py @@ -813,7 +813,17 @@ def test_autoscaling_policy_serializations(policy): ).autoscaling_config.policy.get_policy() if policy is None: - assert deserialized_autoscaling_policy == default_autoscaling_policy + # Compare function attributes instead of function objects since + # cloudpickle.register_pickle_by_value() causes deserialization to + # create a new function object rather than returning the same object + assert ( + deserialized_autoscaling_policy.__name__ + == default_autoscaling_policy.__name__ + ) + assert ( + deserialized_autoscaling_policy.__module__ + == default_autoscaling_policy.__module__ + ) else: # Compare function behavior instead of function objects # since serialization/deserialization creates new function objects