Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ ENV INSTALL_OPTIONAL_DEP=${INSTALL_OPTIONAL_DEP}
RUN pip install --upgrade --no-cache-dir pip setuptools_scm && \
pip install --no-cache-dir .[$INSTALL_OPTIONAL_DEP]

# Set environment variable for workers (can be overridden)
ENV VLLM_ROUTER_WORKERS=1

# Set the entrypoint
ENTRYPOINT ["vllm-router"]
CMD []
99 changes: 81 additions & 18 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,52 @@

logger = logging.getLogger("uvicorn")

# Global variable to store parsed arguments for multi-worker mode
_global_args: object | None = None


def set_global_args(args):
"""Set global arguments for multi-worker mode.

This function should be called before importing the app module
in multi-worker scenarios.
"""
global _global_args
_global_args = args


@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize all components for this worker process
# This ensures each worker has its own properly initialized state
args = getattr(app.state, "args", None)
if args is None:
# Fallback: use global args or parse args if not available
global _global_args
if _global_args is not None:
args = _global_args
else:
# This should only happen in single-worker mode
args = parse_args()

# Initialize all components
initialize_all(app, args)

# Start log stats thread if enabled
# Note: In multi-worker mode, each worker will have its own log stats thread
# This is actually fine as each worker can log its own stats independently
if args.log_stats and not getattr(app.state, "log_stats_started", False):
threading.Thread(
target=log_stats,
args=(
app,
args.log_stats_interval,
),
daemon=True,
).start()
app.state.log_stats_started = True

# Start aiohttp client wrapper
app.state.aiohttp_client_wrapper.start()
if hasattr(app.state, "batch_processor"):
await app.state.batch_processor.initialize()
Expand Down Expand Up @@ -272,32 +315,52 @@ def initialize_all(app: FastAPI, args):
app.state.request_rewriter = get_request_rewriter()


app = FastAPI(lifespan=lifespan)
app.include_router(main_router)
app.include_router(files_router)
app.include_router(batches_router)
app.include_router(metrics_router)
app.state.aiohttp_client_wrapper = AiohttpClientWrapper()
app.state.semantic_cache_available = semantic_cache_available
def create_app():
"""Create and configure the FastAPI application."""
app = FastAPI(lifespan=lifespan)
app.include_router(main_router)
app.include_router(files_router)
app.include_router(batches_router)
app.include_router(metrics_router)
app.state.aiohttp_client_wrapper = AiohttpClientWrapper()
app.state.semantic_cache_available = semantic_cache_available
return app


def setup_app_with_args(app: FastAPI, args):
"""Set up the application with parsed arguments.

This function is called to store args in app.state so they can be
accessed by the lifespan context manager in each worker process.
"""
app.state.args = args


# Create the app instance
app = create_app()


def main():
args = parse_args()
initialize_all(app, args)
if args.log_stats:
threading.Thread(
target=log_stats,
args=(
app,
args.log_stats_interval,
),
daemon=True,
).start()

# Set up the app with arguments (for single-worker mode)
setup_app_with_args(app, args)

# Set global args for multi-worker mode
set_global_args(args)

# Workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active.
set_ulimit()
uvicorn.run(app, host=args.host, port=args.port)

# Use import string for multi-worker support
uvicorn.run(
"vllm_router.app:app",
host=args.host,
port=args.port,
workers=args.workers,
log_level=args.log_level,
)


if __name__ == "__main__":
Expand Down
18 changes: 18 additions & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import argparse
import json
import logging
import os
import sys

from vllm_router import utils
Expand Down Expand Up @@ -379,6 +380,23 @@ def parse_args():
help="The threshold for kv-aware routing.",
)

# Get default workers from environment variable or use 1
try:
default_workers = int(os.environ.get("VLLM_ROUTER_WORKERS", "1"))
except ValueError:
logger.warning(
"Invalid value for VLLM_ROUTER_WORKERS environment variable. "
"It must be an integer. Defaulting to 1."
)
default_workers = 1

parser.add_argument(
"--workers",
type=int,
default=default_workers,
help="The number of worker processes to run. Default is 1 (single process). Can also be set via VLLM_ROUTER_WORKERS environment variable.",
)

args = parser.parse_args()
args = load_initial_config_from_config_file_if_required(parser, args)

Expand Down
Loading