Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 109 additions & 80 deletions src/vllm_router/services/request_service/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

# --- Request Processing & Routing ---
import asyncio
import json
import os
import time
Expand Down Expand Up @@ -136,8 +137,54 @@ async def process_request(
)


def perform_service_discovery(request, request_json, request_endpoint, requested_model):
# TODO (ApostaC): merge two awaits into one
service_discovery = get_service_discovery()
endpoints = service_discovery.get_endpoint_info()

aliases = getattr(service_discovery, "aliases", None)
if aliases and requested_model in aliases.keys():
requested_model = aliases[requested_model]
request_body = replace_model_in_request_body(request_json, requested_model)
update_content_length(request, request_body)

if not request_endpoint:
endpoints = list(
filter(
lambda x: requested_model in x.model_names and not x.sleep,
endpoints,
)
)
engine_stats = request.app.state.engine_stats_scraper.get_engine_stats()
request_stats = request.app.state.request_stats_monitor.get_request_stats(
time.time()
)
else:
endpoints = list(
filter(
lambda x: requested_model in x.model_names
and x.Id == request_endpoint
and not x.sleep,
endpoints,
)
)
engine_stats, request_stats = None, None

if not endpoints:
return JSONResponse(
status_code=400,
content={
"error": f"Model {requested_model} not found or vLLM engine is sleeping."
},
)
return endpoints, engine_stats, request_stats


async def route_general_request(
request: Request, endpoint: str, background_tasks: BackgroundTasks
request: Request,
endpoint: str,
background_tasks: BackgroundTasks,
attempted_reroutes: int = 0,
):
"""
Route the incoming request to the backend server and stream the response back to the client.
Expand Down Expand Up @@ -203,97 +250,79 @@ async def route_general_request(
status_code=400, detail="Request body is not JSON parsable."
)

# TODO (ApostaC): merge two awaits into one
service_discovery = get_service_discovery()
endpoints = service_discovery.get_endpoint_info()
# Perform service discovery to request path a number of times equal to reroutes + 1
for _ in range(attempted_reroutes + 1):
endpoints, engine_stats, request_stats = await asyncio.to_thread(
perform_service_discovery,
request,
request_json,
request_endpoint,
requested_model,
)

aliases = getattr(service_discovery, "aliases", None)
if aliases and requested_model in aliases.keys():
requested_model = aliases[requested_model]
request_body = replace_model_in_request_body(request_json, requested_model)
update_content_length(request, request_body)
logger.debug(f"Routing request {request_id} for model: {requested_model}")
if request_endpoint:
server_url = endpoints[0].url
logger.debug(
f"Routing request {request_id} to engine with Id: {endpoints[0].Id}"
)

if not request_endpoint:
endpoints = list(
filter(
lambda x: requested_model in x.model_names and not x.sleep,
endpoints,
elif isinstance(request.app.state.router, KvawareRouter) or isinstance(
request.app.state.router, PrefixAwareRouter
):
server_url = await request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request, request_json
)
)
engine_stats = request.app.state.engine_stats_scraper.get_engine_stats()
request_stats = request.app.state.request_stats_monitor.get_request_stats(
time.time()
)
else:
endpoints = list(
filter(
lambda x: requested_model in x.model_names
and x.Id == request_endpoint
and not x.sleep,
endpoints,
else:
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request
)
)

if not endpoints:
return JSONResponse(
status_code=400,
content={
"error": f"Model {requested_model} not found or vLLM engine is sleeping."
},
curr_time = time.time()
# Extract actual session ID from request headers for logging
session_key = (
getattr(request.app.state.router, "session_key", None)
if hasattr(request.app.state.router, "session_key")
else None
)

logger.debug(f"Routing request {request_id} for model: {requested_model}")
if request_endpoint:
server_url = endpoints[0].url
logger.debug(
f"Routing request {request_id} to engine with Id: {endpoints[0].Id}"
session_id = (
request.headers.get(session_key, None) if session_key is not None else None
)
session_id_display = session_id if session_id is not None else "None"

elif isinstance(request.app.state.router, KvawareRouter) or isinstance(
request.app.state.router, PrefixAwareRouter
):
server_url = await request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request, request_json
# Debug logging to help troubleshoot session ID extraction
logger.debug(
f"Debug session extraction - Router type: {type(request.app.state.router).__name__}"
)
else:
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request
logger.debug(f"Debug session extraction - Session key config: {session_key}")
logger.debug(
f"Debug session extraction - Request headers: {dict(request.headers)}"
)
logger.debug(f"Debug session extraction - Extracted session ID: {session_id}")

curr_time = time.time()
# Extract actual session ID from request headers for logging
session_key = (
getattr(request.app.state.router, "session_key", None)
if hasattr(request.app.state.router, "session_key")
else None
)
session_id = (
request.headers.get(session_key, None) if session_key is not None else None
)
session_id_display = session_id if session_id is not None else "None"

# Debug logging to help troubleshoot session ID extraction
logger.debug(
f"Debug session extraction - Router type: {type(request.app.state.router).__name__}"
)
logger.debug(f"Debug session extraction - Session key config: {session_key}")
logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}")
logger.debug(f"Debug session extraction - Extracted session ID: {session_id}")
logger.info(
f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
)
error = None
try:
stream_generator = process_request(
request,
request_body,
server_url,
request_id,
endpoint,
background_tasks,
)
headers, status = await anext(stream_generator)
headers_dict = {key: value for key, value in headers.items()}
headers_dict["X-Request-Id"] = request_id
# Break out of the loop when the request's stream is fully generated
break
except Exception as e:
error = e

logger.info(
f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
)
stream_generator = process_request(
request,
request_body,
server_url,
request_id,
endpoint,
background_tasks,
)
headers, status = await anext(stream_generator)
headers_dict = {key: value for key, value in headers.items()}
headers_dict["X-Request-Id"] = request_id
if error:
raise error
return StreamingResponse(
stream_generator,
status_code=status,
Expand Down