diff --git a/query_transcription.py b/query_transcription.py new file mode 100644 index 000000000000..3e50b5d52575 --- /dev/null +++ b/query_transcription.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import OpenAI + +from vllm.assets.audio import AudioAsset + +mary_had_lamb = AudioAsset('mary_had_lamb').get_asset_path() +winning_call = AudioAsset('winning_call').get_asset_path() + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) +with open(str(mary_had_lamb), "rb") as f: + transcription = client.audio.transcriptions.create( + file=f, + model="openai/whisper-large-v3", + language="en", + response_format="text", + temperature=0.0) + print("transcription result:", transcription) diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index d9e51082e6ca..c465527b4e89 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from pathlib import Path from typing import Literal from urllib.parse import urljoin @@ -28,6 +29,10 @@ def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) + def get_asset_path(self) -> Path: + return get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + @property def url(self) -> str: return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b8f54d6c7804..11777ec0e795 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -17,10 +17,10 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union +from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -62,6 +62,8 @@ ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, UnloadLoraAdapterRequest) from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager # yapf: enable @@ -76,6 +78,8 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.serving_transcription import ( + OpenAIServingTranscription) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger @@ -319,6 +323,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization +def transcription(request: Request) -> OpenAIServingTranscription: + return request.app.state.openai_serving_transcription + + def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -545,6 +553,32 @@ async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) +@router.post("/v1/audio/transcriptions") +@with_cancellation +async def create_transcriptions(request: Annotated[TranscriptionRequest, + Form()], + raw_request: Request): + + audio_data = await request.file.read() + + handler = transcription(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Transcriptions API") + + generator = await handler.create_transcription(audio_data, request, + raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranscriptionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { "generate": { "messages": (ChatCompletionRequest, create_chat_completion), @@ -821,6 +855,12 @@ async def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) + state.openai_serving_transcription = OpenAIServingTranscription( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) state.task = model_config.task diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 83b841826231..790e77b093cb 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,9 +5,11 @@ import re import time from argparse import Namespace -from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union +from typing import (Any, ClassVar, Dict, List, Literal, Optional, Set, + TypeAlias, Union) import torch +from fastapi import UploadFile from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) from typing_extensions import Annotated @@ -1426,3 +1428,163 @@ class LoadLoraAdapterRequest(BaseModel): class UnloadLoraAdapterRequest(BaseModel): lora_name: str lora_int_id: Optional[int] = Field(default=None) + + +## Protocols for Audio +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", + "vtt"] + + +class TranscriptionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + #https://platform.openai.com/docs/api-reference/audio/createTranscription + + file: UploadFile + """ + The audio file object (not file name) to transcribe, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: str + """ID of the model to use. + """ + + language: str + """The language of the input audio. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy and latency. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + ## TODO (varun) : Support if set to 0, certain thresholds are met !! + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + + timestamp_granularities: List[Literal["word", "segment"]] = Field( + alias="timestamp_granularities[]", default=[]) + """The timestamp granularities to populate for this transcription. + + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: + There is no additional latency for segment timestamps, but generating word + timestamps incurs additional latency. + """ + + # Default sampling parameters for transcription requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "temperature": 0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens) + + +# Transcription response objects +class TranscriptionResponse(OpenAIBaseModel): + text: str + """The transcribed text.""" + + +class TranscriptionWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + + +class TranscriptionSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: List[int] + """Array of token IDs for the text content.""" + + +class TranscriptionResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The transcribed text.""" + + segments: Optional[List[TranscriptionSegment]] = None + """Segments of the transcribed text and their corresponding details.""" + + words: Optional[List[TranscriptionWord]] = None + """Extracted words and their corresponding timestamps.""" diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py new file mode 100644 index 000000000000..9eef9e3c9dd2 --- /dev/null +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import io +from typing import Any, AsyncGenerator, Dict, Optional, Union + +import librosa +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (RequestResponseMetadata, + TranscriptionRequest, + TranscriptionResponse, + TranscriptionResponseVerbose) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import RequestOutput + +logger = init_logger(__name__) + + +class OpenAIServingTranscription(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info( + "Overwriting default completion sampling param with: %s", + diff_sampling_param) + + async def _preprocess_transcription( + self, + request: TranscriptionRequest, + audio_data: bytes, + ) -> Dict[Any, Any]: + return { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": librosa.load(io.BytesIO(audio_data)), + }, + }, + # TODO(rob): tokenize here. + "decoder_prompt": + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" + # "decoder_prompt": f"{request.prompt}", + } + + # TODO (varun) : Make verbose response work ! + async def create_transcription( + self, audio_data: bytes, request: TranscriptionRequest, + raw_request: Request + ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/audio/createTranscription + for the API specification. This API mimics the OpenAI completion API. + """ + + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + if request.response_format not in ['text', 'json']: + return self.create_error_response( + "Currently only support response_format `text` or `json`") + + request_id = f"cmpl-{self._base_request_id(raw_request)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if lora_request: + return self.create_error_response( + "Currently do not support LoRA for Transcription.") + if prompt_adapter_request: + return self.create_error_response( + "Currently do not support PromptAdapter for Transcription." + ) + + prompt = await self._preprocess_transcription( + request=request, + audio_data=audio_data, + ) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + result_generator: AsyncGenerator[RequestOutput, None] = None + try: + # TODO(rob): subtract len of tokenized prompt. + default_max_tokens = self.model_config.max_model_len + default_params = self.model_config.get_diff_sampling_param() + sampling_params = request.to_sampling_params( + default_max_tokens, default_params) + + self._log_inputs(request_id, + prompt['decoder_prompt'], + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) + + result_generator = self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + # TODO(rob): figure out a way to pipe streaming in. + stream = False + if stream: + return None + + # Non-streaming response. + try: + async for op in result_generator: + result = op + return TranscriptionResponse(text=result.outputs[0].text) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e))