From 521424a9cfd0c151b3d8806244c8059f63c83470 Mon Sep 17 00:00:00 2001 From: huiwq1990 Date: Thu, 26 Jun 2025 16:52:28 +0800 Subject: [PATCH] add transcriptions route Signed-off-by: huiwq1990 --- src/vllm_router/routers/main_router.py | 5 + .../services/request_service/request.py | 158 ++++++++++++++++-- 2 files changed, 150 insertions(+), 13 deletions(-) diff --git a/src/vllm_router/routers/main_router.py b/src/vllm_router/routers/main_router.py index e4e77b018..4283bcef9 100644 --- a/src/vllm_router/routers/main_router.py +++ b/src/vllm_router/routers/main_router.py @@ -23,6 +23,7 @@ from vllm_router.services.request_service.request import ( route_general_request, route_sleep_wakeup_request, + route_general_transcriptions, ) from vllm_router.stats.engine_stats import get_engine_stats_scraper from vllm_router.version import __version__ @@ -232,3 +233,7 @@ async def health() -> Response: ) else: return JSONResponse(content={"status": "healthy"}, status_code=200) + +@main_router.post("/v1/audio/transcriptions") +async def route_v1_audio_transcriptions(request: Request, background_tasks: BackgroundTasks): + return await route_general_transcriptions(request, "/v1/audio/transcriptions", background_tasks) diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 9842d1821..d0e0aeea0 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -68,8 +68,9 @@ async def process_request( backend_url, request_id, endpoint, + is_streaming, background_tasks: BackgroundTasks, - debug_request=None, + **kwargs, ): """ Process a request by sending it to the chosen backend. @@ -80,9 +81,7 @@ async def process_request( backend_url: The URL of the backend to send the request to. request_id: A unique identifier for the request. endpoint: The endpoint to send the request to on the backend. - debug_request: The original request object from the client, used for - optional debug logging. - + is_streaming: Whether the request is a streaming request. Yields: The response headers and status code, followed by the response content. @@ -95,14 +94,6 @@ async def process_request( request.app.state.request_stats_monitor.on_new_request( backend_url, request_id, start_time ) - # Check if this is a streaming request - is_streaming = False - try: - request_json = json.loads(body) - is_streaming = request_json.get("stream", False) - except: - # If we can't parse the body as JSON, assume it's not streaming - pass # For non-streaming requests, collect the full response to cache it properly full_response = bytearray() @@ -110,8 +101,12 @@ async def process_request( async with request.app.state.httpx_client_wrapper().stream( method=request.method, url=backend_url + endpoint, - headers=dict(request.headers), + headers={ + k: v for k, v in request.headers.items() if k.lower() != "content-length" + }, content=body, + files=kwargs.get("form_files", None), + data=kwargs.get("form_datas", None), timeout=None, ) as backend_response: # Yield headers and status code first. @@ -292,12 +287,16 @@ async def route_general_request( logger.info( f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}" ) + + is_streaming = request_json.get("stream", False) + stream_generator = process_request( request, request_body, server_url, request_id, endpoint, + is_streaming, background_tasks, ) headers, status_code = await anext(stream_generator) @@ -458,3 +457,136 @@ async def route_sleep_wakeup_request( content={"status": "success"}, headers={"X-Request-Id": request_id}, ) + +async def route_general_transcriptions( + request: Request, endpoint: str, background_tasks: BackgroundTasks +): + + in_router_time = time.time() + # Same as vllm, Get request_id from X-Request-Id header if available + request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) + request_form = await request.form() + + if "multipart/form-data" not in request.headers.get("content-type", ""): + return JSONResponse( + status_code=400, + content={"error": "Invalid request: form-data not valid."}, + headers={"X-Request-Id": request_id}, + ) + + form_files = {} + form_datas = {} + + for key, value in request_form.items(): + if hasattr(value, "file"): + form_files[key] = (value.filename, await value.read(), value.content_type) + else: + form_datas[key] = value + + if request.query_params: + request_endpoint = request.query_params.get("id") + else: + request_endpoint = None + + requested_model = form_datas.get("model", None) + if requested_model is None: + return JSONResponse( + status_code=400, + content={"error": "Invalid request: missing 'model' in request form."}, + headers={"X-Request-Id": request_id}, + ) + + # TODO (ApostaC): merge two awaits into one + service_discovery = get_service_discovery() + endpoints = service_discovery.get_endpoint_info() + + aliases = getattr(service_discovery, "aliases", None) + if aliases and requested_model in aliases.keys(): + requested_model = aliases[requested_model] + form_datas["model"] = requested_model + + if not request_endpoint: + endpoints = list( + filter( + lambda x: requested_model in x.model_names and x.sleep == False, + endpoints, + ) + ) + engine_stats = request.app.state.engine_stats_scraper.get_engine_stats() + request_stats = request.app.state.request_stats_monitor.get_request_stats( + time.time() + ) + else: + endpoints = list( + filter( + lambda x: requested_model in x.model_names + and x.Id == request_endpoint + and x.sleep == False, + endpoints, + ) + ) + + if not endpoints: + return JSONResponse( + status_code=400, + content={ + "error": f"Model {requested_model} not found or vLLM engine is sleeping." + }, + ) + + logger.debug(f"Routing request {request_id} for model: {requested_model}") + if request_endpoint: + server_url = endpoints[0].url + logger.debug( + f"Routing request {request_id} to engine with Id: {endpoints[0].Id}" + ) + else: + server_url = request.app.state.router.route_request( + endpoints, engine_stats, request_stats, request + ) + + curr_time = time.time() + # Extract actual session ID from request headers for logging + session_key = ( + getattr(request.app.state.router, "session_key", None) + if hasattr(request.app.state.router, "session_key") + else None + ) + session_id = ( + request.headers.get(session_key, None) if session_key is not None else None + ) + session_id_display = session_id if session_id is not None else "None" + + # Debug logging to help troubleshoot session ID extraction + logger.debug( + f"Debug session extraction - Router type: {type(request.app.state.router).__name__}" + ) + logger.debug(f"Debug session extraction - Session key config: {session_key}") + logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}") + logger.debug(f"Debug session extraction - Extracted session ID: {session_id}") + + logger.info( + f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}" + ) + + stream_generator = process_request( + request, + None, + server_url, + request_id, + endpoint, + form_datas.get("stream", False), + background_tasks, + form_files=form_files, + form_datas=form_datas, + ) + headers, status_code = await anext(stream_generator) + headers_dict = {key: value for key, value in headers.items()} + headers_dict["X-Request-Id"] = request_id + return StreamingResponse( + stream_generator, + status_code=status_code, + headers=headers_dict, + media_type="text/event-stream", + ) +