Skip to content

Commit 07f7f48

Browse files
committed
add transcriptions route
Signed-off-by: huiwq1990 <[email protected]>
1 parent 242e9c1 commit 07f7f48

File tree

2 files changed

+150
-9
lines changed

2 files changed

+150
-9
lines changed

src/vllm_router/routers/main_router.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm_router.services.request_service.request import (
2424
route_general_request,
2525
route_sleep_wakeup_request,
26+
route_general_transcriptions,
2627
)
2728
from vllm_router.stats.engine_stats import get_engine_stats_scraper
2829
from vllm_router.version import __version__
@@ -232,3 +233,7 @@ async def health() -> Response:
232233
)
233234
else:
234235
return JSONResponse(content={"status": "healthy"}, status_code=200)
236+
237+
@main_router.post("/v1/audio/transcriptions")
238+
async def route_v1_audio_transcriptions(request: Request, background_tasks: BackgroundTasks):
239+
return await route_general_transcriptions(request, "/v1/audio/transcriptions", background_tasks)

src/vllm_router/services/request_service/request.py

Lines changed: 145 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ async def process_request(
6868
backend_url,
6969
request_id,
7070
endpoint,
71+
is_streaming,
72+
form_files,
73+
form_data,
7174
background_tasks: BackgroundTasks,
7275
debug_request=None,
7376
):
@@ -95,23 +98,17 @@ async def process_request(
9598
request.app.state.request_stats_monitor.on_new_request(
9699
backend_url, request_id, start_time
97100
)
98-
# Check if this is a streaming request
99-
is_streaming = False
100-
try:
101-
request_json = json.loads(body)
102-
is_streaming = request_json.get("stream", False)
103-
except:
104-
# If we can't parse the body as JSON, assume it's not streaming
105-
pass
106101

107102
# For non-streaming requests, collect the full response to cache it properly
108103
full_response = bytearray()
109104

110105
async with request.app.state.httpx_client_wrapper().stream(
111106
method=request.method,
112107
url=backend_url + endpoint,
113-
headers=dict(request.headers),
108+
headers={k: v for k, v in request.headers.items() if k.lower() != "content-length"},
114109
content=body,
110+
files=form_files,
111+
data=form_data,
115112
timeout=None,
116113
) as backend_response:
117114
# Yield headers and status code first.
@@ -292,12 +289,18 @@ async def route_general_request(
292289
logger.info(
293290
f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
294291
)
292+
293+
is_streaming = request_json.get("stream", False)
294+
295295
stream_generator = process_request(
296296
request,
297297
request_body,
298298
server_url,
299299
request_id,
300300
endpoint,
301+
is_streaming,
302+
None,
303+
None,
301304
background_tasks,
302305
)
303306
headers, status_code = await anext(stream_generator)
@@ -458,3 +461,136 @@ async def route_sleep_wakeup_request(
458461
content={"status": "success"},
459462
headers={"X-Request-Id": request_id},
460463
)
464+
465+
async def route_general_transcriptions(
466+
request: Request, endpoint: str, background_tasks: BackgroundTasks
467+
):
468+
469+
in_router_time = time.time()
470+
# Same as vllm, Get request_id from X-Request-Id header if available
471+
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
472+
request_form = await request.form()
473+
474+
if "multipart/form-data" not in request.headers.get("content-type", ""):
475+
return JSONResponse(
476+
status_code=400,
477+
content={"error": "Invalid request: form-data not valid."},
478+
headers={"X-Request-Id": request_id},
479+
)
480+
481+
form_files = {}
482+
form_data = {}
483+
484+
for key, value in request_form.items():
485+
if hasattr(value, "file"):
486+
form_files[key] = (value.filename, await value.read(), value.content_type)
487+
else:
488+
form_data[key] = value
489+
490+
if request.query_params:
491+
request_endpoint = request.query_params.get("id")
492+
else:
493+
request_endpoint = None
494+
495+
requested_model = form_data.get("model", None)
496+
if requested_model is None:
497+
return JSONResponse(
498+
status_code=400,
499+
content={"error": "Invalid request: missing 'model' in request form."},
500+
headers={"X-Request-Id": request_id},
501+
)
502+
503+
# TODO (ApostaC): merge two awaits into one
504+
service_discovery = get_service_discovery()
505+
endpoints = service_discovery.get_endpoint_info()
506+
507+
aliases = getattr(service_discovery, "aliases", None)
508+
if aliases and requested_model in aliases.keys():
509+
requested_model = aliases[requested_model]
510+
form_data["model"] = requested_model
511+
512+
if not request_endpoint:
513+
endpoints = list(
514+
filter(
515+
lambda x: requested_model in x.model_names and x.sleep == False,
516+
endpoints,
517+
)
518+
)
519+
engine_stats = request.app.state.engine_stats_scraper.get_engine_stats()
520+
request_stats = request.app.state.request_stats_monitor.get_request_stats(
521+
time.time()
522+
)
523+
else:
524+
endpoints = list(
525+
filter(
526+
lambda x: requested_model in x.model_names
527+
and x.Id == request_endpoint
528+
and x.sleep == False,
529+
endpoints,
530+
)
531+
)
532+
533+
if not endpoints:
534+
return JSONResponse(
535+
status_code=400,
536+
content={
537+
"error": f"Model {requested_model} not found or vLLM engine is sleeping."
538+
},
539+
)
540+
541+
logger.debug(f"Routing request {request_id} for model: {requested_model}")
542+
if request_endpoint:
543+
server_url = endpoints[0].url
544+
logger.debug(
545+
f"Routing request {request_id} to engine with Id: {endpoints[0].Id}"
546+
)
547+
else:
548+
server_url = request.app.state.router.route_request(
549+
endpoints, engine_stats, request_stats, request
550+
)
551+
552+
curr_time = time.time()
553+
# Extract actual session ID from request headers for logging
554+
session_key = (
555+
getattr(request.app.state.router, "session_key", None)
556+
if hasattr(request.app.state.router, "session_key")
557+
else None
558+
)
559+
session_id = (
560+
request.headers.get(session_key, None) if session_key is not None else None
561+
)
562+
session_id_display = session_id if session_id is not None else "None"
563+
564+
# Debug logging to help troubleshoot session ID extraction
565+
logger.debug(
566+
f"Debug session extraction - Router type: {type(request.app.state.router).__name__}"
567+
)
568+
logger.debug(f"Debug session extraction - Session key config: {session_key}")
569+
logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}")
570+
logger.debug(f"Debug session extraction - Extracted session ID: {session_id}")
571+
572+
logger.info(
573+
f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
574+
)
575+
576+
stream_generator = process_request(
577+
request,
578+
None,
579+
server_url,
580+
request_id,
581+
endpoint,
582+
form_data.get("stream", False),
583+
form_files,
584+
form_data,
585+
background_tasks,
586+
)
587+
headers, status_code = await anext(stream_generator)
588+
headers_dict = {key: value for key, value in headers.items()}
589+
headers_dict["X-Request-Id"] = request_id
590+
return StreamingResponse(
591+
stream_generator,
592+
status_code=status_code,
593+
headers=headers_dict,
594+
media_type="text/event-stream",
595+
)
596+

0 commit comments

Comments
 (0)