Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions examples/online_serving/openai_transcription_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@


def sync_openai(
audio_path: str, client: OpenAI, model: str, *, repetition_penalty: float = 1.3
audio_path: str,
client: OpenAI,
model: str,
*,
repetition_penalty: float = 1.3,
hotwords: str = None,
):
"""
Perform synchronous transcription using OpenAI-compatible API.
Expand All @@ -43,12 +48,15 @@ def sync_openai(
extra_body=dict(
seed=4419,
repetition_penalty=repetition_penalty,
hotwords=hotwords,
),
)
print("transcription result [sync]:", transcription.text)


async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: str):
async def stream_openai_response(
audio_path: str, client: AsyncOpenAI, model: str, hotwords: str = None
):
"""
Perform asynchronous transcription using OpenAI-compatible API.
"""
Expand All @@ -64,6 +72,7 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: st
extra_body=dict(
seed=420,
top_p=0.6,
hotwords=hotwords,
),
stream=True,
)
Expand Down Expand Up @@ -136,6 +145,7 @@ def main(args):
client=client,
model=model,
repetition_penalty=args.repetition_penalty,
hotwords=args.hotwords,
)

# Run the asynchronous function
Expand All @@ -146,7 +156,10 @@ def main(args):
)
asyncio.run(
stream_openai_response(
args.audio_path if args.audio_path else winning_call, client, model
args.audio_path if args.audio_path else winning_call,
client,
model,
hotwords=args.hotwords,
)
)
else:
Expand Down Expand Up @@ -174,5 +187,11 @@ def main(args):
default=1.3,
help="repetition penalty",
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="hotwords",
)
args = parser.parse_args()
main(args)
13 changes: 13 additions & 0 deletions vllm/entrypoints/openai/speech_to_text/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,19 @@ class TranscriptionRequest(OpenAIBaseModel):

language: str | None = None
"""The language of the input audio.


Supplying the input language in
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
will improve accuracy and latency.
"""

hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""

prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
Expand Down Expand Up @@ -446,6 +453,12 @@ class TranslationRequest(OpenAIBaseModel):
will improve accuracy.
"""

hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""

to_language: str | None = None
"""The language of the input audio we translate to.

Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/openai/speech_to_text/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ async def _preprocess_speech_to_text(
else None
)

hotwords = request.hotwords if request.hotwords else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing request.hotwords directly will raise an AttributeError when the request is a TranslationRequest, as the hotwords field is currently only defined in TranscriptionRequest within protocol.py. Since _preprocess_speech_to_text is shared between transcription and translation tasks, you should use getattr to safely access this field.

Suggested change
hotwords = request.hotwords if request.hotwords else None
hotwords = getattr(request, "hotwords", None) or None


if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise VLLMValidationError(
"Maximum file size exceeded",
Expand Down Expand Up @@ -277,6 +279,7 @@ async def _preprocess_speech_to_text(
task_type=self.task_type,
request_prompt=request.prompt,
to_language=to_language,
hotwords=hotwords,
)

parsed_prompt: DictPrompt
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/cohere_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2024,6 +2024,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
if language is None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/fireredasr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
if language is None:
raise ValueError(
Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/models/funasr.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,13 +884,20 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
if language is None:
raise ValueError(
"Language must be specified when creating the funasr prompt"
)

funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501
if hotwords is not None:
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n热词列表:[{}]\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n".format( # noqa: E501
hotwords
)
else:
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501

prompt = {
"prompt": funasr_prompt,
"multi_modal_data": {
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/gemma3n_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
"""
Gemma3n supports "free-form" transcription.
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/glmasr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
"""Get the generation prompt to be used for transcription requests."""
tokenizer = cached_tokenizer_from_config(model_config)
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/granite_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
"""Get the generation prompt to be used for transcription requests."""
# Audio placeholders don't use an index, so value doesn't matter
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
"""Get the prompt for the ASR model.
The model has control over the construction, as long as it
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/kimi_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/qwen3_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
"""Get the generation prompt to be used for transcription requests."""
tokenizer = cached_tokenizer_from_config(model_config)
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/qwen3_omni_moe_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,6 +2195,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
"""
Construct a transcription/translation prompt for Qwen3-Omni.
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/voxtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/voxtral_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ def get_generation_prompt(
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
hotwords: str | None,
) -> PromptType:
if language is None:
raise ValueError(
Expand Down
Loading