diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 1cdc80dd3546..337ad8f5882e 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -44,7 +44,6 @@ def make_request(request_id, multi_modal_placeholders=mm_positions, sampling_params=SamplingParams(max_tokens=17), eos_token_id=100, - arrival_time=0, lora_request=None, cache_salt=cache_salt, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index a03810625466..1d7da8ac3c03 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -38,7 +38,6 @@ def make_request(request_id, sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), eos_token_id=100, - arrival_time=0, lora_request=None, cache_salt=cache_salt, ) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f40d477a0036..80a8978cdca6 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -138,7 +138,6 @@ def create_requests(num_requests: int, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, - arrival_time=0, ) requests.append(request) return requests @@ -732,7 +731,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): prompt_logprobs_dict={}, ) engine_core_outputs = scheduler.update_from_output(output, - model_runner_output) + model_runner_output)[0] for i in range(len(requests)): running_req = scheduler.running[i] diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 04be7c033998..4fe1d3a780ce 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,22 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import multiprocessing +import os import signal +import sys +from multiprocessing.context import SpawnProcess +from typing import Any import uvloop +import zmq import vllm.envs as envs from vllm import AsyncEngineArgs from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.api_server import (run_server, run_server_worker, + setup_server) from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) +from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, get_tcp_uri +from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx +from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus +from vllm.v1.utils import (CoreEngine, get_engine_client_zmq_addr, + wait_for_engine_startup) logger = init_logger(__name__) @@ -34,9 +46,12 @@ def cmd(args: argparse.Namespace) -> None: if hasattr(args, 'model_tag') and args.model_tag is not None: args.model = args.model_tag - if args.headless: + if args.headless or args.api_server_count < 1: run_headless(args) + elif args.api_server_count > 1: + run_multi_api_server(args) else: + # Single API server (this process). uvloop.run(run_server(args)) def validate(self, args: argparse.Namespace) -> None: @@ -67,6 +82,11 @@ def subparser_init( type=int, default=0, help='Starting data parallel rank for secondary nodes.') + serve_parser.add_argument('--api-server-count', + '-asc', + type=int, + default=1, + help='How many API server processes to run.') serve_parser.add_argument( "--config", type=str, @@ -86,6 +106,9 @@ def cmd_init() -> list[CLISubcommand]: def run_headless(args: argparse.Namespace): + if args.api_server_count > 1: + raise RuntimeError("api_server_count can't be set in headless mode") + # Create the EngineConfig. engine_args = AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER @@ -98,7 +121,7 @@ def run_headless(args: argparse.Namespace): local_engine_count = parallel_config.data_parallel_size_local host = parallel_config.data_parallel_master_ip port = engine_args.data_parallel_rpc_port # add to config too - input_address = get_tcp_uri(host, port) + handshake_address = get_tcp_uri(host, port) if local_engine_count <= 0: raise RuntimeError("data_parallel_size_local must be > 0 in " @@ -114,7 +137,7 @@ def signal_handler(signum, frame): logger.info( "Launching %d data parallel engine(s) in headless mode, " - "with head node address %s.", local_engine_count, input_address) + "with head node address %s.", local_engine_count, handshake_address) # Create the engines. engine_manager = CoreEngineProcManager( @@ -124,7 +147,7 @@ def signal_handler(signum, frame): local_start_index=0, vllm_config=vllm_config, on_head_node=False, - input_address=input_address, + handshake_address=handshake_address, executor_class=Executor.get_class(vllm_config), log_stats=not engine_args.disable_log_stats, ) @@ -134,3 +157,130 @@ def signal_handler(signum, frame): finally: logger.info("Shutting down.") engine_manager.close() + + +def run_multi_api_server(args: argparse.Namespace): + + assert not args.headless + num_api_servers = args.api_server_count + assert num_api_servers > 1 + + setup_multiprocess_prometheus() + + listen_address, sock = setup_server(args) + + engine_args = AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + parallel_config = vllm_config.parallel_config + + assert parallel_config.data_parallel_rank == 0 + + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + local_only = local_engine_count == dp_size + + # Set up input and output addresses. + input_addresses = [ + get_engine_client_zmq_addr(local_only, host) + for _ in range(num_api_servers) + ] + output_addresses = [ + get_engine_client_zmq_addr(local_only, host) + for _ in range(num_api_servers) + ] + + addresses: dict[str, Any] = { + "input_addresses": input_addresses, + "output_addresses": output_addresses, + } + + # Set up coordinator for dp > 1. + coordinator = None + stats_update_address = None + if dp_size > 1: + # TODO "ready" event for coordinator + coordinator = DPCoordinator(parallel_config) + addresses.update(coordinator.get_engine_socket_addresses()) + stats_update_address = coordinator.get_stats_publish_address() + + handshake_address = get_engine_client_zmq_addr( + local_only, host, parallel_config.data_parallel_rpc_port) + + with zmq_socket_ctx(handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: + + # Start local engines. + if not local_engine_count: + local_engine_manager = None + else: + local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + handshake_address=handshake_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=0, + local_start_index=0) + + # Start API servers. + spawn_context = multiprocessing.get_context("spawn") + api_server_workers: list[SpawnProcess] = [] + for i, in_addr, out_addr in zip(range(num_api_servers), + input_addresses, output_addresses): + client_config = { + "input_address": in_addr, + "output_address": out_addr, + "client_index": i + } + if stats_update_address is not None: + client_config["stats_update_address"] = stats_update_address + + # TODO check signal propagation + proc = spawn_context.Process(target=run_api_server_worker, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, + client_config)) + api_server_workers.append(proc) + proc.start() + + # Wait for engine handshakes to complete. + core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(dp_size) + ] + + wait_for_engine_startup( + handshake_socket, + addresses, + core_engines, + parallel_config, + vllm_config.cache_config, + local_engine_manager, + coordinator.proc if coordinator else None, + ) + + # TODO handle failures / clean shutdown here + for proc in api_server_workers: + proc.join() + + +def run_api_server_worker(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + + # Add process-specific prefix to stdout and stderr. + from multiprocessing import current_process + process_name = current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + uvloop.run( + run_server_worker(listen_address, sock, args, client_config, + **uvicorn_kwargs)) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a954a9ff90bc..7261bdfa03ce 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -17,13 +17,15 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import Annotated, Optional, Union +from typing import Annotated, Any, Optional, Union import uvloop from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from prometheus_client import make_asgi_app +from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import State from starlette.routing import Mount @@ -97,6 +99,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, is_valid_ipv6_address, set_ulimit) +from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -142,14 +145,17 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[EngineClient]: + args: Namespace, + client_config: Optional[dict[str, Any]] = None, +) -> AsyncIterator[EngineClient]: # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) async with build_async_engine_client_from_engine_args( - engine_args, args.disable_frontend_multiprocessing) as engine: + engine_args, args.disable_frontend_multiprocessing, + client_config) as engine: yield engine @@ -157,6 +163,7 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, + client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: @@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args( from vllm.v1.engine.async_llm import AsyncLLM async_llm: Optional[AsyncLLM] = None + client_index = client_config.pop( + "client_index") if client_config else 0 try: async_llm = AsyncLLM.from_vllm_config( vllm_config=vllm_config, usage_context=usage_context, disable_log_requests=engine_args.disable_log_requests, - disable_log_stats=engine_args.disable_log_stats) + disable_log_stats=engine_args.disable_log_stats, + client_addresses=client_config, + client_index=client_index) # Don't keep the dummy data in memory await async_llm.reset_mm_cache() @@ -315,22 +326,9 @@ async def validate_json_request(raw_request: Request): def mount_metrics(app: FastAPI): - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app, - multiprocess) - from prometheus_fastapi_instrumentator import Instrumentator - - registry = REGISTRY - - prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) - if prometheus_multiproc_dir_path is not None: - logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", - prometheus_multiproc_dir_path) - registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) + """Mount prometheus metrics to a FastAPI app.""" + + registry = get_prometheus_registry() Instrumentator( excluded_handlers=[ @@ -1080,16 +1078,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: return sock -async def run_server(args, **uvicorn_kwargs) -> None: - logger.info("vLLM API server version %s", VLLM_VERSION) - log_non_default_args(args) - - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - +def validate_api_server_args(args): valid_tool_parses = ToolParserManager.tool_parsers.keys() if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valid_tool_parses: + and args.tool_call_parser not in valid_tool_parses: raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " f"(chose from {{ {','.join(valid_tool_parses)} }})") @@ -1100,6 +1092,16 @@ async def run_server(args, **uvicorn_kwargs) -> None: f"invalid reasoning parser: {args.reasoning_parser} " f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + +def setup_server(args): + logger.info("vLLM API server version %s", VLLM_VERSION) + log_non_default_args(args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + validate_api_server_args(args) + # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 @@ -1116,22 +1118,39 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) - async with build_async_engine_client(args) as engine_client: + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address( + addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + + return listen_address, sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + listen_address, sock = setup_server(args) + await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) + + +async def run_server_worker(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + server_index = client_config.get("client_index", 0) if client_config else 0 + + async with build_async_engine_client(args, client_config) as engine_client: app = build_app(args) vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, vllm_config, app.state, args) - def _listen_addr(a: str) -> str: - if is_valid_ipv6_address(a): - return '[' + a + ']' - return a or "0.0.0.0" - - is_ssl = args.ssl_keyfile and args.ssl_certfile - logger.info("Starting vLLM API server on http%s://%s:%d", - "s" if is_ssl else "", _listen_addr(sock_addr[0]), - sock_addr[1]) - + logger.info("Starting vLLM API server %d on %s", server_index, + listen_address) shutdown_task = await serve_http( app, sock=sock, diff --git a/vllm/utils.py b/vllm/utils.py index 9a7da8067ba4..dd8c61bfe353 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2357,6 +2357,7 @@ def make_zmq_socket( socket_type: Any, bind: Optional[bool] = None, identity: Optional[bytes] = None, + linger: Optional[int] = None, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -2376,7 +2377,7 @@ def make_zmq_socket( buf_size = -1 # Use system default buffer size if bind is None: - bind = socket_type != zmq.PUSH + bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): socket.setsockopt(zmq.RCVHWM, 0) @@ -2389,6 +2390,9 @@ def make_zmq_socket( if identity is not None: socket.setsockopt(zmq.IDENTITY, identity) + if linger is not None: + socket.setsockopt(zmq.LINGER, linger) + # Determine if the path is a TCP socket with an IPv6 address. # Enable IPv6 on the zmq socket if so. scheme, host, _ = split_zmq_path(path) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index c17f80b6ae78..055ce446051e 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -45,7 +45,7 @@ def update_from_output( self, scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", - ) -> "EngineCoreOutputs": + ) -> dict[int, "EngineCoreOutputs"]: """Update the scheduler state based on the model runner output. This method is called after the model runner has processed the scheduled @@ -55,7 +55,8 @@ def update_from_output( for each request. Returns: - A EngineCoreOutputs object containing the outputs for each request. + A dict of client index to EngineCoreOutputs object containing the + outputs for each request originating from that client. """ raise NotImplementedError @@ -126,6 +127,11 @@ def reset_prefix_cache(self) -> bool: """ raise NotImplementedError + @abstractmethod + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + raise NotImplementedError + @abstractmethod def make_stats(self) -> Optional["SchedulerStats"]: """Make a SchedulerStats object for logging. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7773853b096a..4fc7abc90371 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -59,7 +59,8 @@ def __init__( # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. - self.include_finished_set = include_finished_set + self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( + defaultdict(set) if include_finished_set else None) # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -668,7 +669,7 @@ def update_from_output( self, scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, - ) -> EngineCoreOutputs: + ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs @@ -676,7 +677,7 @@ def update_from_output( num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: list[Request] = [] - outputs: list[EngineCoreOutput] = [] + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below @@ -772,7 +773,7 @@ def update_from_output( if new_token_ids or kv_transfer_params: # Add EngineCoreOutput for this Request. - outputs.append( + outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, @@ -802,17 +803,35 @@ def update_from_output( self._cached_reqs_data[req_data.req_id].append(req_data) self.running = new_running - engine_core_outputs = EngineCoreOutputs( - outputs=outputs, - scheduler_stats=self.make_stats(spec_decoding_stats), - ) - if self.include_finished_set: - #TODO currently sending duplicates here, improve this - engine_core_outputs.finished_requests = ( - scheduler_output.finished_req_ids | self.finished_req_ids) + + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids is not None: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if engine_core_outputs: + # Return stats to only one of the front-ends. + next(iter(engine_core_outputs.values())).scheduler_stats = ( + self.make_stats(spec_decoding_stats)) return engine_core_outputs + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + return len(self.running), len(self.waiting) + def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request @@ -854,8 +873,11 @@ def _free_request(self, request: Request) -> Optional[dict[str, Any]]: delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) - self._cached_reqs_data.pop(request.request_id, None) - self.finished_req_ids.add(request.request_id) + request_id = request.request_id + self._cached_reqs_data.pop(request_id, None) + self.finished_req_ids.add(request_id) + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) if not delay_free_blocks: self._free_blocks(request) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 122a5a72cc36..243c4eafe227 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -44,10 +44,6 @@ class EngineCoreRequest( omit_defaults=True, # type: ignore[call-arg] gc=False): # type: ignore[call-arg] - # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, - # but this object is currently not playing well with msgspec - # due to circular imports and typing we have in data.py - request_id: str prompt_token_ids: list[int] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] @@ -59,6 +55,8 @@ class EngineCoreRequest( lora_request: Optional[LoRARequest] cache_salt: Optional[str] + client_index: int = 0 + # Used in DP case to indicate which wave of requests this is expected to # belong to, to cover a race condition where the request is sent before # a wave finished notification is received. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0d646d8dd575..7026793befb0 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -34,6 +34,7 @@ from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, setup_default_loggers) +from vllm.v1.metrics.prometheus import shutdown_prometheus from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -52,6 +53,8 @@ def __init__( log_requests: bool = True, start_engine_loop: bool = True, stat_loggers: Optional[list[StatLoggerFactory]] = None, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, ) -> None: """ Create an AsyncLLM. @@ -119,6 +122,8 @@ def __init__( vllm_config=vllm_config, executor_class=executor_class, log_stats=self.log_stats, + client_addresses=client_addresses, + client_index=client_index, ) if self.stat_loggers: for stat_logger in self.stat_loggers[0]: @@ -140,6 +145,8 @@ def from_vllm_config( stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_requests: bool = False, disable_log_stats: bool = False, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, ) -> "AsyncLLM": if not envs.VLLM_USE_V1: raise ValueError( @@ -157,6 +164,8 @@ def from_vllm_config( log_requests=not disable_log_requests, log_stats=not disable_log_stats, usage_context=usage_context, + client_addresses=client_addresses, + client_index=client_index, ) @classmethod @@ -190,6 +199,8 @@ def __del__(self): def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" + shutdown_prometheus() + if engine_core := getattr(self, "engine_core", None): engine_core.shutdown() @@ -393,7 +404,6 @@ async def output_handler(): # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. if stat_loggers: - assert outputs.scheduler_stats is not None AsyncLLM._record_stats( stat_loggers[outputs.engine_index], scheduler_stats=outputs.scheduler_stats, @@ -417,7 +427,7 @@ async def abort(self, request_id: str) -> None: @staticmethod def _record_stats( stat_loggers: list[StatLoggerBase], - scheduler_stats: SchedulerStats, + scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], ): """static so that it can be used from the output_handler task diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py new file mode 100644 index 000000000000..ef5402196c67 --- /dev/null +++ b/vllm/v1/engine/coordinator.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +import multiprocessing +import sys +import time +from typing import Optional + +import msgspec.msgpack +import zmq + +from vllm.config import ParallelConfig +from vllm.logger import init_logger +from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket +from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType +from vllm.v1.serial_utils import MsgpackDecoder +from vllm.v1.utils import get_engine_client_zmq_addr + +logger = init_logger(__name__) + + +class DPCoordinator: + + def __init__(self, parallel_config: ParallelConfig): + + # Assume coordinator is colocated with front-end procs. + front_publish_address = get_open_zmq_ipc_path() + + dp_size = parallel_config.data_parallel_size + assert dp_size > 1, "Coordinator only used for data parallel" + + local_only = dp_size == parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + back_publish_address = get_engine_client_zmq_addr(local_only, host) + back_output_address = get_engine_client_zmq_addr(local_only, host) + + context = get_mp_context() + self.proc: multiprocessing.Process = context.Process( + target=CoordinatorProc.run_coordinator, + name="VLLM_DP_Coordinator", + kwargs={ + "engine_count": parallel_config.data_parallel_size, + "front_publish_address": front_publish_address, + "back_output_address": back_output_address, + "back_publish_address": back_publish_address, + }, + daemon=True) + self.proc.start() + + self.stats_publish_address = front_publish_address + self.coord_in_address = back_publish_address + self.coord_out_address = back_output_address + + def get_stats_publish_address(self) -> str: + return self.stats_publish_address + + def get_engine_socket_addresses(self) -> dict[str, str]: + return { + "coord_in_address": self.coord_in_address, + "coord_out_address": self.coord_out_address, + } + + def close(self): + self.proc.terminate() + + +class EngineState: + + def __init__(self): + self.request_counts = [0, 0] # [waiting, running] + + +class CoordinatorProc: + + def __init__(self, engine_count: int): + + self.ctx = zmq.Context() + + self.engines = [EngineState() for _ in range(engine_count)] + + self.current_wave = 0 + self.engines_running = False + self.stats_changed = False + + @staticmethod + def run_coordinator( + engine_count: int, + front_publish_address: str, + back_output_address: str, + back_publish_address: str, + ): + coordinator = CoordinatorProc(engine_count=engine_count) + + try: + coordinator.process_input_socket( + front_publish_address, + back_output_address, + back_publish_address, + ) + except KeyboardInterrupt: + logger.info("DP Coordinator process exiting") + + def process_input_socket(self, front_publish_address: str, + back_output_address: str, + back_publish_address: str): + + decoder = MsgpackDecoder(EngineCoreOutputs) + + with make_zmq_socket( + path=front_publish_address, # IPC + ctx=self.ctx, + socket_type=zmq.XPUB, + bind=True, + ) as publish_front, make_zmq_socket( + path=back_output_address, # IPC or TCP + ctx=self.ctx, + socket_type=zmq.PULL, + bind=True, + ) as output_back, make_zmq_socket( + path=back_publish_address, # IPC or TCP + ctx=self.ctx, + socket_type=zmq.XPUB, + bind=True, + ) as publish_back: + + poller = zmq.Poller() + poller.register(publish_front, zmq.POLLIN) + poller.register(output_back, zmq.POLLIN) + last_publish = 0 + while True: + elapsed = int(time.time() * 1000) - last_publish + wait_for = 100 if self.stats_changed else 3000 + events = poller.poll(timeout=max(0, wait_for - elapsed)) + if not events: + engine_list = self._get_engine_counts() + to_publish = (engine_list, self.current_wave, + self.engines_running) + msg = msgspec.msgpack.encode(to_publish) + publish_front.send(msg) + last_publish = int(time.time() * 1000) + self.stats_changed = False + continue + + events = dict(events) + + if publish_front in events: + buffer = publish_front.recv() + if buffer == b'\x01': + # Ignore subscription messages. + continue + engine_index, wave = msgspec.msgpack.decode(buffer) + if wave < self.current_wave: + engine_index = None + if not self.engines_running: + self.engines_running = True + self.stats_changed = True + self._send_start_wave(publish_back, self.current_wave, + engine_index) + + if output_back in events: + buffer = output_back.recv() + outputs: EngineCoreOutputs = decoder.decode(buffer) + + assert not outputs.outputs + assert outputs.utility_output is None + + eng_index = outputs.engine_index + if outputs.scheduler_stats: + stats = self.engines[eng_index].request_counts + stats[0] = outputs.scheduler_stats.num_waiting_reqs + stats[1] = outputs.scheduler_stats.num_running_reqs + self.stats_changed = True + + #TODO record prometheus metrics here? + + if outputs.wave_complete is not None: + if self.current_wave <= wave: + self.current_wave = wave + 1 + self.engines_running = False + self.stats_changed = True + elif outputs.start_wave is not None and ( + wave > self.current_wave or + (wave == self.current_wave + and not self.engines_running)): + # Engine received request for a non-current wave so + # we must ensure that other engines progress to the + # next wave. + self.current_wave = wave + self.engines_running = True + self.stats_changed = True + self._send_start_wave(publish_back, wave, eng_index) + + @staticmethod + def _send_start_wave(socket: zmq.Socket, wave: int, + exclude_engine_index: Optional[int]): + wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index)) + socket.send_multipart( + (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) + + def _get_engine_counts(self) -> list[list[int]]: + return [e.request_counts for e in self.engines] + + def _get_engine_list(self) -> Optional[list[int]]: + shortlist: list[int] = [] + min_counts = [sys.maxsize, sys.maxsize] + for i, e in enumerate(self.engines): + if e.request_counts <= min_counts: + if e.request_counts < min_counts: + min_counts = e.request_counts + shortlist.clear() + shortlist.append(i) + return None if len(shortlist) == len(self.engines) else shortlist diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index edc79ae20b9f..4cc1da05bb66 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -7,6 +7,7 @@ import time from collections import deque from concurrent.futures import Future +from contextlib import ExitStack from inspect import isclass, signature from logging import DEBUG from typing import Any, Callable, Optional, TypeVar, Union @@ -22,7 +23,7 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx +from vllm.utils import make_zmq_socket, resolve_obj_by_qualname from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -33,6 +34,7 @@ from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -212,16 +214,13 @@ def execute_model(self, scheduler_output: SchedulerOutput): # Re-raise exception raise err - def step(self) -> EngineCoreOutputs: + def step(self) -> dict[int, EngineCoreOutputs]: """Schedule, execute, and make output.""" # Check for any requests remaining in the scheduler - unfinished, # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): - return EngineCoreOutputs( - outputs=[], - scheduler_stats=self.scheduler.make_stats(), - ) + return {} scheduler_output = self.scheduler.schedule() model_output = self.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( @@ -229,7 +228,7 @@ def step(self) -> EngineCoreOutputs: return engine_core_outputs - def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: + def step_with_batch_queue(self) -> Optional[dict[int, EngineCoreOutputs]]: """Schedule and execute batches with the batch queue. Note that if nothing to output in this step, None is returned. @@ -271,8 +270,8 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: # Blocking until the first result is available. model_output = future.result() self.batch_queue.task_done() - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) + engine_core_outputs = (self.scheduler.update_from_output( + scheduler_output, model_output)) return engine_core_outputs @@ -350,7 +349,7 @@ def __init__( self, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, engine_index: int = 0, @@ -363,15 +362,22 @@ def __init__( # Create input socket. input_ctx = zmq.Context() identity = engine_index.to_bytes(length=2, byteorder="little") - input_socket = make_zmq_socket(input_ctx, - input_address, - zmq.DEALER, - identity=identity, - bind=False) - try: + with make_zmq_socket(input_ctx, + handshake_address, + zmq.DEALER, + identity=identity, + linger=5000, + bind=False) as handshake_socket: + # Register engine with front-end. - output_address = self.startup_handshake( - input_socket, on_head_node, vllm_config.parallel_config) + addresses = self.startup_handshake(handshake_socket, on_head_node, + vllm_config.parallel_config) + input_addresses: list[str] = addresses["input_addresses"] + output_addresses: list[str] = addresses["output_addresses"] + coord_in_addr: Optional[str] = addresses.get("coord_in_address") + coord_out_addr: Optional[str] = addresses.get("coord_out_address") + self.client_count = len(output_addresses) + self.coordinator = coord_out_addr is not None # Update config which may have changed from the handshake. vllm_config.__post_init__() @@ -383,45 +389,44 @@ def __init__( super().__init__(vllm_config, executor_class, log_stats, executor_fail_callback) + self.engine_index = engine_index self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) self.engines_running = False + self.last_counts = (0, 0) # Send ready message. num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks - input_socket.send( + handshake_socket.send( msgspec.msgpack.encode({ "status": "READY", "local": on_head_node, "num_gpu_blocks": num_gpu_blocks, })) - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue = input_queue - self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() - threading.Thread(target=self.process_input_socket, - args=(input_socket, ), - daemon=True).start() - input_socket = None - self.output_thread = threading.Thread( - target=self.process_output_socket, - args=(output_address, engine_index), - daemon=True) - self.output_thread.start() - finally: - if input_socket is not None: - input_socket.close(linger=0) + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], + bytes]]() + threading.Thread(target=self.process_input_sockets, + args=(input_addresses, coord_in_addr, identity), + daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_sockets, + args=(output_addresses, coord_out_addr, engine_index), + daemon=True) + self.output_thread.start() @staticmethod - def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, - parallel_config: ParallelConfig) -> str: + def startup_handshake(handshake_socket: zmq.Socket, on_head_node: bool, + parallel_config: ParallelConfig) -> dict[str, Any]: # Send registration message. - input_socket.send( + handshake_socket.send( msgspec.msgpack.encode({ "status": "HELLO", "local": on_head_node, @@ -429,22 +434,19 @@ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, # Receive initialization message. logger.info("Waiting for init message from front-end.") - if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): + if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): raise RuntimeError("Did not receive response from front-end " f"process within {HANDSHAKE_TIMEOUT_MINS} " f"minutes") - init_bytes = input_socket.recv() + init_bytes = handshake_socket.recv() init_message = msgspec.msgpack.decode(init_bytes) logger.debug("Received init message: %s", init_message) - output_socket_address = init_message["output_socket_address"] - #TBD(nick) maybe replace IP with configured head node address - - received_parallel_config = init_message["parallel_config"] + received_parallel_config = init_message.pop("parallel_config") for key, value in received_parallel_config.items(): setattr(parallel_config, key, value) - return output_socket_address + return init_message["addresses"] @staticmethod def run_engine_core(*args, @@ -536,9 +538,22 @@ def _process_engine_step(self): # Step the engine core. outputs = self.step_fn() + if not outputs: + return + # Put EngineCoreOutputs into the output queue. - if outputs is not None: - self.output_queue.put_nowait(outputs) + for output in outputs.items(): + self.output_queue.put_nowait(output) + + if self.coordinator: + # If there is a DP coordinator, publish our request counts + # (if they've changed) + counts = self.scheduler.get_request_counts() + if counts != self.last_counts: + self.last_counts = counts + stats = SchedulerStats(*counts) + self.output_queue.put_nowait( + (-1, EngineCoreOutputs(scheduler_stats=stats))) def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: @@ -549,7 +564,7 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) elif request_type == EngineCoreRequestType.UTILITY: - call_id, method_name, args = request + client_idx, call_id, method_name, args = request output = UtilityOutput(call_id) try: method = getattr(self, method_name) @@ -560,7 +575,7 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, output.failure_message = (f"Call to {method_name} method" f" failed: {str(e)}") self.output_queue.put_nowait( - EngineCoreOutputs(utility_output=output)) + (client_idx, EngineCoreOutputs(utility_output=output))) elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: raise RuntimeError("Executor failed.") else: @@ -593,27 +608,68 @@ def _send_engine_dead(self): logger.fatal("vLLM shutdown signal from EngineCore failed " "to send. Please report this issue.") - def process_input_socket(self, input_socket: zmq.Socket): + def process_input_sockets(self, input_addresses: list[str], + coord_input_address: Optional[str], + identity: bytes): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() - while True: - # (RequestType, RequestData) - type_frame, *data_frames = input_socket.recv_multipart(copy=False) - request_type = EngineCoreRequestType(bytes(type_frame.buffer)) - - # Deserialize the request data. - decoder = add_request_decoder if ( - request_type == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frames) - - # Push to input queue for core busy loop. - self.input_queue.put_nowait((request_type, request)) + with ExitStack() as stack, zmq.Context() as ctx: + input_sockets = [ + stack.enter_context( + make_zmq_socket(ctx, + input_address, + zmq.DEALER, + identity=identity, + bind=False)) + for input_address in input_addresses + ] + if coord_input_address is None: + coord_socket = None + else: + coord_socket = stack.enter_context( + make_zmq_socket(ctx, + coord_input_address, + zmq.XSUB, + identity=identity, + bind=False)) + # Send subscription message to coordinator. + coord_socket.send(b'\x01') + + # Register sockets with poller. + poller = zmq.Poller() + for input_socket in input_sockets: + # Send initial message to each input socket - this is required + # before the front-end ROUTER socket can send input messages + # back to us. + input_socket.send(b'') + poller.register(input_socket, zmq.POLLIN) + if coord_socket is not None: + poller.register(coord_socket, zmq.POLLIN) - def process_output_socket(self, output_path: str, engine_index: int): + while True: + for input_socket, _ in poller.poll(): + # (RequestType, RequestData) + type_frame, *data_frames = input_socket.recv_multipart( + copy=False) + request_type = EngineCoreRequestType( + bytes(type_frame.buffer)) + + # Deserialize the request data. + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frames) + + # Push to input queue for core busy loop. + self.input_queue.put_nowait((request_type, request)) + + def process_output_sockets(self, output_paths: list[str], + coord_output_path: Optional[str], + engine_index: int): """Output socket IO thread.""" # Msgpack serialization encoding. @@ -627,30 +683,50 @@ def process_output_socket(self, output_path: str, engine_index: int): # We must set linger to ensure the ENGINE_CORE_DEAD # message is sent prior to closing the socket. - with zmq_socket_ctx(output_path, zmq.constants.PUSH, - linger=4000) as socket: + with ExitStack() as stack, zmq.Context() as ctx: + sockets = [ + stack.enter_context( + make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)) + for output_path in output_paths + ] + coord_socket = stack.enter_context( + make_zmq_socket( + ctx, coord_output_path, zmq.PUSH, bind=False, + linger=4000)) if coord_output_path is not None else None + max_reuse_bufs = len(sockets) + 1 + while True: - outputs = self.output_queue.get() - if outputs == EngineCoreProc.ENGINE_CORE_DEAD: - socket.send(outputs, copy=False) + output = self.output_queue.get() + if output == EngineCoreProc.ENGINE_CORE_DEAD: + for socket in sockets: + socket.send(output) + #TODO also send to coordinator here? break - assert not isinstance(outputs, bytes) + assert not isinstance(output, bytes) + client_index, outputs = output outputs.engine_index = engine_index + if client_index == -1: + # Don't reuse buffer for coordinator message + # which will be very small. + assert coord_socket is not None + coord_socket.send_multipart(encoder.encode(outputs)) + continue + # Reclaim buffers that zmq is finished with. while pending and pending[-1][0].done: reuse_buffers.append(pending.pop()[2]) buffer = reuse_buffers.pop() if reuse_buffers else bytearray() buffers = encoder.encode_into(outputs, buffer) - tracker = socket.send_multipart(buffers, - copy=False, - track=True) + tracker = sockets[client_index].send_multipart(buffers, + copy=False, + track=True) if not tracker.done: ref = outputs if len(buffers) > 1 else None pending.appendleft((tracker, ref, buffer)) - elif len(reuse_buffers) < 2: - # Keep at most 2 buffers to reuse. + elif len(reuse_buffers) < max_reuse_bufs: + # Limit the number of buffers to reuse. reuse_buffers.append(buffer) @@ -662,7 +738,7 @@ def __init__( self, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -677,10 +753,11 @@ def __init__( # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. self.counter = 0 + self.current_wave = 0 # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank - super().__init__(vllm_config, on_head_node, input_address, + super().__init__(vllm_config, on_head_node, handshake_address, executor_class, log_stats, dp_rank) def _init_data_parallel(self, vllm_config: VllmConfig): @@ -703,7 +780,6 @@ def _init_data_parallel(self, vllm_config: VllmConfig): self.local_dp_rank = local_dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() - self.current_wave = 0 def shutdown(self): super().shutdown() @@ -718,15 +794,16 @@ def add_request(self, request: EngineCoreRequest): # Request received for an already-completed wave, notify # front-end that we need to start the next one. self.output_queue.put_nowait( - EngineCoreOutputs(start_wave=self.current_wave)) + (-1, EngineCoreOutputs(start_wave=self.current_wave))) super().add_request(request) def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: if request_type == EngineCoreRequestType.START_DP_WAVE: - new_wave: int = request - if new_wave >= self.current_wave: + new_wave, exclude_eng_index = request + if exclude_eng_index != self.engine_index and ( + new_wave >= self.current_wave): self.current_wave = new_wave if not self.engines_running: logger.debug("EngineCore starting idle loop for wave %d.", @@ -779,7 +856,8 @@ def run_busy_loop(self): logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) self.output_queue.put_nowait( - EngineCoreOutputs(wave_complete=self.current_wave)) + (-1, + EngineCoreOutputs(wave_complete=self.current_wave))) self.current_wave += 1 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 0d52bc9a6814..287cfe2b1760 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,6 +2,7 @@ import asyncio import contextlib import queue +import sys import uuid import weakref from abc import ABC, abstractmethod @@ -9,26 +10,27 @@ from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass -from enum import Enum, auto from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union -import msgspec +import msgspec.msgpack import zmq import zmq.asyncio -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_port, get_open_zmq_inproc_path, - get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket) +from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket, + zmq_socket_ctx) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) +from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr -from vllm.v1.utils import CoreEngineProcManager +from vllm.v1.utils import (CoreEngine, CoreEngineProcManager, + get_engine_client_zmq_addr, wait_for_engine_startup) logger = init_logger(__name__) @@ -36,8 +38,6 @@ _R = TypeVar('_R') # Return type for collective_rpc -STARTUP_POLL_PERIOD_MS = 10000 - class EngineCoreClient(ABC): """ @@ -206,7 +206,7 @@ def __init__(self, *args, **kwargs): self.engine_core = EngineCore(*args, **kwargs) def get_output(self) -> EngineCoreOutputs: - return self.engine_core.step() + return self.engine_core.step().get(0) or EngineCoreOutputs() def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) @@ -265,24 +265,6 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) -class CoreEngineState(Enum): - NEW = auto() - CONNECTED = auto() - READY = auto() - - -class CoreEngine: - """One per data parallel rank.""" - - def __init__(self, index: int = 0, local: bool = True): - self.local = local - self.index = index - self.identity = index.to_bytes(length=2, byteorder="little") - - self.state = CoreEngineState.NEW - self.num_reqs_in_flight = 0 - - @dataclass class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding @@ -290,9 +272,11 @@ class BackgroundResources: ctx: Union[zmq.Context] local_engine_manager: Optional[CoreEngineProcManager] = None + coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_queue_task: Optional[asyncio.Task] = None + stats_update_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None # Set if any of the engines are dead. Here so that the output @@ -305,9 +289,13 @@ def __call__(self): self.engine_dead = True if self.local_engine_manager is not None: self.local_engine_manager.close() + if self.coordinator is not None: + self.coordinator.close() if self.output_queue_task is not None: self.output_queue_task.cancel() + if self.stats_update_task is not None: + self.stats_update_task.cancel() # ZMQ context termination can hang if the sockets # aren't explicitly closed first. @@ -349,6 +337,7 @@ def __init__( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, ): self.vllm_config = vllm_config # Serialization setup. @@ -368,8 +357,8 @@ def __init__( try: parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local - start_index = parallel_config.data_parallel_rank local_start_index = parallel_config.data_parallel_rank_local + dp_size = parallel_config.data_parallel_size # SPMD mode is where there is an LLM instance per DP rank and # one core engine per LLM, see @@ -381,42 +370,53 @@ def __init__( CoreEngine(index=local_start_index, local=True) ] else: - assert start_index == 0 + assert parallel_config.data_parallel_rank == 0 local_start_index = 0 self.core_engines = [ CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(parallel_config.data_parallel_size) + for i in range(dp_size) ] - input_address, output_address = self._get_zmq_addresses( - parallel_config, spmd_mode) + local_only = spmd_mode or local_engine_count == dp_size + + self.stats_update_address: Optional[str] = None + if client_addresses is not None: + input_address = client_addresses["input_address"] + output_address = client_addresses["output_address"] + self.stats_update_address = client_addresses.get( + "stats_update_address") + else: + host = parallel_config.data_parallel_master_ip + input_address = get_engine_client_zmq_addr(local_only, host) + output_address = get_engine_client_zmq_addr(local_only, host) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( self.ctx, input_address, zmq.ROUTER, bind=True) - self.resources.output_socket = make_zmq_socket( - self.ctx, output_address, zmq.constants.PULL) - # Start local engines. - if local_engine_count: - # In server mode, start_index and local_start_index will - # both be 0. - self.resources.local_engine_manager = CoreEngineProcManager( - EngineCoreProc.run_engine_core, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=log_stats, - input_address=input_address, - on_head_node=True, - local_engine_count=local_engine_count, - start_index=start_index, - local_start_index=local_start_index) + self.ctx, output_address, zmq.PULL) + + if client_addresses is None: + self._init_engines_direct(vllm_config, local_only, + local_start_index, input_address, + output_address, executor_class, + log_stats) + coordinator = self.resources.coordinator + if coordinator: + self.stats_update_address = ( + coordinator.get_stats_publish_address()) + + # Wait for ready messages from each engine on the input socket. + identities = set(e.identity for e in self.core_engines) + sync_input_socket = zmq.Socket.shadow(self.input_socket) + while identities: + if not sync_input_socket.poll(timeout=600_000): + raise TimeoutError("Timed out waiting for engines to send" + "initial message on input socket.") + identity, _ = sync_input_socket.recv_multipart() + identities.remove(identity) self.core_engine = self.core_engines[0] - - # Wait for engine core process(es) to start. - self._wait_for_engine_startup(output_address, parallel_config) - self.utility_results: dict[int, AnyFuture] = {} # Request objects which may contain pytorch-allocated tensors @@ -429,116 +429,66 @@ def __init__( if not success: self._finalizer() - @staticmethod - def _get_zmq_addresses(parallel_config: ParallelConfig, - spmd_mode: bool) -> tuple[str, str]: - """Returns (input_address, output_address).""" - dp_size = parallel_config.data_parallel_size + def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, + local_start_index: int, input_address: str, + output_address: str, + executor_class: type[Executor], log_stats: bool): + """Self-contained client mode, launch engine and coordinator process + as needed.""" + + parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + host = parallel_config.data_parallel_master_ip - if local_engine_count == dp_size or spmd_mode: - input_address = get_open_zmq_ipc_path() - output_address = get_open_zmq_ipc_path() - else: - host = parallel_config.data_parallel_master_ip - input_port = parallel_config.data_parallel_rpc_port - output_port = get_open_port() - input_address = get_tcp_uri(host, input_port) - output_address = get_tcp_uri(host, output_port) - - return input_address, output_address - - def _wait_for_engine_startup(self, output_address: str, - parallel_config: ParallelConfig): - # Get a sync handle to the socket which can be sync or async. - sync_input_socket = zmq.Socket.shadow(self.input_socket) - - # Wait for engine core process(es) to send ready messages. - local_count = parallel_config.data_parallel_size_local - remote_count = len(self.core_engines) - local_count - # [local, remote] counts - conn_pending, start_pending = [local_count, remote_count], [0, 0] - - poller = zmq.Poller() - poller.register(sync_input_socket, zmq.POLLIN) - proc_manager = self.resources.local_engine_manager - if proc_manager is not None: - for sentinel in proc_manager.sentinels(): - poller.register(sentinel, zmq.POLLIN) - while any(conn_pending) or any(start_pending): - events = poller.poll(STARTUP_POLL_PERIOD_MS) - if not events: - if any(conn_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to connect.", *conn_pending) - if any(start_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to start.", *start_pending) - continue - if len(events) > 1 or events[0][0] != sync_input_socket: - # One of the local core processes exited. - finished = proc_manager.finished_procs( - ) if proc_manager else {} - raise RuntimeError("Engine core initialization failed. " - "See root cause above. " - f"Failed core proc(s): {finished}") - - # Receive HELLO and READY messages from the input socket. - eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() - eng_index = int.from_bytes(eng_identity, byteorder="little") - engine = next( - (e for e in self.core_engines if e.identity == eng_identity), - None) - if engine is None: - raise RuntimeError(f"Message from engine with unexpected data " - f"parallel rank: {eng_index}") - msg = msgspec.msgpack.decode(ready_msg_bytes) - status, local = msg["status"], msg["local"] - if local != engine.local: - raise RuntimeError(f"{status} message from " - f"{'local' if local else 'remote'} " - f"engine {eng_index}, expected it to be " - f"{'local' if engine.local else 'remote'}") - - if status == "HELLO" and engine.state == CoreEngineState.NEW: - - # Send init message with DP config info. - init_message = self.encoder.encode({ - "output_socket_address": output_address, - "parallel_config": { - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "data_parallel_size": - parallel_config.data_parallel_size, - }, - }) - sync_input_socket.send_multipart((eng_identity, *init_message), - copy=False) - conn_pending[0 if local else 1] -= 1 - start_pending[0 if local else 1] += 1 - engine.state = CoreEngineState.CONNECTED - elif status == "READY" and (engine.state - == CoreEngineState.CONNECTED): - # Setup KV cache config with initialization state from - # engine core process. Sum values from all engines in DP case. - cache_config = self.vllm_config.cache_config - num_gpu_blocks = cache_config.num_gpu_blocks or 0 - num_gpu_blocks += msg['num_gpu_blocks'] - cache_config.num_gpu_blocks = num_gpu_blocks - - start_pending[0 if local else 1] -= 1 - engine.state = CoreEngineState.READY - else: - raise RuntimeError(f"Unexpected {status} message for " - f"{'local' if local else 'remote'} engine " - f"{eng_index} in {engine.state} state.") + if len(self.core_engines) > 1: + self.resources.coordinator = DPCoordinator(parallel_config) + + handshake_address = get_engine_client_zmq_addr( + local_only, host, parallel_config.data_parallel_rpc_port) - logger.debug("%s from %s core engine process %s.", status, - "local" if local else "remote", eng_index) + with zmq_socket_ctx(handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: + + # Start local engines. + if local_engine_count: + # In server mode, start_index and local_start_index will + # both be 0. + self.resources.local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + handshake_address=handshake_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index) + + # Wait for engine core process(es) to start. + self._wait_for_engine_startup(handshake_socket, input_address, + output_address) + + def _wait_for_engine_startup(self, handshake_socket: zmq.Socket, + input_address: str, output_address: str): + addresses: dict[str, Any] = { + "input_addresses": [input_address], + "output_addresses": [output_address], + } + + coordinator = self.resources.coordinator + if coordinator is not None: + addresses.update(coordinator.get_engine_socket_addresses()) + + wait_for_engine_startup( + handshake_socket, + addresses, + self.core_engines, + self.vllm_config.parallel_config, + self.vllm_config.cache_config, + self.resources.local_engine_manager, + coordinator.proc if coordinator else None, + ) def shutdown(self): # Terminate background resources. @@ -604,8 +554,8 @@ def process_outputs_socket(): try: shutdown_socket.bind(shutdown_path) poller = zmq.Poller() - poller.register(shutdown_socket) - poller.register(out_socket) + poller.register(shutdown_socket, zmq.POLLIN) + poller.register(out_socket, zmq.POLLIN) while True: socks = poller.poll() if not socks: @@ -667,7 +617,7 @@ def call_utility(self, method: str, *args) -> Any: future: Future[Any] = Future() self.utility_results[call_id] = future self._send_input(EngineCoreRequestType.UTILITY, - (call_id, method, args)) + (0, call_id, method, args)) return future.result() @@ -729,15 +679,21 @@ def save_sharded_state(self, class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): super().__init__( asyncio_mode=True, vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, + client_addresses=client_addresses, ) + self.client_index = client_index self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, Exception]]() try: @@ -853,12 +809,13 @@ async def _call_utility_async(self, method: str, *args, future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (call_id, method, args))) + (self.client_index, call_id, method, args))) await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future async def add_request_async(self, request: EngineCoreRequest) -> None: + request.client_index = self.client_index await self._send_input(EngineCoreRequestType.ADD, request) self._ensure_output_queue_task() @@ -920,17 +877,119 @@ class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): self.current_wave = 0 self.engines_running = False + # To route aborts to the correct engine. self.reqs_in_flight: dict[str, CoreEngine] = {} - super().__init__(vllm_config, executor_class, log_stats) + super().__init__(vllm_config, executor_class, log_stats, + client_addresses, client_index) assert len(self.core_engines) > 1 + # List of [waiting, running] pair per engine. + self.lb_engines: list[list[int]] = [] + + self.first_req_sock_addr = get_open_zmq_inproc_path() + self.first_req_send_socket = make_zmq_socket(self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=True) + try: + # If we are running in an asyncio event loop, start the stats task. + # Otherwise, it will be started lazily. + asyncio.get_running_loop() + self._ensure_stats_update_task() + except RuntimeError: + pass + + def _ensure_stats_update_task(self): + resources = self.resources + if resources.stats_update_task is not None: + return + + assert self.stats_update_address is not None + + async def run_engine_stats_update_task(): + with make_zmq_socket(self.ctx, self.stats_update_address, + zmq.XSUB) as socket, make_zmq_socket( + self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=False) as first_req_rcv_socket: + # Send subscription message. + await socket.send(b'\x01') + + poller = zmq.asyncio.Poller() + poller.register(socket, zmq.POLLIN) + poller.register(first_req_rcv_socket, zmq.POLLIN) + + while True: + events = await poller.poll() + if not self.engines_running and len(events) == 2 or ( + events[0][0] == first_req_rcv_socket): + # Send a message to notify the coordinator that + # we're sending a request while the engines are + # paused, so that it can wake the others up + # (to run dummy EP loop). + self.engines_running = True + buf = first_req_rcv_socket.recv( + flags=zmq.NOBLOCK).result() + target_eng_index = int.from_bytes(buf, "little") + msg = msgspec.msgpack.encode( + (target_eng_index, self.current_wave)) + await socket.send(msg) + + buf = None + while True: + # Drain all stats events (we only care about latest). + future: asyncio.Future[bytes] = socket.recv( + flags=zmq.NOBLOCK) + if isinstance(future.exception(), zmq.Again): + break + buf = future.result() + if buf is None: + continue + + # Update local load-balancing state. + counts, wave, running = msgspec.msgpack.decode(buf) + self.current_wave = wave + self.engines_running = running + self.lb_engines = counts + + resources.stats_update_task = asyncio.create_task( + run_engine_stats_update_task()) + + def get_core_engine_for_request(self) -> CoreEngine: + if not self.lb_engines: + return self.core_engines[0] + # TODO use P2C alg for larger DP sizes + num_engines = len(self.lb_engines) + min_counts = [sys.maxsize, sys.maxsize] + eng_index = 0 + for i in range(num_engines): + # Start from client_index for to help with balancing when + # engines are empty. + idx = (self.client_index + i) % num_engines + counts = self.lb_engines[idx] + if counts < min_counts: + min_counts = counts + eng_index = idx + # Adjust local counts for better balancing between stats updates + # from the coordinator (these are overwritten 10x per second). + if min_counts[0]: + min_counts[0] += 1 + else: + min_counts[1] += 1 + return self.core_engines[eng_index] + async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. return (await asyncio.gather(*[ @@ -939,62 +998,30 @@ async def call_utility_async(self, method: str, *args) -> Any: ]))[0] async def add_request_async(self, request: EngineCoreRequest) -> None: + self._ensure_stats_update_task() + request.current_wave = self.current_wave + request.client_index = self.client_index chosen_engine = self.get_core_engine_for_request() self.reqs_in_flight[request.request_id] = chosen_engine - chosen_engine.num_reqs_in_flight += 1 to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine) if not self.engines_running: - # Send request to chosen engine and dp start loop - # control message to all other engines. - self.engines_running = True - to_await = asyncio.gather( - to_await, # type: ignore[assignment] - *self._start_wave_coros(exclude_index=chosen_engine.index)) + # Notify coordinator that we're sending a request + await self.first_req_send_socket.send(chosen_engine.identity) await to_await self._ensure_output_queue_task() - def get_core_engine_for_request(self) -> CoreEngine: - return min(self.core_engines, key=lambda e: e.num_reqs_in_flight) - @staticmethod async def process_engine_outputs(self: "DPAsyncMPClient", outputs: EngineCoreOutputs): - if self.reqs_in_flight: - for req_id in outputs.finished_requests or (): - if engine := self.reqs_in_flight.pop(req_id, None): - engine.num_reqs_in_flight -= 1 - - if outputs.wave_complete is not None: - # Current wave is complete, move to next wave number - # and mark engines as paused. - if self.current_wave <= outputs.wave_complete: - self.current_wave = outputs.wave_complete + 1 - self.engines_running = False - - elif outputs.start_wave is not None and ( - outputs.start_wave > self.current_wave or - (outputs.start_wave == self.current_wave - and not self.engines_running)): - # Engine received request for a non-current wave so we must ensure - # that other engines progress to the next wave. - self.current_wave = outputs.start_wave - self.engines_running = True - await asyncio.gather(*self._start_wave_coros( - exclude_index=outputs.engine_index)) - - def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]: - logger.debug("Sending start DP wave %d.", self.current_wave) - return [ - self._send_input(EngineCoreRequestType.START_DP_WAVE, - self.current_wave, engine) - for engine in self.core_engines if engine.index != exclude_index - ] + if outputs.finished_requests and self.reqs_in_flight: + for req_id in outputs.finished_requests: + self.reqs_in_flight.pop(req_id, None) async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids: diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 6ee40850beb1..48e03f4e99d4 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -12,13 +12,12 @@ from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason +from vllm.v1.metrics.prometheus import unregister_vllm_metrics from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5.0 - StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] @@ -35,7 +34,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ... @abstractmethod - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): ... @@ -78,20 +77,22 @@ def _get_throughput(self, tracked_stats: list[int], now: float) -> float: # Compute summary metrics for tracked stats return float(np.sum(tracked_stats) / (now - self.last_log_time)) - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): """Log Stats to standard output.""" if iteration_stats: self._track_iteration_stats(iteration_stats) - self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + if scheduler_stats is not None: + self.prefix_caching_metrics.observe( + scheduler_stats.prefix_cache_stats) - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_logging.observe( - scheduler_stats.spec_decoding_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_logging.observe( + scheduler_stats.spec_decoding_stats) - self.last_scheduler_stats = scheduler_stats + self.last_scheduler_stats = scheduler_stats def log(self): now = time.monotonic() @@ -131,16 +132,18 @@ def log(self): self.spec_decoding_logging.log(log_fn=log_fn) def log_engine_initialized(self): - logger.info( - "vllm cache_config_info with initialization " \ - "after num_gpu_blocks is: %d", - self.vllm_config.cache_config.num_gpu_blocks) + if self.vllm_config.cache_config.num_gpu_blocks: + logger.info( + "Engine %03d: vllm cache_config_info with initialization " + "after num_gpu_blocks is: %d", self.engine_index, + self.vllm_config.cache_config.num_gpu_blocks) class PrometheusStatLogger(StatLoggerBase): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): - self._unregister_vllm_metrics() + + unregister_vllm_metrics() self.vllm_config = vllm_config self.engine_index = engine_index # Use this flag to hide metrics that were deprecated in @@ -165,11 +168,13 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.gauge_scheduler_running = prometheus_client.Gauge( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) self.gauge_scheduler_waiting = prometheus_client.Gauge( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) # @@ -178,6 +183,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.gauge_gpu_cache_usage = prometheus_client.Gauge( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) self.counter_gpu_prefix_cache_queries = prometheus_client.Counter( @@ -238,6 +244,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) + # TODO: This metric might be incorrect in case of using multiple + # api_server counts which uses prometheus mp. + # See: https://github.com/vllm-project/vllm/pull/18053 self.histogram_iteration_tokens = \ prometheus_client.Histogram( name="vllm:iteration_tokens_total", @@ -336,6 +345,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): # # LoRA metrics # + + # TODO: This metric might be incorrect in case of using multiple + # api_server counts which uses prometheus mp. self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: self.labelname_max_lora = "max_lora" @@ -346,13 +358,16 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): prometheus_client.Gauge( name="vllm:lora_requests_info", documentation="Running stats on lora requests.", + multiprocess_mode="sum", labelnames=[ self.labelname_max_lora, self.labelname_waiting_lora_adapters, self.labelname_running_lora_adapters, - ]) + ], + ) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): + metrics_info = config_obj.metrics_info() metrics_info["engine"] = self.engine_index @@ -368,25 +383,28 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): info_gauge = prometheus_client.Gauge( name=name, documentation=documentation, - labelnames=metrics_info.keys()).labels(**metrics_info) + multiprocess_mode="mostrecent", + labelnames=metrics_info.keys(), + ).labels(**metrics_info) info_gauge.set(1) - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): """Log to prometheus.""" - self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) - self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) + if scheduler_stats is not None: + self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) + self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) - self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) + self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) - self.counter_gpu_prefix_cache_queries.inc( - scheduler_stats.prefix_cache_stats.queries) - self.counter_gpu_prefix_cache_hits.inc( - scheduler_stats.prefix_cache_stats.hits) + self.counter_gpu_prefix_cache_queries.inc( + scheduler_stats.prefix_cache_stats.queries) + self.counter_gpu_prefix_cache_hits.inc( + scheduler_stats.prefix_cache_stats.hits) - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_prom.observe( - scheduler_stats.spec_decoding_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_prom.observe( + scheduler_stats.spec_decoding_stats) if iteration_stats is None: return @@ -441,13 +459,6 @@ def record(self, scheduler_stats: SchedulerStats, self.gauge_lora_info.labels(**lora_info_labels)\ .set_to_current_time() - @staticmethod - def _unregister_vllm_metrics(): - # Unregister any existing vLLM collectors (for CI/CD - for collector in list(prometheus_client.REGISTRY._collector_to_names): - if hasattr(collector, "_name") and "vllm" in collector._name: - prometheus_client.REGISTRY.unregister(collector) - def log_engine_initialized(self): self.log_metrics_info("cache_config", self.vllm_config.cache_config) diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py new file mode 100644 index 000000000000..c958d7cd31f0 --- /dev/null +++ b/vllm/v1/metrics/prometheus.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile +from typing import Optional + +from prometheus_client import REGISTRY, CollectorRegistry, multiprocess + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Global temporary directory for prometheus multiprocessing +_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None + + +def setup_multiprocess_prometheus(): + """Set up prometheus multiprocessing directory if not already configured. + + """ + global _prometheus_multiproc_dir + + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + _prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name + logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s", + _prometheus_multiproc_dir.name) + else: + logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + +def get_prometheus_registry(): + """Get the appropriate prometheus registry based on multiprocessing + configuration. + + Returns: + Registry: A prometheus registry + """ + if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None: + logger.debug("Using multiprocess registry for prometheus metrics") + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + return registry + + return REGISTRY + + +def unregister_vllm_metrics(): + """Unregister any existing vLLM collectors from the prometheus registry. + + This is useful for testing and CI/CD where metrics may be registered + multiple times across test runs. + """ + registry = get_prometheus_registry() + # Unregister any existing vLLM collectors + for collector in list(registry._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + registry.unregister(collector) + + +def shutdown_prometheus(): + """Shutdown prometheus metrics.""" + try: + pid = os.getpid() + multiprocess.mark_process_dead(pid) + logger.debug("Marked Prometheus metrics for process %d as dead", pid) + except Exception as e: + logger.error("Error during metrics cleanup: %s", str(e)) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index fc6b738546f4..53bb130bbe7f 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -21,18 +21,19 @@ class Request: def __init__( self, request_id: str, + client_index: int, prompt_token_ids: list[int], multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], multi_modal_placeholders: Optional[list[PlaceholderRange]], sampling_params: SamplingParams, eos_token_id: Optional[int], - arrival_time: float, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, ) -> None: self.request_id = request_id + self.client_index = client_index self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id @@ -91,13 +92,13 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( request_id=request.request_id, + client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, eos_token_id=request.eos_token_id, - arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 0758747a83cc..81d08d3e9632 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,22 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 -import os import time import weakref from collections import defaultdict from collections.abc import Sequence +from enum import Enum, auto from multiprocessing import Process, connection -from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, - overload) +from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, + Union, overload) +import msgspec import torch +import zmq -from vllm.config import VllmConfig +from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import get_mp_context, kill_process_tree +from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path, + get_tcp_uri, kill_process_tree) from vllm.v1.executor.abstract import Executor if TYPE_CHECKING: @@ -26,6 +29,8 @@ T = TypeVar("T") +STARTUP_POLL_PERIOD_MS = 10000 + class ConstantList(Generic[T], Sequence): @@ -95,6 +100,13 @@ def __repr__(self): return f"ConstantList({self._x})" +def get_engine_client_zmq_addr(local_only: bool, + host: str, + port: int = 0) -> str: + return get_open_zmq_ipc_path() if local_only else (get_tcp_uri( + host, port or get_open_port())) + + class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown @@ -109,7 +121,7 @@ def __init__( local_start_index: int, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -117,7 +129,7 @@ def __init__( common_kwargs = { "vllm_config": vllm_config, "on_head_node": on_head_node, - "input_address": input_address, + "handshake_address": handshake_address, "executor_class": executor_class, "log_stats": log_stats, } @@ -135,8 +147,7 @@ def __init__( "local_dp_rank": local_index, })) - self._finalizer = weakref.finalize(self, shutdown, self.processes, - input_address) + self._finalizer = weakref.finalize(self, shutdown, self.processes) try: for proc in self.processes: proc.start() @@ -164,9 +175,122 @@ def finished_procs(self) -> dict[str, int]: } +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + +class CoreEngine: + """One per data parallel rank.""" + + def __init__(self, index: int = 0, local: bool = True): + self.local = local + self.index = index + self.identity = index.to_bytes(2, "little") + + self.state = CoreEngineState.NEW + + +def wait_for_engine_startup( + handshake_socket: zmq.Socket, + addresses: dict[str, Any], + core_engines: list[CoreEngine], + parallel_config: ParallelConfig, + cache_config: CacheConfig, + proc_manager: Optional[CoreEngineProcManager], + coord_process: Optional[Process], +): + + # Wait for engine core process(es) to send ready messages. + local_count = parallel_config.data_parallel_size_local + remote_count = len(core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + poller = zmq.Poller() + poller.register(handshake_socket, zmq.POLLIN) + + if proc_manager is not None: + for sentinel in proc_manager.sentinels(): + poller.register(sentinel, zmq.POLLIN) + if coord_process is not None: + poller.register(coord_process.sentinel, zmq.POLLIN) + while any(conn_pending) or any(start_pending): + events = poller.poll(STARTUP_POLL_PERIOD_MS) + if not events: + if any(conn_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) + continue + if len(events) > 1 or events[0][0] != handshake_socket: + # One of the local core processes exited. + finished = proc_manager.finished_procs() if proc_manager else {} + if coord_process is not None and coord_process.exitcode is not None: + finished[coord_process.name] = coord_process.exitcode + raise RuntimeError("Engine core initialization failed. " + "See root cause above. " + f"Failed core proc(s): {finished}") + + # Receive HELLO and READY messages from the input socket. + eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() + eng_index = int.from_bytes(eng_identity, "little") + engine = next((e for e in core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + + if status == "HELLO" and engine.state == CoreEngineState.NEW: + + # Send init message with DP config info. + init_message = msgspec.msgpack.encode({ + "addresses": addresses, + "parallel_config": { + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": parallel_config.data_parallel_size, + }, + }) + handshake_socket.send_multipart((eng_identity, init_message), + copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and (engine.state == CoreEngineState.CONNECTED): + # Setup KV cache config with initialization state from + # engine core process. Sum values from all engines in DP case. + num_gpu_blocks = cache_config.num_gpu_blocks or 0 + num_gpu_blocks += msg['num_gpu_blocks'] + cache_config.num_gpu_blocks = num_gpu_blocks + + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) + + # Note(rob): shutdown function cannot be a bound method, -# else the gc cannot collect the objedecoupct. -def shutdown(procs: list[Process], input_address: str): +# else the gc cannot collect the object. +def shutdown(procs: list[Process]): # Shutdown the process. for proc in procs: if proc.is_alive(): @@ -185,12 +309,6 @@ def shutdown(procs: list[Process], input_address: str): if proc.is_alive() and (pid := proc.pid) is not None: kill_process_tree(pid) - # Remove zmq ipc socket files. - if input_address.startswith("ipc://"): - socket_file = input_address[len("ipc://"):] - if os and os.path.exists(socket_file): - os.remove(socket_file) - def bind_kv_cache( kv_caches: dict[str, torch.Tensor],