diff --git a/docs/user_guide/examples/offline_inference/qwen3_tts.md b/docs/user_guide/examples/offline_inference/qwen3_tts.md index 2178228523..15a5fac236 100644 --- a/docs/user_guide/examples/offline_inference/qwen3_tts.md +++ b/docs/user_guide/examples/offline_inference/qwen3_tts.md @@ -90,6 +90,43 @@ Examples: python end2end.py --query-type Base --mode-tag icl ``` +## Voice and Language Control + +### Supported Voices (CustomVoice) + +Predefined speaker voices are set via the `speaker` (or `voice_type`) field in `additional_information`. Available speakers depend on the loaded checkpoint; check `talker_config.spk_id` in the model config for the full list. Common voices include `vivian`, `ryan`, `aiden`, `ethan`, `serena` (case-insensitive). + +Pass the speaker name in your request: + +```python +additional_information = { + "text": ["你好,我是通义千问"], + "task_type": ["CustomVoice"], + "speaker": ["Vivian"], + "language": ["Chinese"], +} +``` + +### Supported Languages + +The `language` field controls the codec-level language tag. Use `"Auto"` (default) for automatic detection. + +Supported values: `Auto`, `Chinese`, `English`, `Japanese`, `Korean`, `German`, `French`, `Russian`, `Portuguese`, `Spanish`, `Italian`. + +```python +additional_information = { + "text": ["Hello, nice to meet you."], + "task_type": ["CustomVoice"], + "speaker": ["Aiden"], + "language": ["English"], +} +``` + +### VoiceDesign and Base + +- **VoiceDesign**: Use `instruct` for natural-language voice description; no `speaker` needed. +- **Base**: Use `ref_audio` and `ref_text` for voice cloning; `language` is optional. + ## Streaming Mode Add `--streaming` to stream audio chunks progressively via `AsyncOmni` (requires `async_chunk: true` in the stage config): diff --git a/docs/user_guide/examples/online_serving/qwen2_5_omni.md b/docs/user_guide/examples/online_serving/qwen2_5_omni.md index 976d39a966..084e6a3f04 100644 --- a/docs/user_guide/examples/online_serving/qwen2_5_omni.md +++ b/docs/user_guide/examples/online_serving/qwen2_5_omni.md @@ -223,10 +223,6 @@ sudo apt install ffmpeg ``````py --8<-- "examples/online_serving/qwen2_5_omni/gradio_demo.py" `````` -??? abstract "openai_chat_completion_client_for_multimodal_generation.py" - ``````py - --8<-- "examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py" - `````` ??? abstract "run_curl_multimodal_generation.sh" ``````sh --8<-- "examples/online_serving/qwen2_5_omni/run_curl_multimodal_generation.sh" diff --git a/docs/user_guide/examples/online_serving/qwen3_tts.md b/docs/user_guide/examples/online_serving/qwen3_tts.md index faf2fe7814..65023258ef 100644 --- a/docs/user_guide/examples/online_serving/qwen3_tts.md +++ b/docs/user_guide/examples/online_serving/qwen3_tts.md @@ -103,13 +103,13 @@ cd examples/online_serving/qwen3_tts # CustomVoice: Use predefined speaker python openai_speech_client.py \ --text "你好,我是通义千问" \ - --voice vivian \ + --speaker vivian \ --language Chinese # CustomVoice with style instruction python openai_speech_client.py \ --text "今天天气真好" \ - --voice ryan \ + --speaker ryan \ --instructions "用开心的语气说" # VoiceDesign: Describe the voice style @@ -134,7 +134,7 @@ The Python client supports the following command-line arguments: - `--model` (or `-m`): Model name/path (default: `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice`) - `--task-type` (or `-t`): TTS task type. Options: `CustomVoice`, `VoiceDesign`, `Base` - `--text`: Text to synthesize (required) -- `--voice`: Speaker/voice name (default: `vivian`). Options: `vivian`, `ryan`, `aiden`, etc. +- `--speaker`: Speaker name (default: `vivian`). Options: `vivian`, `ryan`, `aiden`, etc. - `--language`: Language. Options: `Auto`, `Chinese`, `English`, `Japanese`, `Korean`, `German`, `French`, `Russian`, `Portuguese`, `Spanish`, `Italian` - `--instructions`: Voice style/emotion instructions - `--ref-audio`: Reference audio file path or URL for voice cloning (Base task) @@ -150,7 +150,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \ -H "Content-Type: application/json" \ -d '{ "input": "Hello, how are you?", - "voice": "vivian", + "speaker": "vivian", "language": "English" }' --output output.wav @@ -159,7 +159,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \ -H "Content-Type: application/json" \ -d '{ "input": "I am so excited!", - "voice": "vivian", + "speaker": "vivian", "instructions": "Speak with great enthusiasm" }' --output excited.wav @@ -176,7 +176,7 @@ client = OpenAI(base_url="http://localhost:8091/v1", api_key="none") response = client.audio.speech.create( model="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", - voice="vivian", + speaker="vivian", input="Hello, how are you?", ) @@ -192,7 +192,7 @@ response = httpx.post( "http://localhost:8091/v1/audio/speech", json={ "input": "Hello, how are you?", - "voice": "vivian", + "speaker": "vivian", "language": "English", }, timeout=300.0, @@ -279,7 +279,7 @@ This endpoint follows the [OpenAI Audio Speech API](https://platform.openai.com/ ```json { "input": "Text to synthesize", - "voice": "vivian", + "speaker": "vivian", "response_format": "wav", "task_type": "CustomVoice", "language": "Auto", @@ -297,6 +297,12 @@ This endpoint follows the [OpenAI Audio Speech API](https://platform.openai.com/ Returns binary audio data with appropriate `Content-Type` header (e.g., `audio/wav`). +### Voice and language (summary) + +- **Speaker**: Use the `speaker` request field to select the speaker (e.g., `vivian`, `ryan`, `aiden`). List available speakers with `GET /v1/audio/voices`. +- **Language**: Use the `language` field for the codec language tag (`Auto`, `Chinese`, `English`, etc.). Default is `Auto` for automatic detection. +- **CustomVoice**: Requires a valid `voice` from the model’s speaker set. **VoiceDesign**: Use `instructions` to describe the voice. **Base**: Use `ref_audio` and `ref_text` for voice cloning. + ## Parameters ### OpenAI Standard Parameters @@ -305,7 +311,7 @@ Returns binary audio data with appropriate `Content-Type` header (e.g., `audio/w | ----------------- | ------ | -------------- | ----------------------------------------------------------- | | `input` | string | **required** | Text to synthesize | | `model` | string | server's model | Model to use (optional, should match server if specified) | -| `voice` | string | "vivian" | Speaker name (e.g., vivian, ryan, aiden) | +| `speaker` | string | "vivian" | Speaker name (e.g., vivian, ryan, aiden) | | `response_format` | string | "wav" | Audio format: wav, mp3, flac, pcm, aac, opus | | `speed` | float | 1.0 | Playback speed (0.25-4.0, not supported with `stream=true`) | @@ -340,7 +346,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \ -H "Content-Type: application/json" \ -d '{ "input": "Hello, how are you?", - "voice": "vivian", + "speaker": "vivian", "language": "English", "stream": true, "response_format": "pcm" diff --git a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py similarity index 81% rename from examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py rename to examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py index 9675280e1c..da6933821e 100644 --- a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py @@ -16,6 +16,14 @@ class QueryResult(NamedTuple): limit_mm_per_prompt: dict[str, int] +def make_audio_output_filename(request_id: str | None, index: int) -> str: + """Build a stable output filename using request ID when available.""" + if not request_id: + request_id = f"unknown_{index}" + safe_request_id = "".join(ch if (ch.isalnum() or ch in ("-", "_")) else "_" for ch in request_id) + return f"audio_{safe_request_id}_{index}.wav" + + def encode_base64_content_from_url(content_url: str) -> str: """Encode a content retrieved from a remote url to base64 format.""" @@ -165,6 +173,34 @@ def get_system_prompt(): } +def _parse_csv_arg(value: str | None) -> list[str]: + if not value: + return [] + return [item.strip() for item in value.split(",") if item.strip()] + + +def _build_prompt_for_query_type( + query_type: str, + custom_prompt: str | None, + video_path: str | None, + image_path: str | None, + audio_path: str | None, +): + query_func = query_map[query_type] + if query_type == "use_video": + return query_func(video_path=video_path, custom_prompt=custom_prompt) + if query_type == "use_image": + return query_func(image_path=image_path, custom_prompt=custom_prompt) + if query_type == "use_audio": + return query_func(audio_path=audio_path, custom_prompt=custom_prompt) + if query_type == "text": + return query_func(custom_prompt=custom_prompt) + if query_type == "use_audio_in_video": + return query_func(video_path=video_path, custom_prompt=custom_prompt) + # use_mixed_modalities / use_multi_audios + return query_func(custom_prompt=custom_prompt) + + def get_text_query(custom_prompt: str | None = None): question = ( custom_prompt or "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." @@ -379,31 +415,6 @@ def run_multimodal_generation(args, client: OpenAI) -> None: audio_path = getattr(args, "audio_path", None) custom_prompt = getattr(args, "prompt", None) - # Get the query function and call it with appropriate parameters - query_func = query_map[args.query_type] - if args.query_type == "use_video": - prompt = query_func(video_path=video_path, custom_prompt=custom_prompt) - elif args.query_type == "use_image": - prompt = query_func(image_path=image_path, custom_prompt=custom_prompt) - elif args.query_type == "use_audio": - prompt = query_func(audio_path=audio_path, custom_prompt=custom_prompt) - elif args.query_type == "text": - prompt = query_func(custom_prompt=custom_prompt) - elif args.query_type == "use_audio_in_video": - prompt = query_func( - video_path=video_path, - custom_prompt=custom_prompt, - ) - else: - prompt = query_func() - - extra_body = { - "sampling_params_list": sampling_params_list # Optional, it has a default setting in stage_configs of the corresponding model. - } - - if args.query_type == "use_audio_in_video": - extra_body["mm_processor_kwargs"] = {"use_audio_in_video": True} - if args.modalities is not None: output_modalities = args.modalities.split(",") else: @@ -411,6 +422,37 @@ def run_multimodal_generation(args, client: OpenAI) -> None: # Test multiple concurrent completions num_concurrent_requests = args.num_concurrent_requests + prompt_list = _parse_csv_arg(getattr(args, "prompts", None)) + speaker_list = _parse_csv_arg(getattr(args, "speakers", None)) + + request_payloads = [] + for idx in range(num_concurrent_requests): + per_req_prompt = ( + prompt_list[idx] + if idx < len(prompt_list) + else (custom_prompt if idx == 0 or not prompt_list else prompt_list[-1]) + ) + per_req_speaker = ( + speaker_list[idx] + if idx < len(speaker_list) + else (args.speaker if idx == 0 or not speaker_list else speaker_list[-1]) + ) + prompt = _build_prompt_for_query_type( + query_type=args.query_type, + custom_prompt=per_req_prompt, + video_path=video_path, + image_path=image_path, + audio_path=audio_path, + ) + extra_body = { + # Optional, it has default settings in stage configs. + "sampling_params_list": sampling_params_list + } + if args.query_type == "use_audio_in_video": + extra_body["mm_processor_kwargs"] = {"use_audio_in_video": True} + if per_req_speaker and per_req_speaker.strip(): + extra_body["speaker"] = per_req_speaker.strip() + request_payloads.append({"prompt": prompt, "extra_body": extra_body}) with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor: # Submit multiple completion requests concurrently @@ -419,14 +461,14 @@ def run_multimodal_generation(args, client: OpenAI) -> None: client.chat.completions.create, messages=[ get_system_prompt(), - prompt, + payload["prompt"], ], model=model_name, modalities=output_modalities, - extra_body=extra_body, + extra_body=payload["extra_body"], stream=args.stream, ) - for _ in range(num_concurrent_requests) + for payload in request_payloads ] # Wait for all requests to complete and collect results @@ -437,10 +479,11 @@ def run_multimodal_generation(args, client: OpenAI) -> None: if not args.stream: # Verify all completions succeeded for chat_completion in chat_completions: + request_id = getattr(chat_completion, "id", None) for choice in chat_completion.choices: if choice.message.audio: audio_data = base64.b64decode(choice.message.audio.data) - audio_file_path = f"audio_{count}.wav" + audio_file_path = make_audio_output_filename(request_id=request_id, index=count) with open(audio_file_path, "wb") as f: f.write(audio_data) print(f"Audio saved to {audio_file_path}") @@ -459,7 +502,8 @@ def run_multimodal_generation(args, client: OpenAI) -> None: if getattr(chunk, "modality", None) == "audio" and content: audio_data = base64.b64decode(content) - audio_file_path = f"audio_{count}.wav" + request_id = getattr(chunk, "id", None) + audio_file_path = make_audio_output_filename(request_id=request_id, index=count) with open(audio_file_path, "wb") as f: f.write(audio_data) print(f"\nAudio saved to {audio_file_path}") @@ -546,6 +590,30 @@ def parse_args(): default="localhost", help="Host/IP of the vLLM Omni API server.", ) + parser.add_argument( + "--speaker", + type=str, + default=None, + help="TTS speaker/voice for audio output (e.g. Ethan, Vivian). Passed via extra_body to the talker stage.", + ) + parser.add_argument( + "--speakers", + type=str, + default=None, + help=( + "Comma-separated speakers for concurrent requests, e.g. " + "'Ethan,Vivian,Ryan'. Overrides --speaker per request." + ), + ) + parser.add_argument( + "--prompts", + type=str, + default=None, + help=( + "Comma-separated prompts for concurrent requests. " + "If fewer than --num-concurrent-requests, the last prompt is reused." + ), + ) return parser.parse_args() diff --git a/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py deleted file mode 100644 index 639f696a72..0000000000 --- a/examples/online_serving/qwen2_5_omni/openai_chat_completion_client_for_multimodal_generation.py +++ /dev/null @@ -1,459 +0,0 @@ -import base64 -import os - -import requests -from openai import OpenAI -from vllm.assets.audio import AudioAsset -from vllm.utils.argparse_utils import FlexibleArgumentParser - -SEED = 42 - - -def encode_base64_content_from_url(content_url: str) -> str: - """Encode a content retrieved from a remote url to base64 format.""" - - with requests.get(content_url) as response: - response.raise_for_status() - result = base64.b64encode(response.content).decode("utf-8") - - return result - - -def encode_base64_content_from_file(file_path: str) -> str: - """Encode a local file to base64 format.""" - with open(file_path, "rb") as f: - content = f.read() - result = base64.b64encode(content).decode("utf-8") - return result - - -def get_video_url_from_path(video_path: str | None) -> str: - """Convert a video path (local file or URL) to a video URL format for the API. - - If video_path is None or empty, returns the default URL. - If video_path is a local file path, encodes it to base64 data URL. - If video_path is a URL, returns it as-is. - """ - if not video_path: - # Default video URL - return "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" - - # Check if it's a URL (starts with http:// or https://) - if video_path.startswith(("http://", "https://")): - return video_path - - # Otherwise, treat it as a local file path - if not os.path.exists(video_path): - raise FileNotFoundError(f"Video file not found: {video_path}") - - # Detect video MIME type from file extension - video_path_lower = video_path.lower() - if video_path_lower.endswith(".mp4"): - mime_type = "video/mp4" - elif video_path_lower.endswith(".webm"): - mime_type = "video/webm" - elif video_path_lower.endswith(".mov"): - mime_type = "video/quicktime" - elif video_path_lower.endswith(".avi"): - mime_type = "video/x-msvideo" - elif video_path_lower.endswith(".mkv"): - mime_type = "video/x-matroska" - else: - # Default to mp4 if extension is unknown - mime_type = "video/mp4" - - video_base64 = encode_base64_content_from_file(video_path) - return f"data:{mime_type};base64,{video_base64}" - - -def get_image_url_from_path(image_path: str | None) -> str: - """Convert an image path (local file or URL) to an image URL format for the API. - - If image_path is None or empty, returns the default URL. - If image_path is a local file path, encodes it to base64 data URL. - If image_path is a URL, returns it as-is. - """ - if not image_path: - # Default image URL - return "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" - - # Check if it's a URL (starts with http:// or https://) - if image_path.startswith(("http://", "https://")): - return image_path - - # Otherwise, treat it as a local file path - if not os.path.exists(image_path): - raise FileNotFoundError(f"Image file not found: {image_path}") - - # Detect image MIME type from file extension - image_path_lower = image_path.lower() - if image_path_lower.endswith((".jpg", ".jpeg")): - mime_type = "image/jpeg" - elif image_path_lower.endswith(".png"): - mime_type = "image/png" - elif image_path_lower.endswith(".gif"): - mime_type = "image/gif" - elif image_path_lower.endswith(".webp"): - mime_type = "image/webp" - else: - # Default to jpeg if extension is unknown - mime_type = "image/jpeg" - - image_base64 = encode_base64_content_from_file(image_path) - return f"data:{mime_type};base64,{image_base64}" - - -def get_audio_url_from_path(audio_path: str | None) -> str: - """Convert an audio path (local file or URL) to an audio URL format for the API. - - If audio_path is None or empty, returns the default URL. - If audio_path is a local file path, encodes it to base64 data URL. - If audio_path is a URL, returns it as-is. - """ - if not audio_path: - # Default audio URL - return AudioAsset("mary_had_lamb").url - - # Check if it's a URL (starts with http:// or https://) - if audio_path.startswith(("http://", "https://")): - return audio_path - - # Otherwise, treat it as a local file path - if not os.path.exists(audio_path): - raise FileNotFoundError(f"Audio file not found: {audio_path}") - - # Detect audio MIME type from file extension - audio_path_lower = audio_path.lower() - if audio_path_lower.endswith((".mp3", ".mpeg")): - mime_type = "audio/mpeg" - elif audio_path_lower.endswith(".wav"): - mime_type = "audio/wav" - elif audio_path_lower.endswith(".ogg"): - mime_type = "audio/ogg" - elif audio_path_lower.endswith(".flac"): - mime_type = "audio/flac" - elif audio_path_lower.endswith(".m4a"): - mime_type = "audio/mp4" - else: - # Default to wav if extension is unknown - mime_type = "audio/wav" - - audio_base64 = encode_base64_content_from_file(audio_path) - return f"data:{mime_type};base64,{audio_base64}" - - -def get_system_prompt(): - return { - "role": "system", - "content": [ - { - "type": "text", - "text": ( - "You are Qwen, a virtual human developed by the Qwen Team, " - "Alibaba Group, capable of perceiving auditory and visual inputs, " - "as well as generating text and speech." - ), - } - ], - } - - -def get_text_query(custom_prompt: str | None = None): - question = ( - custom_prompt or "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words." - ) - prompt = { - "role": "user", - "content": [ - { - "type": "text", - "text": f"{question}", - } - ], - } - return prompt - - -def get_mixed_modalities_query( - video_path: str | None = None, - image_path: str | None = None, - audio_path: str | None = None, - custom_prompt: str | None = None, -): - question = ( - custom_prompt or "What is recited in the audio? What is the content of this image? Why is this video funny?" - ) - video_url = get_video_url_from_path(video_path) - image_url = get_image_url_from_path(image_path) - audio_url = get_audio_url_from_path(audio_path) - prompt = { - "role": "user", - "content": [ - { - "type": "audio_url", - "audio_url": {"url": audio_url}, - }, - { - "type": "image_url", - "image_url": {"url": image_url}, - }, - { - "type": "video_url", - "video_url": {"url": video_url}, - }, - { - "type": "text", - "text": f"{question}", - }, - ], - } - - return prompt - - -def get_use_audio_in_video_query(video_path: str | None = None, custom_prompt: str | None = None): - question = custom_prompt or "Describe the content of the video, then convert what the baby say into text." - video_url = get_video_url_from_path(video_path) - - prompt = { - "role": "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": video_url, - "num_frames": 16, - }, - }, - { - "type": "text", - "text": f"{question}", - }, - ], - } - - return prompt - - -def get_multi_audios_query(audio_path: str | None = None, custom_prompt: str | None = None): - question = custom_prompt or "Are these two audio clips the same?" - audio_url = get_audio_url_from_path(audio_path) - prompt = { - "role": "user", - "content": [ - { - "type": "audio_url", - "audio_url": {"url": audio_url}, - }, - { - "type": "audio_url", - "audio_url": {"url": AudioAsset("winning_call").url}, - }, - { - "type": "text", - "text": f"{question}", - }, - ], - } - return prompt - - -query_map = { - "mixed_modalities": get_mixed_modalities_query, - "use_audio_in_video": get_use_audio_in_video_query, - "multi_audios": get_multi_audios_query, - "text": get_text_query, -} - - -def run_multimodal_generation(args, client: OpenAI) -> None: - model_name = "Qwen/Qwen2.5-Omni-7B" - thinker_sampling_params = { - "temperature": 0.0, # Deterministic - no randomness - "top_p": 1.0, # Disable nucleus sampling - "top_k": -1, # Disable top-k sampling - "max_tokens": 2048, - "seed": SEED, # Fixed seed for sampling - "detokenize": True, - "repetition_penalty": 1.1, - } - talker_sampling_params = { - "temperature": 0.9, - "top_p": 0.8, - "top_k": 40, - "max_tokens": 2048, - "seed": SEED, # Fixed seed for sampling - "detokenize": True, - "repetition_penalty": 1.05, - "stop_token_ids": [8294], - } - code2wav_sampling_params = { - "temperature": 0.0, # Deterministic - no randomness - "top_p": 1.0, # Disable nucleus sampling - "top_k": -1, # Disable top-k sampling - "max_tokens": 2048, - "seed": SEED, # Fixed seed for sampling - "detokenize": True, - "repetition_penalty": 1.1, - } - - sampling_params_list = [ - thinker_sampling_params, - talker_sampling_params, - code2wav_sampling_params, - ] - - # Get paths and custom prompt from args - video_path = getattr(args, "video_path", None) - image_path = getattr(args, "image_path", None) - audio_path = getattr(args, "audio_path", None) - custom_prompt = getattr(args, "prompt", None) - - # Get the query function and call it with appropriate parameters - query_func = query_map[args.query_type] - if args.query_type == "mixed_modalities": - prompt = query_func( - video_path=video_path, image_path=image_path, audio_path=audio_path, custom_prompt=custom_prompt - ) - elif args.query_type == "use_audio_in_video": - prompt = query_func(video_path=video_path, custom_prompt=custom_prompt) - elif args.query_type == "multi_audios": - prompt = query_func(audio_path=audio_path, custom_prompt=custom_prompt) - elif args.query_type == "text": - prompt = query_func(custom_prompt=custom_prompt) - else: - prompt = query_func() - - extra_body = { - "sampling_params_list": sampling_params_list # Optional, it has a default setting in stage_configs of the corresponding model. - } - - if args.query_type == "use_audio_in_video": - extra_body["mm_processor_kwargs"] = {"use_audio_in_video": True} - - if args.modalities is not None: - output_modalities = args.modalities.split(",") - else: - output_modalities = None - - chat_completion = client.chat.completions.create( - messages=[ - get_system_prompt(), - prompt, - ], - model=model_name, - modalities=output_modalities, - extra_body=extra_body, - stream=args.stream, - ) - - count = 0 - if not args.stream: - for choice in chat_completion.choices: - if choice.message.audio: - audio_data = base64.b64decode(choice.message.audio.data) - audio_file_path = f"audio_{count}.wav" - with open(audio_file_path, "wb") as f: - f.write(audio_data) - print(f"Audio saved to {audio_file_path}") - count += 1 - elif choice.message.content: - print("Chat completion output from text:", choice.message.content) - else: - printed_content = False - for chunk in chat_completion: - for choice in chunk.choices: - if hasattr(choice, "delta"): - content = getattr(choice.delta, "content", None) - else: - content = None - - if getattr(chunk, "modality", None) == "audio" and content: - audio_data = base64.b64decode(content) - audio_file_path = f"audio_{count}.wav" - with open(audio_file_path, "wb") as f: - f.write(audio_data) - print(f"\nAudio saved to {audio_file_path}") - count += 1 - - elif getattr(chunk, "modality", None) == "text": - if not printed_content: - printed_content = True - print("\ncontent:", end="", flush=True) - print(content, end="", flush=True) - - -def parse_args(): - parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models") - parser.add_argument( - "--query-type", - "-q", - type=str, - default="mixed_modalities", - choices=query_map.keys(), - help="Query type.", - ) - parser.add_argument( - "--video-path", - "-v", - type=str, - default=None, - help="Path to local video file or URL. If not provided and query-type uses video, uses default video URL.", - ) - parser.add_argument( - "--image-path", - "-i", - type=str, - default=None, - help="Path to local image file or URL. If not provided and query-type uses image, uses default image URL.", - ) - parser.add_argument( - "--audio-path", - "-a", - type=str, - default=None, - help="Path to local audio file or URL. If not provided and query-type uses audio, uses default audio URL.", - ) - parser.add_argument( - "--prompt", - "-p", - type=str, - default=None, - help="Custom text prompt/question to use instead of the default prompt for the selected query type.", - ) - parser.add_argument( - "--modalities", - type=str, - default=None, - help="Output modalities to use for the prompts.", - ) - parser.add_argument( - "--stream", - action="store_true", - help="Stream the response.", - ) - parser.add_argument( - "--port", - type=int, - default=8091, - help="Port of the vLLM Omni API server.", - ) - parser.add_argument( - "--host", - type=str, - default="localhost", - help="Host/IP of the vLLM Omni API server.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - host = args.host - port = args.port - openai_api_base = f"http://{host}:{port}/v1" - client = OpenAI( - api_key="EMPTY", - base_url=openai_api_base, - ) - run_multimodal_generation(args, client) diff --git a/examples/online_serving/qwen3_omni/README.md b/examples/online_serving/qwen3_omni/README.md index 9af5200642..02497f15ae 100644 --- a/examples/online_serving/qwen3_omni/README.md +++ b/examples/online_serving/qwen3_omni/README.md @@ -44,6 +44,7 @@ The Python client supports the following command-line arguments: - `--image-path` (or `-i`): Path to local image file or URL. If not provided and query-type is `use_image`, uses default image URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common image formats: JPEG, PNG, GIF, WebP. Example: `--image-path /path/to/image.jpg` or `--image-path https://example.com/image.png` - `--audio-path` (or `-a`): Path to local audio file or URL. If not provided and query-type is `use_audio`, uses default audio URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common audio formats: MP3, WAV, OGG, FLAC, M4A. Example: `--audio-path /path/to/audio.wav` or `--audio-path https://example.com/audio.mp3` - `--prompt` (or `-p`): Custom text prompt/question. If not provided, uses default prompt for the selected query type. Example: `--prompt "What are the main activities shown in this video?"` +- `--speaker`: TTS speaker/voice for audio output when requesting audio (e.g. `ethan`, `chelsie`, `aiden`). Omit to use the model default. Example: `--speaker "chelsie"` For example, to use a local video file with custom prompt: @@ -150,6 +151,57 @@ print(response.choices[0].message.content) # Text response print(response.choices[1].message.audio) # Audio response ``` +## Speaker selection + +When requesting audio output, you can choose the TTS speaker (voice) used for synthesis. If not specified, the model uses its default speaker. + +### Using curl + +Pass a `speaker` field in the request body: + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "messages": [{"role": "user", "content": "Say hello in one sentence."}], + "modalities": ["audio"], + "speaker": "chelsie" + }' +``` + +### Using Python client + +Use the `--speaker` argument when generating audio: + +```bash +python openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_image \ + --modalities audio \ + --speaker "chelsie" +``` + +### Using OpenAI Python SDK + +Pass `speaker` in `extra_body`: + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Qwen/Qwen3-Omni-30B-A3B-Instruct", + messages=[{"role": "user", "content": "Say hello in one sentence."}], + modalities=["audio"], + extra_body={"speaker": "chelsie"} +) +# Audio uses the specified speaker +print(response.choices[1].message.audio) +``` + +Supported speaker names depend on the model (e.g. `Ethan`, `Chelsie`, `Aiden`). Omit `speaker` to use the default. + ## Streaming Output If you want to enable streaming output, please set the argument as below. The final output will be obtained just after generated by corresponding stage. Now we only support text streaming output. Other modalities can output normally. ```bash diff --git a/examples/online_serving/qwen3_tts/openai_speech_client.py b/examples/online_serving/qwen3_tts/openai_speech_client.py index e008eb139a..064ae50bc2 100644 --- a/examples/online_serving/qwen3_tts/openai_speech_client.py +++ b/examples/online_serving/qwen3_tts/openai_speech_client.py @@ -65,7 +65,7 @@ def run_tts_generation(args) -> None: payload = { "model": args.model, "input": args.text, - "voice": args.voice, + "speaker": args.speaker, "response_format": args.response_format, } @@ -93,7 +93,7 @@ def run_tts_generation(args) -> None: print(f"Model: {args.model}") print(f"Task type: {args.task_type or 'CustomVoice'}") print(f"Text: {args.text}") - print(f"Voice: {args.voice}") + print(f"Speaker: {args.speaker}") print("Generating audio...") # Make the API call @@ -176,10 +176,10 @@ def parse_args(): # Voice/speaker parser.add_argument( - "--voice", + "--speaker", type=str, default="vivian", - help="Speaker/voice name (default: vivian). Options: vivian, ryan, aiden, etc.", + help="Speaker name (default: vivian). Options: vivian, ryan, aiden, etc.", ) parser.add_argument( "--language", diff --git a/examples/online_serving/qwen3_tts/streaming_speech_client.py b/examples/online_serving/qwen3_tts/streaming_speech_client.py index 785c6a0e8d..8f09409cea 100644 --- a/examples/online_serving/qwen3_tts/streaming_speech_client.py +++ b/examples/online_serving/qwen3_tts/streaming_speech_client.py @@ -163,7 +163,7 @@ def main(): # Session config options parser.add_argument("--model", default=None, help="Model name") - parser.add_argument("--voice", default="Vivian", help="Speaker voice") + parser.add_argument("--speaker", default="Vivian", help="Speaker name") parser.add_argument( "--task-type", default="CustomVoice", @@ -215,7 +215,7 @@ def main(): config = {} for key in [ "model", - "voice", + "speaker", "task_type", "language", "instructions", diff --git a/examples/online_serving/qwen3_tts/tts_common.py b/examples/online_serving/qwen3_tts/tts_common.py index b31a75d808..4a44b997f7 100644 --- a/examples/online_serving/qwen3_tts/tts_common.py +++ b/examples/online_serving/qwen3_tts/tts_common.py @@ -100,7 +100,7 @@ def build_payload( if task_type == "CustomVoice": if voice: - payload["voice"] = voice + payload["speaker"] = voice if instructions and instructions.strip(): payload["instructions"] = instructions.strip() diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py index e577432262..a479766638 100644 --- a/vllm_omni/entrypoints/openai/protocol/audio.py +++ b/vllm_omni/entrypoints/openai/protocol/audio.py @@ -9,7 +9,7 @@ class OpenAICreateSpeechRequest(BaseModel): model: str | None = None voice: str | None = Field( default=None, - description="Voice to use. For OpenAI: alloy, echo, etc. For Qwen3-TTS: Vivian, Ryan, etc.", + description="Speaker/voice to use. For Qwen3-TTS: vivian, ryan, aiden, etc.", ) instructions: str | None = Field( default=None, diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 39828e95c7..ad472e94c1 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -533,6 +533,18 @@ async def _preprocess_chat( if hasattr(request, "cache_salt") and request.cache_salt is not None: engine_prompt["cache_salt"] = request.cache_salt + speaker = getattr(request, "speaker", None) + if speaker is not None and isinstance(speaker, str) and speaker.strip(): + if "additional_information" not in engine_prompt or engine_prompt["additional_information"] is None: + engine_prompt["additional_information"] = {} + engine_prompt["additional_information"]["speaker"] = [speaker.lower().strip()] + + language = getattr(request, "language", None) + if language is not None and isinstance(language, str) and language.strip(): + if "additional_information" not in engine_prompt or engine_prompt["additional_information"] is None: + engine_prompt["additional_information"] = {} + engine_prompt["additional_information"]["language"] = [language.strip()] + return conversation, [engine_prompt] async def _inject_audio_from_video_urls( diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 2906288d50..3c8278f14a 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -592,7 +592,7 @@ def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str "or use a CustomVoice model." ) if request.voice is not None and request.voice not in self.supported_speakers: - return f"Invalid speaker '{request.voice}'. Supported: {', '.join(sorted(self.supported_speakers))}" + return f"Invalid voice '{request.voice}'. Supported: {', '.join(sorted(self.supported_speakers))}" # Validate Base task requirements if task_type == "Base": diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py index 726cc2643a..a83aaff8e6 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py @@ -697,7 +697,7 @@ def thinker_to_talker_process( # TODO(Peiqi): add voice_type support req_input_ids, req_embeds = self._thinker_to_talker_prefill( - voice_type=self.voice_type, + speaker=self.voice_type, output_prompt_embeds=thinker_result.to(input_embeds.dtype).to(self._module_device(self.model)), output_token_ids=thinker_output_token_ids, thinker_prompt_embeds=prompt_embeds.to(input_embeds.dtype).to(self._module_device(self.model)), @@ -711,7 +711,7 @@ def thinker_to_talker_process( def _thinker_to_talker_prefill( self, - voice_type: str, + speaker: str, output_prompt_embeds, output_token_ids, thinker_prompt_embeds, @@ -726,7 +726,7 @@ def _thinker_to_talker_prefill( prompt_embeds = torch.cat( [ thinker_prompt_embeds, - self._get_embed_text_spk_token(voice_type) + self.embed_codec_pad_token, + self._get_embed_text_spk_token(speaker) + self.embed_codec_pad_token, output_prompt_embeds[:1] + self.embed_codec_bos_token, ], dim=0, diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 21f0185aa3..6402650ed7 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -322,10 +322,6 @@ def forward( if inputs_embeds is None and input_ids is not None: inputs_embeds = self.talker.embed_input_ids(input_ids) - # TODO(Peiqi): temporal hack here to support voice_type. - if not hasattr(self, "voice_type"): - self.voice_type = voice_type - # Run talker forward with torch.inference_mode(): talker_hidden = self.talker.forward( @@ -377,7 +373,7 @@ def forward( left_context_size.append(info["left_context_size"]) else: logger.debug("No additional_information provided to code2wav stage.") - audio_tensors = self.generate_audio(codes, voice_type, left_context_size, seq_token_counts) + audio_tensors = self.generate_audio(codes, left_context_size, seq_token_counts) return audio_tensors @@ -457,7 +453,6 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) - def generate_audio( self, code: torch.Tensor, - voice_type: str, left_context_size: list[int] | None = None, seq_token_counts: list[int] | None = None, ) -> list[torch.Tensor]: @@ -466,7 +461,6 @@ def generate_audio( Args: code: [batch, num_quantizers, T] - RVQ codec codes - voice_type: Voice type (not used in Qwen3, kept for compatibility) left_context_size: Left context size for streaming decode seq_token_counts: Token count for each request in batch @@ -678,8 +672,16 @@ def _proj_from_thinker(x_opt: torch.Tensor | None) -> torch.Tensor: def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict): # Containers to return per-request updates (e.g., code_predictor_hidden_per_request) update_dict: dict[str, dict] = {} - # TODO(Peiqi): add voice_type support - voice_type = self.voice_type + + voice_type = info_dict.get("speaker") + logger.info("talker_preprocess_prefill speaker: %s", voice_type) + if voice_type is not None and isinstance(voice_type, (list, tuple)) and len(voice_type) > 0: + voice_type = voice_type[0] + if not isinstance(voice_type, str) or not voice_type.strip(): + # Fall back to model default; speaker is per-request. + voice_type = self.default_tts_text_spk_type + else: + voice_type = str(voice_type).lower().strip() start_index = info_dict.get("num_processed_tokens", 0) end_index = start_index + input_embeds.shape[0] # Read thinker outputs for prefill diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index de248f0f33..39d9dba2d2 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -700,7 +700,7 @@ def _first(x: object, default: object) -> object: info: dict[str, Any] = additional_information or {} text = _first(info.get("text"), "") language = _first(info.get("language"), "Auto") - speaker = _first(info.get("speaker"), "") + speaker = _first(info.get("speaker"), "").lower().strip() instruct = _first(info.get("instruct"), "") non_streaming_mode_raw = _first(info.get("non_streaming_mode"), None) @@ -1438,11 +1438,14 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: ) elif task_type == "CustomVoice": - speaker = (info_dict.get("speaker") or [""])[0] - if not isinstance(speaker, str) or not speaker.strip(): + _speaker_raw = info_dict.get("speaker") or [""] + speaker = ( + ((_speaker_raw[0] if isinstance(_speaker_raw, (list, tuple)) else _speaker_raw) or "").lower().strip() + ) + if not speaker: raise ValueError("CustomVoice requires additional_information.speaker.") spk_id_map = getattr(self.talker_config, "spk_id", None) or {} - if speaker.lower() not in spk_id_map: + if speaker not in spk_id_map: raise ValueError(f"Unsupported speaker: {speaker}") spk_id = spk_id_map[speaker.lower()] # Keep it at least 1D; embedding on a 0-d tensor can return 1D. diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 7cfc59f79c..f4828fddaa 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -11,6 +11,12 @@ from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.inputs.data import OmniTokensPrompt +from vllm_omni.model_executor.stage_input_processors.tts_utils import ( + extract_language_from_prompt, + extract_language_from_request, + extract_speaker_from_prompt, + extract_speaker_from_request, +) def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int: @@ -115,6 +121,12 @@ def thinker2talker_async_chunk( "tts_pad_embed": pooling_output.get("tts_pad_embed").detach().cpu(), "finished": torch.tensor(is_finished, dtype=torch.bool), } + speaker = extract_speaker_from_request(request) + if speaker is not None: + talker_additional_info["speaker"] = speaker + language = extract_language_from_request(request) + if language is not None: + talker_additional_info["language"] = language if transfer_manager.request_payload.get(request_id) is None: if not is_finished: transfer_manager.request_payload[request_id] = talker_additional_info @@ -140,6 +152,13 @@ def thinker2talker_async_chunk( talker_additional_info = { "finished": torch.tensor(is_finished, dtype=torch.bool), } + speaker = extract_speaker_from_request(request) + if speaker is not None: + talker_additional_info["speaker"] = speaker + language = extract_language_from_request(request) + if language is not None: + talker_additional_info["language"] = language + if output_token_ids: talker_additional_info["override_keys"] = ["thinker_decode_embeddings", "thinker_output_token_ids"] talker_additional_info["thinker_decode_embeddings"] = pooling_output.get("0").detach().cpu() @@ -148,6 +167,7 @@ def thinker2talker_async_chunk( # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. talker_additional_info["thinker_prefill_embeddings"] = pooling_output.get("0").detach().cpu() talker_additional_info["thinker_hidden_states"] = pooling_output.get("24").detach().cpu() + return talker_additional_info @@ -180,7 +200,7 @@ def thinker2talker( device = torch.device(current_platform.device_type) # Process each thinker output - for thinker_output in thinker_outputs: + for i, thinker_output in enumerate(thinker_outputs): output = thinker_output.outputs[0] info = { @@ -195,6 +215,12 @@ def thinker2talker( "tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float), "tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float), } + speaker = extract_speaker_from_prompt(prompt, index=i) + if speaker is not None: + info["speaker"] = speaker + language = extract_language_from_prompt(prompt, index=i) + if language is not None: + info["language"] = language prompt_len = _compute_talker_prompt_ids_length(info, device=device) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 69724dfc09..934dfb20e9 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -9,6 +9,12 @@ compute_dynamic_initial_chunk_size, max_ic_for_chunk_size, ) +from vllm_omni.model_executor.stage_input_processors.tts_utils import ( + extract_language_from_prompt, + extract_language_from_request, + extract_speaker_from_prompt, + extract_speaker_from_request, +) logger = init_logger(__name__) @@ -25,7 +31,7 @@ def talker2code2wav( talker_outputs = _validate_stage_inputs(stage_list, engine_input_source) code2wav_inputs: list[OmniTokensPrompt] = [] - for talker_output in talker_outputs: + for i, talker_output in enumerate(talker_outputs): output = talker_output.outputs[0] # audio_codes shape: [num_frames, Q] where Q=num_quantizers (16) audio_codes = output.multimodal_output["audio_codes"].to(torch.long) @@ -45,13 +51,24 @@ def talker2code2wav( ref_code_len = 0 # Code2Wav expects codebook-major flat: [Q*num_frames] codec_codes = audio_codes.transpose(0, 1).cpu().reshape(-1).tolist() - additional_information = {"left_context_size": [ref_code_len]} if ref_code_len > 0 else None + additional_information: dict[str, Any] = {} + if ref_code_len > 0: + additional_information["left_context_size"] = [ref_code_len] + # Propagate speaker and language from the original prompt so they are + # available as runtime_additional_information in later pipeline stages, + # consistent with qwen3-omni and qwen2.5-omni stage input processors. + speaker = extract_speaker_from_prompt(prompt, index=i) + if speaker is not None: + additional_information["speaker"] = speaker + language = extract_language_from_prompt(prompt, index=i) + if language is not None: + additional_information["language"] = language code2wav_inputs.append( OmniTokensPrompt( prompt_token_ids=codec_codes, multi_modal_data=None, mm_processor_kwargs=None, - additional_information=additional_information, + additional_information=additional_information if additional_information else None, ) ) return code2wav_inputs @@ -194,8 +211,18 @@ def talker2code2wav_async_chunk( num_frames = len(window_frames) code_predictor_codes = [window_frames[f][q] for q in range(num_quantizers) for f in range(num_frames)] - return { + info: dict[str, Any] = { "code_predictor_codes": code_predictor_codes, "left_context_size": left_context_size, "finished": finished, } + # Propagate speaker and language from the request so they are available + # as runtime_additional_information in subsequent pipeline stages, consistent + # with qwen3-omni and qwen2.5-omni stage input processors. + speaker = extract_speaker_from_request(request) + if speaker is not None: + info["speaker"] = speaker + language = extract_language_from_request(request) + if language is not None: + info["language"] = language + return info diff --git a/vllm_omni/model_executor/stage_input_processors/tts_utils.py b/vllm_omni/model_executor/stage_input_processors/tts_utils.py new file mode 100644 index 0000000000..1bc78b4a20 --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/tts_utils.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team. +"""Shared TTS utility functions for speaker and language extraction. + +These utilities are model-agnostic and can be used by any TTS model stage +processor (qwen3_omni, qwen2_5_omni, qwen3_tts, etc.). +""" + +from typing import Any + +# ============================================================================= +# Speaker helpers +# ============================================================================= + + +def extract_speaker_from_runtime_info( + runtime_additional_information: list[dict[str, Any]] | None, +) -> str | None: + """Extract speaker from per-request runtime info dicts. + + Iterates through the list of per-request info dicts and returns the first + non-empty speaker string found, normalized to lowercase. + + Args: + runtime_additional_information: List of per-request additional info + dicts, as passed to the model's forward() method. + + Returns: + The speaker string (lowercase, stripped), or None if not present. + """ + if not runtime_additional_information: + return None + for info in runtime_additional_information: + vt = info.get("speaker") + if vt is None: + continue + if isinstance(vt, (list, tuple)) and len(vt) > 0: + vt = vt[0] + if isinstance(vt, str) and vt.strip(): + return vt.lower().strip() + if vt is not None: + return str(vt).lower().strip() + return None + + +def extract_speaker_from_request(request: Any) -> str | None: + """Extract speaker from a request's additional_information field. + + Reads from the structured ``additional_information.entries["speaker"]`` + field used by the engine serialization layer. + + Args: + request: An OmniEngineCoreRequest (or compatible object) with an + ``additional_information`` attribute. + + Returns: + The speaker string (lowercase, stripped), or None if not present. + """ + additional_information = getattr(request, "additional_information", None) + if additional_information is None: + return None + entries = getattr(additional_information, "entries", None) + if not isinstance(entries, dict): + return None + entry = entries.get("speaker") + if entry is None: + return None + list_data = getattr(entry, "list_data", None) + if isinstance(list_data, list) and list_data: + val = list_data[0] + return val.lower().strip() if isinstance(val, str) else str(val).lower().strip() + return None + + +def extract_speaker_from_prompt( + prompt: Any, + index: int = 0, +) -> list[str] | None: + """Extract speaker from a prompt's additional_information dict. + + Used in non-async stage processors where the prompt is an + OmniTokensPrompt / TextPrompt dict (or a list of them). + + Args: + prompt: A single prompt dict, or a list of prompt dicts. + index: Which element to pick when prompt is a list. + + Returns: + The speaker as a list (for serialization compatibility), or None. + """ + if prompt is None: + return None + p = prompt[index] if isinstance(prompt, list) and index < len(prompt) else prompt + if p is None: + return None + add_info = p.get("additional_information") + if not isinstance(add_info, dict): + return None + speaker = add_info.get("speaker") + if isinstance(speaker, list) and speaker: + return speaker + return None + + +# ============================================================================= +# Language helpers +# ============================================================================= + + +def extract_language_from_runtime_info( + runtime_additional_information: list[dict[str, Any]] | None, +) -> str | None: + """Extract language from per-request runtime info dicts. + Args: + runtime_additional_information: List of per-request additional info + dicts, as passed to the model's forward() method. + + Returns: + The language string (e.g. "Chinese", "English", "Auto"), or None. + """ + if not runtime_additional_information: + return None + for info in runtime_additional_information: + lang = info.get("language") + if lang is None: + continue + if isinstance(lang, (list, tuple)) and len(lang) > 0: + return lang + if isinstance(lang, str) and lang.strip(): + return [lang.strip()] + return None + + +def extract_language_from_request(request: Any) -> str | None: + """Extract language from a request's additional_information field. + + Args: + request: An OmniEngineCoreRequest (or compatible object) with an + ``additional_information`` attribute. + + Returns: + The language string, or None if not present. + """ + additional_information = getattr(request, "additional_information", None) + if additional_information is None: + return None + entries = getattr(additional_information, "entries", None) + if not isinstance(entries, dict): + return None + entry = entries.get("language") + if entry is None: + return None + list_data = getattr(entry, "list_data", None) + if isinstance(list_data, list) and list_data: + return list_data + return None + + +def extract_language_from_prompt( + prompt: Any, + index: int = 0, +) -> list[str] | None: + """Extract language from a prompt's additional_information dict. + Args: + prompt: A single prompt dict, or a list of prompt dicts. + index: Which element to pick when prompt is a list. + + Returns: + The language as a list (for serialization compatibility), or None. + """ + if prompt is None: + return None + p = prompt[index] if isinstance(prompt, list) and index < len(prompt) else prompt + if p is None: + return None + add_info = p.get("additional_information") + if not isinstance(add_info, dict): + return None + language = add_info.get("language") + if isinstance(language, list) and language: + return language + return None