diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 7862778464dd..00756e719992 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -57,6 +57,8 @@ We currently support the following OpenAI APIs: - Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`). - [Transcriptions API][transcriptions-api] (`/v1/audio/transcriptions`) - Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`). +- [Translation API][translations-api] (`/v1/audio/translations`) + - Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`). In addition, we have the following custom APIs: @@ -374,6 +376,34 @@ The following extra parameters are supported: ```python --8<-- "vllm/entrypoints/openai/protocol.py:transcription-extra-params" ``` + +[](){ #translations-api } + +### Translations API + +Our Translation API is compatible with [OpenAI's Translations API](https://platform.openai.com/docs/api-reference/audio/createTranslation); +you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. +Whisper models can translate audio from one of the 55 non-English supported languages into English. +Please mind that the popular `openai/whisper-large-v3-turbo` model does not support translating. + +!!! note + To use the Translation API, please install with extra audio dependencies using `pip install vllm[audio]`. + +Code example: + +#### Extra Parameters + +The following [sampling parameters][sampling-params] are supported. + +```python +--8<-- "vllm/entrypoints/openai/protocol.py:translation-sampling-params" +``` + +The following extra parameters are supported: + +```python +--8<-- "vllm/entrypoints/openai/protocol.py:translation-extra-params" +``` [](){ #tokenizer-api } diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index ae43cb5da790..755038a76139 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -26,23 +26,12 @@ from vllm.assets.audio import AudioAsset -mary_had_lamb = AudioAsset("mary_had_lamb").get_local_path() -winning_call = AudioAsset("winning_call").get_local_path() -# Modify OpenAI's API key and API base to use vLLM's API server. -openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, -) - - -def sync_openai(): +def sync_openai(audio_path: str, client: OpenAI): """ Perform synchronous transcription using OpenAI-compatible API. """ - with open(str(mary_had_lamb), "rb") as f: + with open(audio_path, "rb") as f: transcription = client.audio.transcriptions.create( file=f, model="openai/whisper-large-v3", @@ -58,8 +47,7 @@ def sync_openai(): print("transcription result:", transcription.text) -# OpenAI Transcription API client does not support streaming. -async def stream_openai_response(): +async def stream_openai_response(audio_path: str, base_url: str, api_key: str): """ Perform streaming transcription using vLLM's raw HTTP streaming API. """ @@ -68,11 +56,12 @@ async def stream_openai_response(): "stream": True, "model": "openai/whisper-large-v3", } - url = openai_api_base + "/audio/transcriptions" - headers = {"Authorization": f"Bearer {openai_api_key}"} + url = base_url + "/audio/transcriptions" + headers = {"Authorization": f"Bearer {api_key}"} print("transcription result:", end=" ") + # OpenAI Transcription API client does not support streaming. async with httpx.AsyncClient() as client: - with open(str(winning_call), "rb") as f: + with open(audio_path, "rb") as f: async with client.stream( "POST", url, files={"file": f}, data=data, headers=headers ) as response: @@ -93,10 +82,20 @@ async def stream_openai_response(): def main(): - sync_openai() - + mary_had_lamb = str(AudioAsset("mary_had_lamb").get_local_path()) + winning_call = str(AudioAsset("winning_call").get_local_path()) + + # Modify OpenAI's API key and API base to use vLLM's API server. + openai_api_key = "EMPTY" + openai_api_base = "http://localhost:8000/v1" + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + + sync_openai(mary_had_lamb, client) # Run the asynchronous function - asyncio.run(stream_openai_response()) + asyncio.run(stream_openai_response(winning_call, openai_api_base, openai_api_key)) if __name__ == "__main__": diff --git a/examples/online_serving/openai_translation_client.py b/examples/online_serving/openai_translation_client.py new file mode 100644 index 000000000000..6f7253e2a789 --- /dev/null +++ b/examples/online_serving/openai_translation_client.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import json + +import httpx +from openai import OpenAI + +from vllm.assets.audio import AudioAsset + + +def sync_openai(audio_path: str, client: OpenAI): + with open(audio_path, "rb") as f: + translation = client.audio.translations.create( + file=f, + model="openai/whisper-large-v3", + response_format="json", + temperature=0.0, + # Additional params not provided by OpenAI API. + extra_body=dict( + language="it", + seed=4419, + repetition_penalty=1.3, + ), + ) + print("translation result:", translation.text) + + +async def stream_openai_response(audio_path: str, base_url: str, api_key: str): + data = { + "language": "it", + "stream": True, + "model": "openai/whisper-large-v3", + } + url = base_url + "/audio/translations" + headers = {"Authorization": f"Bearer {api_key}"} + print("translation result:", end=" ") + # OpenAI translation API client does not support streaming. + async with httpx.AsyncClient() as client: + with open(audio_path, "rb") as f: + async with client.stream( + "POST", url, files={"file": f}, data=data, headers=headers + ) as response: + async for line in response.aiter_lines(): + # Each line is a JSON object prefixed with 'data: ' + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + # Last chunk, stream ends + if line.strip() == "[DONE]": + break + # Parse the JSON response + chunk = json.loads(line) + # Extract and print the content + content = chunk["choices"][0].get("delta", {}).get("content") + print(content, end="") + + +def main(): + foscolo = str(AudioAsset("azacinto_foscolo").get_local_path()) + + # Modify OpenAI's API key and API base to use vLLM's API server. + openai_api_key = "EMPTY" + openai_api_base = "http://localhost:8000/v1" + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + sync_openai(foscolo, client) + # Run the asynchronous function + asyncio.run(stream_openai_response(foscolo, openai_api_base, openai_api_key)) + + +if __name__ == "__main__": + main() diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 8117e774951e..dab14f1d7d03 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -82,6 +82,8 @@ async def test_long_audio_request(mary_had_lamb): mary_had_lamb.seek(0) audio, sr = librosa.load(mary_had_lamb) + # Add small silence after each audio for repeatability in the split process + audio = np.pad(audio, (0, 1600)) repeated_audio = np.tile(audio, 10) # Repeated audio to buffer buffer = io.BytesIO() diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py new file mode 100644 index 000000000000..0c2cb367f330 --- /dev/null +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import io +# imports for guided decoding tests +import json +from unittest.mock import patch + +import librosa +import numpy as np +import pytest +import soundfile as sf +from openai._base_client import AsyncAPIClient + +from vllm.assets.audio import AudioAsset + +from ...utils import RemoteOpenAIServer + + +@pytest.fixture +def foscolo(): + # Test translation it->en + path = AudioAsset('azacinto_foscolo').get_local_path() + with open(str(path), "rb") as f: + yield f + + +# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation! +@pytest.mark.asyncio +async def test_basic_audio(foscolo): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + translation = await client.audio.translations.create( + model=model_name, + file=foscolo, + response_format="text", + # TODO remove once language detection is implemented + extra_body=dict(language="it"), + temperature=0.0) + out = json.loads(translation)['text'].strip() + assert "Nor will I ever touch the sacred" in out + + +@pytest.mark.asyncio +async def test_audio_prompt(foscolo): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + # Condition whisper on starting text + prompt = "Nor have I ever" + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.translations.create( + model=model_name, + file=foscolo, + prompt=prompt, + extra_body=dict(language="it"), + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert "Nor will I ever touch the sacred" not in out + assert prompt not in out + + +@pytest.mark.asyncio +async def test_non_asr_model(foscolo): + # text to text model + model_name = "JackFram/llama-68m" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res = await client.audio.translations.create(model=model_name, + file=foscolo, + temperature=0.0) + assert res.code == 400 and not res.text + assert res.message == "The model does not support Translations API" + + +@pytest.mark.asyncio +async def test_streaming_response(foscolo): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + translation = "" + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res_no_stream = await client.audio.translations.create( + model=model_name, + file=foscolo, + response_format="json", + extra_body=dict(language="it"), + temperature=0.0) + # Unfortunately this only works when the openai client is patched + # to use streaming mode, not exposed in the translation api. + original_post = AsyncAPIClient.post + + async def post_with_stream(*args, **kwargs): + kwargs['stream'] = True + return await original_post(*args, **kwargs) + + with patch.object(AsyncAPIClient, "post", new=post_with_stream): + client = remote_server.get_async_client() + res = await client.audio.translations.create(model=model_name, + file=foscolo, + temperature=0.0, + extra_body=dict( + stream=True, + language="it")) + # Reconstruct from chunks and validate + async for chunk in res: + # just a chunk + text = chunk.choices[0]['delta']['content'] + translation += text + + assert translation == res_no_stream.text + + +@pytest.mark.asyncio +async def test_stream_options(foscolo): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + original_post = AsyncAPIClient.post + + async def post_with_stream(*args, **kwargs): + kwargs['stream'] = True + return await original_post(*args, **kwargs) + + with patch.object(AsyncAPIClient, "post", new=post_with_stream): + client = remote_server.get_async_client() + res = await client.audio.translations.create( + model=model_name, + file=foscolo, + temperature=0.0, + extra_body=dict(language="it", + stream=True, + stream_include_usage=True, + stream_continuous_usage_stats=True)) + final = False + continuous = True + async for chunk in res: + if not len(chunk.choices): + # final usage sent + final = True + else: + continuous = continuous and hasattr(chunk, 'usage') + assert final and continuous + + +@pytest.mark.asyncio +async def test_long_audio_request(foscolo): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + + foscolo.seek(0) + audio, sr = librosa.load(foscolo) + repeated_audio = np.tile(audio, 2) + # Repeated audio to buffer + buffer = io.BytesIO() + sf.write(buffer, repeated_audio, sr, format='WAV') + buffer.seek(0) + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + translation = await client.audio.translations.create( + model=model_name, + file=buffer, + extra_body=dict(language="it"), + response_format="text", + temperature=0.0) + out = json.loads(translation)['text'].strip().lower() + # TODO investigate higher model uncertainty in for longer translations. + assert out.count("nor will i ever") == 2 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a23736470f66..681633a2aff7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -73,6 +73,8 @@ TokenizeResponse, TranscriptionRequest, TranscriptionResponse, + TranslationRequest, + TranslationResponse, UnloadLoRAAdapterRequest) # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat @@ -88,7 +90,7 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.serving_transcription import ( - OpenAIServingTranscription) + OpenAIServingTranscription, OpenAIServingTranslation) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, with_cancellation) @@ -401,6 +403,10 @@ def transcription(request: Request) -> OpenAIServingTranscription: return request.app.state.openai_serving_transcription +def translation(request: Request) -> OpenAIServingTranslation: + return request.app.state.openai_serving_translation + + def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -774,6 +780,47 @@ async def create_transcriptions(raw_request: Request, return StreamingResponse(content=generator, media_type="text/event-stream") +@router.post("/v1/audio/translations", + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.UNPROCESSABLE_ENTITY.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +@load_aware_call +async def create_translations(request: Annotated[TranslationRequest, + Form()], + raw_request: Request): + handler = translation(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Translations API") + + audio_data = await request.file.read() + generator = await handler.create_translation(audio_data, request, + raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranslationResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + @router.post("/rerank", dependencies=[Depends(validate_json_request)], responses={ @@ -1248,6 +1295,12 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger, ) if model_config.runner_type == "transcription" else None + state.openai_serving_translation = OpenAIServingTranslation( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.runner_type == "transcription" else None state.task = model_config.task state.enable_server_load_tracking = args.enable_server_load_tracking diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b278d0d00586..3b5281962b2d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1947,3 +1947,190 @@ class TranscriptionResponseVerbose(OpenAIBaseModel): words: Optional[list[TranscriptionWord]] = None """Extracted words and their corresponding timestamps.""" + + +class TranslationResponseStreamChoice(OpenAIBaseModel): + delta: DeltaMessage + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class TranslationStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}") + object: Literal["translation.chunk"] = "translation.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[TranslationResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class TranslationRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/audio/createTranslation + + file: UploadFile + """ + The audio file object (not file name) to translate, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: Optional[str] = None + """ID of the model to use. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + # TODO support additional sampling parameters + # --8<-- [start:translation-sampling-params] + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + # --8<-- [end:translation-sampling-params] + + # --8<-- [start:translation-extra-params] + language: Optional[str] = None + """The language of the input audio we translate from. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy. + """ + + stream: Optional[bool] = False + """Custom field not present in the original OpenAI definition. When set, + it will enable output to be streamed in a similar fashion as the Chat + Completion endpoint. + """ + # Flattened stream option to simplify form data. + stream_include_usage: Optional[bool] = False + stream_continuous_usage_stats: Optional[bool] = False + # --8<-- [end:translation-extra-params] + + # Default sampling parameters for translation requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "temperature": 0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens, + output_kind=RequestOutputKind.DELTA + if self.stream \ + else RequestOutputKind.FINAL_ONLY) + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] + stream = data.get("stream", False) + if any(bool(data.get(so, False)) for so in stream_opts) and not stream: + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + +# Translation response objects +class TranslationResponse(OpenAIBaseModel): + text: str + """The translated text.""" + + +class TranslationWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + + +class TranslationSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: list[int] + """Array of token IDs for the text content.""" + + +class TranslationResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The translated text.""" + + segments: Optional[list[TranslationSegment]] = None + """Segments of the translated text and their corresponding details.""" + + words: Optional[list[TranslationWord]] = None + """Extracted words and their corresponding timestamps.""" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 4bf790bbb298..cf2b738ba55e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -58,7 +58,8 @@ TokenizeCompletionRequest, TokenizeResponse, TranscriptionRequest, - TranscriptionResponse) + TranscriptionResponse, + TranslationRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable @@ -89,9 +90,8 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] - -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, - TranscriptionRequest] +SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest] AnyResponse = Union[ CompletionResponse, diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 60d66434ea5a..0d6989fe91bf 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,155 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio -import io -import math -import time from collections.abc import AsyncGenerator -from math import ceil -from typing import Final, Optional, Union, cast +from typing import Optional, Union -import numpy as np from fastapi import Request from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( - DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest, + ErrorResponse, RequestResponseMetadata, TranscriptionRequest, TranscriptionResponse, TranscriptionResponseStreamChoice, - TranscriptionStreamResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import OpenAIServing + TranscriptionStreamResponse, TranslationRequest, TranslationResponse, + TranslationResponseStreamChoice, TranslationStreamResponse) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.inputs.data import PromptType +from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText from vllm.logger import init_logger from vllm.outputs import RequestOutput -from vllm.transformers_utils.processor import cached_get_processor -from vllm.utils import PlaceholderModule - -try: - import librosa -except ImportError: - librosa = PlaceholderModule("librosa") # type: ignore[assignment] logger = init_logger(__name__) -# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages -# TODO these configs should live somewhere with the model so we can support -# additional ones - -ISO639_1_SUPPORTED_LANGS = { - "af": "Afrikaans", - "ar": "Arabic", - "hy": "Armenian", - "az": "Azerbaijani", - "be": "Belarusian", - "bs": "Bosnian", - "bg": "Bulgarian", - "ca": "Catalan", - "zh": "Chinese", - "hr": "Croatian", - "cs": "Czech", - "da": "Danish", - "nl": "Dutch", - "en": "English", - "et": "Estonian", - "fi": "Finnish", - "fr": "French", - "gl": "Galician", - "de": "German", - "el": "Greek", - "he": "Hebrew", - "hi": "Hindi", - "hu": "Hungarian", - "is": "Icelandic", - "id": "Indonesian", - "it": "Italian", - "ja": "Japanese", - "kn": "Kannada", - "kk": "Kazakh", - "ko": "Korean", - "lv": "Latvian", - "lt": "Lithuanian", - "mk": "Macedonian", - "ms": "Malay", - "mr": "Marathi", - "mi": "Maori", - "ne": "Nepali", - "no": "Norwegian", - "fa": "Persian", - "pl": "Polish", - "pt": "Portuguese", - "ro": "Romanian", - "ru": "Russian", - "sr": "Serbian", - "sk": "Slovak", - "sl": "Slovenian", - "es": "Spanish", - "sw": "Swahili", - "sv": "Swedish", - "tl": "Tagalog", - "ta": "Tamil", - "th": "Thai", - "tr": "Turkish", - "uk": "Ukrainian", - "ur": "Urdu", - "vi": "Vietnamese", - "cy": "Welsh" -} -ISO639_1_OTHER_LANGS = { - "lo": "Lao", - "jw": "Javanese", - "tk": "Turkmen", - "yi": "Yiddish", - "so": "Somali", - "bn": "Bengali", - "nn": "Norwegian Nynorsk", - "si": "Sinhala", - "yo": "Yoruba", - "sa": "Sanskrit", - "mi": "Māori", - "fo": "Faroese", # codespell:ignore - "mt": "Maltese", - "tg": "Tajik", - "mg": "Malagasy", - "haw": "Hawaiian", - "km": "Khmer", - "br": "Breton", - "ps": "Pashto", - "ln": "Lingala", - "la": "Latin", - "ml": "Malayalam", - "sq": "Albanian", - "su": "Sundanese", - "eu": "Basque", - "ka": "Georgian", - "uz": "Uzbek", - "sn": "Shona", - "ht": "Haitian", - "as": "Assamese", - "mn": "Mongolian", - "te": "Telugu", - "pa": "Panjabi", - "tt": "Tatar", - "gu": "Gujarati", - "oc": "Occitan", - "ha": "Hausa", - "ba": "Bashkir", - "my": "Burmese", - "sd": "Sindhi", - "am": "Amharic", - "lb": "Luxembourgish", - "bo": "Tibetan" -} - -# As per https://platform.openai.com/docs/guides/speech-to-text#overview. -# TODO configurable -MAX_AUDIO_CLIP_FILESIZE_MB = 25 -OVERLAP_CHUNK_SECOND = 1 -MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio - -class OpenAIServingTranscription(OpenAIServing): +class OpenAIServingTranscription(OpenAISpeechToText): + """Handles transcription requests.""" def __init__( self, @@ -164,70 +37,9 @@ def __init__( model_config=model_config, models=models, request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) - - self.default_sampling_params = ( - self.model_config.get_diff_sampling_param()) - processor = cached_get_processor(model_config.model) - self.max_audio_clip_s = processor.feature_extractor.chunk_length - self.model_sr = processor.feature_extractor.sampling_rate - self.hop_length = processor.feature_extractor.hop_length - - if self.default_sampling_params: - logger.info( - "Overwriting default completion sampling param with: %s", - self.default_sampling_params) - - async def _preprocess_transcription( - self, - request: TranscriptionRequest, - audio_data: bytes, - ) -> tuple[list[PromptType], float]: - # Validate request - # TODO language should be optional and can be guessed. - # For now we default to en. See - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 - lang_token = f"<|{request.language}|>" if request.language else "<|en|>" - if request.language: - if request.language in ISO639_1_SUPPORTED_LANGS: - pass - elif request.language in ISO639_1_OTHER_LANGS: - logger.warning( - "The selected language %s has limited accuracy with" - " reported WER>=0.5. Results may be less accurate " - "for this choice.", request.language) - else: - raise ValueError( - f"Unsupported language: {request.language}." - "Language should be one of:" + - f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + - f"or {list(ISO639_1_OTHER_LANGS.values())}") - - if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: - raise ValueError("Maximum file size exceeded.") - - with io.BytesIO(audio_data) as bytes_: - y, sr = librosa.load(bytes_) - - duration = librosa.get_duration(y=y, sr=sr) - chunks = [y] if duration < 30 else self._split_audio(y, sr) - prompts = [] - for i, chunk in enumerate(chunks): - prompt = { - "encoder_prompt": { - "prompt": "", - "multi_modal_data": { - "audio": (chunk, sr), - }, - }, - "decoder_prompt": - f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" - if i == 0 else "" - } - prompts.append(cast(PromptType, prompt)) - return prompts, duration + return_tokens_as_token_ids=return_tokens_as_token_ids, + task_type="transcribe") - # TODO (varun) : Make verbose response work ! async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request @@ -238,250 +50,83 @@ async def create_transcription( See https://platform.openai.com/docs/api-reference/audio/createTranscription for the API specification. This API mimics the OpenAI transcription API. """ - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - # If the engine is dead, raise the engine's DEAD_ERROR. - # This is required for the streaming case, where we return a - # success status before we actually start generating text :). - if self.engine_client.errored: - raise self.engine_client.dead_error - - if request.response_format not in ['text', 'json']: - return self.create_error_response( - "Currently only support response_format `text` or `json`") - - request_id = f"trsc-{self._base_request_id(raw_request)}" - - request_metadata = RequestResponseMetadata(request_id=request_id) - if raw_request: - raw_request.state.request_metadata = request_metadata - - try: - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) - - if lora_request: - return self.create_error_response( - "Currently do not support LoRA for Transcription.") - if prompt_adapter_request: - return self.create_error_response( - "Currently do not support PromptAdapter for Transcription." - ) - - prompts, duration_s = await self._preprocess_transcription( - request=request, - audio_data=audio_data, - ) - - except ValueError as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) - - list_result_generator: Optional[list[AsyncGenerator[RequestOutput, - None]]] = None - try: - # Unlike most decoder-only models, whisper generation length is not - # constrained by the size of the input audio, which is mapped to a - # fixed-size log-mel-spectogram. - default_max_tokens = self.model_config.max_model_len - sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) - - self._log_inputs( - request_id, - prompts[0]['decoder_prompt'], # type: ignore - params=sampling_params, - lora_request=None, - prompt_adapter_request=None) - - list_result_generator = [ - self.engine_client.generate( - prompt, - sampling_params, - request_id, - ) for prompt in prompts - ] - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - if request.stream: - return self.transcription_stream_generator(request, - list_result_generator, - request_id, - request_metadata, - duration_s) - # Non-streaming response. - try: - assert list_result_generator is not None - text = "" - for result_generator in list_result_generator: - async for op in result_generator: - text += op.outputs[0].text - return TranscriptionResponse(text=text) - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return await self._create_speech_to_text( + audio_data=audio_data, + request=request, + raw_request=raw_request, + response_class=TranscriptionResponse, + stream_generator_method=self.transcription_stream_generator, + ) async def transcription_stream_generator( self, request: TranscriptionRequest, - list_result_generator: list[AsyncGenerator[RequestOutput, None]], + result_generator: list[AsyncGenerator[RequestOutput, None]], request_id: str, request_metadata: RequestResponseMetadata, audio_duration_s: float) -> AsyncGenerator[str, None]: - created_time = int(time.time()) - model_name = request.model - chunk_object_type: Final = "transcription.chunk" - - completion_tokens = 0 - num_prompt_tokens = 0 - - include_usage = request.stream_include_usage \ - if request.stream_include_usage else False - include_continuous_usage = request.stream_continuous_usage_stats\ - if include_usage and request.stream_continuous_usage_stats\ - else False - - try: - for result_generator in list_result_generator: - async for res in result_generator: - # On first result. - if res.prompt_token_ids is not None: - # Do not account the 4-tokens `<|startoftranscript|>..` - # Could be negative when language token - # is not specified. - num_prompt_tokens = max( - len(res.prompt_token_ids) - 4, 0) - # NOTE(NickLucche) user can't pass encoder - # prompts directly at least not to Whisper. - # One indicator of the encoder amount of processing - # is the log-mel spectogram length. - num_prompt_tokens += ceil( - audio_duration_s * self.model_sr / self.hop_length) - - # We need to do it here, because if there are exceptions in - # the result_generator, it needs to be sent as the FIRST - # response (by the try...catch). - - # Just one output (n=1) supported. - assert len(res.outputs) == 1 - output = res.outputs[0] + generator = self._speech_to_text_stream_generator( + request=request, + list_result_generator=result_generator, + request_id=request_id, + request_metadata=request_metadata, + audio_duration_s=audio_duration_s, + chunk_object_type="transcription.chunk", + response_stream_choice_class=TranscriptionResponseStreamChoice, + stream_response_class=TranscriptionStreamResponse, + ) + async for chunk in generator: + yield chunk + + +class OpenAIServingTranslation(OpenAISpeechToText): + """Handles translation requests.""" - delta_message = DeltaMessage(content=output.text) - completion_tokens += len(output.token_ids) - - if output.finish_reason is None: - # Still generating, send delta update. - choice_data = TranscriptionResponseStreamChoice( - delta=delta_message) - else: - # Model is finished generating. - choice_data = TranscriptionResponseStreamChoice( - delta=delta_message, - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) - - chunk = TranscriptionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - - # handle usage stats if requested & if continuous - if include_continuous_usage: - chunk.usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens, - ) - - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - - # Once the final token is handled, if stream_options.include_usage - # is sent, send the usage. - if include_usage: - final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + - completion_tokens) - - final_usage_chunk = TranscriptionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[], - model=model_name, - usage=final_usage) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=True, exclude_none=True)) - yield f"data: {final_usage_data}\n\n" - - # report to FastAPI middleware aggregate usage across all choices - request_metadata.final_usage_info = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens) - - except Exception as e: - # TODO: Use a vllm-specific Validation Error - logger.exception("Error in chat completion stream generator.") - data = self.create_streaming_error_response(str(e)) - yield f"data: {data}\n\n" - # Send the final done message after all response.n are finished - yield "data: [DONE]\n\n" - - def _split_audio(self, audio_data: np.ndarray, - sample_rate: int) -> list[np.ndarray]: - chunk_size = sample_rate * self.max_audio_clip_s - overlap_size = sample_rate * OVERLAP_CHUNK_SECOND - chunks = [] - i = 0 - while i < audio_data.shape[-1]: - if i + chunk_size >= audio_data.shape[-1]: - # handle last chunk - chunks.append(audio_data[..., i:]) - break - - # Find the best split point in the overlap region - search_start = i + chunk_size - overlap_size - search_end = min(i + chunk_size, audio_data.shape[-1]) - split_point = self._find_split_point(audio_data, search_start, - search_end) + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + task_type="translate") - # Extract chunk up to the split point - chunks.append(audio_data[..., i:split_point]) - i = split_point - return chunks + async def create_translation( + self, audio_data: bytes, request: TranslationRequest, + raw_request: Request + ) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]: + """Translation API similar to OpenAI's API. - def _find_split_point(self, wav: np.ndarray, start_idx: int, - end_idx: int) -> int: - """Find the best point to split audio by - looking for silence or low amplitude. - Args: - wav: Audio tensor [1, T] - start_idx: Start index of search region - end_idx: End index of search region - Returns: - Index of best splitting point + See https://platform.openai.com/docs/api-reference/audio/createTranslation + for the API specification. This API mimics the OpenAI translation API. """ - segment = wav[start_idx:end_idx] - - # Calculate RMS energy in small windows - min_energy = math.inf - quietest_idx = 0 - for i in range(0, - len(segment) - MIN_ENERGY_WINDOW_SIZE, - MIN_ENERGY_WINDOW_SIZE): - window = segment[i:i + MIN_ENERGY_WINDOW_SIZE] - energy = (window**2).mean()**0.5 - if energy < min_energy: - quietest_idx = i + start_idx - min_energy = energy - return quietest_idx + return await self._create_speech_to_text( + audio_data=audio_data, + request=request, + raw_request=raw_request, + response_class=TranslationResponse, + stream_generator_method=self.translation_stream_generator, + ) + + async def translation_stream_generator( + self, request: TranslationRequest, + result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, request_metadata: RequestResponseMetadata, + audio_duration_s: float) -> AsyncGenerator[str, None]: + generator = self._speech_to_text_stream_generator( + request=request, + list_result_generator=result_generator, + request_id=request_id, + request_metadata=request_metadata, + audio_duration_s=audio_duration_s, + chunk_object_type="translation.chunk", + response_stream_choice_class=TranslationResponseStreamChoice, + stream_response_class=TranslationStreamResponse, + ) + async for chunk in generator: + yield chunk diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py new file mode 100644 index 000000000000..b23cf6cab097 --- /dev/null +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -0,0 +1,503 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import io +import math +import time +from collections.abc import AsyncGenerator +from math import ceil +from typing import Callable, Literal, Optional, TypeVar, Union, cast + +import numpy as np +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + DeltaMessage, ErrorResponse, RequestResponseMetadata, + TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionStreamResponse, TranslationResponse, + TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo) +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + SpeechToTextRequest) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils import PlaceholderModule + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +SpeechToTextResponse = Union[TranscriptionResponse, TranslationResponse] +T = TypeVar("T", bound=SpeechToTextResponse) + +logger = init_logger(__name__) + +# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages +# TODO these configs should live somewhere with the model so we can support +# additional ones + +ISO639_1_SUPPORTED_LANGS = { + "af": "Afrikaans", + "ar": "Arabic", + "hy": "Armenian", + "az": "Azerbaijani", + "be": "Belarusian", + "bs": "Bosnian", + "bg": "Bulgarian", + "ca": "Catalan", + "zh": "Chinese", + "hr": "Croatian", + "cs": "Czech", + "da": "Danish", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "gl": "Galician", + "de": "German", + "el": "Greek", + "he": "Hebrew", + "hi": "Hindi", + "hu": "Hungarian", + "is": "Icelandic", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "kk": "Kazakh", + "ko": "Korean", + "lv": "Latvian", + "lt": "Lithuanian", + "mk": "Macedonian", + "ms": "Malay", + "mr": "Marathi", + "mi": "Maori", + "ne": "Nepali", + "no": "Norwegian", + "fa": "Persian", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sr": "Serbian", + "sk": "Slovak", + "sl": "Slovenian", + "es": "Spanish", + "sw": "Swahili", + "sv": "Swedish", + "tl": "Tagalog", + "ta": "Tamil", + "th": "Thai", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "cy": "Welsh" +} +ISO639_1_OTHER_LANGS = { + "lo": "Lao", + "jw": "Javanese", + "tk": "Turkmen", + "yi": "Yiddish", + "so": "Somali", + "bn": "Bengali", + "nn": "Norwegian Nynorsk", + "si": "Sinhala", + "yo": "Yoruba", + "sa": "Sanskrit", + "mi": "Māori", + "fo": "Faroese", # codespell:ignore + "mt": "Maltese", + "tg": "Tajik", + "mg": "Malagasy", + "haw": "Hawaiian", + "km": "Khmer", + "br": "Breton", + "ps": "Pashto", + "ln": "Lingala", + "la": "Latin", + "ml": "Malayalam", + "sq": "Albanian", + "su": "Sundanese", + "eu": "Basque", + "ka": "Georgian", + "uz": "Uzbek", + "sn": "Shona", + "ht": "Haitian", + "as": "Assamese", + "mn": "Mongolian", + "te": "Telugu", + "pa": "Panjabi", + "tt": "Tatar", + "gu": "Gujarati", + "oc": "Occitan", + "ha": "Hausa", + "ba": "Bashkir", + "my": "Burmese", + "sd": "Sindhi", + "am": "Amharic", + "lb": "Luxembourgish", + "bo": "Tibetan" +} + +# As per https://platform.openai.com/docs/guides/speech-to-text#overview. +# TODO configurable +MAX_AUDIO_CLIP_FILESIZE_MB = 25 +OVERLAP_CHUNK_SECOND = 1 +MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio + + +class OpenAISpeechToText(OpenAIServing): + """Base class for speech-to-text operations like transcription and + translation.""" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + task_type: Literal["transcribe", "translate"] = "transcribe", + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + processor = cached_get_processor(model_config.model) + self.max_audio_clip_s = processor.feature_extractor.chunk_length + self.model_sr = processor.feature_extractor.sampling_rate + self.hop_length = processor.feature_extractor.hop_length + self.task_type = task_type + + if self.default_sampling_params: + logger.info( + "Overwriting default completion sampling param with: %s", + self.default_sampling_params) + + async def _preprocess_speech_to_text( + self, + request: SpeechToTextRequest, + audio_data: bytes, + ) -> tuple[list[PromptType], float]: + # Validate request + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + lang_token = f"<|{request.language}|>" if request.language else "<|en|>" + if request.language: + if request.language in ISO639_1_SUPPORTED_LANGS: + pass + elif request.language in ISO639_1_OTHER_LANGS: + logger.warning( + "The selected language %s has limited accuracy with" + " reported WER>=0.5. Results may be less accurate " + "for this choice.", request.language) + else: + raise ValueError( + f"Unsupported language: {request.language}." + "Language should be one of:" + + f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + + f"or {list(ISO639_1_OTHER_LANGS.values())}") + + if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: + raise ValueError("Maximum file size exceeded.") + + with io.BytesIO(audio_data) as bytes_: + # NOTE resample to model SR here for efficiency. This is also a + # pre-requisite for chunking, as it assumes Whisper SR. + y, sr = librosa.load(bytes_, sr=self.model_sr) + + duration = librosa.get_duration(y=y, sr=sr) + chunks = [y] if duration < 30 else self._split_audio(y, int(sr)) + prompts = [] + for chunk in chunks: + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (chunk, sr), + }, + }, + "decoder_prompt": + (f"<|startoftranscript|>{lang_token}" + f"<|{self.task_type}|><|notimestamps|>{request.prompt}") + } + prompts.append(cast(PromptType, prompt)) + return prompts, duration + + async def _create_speech_to_text( + self, + audio_data: bytes, + request: SpeechToTextRequest, + raw_request: Request, + response_class: type[T], + stream_generator_method: Callable[..., AsyncGenerator[str, None]], + ) -> Union[T, AsyncGenerator[str, None], ErrorResponse]: + """Base method for speech-to-text operations like transcription and + translation.""" + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + if request.response_format not in ['text', 'json']: + return self.create_error_response( + "Currently only support response_format `text` or `json`") + + request_id = f"{self.task_type}-{self._base_request_id(raw_request)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if lora_request: + return self.create_error_response( + "Currently do not support LoRA for " + f"{self.task_type.title()}.") + if prompt_adapter_request: + return self.create_error_response( + f"Currently do not support PromptAdapter for " + f"{self.task_type.title()}.") + + prompts, duration_s = await self._preprocess_speech_to_text( + request=request, + audio_data=audio_data, + ) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + list_result_generator: Optional[list[AsyncGenerator[RequestOutput, + None]]] = None + try: + # Unlike most decoder-only models, whisper generation length is not + # constrained by the size of the input audio, which is mapped to a + # fixed-size log-mel-spectogram. + default_max_tokens = self.model_config.max_model_len + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) + + self._log_inputs( + request_id, + prompts[0]['decoder_prompt'], # type: ignore + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) + + list_result_generator = [ + self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) for prompt in prompts + ] + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + if request.stream: + return stream_generator_method(request, list_result_generator, + request_id, request_metadata, + duration_s) + # Non-streaming response. + try: + assert list_result_generator is not None + text = "" + for result_generator in list_result_generator: + async for op in result_generator: + text += op.outputs[0].text + return cast(T, response_class(text=text)) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def _speech_to_text_stream_generator( + self, + request: SpeechToTextRequest, + list_result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, + request_metadata: RequestResponseMetadata, + audio_duration_s: float, + chunk_object_type: Literal["translation.chunk", "transcription.chunk"], + response_stream_choice_class: Union[ + type[TranscriptionResponseStreamChoice], + type[TranslationResponseStreamChoice]], + stream_response_class: Union[type[TranscriptionStreamResponse], + type[TranslationStreamResponse]], + ) -> AsyncGenerator[str, None]: + created_time = int(time.time()) + model_name = request.model + + completion_tokens = 0 + num_prompt_tokens = 0 + + include_usage = request.stream_include_usage \ + if request.stream_include_usage else False + include_continuous_usage = request.stream_continuous_usage_stats\ + if include_usage and request.stream_continuous_usage_stats\ + else False + + try: + for result_generator in list_result_generator: + async for res in result_generator: + # On first result. + if res.prompt_token_ids is not None: + # Do not account the 4-tokens `<|startoftranscript|>..` + # Could be negative when language token + # is not specified. + num_prompt_tokens = max( + len(res.prompt_token_ids) - 4, 0) + # NOTE(NickLucche) user can't pass encoder + # prompts directly at least not to Whisper. + # One indicator of the encoder amount of processing + # is the log-mel spectogram length. + num_prompt_tokens += ceil( + audio_duration_s * self.model_sr / self.hop_length) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + + # Just one output (n=1) supported. + assert len(res.outputs) == 1 + output = res.outputs[0] + + delta_message = DeltaMessage(content=output.text) + completion_tokens += len(output.token_ids) + + if output.finish_reason is None: + # Still generating, send delta update. + choice_data = response_stream_choice_class( + delta=delta_message) + else: + # Model is finished generating. + choice_data = response_stream_choice_class( + delta=delta_message, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + + chunk = stream_response_class(id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Once the final token is handled, if stream_options.include_usage + # is sent, send the usage. + if include_usage: + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + + final_usage_chunk = stream_response_class( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in %s stream generator.", self.task_type) + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + def _split_audio(self, audio_data: np.ndarray, + sample_rate: int) -> list[np.ndarray]: + chunk_size = sample_rate * self.max_audio_clip_s + overlap_size = sample_rate * OVERLAP_CHUNK_SECOND + chunks = [] + i = 0 + while i < audio_data.shape[-1]: + if i + chunk_size >= audio_data.shape[-1]: + # handle last chunk + chunks.append(audio_data[..., i:]) + break + + # Find the best split point in the overlap region + search_start = i + chunk_size - overlap_size + search_end = min(i + chunk_size, audio_data.shape[-1]) + split_point = self._find_split_point(audio_data, search_start, + search_end) + + # Extract chunk up to the split point + chunks.append(audio_data[..., i:split_point]) + i = split_point + return chunks + + def _find_split_point(self, wav: np.ndarray, start_idx: int, + end_idx: int) -> int: + """Find the best point to split audio by + looking for silence or low amplitude. + Args: + wav: Audio tensor [1, T] + start_idx: Start index of search region + end_idx: End index of search region + Returns: + Index of best splitting point + """ + segment = wav[start_idx:end_idx] + + # Calculate RMS energy in small windows + min_energy = math.inf + quietest_idx = 0 + for i in range(0, + len(segment) - MIN_ENERGY_WINDOW_SIZE, + MIN_ENERGY_WINDOW_SIZE): + window = segment[i:i + MIN_ENERGY_WINDOW_SIZE] + energy = (window**2).mean()**0.5 + if energy < min_energy: + quietest_idx = i + start_idx + min_energy = energy + return quietest_idx