Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/vllm_router/routers/main_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm_router.services.request_service.request import (
route_general_request,
route_sleep_wakeup_request,
route_general_transcriptions,
)
from vllm_router.stats.engine_stats import get_engine_stats_scraper
from vllm_router.version import __version__
Expand Down Expand Up @@ -232,3 +233,7 @@ async def health() -> Response:
)
else:
return JSONResponse(content={"status": "healthy"}, status_code=200)

@main_router.post("/v1/audio/transcriptions")
async def route_v1_audio_transcriptions(request: Request, background_tasks: BackgroundTasks):
return await route_general_transcriptions(request, "/v1/audio/transcriptions", background_tasks)
158 changes: 145 additions & 13 deletions src/vllm_router/services/request_service/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ async def process_request(
backend_url,
request_id,
endpoint,
is_streaming,
background_tasks: BackgroundTasks,
debug_request=None,
**kwargs,
):
"""
Process a request by sending it to the chosen backend.
Expand All @@ -80,9 +81,7 @@ async def process_request(
backend_url: The URL of the backend to send the request to.
request_id: A unique identifier for the request.
endpoint: The endpoint to send the request to on the backend.
debug_request: The original request object from the client, used for
optional debug logging.

is_streaming: Whether the request is a streaming request.
Yields:
The response headers and status code, followed by the response content.

Expand All @@ -95,23 +94,19 @@ async def process_request(
request.app.state.request_stats_monitor.on_new_request(
backend_url, request_id, start_time
)
# Check if this is a streaming request
is_streaming = False
try:
request_json = json.loads(body)
is_streaming = request_json.get("stream", False)
except:
# If we can't parse the body as JSON, assume it's not streaming
pass

# For non-streaming requests, collect the full response to cache it properly
full_response = bytearray()

async with request.app.state.httpx_client_wrapper().stream(
method=request.method,
url=backend_url + endpoint,
headers=dict(request.headers),
headers={
k: v for k, v in request.headers.items() if k.lower() != "content-length"
},
content=body,
files=kwargs.get("form_files", None),
data=kwargs.get("form_datas", None),
timeout=None,
) as backend_response:
# Yield headers and status code first.
Expand Down Expand Up @@ -292,12 +287,16 @@ async def route_general_request(
logger.info(
f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
)

is_streaming = request_json.get("stream", False)

stream_generator = process_request(
request,
request_body,
server_url,
request_id,
endpoint,
is_streaming,
background_tasks,
)
headers, status_code = await anext(stream_generator)
Expand Down Expand Up @@ -458,3 +457,136 @@ async def route_sleep_wakeup_request(
content={"status": "success"},
headers={"X-Request-Id": request_id},
)

async def route_general_transcriptions(
request: Request, endpoint: str, background_tasks: BackgroundTasks
):

in_router_time = time.time()
# Same as vllm, Get request_id from X-Request-Id header if available
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
request_form = await request.form()

if "multipart/form-data" not in request.headers.get("content-type", ""):
return JSONResponse(
status_code=400,
content={"error": "Invalid request: form-data not valid."},
headers={"X-Request-Id": request_id},
)

form_files = {}
form_datas = {}

for key, value in request_form.items():
if hasattr(value, "file"):
form_files[key] = (value.filename, await value.read(), value.content_type)
else:
form_datas[key] = value

if request.query_params:
request_endpoint = request.query_params.get("id")
else:
request_endpoint = None

requested_model = form_datas.get("model", None)
if requested_model is None:
return JSONResponse(
status_code=400,
content={"error": "Invalid request: missing 'model' in request form."},
headers={"X-Request-Id": request_id},
)

# TODO (ApostaC): merge two awaits into one
service_discovery = get_service_discovery()
endpoints = service_discovery.get_endpoint_info()

aliases = getattr(service_discovery, "aliases", None)
if aliases and requested_model in aliases.keys():
requested_model = aliases[requested_model]
form_datas["model"] = requested_model

if not request_endpoint:
endpoints = list(
filter(
lambda x: requested_model in x.model_names and x.sleep == False,
endpoints,
)
)
engine_stats = request.app.state.engine_stats_scraper.get_engine_stats()
request_stats = request.app.state.request_stats_monitor.get_request_stats(
time.time()
)
else:
endpoints = list(
filter(
lambda x: requested_model in x.model_names
and x.Id == request_endpoint
and x.sleep == False,
endpoints,
)
)

if not endpoints:
return JSONResponse(
status_code=400,
content={
"error": f"Model {requested_model} not found or vLLM engine is sleeping."
},
)

logger.debug(f"Routing request {request_id} for model: {requested_model}")
if request_endpoint:
server_url = endpoints[0].url
logger.debug(
f"Routing request {request_id} to engine with Id: {endpoints[0].Id}"
)
else:
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request
)

curr_time = time.time()
# Extract actual session ID from request headers for logging
session_key = (
getattr(request.app.state.router, "session_key", None)
if hasattr(request.app.state.router, "session_key")
else None
)
session_id = (
request.headers.get(session_key, None) if session_key is not None else None
)
session_id_display = session_id if session_id is not None else "None"

# Debug logging to help troubleshoot session ID extraction
logger.debug(
f"Debug session extraction - Router type: {type(request.app.state.router).__name__}"
)
logger.debug(f"Debug session extraction - Session key config: {session_key}")
logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}")
logger.debug(f"Debug session extraction - Extracted session ID: {session_id}")

logger.info(
f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
)

stream_generator = process_request(
request,
None,
server_url,
request_id,
endpoint,
form_datas.get("stream", False),
background_tasks,
form_files=form_files,
form_datas=form_datas,
)
headers, status_code = await anext(stream_generator)
headers_dict = {key: value for key, value in headers.items()}
headers_dict["X-Request-Id"] = request_id
return StreamingResponse(
stream_generator,
status_code=status_code,
headers=headers_dict,
media_type="text/event-stream",
)