diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py new file mode 100644 index 00000000000..3af8fbe02db --- /dev/null +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py @@ -0,0 +1,518 @@ +# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py + +# SPDX-License-Identifier: Apache-2.0 +# +# Tutorial: Using the Load Balance Proxy Server Example +# +# This proxy server is designed to distribute requests between multiple +# "prefiller" and "decoder" backend servers for large language model inference. +# It is useful for scaling out inference workloads and balancing load across +# multiple backend instances. +# +# Features: +# - Load balances requests to multiple prefiller and decoder servers. +# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints. +# - Streams responses from backend servers to clients. +# +# Prerequisites: +# - Python 3.8+ +# - Install dependencies: +# pip install fastapi httpx uvicorn vllm +# +# Step 1: Start Your Backend Servers +# ---------------------------------- +# You need to have at least one prefiller and one decoder backend running. +# These can be mock servers or actual vLLM servers. +# +# For testing, you can use the provided mock server: +# +# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1 +# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2 +# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1 +# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2 +# +# Step 2: Start the Proxy Server +# ------------------------------ +# Run the proxy server, specifying the host/port for each prefiller and decoder: +# +# python load_balance_proxy_server_example.py \ +# --host 0.0.0.0 --port 9000 \ +# --prefiller-hosts 127.0.0.1 127.0.0.1 \ +# --prefiller-ports 8100 8101 \ +# --decoder-hosts 127.0.0.1 127.0.0.1 \ +# --decoder-ports 8200 8201 +# +# This will start the proxy on port 9000, load balancing between two prefiller +# and two decoder servers. +# +# Step 3: Send a Request to the Proxy +# ----------------------------------- +# You can now send OpenAI-compatible requests to the proxy. For example: +# +# curl -X POST http://localhost:9000/v1/completions \ +# -H "Content-Type: application/json" \ +# -d '{ +# "model": "your-model", +# "prompt": "The quick brown fox jumps over the lazy dog", +# "max_tokens": 16 +# }' +# +# Or for chat completions: +# +# curl -X POST http://localhost:9000/v1/chat/completions \ +# -H "Content-Type: application/json" \ +# -d '{ +# "model": "your-model", +# "messages": [{"role": "user", "content": "Hello!"}], +# "max_tokens": 16 +# }' +# +# Step 4: Health Check +# -------------------- +# To check if the proxy is running and see how many backend instances are +# connected, use: +# +# curl http://localhost:9000/healthcheck +# +# This will return a JSON object with the status and the number of prefiller +# and decoder instances. +# +# Notes: +# - You can scale the number of prefiller and decoder servers as needed. +# - The proxy will round-robin requests to balance load. +# - For production, ensure your backend servers are robust and secure. +# +# For more details, see the code and comments in this file. + + +import argparse +import asyncio +import heapq +import os +import sys +from contextlib import asynccontextmanager +from typing import List + +import httpx +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Add uvloop for faster event loop if available +try: + import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + pass + + +class ServerState: + + def __init__(self, host, port): + self.host = host + self.port = port + self.url = f'http://{host}:{port}/v1' + self.client = httpx.AsyncClient(timeout=None, + base_url=self.url, + limits=httpx.Limits( + max_connections=100000, + max_keepalive_connections=100000)) + self.active_tokens = 0 + self.active_kv_cache = 0 # Only for prefiller + self.active_requests = 0 # Number of active requests + self.aborted_requests = set() # Track aborted requests + # Removed individual server lock - will use global locks instead + + +class ProxyState: + + def __init__(self, prefiller_instances, decoder_instances): + self.prefillers: List[ServerState] = [ + ServerState(h, p) for h, p in prefiller_instances + ] + self.decoders: List[ServerState] = [ + ServerState(h, p) for h, p in decoder_instances + ] + self.req_to_prefiller = {} + self.req_id_lock = asyncio.Lock() + self.req_id_counter = 0 + # Removed selection locks - no longer needed for synchronous methods + + # Initialize priority queues for efficient server selection + # Each entry is (priority_score, server_index, server_reference) + # Lower priority score = higher priority (less loaded) + self.prefiller_heap = [(0, i, server) + for i, server in enumerate(self.prefillers)] + self.decoder_heap = [(0, i, server) + for i, server in enumerate(self.decoders)] + heapq.heapify(self.prefiller_heap) + heapq.heapify(self.decoder_heap) + + def _update_prefiller_priority(self, server_idx: int): + """Update the priority of a prefiller server in the heap.""" + server = self.prefillers[server_idx] + # Priority based on active_tokens and active_kv_cache + priority = server.active_tokens + server.active_kv_cache * 0.3 + # Remove old entry and add new one + self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap + if i != server_idx] + heapq.heappush(self.prefiller_heap, + (priority, server_idx, server)) # type: ignore + + def _update_decoder_priority(self, server_idx: int): + """Update the priority of a decoder server in the heap.""" + server = self.decoders[server_idx] + priority = server.active_tokens + # Remove old entry and add new one + self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap + if i != server_idx] + heapq.heappush(self.decoder_heap, + (priority, server_idx, server)) # type: ignore + + def abort_prefiller_request(self, server_idx: int, + request_id): # Changed to synchronous + """ + Mark a request as aborted. This will helps to release kv cache in + prefiller node. + """ + # No lock needed - atomic operation + self.prefillers[server_idx].aborted_requests.add(request_id) + + def aquire_aborted_prefiller_requests( + self, server_idx: int): # Changed to synchronous + """ + Get the set of aborted requests and clear it. + This is used to release kv cache in prefiller node. + """ + # No lock needed - atomic operation + aborted_requests = self.prefillers[server_idx].aborted_requests.copy() + self.prefillers[server_idx].aborted_requests.clear() + return aborted_requests + + async def next_req_id(self): + async with self.req_id_lock: + self.req_id_counter += 1 + return str(self.req_id_counter) + + def select_prefiller(self, token_count): # Changed to synchronous + # No lock needed - entire function is atomic + if not self.prefiller_heap: + raise RuntimeError("No prefiller servers available") + + priority, chosen, server = heapq.heappop(self.prefiller_heap) + + # Update the chosen server atomically + self.prefillers[chosen].active_tokens += token_count + self.prefillers[chosen].active_kv_cache += token_count + + # Update priority and re-add to heap + self._update_prefiller_priority(chosen) + + return chosen + + def release_prefiller(self, idx, token_count): # Changed to synchronous + # No lock needed - atomic operation + self.prefillers[idx].active_tokens -= token_count + # Update priority queue after releasing + self._update_prefiller_priority(idx) + + def release_prefiller_kv(self, idx, token_count): # Changed to synchronous + # No lock needed - atomic operation + if self.prefillers[idx].active_kv_cache > 0: + self.prefillers[idx].active_kv_cache -= token_count + # Update priority queue after releasing + self._update_prefiller_priority(idx) + + def select_decoder(self, token_count): # Changed to synchronous + # No lock needed - entire function is atomic + if not self.decoder_heap: + raise RuntimeError("No decoder servers available") + + priority, chosen, server = heapq.heappop(self.decoder_heap) + + # Update the chosen server atomically + self.decoders[chosen].active_tokens += token_count + + # Update priority and re-add to heap + self._update_decoder_priority(chosen) + + return chosen + + def release_decoder(self, idx, token_count): # Changed to synchronous + # No lock needed - atomic operation + self.decoders[idx].active_tokens -= token_count + # Update priority queue after releasing + self._update_decoder_priority(idx) + + # Omni_infer's calculate_input_scores function + def calculate_prefill_scores(self, request_length: int) -> float: + length_score = request_length / 4.0 + input_score = length_score * 0.0345 + 120.0745 + return input_score + + def calculate_decode_scores(self, request_length: int) -> float: + return request_length + + +proxy_state = None + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-hosts", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--prefiller-ports", + type=int, + nargs="+", + default=[8001]) + parser.add_argument("--decoder-hosts", + type=str, + nargs="+", + default=["localhost"]) + parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002]) + parser.add_argument("--max-retries", + type=int, + default=3, + help="Maximum number of retries for HTTP requests") + parser.add_argument( + "--retry-delay", + type=float, + default=0.001, + help="Base delay (seconds) for exponential backoff retries") + args = parser.parse_args() + if len(args.prefiller_hosts) != len(args.prefiller_ports): + raise ValueError( + "Number of prefiller hosts must match number of prefiller ports") + if len(args.decoder_hosts) != len(args.decoder_ports): + raise ValueError( + "Number of decoder hosts must match number of decoder ports") + args.prefiller_instances = list( + zip(args.prefiller_hosts, args.prefiller_ports)) + args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) + return args + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global proxy_state + proxy_state = ProxyState(global_args.prefiller_instances, + global_args.decoder_instances) + print( + f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients." + ) + yield + for p in proxy_state.prefillers: + await p.client.aclose() + for d in proxy_state.decoders: + await d.client.aclose() + + +app = FastAPI(lifespan=lifespan) + + +async def send_request_to_service(client: httpx.AsyncClient, + prefiller_id: int, + endpoint: str, + req_data: dict, + request_id: str, + max_retries: int = 3, + base_delay: float = 0.2): + aborted_requests = proxy_state.aquire_aborted_prefiller_requests( + prefiller_id) + req_data = req_data.copy() + req_data['kv_transfer_params'] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + "aborted_request": list(aborted_requests), + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + last_exc = None + for attempt in range(1, max_retries + 1): + try: + response = await client.post(endpoint, + json=req_data, + headers=headers) + response.raise_for_status() + return response + except (httpx.RequestError, httpx.HTTPStatusError) as e: + logger.warning( + f"Attempt {attempt} failed for {endpoint}: {str(e)}") + last_exc = e + if attempt < max_retries: + await asyncio.sleep(base_delay * (2**(attempt - 1))) + else: + logger.error( + f"All {max_retries} attempts failed for {endpoint}.") + raise last_exc + + +async def stream_service_response_with_retry(client: httpx.AsyncClient, + endpoint: str, + req_data: dict, + request_id: str, + max_retries: int = 3, + base_delay: float = 0.2): + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + for attempt in range(1, max_retries + 1): + try: + async with client.stream("POST", + endpoint, + json=req_data, + headers=headers) as response: + response.raise_for_status() + first_chunk_sent = False + async for chunk in response.aiter_bytes(): + first_chunk_sent = True + yield chunk + return # Success, exit after streaming + except (httpx.RequestError, httpx.HTTPStatusError) as e: + if attempt < max_retries: + logger.warning( + f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}" + ) + await asyncio.sleep(base_delay * (2**(attempt - 1))) + else: + logger.error( + f"All {max_retries} attempts failed for streaming {endpoint}." + ) + raise e + except Exception as e: + # If any chunk has been sent, do not retry, just log and drop + if 'first_chunk_sent' in locals() and first_chunk_sent: + logger.error( + f"Streaming to client interrupted after response started: {str(e)}" + ) + return + else: + if attempt < max_retries: + logger.warning( + f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}" + ) + await asyncio.sleep(base_delay * (2**(attempt - 1))) + else: + logger.error( + f"All {max_retries} attempts failed for streaming {endpoint}." + ) + raise e + + +async def _handle_completions(api: str, request: Request): + try: + req_data = await request.json() + req_body = await request.body() + request_length = len(req_body) + prefiller_score = proxy_state.calculate_prefill_scores(request_length) + logger.debug( + f"Request length: {request_length}, Prefiller score: {prefiller_score}" + ) + request_id = await proxy_state.next_req_id() + # Select prefiller + prefiller_idx = proxy_state.select_prefiller(prefiller_score) + prefiller = proxy_state.prefillers[prefiller_idx] + # Send request to prefiller + response = await send_request_to_service( + prefiller.client, + prefiller_idx, + api, + req_data, + request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay) + proxy_state.release_prefiller(prefiller_idx, prefiller_score) + response_json = response.json() + kv_transfer_params = response_json.get('kv_transfer_params', {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + # Select decoder + decoder_score = proxy_state.calculate_decode_scores(request_length) + logger.debug("Decoder score: %f", decoder_score) + # Use the prefiller's kv_transfer_params to select decoder + decoder_idx = proxy_state.select_decoder(decoder_score) + decoder = proxy_state.decoders[decoder_idx] + logger.debug("Using %s %s", prefiller.url, decoder.url) + # Stream response from decoder + released_kv = False + + async def generate_stream(): + nonlocal released_kv + # Only one await per chunk, minimal logic in loop + try: + async for chunk in stream_service_response_with_retry( + decoder.client, + api, + req_data, + request_id=request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay): + if not released_kv and chunk: + proxy_state.release_prefiller_kv( + prefiller_idx, prefiller_score) + released_kv = True + yield chunk + except Exception as e: + logger.error( + f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it" + ) + proxy_state.abort_prefiller_request(prefiller_idx, request_id) + proxy_state.release_prefiller_kv(prefiller_idx, + prefiller_score) + + # After streaming done, release tokens + proxy_state.release_decoder(decoder_idx, decoder_score) + + return StreamingResponse(generate_stream(), + media_type="application/json") + except Exception as e: + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + f" - {api} endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + return await _handle_completions("/completions", request) + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + return await _handle_completions("/chat/completions", request) + + +@app.get("/healthcheck") +async def healthcheck(): + return { + "status": "ok", + "prefill_instances": len(proxy_state.prefillers), + "decode_instances": len(proxy_state.decoders) + } + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/disaggregated_prefill_v1/toy_proxy_server.py b/examples/disaggregated_prefill_v1/toy_proxy_server.py deleted file mode 100644 index 2e26d0aee22..00000000000 --- a/examples/disaggregated_prefill_v1/toy_proxy_server.py +++ /dev/null @@ -1,275 +0,0 @@ -# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py - -# SPDX-License-Identifier: Apache-2.0 - -import argparse -import itertools -import os -import uuid -from contextlib import asynccontextmanager - -import httpx -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """ - Lifespan context manager to handle startup and shutdown events. - """ - # Startup: Initialize client pools for prefiller and decoder services - app.state.prefill_clients = [] - app.state.decode_clients = [] - limit = httpx.Limits(max_connections=100000, - max_keepalive_connections=100000) - - # Create prefill clients - for i, (host, port) in enumerate(global_args.prefiller_instances): - prefiller_base_url = f'http://{host}:{port}/v1' - app.state.prefill_clients.append({ - 'client': - httpx.AsyncClient(timeout=None, - base_url=prefiller_base_url, - limits=limit), - 'host': - host, - 'port': - port, - 'id': - i - }) - - # Create decode clients - for i, (host, port) in enumerate(global_args.decoder_instances): - decoder_base_url = f'http://{host}:{port}/v1' - app.state.decode_clients.append({ - 'client': - httpx.AsyncClient(timeout=None, - base_url=decoder_base_url, - limits=limit), - 'host': - host, - 'port': - port, - 'id': - i - }) - - # Initialize round-robin iterators - app.state.prefill_iterator = itertools.cycle( - range(len(app.state.prefill_clients))) - app.state.decode_iterator = itertools.cycle( - range(len(app.state.decode_clients))) - - print(f"Initialized {len(app.state.prefill_clients)} prefill clients " - f"and {len(app.state.decode_clients)} decode clients.") - - yield - - # Shutdown: Close all clients - for client_info in app.state.prefill_clients: - await client_info['client'].aclose() - - for client_info in app.state.decode_clients: - await client_info['client'].aclose() - - -# Update FastAPI app initialization to use lifespan -app = FastAPI(lifespan=lifespan) - - -def parse_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--host", type=str, default="localhost") - - # For prefiller instances - parser.add_argument("--prefiller-hosts", - "--prefiller-host", - type=str, - nargs="+", - default=["localhost"]) - parser.add_argument("--prefiller-ports", - "--prefiller-port", - type=int, - nargs="+", - default=[8100]) - - # For decoder instances - parser.add_argument("--decoder-hosts", - "--decoder-host", - type=str, - nargs="+", - default=["localhost"]) - parser.add_argument("--decoder-ports", - "--decoder-port", - type=int, - nargs="+", - default=[8200]) - - args = parser.parse_args() - - # Validate and pair hosts with ports - if len(args.prefiller_hosts) != len(args.prefiller_ports): - raise ValueError( - "Number of prefiller hosts must match number of prefiller ports") - - if len(args.decoder_hosts) != len(args.decoder_ports): - raise ValueError( - "Number of decoder hosts must match number of decoder ports") - - # Create tuples of (host, port) for each service type - args.prefiller_instances = list( - zip(args.prefiller_hosts, args.prefiller_ports)) - args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) - - return args - - -def get_next_client(app, service_type: str): - """ - Get the next client in round-robin fashion. - - Args: - app: The FastAPI app instance - service_type: Either 'prefill' or 'decode' - - Returns: - The next client to use - """ - if service_type == 'prefill': - client_idx = next(app.state.prefill_iterator) - return app.state.prefill_clients[client_idx] - elif service_type == 'decode': - client_idx = next(app.state.decode_iterator) - return app.state.decode_clients[client_idx] - else: - raise ValueError(f"Unknown service type: {service_type}") - - -async def send_request_to_service(client_info: dict, endpoint: str, - req_data: dict, request_id: str): - """ - Send a request to a service using a client from the pool. - """ - req_data = req_data.copy() - req_data['kv_transfer_params'] = { - "do_remote_decode": True, - "do_remote_prefill": False, - "remote_engine_id": None, - "remote_block_ids": None, - "remote_host": None, - "remote_port": None - } - req_data["stream"] = False - req_data["max_tokens"] = 1 - if "stream_options" in req_data: - del req_data["stream_options"] - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id - } - - response = await client_info['client'].post(endpoint, - json=req_data, - headers=headers) - response.raise_for_status() - - return response - - -async def stream_service_response(client_info: dict, endpoint: str, - req_data: dict, request_id: str): - """ - Asynchronously stream response from a service using a client from the pool. - """ - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id - } - - async with client_info['client'].stream("POST", - endpoint, - json=req_data, - headers=headers) as response: - response.raise_for_status() - async for chunk in response.aiter_bytes(): - yield chunk - - -async def _handle_completions(api: str, request: Request): - try: - req_data = await request.json() - request_id = str(uuid.uuid4()) - - # Get the next prefill client in round-robin fashion - prefill_client_info = get_next_client(request.app, 'prefill') - - # Send request to prefill service - response = await send_request_to_service(prefill_client_info, api, - req_data, request_id) - - # Extract the needed fields - response_json = response.json() - kv_transfer_params = response_json.get('kv_transfer_params', {}) - if kv_transfer_params: - req_data["kv_transfer_params"] = kv_transfer_params - - # Get the next decode client in round-robin fashion - decode_client_info = get_next_client(request.app, 'decode') - - logger.debug("Using %s %s", prefill_client_info, decode_client_info) - - # Stream response from decode service - async def generate_stream(): - async for chunk in stream_service_response(decode_client_info, - api, - req_data, - request_id=request_id): - yield chunk - - return StreamingResponse(generate_stream(), - media_type="application/json") - - except Exception as e: - import sys - import traceback - exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server" - f" - {api} endpoint") - print(e) - print("".join(traceback.format_exception(*exc_info))) - raise - - -@app.post("/v1/completions") -async def handle_completions(request: Request): - return await _handle_completions("/completions", request) - - -@app.post("/v1/chat/completions") -async def handle_chat_completions(request: Request): - return await _handle_completions("/chat/completions", request) - - -@app.get("/healthcheck") -async def healthcheck(): - """Simple endpoint to check if the server is running.""" - return { - "status": "ok", - "prefill_instances": len(app.state.prefill_clients), - "decode_instances": len(app.state.decode_clients) - } - - -if __name__ == '__main__': - global global_args - global_args = parse_args() - - import uvicorn - uvicorn.run(app, host=global_args.host, port=global_args.port) \ No newline at end of file