Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions vllm/entrypoints/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import get_tcp_uri
from vllm.utils.system_utils import get_mp_context
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
from vllm.v1.executor import Executor
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
Expand Down Expand Up @@ -112,6 +113,10 @@ def cmd(args: argparse.Namespace) -> None:
)
args.api_server_count = 1

# Auto-enable server load tracking when max_unfinished_requests is set
if args.max_unfinished_requests is not None:
args.enable_server_load_tracking = True

if args.api_server_count < 1:
run_headless(args)
elif args.api_server_count > 1:
Expand Down Expand Up @@ -276,6 +281,18 @@ def signal_handler(signum, frame):

addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)

# Create shared memory for tracking unfinished_requests across API servers
# Use mp context to match APIServerProcessManager
if args.max_unfinished_requests is not None and num_api_servers > 1:
mp_context = get_mp_context()
shared_unfinished_requests = mp_context.Array("i", num_api_servers, lock=False)
Comment thread
chaunceyjiang marked this conversation as resolved.
# Initialize all slots to 0
for i in range(num_api_servers):
shared_unfinished_requests[i] = 0
args.shared_unfinished_requests = shared_unfinished_requests
else:
args.shared_unfinished_requests = None

with launch_core_engines(
vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses, tensor_queue):
Expand Down
25 changes: 24 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,10 @@ async def init_render_app_state(
state.args = args
state.enable_server_load_tracking = False
state.server_load_metrics = 0
# max_unfinished_requests not applicable for render server
state.max_unfinished_requests = None
state.shared_unfinished_requests = None
state.server_index = 0


def create_server_socket(addr: tuple[str, int]) -> socket.socket:
Expand Down Expand Up @@ -588,6 +592,7 @@ async def build_and_serve(
listen_address: str,
sock: socket.socket,
args: Namespace,
client_config: dict | None = None,
**uvicorn_kwargs,
) -> asyncio.Task:
"""Build FastAPI app, initialize state, and start serving.
Expand All @@ -605,6 +610,19 @@ async def build_and_serve(

logger.info("Supported tasks: %s", supported_tasks)
app = build_app(args, supported_tasks, model_config)

# Set server tracking fields from client_config
if client_config is not None:
app.state.server_index = client_config.get("client_index", 0)
unfinished_arr = client_config.get("shared_unfinished_requests")
app.state.shared_unfinished_requests = unfinished_arr
else:
app.state.server_index = 0
app.state.shared_unfinished_requests = None

# max_unfinished_requests is always from args
app.state.max_unfinished_requests = args.max_unfinished_requests

await init_app_state(engine_client, app.state, args, supported_tasks)

logger.info("Starting vLLM server on %s", listen_address)
Expand Down Expand Up @@ -702,7 +720,12 @@ async def run_server_worker(
client_config=client_config,
) as engine_client:
shutdown_task = await build_and_serve(
engine_client, listen_address, sock, args, **uvicorn_kwargs
engine_client,
listen_address,
sock,
args,
client_config=client_config,
**uvicorn_kwargs,
)
# NB: Await server shutdown only after the backend context is exited
try:
Expand Down
9 changes: 9 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ class BaseFrontendArgs:
"""If set to True, enable prompt_tokens_details in usage."""
enable_server_load_tracking: bool = False
"""If set to True, enable tracking server_load_metrics in the app state."""
max_unfinished_requests: int | None = None
"""Maximum number of unfinished requests allowed across all API servers.
When the total unfinished requests exceeds this value, new requests
are rejected with a 503 error. Uses shared memory when multiple
API servers are running."""
enable_force_include_usage: bool = False
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint: bool = False
Expand Down Expand Up @@ -375,6 +380,10 @@ def validate_parsed_serve_args(args: argparse.Namespace):
if args.enable_log_outputs and not args.enable_log_requests:
raise TypeError("Error: --enable-log-outputs requires --enable-log-requests")

# max_unfinished_requests must be positive if set
if args.max_unfinished_requests is not None and args.max_unfinished_requests <= 0:
raise ValueError("Error: --max-unfinished-requests must be > 0")


def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
Expand Down
33 changes: 33 additions & 0 deletions vllm/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ def decrement_server_load(request: Request):
request.app.state.server_load_metrics -= 1


def _make_overloaded_response():
return JSONResponse(
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
content=ErrorResponse(
error=ErrorInfo(
message="Server is overloaded",
type="ServiceUnavailableError",
code=HTTPStatus.SERVICE_UNAVAILABLE.value,
param=None,
)
).model_dump(),
)


def load_aware_call(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
Expand All @@ -119,7 +133,26 @@ async def wrapper(*args, **kwargs):
if not hasattr(raw_request.app.state, "server_load_metrics"):
raw_request.app.state.server_load_metrics = 0

max_unfinished = getattr(raw_request.app.state, "max_unfinished_requests", None)
shared_array = getattr(
raw_request.app.state, "shared_unfinished_requests", None
)

raw_request.app.state.server_load_metrics += 1

if max_unfinished is not None:
if shared_array is not None:
# Multi-server: update shared array and check total
server_index = getattr(raw_request.app.state, "server_index", 0)
shared_array[server_index] = raw_request.app.state.server_load_metrics
total_unfinished = sum(shared_array)
else:
# Single server: check local count directly
total_unfinished = raw_request.app.state.server_load_metrics

if total_unfinished > max_unfinished:
raw_request.app.state.server_load_metrics -= 1
return _make_overloaded_response()
Comment thread
chaunceyjiang marked this conversation as resolved.
try:
response = await func(*args, **kwargs)
except Exception:
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def __init__(
client_config["stats_update_address"] = stats_update_address
if tensor_queue is not None:
client_config["tensor_queue"] = tensor_queue
unfinished_arr = getattr(args, "shared_unfinished_requests", None)
if unfinished_arr is not None:
client_config["shared_unfinished_requests"] = unfinished_arr

proc = spawn_context.Process(
target=target_server_fn or run_api_server_worker_proc,
Expand Down
Loading