diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md index df69849bb922..af5f77747fac 100644 --- a/docs/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -19,12 +19,13 @@ Two main reasons: Please refer to [examples/online_serving/disaggregated_prefill.sh](../../examples/online_serving/disaggregated_prefill.sh) for the example usage of disaggregated prefilling. -Now supports 5 types of connectors: +Now supports 6 types of connectors: - **ExampleConnector**: refer to [examples/offline_inference/disaggregated-prefill-v1/run.sh](../../examples/offline_inference/disaggregated-prefill-v1/run.sh) for the example usage of ExampleConnector disaggregated prefilling. - **LMCacheConnectorV1**: refer to [examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh](../../examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh) for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission. - **NixlConnector**: refer to [tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md). - **P2pNcclConnector**: refer to [examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh](../../examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh) for the example usage of P2pNcclConnector disaggregated prefilling. +- **MooncakeConnector**: refer to [examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh](../../examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh) for the example usage of ExampleConnector disaggregated prefilling. For detailed usage guide, see [MooncakeConnector Usage Guide](mooncake_connector_usage.md). - **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as: ```bash diff --git a/docs/features/mooncake_connector_usage.md b/docs/features/mooncake_connector_usage.md index 653ea29ad943..0e2478924ead 100644 --- a/docs/features/mooncake_connector_usage.md +++ b/docs/features/mooncake_connector_usage.md @@ -31,11 +31,9 @@ vllm serve Qwen/Qwen2.5-7B-Instruct --port 8020 --kv-transfer-config '{"kv_conne ### Proxy ```bash -python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --prefiller-host 192.168.0.2 --prefiller-port 8010 --decoder-host 192.168.0.3 --decoder-port 8020 +python examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py --prefill http://192.168.0.2:8010 --decode http://192.168.0.3:8020 ``` -> NOTE: The Mooncake Connector currently uses the proxy from nixl_integration. This will be replaced with a self-developed proxy in the future. - Now you can send requests to the proxy server through port 8000. ## Environment Variables @@ -43,16 +41,29 @@ Now you can send requests to the proxy server through port 8000. - `VLLM_MOONCAKE_BOOTSTRAP_PORT`: Port for Mooncake bootstrap server - Default: 8998 - Required only for prefiller instances - - Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine - - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank - - Used for the decoder notifying the prefiller + - For headless instances, must be the same as the master instance + - Each instance needs a unique port on its host; using the same port number across different hosts is fine - `VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional) - Default: 480 - If a request is aborted and the decoder has not yet notified the prefiller, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely. -## KV Role Options +## KV Transfer Config + +### KV Role Options - **kv_producer**: For prefiller instances that generate KV caches - **kv_consumer**: For decoder instances that consume KV caches from prefiller - **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined. + +### kv_connector_extra_config + +- **num_workers**: Size of thread pool for one prefiller worker to transfer KV caches by mooncake. (default 10) +- **mooncake_protocol**: Mooncake connector protocol. (default "rdma") + +## Example Scripts/Code + +Refer to these example scripts in the vLLM repository: + +- [run_mooncake_connector.sh](../../examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh) +- [mooncake_connector_proxy.py](../../examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py) diff --git a/examples/online_serving/disaggregated_serving/README.md b/examples/online_serving/disaggregated_serving/README.md index 090afd7515ee..1e3284299346 100644 --- a/examples/online_serving/disaggregated_serving/README.md +++ b/examples/online_serving/disaggregated_serving/README.md @@ -6,3 +6,4 @@ This example contains scripts that demonstrate the disaggregated serving feature - `disagg_proxy_demo.py` - Demonstrates XpYd (X prefill instances, Y decode instances). - `kv_events.sh` - Demonstrates KV cache event publishing. +- `mooncake_connector` - A proxy demo for MooncakeConnector. diff --git a/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py b/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py new file mode 100644 index 000000000000..09880a32aa35 --- /dev/null +++ b/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py @@ -0,0 +1,376 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import asyncio +import ipaddress +import itertools +import os +import urllib +import uuid +from contextlib import asynccontextmanager +from typing import Any + +import httpx +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse + + +def maybe_wrap_ipv6_address(address: str) -> str: + try: + ipaddress.IPv6Address(address) + return f"[{address}]" + except ValueError: + return address + + +def make_http_path(host: str, port: int) -> str: + return f"http://{host}:{port}" + + +def prefiller_cycle(prefill_clients: list[Any]): + while True: + for prefill_client in prefill_clients: + for i in range(prefill_client["dp_size"]): + yield prefill_client, i + + +async def get_prefiller_info(prefill_clients: list, ready: asyncio.Event): + for prefill_client in prefill_clients: + while True: + try: + # Wait for prefill service to be ready + response = await prefill_client["client"].get("/health") + response.raise_for_status() + except Exception: + await asyncio.sleep(1) + continue + + response = await prefill_client["client"].get( + prefill_client["bootstrap_addr"] + "/query" + ) + response.raise_for_status() + data = response.json() + break + + for dp_rank, dp_entry in data.items(): + prefill_client["dp_engine_id"][int(dp_rank)] = dp_entry["engine_id"] + dp_size = len(data) + prefill_client["dp_size"] = dp_size + print(f"Inited prefiller {prefill_client['url']} with dp_size={dp_size}") + + ready.set() + print("All prefiller instances are ready.") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + app.state.ready = asyncio.Event() + + # Create prefill clients + for i, (url, bootstrap_port) in enumerate(global_args.prefill): + parsed_url = urllib.parse.urlparse(url) + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + app.state.prefill_clients.append( + { + "client": httpx.AsyncClient( + timeout=None, + base_url=url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ), + "url": url, + "bootstrap_addr": make_http_path(hostname, bootstrap_port or 8998), + "dp_engine_id": {}, + } + ) + + # Create decode clients + for i, url in enumerate(global_args.decode): + parsed_url = urllib.parse.urlparse(url) + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + app.state.decode_clients.append( + { + "client": httpx.AsyncClient( + timeout=None, + base_url=url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ), + } + ) + + asyncio.create_task(get_prefiller_info(app.state.prefill_clients, app.state.ready)) + + # Initialize round-robin iterators + app.state.prefill_iterator = prefiller_cycle(app.state.prefill_clients) + app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients))) + + print( + f"Got {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients." + ) + + yield + + # Shutdown: Close all clients + for client_info in app.state.prefill_clients: + await client_info["client"].aclose() + + for client_info in app.state.decode_clients: + await client_info["client"].aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + # Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI + parser.add_argument("--host", type=str, default="127.0.0.1") + + # For prefiller instances + parser.add_argument( + "--prefill", + nargs="+", + action="append", + dest="prefill_raw", + metavar=("URL", "bootstrap_port"), + help=( + "Prefill server URL and optional bootstrap port. " + "Can be specified multiple times. " + "Format: --prefill URL [BOOTSTRAP_PORT]. " + "BOOTSTRAP_PORT can be a port number, " + "'none', or omitted (defaults to none)." + ), + ) + + # For decoder instances + parser.add_argument( + "--decode", + nargs=1, + action="append", + dest="decode_raw", + metavar=("URL",), + help="Decode server URL. Can be specified multiple times.", + ) + + args = parser.parse_args() + args.prefill = _parse_prefill_urls(args.prefill_raw) + args.decode = _parse_decode_urls(args.decode_raw) + + return args + + +# From sglang router_args.py +def _parse_prefill_urls(prefill_list): + """Parse prefill URLs from --prefill arguments. + + Format: --prefill URL [BOOTSTRAP_PORT] + Example: + --prefill http://prefill1:8080 9000 # With bootstrap port + --prefill http://prefill2:8080 none # Explicitly no bootstrap port + --prefill http://prefill3:8080 # Defaults to no bootstrap port + """ + if not prefill_list: + return [] + + prefill_urls = [] + for prefill_args in prefill_list: + url = prefill_args[0] + + # Handle optional bootstrap port + if len(prefill_args) >= 2: + bootstrap_port_str = prefill_args[1] + # Handle 'none' as None + if bootstrap_port_str.lower() == "none": + bootstrap_port = None + else: + try: + bootstrap_port = int(bootstrap_port_str) + except ValueError as e: + raise ValueError( + f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" # noqa: E501 + ) from e + else: + # No bootstrap port specified, default to None + bootstrap_port = None + + prefill_urls.append((url, bootstrap_port)) + + return prefill_urls + + +def _parse_decode_urls(decode_list): + """Parse decode URLs from --decode arguments. + + Format: --decode URL + Example: --decode http://decode1:8081 --decode http://decode2:8081 + """ + if not decode_list: + return [] + + # decode_list is a list of single-element lists due to nargs=1 + return [url[0] for url in decode_list] + + +def get_next_client(app, service_type: str): + """ + Get the next client in round-robin fashion. + + Args: + app: The FastAPI app instance + service_type: Either 'prefill' or 'decode' + + Returns: + The next client to use + """ + if service_type == "prefill": + return next(app.state.prefill_iterator) + elif service_type == "decode": + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def send_request_to_service( + client_info: dict, dp_rank: int, endpoint: str, req_data: dict, request_id: str +): + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data["kv_transfer_params"] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "transfer_id": f"xfer-{request_id}", + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "max_completion_tokens" in req_data: + req_data["max_completion_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + "X-data-parallel-rank": str(dp_rank), + } + + response = await client_info["client"].post( + endpoint, json=req_data, headers=headers + ) + response.raise_for_status() + + # CRITICAL: Release connection back to pool + await response.aclose() + + +async def stream_service_response( + prefill_client_info: dict, + prefill_dp_rank: int, + decode_client_info: dict, + endpoint: str, + req_data: dict, + request_id: str, +): + """ + Asynchronously stream response from a service using a client from the pool. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + + req_data["kv_transfer_params"] = { + "do_remote_decode": False, + "do_remote_prefill": True, + "remote_bootstrap_addr": prefill_client_info["bootstrap_addr"], + "remote_engine_id": prefill_client_info["dp_engine_id"][prefill_dp_rank], + "transfer_id": f"xfer-{request_id}", + } + + async with decode_client_info["client"].stream( + "POST", endpoint, json=req_data, headers=headers + ) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +async def _handle_completions(api: str, request: Request): + if not app.state.ready.is_set(): + raise HTTPException(status_code=503, detail="Service Unavailable") + + try: + req_data = await request.json() + request_id = str(uuid.uuid4()) + + # Get the next prefill client in round-robin fashion + prefill_client_info, prefill_dp_rank = get_next_client(request.app, "prefill") + + # Send request to prefill service + asyncio.create_task( + send_request_to_service( + prefill_client_info, prefill_dp_rank, api, req_data, request_id + ) + ) + + decode_client_info = get_next_client(request.app, "decode") + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response( + prefill_client_info, + prefill_dp_rank, + decode_client_info, + api, + req_data, + request_id=request_id, + ): + yield chunk + + return StreamingResponse(generate_stream(), media_type="application/json") + + except Exception as e: + import sys + import traceback + + exc_info = sys.exc_info() + print(f"Error occurred in disagg prefill proxy server - {api} endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + return await _handle_completions("/v1/completions", request) + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + return await _handle_completions("/v1/chat/completions", request) + + +if __name__ == "__main__": + global global_args + global_args = parse_args() + + import uvicorn + + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh b/examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh new file mode 100644 index 000000000000..e38d377c331a --- /dev/null +++ b/examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh @@ -0,0 +1,222 @@ +#!/bin/bash + +# ============================================================================= +# vLLM Disaggregated Serving Script for Mooncake Connector +# ============================================================================= +# This script demonstrates disaggregated prefill and decode serving using +# Mooncake Connector. +# +# Configuration can be customized via environment variables: +# MODEL: Model to serve +# PREFILL_GPUS: Comma-separated GPU IDs for prefill servers +# DECODE_GPUS: Comma-separated GPU IDs for decode servers +# PREFILL_PORTS: Comma-separated ports for prefill servers +# BOOTSTRAP_PORTS: Bootstrap server port launched by prefill servers +# DECODE_PORTS: Comma-separated ports for decode servers +# PROXY_PORT: Proxy server port used to setup P/D disaggregated connection. +# TIMEOUT_SECONDS: Server startup timeout +# ============================================================================= + +# Configuration - can be overridden via environment variables +MODEL=${MODEL:-Qwen/Qwen2.5-7B-Instruct} +TIMEOUT_SECONDS=${TIMEOUT_SECONDS:-1200} +PROXY_PORT=${PROXY_PORT:-8000} + +PREFILL_GPUS=${PREFILL_GPUS:-0} +DECODE_GPUS=${DECODE_GPUS:-1} +PREFILL_PORTS=${PREFILL_PORTS:-8010} +BOOTSTRAP_PORTS=${BOOTSTRAP_PORTS:-8998} +DECODE_PORTS=${DECODE_PORTS:-8020} + +echo "Warning: Mooncake Connector support for vLLM v1 is experimental and subject to change." +echo "" +echo "Architecture Configuration:" +echo " Model: $MODEL" +echo " Prefill GPUs: $PREFILL_GPUS, Ports: $PREFILL_PORTS, Bootstrap Port:$BOOTSTRAP_PORTS" +echo " Decode GPUs: $DECODE_GPUS, Ports: $DECODE_PORTS" +echo " Proxy Port: $PROXY_PORT" +echo " Timeout: ${TIMEOUT_SECONDS}s" +echo "" + +PIDS=() + +# Switch to the directory of the current script +cd "$(dirname "${BASH_SOURCE[0]}")" + +check_required_files() { + local files=("mooncake_connector_proxy.py") + for file in "${files[@]}"; do + if [[ ! -f "$file" ]]; then + echo "Required file $file not found in $(pwd)" + exit 1 + fi + done +} + +check_hf_token() { + if [ -z "$HF_TOKEN" ]; then + echo "HF_TOKEN is not set. Please set it to your Hugging Face token." + echo "Example: export HF_TOKEN=your_token_here" + exit 1 + fi + if [[ "$HF_TOKEN" != hf_* ]]; then + echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token." + exit 1 + fi + echo "HF_TOKEN is set and valid." +} + +check_num_gpus() { + # Check if the number of GPUs are >=2 via nvidia-smi + num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + if [ "$num_gpus" -lt 2 ]; then + echo "You need at least 2 GPUs to run disaggregated prefill." + exit 1 + else + echo "Found $num_gpus GPUs." + fi +} + +ensure_python_library_installed() { + echo "Checking if $1 is installed..." + if ! python3 -c "import $1" > /dev/null 2>&1; then + echo "$1 is not installed. Please install it via pip install $1." + exit 1 + else + echo "$1 is installed." + fi +} + +cleanup() { + echo "Stopping everything…" + trap - INT TERM # prevent re-entrancy + pkill -9 -f "mooncake_connector_proxy.py" + kill -- -$$ # negative PID == "this whole process-group" + wait # reap children so we don't leave zombies + exit 0 +} + +wait_for_server() { + local port=$1 + local timeout_seconds=$TIMEOUT_SECONDS + local start_time=$(date +%s) + + echo "Waiting for server on port $port..." + + while true; do + if curl -s "localhost:${port}/v1/completions" > /dev/null; then + echo "Server on port $port is ready." + return 0 + fi + + local now=$(date +%s) + if (( now - start_time >= timeout_seconds )); then + echo "Timeout waiting for server on port $port" + return 1 + fi + + sleep 1 + done +} + +main() { + check_required_files + check_hf_token + check_num_gpus + ensure_python_library_installed vllm + ensure_python_library_installed mooncake.engine + + trap cleanup INT + trap cleanup USR1 + trap cleanup TERM + + echo "Launching disaggregated serving components..." + echo "Please check the log files for detailed output:" + echo " - prefill*.log: Prefill server logs" + echo " - decode*.log: Decode server logs" + echo " - proxy.log: Proxy server log" + + # Parse GPU and port arrays + IFS=',' read -ra PREFILL_GPU_ARRAY <<< "$PREFILL_GPUS" + IFS=',' read -ra DECODE_GPU_ARRAY <<< "$DECODE_GPUS" + IFS=',' read -ra PREFILL_PORT_ARRAY <<< "$PREFILL_PORTS" + IFS=',' read -ra BOOTSTRAP_PORT_ARRAY <<< "$BOOTSTRAP_PORTS" + IFS=',' read -ra DECODE_PORT_ARRAY <<< "$DECODE_PORTS" + + proxy_param="" + + # ============================================================================= + # Launch Prefill Servers (X Producers) + # ============================================================================= + echo "" + echo "Starting ${#PREFILL_GPU_ARRAY[@]} prefill server(s)..." + for i in "${!PREFILL_GPU_ARRAY[@]}"; do + local gpu_id=${PREFILL_GPU_ARRAY[$i]} + local port=${PREFILL_PORT_ARRAY[$i]} + local bootstrap_port=${BOOTSTRAP_PORT_ARRAY[$i]} + + echo " Prefill server $((i+1)): GPU $gpu_id, Port $port, Bootstrap Port $bootstrap_port" + VLLM_MOONCAKE_BOOTSTRAP_PORT=$bootstrap_port CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \ + --port $port \ + --kv-transfer-config \ + "{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_producer\"}" > prefill$((i+1)).log 2>&1 & + PIDS+=($!) + proxy_param="${proxy_param} --prefill http://0.0.0.0:${port} $bootstrap_port" + done + + # ============================================================================= + # Launch Decode Servers (Y Decoders) + # ============================================================================= + echo "" + echo "Starting ${#DECODE_GPU_ARRAY[@]} decode server(s)..." + for i in "${!DECODE_GPU_ARRAY[@]}"; do + local gpu_id=${DECODE_GPU_ARRAY[$i]} + local port=${DECODE_PORT_ARRAY[$i]} + + echo " Decode server $((i+1)): GPU $gpu_id, Port $port" + CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \ + --port $port \ + --kv-transfer-config \ + "{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_consumer\"}" > decode$((i+1)).log 2>&1 & + PIDS+=($!) + proxy_param="${proxy_param} --decode http://0.0.0.0:${port}" + done + + # ============================================================================= + # Launch Proxy Server + # ============================================================================= + echo "" + echo "Starting proxy server on port $PROXY_PORT..." + python3 mooncake_connector_proxy.py $proxy_param --port $PROXY_PORT > proxy.log 2>&1 & + PIDS+=($!) + + # ============================================================================= + # Wait for All Servers to Start + # ============================================================================= + echo "" + echo "Waiting for all servers to start..." + for port in "${PREFILL_PORT_ARRAY[@]}" "${DECODE_PORT_ARRAY[@]}"; do + if ! wait_for_server $port; then + echo "Failed to start server on port $port" + cleanup + exit 1 + fi + done + + echo "" + echo "All servers are up. Starting benchmark..." + + # ============================================================================= + # Run Benchmark + # ============================================================================= + vllm bench serve --port $PROXY_PORT --seed $(date +%s) \ + --backend vllm --model $MODEL \ + --dataset-name random --random-input-len 7500 --random-output-len 200 \ + --num-prompts 200 --burstiness 100 --request-rate 2 | tee benchmark.log + + echo "Benchmarking done. Cleaning up..." + + cleanup +} + +main diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 3933f6d6569b..1ceac39711b2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -198,6 +198,6 @@ def get_connector_class( ) KVConnectorFactory.register_connector( "MooncakeConnector", - "vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector", + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector", "MooncakeConnector", ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py similarity index 51% rename from vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py rename to vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index b2b6411f0e05..f105d34928fc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -6,8 +6,10 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from enum import IntEnum from typing import TYPE_CHECKING, Any +import httpx import msgspec import numpy as np import torch @@ -17,6 +19,7 @@ from vllm import envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( + EngineId, TpKVTopology, get_current_attn_backend, ) @@ -25,10 +28,15 @@ KVConnectorMetadata, KVConnectorRole, ) +from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import ( + MooncakeBootstrapServer, + RegisterWorkerPayload, +) from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group, + is_local_first_rank, ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger @@ -43,7 +51,7 @@ except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " - "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " "to run VLLM with MooncakeTransferEngine." ) from e @@ -52,46 +60,75 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request -EngineId = str -ReqId = str - -TRANS_DONE = b"trans_done" -TRANS_ERROR = b"trans_error" +ReqId = str # Internal scheduler request ID +TransferId = str # KV transfer coordination ID (shared by P/D) logger = init_logger(__name__) -class MooncakeAgentMetadata( +class MooncakeXferMetadata( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True, ): remote_hostname: str remote_port: int - request_ids: list[ReqId] + remote_tp_size: int + remote_tp_rank: int + req_blocks: dict[ReqId, tuple[TransferId, list[int]]] kv_caches_base_addr: list[int] - block_ids: list[list[int]] + + +class MooncakeXferResponseStatus(IntEnum): + # Transfer finished + FINISH = 0 + # Continue to receive + CONTINUE = 1 + # Something wrong, see err_msg + ERROR = 2 + + +class MooncakeXferResponse( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] +): + status: MooncakeXferResponseStatus + ok_reqs: list[ReqId] | None = None + err_reqs: list[ReqId] | None = None + err_msg: str | None = None @dataclass -class RecvReqMeta: +class PullReqMeta: + d_req_id: ReqId + transfer_id: TransferId local_block_ids: list[int] - remote_host: str - remote_port: int + remote_engine_id: EngineId + remote_bootstrap_addr: str + # Set expire time to avoid infinitely sending requests. + expire_time: float = float("inf") + # Designed for one D pairing to multiple P + pull_tasks_count: int = 0 @dataclass class SendBlockMeta: + p_req_id: ReqId + transfer_id: TransferId local_block_ids: list[int] ready: asyncio.Event expire_time: float = float("inf") + need_send: int = 0 + sent: int = 0 + sending: int = 0 class MooncakeConnectorMetadata(KVConnectorMetadata): def __init__(self): - self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {} - self.reqs_to_send: dict[ReqId, list[int]] = {} + # Use (engine_id, dp_rank) to group reqs with same dp. + # See comments in MooncakeBootstrapServer. + self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict) + self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {} + self.reqs_not_processed: set[TransferId] = set() def add_new_req( self, @@ -100,14 +137,18 @@ def add_new_req( kv_transfer_params: dict[str, Any], load_remote_cache: bool = True, ): + transfer_id = kv_transfer_params["transfer_id"] if load_remote_cache: - self.reqs_to_recv[request_id] = RecvReqMeta( + remote_engine_id = kv_transfer_params["remote_engine_id"] + self.reqs_to_recv[remote_engine_id][request_id] = PullReqMeta( + d_req_id=request_id, local_block_ids=local_block_ids, - remote_host=kv_transfer_params["remote_host"], - remote_port=kv_transfer_params["remote_port"], + remote_engine_id=remote_engine_id, + remote_bootstrap_addr=kv_transfer_params["remote_bootstrap_addr"], + transfer_id=transfer_id, ) else: - self.reqs_to_send[request_id] = local_block_ids + self.reqs_to_send[request_id] = (transfer_id, local_block_ids) class MooncakeConnector(KVConnectorBase_V1): @@ -209,19 +250,24 @@ class MooncakeConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config - self.engine_id: EngineId = engine_id - self.side_channel_host = get_ip() - self.side_channel_port = get_mooncake_side_channel_port(vllm_config) assert vllm_config.kv_transfer_config - self.kv_role = vllm_config.kv_transfer_config.kv_role + self.is_kv_producer: bool = ( + vllm_config.kv_transfer_config.kv_role == "kv_producer" + ) + self.is_kv_consumer: bool = ( + vllm_config.kv_transfer_config.kv_role == "kv_consumer" + ) logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id) # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} - self._reqs_need_send: dict[ReqId, list[int]] = {} + self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {} + # Reqs to remove from processed set because they're not to send after + # remote prefill or aborted. + self._reqs_not_processed: set[TransferId] = set() def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int @@ -249,8 +295,12 @@ def get_num_new_matched_tokens( params, ) - if params is not None and params.get("do_remote_prefill"): + if not params: + return 0, False + + if params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. + assert not self.is_kv_producer token_ids = request.prompt_token_ids or [] count = len(token_ids) - num_computed_tokens if count > 0: @@ -265,7 +315,8 @@ def update_state_after_alloc( params = request.kv_transfer_params logger.debug( "MooncakeConnector update_state_after_alloc: " - "num_external_tokens=%s, kv_transfer_params=%s", + "req_id=%s num_external_tokens=%s, kv_transfer_params=%s", + request.request_id, num_external_tokens, params, ) @@ -274,8 +325,11 @@ def update_state_after_alloc( return if params.get("do_remote_prefill"): - assert self.kv_role != "kv_producer" - if all(p in params for p in ("remote_host", "remote_port")): + assert not self.is_kv_producer + if all( + p in params + for p in ("remote_engine_id", "remote_bootstrap_addr", "transfer_id") + ): # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call # send_notif in _read_blocks to free the memory on the P. @@ -294,8 +348,12 @@ def update_state_after_alloc( params["do_remote_prefill"] = False elif params.get("do_remote_decode"): - # Add an empty list to worker to create event. - self._reqs_need_send[request.request_id] = [] + assert not self.is_kv_consumer + if not params.get("transfer_id"): + logger.warning("Missing transfer_id in kv_transfer_params from router!") + else: + # Add an empty list to worker to create event. + self._reqs_need_send[request.request_id] = (request, []) def build_connector_meta( self, @@ -303,8 +361,8 @@ def build_connector_meta( ) -> KVConnectorMetadata: meta = MooncakeConnectorMetadata() - # Loop through scheduled reqs and convert to RecvReqMeta. - if self.kv_role != "kv_producer": + # Loop through scheduled reqs and convert to PullReqMeta. + if not self.is_kv_producer: for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None meta.add_new_req( @@ -314,15 +372,18 @@ def build_connector_meta( ) self._reqs_need_recv.clear() - if self.kv_role != "kv_consumer": - for req_id, block_ids in self._reqs_need_send.items(): + if not self.is_kv_consumer: + for req_id, (req, block_ids) in self._reqs_need_send.items(): + assert req.kv_transfer_params is not None meta.add_new_req( request_id=req_id, local_block_ids=block_ids, - kv_transfer_params={}, + kv_transfer_params=req.kv_transfer_params, load_remote_cache=False, ) self._reqs_need_send.clear() + meta.reqs_not_processed = self._reqs_not_processed + self._reqs_not_processed = set() return meta @@ -338,12 +399,13 @@ def request_finished( params = request.kv_transfer_params logger.debug( - "MooncakeConnector request_finished, request_status=%s, " + "MooncakeConnector request_finished, req_id=%s, request_status=%s, " "kv_transfer_params=%s", + request.request_id, request.status, params, ) - if not params: + if not params or not params.get("transfer_id"): return False, None if params.get("do_remote_prefill"): @@ -353,32 +415,30 @@ def request_finished( # To avoid stranding the prefill blocks in the prefill instance, # we must add empty block_ids to _reqs_need_recv so that our # worker side will notify and free blocks in the prefill instance. - assert self.kv_role != "kv_producer" + assert not self.is_kv_producer self._reqs_need_recv[request.request_id] = (request, []) params["do_remote_prefill"] = False return False, None - if ( - not params.get("do_remote_decode") - or request.status != RequestStatus.FINISHED_LENGTH_CAPPED - ): + if not params.get("do_remote_decode"): return False, None - assert self.kv_role != "kv_consumer" + assert not self.is_kv_consumer + + if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: + # Also include the case of a P/D Prefill request with immediate + # block free (eg abort). Stop tracking this request. + self._reqs_not_processed.add(params["transfer_id"]) + return False, None # TODO: check whether block_ids actually ever be 0. If not we could # remove the conditional below delay_free_blocks = len(block_ids) > 0 if delay_free_blocks: - self._reqs_need_send[request.request_id] = block_ids + self._reqs_need_send[request.request_id] = (request, block_ids) - return delay_free_blocks, dict( - do_remote_prefill=True, - do_remote_decode=False, - remote_host=self.side_channel_host, - remote_port=self.side_channel_port, - ) + return delay_free_blocks, None class MooncakeConnectorWorker: @@ -391,7 +451,18 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.engine = TransferEngine() self.hostname = get_ip() - protocol = self.vllm_config.kv_transfer_config.kv_connector_extra_config.get( # type: ignore[union-attr] + + assert (kv_transfer_config := vllm_config.kv_transfer_config) + self.is_kv_producer: bool = kv_transfer_config.kv_role == "kv_producer" + self.is_kv_consumer: bool = kv_transfer_config.kv_role == "kv_consumer" + self.num_sender_workers = kv_transfer_config.kv_connector_extra_config.get( + "num_workers", 10 + ) + # Create more tasks than workers to keep the thread pool saturated. + # Tasks can await async events, so a surplus (2x is a robust heuristic) + # prevents workers from idling. + self.num_sender_tasks = self.num_sender_workers * 2 + protocol = kv_transfer_config.kv_connector_extra_config.get( # type: ignore[union-attr] "mooncake_protocol", "rdma" ) logger.info( @@ -409,33 +480,31 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.rpc_port, ) - # Mooncake handshake port. - self.side_channel_port: int = get_mooncake_side_channel_port(vllm_config) - + self._remote_agents: dict[EngineId, dict[int, dict[int, str]]] = {} + self._pending_bootstrap_querys: dict[str, asyncio.Event] = {} + self.side_channel_port: int = 0 # we will bind it in register_kv_caches() self.engine_id: EngineId = engine_id self.tp_rank = get_tensor_model_parallel_rank() - self.world_size = get_tensor_model_parallel_world_size() - self.tp_group = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() self.num_blocks = 0 - assert vllm_config.kv_transfer_config - self.kv_role = vllm_config.kv_transfer_config.kv_role - self.num_sender_workers = ( - vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "num_workers", 10 + assert (parallel_config := vllm_config.parallel_config) + dp_rank = parallel_config.data_parallel_index + dp_local_rank = parallel_config.data_parallel_rank_local + self.dp_rank = dp_local_rank if parallel_config.local_engines_only else dp_rank + pp_size = vllm_config.parallel_config.pipeline_parallel_size + if pp_size > 1: + raise ValueError( + "Mooncake Transfer Engine does not support pipeline parallelism yet." ) - ) - # Create more tasks than workers to keep the thread pool saturated. - # Tasks can await async events, so a surplus (2x is a robust heuristic) - # prevents workers from idling. - self.num_sender_tasks = self.num_sender_workers * 2 + self.pp_rank = get_pp_group().rank_in_group self.kv_caches_base_addr: list[int] = [] self.device_kv_caches: dict[str, torch.Tensor] = {} - self.reqs_need_send: dict[ReqId, SendBlockMeta] = {} + self.reqs_need_send: dict[TransferId, SendBlockMeta] = {} # For kv_both, we will act both prefiller and decoder. - if self.kv_role != "kv_consumer": + if not self.is_kv_consumer: # Background threads for sending kvcaches to D. self._sender_executor = ThreadPoolExecutor( max_workers=self.num_sender_workers, @@ -454,7 +523,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): ) self._sender_listener_t.start() - if self.kv_role != "kv_producer": + # Start bootstrap server on global rank 0. + if should_launch_bootstrap_server(vllm_config): + _, port = get_mooncake_bootstrap_addr(vllm_config) + self.bootstrap_server = MooncakeBootstrapServer( + vllm_config, "0.0.0.0", port + ) + self.bootstrap_server.start() + + if not self.is_kv_producer: self.receiver_loop = asyncio.new_event_loop() self._mooncake_receiver_t = threading.Thread( target=_async_loop, args=(self.receiver_loop,), daemon=True @@ -478,7 +555,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) - self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} self.kv_topo = TpKVTopology( tp_rank=self.tp_rank, @@ -492,7 +569,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.async_zmq_ctx = zmq.asyncio.Context() self._encoder = msgspec.msgpack.Encoder() - self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) + self._xfer_meta_decoder = msgspec.msgpack.Decoder(MooncakeXferMetadata) + self._xfer_resp_decoder = msgspec.msgpack.Decoder(MooncakeXferResponse) def __del__(self): self.shutdown() @@ -500,26 +578,62 @@ def __del__(self): def shutdown(self): """Cleanup background threads on destruction.""" self.async_zmq_ctx.term() - if self.kv_role != "kv_consumer": + if not self.is_kv_consumer: self._sender_executor.shutdown(wait=False) if self.sender_loop.is_running(): self.sender_loop.call_soon_threadsafe(self.sender_loop.stop) self._sender_listener_t.join() - if self.kv_role != "kv_producer" and self.receiver_loop.is_running(): + if should_launch_bootstrap_server(self.vllm_config): + self.bootstrap_server.shutdown() + if not self.is_kv_producer and self.receiver_loop.is_running(): self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop) self._mooncake_receiver_t.join() - async def _mooncake_sender_listener( - self, ready_event: threading.Event, base_port: int, tp_rank: int - ): + async def register_worker_with_bootstrap(self): + host, port = get_mooncake_bootstrap_addr(self.vllm_config) + url = make_zmq_path("http", host, port) + "/register" + worker_addr = make_zmq_path("tcp", self.hostname, self.side_channel_port) + payload = RegisterWorkerPayload( + engine_id=self.engine_id, + dp_rank=self.dp_rank, + tp_rank=self.tp_rank, + pp_rank=self.pp_rank, + addr=worker_addr, + ) + while True: + try: + async with httpx.AsyncClient() as client: + response = await client.post(url, json=payload.model_dump()) + response.raise_for_status() + logger.debug("Successfully registered with bootstrap server at %s", url) + break + except httpx.ConnectError: + # Bootstrap server not ready, wait for a while and retry. + await asyncio.sleep(1) + except Exception as e: + err_msg = ( + e.response.text if isinstance(e, httpx.HTTPStatusError) else str(e) + ) + logger.error( + "Error registering %s with bootstrap server: %s", payload, err_msg + ) + raise e + + async def _mooncake_sender_listener(self, ready_event: threading.Event): """ Background thread that listens for Mooncake requests, dispatches them to a thread pool, and sends acknowledgments upon completion. """ - path = make_zmq_path("tcp", self.hostname, base_port + tp_rank) - sock = make_zmq_socket(self.async_zmq_ctx, path, zmq.ROUTER) - logger.debug("Mooncake sender starting listening on path: %s", path) + sock = self.async_zmq_ctx.socket(zmq.ROUTER) + self.side_channel_port = sock.bind_to_random_port(f"tcp://{self.hostname}") + logger.debug( + "Mooncake sender starting listening on path: tcp://%s:%d", + self.hostname, + self.side_channel_port, + ) + + await self.register_worker_with_bootstrap() # Create async worker tasks that process items from the queue sender_tasks = [ @@ -531,7 +645,7 @@ async def _mooncake_sender_listener( try: while True: - identity, _, metadata_bytes = await sock.recv_multipart() + identity, metadata_bytes = await sock.recv_multipart() await self.sender_worker_queue.put((identity, metadata_bytes)) except zmq.ContextTerminated: logger.debug("ZMQ context terminated, exiting Mooncake sender thread.") @@ -549,12 +663,16 @@ async def _sender_worker(self, sock: zmq.asyncio.Socket): try: identity, metadata_bytes = await self.sender_worker_queue.get() try: - metadata = self._decoder.decode(metadata_bytes) - await self.send_kv_to_decode(metadata) - await sock.send_multipart((identity, b"", TRANS_DONE)) + metadata = self._xfer_meta_decoder.decode(metadata_bytes) + await self.send_kv_to_decode(identity, sock, metadata) except Exception as e: logger.error("Error processing Mooncake xfer request: %s", e) - await sock.send_multipart((identity, b"", TRANS_ERROR)) + error_response = MooncakeXferResponse( + status=MooncakeXferResponseStatus.ERROR, err_msg=str(e) + ) + await sock.send_multipart( + (identity, self._encoder.encode(error_response)) + ) finally: self.sender_worker_queue.task_done() except asyncio.CancelledError: @@ -562,55 +680,169 @@ async def _sender_worker(self, sock: zmq.asyncio.Socket): except Exception as e: logger.error("Error in _sender_worker: %s", e) - async def send_kv_to_decode(self, meta: MooncakeAgentMetadata): - send_reqs: list[tuple[ReqId, SendBlockMeta]] = [] - for req_id in meta.request_ids: - send_meta = self.reqs_need_send.get(req_id) - if send_meta is None: - logger.warning("Request %s not found in reqs_need_send", req_id) - return - # Mark it as not expired. We will send it now. - send_meta.expire_time = float("inf") - send_reqs.append((req_id, send_meta)) - - src_ptrs, dst_ptrs, lengths = await self._build_transfer_params(send_reqs, meta) - remote_session = f"{meta.remote_hostname}:{meta.remote_port}" - ret_value = await self.sender_loop.run_in_executor( - self._sender_executor, - self._send_blocks, - remote_session, - src_ptrs, - dst_ptrs, - lengths, - ) + async def send_kv_to_decode( + self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata + ): + pending_reqs: dict[ReqId, SendBlockMeta] = {} + remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size) + if self.tp_rank not in remote_tp_ranks: + # This D worker does not pair with the P worker. + msg = f"This P tp_rank {self.tp_rank} not in remote D target ranks {remote_tp_ranks}" # noqa: E501 + logger.error(msg) + response = MooncakeXferResponse( + status=MooncakeXferResponseStatus.ERROR, + err_msg=msg, + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + return + for d_req_id, (transfer_id, _) in meta.req_blocks.items(): + if transfer_id not in self.reqs_need_send: + # This req is not enqueued in P side yet, create it here. + self.reqs_need_send[transfer_id] = SendBlockMeta( + p_req_id="", + transfer_id=transfer_id, + local_block_ids=[], + ready=asyncio.Event(), + ) + send_meta = self.reqs_need_send[transfer_id] + pending_reqs[d_req_id] = send_meta - if ret_value != 0: - raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}") + async def wait_and_ret( + d_req_id: ReqId, send_meta: SendBlockMeta + ) -> tuple[ReqId, SendBlockMeta]: + await send_meta.ready.wait() + return d_req_id, send_meta + + wait_tasks = [ + asyncio.create_task(wait_and_ret(d_req_id, send_meta)) + for d_req_id, send_meta in pending_reqs.items() + ] + + while wait_tasks: + done, pending = await asyncio.wait( + wait_tasks, + timeout=envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, + return_when=asyncio.FIRST_COMPLETED, + ) + + if not done: + # Timeout, abort all pending requests. + for task in wait_tasks: + task.cancel() + logger.warning( + "Timeout waiting for P side ready: %s", list(pending_reqs) + ) + response = MooncakeXferResponse( + status=MooncakeXferResponseStatus.FINISH, + err_reqs=list(pending_reqs), + err_msg="Timeout waiting for P side ready.", + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + break + + wait_tasks = list(pending) + response_status = ( + MooncakeXferResponseStatus.CONTINUE + if wait_tasks + else MooncakeXferResponseStatus.FINISH + ) + ready_reqs: list[tuple[ReqId, SendBlockMeta]] = [] + for task in done: + d_req_id, send_meta = task.result() + del pending_reqs[d_req_id] + # Do we still in reqs_need_send (not expired)? + if send_meta.transfer_id in self.reqs_need_send: + # Mark it sending to avoid expiration. + send_meta.sending += 1 + if not send_meta.need_send: + self.resolve_need_send(send_meta, remote_tp_ranks) + ready_reqs.append((d_req_id, send_meta)) + else: + # Otherwise (expired, very unlikely), just forget it. + logger.warning( + "Request %s expired before sending on P side.", d_req_id + ) - for req_id in meta.request_ids: - del self.reqs_need_send[req_id] + src_ptrs, dst_ptrs, lengths, err_reqs = await self._build_transfer_params( + ready_reqs, meta + ) - self.finished_sending_reqs.update(meta.request_ids) + if err_reqs: + response = MooncakeXferResponse( + status=response_status, + err_reqs=err_reqs, + err_msg="P num blocks less than D", + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + + if src_ptrs: + remote_session = f"{meta.remote_hostname}:{meta.remote_port}" + ret_value = await self.sender_loop.run_in_executor( + self._sender_executor, + self._send_blocks, + remote_session, + src_ptrs, + dst_ptrs, + lengths, + ) + + if ret_value != 0: + err_reqs = [] + for d_req_id, send_meta in ready_reqs: + send_meta.sending -= 1 + err_reqs.append(d_req_id) + # Do best effort to transfer the remaining reqs. + response = MooncakeXferResponse( + status=response_status, + err_reqs=err_reqs, + err_msg=f"Mooncake transfer engine returned {ret_value}", + ) + await sock.send_multipart( + (identity, self._encoder.encode(response)) + ) + continue + + for d_req_id, send_meta in ready_reqs: + # TODO: for heterogeneous TP (one P pairs to multiple D), + # we need to check whether all headers are sent. + # If not, we should set expire_time to normal and skip the below. + send_meta.sending -= 1 + send_meta.sent += 1 + if send_meta.sent == send_meta.need_send: + del self.reqs_need_send[send_meta.transfer_id] + self.finished_sending_reqs.add(send_meta.p_req_id) + + response = MooncakeXferResponse( + status=response_status, + ok_reqs=[d_req_id for d_req_id, _ in ready_reqs], + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + + def resolve_need_send(self, send_meta: SendBlockMeta, remote_tp_ranks: list[int]): + # Prepare for heterogeneous TP (one P pairs to multiple D) + send_meta.need_send = len(remote_tp_ranks) + if send_meta.need_send != 1: + logger.error("Mooncake: Heterogeneous TP is not supported yet.") + raise NotImplementedError( + "Mooncake: Heterogeneous TP is not supported yet." + ) async def _build_transfer_params( self, - send_reqs: list[tuple[ReqId, SendBlockMeta]], - agent_meta: MooncakeAgentMetadata, - ) -> tuple[list[int], list[int], list[int]]: + ready_reqs: list[tuple[ReqId, SendBlockMeta]], + agent_meta: MooncakeXferMetadata, + ) -> tuple[list[int], list[int], list[int], list[ReqId]]: src_ptrs = [] dst_ptrs = [] lengths = [] + err_reqs: list[ReqId] = [] local_base_addr = self.kv_caches_base_addr remote_base_addr = agent_meta.kv_caches_base_addr block_len = self.block_len remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}" - assert len(send_reqs) == len(agent_meta.block_ids) - for (req_id, send_meta), remote_block_ids in zip( - send_reqs, agent_meta.block_ids - ): - await send_meta.ready.wait() - + for d_req_id, send_meta in ready_reqs: + _, remote_block_ids = agent_meta.req_blocks[d_req_id] num_remote_blocks = len(remote_block_ids) if num_remote_blocks == 0: continue @@ -618,7 +850,15 @@ async def _build_transfer_params( local_block_ids = send_meta.local_block_ids # Partial prefix cache hit: just read uncomputed blocks. num_local_blocks = len(local_block_ids) - assert num_local_blocks >= num_remote_blocks + if num_local_blocks < num_remote_blocks: + logger.error( + "req %s: local blocks(%d) less than remote blocks(%d)!", + d_req_id, + num_local_blocks, + num_remote_blocks, + ) + err_reqs.append(d_req_id) + continue if num_local_blocks > num_remote_blocks: local_block_ids = local_block_ids[-num_remote_blocks:] @@ -643,12 +883,12 @@ async def _build_transfer_params( logger.debug( "Sending kv_caches for request %s (%d blocks) to %s", - req_id, + d_req_id, num_remote_blocks, remote_session, ) - return src_ptrs, dst_ptrs, lengths + return src_ptrs, dst_ptrs, lengths, err_reqs def _send_blocks( self, @@ -722,15 +962,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) # No need to launch server for D node. - if self.kv_role == "kv_consumer": + if self.is_kv_consumer: return ready_event = threading.Event() asyncio.run_coroutine_threadsafe( - self._mooncake_sender_listener( - ready_event, self.side_channel_port, self.tp_rank - ), - self.sender_loop, + self._mooncake_sender_listener(ready_event), self.sender_loop ) ready_event.wait() # Wait for listener ZMQ socket to be ready. @@ -745,21 +982,25 @@ async def fetch_finished_sending_reqs(self) -> set[ReqId]: # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter() - expired_reqs = [ - req_id - for req_id, send_meta in self.reqs_need_send.items() - if send_meta.expire_time < now - ] - for req_id in expired_reqs: - logger.warning( - "Request %s timed out after %d seconds without " - "being sent. Freeing its blocks on the producer side.", - req_id, - envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, - ) - del self.reqs_need_send[req_id] - if expired_reqs: - finished_sending_reqs.update(expired_reqs) + + expired_transfer_id = [] + for transfer_id, send_meta in self.reqs_need_send.items(): + if ( + send_meta.p_req_id + and send_meta.expire_time < now + and send_meta.sending == 0 + ): + logger.warning( + "Request %s timed out after %d seconds without " + "being sent. Freeing its blocks on the producer side.", + send_meta.p_req_id, + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, + ) + finished_sending_reqs.add(send_meta.p_req_id) + expired_transfer_id.append(transfer_id) + + for transfer_id in expired_transfer_id: + del self.reqs_need_send[transfer_id] return finished_sending_reqs @@ -771,12 +1012,12 @@ def get_finished(self) -> tuple[set[str] | None, set[str] | None]: """ recv_fut = None send_fut = None - if self.kv_role != "kv_producer": + if not self.is_kv_producer: recv_fut = asyncio.run_coroutine_threadsafe( self.fetch_finished_recving_reqs(), self.receiver_loop ) - if self.kv_role != "kv_consumer": + if not self.is_kv_consumer: send_fut = asyncio.run_coroutine_threadsafe( self.fetch_finished_sending_reqs(), self.sender_loop ) @@ -795,69 +1036,174 @@ def get_finished(self) -> tuple[set[str] | None, set[str] | None]: return finished_sending_reqs or None, finished_recving_reqs or None - async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]): - req_ids, block_ids = map(list, zip(*req_blocks)) - metadata = MooncakeAgentMetadata( + async def receive_kv_from_single_worker( + self, + worker_addr: str, + pull_metas: dict[ReqId, PullReqMeta], + ): + req_ids = set(pull_metas) + metadata = MooncakeXferMetadata( remote_hostname=self.hostname, remote_port=self.rpc_port, - request_ids=req_ids, + remote_tp_size=self.tp_size, + remote_tp_rank=self.tp_rank, + req_blocks={ + req_id: (pull_meta.transfer_id, pull_meta.local_block_ids) + for req_id, pull_meta in pull_metas.items() + }, kv_caches_base_addr=self.kv_caches_base_addr, - block_ids=block_ids, ) encoded_data = self._encoder.encode(metadata) logger.debug( - "Size of encoded MooncakeAgentMetadata: %d bytes", len(encoded_data) + "Size of encoded MooncakeXferMetadata: %d bytes", len(encoded_data) + ) + logger.debug( + "Sending kv transfer request for %s on path: %s", req_ids, worker_addr ) - logger.debug("Sending kv transfer request for %s on path: %s", req_ids, path) # Send query for the request. - sock: zmq.asyncio.Socket = make_zmq_socket( - self.async_zmq_ctx, path, zmq.REQ, bind=False, linger=0 - ) - sock.setsockopt(zmq.RCVTIMEO, 60000) try: - await sock.send(encoded_data) - ret_msg = await sock.recv() - if ret_msg != TRANS_DONE: - logger.error( - "Error happens during tranfering kvcache for %s, see logs in prefiller.", # noqa: E501 - req_ids, + with make_zmq_socket( + self.async_zmq_ctx, worker_addr, zmq.DEALER, bind=False, linger=0 + ) as sock: + # If something goes wrong, let P wait timeout first (in asyncio.wait()). + sock.setsockopt( + zmq.RCVTIMEO, (envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT + 60) * 1000 ) - return + await sock.send(encoded_data) + while True: + ret_msg = await sock.recv() + response = self._xfer_resp_decoder.decode(ret_msg) + if response.status == MooncakeXferResponseStatus.ERROR: + logger.error( + "Error happens during tranfering kvcache for %s: %s", + req_ids, + response.err_msg, + ) + return + self.process_pulling_result(response, pull_metas) + if response.status == MooncakeXferResponseStatus.FINISH: + break except zmq.ContextTerminated: logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.") except Exception as e: - logger.error("MooncakeAgentMetadata transfer failed for %s: %s", req_ids, e) + logger.error("MooncakeXferMetadata transfer failed for %s: %s", req_ids, e) return - finally: - sock.close() - self.finished_recving_reqs.update(req_ids) + def process_pulling_result( + self, + response: MooncakeXferResponse, + pull_metas: dict[ReqId, PullReqMeta], + ): + ok_reqs: list[ReqId] = response.ok_reqs or [] + + for req_id in ok_reqs: + pull_meta = pull_metas[req_id] + # No race because we are in async loop. + pull_meta.pull_tasks_count -= 1 + if pull_meta.pull_tasks_count == 0: + self.finished_recving_reqs.add(pull_meta.d_req_id) + + if ok_reqs: + logger.debug("pulling kv_caches for %s finished", ok_reqs) + + if response.err_reqs: + logger.error( + "pulling kv_caches for %s failed: %s", + response.err_reqs, + response.err_msg, + ) - logger.debug("pulling kv_caches for %s finished", req_ids) + async def _connect_to_prefiller_bootstrap(self, remote_bootstrap_addr: str): + url = remote_bootstrap_addr + "/query" + try: + async with httpx.AsyncClient() as client: + response = await client.get(url) + response.raise_for_status() + data: dict = response.json() + for _, dp_entry in data.items(): + remote_engine_id = dp_entry["engine_id"] + self._remote_agents[remote_engine_id] = { + int(tp_rank): { + int(pp_rank): worker_addr + for pp_rank, worker_addr in tp_entry.items() + } + for tp_rank, tp_entry in dp_entry["worker_addr"].items() + } + self._tp_size[remote_engine_id] = len(dp_entry["worker_addr"]) + except Exception as e: + logger.error( + "Failed to connect to bootstrap server %s: %s", + remote_bootstrap_addr, + e, + ) - def group_kv_pull(self, metadata: MooncakeConnectorMetadata): - kv_pulls = defaultdict(list) - for req_id, meta in metadata.reqs_to_recv.items(): - logger.debug( - "start_load_kv for request %s from remote engine. " - "Num local_block_ids: %s.", - req_id, - len(meta.local_block_ids), + # Always notify others regardless of connection success or failure. + self._pending_bootstrap_querys[remote_bootstrap_addr].set() + del self._pending_bootstrap_querys[remote_bootstrap_addr] + + def receive_kv( + self, + remote_engine_id: EngineId, + pull_metas: dict[ReqId, PullReqMeta], + ): + remote_tp_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( + remote_engine_id + ) + count = len(remote_tp_ranks) + if count != 1: + logger.error("Mooncake: Heterogeneous TP is not supported yet.") + raise NotImplementedError( + "Mooncake: Heterogeneous TP is not supported yet." ) - path = make_zmq_path( - "tcp", meta.remote_host, meta.remote_port + self.tp_rank + for pull_meta in pull_metas.values(): + pull_meta.pull_tasks_count = count + for remote_tp_rank in remote_tp_ranks: + worker_addr = self._remote_agents[remote_engine_id][remote_tp_rank][0] + asyncio.create_task( + self.receive_kv_from_single_worker(worker_addr, pull_metas) + ) + + async def handle_new_engine_id( + self, + remote_engine_id: EngineId, + pull_metas: dict[ReqId, PullReqMeta], + ): + remote_bootstrap_addr = next(iter(pull_metas.values())).remote_bootstrap_addr + if remote_bootstrap_addr not in self._pending_bootstrap_querys: + self._pending_bootstrap_querys[remote_bootstrap_addr] = asyncio.Event() + await self._connect_to_prefiller_bootstrap(remote_bootstrap_addr) + else: + await self._pending_bootstrap_querys[remote_bootstrap_addr].wait() + + if remote_engine_id not in self._remote_agents: + logger.error( + "Failed to find remote engine_id %s from bootstrap server %s", + remote_engine_id, + remote_bootstrap_addr, ) - kv_pulls[path].append((req_id, meta.local_block_ids)) + return + + self.receive_kv(remote_engine_id, pull_metas) - return kv_pulls + async def _start_load_kv( + self, reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] + ): + for remote_engine_id, pull_metas in reqs_to_recv.items(): + if remote_engine_id not in self._remote_agents: + asyncio.create_task( + self.handle_new_engine_id(remote_engine_id, pull_metas) + ) + else: + self.receive_kv(remote_engine_id, pull_metas) async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): - for req_id, block_ids in metadata.reqs_to_send.items(): + for p_req_id, (transfer_id, block_ids) in metadata.reqs_to_send.items(): if block_ids: # Already gone through request_finished() - send_meta = self.reqs_need_send[req_id] + send_meta = self.reqs_need_send[transfer_id] + send_meta.p_req_id = p_req_id send_meta.local_block_ids = block_ids send_meta.expire_time = ( time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT @@ -866,20 +1212,29 @@ async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): else: # From update_state_after_alloc(), # but not reach request_finished() yet - self.reqs_need_send[req_id] = SendBlockMeta( - local_block_ids=[], - ready=asyncio.Event(), - ) + # This may be already created by send_kv_to_decode() + # when D is sending MooncakeXferMetadata. + if transfer_id not in self.reqs_need_send: + self.reqs_need_send[transfer_id] = SendBlockMeta( + p_req_id=p_req_id, + transfer_id=transfer_id, + local_block_ids=[], + ready=asyncio.Event(), + ) + for transfer_id in metadata.reqs_not_processed: + send_meta = self.reqs_need_send.pop(transfer_id) + if send_meta: + assert not send_meta.ready.is_set() def start_load_kv(self, metadata: MooncakeConnectorMetadata): - if self.kv_role != "kv_producer": - kv_pulls = self.group_kv_pull(metadata) - for path, req_blocks in kv_pulls.items(): - asyncio.run_coroutine_threadsafe( - self.receive_kv(path, req_blocks), self.receiver_loop - ) + if not self.is_kv_producer and metadata.reqs_to_recv: + asyncio.run_coroutine_threadsafe( + self._start_load_kv(metadata.reqs_to_recv), self.receiver_loop + ) - if self.kv_role != "kv_consumer": + if not self.is_kv_consumer and ( + metadata.reqs_to_send or metadata.reqs_not_processed + ): asyncio.run_coroutine_threadsafe( self.record_send_reqs(metadata), self.sender_loop ) @@ -914,3 +1269,31 @@ def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int: def _async_loop(loop: asyncio.AbstractEventLoop): asyncio.set_event_loop(loop) loop.run_forever() + + +def should_launch_bootstrap_server(vllm_config: VllmConfig) -> bool: + assert (parallel_config := vllm_config.parallel_config) + # In hybrid or external LB mode, + # each instance should have its own bootstrap server. + # + # In internal LB mode, + # only the real global first rank need to launch the bootstrap server. + return is_local_first_rank() and ( + parallel_config.local_engines_only or parallel_config.data_parallel_index == 0 + ) + + +def get_mooncake_bootstrap_addr(vllm_config: VllmConfig) -> tuple[str, int]: + """ + Returns the address of the Mooncake bootstrap server. + This is only used by prefillers to register workers. + Decoders should get addr from kv_transfer_params. + """ + assert (parallel_config := vllm_config.parallel_config) + if parallel_config.local_engines_only: + # In hybrid or external LB mode, connect to local server. + host = "127.0.0.1" + else: + host = parallel_config.data_parallel_master_ip + port = envs.VLLM_MOONCAKE_BOOTSTRAP_PORT + return (host, port) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py new file mode 100644 index 000000000000..d1a9946709d8 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +import time +from dataclasses import dataclass + +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.utils import EngineId +from vllm.logger import init_logger + +WorkerAddr = str + +logger = init_logger(__name__) + + +class RegisterWorkerPayload(BaseModel): + engine_id: EngineId + dp_rank: int + tp_rank: int + pp_rank: int + addr: WorkerAddr + + +@dataclass +class EngineEntry: + engine_id: EngineId + # {tp_rank: {pp_rank: worker_addr}} + worker_addr: dict[int, dict[int, WorkerAddr]] + + +class MooncakeBootstrapServer: + """ + A centralized server running on the global rank 0 prefiller worker. + Prefiller workers register their connection info (IP, port, ranks) here. + """ + + def __init__(self, vllm_config: VllmConfig, host: str, port: int): + self.workers: dict[int, EngineEntry] = {} + + self.host = host + self.port = port + self.app = FastAPI() + self._register_routes() + self.server_thread: threading.Thread | None = None + self.server: uvicorn.Server | None = None + + def __del__(self): + self.shutdown() + + def _register_routes(self): + # All methods are async. No need to use lock to protect data. + self.app.post("/register")(self.register_worker) + self.app.get("/query", response_model=dict[int, EngineEntry])(self.query) + + def start(self): + if self.server_thread: + return + + config = uvicorn.Config(app=self.app, host=self.host, port=self.port) + self.server = uvicorn.Server(config=config) + self.server_thread = threading.Thread( + target=self.server.run, name="mooncake_bootstrap_server", daemon=True + ) + self.server_thread.start() + while not self.server.started: + time.sleep(0.1) # Wait for the server to start + logger.info("Mooncake Bootstrap Server started at %s:%d", self.host, self.port) + + def shutdown(self): + if self.server_thread is None or self.server is None or not self.server.started: + return + + self.server.should_exit = True + self.server_thread.join() + logger.info("Mooncake Bootstrap Server stopped.") + + async def register_worker(self, payload: RegisterWorkerPayload): + """Handles registration of a prefiller worker.""" + if payload.dp_rank not in self.workers: + self.workers[payload.dp_rank] = EngineEntry( + engine_id=payload.engine_id, + worker_addr={}, + ) + + dp_entry = self.workers[payload.dp_rank] + if dp_entry.engine_id != payload.engine_id: + raise HTTPException( + status_code=400, + detail=( + f"Engine ID mismatch for dp_rank={payload.dp_rank}: " + f"expected {dp_entry.engine_id}, got {payload.engine_id}" + ), + ) + if payload.tp_rank not in dp_entry.worker_addr: + dp_entry.worker_addr[payload.tp_rank] = {} + + tp_entry = dp_entry.worker_addr[payload.tp_rank] + if payload.pp_rank in tp_entry: + raise HTTPException( + status_code=400, + detail=( + f"Worker with dp_rank={payload.dp_rank}, " + f"tp_rank={payload.tp_rank}, pp_rank={payload.pp_rank} " + f"is already registered at " + f"{tp_entry[payload.pp_rank]}, " + f"but still want to register at {payload.addr}" + ), + ) + + tp_entry[payload.pp_rank] = payload.addr + logger.debug( + "Registered worker: engine_id=%s, dp_rank=%d, tp_rank=%d, pp_rank=%d at %s", + payload.engine_id, + payload.dp_rank, + payload.tp_rank, + payload.pp_rank, + payload.addr, + ) + + return {"status": "ok"} + + async def query(self) -> dict[int, EngineEntry]: + return self.workers