diff --git a/python/ray/llm/_internal/serve/deployments/llm/llm_server.py b/python/ray/llm/_internal/serve/deployments/llm/llm_server.py index 9e594563b456..d1105db5afa8 100644 --- a/python/ray/llm/_internal/serve/deployments/llm/llm_server.py +++ b/python/ray/llm/_internal/serve/deployments/llm/llm_server.py @@ -437,10 +437,14 @@ async def __init__( """ await super().__init__(llm_config) - self._engine_cls = engine_cls or self._default_engine_cls - self.engine = self._get_engine_class(self._llm_config) - await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S) - + self._engine_cls = engine_cls or self._get_default_engine_class() + self.engine: Optional[LLMEngine] = None + if self._engine_cls is not None: + self.engine = self._engine_cls(self._llm_config) + await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S) + + # TODO (Kourosh): I think we can completely remove image retriever. + # It was missed to get removed. self.image_retriever = ( image_retriever_cls() if image_retriever_cls @@ -466,25 +470,20 @@ async def __init__( self.response_postprocessor = ResponsePostprocessor() - @property - def _get_engine_class(self) -> Type[LLMEngine]: + def _get_default_engine_class(self) -> Type[LLMEngine]: """Helper to load the engine class from the environment variable. - This is used for testing or escape-hatch for patching purposes. If env variable is not set, it will fallback to the default engine class. """ engine_cls_path = os.environ.get(RAYLLM_VLLM_ENGINE_CLS_ENV) if engine_cls_path: - try: - return import_attr(engine_cls_path) - except AttributeError: - logger.warning( - f"Failed to import engine class {engine_cls_path}. " - f"Using the default engine class {self._engine_cls}." - ) - return self._engine_cls + return import_attr(engine_cls_path) + return self._default_engine_cls async def _start_engine(self): + if self.engine is None: + raise ValueError("Engine is not set") + await self.engine.start() # Push telemetry reports for the model in the current deployment. @@ -616,7 +615,13 @@ async def check_health(self) -> None: Check the health of the replica. Does not return anything. Raise error when the engine is dead and needs to be restarted. """ - return await self.engine.check_health() + if self.engine is None: + return + try: + return await self.engine.check_health() + except Exception as e: + logger.error("Engine health check failed in LLMServer.check_health: %s", e) + raise e async def embeddings(self, request: EmbeddingRequest) -> LLMEmbeddingsResponse: """Runs an embeddings request to the vllm engine, and return the response. diff --git a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py index f58b2a400e79..25d8bd2fdf75 100644 --- a/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py +++ b/python/ray/llm/_internal/serve/deployments/llm/vllm/vllm_engine.py @@ -1,4 +1,3 @@ -import asyncio import os import re import time @@ -816,9 +815,9 @@ async def check_health(self) -> None: raise RuntimeError(f"{type(self.engine)} does not support health check.") try: - return await asyncio.wait_for(self.engine.check_health(), timeout=15) + await self.engine.check_health() except BaseException as e: - logger.exception("Healthcheck failed. The replica will be restarted") + logger.error("Healthcheck failed. The replica will be restarted") raise e from None @staticmethod diff --git a/python/ray/llm/_internal/serve/deployments/prefill_decode_disagg/prefill_decode_disagg.py b/python/ray/llm/_internal/serve/deployments/prefill_decode_disagg/prefill_decode_disagg.py index ba7fc4684b23..399ddbba584b 100644 --- a/python/ray/llm/_internal/serve/deployments/prefill_decode_disagg/prefill_decode_disagg.py +++ b/python/ray/llm/_internal/serve/deployments/prefill_decode_disagg/prefill_decode_disagg.py @@ -1,6 +1,5 @@ """Using Ray Serve to deploy LLM models with P/D disaggregation. """ -import asyncio import logging import uuid from typing import Any, AsyncGenerator, Dict, Union @@ -63,6 +62,7 @@ def parse_configs_and_cast_type(config: Union[str, LLMConfig]) -> LLMConfig: class PDProxyServer(LLMServer): + _default_engine_cls = None """ Proxy between P/D LLM servers. @@ -83,14 +83,6 @@ async def __init__( prefill_server: DeploymentHandle, decode_server: DeploymentHandle, ): - class FakeEngine: - """Provide a fake engine such that proxy don't really start any engine.""" - - def __init__(self, *args, **kwargs): - pass - - async def start(self, *args, **kwargs): - pass # We pass `llm_config` here to let super() extract the model_id, such that /v1/models # endpoint can work correctly. @@ -98,7 +90,6 @@ async def start(self, *args, **kwargs): # API, instead of passing it in as an argument. await super().__init__( llm_config, - engine_cls=FakeEngine, ) self.prefill_server = prefill_server @@ -160,13 +151,6 @@ async def _predict( ): yield chunk - async def check_health(self) -> None: - """Check the health of the llm engine.""" - await asyncio.gather( - self.prefill_server.check_health.remote(), - self.decode_server.check_health.remote(), - ) - @classmethod def as_deployment(cls) -> serve.Deployment: """Turns PDProxyServer into a Ray Serve deployment.""" diff --git a/python/ray/llm/_internal/serve/deployments/routers/router.py b/python/ray/llm/_internal/serve/deployments/routers/router.py index b25276611d94..e488f269605c 100644 --- a/python/ray/llm/_internal/serve/deployments/routers/router.py +++ b/python/ray/llm/_internal/serve/deployments/routers/router.py @@ -232,12 +232,6 @@ async def _setup_handle_and_config_maps( async def check_health(self): await self._init_completed.wait() - await asyncio.gather( - *[ - handle.check_health.remote() - for handle in self._default_serve_handles.values() - ] - ) def _get_configured_serve_handle(self, model_id: str): """Gets a ServeHandle to a model deployment. diff --git a/python/ray/llm/tests/serve/cpu/deployments/routers/test_router.py b/python/ray/llm/tests/serve/cpu/deployments/routers/test_router.py index 90076a235cef..5ba14036df08 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/routers/test_router.py +++ b/python/ray/llm/tests/serve/cpu/deployments/routers/test_router.py @@ -170,8 +170,6 @@ async def test_check_health(self, llm_config: LLMConfig): await router.check_health() - assert server.check_health.remote.call_count == 1 - if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/release/llm_tests/serve/test_llm_serve_fault_tolerance.py b/release/llm_tests/serve/test_llm_serve_fault_tolerance.py new file mode 100644 index 000000000000..d81a234c3aa7 --- /dev/null +++ b/release/llm_tests/serve/test_llm_serve_fault_tolerance.py @@ -0,0 +1,95 @@ +import time +from typing import Literal, List, Generator + +import pytest +import ray +from ray import serve +from ray.serve.llm import LLMConfig, ModelLoadingConfig, build_llm_deployment + +MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" +RAY_MODEL_ID = "qwen-0.5b" + + +def get_llm_config( + tensor_parallel_size: int = 1, +) -> LLMConfig: + """Create LLMConfig with specified parallelism parameters.""" + return LLMConfig( + model_loading_config=ModelLoadingConfig( + model_id=RAY_MODEL_ID, + model_source=MODEL_ID, + ), + deployment_config=dict( + name="test", + num_replicas=2, + ), + engine_kwargs=dict( + tensor_parallel_size=tensor_parallel_size, + enforce_eager=True, + ), + runtime_env={"env_vars": {"VLLM_USE_V1": "1"}}, + ) + + +def find_replica_ids(deployment_name: str) -> List[str]: + actors = ray.util.list_named_actors("serve") + found_replica_ids = [] + for actor in actors: + if deployment_name in actor["name"]: + found_replica_ids.append(actor["name"]) + return found_replica_ids + + +def kill_replica(replica_id: str) -> None: + actor = ray.get_actor(replica_id, namespace="serve") + ray.kill(actor) + + +@pytest.fixture(name="app", scope="function") +def start_ray_serve( + tensor_parallel_size: int = 1, +) -> Generator: + """Start Ray Serve with specified parallelism parameters.""" + llm_config: LLMConfig = get_llm_config(tensor_parallel_size) + app = build_llm_deployment(llm_config, name_prefix="LLM:") + serve.run(app, blocking=False) + yield app + serve.shutdown() + + +def wait_for_deployment_status( + deployment_name: str, status: Literal["HEALTHY", "UNHEALTHY"], timeout_s: int = 120 +) -> None: + s = time.time() + while time.time() - s < timeout_s: + print(f"Waiting for deployment {deployment_name} to become {status}") + state = serve.status() + if state.applications["default"].deployments[deployment_name].status == status: + return + time.sleep(1) + raise TimeoutError( + f"Deployment {deployment_name} did not become " + f"{status} within {timeout_s} seconds" + ) + + +def test_recovery_from_replica_failure(app) -> None: + """Tests that the deployment recovers from replica failure.""" + dname = "LLM:test" + wait_for_deployment_status(dname, "HEALTHY", timeout_s=60) + + # Kill both replicas + replica_ids = find_replica_ids(dname) + for replica_id in replica_ids: + print(f"Killing replica {replica_id}") + kill_replica(replica_id) + + # wait for deployment to get unhealthy + wait_for_deployment_status(dname, "UNHEALTHY", timeout_s=60) + + # Wait again for deployment to get healthy + wait_for_deployment_status(dname, "HEALTHY", timeout_s=60) + + +if __name__ == "__main__": + pytest.main(["-xvs", __file__]) diff --git a/release/llm_tests/serve/test_llm_serve_integration.py b/release/llm_tests/serve/test_llm_serve_integration.py index 97ee5c0d8fd4..c88dc8044b19 100644 --- a/release/llm_tests/serve/test_llm_serve_integration.py +++ b/release/llm_tests/serve/test_llm_serve_integration.py @@ -27,6 +27,7 @@ async def test_engine_metrics(): model="Qwen/Qwen2.5-0.5B-Instruct", dtype="auto", disable_log_stats=False, + enforce_eager=True, ) engine = AsyncLLM.from_engine_args( @@ -75,6 +76,7 @@ def remote_model_app(request): enable_chunked_prefill=True, enable_prefix_caching=True, trust_remote_code=remote_code, + enforce_eager=True, ), } diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 9fb5e96a1eea..f0098df5010f 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -4288,7 +4288,7 @@ long_running: false script: pytest -vs test_llm_serve_correctness.py -- name: llm_serve_integration +- name: llm_serve_vllm_integration_tests frequency: nightly python: "3.11" group: llm-serve @@ -4307,7 +4307,7 @@ run: timeout: 3600 long_running: false - script: pytest -vs test_llm_serve_integration.py + script: pytest -vs test_llm_serve_integration.py test_llm_serve_fault_tolerance.py - name: llm_serve_llama_3dot1_8B_quantized_tp1_1p1d frequency: nightly