diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 716e27bda4f..7e08295ade3 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -635,29 +635,38 @@ def disaggregated( disagg_cfg = parse_disagg_config_file(config_file) - metadata_server_cfg = parse_metadata_server_config_file( - metadata_server_config_file) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((disagg_cfg.hostname, disagg_cfg.port)) + except OSError as e: + raise RuntimeError( + f"Failed to bind socket to {disagg_cfg.hostname}:{disagg_cfg.port}: {e}" + ) + + metadata_server_cfg = parse_metadata_server_config_file( + metadata_server_config_file) + + server = OpenAIDisaggServer( + config=disagg_cfg, + req_timeout_secs=request_timeout, + server_start_timeout_secs=server_start_timeout, + metadata_server_cfg=metadata_server_cfg, + metrics_interval_secs=metrics_log_interval) + + # Disable GC by default + # When concurrency is high, the number of Python objects increases, so + # GC runs frequently and takes a long time to process. In this case, + # requests are not immediately forwarded to CTX workers and GEN workers, + # causing them to run with small batch sizes. Disabling GC can mitigate + # this problem. + # By testing this feature, we didn't observe significant RSS or VMS + # increment, and observed that `count0` (obtained by `gc.get_count()`) + # increases by fewer than 1,000 after every 200,000 requests, while the + # maximum value of `count0` exceeded 3,000,000 during the test. + if os.getenv("TRTLLM_DISAGG_SERVER_DISABLE_GC", "1") == "1": + gc.disable() - server = OpenAIDisaggServer(config=disagg_cfg, - req_timeout_secs=request_timeout, - server_start_timeout_secs=server_start_timeout, - metadata_server_cfg=metadata_server_cfg, - metrics_interval_secs=metrics_log_interval) - - # Disable GC by default - # When concurrency is high, the number of Python objects increases, so - # GC runs frequently and takes a long time to process. In this case, - # requests are not immediately forwarded to CTX workers and GEN workers, - # causing them to run with small batch sizes. Disabling GC can mitigate - # this problem. - # By testing this feature, we didn't observe significant RSS or VMS - # increment, and observed that `count0` (obtained by `gc.get_count()`) - # increases by fewer than 1,000 after every 200,000 requests, while the - # maximum value of `count0` exceeded 3,000,000 during the test. - if os.getenv("TRTLLM_DISAGG_SERVER_DISABLE_GC", "1") == "1": - gc.disable() - - asyncio.run(server(disagg_cfg.hostname, disagg_cfg.port)) + asyncio.run(server(disagg_cfg.hostname, disagg_cfg.port, sockets=[s])) def set_cuda_device(): diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 55c3e136e5a..524dd9fd110 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -16,6 +16,7 @@ # yapf: disable import asyncio import signal +import socket import traceback from contextlib import asynccontextmanager from typing import Callable, Optional @@ -190,13 +191,13 @@ async def cluster_info(self) -> JSONResponse: async def version(self) -> JSONResponse: return JSONResponse(content={"version": VERSION}) - async def __call__(self, host: str, port: int): + async def __call__(self, host: str, port: int, sockets: list[socket.socket] | None = None): config = uvicorn.Config(self.app, host=host, port=port, log_level=logger.level, timeout_keep_alive=TIMEOUT_KEEP_ALIVE) - await uvicorn.Server(config).serve() + await uvicorn.Server(config).serve(sockets=sockets) # TODO: rework this for service discovery, now it's only for static server list async def _set_steady_clock_offsets(self):