From f1b4b2802cde3faa894510acb65beead29f57597 Mon Sep 17 00:00:00 2001 From: Tianchen Ding Date: Mon, 19 Jan 2026 15:18:42 +0800 Subject: [PATCH 1/7] rework again Signed-off-by: Tianchen Ding --- docs/features/disagg_prefill.md | 3 +- docs/features/mooncake_connector_usage.md | 25 +- .../disaggregated_serving/README.md | 1 + .../mooncake_connector_proxy.py | 384 ++++++++++ .../run_mooncake_connector.sh | 222 ++++++ vllm/config/parallel.py | 10 + .../kv_transfer/kv_connector/factory.py | 2 +- .../v1/{ => mooncake}/mooncake_connector.py | 716 +++++++++++++----- .../v1/mooncake/mooncake_utils.py | 225 ++++++ vllm/v1/engine/core.py | 6 + 10 files changed, 1413 insertions(+), 181 deletions(-) create mode 100644 examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py create mode 100644 examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh rename vllm/distributed/kv_transfer/kv_connector/v1/{ => mooncake}/mooncake_connector.py (53%) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md index df69849bb922..af5f77747fac 100644 --- a/docs/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -19,12 +19,13 @@ Two main reasons: Please refer to [examples/online_serving/disaggregated_prefill.sh](../../examples/online_serving/disaggregated_prefill.sh) for the example usage of disaggregated prefilling. -Now supports 5 types of connectors: +Now supports 6 types of connectors: - **ExampleConnector**: refer to [examples/offline_inference/disaggregated-prefill-v1/run.sh](../../examples/offline_inference/disaggregated-prefill-v1/run.sh) for the example usage of ExampleConnector disaggregated prefilling. - **LMCacheConnectorV1**: refer to [examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh](../../examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh) for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission. - **NixlConnector**: refer to [tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md). - **P2pNcclConnector**: refer to [examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh](../../examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh) for the example usage of P2pNcclConnector disaggregated prefilling. +- **MooncakeConnector**: refer to [examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh](../../examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh) for the example usage of ExampleConnector disaggregated prefilling. For detailed usage guide, see [MooncakeConnector Usage Guide](mooncake_connector_usage.md). - **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as: ```bash diff --git a/docs/features/mooncake_connector_usage.md b/docs/features/mooncake_connector_usage.md index 653ea29ad943..0e2478924ead 100644 --- a/docs/features/mooncake_connector_usage.md +++ b/docs/features/mooncake_connector_usage.md @@ -31,11 +31,9 @@ vllm serve Qwen/Qwen2.5-7B-Instruct --port 8020 --kv-transfer-config '{"kv_conne ### Proxy ```bash -python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --prefiller-host 192.168.0.2 --prefiller-port 8010 --decoder-host 192.168.0.3 --decoder-port 8020 +python examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py --prefill http://192.168.0.2:8010 --decode http://192.168.0.3:8020 ``` -> NOTE: The Mooncake Connector currently uses the proxy from nixl_integration. This will be replaced with a self-developed proxy in the future. - Now you can send requests to the proxy server through port 8000. ## Environment Variables @@ -43,16 +41,29 @@ Now you can send requests to the proxy server through port 8000. - `VLLM_MOONCAKE_BOOTSTRAP_PORT`: Port for Mooncake bootstrap server - Default: 8998 - Required only for prefiller instances - - Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine - - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank - - Used for the decoder notifying the prefiller + - For headless instances, must be the same as the master instance + - Each instance needs a unique port on its host; using the same port number across different hosts is fine - `VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional) - Default: 480 - If a request is aborted and the decoder has not yet notified the prefiller, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely. -## KV Role Options +## KV Transfer Config + +### KV Role Options - **kv_producer**: For prefiller instances that generate KV caches - **kv_consumer**: For decoder instances that consume KV caches from prefiller - **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined. + +### kv_connector_extra_config + +- **num_workers**: Size of thread pool for one prefiller worker to transfer KV caches by mooncake. (default 10) +- **mooncake_protocol**: Mooncake connector protocol. (default "rdma") + +## Example Scripts/Code + +Refer to these example scripts in the vLLM repository: + +- [run_mooncake_connector.sh](../../examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh) +- [mooncake_connector_proxy.py](../../examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py) diff --git a/examples/online_serving/disaggregated_serving/README.md b/examples/online_serving/disaggregated_serving/README.md index 090afd7515ee..1e3284299346 100644 --- a/examples/online_serving/disaggregated_serving/README.md +++ b/examples/online_serving/disaggregated_serving/README.md @@ -6,3 +6,4 @@ This example contains scripts that demonstrate the disaggregated serving feature - `disagg_proxy_demo.py` - Demonstrates XpYd (X prefill instances, Y decode instances). - `kv_events.sh` - Demonstrates KV cache event publishing. +- `mooncake_connector` - A proxy demo for MooncakeConnector. diff --git a/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py b/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py new file mode 100644 index 000000000000..c5867b2846ac --- /dev/null +++ b/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import asyncio +import ipaddress +import itertools +import logging +import os +import urllib +import uuid +from contextlib import asynccontextmanager +from typing import Any + +import httpx +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def maybe_wrap_ipv6_address(address: str) -> str: + try: + ipaddress.IPv6Address(address) + return f"[{address}]" + except ValueError: + return address + + +def make_http_path(host: str, port: int) -> str: + return f"http://{host}:{port}" + + +def prefiller_cycle(prefill_clients: list[Any]): + while True: + for prefill_client in prefill_clients: + for i in range(prefill_client["dp_size"]): + yield prefill_client, i + + +async def get_prefiller_info(prefill_clients: list, ready: asyncio.Event): + for prefill_client in prefill_clients: + while True: + try: + response = await prefill_client["client"].get( + prefill_client["bootstrap_addr"] + "/query" + ) + response.raise_for_status() + data = response.json() + break + except Exception: + await asyncio.sleep(1) + + dp_size = 0 + for engine_id, engine_entry in data.items(): + dp_size += len(engine_entry) + for dp_rank in engine_entry: + prefill_client["dp_engine_id"][int(dp_rank)] = engine_id + prefill_client["dp_size"] = dp_size + + ready.set() + logger.info("All prefiller instances are ready.") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + app.state.ready = asyncio.Event() + + # Create prefill clients + for i, (url, bootstrap_port) in enumerate(global_args.prefill): + parsed_url = urllib.parse.urlparse(url) + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + app.state.prefill_clients.append( + { + "client": httpx.AsyncClient( + timeout=None, + base_url=url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ), + "bootstrap_addr": make_http_path(hostname, bootstrap_port or 8998), + "dp_engine_id": {}, + } + ) + + # Create decode clients + for i, url in enumerate(global_args.decode): + parsed_url = urllib.parse.urlparse(url) + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + app.state.decode_clients.append( + { + "client": httpx.AsyncClient( + timeout=None, + base_url=url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ), + } + ) + + asyncio.create_task(get_prefiller_info(app.state.prefill_clients, app.state.ready)) + + # Initialize round-robin iterators + app.state.prefill_iterator = prefiller_cycle(app.state.prefill_clients) + app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients))) + + print( + f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients." + ) + + yield + + # Shutdown: Close all clients + for client_info in app.state.prefill_clients: + await client_info["client"].aclose() + + for client_info in app.state.decode_clients: + await client_info["client"].aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + # Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI + parser.add_argument("--host", type=str, default="127.0.0.1") + + # For prefiller instances + parser.add_argument( + "--prefill", + nargs="+", + action="append", + dest="prefill_raw", + metavar=("URL", "bootstrap_port"), + help=( + "Prefill server URL and optional bootstrap port. " + "Can be specified multiple times. " + "Format: --prefill URL [BOOTSTRAP_PORT]. " + "BOOTSTRAP_PORT can be a port number, " + "'none', or omitted (defaults to none)." + ), + ) + + # For decoder instances + parser.add_argument( + "--decode", + nargs=1, + action="append", + dest="decode_raw", + metavar=("URL",), + help="Decode server URL. Can be specified multiple times.", + ) + + args = parser.parse_args() + args.prefill = _parse_prefill_urls(args.prefill_raw) + args.decode = _parse_decode_urls(args.decode_raw) + + return args + + +# From sglang router_args.py +def _parse_prefill_urls(prefill_list): + """Parse prefill URLs from --prefill arguments. + + Format: --prefill URL [BOOTSTRAP_PORT] + Example: + --prefill http://prefill1:8080 9000 # With bootstrap port + --prefill http://prefill2:8080 none # Explicitly no bootstrap port + --prefill http://prefill3:8080 # Defaults to no bootstrap port + """ + if not prefill_list: + return [] + + prefill_urls = [] + for prefill_args in prefill_list: + url = prefill_args[0] + + # Handle optional bootstrap port + if len(prefill_args) >= 2: + bootstrap_port_str = prefill_args[1] + # Handle 'none' as None + if bootstrap_port_str.lower() == "none": + bootstrap_port = None + else: + try: + bootstrap_port = int(bootstrap_port_str) + except ValueError as e: + raise ValueError( + f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" # noqa: E501 + ) from e + else: + # No bootstrap port specified, default to None + bootstrap_port = None + + prefill_urls.append((url, bootstrap_port)) + + return prefill_urls + + +def _parse_decode_urls(decode_list): + """Parse decode URLs from --decode arguments. + + Format: --decode URL + Example: --decode http://decode1:8081 --decode http://decode2:8081 + """ + if not decode_list: + return [] + + # decode_list is a list of single-element lists due to nargs=1 + return [url[0] for url in decode_list] + + +def get_next_client(app, service_type: str): + """ + Get the next client in round-robin fashion. + + Args: + app: The FastAPI app instance + service_type: Either 'prefill' or 'decode' + + Returns: + The next client to use + """ + if service_type == "prefill": + return next(app.state.prefill_iterator) + elif service_type == "decode": + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def send_request_to_service( + client_info: dict, dp_rank: int, endpoint: str, req_data: dict, request_id: str +): + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data["kv_transfer_params"] = { + "do_remote_decode": True, + "do_remote_prefill": False, + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "max_completion_tokens" in req_data: + req_data["max_completion_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + "X-data-parallel-rank": str(dp_rank), + } + + response = await client_info["client"].post( + endpoint, json=req_data, headers=headers + ) + response.raise_for_status() + + # CRITICAL: Release connection back to pool + await response.aclose() + + +async def stream_service_response( + prefill_client_info: dict, + prefill_dp_rank: int, + decode_client_info: dict, + endpoint: str, + req_data: dict, + request_id: str, +): + """ + Asynchronously stream response from a service using a client from the pool. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + + req_data["kv_transfer_params"] = { + "do_remote_decode": False, + "do_remote_prefill": True, + "remote_bootstrap_addr": prefill_client_info["bootstrap_addr"], + "remote_engine_id": prefill_client_info["dp_engine_id"][prefill_dp_rank], + "remote_dp_rank": prefill_dp_rank, + } + + async with decode_client_info["client"].stream( + "POST", endpoint, json=req_data, headers=headers + ) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +async def _handle_completions(api: str, request: Request): + if not app.state.ready.is_set(): + raise HTTPException(status_code=503, detail="Service Unavailable") + + try: + req_data = await request.json() + request_id = str(uuid.uuid4()) + + # Get the next prefill client in round-robin fashion + prefill_client_info, prefill_dp_rank = get_next_client(request.app, "prefill") + + # Send request to prefill service + asyncio.create_task( + send_request_to_service( + prefill_client_info, prefill_dp_rank, api, req_data, request_id + ) + ) + + decode_client_info = get_next_client(request.app, "decode") + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response( + prefill_client_info, + prefill_dp_rank, + decode_client_info, + api, + req_data, + request_id=request_id, + ): + yield chunk + + return StreamingResponse(generate_stream(), media_type="application/json") + + except Exception as e: + import sys + import traceback + + exc_info = sys.exc_info() + print(f"Error occurred in disagg prefill proxy server - {api} endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + return await _handle_completions("/v1/completions", request) + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + return await _handle_completions("/v1/chat/completions", request) + + +@app.get("/healthcheck") +async def healthcheck(): + """Simple endpoint to check if the server is running.""" + return { + "status": "ok", + "prefill_instances": len(app.state.prefill_clients), + "decode_instances": len(app.state.decode_clients), + } + + +if __name__ == "__main__": + global global_args + global_args = parse_args() + + import uvicorn + + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh b/examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh new file mode 100644 index 000000000000..e38d377c331a --- /dev/null +++ b/examples/online_serving/disaggregated_serving/mooncake_connector/run_mooncake_connector.sh @@ -0,0 +1,222 @@ +#!/bin/bash + +# ============================================================================= +# vLLM Disaggregated Serving Script for Mooncake Connector +# ============================================================================= +# This script demonstrates disaggregated prefill and decode serving using +# Mooncake Connector. +# +# Configuration can be customized via environment variables: +# MODEL: Model to serve +# PREFILL_GPUS: Comma-separated GPU IDs for prefill servers +# DECODE_GPUS: Comma-separated GPU IDs for decode servers +# PREFILL_PORTS: Comma-separated ports for prefill servers +# BOOTSTRAP_PORTS: Bootstrap server port launched by prefill servers +# DECODE_PORTS: Comma-separated ports for decode servers +# PROXY_PORT: Proxy server port used to setup P/D disaggregated connection. +# TIMEOUT_SECONDS: Server startup timeout +# ============================================================================= + +# Configuration - can be overridden via environment variables +MODEL=${MODEL:-Qwen/Qwen2.5-7B-Instruct} +TIMEOUT_SECONDS=${TIMEOUT_SECONDS:-1200} +PROXY_PORT=${PROXY_PORT:-8000} + +PREFILL_GPUS=${PREFILL_GPUS:-0} +DECODE_GPUS=${DECODE_GPUS:-1} +PREFILL_PORTS=${PREFILL_PORTS:-8010} +BOOTSTRAP_PORTS=${BOOTSTRAP_PORTS:-8998} +DECODE_PORTS=${DECODE_PORTS:-8020} + +echo "Warning: Mooncake Connector support for vLLM v1 is experimental and subject to change." +echo "" +echo "Architecture Configuration:" +echo " Model: $MODEL" +echo " Prefill GPUs: $PREFILL_GPUS, Ports: $PREFILL_PORTS, Bootstrap Port:$BOOTSTRAP_PORTS" +echo " Decode GPUs: $DECODE_GPUS, Ports: $DECODE_PORTS" +echo " Proxy Port: $PROXY_PORT" +echo " Timeout: ${TIMEOUT_SECONDS}s" +echo "" + +PIDS=() + +# Switch to the directory of the current script +cd "$(dirname "${BASH_SOURCE[0]}")" + +check_required_files() { + local files=("mooncake_connector_proxy.py") + for file in "${files[@]}"; do + if [[ ! -f "$file" ]]; then + echo "Required file $file not found in $(pwd)" + exit 1 + fi + done +} + +check_hf_token() { + if [ -z "$HF_TOKEN" ]; then + echo "HF_TOKEN is not set. Please set it to your Hugging Face token." + echo "Example: export HF_TOKEN=your_token_here" + exit 1 + fi + if [[ "$HF_TOKEN" != hf_* ]]; then + echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token." + exit 1 + fi + echo "HF_TOKEN is set and valid." +} + +check_num_gpus() { + # Check if the number of GPUs are >=2 via nvidia-smi + num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + if [ "$num_gpus" -lt 2 ]; then + echo "You need at least 2 GPUs to run disaggregated prefill." + exit 1 + else + echo "Found $num_gpus GPUs." + fi +} + +ensure_python_library_installed() { + echo "Checking if $1 is installed..." + if ! python3 -c "import $1" > /dev/null 2>&1; then + echo "$1 is not installed. Please install it via pip install $1." + exit 1 + else + echo "$1 is installed." + fi +} + +cleanup() { + echo "Stopping everything…" + trap - INT TERM # prevent re-entrancy + pkill -9 -f "mooncake_connector_proxy.py" + kill -- -$$ # negative PID == "this whole process-group" + wait # reap children so we don't leave zombies + exit 0 +} + +wait_for_server() { + local port=$1 + local timeout_seconds=$TIMEOUT_SECONDS + local start_time=$(date +%s) + + echo "Waiting for server on port $port..." + + while true; do + if curl -s "localhost:${port}/v1/completions" > /dev/null; then + echo "Server on port $port is ready." + return 0 + fi + + local now=$(date +%s) + if (( now - start_time >= timeout_seconds )); then + echo "Timeout waiting for server on port $port" + return 1 + fi + + sleep 1 + done +} + +main() { + check_required_files + check_hf_token + check_num_gpus + ensure_python_library_installed vllm + ensure_python_library_installed mooncake.engine + + trap cleanup INT + trap cleanup USR1 + trap cleanup TERM + + echo "Launching disaggregated serving components..." + echo "Please check the log files for detailed output:" + echo " - prefill*.log: Prefill server logs" + echo " - decode*.log: Decode server logs" + echo " - proxy.log: Proxy server log" + + # Parse GPU and port arrays + IFS=',' read -ra PREFILL_GPU_ARRAY <<< "$PREFILL_GPUS" + IFS=',' read -ra DECODE_GPU_ARRAY <<< "$DECODE_GPUS" + IFS=',' read -ra PREFILL_PORT_ARRAY <<< "$PREFILL_PORTS" + IFS=',' read -ra BOOTSTRAP_PORT_ARRAY <<< "$BOOTSTRAP_PORTS" + IFS=',' read -ra DECODE_PORT_ARRAY <<< "$DECODE_PORTS" + + proxy_param="" + + # ============================================================================= + # Launch Prefill Servers (X Producers) + # ============================================================================= + echo "" + echo "Starting ${#PREFILL_GPU_ARRAY[@]} prefill server(s)..." + for i in "${!PREFILL_GPU_ARRAY[@]}"; do + local gpu_id=${PREFILL_GPU_ARRAY[$i]} + local port=${PREFILL_PORT_ARRAY[$i]} + local bootstrap_port=${BOOTSTRAP_PORT_ARRAY[$i]} + + echo " Prefill server $((i+1)): GPU $gpu_id, Port $port, Bootstrap Port $bootstrap_port" + VLLM_MOONCAKE_BOOTSTRAP_PORT=$bootstrap_port CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \ + --port $port \ + --kv-transfer-config \ + "{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_producer\"}" > prefill$((i+1)).log 2>&1 & + PIDS+=($!) + proxy_param="${proxy_param} --prefill http://0.0.0.0:${port} $bootstrap_port" + done + + # ============================================================================= + # Launch Decode Servers (Y Decoders) + # ============================================================================= + echo "" + echo "Starting ${#DECODE_GPU_ARRAY[@]} decode server(s)..." + for i in "${!DECODE_GPU_ARRAY[@]}"; do + local gpu_id=${DECODE_GPU_ARRAY[$i]} + local port=${DECODE_PORT_ARRAY[$i]} + + echo " Decode server $((i+1)): GPU $gpu_id, Port $port" + CUDA_VISIBLE_DEVICES=$gpu_id vllm serve $MODEL \ + --port $port \ + --kv-transfer-config \ + "{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_consumer\"}" > decode$((i+1)).log 2>&1 & + PIDS+=($!) + proxy_param="${proxy_param} --decode http://0.0.0.0:${port}" + done + + # ============================================================================= + # Launch Proxy Server + # ============================================================================= + echo "" + echo "Starting proxy server on port $PROXY_PORT..." + python3 mooncake_connector_proxy.py $proxy_param --port $PROXY_PORT > proxy.log 2>&1 & + PIDS+=($!) + + # ============================================================================= + # Wait for All Servers to Start + # ============================================================================= + echo "" + echo "Waiting for all servers to start..." + for port in "${PREFILL_PORT_ARRAY[@]}" "${DECODE_PORT_ARRAY[@]}"; do + if ! wait_for_server $port; then + echo "Failed to start server on port $port" + cleanup + exit 1 + fi + done + + echo "" + echo "All servers are up. Starting benchmark..." + + # ============================================================================= + # Run Benchmark + # ============================================================================= + vllm bench serve --port $PROXY_PORT --seed $(date +%s) \ + --backend vllm --model $MODEL \ + --dataset-name random --random-input-len 7500 --random-output-len 200 \ + --num-prompts 200 --burstiness 100 --request-rate 2 | tee benchmark.log + + echo "Benchmarking done. Cleaning up..." + + cleanup +} + +main diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 16487d744e20..dd7d1ee1eb62 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -280,6 +280,14 @@ class is dynamically inherited by the worker class. This is used to inject """Equal to the data parallel rank but not used for torch process groups and not overridden for dense models.""" + origin_data_parallel_size: int = Field(init=False) + """Equal to the data parallel size but not used for torch process groups + and not overridden for dense models.""" + + origin_data_parallel_size_local: int = Field(init=False) + """Equal to the data parallel size local but not used for torch process groups + and not overridden for dense models.""" + _api_process_count: int = Field(default=1, gt=0) """ The number of API processes initialized. @@ -594,6 +602,8 @@ def __post_init__(self) -> None: ) self.data_parallel_index = self.data_parallel_rank + self.origin_data_parallel_size = self.data_parallel_size + self.origin_data_parallel_size_local = self.data_parallel_size_local if self.distributed_executor_backend == "external_launcher": os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index f4113c91b60f..4a04408d9fd2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -198,6 +198,6 @@ def get_connector_class( ) KVConnectorFactory.register_connector( "MooncakeConnector", - "vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector", + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector", "MooncakeConnector", ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py similarity index 53% rename from vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py rename to vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index ef0268b9aba0..b0e447b1eed5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -4,10 +4,13 @@ import threading import time from collections import defaultdict +from collections.abc import MutableMapping from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from enum import IntEnum from typing import TYPE_CHECKING, Any, Optional +import httpx import msgspec import numpy as np import torch @@ -17,6 +20,7 @@ from vllm import envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( + EngineId, TpKVTopology, get_current_attn_backend, ) @@ -25,10 +29,17 @@ KVConnectorMetadata, KVConnectorRole, ) +from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import ( + EngineEntry, + MooncakeBootstrapServer, + RegisterWorkerPayload, + TruncatingDict, +) from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group, + is_local_first_rank, ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger @@ -43,7 +54,7 @@ except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " - "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " "to run VLLM with MooncakeTransferEngine." ) from e @@ -52,46 +63,74 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request -EngineId = str ReqId = str -TRANS_DONE = b"trans_done" -TRANS_ERROR = b"trans_error" - logger = init_logger(__name__) -class MooncakeAgentMetadata( +class MooncakeXferMetadata( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True, ): remote_hostname: str remote_port: int - request_ids: list[ReqId] + remote_tp_size: int + remote_tp_rank: int + req_blocks: dict[ReqId, list[int]] kv_caches_base_addr: list[int] - block_ids: list[list[int]] + + +class MooncakeXferResponseStatus(IntEnum): + # Transfer finished + FINISH = 0 + # Continue to receive + CONTINUE = 1 + # Something wrong, see err_msg + ERROR = 2 + + +class MooncakeXferResponse( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] +): + status: MooncakeXferResponseStatus + ok_reqs: list[ReqId] | None = None + err_reqs: list[ReqId] | None = None + err_msg: str | None = None @dataclass -class RecvReqMeta: +class PullReqMeta: + req_id: ReqId local_block_ids: list[int] - remote_host: str - remote_port: int + remote_engine_id: EngineId + remote_bootstrap_addr: str + # Set expire time to avoid infinitely sending requests. + expire_time: float = float("inf") + # Designed for one D pairing to multiple P + pull_tasks_count: int = 0 @dataclass class SendBlockMeta: + p_req_id: ReqId local_block_ids: list[int] ready: asyncio.Event expire_time: float = float("inf") + need_send: int = 0 + sended: int = 0 + sending: int = 0 class MooncakeConnectorMetadata(KVConnectorMetadata): def __init__(self): - self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {} + # Use (engine_id, dp_rank) to group reqs with same dp. + # See comments in MooncakeBootstrapServer. + self.reqs_to_recv: dict[tuple[EngineId, int], dict[ReqId, PullReqMeta]] = ( + defaultdict(dict) + ) self.reqs_to_send: dict[ReqId, list[int]] = {} + self.reqs_not_processed: set[ReqId] = set() def add_new_req( self, @@ -101,10 +140,14 @@ def add_new_req( load_remote_cache: bool = True, ): if load_remote_cache: - self.reqs_to_recv[request_id] = RecvReqMeta( + remote_engine_id = kv_transfer_params["remote_engine_id"] + self.reqs_to_recv[(remote_engine_id, kv_transfer_params["remote_dp_rank"])][ + request_id + ] = PullReqMeta( + req_id=request_id, local_block_ids=local_block_ids, - remote_host=kv_transfer_params["remote_host"], - remote_port=kv_transfer_params["remote_port"], + remote_engine_id=remote_engine_id, + remote_bootstrap_addr=kv_transfer_params["remote_bootstrap_addr"], ) else: self.reqs_to_send[request_id] = local_block_ids @@ -209,12 +252,14 @@ class MooncakeConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config - self.engine_id: EngineId = engine_id - self.side_channel_host = get_ip() - self.side_channel_port = get_mooncake_side_channel_port(vllm_config) assert vllm_config.kv_transfer_config - self.kv_role = vllm_config.kv_transfer_config.kv_role + self.is_kv_producer: bool = ( + vllm_config.kv_transfer_config.kv_role == "kv_producer" + ) + self.is_kv_consumer: bool = ( + vllm_config.kv_transfer_config.kv_role == "kv_consumer" + ) logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id) # Requests that need to start recv/send. @@ -222,6 +267,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_send: dict[ReqId, list[int]] = {} + # Reqs to remove from processed set because they're not to send after + # remote prefill or aborted. + self._reqs_not_processed: set[ReqId] = set() def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int @@ -249,8 +297,12 @@ def get_num_new_matched_tokens( params, ) - if params is not None and params.get("do_remote_prefill"): + if not params: + return 0, False + + if params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. + assert not self.is_kv_producer token_ids = request.prompt_token_ids or [] count = len(token_ids) - num_computed_tokens if count > 0: @@ -265,7 +317,8 @@ def update_state_after_alloc( params = request.kv_transfer_params logger.debug( "MooncakeConnector update_state_after_alloc: " - "num_external_tokens=%s, kv_transfer_params=%s", + "req_id=%s num_external_tokens=%s, kv_transfer_params=%s", + request.request_id, num_external_tokens, params, ) @@ -274,8 +327,11 @@ def update_state_after_alloc( return if params.get("do_remote_prefill"): - assert self.kv_role != "kv_producer" - if all(p in params for p in ("remote_host", "remote_port")): + assert not self.is_kv_producer + if all( + p in params + for p in ("remote_engine_id", "remote_bootstrap_addr", "remote_dp_rank") + ): # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call # send_notif in _read_blocks to free the memory on the P. @@ -295,6 +351,7 @@ def update_state_after_alloc( elif params.get("do_remote_decode"): # Add an empty list to worker to create event. + assert not self.is_kv_consumer self._reqs_need_send[request.request_id] = [] def build_connector_meta( @@ -304,7 +361,7 @@ def build_connector_meta( meta = MooncakeConnectorMetadata() # Loop through scheduled reqs and convert to RecvReqMeta. - if self.kv_role != "kv_producer": + if not self.is_kv_producer: for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None meta.add_new_req( @@ -314,7 +371,7 @@ def build_connector_meta( ) self._reqs_need_recv.clear() - if self.kv_role != "kv_consumer": + if not self.is_kv_consumer: for req_id, block_ids in self._reqs_need_send.items(): meta.add_new_req( request_id=req_id, @@ -323,6 +380,8 @@ def build_connector_meta( load_remote_cache=False, ) self._reqs_need_send.clear() + meta.reqs_not_processed = self._reqs_not_processed + self._reqs_not_processed = set() return meta @@ -338,8 +397,9 @@ def request_finished( params = request.kv_transfer_params logger.debug( - "MooncakeConnector request_finished, request_status=%s, " + "MooncakeConnector request_finished, req_id=%s, request_status=%s, " "kv_transfer_params=%s", + request.request_id, request.status, params, ) @@ -353,18 +413,21 @@ def request_finished( # To avoid stranding the prefill blocks in the prefill instance, # we must add empty block_ids to _reqs_need_recv so that our # worker side will notify and free blocks in the prefill instance. - assert self.kv_role != "kv_producer" + assert not self.is_kv_producer self._reqs_need_recv[request.request_id] = (request, []) params["do_remote_prefill"] = False return False, None - if ( - not params.get("do_remote_decode") - or request.status != RequestStatus.FINISHED_LENGTH_CAPPED - ): + if not params.get("do_remote_decode"): return False, None - assert self.kv_role != "kv_consumer" + assert not self.is_kv_consumer + + if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: + # Also include the case of a P/D Prefill request with immediate + # block free (eg abort). Stop tracking this request. + self._reqs_not_processed.add(request.request_id) + return False, None # TODO: check whether block_ids actually ever be 0. If not we could # remove the conditional below @@ -373,12 +436,7 @@ def request_finished( if delay_free_blocks: self._reqs_need_send[request.request_id] = block_ids - return delay_free_blocks, dict( - do_remote_prefill=True, - do_remote_decode=False, - remote_host=self.side_channel_host, - remote_port=self.side_channel_port, - ) + return delay_free_blocks, None class MooncakeConnectorWorker: @@ -391,7 +449,18 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.engine = TransferEngine() self.hostname = get_ip() - protocol = self.vllm_config.kv_transfer_config.kv_connector_extra_config.get( # type: ignore[union-attr] + + assert (kv_transfer_config := vllm_config.kv_transfer_config) + self.is_kv_producer: bool = kv_transfer_config.kv_role == "kv_producer" + self.is_kv_consumer: bool = kv_transfer_config.kv_role == "kv_consumer" + self.num_sender_workers = kv_transfer_config.kv_connector_extra_config.get( + "num_workers", 10 + ) + # Create more tasks than workers to keep the thread pool saturated. + # Tasks can await async events, so a surplus (2x is a robust heuristic) + # prevents workers from idling. + self.num_sender_tasks = self.num_sender_workers * 2 + protocol = kv_transfer_config.kv_connector_extra_config.get( # type: ignore[union-attr] "mooncake_protocol", "rdma" ) logger.info( @@ -409,33 +478,35 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.rpc_port, ) - # Mooncake handshake port. - self.side_channel_port: int = get_mooncake_side_channel_port(vllm_config) - + self._remote_agents: dict[EngineId, EngineEntry] = {} + self._pending_bootstrap_querys: dict[str, asyncio.Event] = {} + self.side_channel_port: int = 0 # we will bind it in register_kv_caches() self.engine_id: EngineId = engine_id self.tp_rank = get_tensor_model_parallel_rank() - self.world_size = get_tensor_model_parallel_world_size() - self.tp_group = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() self.num_blocks = 0 - assert vllm_config.kv_transfer_config - self.kv_role = vllm_config.kv_transfer_config.kv_role - self.num_sender_workers = ( - vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "num_workers", 10 + assert (parallel_config := vllm_config.parallel_config) + dp_rank = parallel_config.data_parallel_index + dp_local_rank = parallel_config.data_parallel_rank_local + self.dp_rank = dp_local_rank if parallel_config.local_engines_only else dp_rank + pp_size = vllm_config.parallel_config.pipeline_parallel_size + if pp_size > 1: + raise ValueError( + "Mooncake Transfer Engine does not support pipeline parallelism yet." ) - ) - # Create more tasks than workers to keep the thread pool saturated. - # Tasks can await async events, so a surplus (2x is a robust heuristic) - # prevents workers from idling. - self.num_sender_tasks = self.num_sender_workers * 2 + self.pp_rank = get_pp_group().rank_in_group self.kv_caches_base_addr: list[int] = [] self.device_kv_caches: dict[str, torch.Tensor] = {} - self.reqs_need_send: dict[ReqId, SendBlockMeta] = {} + self.reqs_need_send: MutableMapping[ReqId, SendBlockMeta] = TruncatingDict() + + # Only used by prefillers. + host, port = get_mooncake_bootstrap_addr(vllm_config) + self.bootstrap_addr = make_zmq_path("http", host, port) # For kv_both, we will act both prefiller and decoder. - if self.kv_role != "kv_consumer": + if not self.is_kv_consumer: # Background threads for sending kvcaches to D. self._sender_executor = ThreadPoolExecutor( max_workers=self.num_sender_workers, @@ -454,7 +525,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): ) self._sender_listener_t.start() - if self.kv_role != "kv_producer": + # Start bootstrap server on global rank 0. + if should_launch_bootstrap_server(vllm_config): + self.bootstrap_server = MooncakeBootstrapServer( + vllm_config, "0.0.0.0", port + ) + self.bootstrap_server.start() + + if not self.is_kv_producer: self.receiver_loop = asyncio.new_event_loop() self._mooncake_receiver_t = threading.Thread( target=_async_loop, args=(self.receiver_loop,), daemon=True @@ -478,7 +556,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) - self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} self.kv_topo = TpKVTopology( tp_rank=self.tp_rank, @@ -492,7 +570,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.async_zmq_ctx = zmq.asyncio.Context() self._encoder = msgspec.msgpack.Encoder() - self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) + self._xfer_meta_decoder = msgspec.msgpack.Decoder(MooncakeXferMetadata) + self._xfer_resp_decoder = msgspec.msgpack.Decoder(MooncakeXferResponse) def __del__(self): self.shutdown() @@ -500,26 +579,61 @@ def __del__(self): def shutdown(self): """Cleanup background threads on destruction.""" self.async_zmq_ctx.term() - if self.kv_role != "kv_consumer": + if not self.is_kv_consumer: self._sender_executor.shutdown(wait=False) if self.sender_loop.is_running(): self.sender_loop.call_soon_threadsafe(self.sender_loop.stop) self._sender_listener_t.join() - if self.kv_role != "kv_producer" and self.receiver_loop.is_running(): + if should_launch_bootstrap_server(self.vllm_config): + self.bootstrap_server.shutdown() + if not self.is_kv_producer and self.receiver_loop.is_running(): self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop) self._mooncake_receiver_t.join() - async def _mooncake_sender_listener( - self, ready_event: threading.Event, base_port: int, tp_rank: int - ): + async def register_worker_with_bootstrap(self): + url = self.bootstrap_addr + "/register" + worker_addr = make_zmq_path("tcp", self.hostname, self.side_channel_port) + payload = RegisterWorkerPayload( + engine_id=self.engine_id, + dp_rank=self.dp_rank, + tp_rank=self.tp_rank, + pp_rank=self.pp_rank, + addr=worker_addr, + ) + while True: + try: + async with httpx.AsyncClient() as client: + response = await client.post(url, json=payload.model_dump()) + response.raise_for_status() + logger.debug("Successfully registered with bootstrap server at %s", url) + break + except httpx.ConnectError: + # Bootstrap server not ready, wait for a while and retry. + await asyncio.sleep(1) + except Exception as e: + err_msg = ( + e.response.text if isinstance(e, httpx.HTTPStatusError) else str(e) + ) + logger.error( + "Error registering %s with bootstrap server: %s", payload, err_msg + ) + raise e + + async def _mooncake_sender_listener(self, ready_event: threading.Event): """ Background thread that listens for Mooncake requests, dispatches them to a thread pool, and sends acknowledgments upon completion. """ - path = make_zmq_path("tcp", self.hostname, base_port + tp_rank) - sock = make_zmq_socket(self.async_zmq_ctx, path, zmq.ROUTER) - logger.debug("Mooncake sender starting listening on path: %s", path) + sock = self.async_zmq_ctx.socket(zmq.ROUTER) + self.side_channel_port = sock.bind_to_random_port(f"tcp://{self.hostname}") + logger.debug( + "Mooncake sender starting listening on path: tcp://%s:%d", + self.hostname, + self.side_channel_port, + ) + + await self.register_worker_with_bootstrap() # Create async worker tasks that process items from the queue sender_tasks = [ @@ -531,7 +645,7 @@ async def _mooncake_sender_listener( try: while True: - identity, _, metadata_bytes = await sock.recv_multipart() + identity, metadata_bytes = await sock.recv_multipart() await self.sender_worker_queue.put((identity, metadata_bytes)) except zmq.ContextTerminated: logger.debug("ZMQ context terminated, exiting Mooncake sender thread.") @@ -549,12 +663,16 @@ async def _sender_worker(self, sock: zmq.asyncio.Socket): try: identity, metadata_bytes = await self.sender_worker_queue.get() try: - metadata = self._decoder.decode(metadata_bytes) - await self.send_kv_to_decode(metadata) - await sock.send_multipart((identity, b"", TRANS_DONE)) + metadata = self._xfer_meta_decoder.decode(metadata_bytes) + await self.send_kv_to_decode(identity, sock, metadata) except Exception as e: logger.error("Error processing Mooncake xfer request: %s", e) - await sock.send_multipart((identity, b"", TRANS_ERROR)) + error_response = MooncakeXferResponse( + status=MooncakeXferResponseStatus.ERROR, err_msg=str(e) + ) + await sock.send_multipart( + (identity, self._encoder.encode(error_response)) + ) finally: self.sender_worker_queue.task_done() except asyncio.CancelledError: @@ -562,55 +680,156 @@ async def _sender_worker(self, sock: zmq.asyncio.Socket): except Exception as e: logger.error("Error in _sender_worker: %s", e) - async def send_kv_to_decode(self, meta: MooncakeAgentMetadata): - send_reqs: list[tuple[ReqId, SendBlockMeta]] = [] - for req_id in meta.request_ids: - send_meta = self.reqs_need_send.get(req_id) - if send_meta is None: - logger.warning("Request %s not found in reqs_need_send", req_id) - return - # Mark it as not expired. We will send it now. - send_meta.expire_time = float("inf") - send_reqs.append((req_id, send_meta)) - - src_ptrs, dst_ptrs, lengths = await self._build_transfer_params(send_reqs, meta) - remote_session = f"{meta.remote_hostname}:{meta.remote_port}" - ret_value = await self.sender_loop.run_in_executor( - self._sender_executor, - self._send_blocks, - remote_session, - src_ptrs, - dst_ptrs, - lengths, - ) + async def send_kv_to_decode( + self, identity: bytes, sock: zmq.asyncio.Socket, meta: MooncakeXferMetadata + ): + pending_reqs: dict[ReqId, SendBlockMeta] = {} + remote_tp_ranks = self.kv_topo.get_target_remote_ranks(meta.remote_tp_size) + if self.tp_rank not in remote_tp_ranks: + # This D worker does not pair with the P worker. + msg = f"This P tp_rank {self.tp_rank} not in remote D target ranks {remote_tp_ranks}" # noqa: E501 + logger.error(msg) + response = MooncakeXferResponse( + status=MooncakeXferResponseStatus.ERROR, + err_msg=msg, + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + return + for d_req_id in meta.req_blocks: + if d_req_id not in self.reqs_need_send: + # This req is not enqueued in P side yet, create it here. + self.reqs_need_send[d_req_id] = SendBlockMeta( + p_req_id="", local_block_ids=[], ready=asyncio.Event() + ) + send_meta = self.reqs_need_send[d_req_id] + pending_reqs[d_req_id] = send_meta - if ret_value != 0: - raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}") + async def wait_and_ret( + d_req_id: ReqId, send_meta: SendBlockMeta + ) -> tuple[ReqId, SendBlockMeta]: + await send_meta.ready.wait() + return d_req_id, send_meta - for req_id in meta.request_ids: - del self.reqs_need_send[req_id] + wait_tasks = [ + asyncio.create_task(wait_and_ret(d_req_id, send_meta)) + for d_req_id, send_meta in pending_reqs.items() + ] + + while wait_tasks: + done, pending = await asyncio.wait( + wait_tasks, + timeout=envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, + return_when=asyncio.FIRST_COMPLETED, + ) - self.finished_sending_reqs.update(meta.request_ids) + if not done: + # Timeout, abort all pending requests. + for task in wait_tasks: + task.cancel() + response = MooncakeXferResponse( + status=MooncakeXferResponseStatus.FINISH, + err_reqs=list(pending_reqs), + err_msg="Timeout waiting for P side ready.", + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + break + + wait_tasks = list(pending) + response_status = ( + MooncakeXferResponseStatus.CONTINUE + if wait_tasks + else MooncakeXferResponseStatus.FINISH + ) + ready_reqs: list[tuple[ReqId, SendBlockMeta]] = [] + for task in done: + d_req_id, send_meta = task.result() + del pending_reqs[d_req_id] + # Do we still in reqs_need_send (not expired)? + if d_req_id in self.reqs_need_send: + # Mark it sending to avoid expiration. + send_meta.sending += 1 + if not send_meta.need_send: + self.resolve_need_send(send_meta, remote_tp_ranks) + ready_reqs.append((d_req_id, send_meta)) + # Otherwise (expired, very unlikely), forget it. Do not let D retry. + + src_ptrs, dst_ptrs, lengths, err_reqs = await self._build_transfer_params( + ready_reqs, meta + ) + + if err_reqs: + response = MooncakeXferResponse( + status=response_status, + err_reqs=err_reqs, + err_msg="P num blocks less than D", + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + + if src_ptrs: + remote_session = f"{meta.remote_hostname}:{meta.remote_port}" + ret_value = await self.sender_loop.run_in_executor( + self._sender_executor, + self._send_blocks, + remote_session, + src_ptrs, + dst_ptrs, + lengths, + ) + + if ret_value != 0: + err_reqs = [] + for d_req_id, send_meta in ready_reqs: + send_meta.sending -= 1 + err_reqs.append(d_req_id) + # Do best effort to transfer the remaining reqs. + response = MooncakeXferResponse( + status=response_status, + err_reqs=err_reqs, + err_msg=f"Mooncake transfer engine returned {ret_value}", + ) + await sock.send_multipart( + (identity, self._encoder.encode(response)) + ) + continue + + for d_req_id, send_meta in ready_reqs: + # Todo: for heterogeneous TP (one P pairs to multiple D), + # we need to check whether all headers are sent. + # If not, we should set expire_time to normal and skip the below. + send_meta.sending -= 1 + send_meta.sended += 1 + if send_meta.sended == send_meta.need_send: + del self.reqs_need_send[d_req_id] + self.finished_sending_reqs.add(send_meta.p_req_id) + + response = MooncakeXferResponse( + status=response_status, + ok_reqs=[d_req_id for d_req_id, _ in ready_reqs], + ) + await sock.send_multipart((identity, self._encoder.encode(response))) + + def resolve_need_send(self, send_meta: SendBlockMeta, remote_tp_ranks: list[int]): + # Prepare for heterogeneous TP (one P pairs to multiple D) + send_meta.need_send = len(remote_tp_ranks) + if send_meta.need_send != 1: + logger.error("Mooncake: Heterogeneous TP is not supported yet.") async def _build_transfer_params( self, - send_reqs: list[tuple[ReqId, SendBlockMeta]], - agent_meta: MooncakeAgentMetadata, - ) -> tuple[list[int], list[int], list[int]]: + ready_reqs: list[tuple[ReqId, SendBlockMeta]], + agent_meta: MooncakeXferMetadata, + ) -> tuple[list[int], list[int], list[int], list[ReqId]]: src_ptrs = [] dst_ptrs = [] lengths = [] + err_reqs: list[ReqId] = [] local_base_addr = self.kv_caches_base_addr remote_base_addr = agent_meta.kv_caches_base_addr block_len = self.block_len remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}" - assert len(send_reqs) == len(agent_meta.block_ids) - for (req_id, send_meta), remote_block_ids in zip( - send_reqs, agent_meta.block_ids - ): - await send_meta.ready.wait() - + for d_req_id, send_meta in ready_reqs: + remote_block_ids = agent_meta.req_blocks[d_req_id] num_remote_blocks = len(remote_block_ids) if num_remote_blocks == 0: continue @@ -618,7 +837,15 @@ async def _build_transfer_params( local_block_ids = send_meta.local_block_ids # Partial prefix cache hit: just read uncomputed blocks. num_local_blocks = len(local_block_ids) - assert num_local_blocks >= num_remote_blocks + if num_local_blocks < num_remote_blocks: + logger.error( + "req %s: local blocks(%d) less than remote blocks(%d)!", + d_req_id, + num_local_blocks, + num_remote_blocks, + ) + err_reqs.append(d_req_id) + continue if num_local_blocks > num_remote_blocks: local_block_ids = local_block_ids[-num_remote_blocks:] @@ -643,12 +870,12 @@ async def _build_transfer_params( logger.debug( "Sending kv_caches for request %s (%d blocks) to %s", - req_id, + d_req_id, num_remote_blocks, remote_session, ) - return src_ptrs, dst_ptrs, lengths + return src_ptrs, dst_ptrs, lengths, err_reqs def _send_blocks( self, @@ -722,15 +949,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) # No need to launch server for D node. - if self.kv_role == "kv_consumer": + if self.is_kv_consumer: return ready_event = threading.Event() asyncio.run_coroutine_threadsafe( - self._mooncake_sender_listener( - ready_event, self.side_channel_port, self.tp_rank - ), - self.sender_loop, + self._mooncake_sender_listener(ready_event), self.sender_loop ) ready_event.wait() # Wait for listener ZMQ socket to be ready. @@ -746,9 +970,11 @@ async def fetch_finished_sending_reqs(self) -> set[ReqId]: # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter() expired_reqs = [ - req_id - for req_id, send_meta in self.reqs_need_send.items() - if send_meta.expire_time < now + send_meta.p_req_id + for send_meta in self.reqs_need_send.values() + if send_meta.p_req_id + and send_meta.expire_time < now + and send_meta.sending == 0 ] for req_id in expired_reqs: logger.warning( @@ -771,12 +997,12 @@ def get_finished(self) -> tuple[set[str] | None, set[str] | None]: """ recv_fut = None send_fut = None - if self.kv_role != "kv_producer": + if not self.is_kv_producer: recv_fut = asyncio.run_coroutine_threadsafe( self.fetch_finished_recving_reqs(), self.receiver_loop ) - if self.kv_role != "kv_consumer": + if not self.is_kv_consumer: send_fut = asyncio.run_coroutine_threadsafe( self.fetch_finished_sending_reqs(), self.sender_loop ) @@ -795,69 +1021,179 @@ def get_finished(self) -> tuple[set[str] | None, set[str] | None]: return finished_sending_reqs or None, finished_recving_reqs or None - async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]): - req_ids, block_ids = map(list, zip(*req_blocks)) - metadata = MooncakeAgentMetadata( + async def receive_kv_from_single_worker( + self, + worker_addr: str, + pull_metas: dict[ReqId, PullReqMeta], + ): + req_ids = set(pull_metas) + metadata = MooncakeXferMetadata( remote_hostname=self.hostname, remote_port=self.rpc_port, - request_ids=req_ids, + remote_tp_size=self.tp_size, + remote_tp_rank=self.tp_rank, + req_blocks={ + req_id: pull_meta.local_block_ids + for req_id, pull_meta in pull_metas.items() + }, kv_caches_base_addr=self.kv_caches_base_addr, - block_ids=block_ids, ) encoded_data = self._encoder.encode(metadata) logger.debug( - "Size of encoded MooncakeAgentMetadata: %d bytes", len(encoded_data) + "Size of encoded MooncakeXferMetadata: %d bytes", len(encoded_data) + ) + logger.debug( + "Sending kv transfer request for %s on path: %s", req_ids, worker_addr ) - logger.debug("Sending kv transfer request for %s on path: %s", req_ids, path) # Send query for the request. - sock: zmq.asyncio.Socket = make_zmq_socket( - self.async_zmq_ctx, path, zmq.REQ, bind=False, linger=0 - ) - sock.setsockopt(zmq.RCVTIMEO, 60000) try: - await sock.send(encoded_data) - ret_msg = await sock.recv() - if ret_msg != TRANS_DONE: - logger.error( - "Error happens during tranfering kvcache for %s, see logs in prefiller.", # noqa: E501 - req_ids, - ) - return + with make_zmq_socket( + self.async_zmq_ctx, worker_addr, zmq.DEALER, bind=False, linger=0 + ) as sock: + await sock.send(encoded_data) + while True: + ret_msg = await sock.recv() + response = self._xfer_resp_decoder.decode(ret_msg) + if response.status == MooncakeXferResponseStatus.ERROR: + logger.error( + "Error happens during tranfering kvcache for %s: %s", + req_ids, + response.err_msg, + ) + return + self.process_pulling_result(response, pull_metas, worker_addr) + if response.status == MooncakeXferResponseStatus.FINISH: + break except zmq.ContextTerminated: logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.") except Exception as e: - logger.error("MooncakeAgentMetadata transfer failed for %s: %s", req_ids, e) + logger.error("MooncakeXferMetadata transfer failed for %s: %s", req_ids, e) return - finally: - sock.close() - self.finished_recving_reqs.update(req_ids) + def process_pulling_result( + self, + response: MooncakeXferResponse, + pull_metas: dict[ReqId, PullReqMeta], + worker_addr: str, + ): + ok_reqs: list[ReqId] = response.ok_reqs or [] + + for req_id in ok_reqs: + pull_meta = pull_metas[req_id] + # No race because we are in async loop. + pull_meta.pull_tasks_count -= 1 + if pull_meta.pull_tasks_count == 0: + self.finished_recving_reqs.add(pull_meta.req_id) + + if ok_reqs: + logger.debug("pulling kv_caches for %s finished", ok_reqs) + + if response.err_reqs: + logger.error( + "pulling kv_caches for %s failed: %s", + response.err_reqs, + response.err_msg, + ) - logger.debug("pulling kv_caches for %s finished", req_ids) + async def _connect_to_prefiller_bootstrap(self, remote_bootstrap_addr: str): + url = remote_bootstrap_addr + "/query" + try: + async with httpx.AsyncClient() as client: + response = await client.get(url) + response.raise_for_status() + data: dict = response.json() + for engine_id, engine_entry in data.items(): + remote_engine_id = engine_id + self._remote_agents[remote_engine_id] = { + int(dp_rank): { + int(tp_rank): { + int(pp_rank): worker_addr + for pp_rank, worker_addr in tp_entry.items() + } + for tp_rank, tp_entry in dp_entry.items() + } + for dp_rank, dp_entry in engine_entry.items() + } + any_dp_entry = next(iter(engine_entry.values())) + self._tp_size[remote_engine_id] = len(any_dp_entry) + except Exception as e: + logger.error( + "Failed to connect to bootstrap server %s: %s", + remote_bootstrap_addr, + e, + ) - def group_kv_pull(self, metadata: MooncakeConnectorMetadata): - kv_pulls = defaultdict(list) - for req_id, meta in metadata.reqs_to_recv.items(): - logger.debug( - "start_load_kv for request %s from remote engine. " - "Num local_block_ids: %s.", - req_id, - len(meta.local_block_ids), + # Always notify others regardless of connection success or failure. + self._pending_bootstrap_querys[remote_bootstrap_addr].set() + del self._pending_bootstrap_querys[remote_bootstrap_addr] + + def receive_kv( + self, + remote_engine_id: EngineId, + remote_dp_rank: int, + pull_metas: dict[ReqId, PullReqMeta], + ): + remote_tp_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( + remote_engine_id + ) + count = len(remote_tp_ranks) + if count != 1: + logger.error("Mooncake: Heterogeneous TP is not supported yet.") + return + for pull_meta in pull_metas.values(): + pull_meta.pull_tasks_count = count + for remote_tp_rank in remote_tp_ranks: + worker_addr = self._remote_agents[remote_engine_id][remote_dp_rank][ + remote_tp_rank + ][0] + asyncio.create_task( + self.receive_kv_from_single_worker(worker_addr, pull_metas) ) - path = make_zmq_path( - "tcp", meta.remote_host, meta.remote_port + self.tp_rank + + async def handle_new_engine_id( + self, + remote_engine_id: EngineId, + remote_dp_rank: int, + pull_metas: dict[ReqId, PullReqMeta], + ): + remote_bootstrap_addr = next(iter(pull_metas.values())).remote_bootstrap_addr + if remote_bootstrap_addr not in self._pending_bootstrap_querys: + self._pending_bootstrap_querys[remote_bootstrap_addr] = asyncio.Event() + await self._connect_to_prefiller_bootstrap(remote_bootstrap_addr) + else: + await self._pending_bootstrap_querys[remote_bootstrap_addr].wait() + + if remote_engine_id not in self._remote_agents: + logger.error( + "Failed to find remote engine_id %s from bootstrap server %s", + remote_engine_id, + remote_bootstrap_addr, ) - kv_pulls[path].append((req_id, meta.local_block_ids)) + return - return kv_pulls + self.receive_kv(remote_engine_id, remote_dp_rank, pull_metas) + + async def _start_load_kv( + self, reqs_to_recv: dict[tuple[EngineId, int], dict[ReqId, PullReqMeta]] + ): + for (remote_engine_id, remote_dp_rank), pull_metas in reqs_to_recv.items(): + if remote_engine_id not in self._remote_agents: + asyncio.create_task( + self.handle_new_engine_id( + remote_engine_id, remote_dp_rank, pull_metas + ) + ) + continue + self.receive_kv(remote_engine_id, remote_dp_rank, pull_metas) async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): - for req_id, block_ids in metadata.reqs_to_send.items(): + for p_req_id, block_ids in metadata.reqs_to_send.items(): if block_ids: # Already gone through request_finished() - send_meta = self.reqs_need_send[req_id] + send_meta = self.reqs_need_send[p_req_id] + send_meta.p_req_id = p_req_id send_meta.local_block_ids = block_ids send_meta.expire_time = ( time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT @@ -866,20 +1202,28 @@ async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): else: # From update_state_after_alloc(), # but not reach request_finished() yet - self.reqs_need_send[req_id] = SendBlockMeta( - local_block_ids=[], - ready=asyncio.Event(), - ) + # This may be already created by send_kv_to_decode() + # when D is sending MooncakeXferMetadata. + if p_req_id not in self.reqs_need_send: + self.reqs_need_send[p_req_id] = SendBlockMeta( + p_req_id=p_req_id, + local_block_ids=[], + ready=asyncio.Event(), + ) + for p_req_id in metadata.reqs_not_processed: + send_meta = self.reqs_need_send.pop(p_req_id) + if send_meta: + assert not send_meta.ready.is_set() def start_load_kv(self, metadata: MooncakeConnectorMetadata): - if self.kv_role != "kv_producer": - kv_pulls = self.group_kv_pull(metadata) - for path, req_blocks in kv_pulls.items(): - asyncio.run_coroutine_threadsafe( - self.receive_kv(path, req_blocks), self.receiver_loop - ) + if not self.is_kv_producer and metadata.reqs_to_recv: + asyncio.run_coroutine_threadsafe( + self._start_load_kv(metadata.reqs_to_recv), self.receiver_loop + ) - if self.kv_role != "kv_consumer": + if not self.is_kv_consumer and ( + metadata.reqs_to_send or metadata.reqs_not_processed + ): asyncio.run_coroutine_threadsafe( self.record_send_reqs(metadata), self.sender_loop ) @@ -914,3 +1258,31 @@ def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int: def _async_loop(loop: asyncio.AbstractEventLoop): asyncio.set_event_loop(loop) loop.run_forever() + + +def should_launch_bootstrap_server(vllm_config: VllmConfig) -> bool: + assert (parallel_config := vllm_config.parallel_config) + # In hybrid or external LB mode, + # each instance should have its own bootstrap server. + # + # In internal LB mode, + # only the real global first rank need to launch the bootstrap server. + return is_local_first_rank() and ( + parallel_config.local_engines_only or parallel_config.data_parallel_index == 0 + ) + + +def get_mooncake_bootstrap_addr(vllm_config: VllmConfig) -> tuple[str, int]: + """ + Returns the address of the Mooncake bootstrap server. + This is only used by prefillers to register workers. + Decoders should get addr from kv_transfer_params. + """ + assert (parallel_config := vllm_config.parallel_config) + if parallel_config.local_engines_only: + # In hybrid or external LB mode, connect to local server. + host = "127.0.0.1" + else: + host = parallel_config.data_parallel_master_ip + port = envs.VLLM_MOONCAKE_BOOTSTRAP_PORT + return (host, port) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py new file mode 100644 index 000000000000..80ed3efbb6a5 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +import time +from collections.abc import MutableMapping + +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.utils import EngineId +from vllm.logger import init_logger + +WorkerAddr = str + +logger = init_logger(__name__) + + +class RegisterWorkerPayload(BaseModel): + engine_id: EngineId + dp_rank: int + tp_rank: int + pp_rank: int + addr: WorkerAddr + + +# {dp_rank: {tp_rank: {pp_rank: worker_addr}}} +EngineEntry = dict[int, dict[int, dict[int, WorkerAddr]]] + + +class MooncakeBootstrapServer: + """ + A centralized server running on the global rank 0 prefiller worker. + Prefiller workers register their connection info (IP, port, ranks) here. + """ + + def __init__(self, vllm_config: VllmConfig, host: str, port: int): + # Since #30739, dp with non-Moe models are treated as separate worlds. + # Multiple dp ranks may have the same engine id because + # DPEngineCoreProc._init_data_parallel() is not called. + # So we cannot simply use engine id to distinguish dp ranks. + # Instead, we use [engine_id][dp_rank] to double check. + # + # For example, for vllm instance in 2 nodes and each with dp_size==2: + # + # Internal LB (non-Moe models): + # engine_id0 dp_rank=0 + # engine_id0 dp_rank=1 + # engine_id1 dp_rank=2 + # engine_id1 dp_rank=3 + # + # Internal LB (Moe models): + # engine_id0_dp0 dp_rank=0 + # engine_id0_dp1 dp_rank=1 + # engine_id1_dp0 dp_rank=2 + # engine_id3_dp1 dp_rank=3 + # + # Hybrid LB (non-Moe models): + # engine_id0 dp_rank=0 + # engine_id0 dp_rank=1 + # engine_id1 dp_rank=0 * + # engine_id1 dp_rank=1 * + # + # Hybrid LB (Moe models): + # engine_id0_dp0 dp_rank=0 + # engine_id0_dp1 dp_rank=1 + # engine_id1_dp0 dp_rank=0 * + # engine_id1_dp1 dp_rank=1 * + # + # External LB: + # engine_id0 dp_rank=0 + # engine_id1 dp_rank=0 * + # engine_id2 dp_rank=0 * + # engine_id3 dp_rank=0 * + # + # * here we use local dp_rank + + self.workers: dict[EngineId, EngineEntry] = {} + + assert (parallel_config := vllm_config.parallel_config) + dp_size = parallel_config.origin_data_parallel_size + dp_local_size = parallel_config.origin_data_parallel_size_local + self.dp_size = dp_local_size if parallel_config.local_engines_only else dp_size + # We should have these workers registered before serving requests. + self.total_count = parallel_config.world_size * self.dp_size + self.registered_count = 0 + + self.host = host + self.port = port + self.app = FastAPI() + self._register_routes() + self.server_thread: threading.Thread | None = None + self.server: uvicorn.Server | None = None + + def __del__(self): + self.shutdown() + + def _register_routes(self): + # All methods are async. No need to use lock to protect data. + self.app.post("/register")(self.register_worker) + self.app.get("/query", response_model=dict[EngineId, EngineEntry])(self.query) + + def start(self): + if self.server_thread: + return + + config = uvicorn.Config(app=self.app, host=self.host, port=self.port) + self.server = uvicorn.Server(config=config) + self.server_thread = threading.Thread( + target=self.server.run, name="mooncake_bootstrap_server", daemon=True + ) + self.server_thread.start() + while not self.server.started: + time.sleep(0.1) # Wait for the server to start + logger.info("Mooncake Bootstrap Server started at %s:%d", self.host, self.port) + + def shutdown(self): + if self.server_thread is None or self.server is None or not self.server.started: + return + + self.server.should_exit = True + self.server_thread.join() + logger.info("Mooncake Bootstrap Server stopped.") + + async def register_worker(self, payload: RegisterWorkerPayload): + """Handles registration of a prefiller worker.""" + if self.registered_count >= self.total_count: + raise HTTPException( + status_code=400, + detail=(f"All {self.total_count} workers have been registered"), + ) + if payload.engine_id not in self.workers: + self.workers[payload.engine_id] = {} + + engine_entry = self.workers[payload.engine_id] + if payload.dp_rank not in engine_entry: + engine_entry[payload.dp_rank] = {} + + dp_entry = engine_entry[payload.dp_rank] + if payload.tp_rank not in dp_entry: + dp_entry[payload.tp_rank] = {} + + tp_entry = dp_entry[payload.tp_rank] + if payload.pp_rank in tp_entry: + raise HTTPException( + status_code=400, + detail=( + f"Worker with dp_rank={payload.dp_rank}, " + f"tp_rank={payload.tp_rank}, pp_rank={payload.pp_rank} " + f"is already registered at " + f"{tp_entry[payload.pp_rank]}, " + f"but still want to register at {payload.addr}" + ), + ) + tp_entry[payload.pp_rank] = payload.addr + + logger.debug( + "Registered worker: engine_id=%s, dp_rank=%d, tp_rank=%d, pp_rank=%d at %s", + payload.engine_id, + payload.dp_rank, + payload.tp_rank, + payload.pp_rank, + payload.addr, + ) + + self.registered_count += 1 + return {"status": "ok"} + + async def query(self) -> dict[EngineId, EngineEntry]: + if self.registered_count < self.total_count: + raise HTTPException( + status_code=503, + detail=( + "Workers still registering: " + f"{self.registered_count}/{self.total_count}" + ), + ) + return self.workers + + +# Workaround for #27987 +# Drop the last "-{random_uuid():.8}" +# After #32630 or other solution is merged, we can remove this workaround. +class TruncatingDict(MutableMapping): + def __init__(self, *args, **kwargs): + self._store = dict() + self.update(dict(*args, **kwargs)) + + @staticmethod + def _truncate_key(key): + if not isinstance(key, str) or len(key) < 10: + raise TypeError("Keys must be strings with at least 10 characters") + return key[:-9] + + def __setitem__(self, key, value): + truncated_key = self._truncate_key(key) + self._store[truncated_key] = value + + def __getitem__(self, key): + truncated_key = self._truncate_key(key) + return self._store[truncated_key] + + def __delitem__(self, key): + truncated_key = self._truncate_key(key) + del self._store[truncated_key] + + def __contains__(self, key): + truncated_key = self._truncate_key(key) + return truncated_key in self._store + + def __iter__(self): + return iter(self._store) + + def __len__(self): + return len(self._store) + + def values(self): + return self._store.values() + + def items(self): + return self._store.items() + + def __repr__(self): + return f"{type(self).__name__}({self._store})" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 141e5a459c5b..717f4b89adf8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -912,6 +912,12 @@ def signal_handler(signum, frame): decorate_logs() parallel_config.data_parallel_index = dp_rank + parallel_config.origin_data_parallel_size = ( + parallel_config.data_parallel_size + ) + parallel_config.origin_data_parallel_size_local = ( + parallel_config.data_parallel_size_local + ) if data_parallel and vllm_config.model_config.is_moe: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank From 652d08cba34cf95413d2a8e92192e939f3b0e002 Mon Sep 17 00:00:00 2001 From: Tianchen Ding Date: Fri, 23 Jan 2026 15:52:33 +0800 Subject: [PATCH 2/7] add init.py Signed-off-by: Tianchen Ding --- vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__init__.py diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From fde56139103e4894d7b6c3d2f681f71f4126de21 Mon Sep 17 00:00:00 2001 From: Tianchen Ding Date: Fri, 23 Jan 2026 16:00:59 +0800 Subject: [PATCH 3/7] fix a typo in comment Signed-off-by: Tianchen Ding --- .../kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py index 80ed3efbb6a5..061e651e3811 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py @@ -54,7 +54,7 @@ def __init__(self, vllm_config: VllmConfig, host: str, port: int): # engine_id0_dp0 dp_rank=0 # engine_id0_dp1 dp_rank=1 # engine_id1_dp0 dp_rank=2 - # engine_id3_dp1 dp_rank=3 + # engine_id1_dp1 dp_rank=3 # # Hybrid LB (non-Moe models): # engine_id0 dp_rank=0 From 5c512330429931f2e70242f3fcb2cd3a11965da9 Mon Sep 17 00:00:00 2001 From: Tianchen Ding Date: Wed, 28 Jan 2026 14:12:39 +0800 Subject: [PATCH 4/7] remove workaround for engine_id, dp_size health check and req_id Signed-off-by: Tianchen Ding --- .../mooncake_connector_proxy.py | 46 ++--- vllm/config/parallel.py | 10 - .../v1/mooncake/mooncake_connector.py | 175 +++++++++--------- .../v1/mooncake/mooncake_utils.py | 101 +++------- vllm/v1/engine/core.py | 6 - 5 files changed, 132 insertions(+), 206 deletions(-) diff --git a/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py b/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py index c5867b2846ac..09880a32aa35 100644 --- a/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py +++ b/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py @@ -5,7 +5,6 @@ import asyncio import ipaddress import itertools -import logging import os import urllib import uuid @@ -16,9 +15,6 @@ from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - def maybe_wrap_ipv6_address(address: str) -> str: try: @@ -43,24 +39,28 @@ async def get_prefiller_info(prefill_clients: list, ready: asyncio.Event): for prefill_client in prefill_clients: while True: try: - response = await prefill_client["client"].get( - prefill_client["bootstrap_addr"] + "/query" - ) + # Wait for prefill service to be ready + response = await prefill_client["client"].get("/health") response.raise_for_status() - data = response.json() - break except Exception: await asyncio.sleep(1) + continue + + response = await prefill_client["client"].get( + prefill_client["bootstrap_addr"] + "/query" + ) + response.raise_for_status() + data = response.json() + break - dp_size = 0 - for engine_id, engine_entry in data.items(): - dp_size += len(engine_entry) - for dp_rank in engine_entry: - prefill_client["dp_engine_id"][int(dp_rank)] = engine_id + for dp_rank, dp_entry in data.items(): + prefill_client["dp_engine_id"][int(dp_rank)] = dp_entry["engine_id"] + dp_size = len(data) prefill_client["dp_size"] = dp_size + print(f"Inited prefiller {prefill_client['url']} with dp_size={dp_size}") ready.set() - logger.info("All prefiller instances are ready.") + print("All prefiller instances are ready.") @asynccontextmanager @@ -87,6 +87,7 @@ async def lifespan(app: FastAPI): max_keepalive_connections=None, ), ), + "url": url, "bootstrap_addr": make_http_path(hostname, bootstrap_port or 8998), "dp_engine_id": {}, } @@ -116,7 +117,7 @@ async def lifespan(app: FastAPI): app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients))) print( - f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"Got {len(app.state.prefill_clients)} prefill clients " f"and {len(app.state.decode_clients)} decode clients." ) @@ -256,6 +257,7 @@ async def send_request_to_service( req_data["kv_transfer_params"] = { "do_remote_decode": True, "do_remote_prefill": False, + "transfer_id": f"xfer-{request_id}", } req_data["stream"] = False req_data["max_tokens"] = 1 @@ -299,7 +301,7 @@ async def stream_service_response( "do_remote_prefill": True, "remote_bootstrap_addr": prefill_client_info["bootstrap_addr"], "remote_engine_id": prefill_client_info["dp_engine_id"][prefill_dp_rank], - "remote_dp_rank": prefill_dp_rank, + "transfer_id": f"xfer-{request_id}", } async with decode_client_info["client"].stream( @@ -365,16 +367,6 @@ async def handle_chat_completions(request: Request): return await _handle_completions("/v1/chat/completions", request) -@app.get("/healthcheck") -async def healthcheck(): - """Simple endpoint to check if the server is running.""" - return { - "status": "ok", - "prefill_instances": len(app.state.prefill_clients), - "decode_instances": len(app.state.decode_clients), - } - - if __name__ == "__main__": global global_args global_args = parse_args() diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index dd7d1ee1eb62..16487d744e20 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -280,14 +280,6 @@ class is dynamically inherited by the worker class. This is used to inject """Equal to the data parallel rank but not used for torch process groups and not overridden for dense models.""" - origin_data_parallel_size: int = Field(init=False) - """Equal to the data parallel size but not used for torch process groups - and not overridden for dense models.""" - - origin_data_parallel_size_local: int = Field(init=False) - """Equal to the data parallel size local but not used for torch process groups - and not overridden for dense models.""" - _api_process_count: int = Field(default=1, gt=0) """ The number of API processes initialized. @@ -602,8 +594,6 @@ def __post_init__(self) -> None: ) self.data_parallel_index = self.data_parallel_rank - self.origin_data_parallel_size = self.data_parallel_size - self.origin_data_parallel_size_local = self.data_parallel_size_local if self.distributed_executor_backend == "external_launcher": os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index b0e447b1eed5..7cf8650bdb23 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -4,7 +4,6 @@ import threading import time from collections import defaultdict -from collections.abc import MutableMapping from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from enum import IntEnum @@ -30,10 +29,8 @@ KVConnectorRole, ) from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import ( - EngineEntry, MooncakeBootstrapServer, RegisterWorkerPayload, - TruncatingDict, ) from vllm.distributed.parallel_state import ( get_pp_group, @@ -63,7 +60,8 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request -ReqId = str +ReqId = str # Internal scheduler request ID +TransferId = str # KV transfer coordination ID (shared by P/D) logger = init_logger(__name__) @@ -76,7 +74,7 @@ class MooncakeXferMetadata( remote_port: int remote_tp_size: int remote_tp_rank: int - req_blocks: dict[ReqId, list[int]] + req_blocks: dict[ReqId, tuple[TransferId, list[int]]] kv_caches_base_addr: list[int] @@ -101,7 +99,8 @@ class MooncakeXferResponse( @dataclass class PullReqMeta: - req_id: ReqId + d_req_id: ReqId + transfer_id: TransferId local_block_ids: list[int] remote_engine_id: EngineId remote_bootstrap_addr: str @@ -114,6 +113,7 @@ class PullReqMeta: @dataclass class SendBlockMeta: p_req_id: ReqId + transfer_id: TransferId local_block_ids: list[int] ready: asyncio.Event expire_time: float = float("inf") @@ -126,11 +126,9 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): def __init__(self): # Use (engine_id, dp_rank) to group reqs with same dp. # See comments in MooncakeBootstrapServer. - self.reqs_to_recv: dict[tuple[EngineId, int], dict[ReqId, PullReqMeta]] = ( - defaultdict(dict) - ) - self.reqs_to_send: dict[ReqId, list[int]] = {} - self.reqs_not_processed: set[ReqId] = set() + self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict) + self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {} + self.reqs_not_processed: set[TransferId] = set() def add_new_req( self, @@ -139,18 +137,18 @@ def add_new_req( kv_transfer_params: dict[str, Any], load_remote_cache: bool = True, ): + transfer_id = kv_transfer_params["transfer_id"] if load_remote_cache: remote_engine_id = kv_transfer_params["remote_engine_id"] - self.reqs_to_recv[(remote_engine_id, kv_transfer_params["remote_dp_rank"])][ - request_id - ] = PullReqMeta( - req_id=request_id, + self.reqs_to_recv[remote_engine_id][request_id] = PullReqMeta( + d_req_id=request_id, local_block_ids=local_block_ids, remote_engine_id=remote_engine_id, remote_bootstrap_addr=kv_transfer_params["remote_bootstrap_addr"], + transfer_id=transfer_id, ) else: - self.reqs_to_send[request_id] = local_block_ids + self.reqs_to_send[request_id] = (transfer_id, local_block_ids) class MooncakeConnector(KVConnectorBase_V1): @@ -266,10 +264,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} - self._reqs_need_send: dict[ReqId, list[int]] = {} + self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {} # Reqs to remove from processed set because they're not to send after # remote prefill or aborted. - self._reqs_not_processed: set[ReqId] = set() + self._reqs_not_processed: set[TransferId] = set() def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int @@ -330,7 +328,7 @@ def update_state_after_alloc( assert not self.is_kv_producer if all( p in params - for p in ("remote_engine_id", "remote_bootstrap_addr", "remote_dp_rank") + for p in ("remote_engine_id", "remote_bootstrap_addr", "transfer_id") ): # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call @@ -350,9 +348,16 @@ def update_state_after_alloc( params["do_remote_prefill"] = False elif params.get("do_remote_decode"): - # Add an empty list to worker to create event. assert not self.is_kv_consumer - self._reqs_need_send[request.request_id] = [] + if not params.get("transfer_id"): + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", + params, + ) + else: + # Add an empty list to worker to create event. + self._reqs_need_send[request.request_id] = (request, []) def build_connector_meta( self, @@ -360,7 +365,7 @@ def build_connector_meta( ) -> KVConnectorMetadata: meta = MooncakeConnectorMetadata() - # Loop through scheduled reqs and convert to RecvReqMeta. + # Loop through scheduled reqs and convert to PullReqMeta. if not self.is_kv_producer: for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None @@ -372,11 +377,12 @@ def build_connector_meta( self._reqs_need_recv.clear() if not self.is_kv_consumer: - for req_id, block_ids in self._reqs_need_send.items(): + for req_id, (req, block_ids) in self._reqs_need_send.items(): + assert req.kv_transfer_params is not None meta.add_new_req( request_id=req_id, local_block_ids=block_ids, - kv_transfer_params={}, + kv_transfer_params=req.kv_transfer_params, load_remote_cache=False, ) self._reqs_need_send.clear() @@ -403,7 +409,7 @@ def request_finished( request.status, params, ) - if not params: + if not params or not params.get("transfer_id"): return False, None if params.get("do_remote_prefill"): @@ -426,7 +432,7 @@ def request_finished( if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: # Also include the case of a P/D Prefill request with immediate # block free (eg abort). Stop tracking this request. - self._reqs_not_processed.add(request.request_id) + self._reqs_not_processed.add(params["transfer_id"]) return False, None # TODO: check whether block_ids actually ever be 0. If not we could @@ -434,7 +440,7 @@ def request_finished( delay_free_blocks = len(block_ids) > 0 if delay_free_blocks: - self._reqs_need_send[request.request_id] = block_ids + self._reqs_need_send[request.request_id] = (request, block_ids) return delay_free_blocks, None @@ -478,7 +484,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.rpc_port, ) - self._remote_agents: dict[EngineId, EngineEntry] = {} + self._remote_agents: dict[EngineId, dict[int, dict[int, str]]] = {} self._pending_bootstrap_querys: dict[str, asyncio.Event] = {} self.side_channel_port: int = 0 # we will bind it in register_kv_caches() self.engine_id: EngineId = engine_id @@ -499,7 +505,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.kv_caches_base_addr: list[int] = [] self.device_kv_caches: dict[str, torch.Tensor] = {} - self.reqs_need_send: MutableMapping[ReqId, SendBlockMeta] = TruncatingDict() + self.reqs_need_send: dict[TransferId, SendBlockMeta] = {} # Only used by prefillers. host, port = get_mooncake_bootstrap_addr(vllm_config) @@ -695,13 +701,16 @@ async def send_kv_to_decode( ) await sock.send_multipart((identity, self._encoder.encode(response))) return - for d_req_id in meta.req_blocks: - if d_req_id not in self.reqs_need_send: + for d_req_id, (transfer_id, _) in meta.req_blocks.items(): + if transfer_id not in self.reqs_need_send: # This req is not enqueued in P side yet, create it here. - self.reqs_need_send[d_req_id] = SendBlockMeta( - p_req_id="", local_block_ids=[], ready=asyncio.Event() + self.reqs_need_send[transfer_id] = SendBlockMeta( + p_req_id="", + transfer_id=transfer_id, + local_block_ids=[], + ready=asyncio.Event(), ) - send_meta = self.reqs_need_send[d_req_id] + send_meta = self.reqs_need_send[transfer_id] pending_reqs[d_req_id] = send_meta async def wait_and_ret( @@ -745,7 +754,7 @@ async def wait_and_ret( d_req_id, send_meta = task.result() del pending_reqs[d_req_id] # Do we still in reqs_need_send (not expired)? - if d_req_id in self.reqs_need_send: + if send_meta.transfer_id in self.reqs_need_send: # Mark it sending to avoid expiration. send_meta.sending += 1 if not send_meta.need_send: @@ -799,7 +808,7 @@ async def wait_and_ret( send_meta.sending -= 1 send_meta.sended += 1 if send_meta.sended == send_meta.need_send: - del self.reqs_need_send[d_req_id] + del self.reqs_need_send[send_meta.transfer_id] self.finished_sending_reqs.add(send_meta.p_req_id) response = MooncakeXferResponse( @@ -829,7 +838,7 @@ async def _build_transfer_params( remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}" for d_req_id, send_meta in ready_reqs: - remote_block_ids = agent_meta.req_blocks[d_req_id] + _, remote_block_ids = agent_meta.req_blocks[d_req_id] num_remote_blocks = len(remote_block_ids) if num_remote_blocks == 0: continue @@ -969,23 +978,25 @@ async def fetch_finished_sending_reqs(self) -> set[ReqId]: # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter() - expired_reqs = [ - send_meta.p_req_id - for send_meta in self.reqs_need_send.values() - if send_meta.p_req_id - and send_meta.expire_time < now - and send_meta.sending == 0 - ] - for req_id in expired_reqs: - logger.warning( - "Request %s timed out after %d seconds without " - "being sent. Freeing its blocks on the producer side.", - req_id, - envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, - ) - del self.reqs_need_send[req_id] - if expired_reqs: - finished_sending_reqs.update(expired_reqs) + + expired_transfer_id = [] + for transfer_id, send_meta in self.reqs_need_send.items(): + if ( + send_meta.p_req_id + and send_meta.expire_time < now + and send_meta.sending == 0 + ): + logger.warning( + "Request %s timed out after %d seconds without " + "being sent. Freeing its blocks on the producer side.", + send_meta.p_req_id, + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, + ) + finished_sending_reqs.add(send_meta.p_req_id) + expired_transfer_id.append(transfer_id) + + for transfer_id in expired_transfer_id: + del self.reqs_need_send[transfer_id] return finished_sending_reqs @@ -1033,7 +1044,7 @@ async def receive_kv_from_single_worker( remote_tp_size=self.tp_size, remote_tp_rank=self.tp_rank, req_blocks={ - req_id: pull_meta.local_block_ids + req_id: (pull_meta.transfer_id, pull_meta.local_block_ids) for req_id, pull_meta in pull_metas.items() }, kv_caches_base_addr=self.kv_caches_base_addr, @@ -1063,7 +1074,7 @@ async def receive_kv_from_single_worker( response.err_msg, ) return - self.process_pulling_result(response, pull_metas, worker_addr) + self.process_pulling_result(response, pull_metas) if response.status == MooncakeXferResponseStatus.FINISH: break except zmq.ContextTerminated: @@ -1076,7 +1087,6 @@ def process_pulling_result( self, response: MooncakeXferResponse, pull_metas: dict[ReqId, PullReqMeta], - worker_addr: str, ): ok_reqs: list[ReqId] = response.ok_reqs or [] @@ -1085,7 +1095,7 @@ def process_pulling_result( # No race because we are in async loop. pull_meta.pull_tasks_count -= 1 if pull_meta.pull_tasks_count == 0: - self.finished_recving_reqs.add(pull_meta.req_id) + self.finished_recving_reqs.add(pull_meta.d_req_id) if ok_reqs: logger.debug("pulling kv_caches for %s finished", ok_reqs) @@ -1104,20 +1114,16 @@ async def _connect_to_prefiller_bootstrap(self, remote_bootstrap_addr: str): response = await client.get(url) response.raise_for_status() data: dict = response.json() - for engine_id, engine_entry in data.items(): - remote_engine_id = engine_id + for _, dp_entry in data.items(): + remote_engine_id = dp_entry["engine_id"] self._remote_agents[remote_engine_id] = { - int(dp_rank): { - int(tp_rank): { - int(pp_rank): worker_addr - for pp_rank, worker_addr in tp_entry.items() - } - for tp_rank, tp_entry in dp_entry.items() + int(tp_rank): { + int(pp_rank): worker_addr + for pp_rank, worker_addr in tp_entry.items() } - for dp_rank, dp_entry in engine_entry.items() + for tp_rank, tp_entry in dp_entry["worker_addr"].items() } - any_dp_entry = next(iter(engine_entry.values())) - self._tp_size[remote_engine_id] = len(any_dp_entry) + self._tp_size[remote_engine_id] = len(dp_entry["worker_addr"]) except Exception as e: logger.error( "Failed to connect to bootstrap server %s: %s", @@ -1132,7 +1138,6 @@ async def _connect_to_prefiller_bootstrap(self, remote_bootstrap_addr: str): def receive_kv( self, remote_engine_id: EngineId, - remote_dp_rank: int, pull_metas: dict[ReqId, PullReqMeta], ): remote_tp_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( @@ -1145,9 +1150,7 @@ def receive_kv( for pull_meta in pull_metas.values(): pull_meta.pull_tasks_count = count for remote_tp_rank in remote_tp_ranks: - worker_addr = self._remote_agents[remote_engine_id][remote_dp_rank][ - remote_tp_rank - ][0] + worker_addr = self._remote_agents[remote_engine_id][remote_tp_rank][0] asyncio.create_task( self.receive_kv_from_single_worker(worker_addr, pull_metas) ) @@ -1155,7 +1158,6 @@ def receive_kv( async def handle_new_engine_id( self, remote_engine_id: EngineId, - remote_dp_rank: int, pull_metas: dict[ReqId, PullReqMeta], ): remote_bootstrap_addr = next(iter(pull_metas.values())).remote_bootstrap_addr @@ -1173,26 +1175,24 @@ async def handle_new_engine_id( ) return - self.receive_kv(remote_engine_id, remote_dp_rank, pull_metas) + self.receive_kv(remote_engine_id, pull_metas) async def _start_load_kv( - self, reqs_to_recv: dict[tuple[EngineId, int], dict[ReqId, PullReqMeta]] + self, reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] ): - for (remote_engine_id, remote_dp_rank), pull_metas in reqs_to_recv.items(): + for remote_engine_id, pull_metas in reqs_to_recv.items(): if remote_engine_id not in self._remote_agents: asyncio.create_task( - self.handle_new_engine_id( - remote_engine_id, remote_dp_rank, pull_metas - ) + self.handle_new_engine_id(remote_engine_id, pull_metas) ) continue - self.receive_kv(remote_engine_id, remote_dp_rank, pull_metas) + self.receive_kv(remote_engine_id, pull_metas) async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): - for p_req_id, block_ids in metadata.reqs_to_send.items(): + for p_req_id, (transfer_id, block_ids) in metadata.reqs_to_send.items(): if block_ids: # Already gone through request_finished() - send_meta = self.reqs_need_send[p_req_id] + send_meta = self.reqs_need_send[transfer_id] send_meta.p_req_id = p_req_id send_meta.local_block_ids = block_ids send_meta.expire_time = ( @@ -1204,14 +1204,15 @@ async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): # but not reach request_finished() yet # This may be already created by send_kv_to_decode() # when D is sending MooncakeXferMetadata. - if p_req_id not in self.reqs_need_send: - self.reqs_need_send[p_req_id] = SendBlockMeta( + if transfer_id not in self.reqs_need_send: + self.reqs_need_send[transfer_id] = SendBlockMeta( p_req_id=p_req_id, + transfer_id=transfer_id, local_block_ids=[], ready=asyncio.Event(), ) - for p_req_id in metadata.reqs_not_processed: - send_meta = self.reqs_need_send.pop(p_req_id) + for transfer_id in metadata.reqs_not_processed: + send_meta = self.reqs_need_send.pop(transfer_id) if send_meta: assert not send_meta.ready.is_set() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py index 061e651e3811..102aa988f6fd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py @@ -3,6 +3,7 @@ import threading import time from collections.abc import MutableMapping +from dataclasses import dataclass import uvicorn from fastapi import FastAPI, HTTPException @@ -25,8 +26,11 @@ class RegisterWorkerPayload(BaseModel): addr: WorkerAddr -# {dp_rank: {tp_rank: {pp_rank: worker_addr}}} -EngineEntry = dict[int, dict[int, dict[int, WorkerAddr]]] +@dataclass +class EngineEntry: + engine_id: EngineId + # {tp_rank: {pp_rank: worker_addr}} + worker_addr: dict[int, dict[int, WorkerAddr]] class MooncakeBootstrapServer: @@ -36,55 +40,7 @@ class MooncakeBootstrapServer: """ def __init__(self, vllm_config: VllmConfig, host: str, port: int): - # Since #30739, dp with non-Moe models are treated as separate worlds. - # Multiple dp ranks may have the same engine id because - # DPEngineCoreProc._init_data_parallel() is not called. - # So we cannot simply use engine id to distinguish dp ranks. - # Instead, we use [engine_id][dp_rank] to double check. - # - # For example, for vllm instance in 2 nodes and each with dp_size==2: - # - # Internal LB (non-Moe models): - # engine_id0 dp_rank=0 - # engine_id0 dp_rank=1 - # engine_id1 dp_rank=2 - # engine_id1 dp_rank=3 - # - # Internal LB (Moe models): - # engine_id0_dp0 dp_rank=0 - # engine_id0_dp1 dp_rank=1 - # engine_id1_dp0 dp_rank=2 - # engine_id1_dp1 dp_rank=3 - # - # Hybrid LB (non-Moe models): - # engine_id0 dp_rank=0 - # engine_id0 dp_rank=1 - # engine_id1 dp_rank=0 * - # engine_id1 dp_rank=1 * - # - # Hybrid LB (Moe models): - # engine_id0_dp0 dp_rank=0 - # engine_id0_dp1 dp_rank=1 - # engine_id1_dp0 dp_rank=0 * - # engine_id1_dp1 dp_rank=1 * - # - # External LB: - # engine_id0 dp_rank=0 - # engine_id1 dp_rank=0 * - # engine_id2 dp_rank=0 * - # engine_id3 dp_rank=0 * - # - # * here we use local dp_rank - - self.workers: dict[EngineId, EngineEntry] = {} - - assert (parallel_config := vllm_config.parallel_config) - dp_size = parallel_config.origin_data_parallel_size - dp_local_size = parallel_config.origin_data_parallel_size_local - self.dp_size = dp_local_size if parallel_config.local_engines_only else dp_size - # We should have these workers registered before serving requests. - self.total_count = parallel_config.world_size * self.dp_size - self.registered_count = 0 + self.workers: dict[int, EngineEntry] = {} self.host = host self.port = port @@ -99,7 +55,7 @@ def __del__(self): def _register_routes(self): # All methods are async. No need to use lock to protect data. self.app.post("/register")(self.register_worker) - self.app.get("/query", response_model=dict[EngineId, EngineEntry])(self.query) + self.app.get("/query", response_model=dict[int, EngineEntry])(self.query) def start(self): if self.server_thread: @@ -125,23 +81,25 @@ def shutdown(self): async def register_worker(self, payload: RegisterWorkerPayload): """Handles registration of a prefiller worker.""" - if self.registered_count >= self.total_count: + if payload.dp_rank not in self.workers: + self.workers[payload.dp_rank] = EngineEntry( + engine_id=payload.engine_id, + worker_addr={}, + ) + + dp_entry = self.workers[payload.dp_rank] + if dp_entry.engine_id != payload.engine_id: raise HTTPException( status_code=400, - detail=(f"All {self.total_count} workers have been registered"), + detail=( + f"Engine ID mismatch for dp_rank={payload.dp_rank}: " + f"expected {dp_entry.engine_id}, got {payload.engine_id}" + ), ) - if payload.engine_id not in self.workers: - self.workers[payload.engine_id] = {} - - engine_entry = self.workers[payload.engine_id] - if payload.dp_rank not in engine_entry: - engine_entry[payload.dp_rank] = {} - - dp_entry = engine_entry[payload.dp_rank] - if payload.tp_rank not in dp_entry: - dp_entry[payload.tp_rank] = {} + if payload.tp_rank not in dp_entry.worker_addr: + dp_entry.worker_addr[payload.tp_rank] = {} - tp_entry = dp_entry[payload.tp_rank] + tp_entry = dp_entry.worker_addr[payload.tp_rank] if payload.pp_rank in tp_entry: raise HTTPException( status_code=400, @@ -153,8 +111,8 @@ async def register_worker(self, payload: RegisterWorkerPayload): f"but still want to register at {payload.addr}" ), ) - tp_entry[payload.pp_rank] = payload.addr + tp_entry[payload.pp_rank] = payload.addr logger.debug( "Registered worker: engine_id=%s, dp_rank=%d, tp_rank=%d, pp_rank=%d at %s", payload.engine_id, @@ -164,18 +122,9 @@ async def register_worker(self, payload: RegisterWorkerPayload): payload.addr, ) - self.registered_count += 1 return {"status": "ok"} - async def query(self) -> dict[EngineId, EngineEntry]: - if self.registered_count < self.total_count: - raise HTTPException( - status_code=503, - detail=( - "Workers still registering: " - f"{self.registered_count}/{self.total_count}" - ), - ) + async def query(self) -> dict[int, EngineEntry]: return self.workers diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ad0ec1f121ef..d5e75824d2e3 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -923,12 +923,6 @@ def signal_handler(signum, frame): ) parallel_config.data_parallel_index = dp_rank - parallel_config.origin_data_parallel_size = ( - parallel_config.data_parallel_size - ) - parallel_config.origin_data_parallel_size_local = ( - parallel_config.data_parallel_size_local - ) if data_parallel and vllm_config.model_config.is_moe: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank From f6770637a20d40576ffe70ea4cd88d44ebe43f52 Mon Sep 17 00:00:00 2001 From: Tianchen Ding Date: Thu, 29 Jan 2026 17:59:13 +0800 Subject: [PATCH 5/7] cleanup and fix Signed-off-by: Tianchen Ding --- .../v1/mooncake/mooncake_connector.py | 4 ++ .../v1/mooncake/mooncake_utils.py | 47 ------------------- 2 files changed, 4 insertions(+), 47 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index 7cf8650bdb23..1a384090e7ca 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -1063,6 +1063,10 @@ async def receive_kv_from_single_worker( with make_zmq_socket( self.async_zmq_ctx, worker_addr, zmq.DEALER, bind=False, linger=0 ) as sock: + # If something goes wrong, let P wait timeout first (in asyncio.wait()). + sock.setsockopt( + zmq.RCVTIMEO, (envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT + 60) * 1000 + ) await sock.send(encoded_data) while True: ret_msg = await sock.recv() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py index 102aa988f6fd..d1a9946709d8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading import time -from collections.abc import MutableMapping from dataclasses import dataclass import uvicorn @@ -126,49 +125,3 @@ async def register_worker(self, payload: RegisterWorkerPayload): async def query(self) -> dict[int, EngineEntry]: return self.workers - - -# Workaround for #27987 -# Drop the last "-{random_uuid():.8}" -# After #32630 or other solution is merged, we can remove this workaround. -class TruncatingDict(MutableMapping): - def __init__(self, *args, **kwargs): - self._store = dict() - self.update(dict(*args, **kwargs)) - - @staticmethod - def _truncate_key(key): - if not isinstance(key, str) or len(key) < 10: - raise TypeError("Keys must be strings with at least 10 characters") - return key[:-9] - - def __setitem__(self, key, value): - truncated_key = self._truncate_key(key) - self._store[truncated_key] = value - - def __getitem__(self, key): - truncated_key = self._truncate_key(key) - return self._store[truncated_key] - - def __delitem__(self, key): - truncated_key = self._truncate_key(key) - del self._store[truncated_key] - - def __contains__(self, key): - truncated_key = self._truncate_key(key) - return truncated_key in self._store - - def __iter__(self): - return iter(self._store) - - def __len__(self): - return len(self._store) - - def values(self): - return self._store.values() - - def items(self): - return self._store.items() - - def __repr__(self): - return f"{type(self).__name__}({self._store})" From 2bf31b039828938f207b1d69a7fa1d9f29369152 Mon Sep 17 00:00:00 2001 From: dtc Date: Thu, 29 Jan 2026 18:32:38 +0800 Subject: [PATCH 6/7] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Nicolò Lucchesi Signed-off-by: Tianchen Ding --- .../kv_connector/v1/mooncake/mooncake_connector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index 1a384090e7ca..b00efbb36ea9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -118,7 +118,7 @@ class SendBlockMeta: ready: asyncio.Event expire_time: float = float("inf") need_send: int = 0 - sended: int = 0 + sent: int = 0 sending: int = 0 @@ -802,7 +802,7 @@ async def wait_and_ret( continue for d_req_id, send_meta in ready_reqs: - # Todo: for heterogeneous TP (one P pairs to multiple D), + # TODO: for heterogeneous TP (one P pairs to multiple D), # we need to check whether all headers are sent. # If not, we should set expire_time to normal and skip the below. send_meta.sending -= 1 @@ -1189,8 +1189,8 @@ async def _start_load_kv( asyncio.create_task( self.handle_new_engine_id(remote_engine_id, pull_metas) ) - continue - self.receive_kv(remote_engine_id, pull_metas) + else: + self.receive_kv(remote_engine_id, pull_metas) async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): for p_req_id, (transfer_id, block_ids) in metadata.reqs_to_send.items(): From 998877182d8c92f3414af168f938fa6ca3a3e45c Mon Sep 17 00:00:00 2001 From: Tianchen Ding Date: Thu, 29 Jan 2026 18:52:47 +0800 Subject: [PATCH 7/7] fix Signed-off-by: Tianchen Ding --- .../v1/mooncake/mooncake_connector.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index b00efbb36ea9..616caef4b39c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -350,11 +350,7 @@ def update_state_after_alloc( elif params.get("do_remote_decode"): assert not self.is_kv_consumer if not params.get("transfer_id"): - logger.warning( - "Got invalid KVTransferParams: %s. This " - "request will not utilize KVTransfer", - params, - ) + logger.warning("Missing transfer_id in kv_transfer_params from router!") else: # Add an empty list to worker to create event. self._reqs_need_send[request.request_id] = (request, []) @@ -507,10 +503,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.device_kv_caches: dict[str, torch.Tensor] = {} self.reqs_need_send: dict[TransferId, SendBlockMeta] = {} - # Only used by prefillers. - host, port = get_mooncake_bootstrap_addr(vllm_config) - self.bootstrap_addr = make_zmq_path("http", host, port) - # For kv_both, we will act both prefiller and decoder. if not self.is_kv_consumer: # Background threads for sending kvcaches to D. @@ -533,6 +525,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Start bootstrap server on global rank 0. if should_launch_bootstrap_server(vllm_config): + _, port = get_mooncake_bootstrap_addr(vllm_config) self.bootstrap_server = MooncakeBootstrapServer( vllm_config, "0.0.0.0", port ) @@ -597,7 +590,8 @@ def shutdown(self): self._mooncake_receiver_t.join() async def register_worker_with_bootstrap(self): - url = self.bootstrap_addr + "/register" + host, port = get_mooncake_bootstrap_addr(self.vllm_config) + url = make_zmq_path("http", host, port) + "/register" worker_addr = make_zmq_path("tcp", self.hostname, self.side_channel_port) payload = RegisterWorkerPayload( engine_id=self.engine_id, @@ -735,6 +729,9 @@ async def wait_and_ret( # Timeout, abort all pending requests. for task in wait_tasks: task.cancel() + logger.warning( + "Timeout waiting for P side ready: %s", list(pending_reqs) + ) response = MooncakeXferResponse( status=MooncakeXferResponseStatus.FINISH, err_reqs=list(pending_reqs), @@ -760,7 +757,11 @@ async def wait_and_ret( if not send_meta.need_send: self.resolve_need_send(send_meta, remote_tp_ranks) ready_reqs.append((d_req_id, send_meta)) - # Otherwise (expired, very unlikely), forget it. Do not let D retry. + else: + # Otherwise (expired, very unlikely), just forget it. + logger.warning( + "Request %s expired before sending on P side.", d_req_id + ) src_ptrs, dst_ptrs, lengths, err_reqs = await self._build_transfer_params( ready_reqs, meta @@ -806,8 +807,8 @@ async def wait_and_ret( # we need to check whether all headers are sent. # If not, we should set expire_time to normal and skip the below. send_meta.sending -= 1 - send_meta.sended += 1 - if send_meta.sended == send_meta.need_send: + send_meta.sent += 1 + if send_meta.sent == send_meta.need_send: del self.reqs_need_send[send_meta.transfer_id] self.finished_sending_reqs.add(send_meta.p_req_id) @@ -822,6 +823,9 @@ def resolve_need_send(self, send_meta: SendBlockMeta, remote_tp_ranks: list[int] send_meta.need_send = len(remote_tp_ranks) if send_meta.need_send != 1: logger.error("Mooncake: Heterogeneous TP is not supported yet.") + raise NotImplementedError( + "Mooncake: Heterogeneous TP is not supported yet." + ) async def _build_transfer_params( self, @@ -1150,7 +1154,9 @@ def receive_kv( count = len(remote_tp_ranks) if count != 1: logger.error("Mooncake: Heterogeneous TP is not supported yet.") - return + raise NotImplementedError( + "Mooncake: Heterogeneous TP is not supported yet." + ) for pull_meta in pull_metas.values(): pull_meta.pull_tasks_count = count for remote_tp_rank in remote_tp_ranks: @@ -1190,7 +1196,7 @@ async def _start_load_kv( self.handle_new_engine_id(remote_engine_id, pull_metas) ) else: - self.receive_kv(remote_engine_id, pull_metas) + self.receive_kv(remote_engine_id, pull_metas) async def record_send_reqs(self, metadata: MooncakeConnectorMetadata): for p_req_id, (transfer_id, block_ids) in metadata.reqs_to_send.items():