From e59841b8c62f08efa96d7045d8c3bc638b11a795 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 23 Jan 2025 14:06:50 -0500 Subject: [PATCH 1/3] whisper-async working poc --- query_transcription.py | 24 +++ vllm/entrypoints/openai/api_server.py | 44 ++++- vllm/entrypoints/openai/protocol.py | 160 +++++++++++++++++- .../openai/serving_transcription.py | 137 +++++++++++++++ 4 files changed, 361 insertions(+), 4 deletions(-) create mode 100644 query_transcription.py create mode 100644 vllm/entrypoints/openai/serving_transcription.py diff --git a/query_transcription.py b/query_transcription.py new file mode 100644 index 000000000000..8132958a077b --- /dev/null +++ b/query_transcription.py @@ -0,0 +1,24 @@ +from openai import OpenAI +from openai.types.audio import TranscriptionCreateParams +from pathlib import Path +import io + +mary_had_lamb = Path('/home/varun/.cache/vllm/assets/vllm_public_assets/mary_had_lamb.ogg') +winning_call = Path('/home/varun/.cache/vllm/assets/vllm_public_assets/winning_call.ogg') + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) +with open(str(mary_had_lamb), "rb") as f: + transcription = client.audio.transcriptions.create( + file=f, + model="openai/whisper-large-v3", + language="en", + prompt="<|startoftranscript|>", + response_format="text", + temperature=0.0) + print("transcription result:", transcription) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9bb11907f740..280db0b9a7c4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,10 +14,10 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union +from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union, List, Annotated import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, FastAPI, HTTPException, Request, Form from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -58,7 +58,11 @@ ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, - UnloadLoraAdapterRequest) + UnloadLoraAdapterRequest, + TranscriptionRequest, + TranscriptionResponse, + AudioResponseFormat, + ) # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -70,6 +74,7 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.serving_transcription import OpenAIServingTranscription from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger @@ -303,6 +308,9 @@ def score(request: Request) -> Optional[OpenAIServingScores]: def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization +def transcription(request: Request) -> OpenAIServingTranscription: + return request.app.state.openai_serving_transcription + def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -495,6 +503,29 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) +@router.post("/v1/audio/transcriptions") +@with_cancellation +async def create_transcriptions(request: Annotated[TranscriptionRequest, Form()], + raw_request: Request): + + audio_data = await request.file.read() + + handler = transcription(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Transcriptions API") + + generator = await handler.create_transcription(audio_data, request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranscriptionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { "generate": { @@ -682,6 +713,7 @@ async def init_app_state( state: State, args: Namespace, ) -> None: + if args.served_model_name is not None: served_model_names = args.served_model_name else: @@ -761,6 +793,12 @@ async def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) + state.openai_serving_transcription = OpenAIServingTranscription( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) state.task = model_config.task diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 14e41346df77..cf5674d492eb 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3,11 +3,13 @@ import re import time from argparse import Namespace -from typing import Any, Dict, List, Literal, Optional, Union +from os import PathLike +from typing import Any, Dict, List, Literal, Optional, Union, TypeAlias, TYPE_CHECKING, Tuple, Mapping import torch from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated +from fastapi import UploadFile from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger @@ -1334,3 +1336,159 @@ class LoadLoraAdapterRequest(BaseModel): class UnloadLoraAdapterRequest(BaseModel): lora_name: str lora_int_id: Optional[int] = Field(default=None) + +## Protocols for Audio +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"] + +class TranscriptionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + #https://platform.openai.com/docs/api-reference/audio/createTranscription + + file: UploadFile + """ + The audio file object (not file name) to transcribe, in one of these formats: + flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: str + """ID of the model to use. + """ + + language: str + """The language of the input audio. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will + improve accuracy and latency. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + ## TODO (varun) : Support if set to 0, certain thresholds are met !! + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values like + 0.2 will make it more focused and deterministic. If set to 0, the model will use + [log probability](https://en.wikipedia.org/wiki/Log_probability) to + automatically increase the temperature until certain thresholds are hit. + """ + + timestamp_granularities: List[Literal["word", "segment"]] = Field(alias="timestamp_granularities[]", default=[]) + """The timestamp granularities to populate for this transcription. + + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: There + is no additional latency for segment timestamps, but generating word timestamps + incurs additional latency. + """ + + # Default sampling parameters for transcription requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "temperature": 0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + # TODO (varun) : ATM the max_tokens are set to the max-model-len - len(prompt_ids). + # Tbis makes sense. but is this okay ? + return SamplingParams.from_optional( + temperature=temperature, + max_tokens=max_tokens) + +# Transcription response objects +class TranscriptionResponse(OpenAIBaseModel): + text: str + """The transcribed text.""" + +class TranscriptionWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + +class TranscriptionSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this + segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: List[int] + """Array of token IDs for the text content.""" + +class TranscriptionResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The transcribed text.""" + + segments: Optional[List[TranscriptionSegment]] = None + """Segments of the transcribed text and their corresponding details.""" + + words: Optional[List[TranscriptionWord]] = None + """Extracted words and their corresponding timestamps.""" + diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py new file mode 100644 index 000000000000..c1d46ca0581a --- /dev/null +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -0,0 +1,137 @@ +import asyncio +import io +#import time +from typing import Any, AsyncGenerator, Dict, Optional, Union + +## TODO (varun) : This is used for testing.. use pydub instead ????? +import librosa +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (RequestResponseMetadata, + TranscriptionRequest, + TranscriptionResponse, + TranscriptionResponseVerbose) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import RequestOutput + +logger = init_logger(__name__) + + +class OpenAIServingTranscription(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info( + "Overwriting default completion sampling param with: %s", + diff_sampling_param) + + # TODO (varun) : pass in a tokenizer and return tokenized values !! + async def _preprocess_transcription( + self, audio_data: bytes, + request: TranscriptionRequest) -> Dict[Any, Any]: + return { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": librosa.load(io.BytesIO(audio_data)), + }, + }, + # TODO (Varun) : Should this instead be encoder prompt ??? + "decoder_prompt": f"{request.prompt}", + } + + # TODO (varun) : Make verbose response work ! + async def create_transcription( + self, audio_data: bytes, request: TranscriptionRequest, + raw_request: Request + ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/audio/createTranscription + for the API specification. This API mimics the OpenAI completion API. + """ + + assert request.response_format in ['text', 'json'] + + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + request_id = f"cmpl-{self._base_request_id(raw_request)}" + # TODO (varun) : other serving_* files use this -- we should use + # it as well. + #created_time = int(time.time()) + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + # TODO (varun) : Does Whisper have LoRA ? + #tokenizer = await self.engine_client.get_tokenizer(None) + + prompt = await self._preprocess_transcription(audio_data, request) + + default_sampling_params = (self.model_config.get_diff_sampling_param()) + + # TODO (Varun) : figure out default_max_tokens by tokenizing first + default_max_tokens = 200 + sampling_params = request.to_sampling_params(default_max_tokens, + default_sampling_params) + + self._log_inputs( + request_id, + prompt['decoder_prompt'], + params=sampling_params, + lora_request=None, + prompt_adapter_request=None, + ) + + generator: AsyncGenerator[RequestOutput, None] = None + try: + generator = self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + # Non-streaming response + result: Optional[RequestOutput] = None + + try: + async for op in generator: + result = op + return TranscriptionResponse(text=result.outputs[0].text) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) From 7fd9b4d40b3082572c487652d35675f8d45d25d0 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 3 Feb 2025 02:16:28 +0000 Subject: [PATCH 2/3] updated Signed-off-by: rshaw@neuralmagic.com --- query_transcription.py | 20 ++--- vllm/assets/audio.py | 5 ++ vllm/entrypoints/openai/api_server.py | 20 +++-- vllm/entrypoints/openai/protocol.py | 55 ++++++------ .../openai/serving_transcription.py | 88 +++++++++++-------- 5 files changed, 107 insertions(+), 81 deletions(-) diff --git a/query_transcription.py b/query_transcription.py index 8132958a077b..b0c9c95c8818 100644 --- a/query_transcription.py +++ b/query_transcription.py @@ -1,10 +1,9 @@ from openai import OpenAI -from openai.types.audio import TranscriptionCreateParams -from pathlib import Path -import io -mary_had_lamb = Path('/home/varun/.cache/vllm/assets/vllm_public_assets/mary_had_lamb.ogg') -winning_call = Path('/home/varun/.cache/vllm/assets/vllm_public_assets/winning_call.ogg') +from vllm.assets.audio import AudioAsset + +mary_had_lamb = AudioAsset('mary_had_lamb').get_asset_path() +winning_call = AudioAsset('winning_call').get_asset_path() # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" @@ -15,10 +14,9 @@ ) with open(str(mary_had_lamb), "rb") as f: transcription = client.audio.transcriptions.create( - file=f, - model="openai/whisper-large-v3", - language="en", - prompt="<|startoftranscript|>", - response_format="text", - temperature=0.0) + file=f, + model="openai/whisper-large-v3", + language="en", + response_format="text", + temperature=0.0) print("transcription result:", transcription) diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index a46c67ad7e00..db94ae0b7fb6 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from pathlib import Path from typing import Literal from urllib.parse import urljoin @@ -26,6 +27,10 @@ def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) + def get_asset_path(self) -> Path: + return get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + @property def url(self) -> str: return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 280db0b9a7c4..e8814aecd162 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,10 +14,10 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union, List, Annotated +from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request, Form +from fastapi import APIRouter, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -58,11 +58,9 @@ ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, - UnloadLoraAdapterRequest, TranscriptionRequest, TranscriptionResponse, - AudioResponseFormat, - ) + UnloadLoraAdapterRequest) # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -74,7 +72,8 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) -from vllm.entrypoints.openai.serving_transcription import OpenAIServingTranscription +from vllm.entrypoints.openai.serving_transcription import ( + OpenAIServingTranscription) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger @@ -308,6 +307,7 @@ def score(request: Request) -> Optional[OpenAIServingScores]: def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization + def transcription(request: Request) -> OpenAIServingTranscription: return request.app.state.openai_serving_transcription @@ -503,9 +503,11 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) + @router.post("/v1/audio/transcriptions") @with_cancellation -async def create_transcriptions(request: Annotated[TranscriptionRequest, Form()], +async def create_transcriptions(request: Annotated[TranscriptionRequest, + Form()], raw_request: Request): audio_data = await request.file.read() @@ -515,7 +517,8 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest, Form()] return base(raw_request).create_error_response( message="The model does not support Transcriptions API") - generator = await handler.create_transcription(audio_data, request, raw_request) + generator = await handler.create_transcription(audio_data, request, + raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -713,7 +716,6 @@ async def init_app_state( state: State, args: Namespace, ) -> None: - if args.served_model_name is not None: served_model_names = args.served_model_name else: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index cf5674d492eb..66152a5a4999 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3,13 +3,12 @@ import re import time from argparse import Namespace -from os import PathLike -from typing import Any, Dict, List, Literal, Optional, Union, TypeAlias, TYPE_CHECKING, Tuple, Mapping +from typing import Any, Dict, List, Literal, Optional, TypeAlias, Union import torch +from fastapi import UploadFile from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated -from fastapi import UploadFile from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger @@ -1337,17 +1336,20 @@ class UnloadLoraAdapterRequest(BaseModel): lora_name: str lora_int_id: Optional[int] = Field(default=None) + ## Protocols for Audio -AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"] +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", + "vtt"] + class TranscriptionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation #https://platform.openai.com/docs/api-reference/audio/createTranscription - file: UploadFile + file: UploadFile """ - The audio file object (not file name) to transcribe, in one of these formats: - flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + The audio file object (not file name) to transcribe, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. """ model: str @@ -1358,8 +1360,8 @@ class TranscriptionRequest(OpenAIBaseModel): """The language of the input audio. Supplying the input language in - [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will - improve accuracy and latency. + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy and latency. """ prompt: str = Field(default="") @@ -1376,23 +1378,24 @@ class TranscriptionRequest(OpenAIBaseModel): `verbose_json`, or `vtt`. """ - ## TODO (varun) : Support if set to 0, certain thresholds are met !! + ## TODO (varun) : Support if set to 0, certain thresholds are met !! temperature: float = Field(default=0.0) """The sampling temperature, between 0 and 1. - Higher values like 0.8 will make the output more random, while lower values like - 0.2 will make it more focused and deterministic. If set to 0, the model will use - [log probability](https://en.wikipedia.org/wiki/Log_probability) to - automatically increase the temperature until certain thresholds are hit. + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. """ - timestamp_granularities: List[Literal["word", "segment"]] = Field(alias="timestamp_granularities[]", default=[]) + timestamp_granularities: List[Literal["word", "segment"]] = Field( + alias="timestamp_granularities[]", default=[]) """The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. - Either or both of these options are supported: `word`, or `segment`. Note: There - is no additional latency for segment timestamps, but generating word timestamps - incurs additional latency. + Either or both of these options are supported: `word`, or `segment`. Note: + There is no additional latency for segment timestamps, but generating word + timestamps incurs additional latency. """ # Default sampling parameters for transcription requests. @@ -1414,17 +1417,16 @@ def to_sampling_params( temperature = default_sampling_params.get( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) - # TODO (varun) : ATM the max_tokens are set to the max-model-len - len(prompt_ids). - # Tbis makes sense. but is this okay ? - return SamplingParams.from_optional( - temperature=temperature, - max_tokens=max_tokens) + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens) + # Transcription response objects class TranscriptionResponse(OpenAIBaseModel): text: str """The transcribed text.""" + class TranscriptionWord(OpenAIBaseModel): end: float """End time of the word in seconds.""" @@ -1435,6 +1437,7 @@ class TranscriptionWord(OpenAIBaseModel): word: str """The text content of the word.""" + class TranscriptionSegment(OpenAIBaseModel): id: int """Unique identifier of the segment.""" @@ -1457,8 +1460,8 @@ class TranscriptionSegment(OpenAIBaseModel): no_speech_prob: float """Probability of no speech in the segment. - If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this - segment silent. + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. """ seek: int @@ -1476,6 +1479,7 @@ class TranscriptionSegment(OpenAIBaseModel): tokens: List[int] """Array of token IDs for the text content.""" + class TranscriptionResponseVerbose(OpenAIBaseModel): duration: str """The duration of the input audio.""" @@ -1491,4 +1495,3 @@ class TranscriptionResponseVerbose(OpenAIBaseModel): words: Optional[List[TranscriptionWord]] = None """Extracted words and their corresponding timestamps.""" - diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index c1d46ca0581a..572a6d1ae73c 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,9 +1,7 @@ import asyncio import io -#import time from typing import Any, AsyncGenerator, Dict, Optional, Union -## TODO (varun) : This is used for testing.. use pydub instead ????? import librosa from fastapi import Request @@ -45,10 +43,11 @@ def __init__( "Overwriting default completion sampling param with: %s", diff_sampling_param) - # TODO (varun) : pass in a tokenizer and return tokenized values !! async def _preprocess_transcription( - self, audio_data: bytes, - request: TranscriptionRequest) -> Dict[Any, Any]: + self, + request: TranscriptionRequest, + audio_data: bytes, + ) -> Dict[Any, Any]: return { "encoder_prompt": { "prompt": "", @@ -56,8 +55,10 @@ async def _preprocess_transcription( "audio": librosa.load(io.BytesIO(audio_data)), }, }, - # TODO (Varun) : Should this instead be encoder prompt ??? - "decoder_prompt": f"{request.prompt}", + # TODO(rob): tokenize here. + "decoder_prompt": + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" + # "decoder_prompt": f"{request.prompt}", } # TODO (varun) : Make verbose response work ! @@ -71,8 +72,6 @@ async def create_transcription( for the API specification. This API mimics the OpenAI completion API. """ - assert request.response_format in ['text', 'json'] - error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret @@ -83,38 +82,54 @@ async def create_transcription( if self.engine_client.errored: raise self.engine_client.dead_error + if request.response_format not in ['text', 'json']: + return self.create_error_response( + "Currently only support response_format `text` or `json`") + request_id = f"cmpl-{self._base_request_id(raw_request)}" - # TODO (varun) : other serving_* files use this -- we should use - # it as well. - #created_time = int(time.time()) request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: raw_request.state.request_metadata = request_metadata - # TODO (varun) : Does Whisper have LoRA ? - #tokenizer = await self.engine_client.get_tokenizer(None) - - prompt = await self._preprocess_transcription(audio_data, request) - - default_sampling_params = (self.model_config.get_diff_sampling_param()) - - # TODO (Varun) : figure out default_max_tokens by tokenizing first - default_max_tokens = 200 - sampling_params = request.to_sampling_params(default_max_tokens, - default_sampling_params) + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if lora_request: + return self.create_error_response( + "Currently do not support LoRA for Transcription.") + if prompt_adapter_request: + return self.create_error_response( + "Currently do not support PromptAdapter for Transcription." + ) + + prompt = await self._preprocess_transcription( + request=request, + audio_data=audio_data, + ) - self._log_inputs( - request_id, - prompt['decoder_prompt'], - params=sampling_params, - lora_request=None, - prompt_adapter_request=None, - ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) - generator: AsyncGenerator[RequestOutput, None] = None + result_generator: AsyncGenerator[RequestOutput, None] = None try: - generator = self.engine_client.generate( + # TODO(rob): subtract len of tokenized prompt. + default_max_tokens = self.model_config.max_model_len + default_params = self.model_config.get_diff_sampling_param() + sampling_params = request.to_sampling_params( + default_max_tokens, default_params) + + self._log_inputs(request_id, + prompt['decoder_prompt'], + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) + + result_generator = self.engine_client.generate( prompt, sampling_params, request_id, @@ -123,11 +138,14 @@ async def create_transcription( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - # Non-streaming response - result: Optional[RequestOutput] = None + # TODO(rob): figure out a way to pipe streaming in. + stream = False + if stream: + return None + # Non-streaming response. try: - async for op in generator: + async for op in result_generator: result = op return TranscriptionResponse(text=result.outputs[0].text) except asyncio.CancelledError: From f13fd97393ef44697c6c9678b5f6d66167400133 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 3 Feb 2025 02:18:11 +0000 Subject: [PATCH 3/3] license identifier Signed-off-by: rshaw@neuralmagic.com --- query_transcription.py | 1 + vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/protocol.py | 1 + vllm/entrypoints/openai/serving_transcription.py | 1 + 4 files changed, 4 insertions(+) diff --git a/query_transcription.py b/query_transcription.py index b0c9c95c8818..3e50b5d52575 100644 --- a/query_transcription.py +++ b/query_transcription.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 from openai import OpenAI from vllm.assets.audio import AudioAsset diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e8814aecd162..dead8d2f0da7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 import asyncio import atexit import importlib diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 66152a5a4999..a89ccb4781ee 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import re diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 572a6d1ae73c..9eef9e3c9dd2 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 import asyncio import io from typing import Any, AsyncGenerator, Dict, Optional, Union