diff --git a/examples/offline_inference/ming_flash_omni/README.md b/examples/offline_inference/ming_flash_omni/README.md index 7414163fc0..be90b408d1 100644 --- a/examples/offline_inference/ming_flash_omni/README.md +++ b/examples/offline_inference/ming_flash_omni/README.md @@ -1,75 +1,91 @@ # Ming-flash-omni 2.0 -[Ming-flash-omni-2.0](https://github.com/inclusionAI/Ming) is an omni-modal model supporting text, image, video, and audio understanding, with outputs in text, image, and audio. For now, Ming-flash-omni-2.0 in vLLM-Omni is supported with thinker stage (multi-modal understanding). +[Ming-flash-omni-2.0](https://github.com/inclusionAI/Ming) is an omni-modal model supporting text, image, video, and audio understanding, with text and speech outputs. + +vLLM-Omni supports two deployment modes: + +| Mode | Stage config | Output | +|------|-------------|--------| +| Thinker only (multimodal understanding) | `ming_flash_omni_thinker.yaml` (default `--omni`) | Text | +| Thinker + Talker (omni-speech) | `ming_flash_omni.yaml` | Text + Audio | + +For standalone TTS (talker only), see [`examples/offline_inference/ming_flash_omni_tts/`](../ming_flash_omni_tts/). ## Setup Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. +The default `--omni` flag runs thinker only. For omni-speech, pass the two-stage config explicitly: + +```bash +--stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml +``` + ## Run examples -### Text-only +The end-to-end script defaults to built-in assets; pass `--image-path`, +`--audio-path`, or `--video-path` to override. + ```bash +# Text-only python examples/offline_inference/ming_flash_omni/end2end.py --query-type text + +# Image / audio / video / mixed understanding +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video --num-frames 16 +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_mixed_modalities \ + --image-path /path/to/image.jpg --audio-path /path/to/audio.wav ``` #### Reasoning (Thinking Mode) -Reasoning (Thinking) mode is enabled via applying "detailed thinking on" when building the system prompt template (in `apply_chat_template`). - -In the end2end example, a default problem for thinking mode is provided, as referred to the example usage of Ming's cookbook; -To utilize it, you have to download the example figure from https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/figures/cases/3_0.png +Reasoning ("detailed thinking on") is applied by the script when +`--query-type reasoning` is set. The default prompt matches Ming's cookbook +and expects the reference figure from the upstream repo — see +`get_reasoning_query` in `end2end.py`. ```bash python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png ``` -### Image understanding -```bash -python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image +### Omni-speech (thinker + talker) -# With a local image -python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image --image-path /path/to/image.jpg -``` +To enable spoken output, use the two-stage config and request `audio` (or `text,audio`) modalities. +The thinker processes your multimodal input, generates text, then the talker synthesises the response as speech. -### Audio understanding +**Audio-only output** (speech response, no text): ```bash -python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio - -# With a local audio file -python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio --audio-path /path/to/audio.wav +python examples/offline_inference/ming_flash_omni/end2end.py \ + --query-type text \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml \ + --modalities audio \ + --output-dir output_ming_omni_speech ``` -### Video understanding +**Both text and audio output**: ```bash -python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video - -# With a local video and custom frame count -python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video --video-path /path/to/video.mp4 --num-frames 16 +python examples/offline_inference/ming_flash_omni/end2end.py \ + --query-type use_audio \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml \ + --modalities text,audio \ + --output-dir output_ming_omni_speech ``` -### Mixed modalities (image + audio) -```bash -python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_mixed_modalities \ - --image-path /path/to/image.jpg \ - --audio-path /path/to/audio.wav -``` +Generated `.wav` files are saved to `--output-dir` (default `output_ming`), one per request. -If media file paths are not provided, the script uses built-in default assets. +The stage config allocates thinker on GPUs 0–3 and talker on GPU 3 by default. Adjust `devices` in the YAML to match your hardware. ### Modality control -To control output modalities (e.g. text-only output): -```bash -python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio --modalities text -``` -*For now, only text output is supported* +| `--modalities` | Thinker output | Talker | Saved files | +|---------------|----------------|--------|-------------| +| `text` (default) | Text | Not run | `.txt` | +| `audio` | Text (internal) | Runs | `.wav` | +| `text,audio` | Text | Runs | `.txt` + `.wav` | -### Custom stage config -```bash -python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image \ - --stage-configs-path /path/to/your_config.yaml -``` +Pass `--stage-configs-path /path/to/your_config.yaml` to any of the commands +above to override the stage config. ## Online serving diff --git a/examples/offline_inference/ming_flash_omni/end2end.py b/examples/offline_inference/ming_flash_omni/end2end.py index 8f87301316..e00dcea7bb 100644 --- a/examples/offline_inference/ming_flash_omni/end2end.py +++ b/examples/offline_inference/ming_flash_omni/end2end.py @@ -7,6 +7,7 @@ from typing import NamedTuple import numpy as np +import soundfile as sf import vllm from PIL import Image from transformers import AutoProcessor @@ -319,7 +320,16 @@ def main(args): seed=SEED, detokenize=True, ) - sampling_params_list = [thinker_sampling_params] + # Talker (ming_tts) uses a custom generation loop (CFM + AudioVAE); + # vLLM sampling is a no-op here — max_tokens=1 just satisfies the scheduler. + talker_sampling_params = SamplingParams( + temperature=0.0, + max_tokens=1, + ) + all_sampling_params = [thinker_sampling_params, talker_sampling_params] + # Match sampling params to the number of configured stages + # (thinker-only yaml → 1, thinker+talker yaml → 2). + sampling_params_list = all_sampling_params[: omni.num_stages] prompts = [query_result.inputs for _ in range(args.num_prompts)] @@ -362,7 +372,19 @@ def main(args): print(f"Failed to write output file {out_txt}: {e}") elif stage_outputs.final_output_type == "audio": - raise NotImplementedError("Add audio example after talker supported.") + request_id = output.request_id + mm = output.outputs[0].multimodal_output + if mm and "audio" in mm: + audio = mm["audio"] + sr_raw = mm.get("sr", 44100) + sample_rate = int(sr_raw.item() if hasattr(sr_raw, "item") else sr_raw) + audio_numpy = audio.float().squeeze().cpu().numpy() + output_wav = os.path.join(output_dir, f"{request_id}.wav") + sf.write(output_wav, audio_numpy, samplerate=sample_rate, format="WAV") + print( + f"Request ID: {request_id}, audio saved to {output_wav} " + f"({len(audio_numpy) / sample_rate:.2f}s, {sample_rate}Hz)" + ) processed_count += 1 if profiler_enabled and processed_count >= total_requests: diff --git a/examples/offline_inference/ming_flash_omni_tts/README.md b/examples/offline_inference/ming_flash_omni_tts/README.md new file mode 100644 index 0000000000..15b84041df --- /dev/null +++ b/examples/offline_inference/ming_flash_omni_tts/README.md @@ -0,0 +1,47 @@ +# Ming-flash-omni Standalone TTS (Offline) + +This example runs **Ming-flash-omni-2.0 talker-only** offline inference with: + +- `model`: `Jonathan1909/Ming-flash-omni-2.0` +- `stage config`: `vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml` + +It follows the Ming cookbook parameter style: + +- `prompt`: `"Please generate speech based on the following description.\n"` +- `max_decode_steps`: `200` +- `cfg`: `2.0` +- `sigma`: `0.25` +- `temperature`: `0.0` + +## Quick Start + +```bash +python examples/offline_inference/ming_flash_omni_tts/end2end.py --case style +``` + +## Cases + +```bash +# Style +python examples/offline_inference/ming_flash_omni_tts/end2end.py --case style + +# IP +python examples/offline_inference/ming_flash_omni_tts/end2end.py --case ip + +# Basic (speed/pitch/volume control) +python examples/offline_inference/ming_flash_omni_tts/end2end.py --case basic +``` + +## Useful Arguments + +- `--text`: override default text in the selected case +- `--output`: custom output wav path +- `--model`: local model path or HF repo id +- `--stage-configs-path`: custom talker stage config path +- `--log-stats`: enable runtime stats logs + +## Notes + +- This directory is for **standalone talker deployment (TTS)**. +- For Ming thinker multimodal understanding examples, see: + `examples/offline_inference/ming_flash_omni/`. diff --git a/examples/offline_inference/ming_flash_omni_tts/end2end.py b/examples/offline_inference/ming_flash_omni_tts/end2end.py new file mode 100644 index 0000000000..0a2d25646f --- /dev/null +++ b/examples/offline_inference/ming_flash_omni_tts/end2end.py @@ -0,0 +1,129 @@ +"""Offline e2e example for Ming-flash-omni-2.0 standalone talker (TTS).""" + +import os +from typing import Any + +import soundfile as sf +import torch +from vllm.utils.argparse_utils import FlexibleArgumentParser + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniTokensPrompt +from vllm_omni.model_executor.models.ming_flash_omni.prompt_utils import ( + DEFAULT_PROMPT, + create_instruction, +) + +MODEL_NAME = "Jonathan1909/Ming-flash-omni-2.0" +DEFAULT_STAGE_CONFIG = "vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml" + + +def get_messages(case: str, text_override: str | None) -> dict[str, Any]: + if case == "style": + text = text_override or "我会一直在这里陪着你,直到你慢慢、慢慢地沉入那个最温柔的梦里……好吗?" + instruction = create_instruction( + { + "风格": "这是一种ASMR耳语,属于一种旨在引发特殊感官体验的创意风格。这个女性使用轻柔的普通话进行耳语,声音气音成分重。音量极低,紧贴麦克风,语速极慢,旨在制造触发听者颅内快感的声学刺激。", + } + ) + return { + "prompt": DEFAULT_PROMPT, + "text": text, + "instruction": instruction, + "use_zero_spk_emb": True, + } + if case == "ip": + text = text_override or "这款产品的名字,叫变态坑爹牛肉丸。" + return { + "prompt": DEFAULT_PROMPT, + "text": text, + "instruction": create_instruction({"IP": "灵小甄"}), + "use_zero_spk_emb": True, + } + if case == "basic": + text = text_override or "我们当迎着阳光辛勤耕作,去摘取,去制作,去品尝,去馈赠。" + return { + "prompt": DEFAULT_PROMPT, + "text": text, + "instruction": create_instruction({"语速": "快速", "基频": "中", "音量": "中"}), + "use_zero_spk_emb": True, + } + raise ValueError(f"Unknown case: {case}") + + +def save_audio(mm: dict[str, Any], output_path: str) -> None: + if not mm or "audio" not in mm: + raise RuntimeError("No audio found in model output") + audio = mm["audio"] + sr_raw = mm.get("sr", 44100) + if isinstance(sr_raw, torch.Tensor): + sample_rate = int(sr_raw.item()) + else: + sample_rate = int(sr_raw) + waveform = audio.squeeze().float().cpu().numpy() + sf.write(output_path, waveform, sample_rate) + print(f"Saved {output_path} ({len(waveform) / sample_rate:.2f}s, {sample_rate}Hz)") + + +def parse_args(): + parser = FlexibleArgumentParser(description="Ming-flash-omni standalone talker offline e2e example") + parser.add_argument("--model", type=str, default=MODEL_NAME, help="Model name or local path.") + parser.add_argument( + "--stage-configs-path", + type=str, + default=DEFAULT_STAGE_CONFIG, + help="Path to stage configs yaml for standalone talker deployment.", + ) + parser.add_argument( + "--case", + type=str, + default="style", + choices=["style", "ip", "basic"], + help="Example case.", + ) + parser.add_argument("--text", type=str, default=None, help="Override default text for the selected case.") + parser.add_argument("--output", type=str, default=None, help="Output wav path.") + parser.add_argument("--log-stats", action="store_true", default=False, help="Enable stats logging.") + parser.add_argument("--init-timeout", type=int, default=600, help="Engine init timeout in seconds.") + parser.add_argument("--stage-init-timeout", type=int, default=300, help="Single stage init timeout in seconds.") + return parser.parse_args() + + +def main(): + args = parse_args() + + omni = Omni( + model=args.model, + stage_configs_path=args.stage_configs_path, + trust_remote_code=True, + log_stats=args.log_stats, + init_timeout=args.init_timeout, + stage_init_timeout=args.stage_init_timeout, + ) + + messages = get_messages(args.case, args.text) + decode_args = { + # Standalone TTS deployment + "ming_task": "instruct", + "max_decode_steps": 200, + "cfg": 2.0, + "sigma": 0.25, + "temperature": 0.0, + } + req = OmniTokensPrompt( + prompt_token_ids=[0], + additional_information={**messages, **decode_args}, + ) + + outputs = omni.generate(req) + mm = outputs[0].outputs[0].multimodal_output + + output_path = args.output or f"output_{args.case}.wav" + save_audio(mm, output_path) + omni.close() + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/ming_flash_omni/README.md b/examples/online_serving/ming_flash_omni/README.md index 502232725c..8b7d03e211 100644 --- a/examples/online_serving/ming_flash_omni/README.md +++ b/examples/online_serving/ming_flash_omni/README.md @@ -4,38 +4,49 @@ Please refer to [README.md](../../../README.md) +## Deployment modes + +| Mode | Launch command | Output | +|------|---------------|--------| +| Thinker only (multimodal understanding) | `vllm serve ... --omni` | Text | +| Thinker + Talker (omni-speech) | `vllm serve ... --omni --stage-configs-path ming_flash_omni.yaml` | Text + Audio | + +For standalone TTS (talker only), see [`examples/online_serving/ming_flash_omni_tts/`](../ming_flash_omni_tts/). + ## Run examples (Ming-flash-omni 2.0) ### Launch the Server +**Thinker only (text output):** ```bash vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 ``` -If you have custom stage configs file, launch the server with command below +**Thinker + Talker (omni-speech, text + audio output):** ```bash -vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml ``` +Pass `--stage-configs-path /path/to/your_config.yaml` to use a custom stage +config. + ### Send Multi-modal Request -#### Send request via python +Shared Python client (supports `text | use_image | use_audio | use_video | +use_mixed_modalities`; pass `--image-path` / `--audio-path` / `--video-path` +for local files or URLs, `--modalities text` for output, `--help` for the +full flag list): ```bash -python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py --model Jonathan1909/Ming-flash-omni-2.0 --query-type use_mixed_modalities --port 8091 --host "localhost" --modalities text +python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \ + --model Jonathan1909/Ming-flash-omni-2.0 \ + --query-type use_mixed_modalities \ + --port 8091 --host localhost \ + --modalities text ``` -The Python client supports the following command-line arguments: - -- `--query-type` (or `-q`): Query type. Options: `text`, `use_audio`, `use_image`, `use_video`, `use_mixed_modalities` -- `--video-path` (or `-v`): Path to local video file or URL. If not provided and query-type uses video, uses default video URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs. Example: `--video-path /path/to/video.mp4` or `--video-path https://example.com/video.mp4` -- `--image-path` (or `-i`): Path to local image file or URL. If not provided and query-type uses image, uses default image URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common image formats: JPEG, PNG, GIF, WebP. Example: `--image-path /path/to/image.jpg` or `--image-path https://example.com/image.png` -- `--audio-path` (or `-a`): Path to local audio file or URL. If not provided and query-type uses audio, uses default audio URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common audio formats: MP3, WAV, OGG, FLAC, M4A. Example: `--audio-path /path/to/audio.wav` or `--audio-path https://example.com/audio.mp3` -- `--prompt` (or `-p`): Custom text prompt/question. If not provided, uses default prompt for the selected query type. Example: `--prompt "What are the main activities shown in this video?"` -- `--modalities`: Output modalities. For now, only `text` is supported. Example: `--modalities text` - - -#### Send request via curl +Parameterized curl wrapper in this directory: ```bash bash run_curl_multimodal_generation.sh text @@ -47,83 +58,17 @@ bash run_curl_multimodal_generation.sh use_mixed_modalities ## Modality control -Ming-flash-omni 2.0 currently supports text output only (thinker stage). - -| Modalities | Output | -|------------|--------| -| `["text"]` | Text only | -| Not specified | Text only (default) | - -### Using curl - -```bash -curl http://localhost:8091/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "Jonathan1909/Ming-flash-omni-2.0", - "messages": [ - {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, - {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"} - ], - "modalities": ["text"] - }' -``` - -### Using OpenAI Python SDK - -```python -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") - -response = client.chat.completions.create( - model="Jonathan1909/Ming-flash-omni-2.0", - messages=[ - {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, - {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}, - ], - modalities=["text"], -) -print(response.choices[0].message.content) -``` - -### Multi-modal input with OpenAI Python SDK - -```python -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") - -response = client.chat.completions.create( - model="Jonathan1909/Ming-flash-omni-2.0", - messages=[ - {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg"}}, - {"type": "text", "text": "Describe this image in detail."}, - ], - }, - ], - modalities=["text"], -) -print(response.choices[0].message.content) -``` - -## Streaming Output +| `modalities` | Server config | Output | +|-------------|--------------|--------| +| `["text"]` or omitted | Thinker only | Text | +| `["audio"]` | Thinker + Talker | Audio (speech) | +| `["text", "audio"]` | Thinker + Talker | Text + Audio | -To enable streaming output: - -```bash -python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \ - --query-type use_image \ - --model Jonathan1909/Ming-flash-omni-2.0 \ - --modalities text \ - --stream -``` +For ready-to-copy curl examples (text / audio / multimodal input, SSE +streaming, reasoning mode), see the recipe at +[`recipes/inclusionAI/Ming-flash-omni-2.0.md`](../../../recipes/inclusionAI/Ming-flash-omni-2.0.md). -Or with the OpenAI Python SDK: +## OpenAI Python SDK — streaming ```python from openai import OpenAI @@ -146,59 +91,5 @@ for chunk in response: print() ``` -Or using curl: - -```bash -curl http://localhost:8091/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "Jonathan1909/Ming-flash-omni-2.0", - "messages": [ - {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, - {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"} - ], - "modalities": ["text"], - "stream": true, - }' -``` - - -## Reasoning (Thinking Mode) - -To enable reasoning/thinking mode, change `detailed thinking off` to `detailed thinking on` in the system prompt: - -### Using curl - -```bash -curl http://localhost:8091/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "Jonathan1909/Ming-flash-omni-2.0", - "messages": [ - {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking on"}]}, - {"role": "user", "content": [ - {"type": "image_url", "image_url": {"url": "https://example.com/math_problem.png"}}, - {"type": "text", "text": "Solve this math problem step by step."} - ]} - ], - "modalities": ["text"] - }' -``` - -### Using OpenAI Python SDK - -```python -from openai import OpenAI - -client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") - -response = client.chat.completions.create( - model="Jonathan1909/Ming-flash-omni-2.0", - messages=[ - {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking on"}]}, - {"role": "user", "content": "If a train travels 120 km in 2 hours, what is its average speed?"}, - ], - modalities=["text"], -) -print(response.choices[0].message.content) -``` +The `--stream` flag on the Python client script above shows the same pattern +driven by the shared multimodal client. diff --git a/examples/online_serving/ming_flash_omni_tts/README.md b/examples/online_serving/ming_flash_omni_tts/README.md new file mode 100644 index 0000000000..e56d9ea12a --- /dev/null +++ b/examples/online_serving/ming_flash_omni_tts/README.md @@ -0,0 +1,54 @@ +# Ming-flash-omni Standalone TTS (Online Serving) + +This directory contains online e2e examples for **Ming-flash-omni-2.0 standalone talker deployment**. + +Server uses: + +- `model`: `Jonathan1909/Ming-flash-omni-2.0` +- `stage config`: `vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml` + +## Launch the Server + +```bash +# from repo root +bash examples/online_serving/ming_flash_omni_tts/run_server.sh +``` + +Equivalent manual command: + +```bash +vllm serve Jonathan1909/Ming-flash-omni-2.0 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml \ + --host 0.0.0.0 \ + --port 8091 \ + --trust-remote-code \ + --omni +``` + +## Send TTS Request + +Python client: + +```bash +python examples/online_serving/ming_flash_omni_tts/speech_client.py \ + --text "我们当迎着阳光辛勤耕作,去摘取,去制作,去品尝,去馈赠。" \ + --output ming_online.wav +``` + +Long-form `instructions` (e.g. ASMR whisper style) via the client: + +```bash +python examples/online_serving/ming_flash_omni_tts/speech_client.py \ + --text "我会一直在这里陪着你,直到你慢慢、慢慢地沉入那个最温柔的梦里……好吗?" \ + --instructions "这是一种ASMR耳语,属于一种旨在引发特殊感官体验的创意风格。这个女性使用轻柔的普通话进行耳语,声音气音成分重。音量极低,紧贴麦克风,语速极慢,旨在制造触发听者颅内快感的声学刺激。" \ + --output ming_online_asmr.wav +``` + +## Notes + +- This is the **online serving** counterpart of `examples/offline_inference/ming_flash_omni_tts/`. +- The server uses `use_zero_spk_emb=True` and the default decode args + (`max_decode_steps=200`, `cfg=2.0`, `sigma=0.25`, `temperature=0.0`). + For other caption fields (`语速`, `基频`, `IP`, BGM, etc.) or overriding + decode args, use the offline e2e example where `additional_information` + is set explicitly. diff --git a/examples/online_serving/ming_flash_omni_tts/run_server.sh b/examples/online_serving/ming_flash_omni_tts/run_server.sh new file mode 100755 index 0000000000..91be35c6c5 --- /dev/null +++ b/examples/online_serving/ming_flash_omni_tts/run_server.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Launch vLLM-Omni server for Ming-flash-omni-2.0 standalone talker (TTS). +# +# Usage: +# ./run_server.sh +# MODEL=/path/to/local/model ./run_server.sh +# PORT=8091 ./run_server.sh +# HOST=127.0.0.1 ./run_server.sh # bind only to loopback + +set -e + +MODEL="${MODEL:-Jonathan1909/Ming-flash-omni-2.0}" +HOST="${HOST:-0.0.0.0}" +PORT="${PORT:-8091}" +STAGE_CONFIG="${STAGE_CONFIG:-vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml}" + +echo "Starting Ming standalone TTS server with model: $MODEL" +echo "Stage config: $STAGE_CONFIG" + +vllm serve "$MODEL" \ + --stage-configs-path "$STAGE_CONFIG" \ + --host "$HOST" \ + --port "$PORT" \ + --trust-remote-code \ + --omni diff --git a/examples/online_serving/ming_flash_omni_tts/speech_client.py b/examples/online_serving/ming_flash_omni_tts/speech_client.py new file mode 100644 index 0000000000..b3fceba25e --- /dev/null +++ b/examples/online_serving/ming_flash_omni_tts/speech_client.py @@ -0,0 +1,93 @@ +"""Client for Ming standalone TTS via /v1/audio/speech endpoint.""" + +import argparse +import json +import sys + +import httpx + +DEFAULT_API_BASE = "http://localhost:8091" +DEFAULT_API_KEY = "EMPTY" +DEFAULT_MODEL = "Jonathan1909/Ming-flash-omni-2.0" + + +def run_tts(args) -> None: + payload = { + "model": args.model, + "input": args.text, + "response_format": args.response_format, + } + + instructions = args.instructions + if args.instruction_json: + if instructions: + sys.exit("--instructions and --instruction-json are mutually exclusive") + + try: + parsed = json.loads(args.instruction_json) + except json.JSONDecodeError as exc: + sys.exit(f"--instruction-json must be valid JSON: {exc}") + if not isinstance(parsed, dict): + sys.exit("--instruction-json must decode to a JSON object") + # Re-encode with ensure_ascii=False so UTF-8 Chinese keys/values + # arrive at the server intact rather than as \\uXXXX escapes. + instructions = json.dumps(parsed, ensure_ascii=False) + if instructions: + payload["instructions"] = instructions + + print(f"Model: {args.model}") + print(f"Text: {args.text}") + print("Generating audio...") + + api_url = f"{args.api_base}/v1/audio/speech" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {args.api_key}", + } + + with httpx.Client(timeout=300.0) as client: + response = client.post(api_url, json=payload, headers=headers) + + if response.status_code != 200: + print(f"Error: {response.status_code}") + print(response.text) + return + + output_path = args.output or "ming_tts_output.wav" + with open(output_path, "wb") as f: + f.write(response.content) + print(f"Audio saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Ming standalone TTS speech client") + parser.add_argument("--api-base", default=DEFAULT_API_BASE, help="API base URL") + parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key") + parser.add_argument("--model", "-m", default=DEFAULT_MODEL, help="Model name or local path") + parser.add_argument("--text", required=True, help="Text to synthesize") + parser.add_argument( + "--response-format", + default="wav", + choices=["wav", "mp3", "flac", "pcm", "aac", "opus"], + help="Audio format (default: wav)", + ) + parser.add_argument("--output", "-o", default=None, help="Output file path") + parser.add_argument( + "--instructions", + default=None, + help="Free-form style description (mapped to caption 风格 on the server).", + ) + parser.add_argument( + "--instruction-json", + default=None, + help=( + "Structured caption JSON forwarded as `instructions`. Accepts Ming " + "caption keys: 方言, 风格, 语速, 基频, 音量, 情感, IP, 说话人, BGM. " + ), + ) + args = parser.parse_args() + run_tts(args) + + +if __name__ == "__main__": + main() diff --git a/recipes/README.md b/recipes/README.md index 01ecc41f18..69ce4d7504 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -29,6 +29,8 @@ recipes/ multimodal chat on `1x A100 80GB` - [`Wan-AI/Wan2.2-I2V.md`](./Wan-AI/Wan2.2-I2V.md): image-to-video serving recipe for Wan2.2 14B on `8x Ascend NPU (A2/A3)` +- [`inclusionAI/Ming-flash-omni-2.0.md`](./inclusionAI/Ming-flash-omni-2.0.md): + online serving recipe for multimodal chat (`4x H100 80GB`) and standalone TTS (`1x H100 80GB`) Within a single recipe file, include different hardware support sections such as `GPU`, `ROCm`, and `NPU`, and add concrete tested configurations like diff --git a/recipes/inclusionAI/Ming-flash-omni-2.0.md b/recipes/inclusionAI/Ming-flash-omni-2.0.md new file mode 100644 index 0000000000..873158c8ad --- /dev/null +++ b/recipes/inclusionAI/Ming-flash-omni-2.0.md @@ -0,0 +1,210 @@ +# Ming-flash-omni 2.0 for omni-speech chat and standalone TTS + +## Summary + +- Vendor: inclusionAI +- Model: `Jonathan1909/Ming-flash-omni-2.0` +- Task: Multimodal chat with text, image, audio, or video input; standalone text-to-speech (TTS); +and image generation +- Mode: Online serving with the OpenAI-compatible API +- Maintainer: Community + +## When to use this recipe + +Use this recipe when you want a known-good starting point for serving +`Jonathan1909/Ming-flash-omni-2.0` with vLLM-Omni in one of three modes: + +- **Thinker only** — multimodal understanding with text output. +- **Thinker + Talker (omni-speech)** — multimodal understanding with text and spoken output. +- **Talker only (TTS)** — standalone text-to-speech via the OpenAI `/v1/audio/speech` endpoint. + +## References + +- Upstream model: + [`inclusionAI/Ming`](https://github.com/inclusionAI/Ming) +- For offline inference and additional client variants, see + `examples/offline_inference/ming_flash_omni{,_tts}/` and + `examples/online_serving/ming_flash_omni{,_tts}/`. + + +## Hardware Support + +This recipe documents reference GPU configurations for the two-stage +omni-speech deployment and the standalone TTS deployment. +Other hardware and configurations are welcome as community validation lands. + +## GPU + +### 4x H100 80GB — omni-speech/chat (thinker + talker) + +The bundled `ming_flash_omni.yaml` runs the thinker with tensor parallel size +4 on GPUs 0–3 and the talker on GPU 3. +Adjust `devices` in the YAML to match your hardware. + +#### Environment + +- OS: Linux +- Python: 3.10+ +- CUDA Driver Version: 590.48.01 +- CUDA 12.5 +- vLLM version: 0.19.0 +- vLLM-Omni version or commit: 0.19.0rc1 + +#### Command + +Thinker only (text output): + +```bash +vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 +``` + +Thinker + talker (text and/or audio output): + +```bash +vllm serve Jonathan1909/Ming-flash-omni-2.0 \ + --omni \ + --port 8091 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml \ + --log-stats +``` + +`--log-stats` is optional but recommended while validating the deployment. + +#### Verification + +Text output from a multimodal (image) input: + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "messages": [ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg"}}, + {"type": "text", "text": "Describe this image in detail."} + ]} + ], + "modalities": ["text"] + }' +``` + +Spoken response from a text query (save the WAV bytes): + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "messages": [ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"} + ], + "modalities": ["audio"] + }' | jq -r '.choices[0].message.audio.data' | base64 -d > ming_omni_parrot.wav +``` + +Text + audio output from an audio input (swap `audio_url` for `video_url` +or `image_url` to exercise the other multimodal input paths): + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "messages": [ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": [ + {"type": "audio_url", "audio_url": {"url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg"}}, + {"type": "text", "text": "Please recognize the language of this speech and transcribe it. Format: oral."} + ]} + ], + "modalities": ["text", "audio"] + }' | jq -r '.choices[0].message.content' +``` + +Streaming text output via SSE (set `"stream": true`): + +```bash +curl -N http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "messages": [ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"} + ], + "modalities": ["text"], + "stream": true + }' +``` + +Each SSE event carries a `data:` line with a chat-completion chunk; text +deltas appear at `choices[0].delta.content`. + +#### Notes + +- Output modality is selected by the request body: `"modalities": ["text"]`, + `["audio"]`, or `["text", "audio"]`. The two-stage omni-speech server must be launched + for any request containing `audio`. +- Reasoning mode: flip the system prompt suffix from `detailed thinking off` + to `detailed thinking on` in any request above. +- Memory usage: size depends on output modalities and multimodal input; leave + headroom for video frames and audio caches. + +### 1x H100 80GB — standalone TTS (talker only) + +The bundled `ming_flash_omni_tts.yaml` runs the talker on a single GPU and exposes the OpenAI `/v1/audio/speech` endpoint. + +#### Environment + +- OS: Linux +- Python: 3.10+ +- CUDA Driver Version: 590.48.01 +- CUDA 12.5 +- vLLM version: 0.19.0 +- vLLM-Omni version or commit: 0.19.0rc1 + +#### Command + +```bash +vllm serve Jonathan1909/Ming-flash-omni-2.0 \ + --omni \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml \ + --port 8091 \ + --log-stats +``` + +`--log-stats` is optional but recommended while validating the deployment. + +#### Verification + +Basic curl: + +```bash +curl -X POST http://localhost:8091/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "input": "我会一直在这里陪着你。", + "response_format": "wav" + }' --output ming_online.wav +``` + +Speaker selection (e.g. `lingguang`): + +```bash +curl -X POST http://localhost:8091/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "input": "春天来了,万物复苏,大地一片生机盎然。田野里的油菜花开得金灿灿的,蜜蜂在花丛中忙碌地采蜜。远处的山坡上,桃花和杏花竞相绽放,粉的白的交织在一起,美不胜收。清晨的微风带着泥土的芬芳,轻轻拂过脸颊,让人感到无比惬意。孩子们在田间小路上追逐嬉戏,老人们坐在门前晒太阳,享受着这份宁静与美好。", + "speaker": "lingguang", + "response_format": "wav" + }' --output ming_online_lingguang.wav +``` + +#### Notes + +- The OpenAI `instructions` field is forwarded to the talker as the caption JSON — pass a raw string for `风格` (style) only, or a JSON-encoded object for multiple entries such as `方言` (dialect) and `情感` (emotion). diff --git a/tests/e2e/offline_inference/test_ming_flash_omni.py b/tests/e2e/offline_inference/test_ming_flash_omni.py index c591e910ac..ca0b0fe0d9 100644 --- a/tests/e2e/offline_inference/test_ming_flash_omni.py +++ b/tests/e2e/offline_inference/test_ming_flash_omni.py @@ -36,6 +36,21 @@ def build_prompt(user_text: str) -> str: def get_eager_config(): + path = modify_stage_config( + str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_thinker_only_ci.yaml"), + updates={ + "stage_args": { + 0: { + "engine_args.enforce_eager": "true", + }, + }, + }, + ) + return path + + +def get_eager_tts_config(): + """Thinker+talker CI config with enforce_eager on the thinker stage.""" path = modify_stage_config( str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_ci.yaml"), updates={ @@ -49,9 +64,14 @@ def get_eager_config(): return path +# Thinker-only config — used by text-output tests. stage_configs = [get_eager_config()] test_params = [(model, stage_config) for model in models for stage_config in stage_configs] +# Thinker+talker config — used by audio-output tests. +stage_configs_tts = [get_eager_tts_config()] +test_params_tts = [(model, stage_config) for model in models for stage_config in stage_configs_tts] + @pytest.mark.core_model @pytest.mark.omni @@ -140,3 +160,36 @@ def test_mixed_to_text(omni_runner, omni_runner_handler) -> None: request_config = {"prompts": prompt, "images": image, "audios": audio, "modalities": ["text"]} omni_runner_handler.send_request(request_config) + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_runner", test_params_tts, indirect=True) +def test_text_to_audio(omni_runner, omni_runner_handler) -> None: + """ + Test text input with audio output via the thinker+talker pipeline. + Input Modal: text + Output Modal: audio + """ + prompt = build_prompt("请简单介绍一下北京。") + request_config = {"prompts": prompt, "modalities": ["audio"]} + + omni_runner_handler.send_request(request_config) + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_runner", test_params_tts, indirect=True) +def test_image_to_audio(omni_runner, omni_runner_handler) -> None: + """ + Test image + text input with audio output via the thinker+talker pipeline. + Input Modal: image + text + Output Modal: audio + """ + image = generate_synthetic_image(224, 224)["np_array"] + prompt = build_prompt(f"{IMAGE_TOKEN}Describe this image briefly.") + request_config = {"prompts": prompt, "images": image, "modalities": ["audio"]} + + omni_runner_handler.send_request(request_config) diff --git a/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml index fb0c72cc51..53cc73ce09 100644 --- a/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml +++ b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml @@ -1,4 +1,8 @@ -# Thinker stage only +# CI stage config for Ming-flash-omni-2.0 thinker+talker pipeline. +# Stage 0: Thinker (multimodal understanding -> text generation) +# Stage 1: Talker (text -> audio waveform via CFM + AudioVAE) +# +# The following config has been verified on 4x H100-80G GPUs stage_args: - stage_id: 0 stage_type: llm @@ -10,10 +14,12 @@ stage_args: model_arch: MingFlashOmniForConditionalGeneration worker_type: ar scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.9 + gpu_memory_utilization: 0.74 enforce_eager: false trust_remote_code: true - engine_output_type: latent + # Ming Thinker -> talker bridge reads detokenised text from + # source_output.outputs[0].text (not hidden states). + engine_output_type: text distributed_executor_backend: "mp" enable_prefix_caching: false max_num_batched_tokens: 32768 @@ -22,14 +28,48 @@ stage_args: hf_config_name: llm_config load_format: dummy mm_processor_cache_gb: 0 + compilation_config: + pass_config: + # Disable fused all-reduce to avoid a vllm/flashinfer version mismatch. + fuse_allreduce_rms: false final_output: true final_output_type: text is_comprehension: true default_sampling_params: - temperature: 0.4 + temperature: 0.0 top_p: 0.9 max_tokens: 100 repetition_penalty: 1.05 seed: 42 detokenize: true ignore_eos: false + + - stage_id: 1 + stage_type: llm + runtime: + devices: "3" + max_batch_size: 1 + engine_args: + model_stage: ming_tts + model_arch: MingFlashOmniTalkerForConditionalGeneration + worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + gpu_memory_utilization: 0.18 + enforce_eager: true + trust_remote_code: true + engine_output_type: audio + enable_prefix_caching: false + max_num_batched_tokens: 1000000 + tokenizer_subdir: talker/llm + # The HF repo ships BailingMM2Config (thinker-only) at root; + # OmniModelConfig treats that as "stage does not share outer mrope". + hf_config_name: talker_config + load_format: dummy + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.ming_flash_omni.thinker2talker + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + max_tokens: 1 + seed: 42 diff --git a/tests/e2e/stage_configs/bailingmm_moe_v2_lite_thinker_only_ci.yaml b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_thinker_only_ci.yaml new file mode 100644 index 0000000000..fb0c72cc51 --- /dev/null +++ b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_thinker_only_ci.yaml @@ -0,0 +1,35 @@ +# Thinker stage only +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + devices: "0,1,2,3" + max_batch_size: 1 + engine_args: + model_stage: thinker + model_arch: MingFlashOmniForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: false + trust_remote_code: true + engine_output_type: latent + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + max_model_len: 32768 + tensor_parallel_size: 4 + hf_config_name: llm_config + load_format: dummy + mm_processor_cache_gb: 0 + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + max_tokens: 100 + repetition_penalty: 1.05 + seed: 42 + detokenize: true + ignore_eos: false diff --git a/tests/model_executor/models/ming_flash_omni/test_talker_cfm.py b/tests/model_executor/models/ming_flash_omni/test_talker_cfm.py new file mode 100644 index 0000000000..419ce00dae --- /dev/null +++ b/tests/model_executor/models/ming_flash_omni/test_talker_cfm.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from vllm_omni.model_executor.models.ming_flash_omni.talker_module import ( + CFM, + Aggregator, + CFMGraphExecutor, + CFMGraphExecutorPool, + DiT, +) + +torch = pytest.importorskip("torch") +pytest.importorskip("x_transformers") + +pytestmark = [ + pytest.mark.core_model, + pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA for graph capture"), +] + +_LATENT_DIM = 8 +_PATCH_SIZE = 4 +_HIS_PATCH_SIZE = 8 +_LLM_HIDDEN = 16 +_DIT_HIDDEN = 32 +_AGG_HIDDEN = 32 +_NUM_HEADS = 4 +_DEPTH = 2 +_STEPS = 5 +_DTYPE = torch.float32 + + +def _warmup_pipeline(cfm: CFM, aggregator: Aggregator, stop_head: torch.nn.Linear, device: torch.device) -> None: + llm_cond = torch.randn(1, 1, _LLM_HIDDEN, device=device, dtype=_DTYPE) + lat_cond = torch.randn(1, _HIS_PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE) + y0 = torch.randn(1, _PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE) + t = torch.linspace(0.0, 1.0, _STEPS + 1, device=device, dtype=_DTYPE) + sde_args = torch.tensor([2.0, 0.25, 0.0], device=device, dtype=_DTYPE) + sde_rnd = torch.randn(_STEPS, 1, _PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE) + + with torch.no_grad(): + gen_lat = cfm.sample(llm_cond, lat_cond, y0, t, sde_args, sde_rnd) + aggregator(gen_lat) + stop_head(llm_cond[:, -1, :]).softmax(dim=-1) + torch.cuda.synchronize(device) + + +def _build_pipeline(): + device = torch.device("cuda") + dit = ( + DiT( + in_channels=_LATENT_DIM, + hidden_size=_DIT_HIDDEN, + depth=_DEPTH, + num_heads=_NUM_HEADS, + mlp_ratio=2.0, + llm_cond_dim=_LLM_HIDDEN, + ) + .to(device=device, dtype=_DTYPE) + .eval() + ) + cfm = CFM(dit, steps=_STEPS, sway_sampling_coef=-1.0).to(device=device, dtype=_DTYPE).eval() + aggregator = ( + Aggregator( + in_channels=_LATENT_DIM, + hidden_size=_AGG_HIDDEN, + depth=_DEPTH, + num_heads=_NUM_HEADS, + mlp_ratio=2.0, + llm_input_dim=_LLM_HIDDEN, + ) + .to(device=device, dtype=_DTYPE) + .eval() + ) + stop_head = torch.nn.Linear(_LLM_HIDDEN, 2).to(device=device, dtype=_DTYPE).eval() + + config = SimpleNamespace(steps=_STEPS, patch_size=_PATCH_SIZE) + _warmup_pipeline(cfm, aggregator, stop_head, device) + return config, cfm, aggregator, stop_head, device + + +class TestCFMGraphExecutor: + """Capture once, replay twice: outputs must stay consistently-shaped.""" + + def test_execute_shapes_and_replay(self) -> None: + config, cfm, aggregator, stop_head, device = _build_pipeline() + executor = CFMGraphExecutor(config, cfm, aggregator, stop_head) + + bsz = 1 + input_tensor = torch.randn(bsz, 1, _LLM_HIDDEN, device=device, dtype=_DTYPE) + his_lat = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE) + + gen_lat, inputs_embeds, stop_out = executor.execute(input_tensor, his_lat) + torch.cuda.synchronize() + + assert gen_lat.shape == (bsz, _PATCH_SIZE, _LATENT_DIM) + assert inputs_embeds.shape == (bsz, 1, _LLM_HIDDEN) + assert stop_out.shape == (bsz, 2) + assert torch.isfinite(gen_lat).all() + assert torch.isfinite(inputs_embeds).all() + # stop_head output is softmax-normalized across the last dim. + assert torch.allclose(stop_out.sum(dim=-1), torch.ones(bsz, device=device, dtype=_DTYPE), atol=1e-4) + + # Replay the captured graph with fresh inputs — shapes must match. + new_input = torch.randn_like(input_tensor) + new_his = torch.randn_like(his_lat) + gen_lat2, inputs_embeds2, stop_out2 = executor.execute(new_input, new_his) + torch.cuda.synchronize() + assert gen_lat2.shape == gen_lat.shape + assert inputs_embeds2.shape == inputs_embeds.shape + assert stop_out2.shape == stop_out.shape + assert executor.initialized is True + + def test_execute_is_noninplace_on_inputs(self) -> None: + config, cfm, aggregator, stop_head, device = _build_pipeline() + executor = CFMGraphExecutor(config, cfm, aggregator, stop_head) + + input_tensor = torch.randn(1, 1, _LLM_HIDDEN, device=device, dtype=_DTYPE) + his_lat = torch.randn(1, _HIS_PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE) + snapshot_input = input_tensor.clone() + snapshot_his = his_lat.clone() + + executor.execute(input_tensor, his_lat) + torch.cuda.synchronize() + assert torch.equal(input_tensor, snapshot_input) + assert torch.equal(his_lat, snapshot_his) + + +class TestCFMGraphExecutorPool: + def test_pool_acquires_and_releases(self) -> None: + config, cfm, aggregator, stop_head, device = _build_pipeline() + pool = CFMGraphExecutorPool(config, cfm, aggregator, stop_head, pool_size=2) + + input_tensor = torch.randn(1, 1, _LLM_HIDDEN, device=device, dtype=_DTYPE) + his_lat = torch.randn(1, _HIS_PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE) + + gen_lat, inputs_embeds, stop_out = pool.execute(input_tensor, his_lat) + torch.cuda.synchronize() + assert gen_lat.shape == (1, _PATCH_SIZE, _LATENT_DIM) + assert inputs_embeds.shape == (1, 1, _LLM_HIDDEN) + assert stop_out.shape == (1, 2) + assert pool.pool.qsize() == 2 diff --git a/tests/model_executor/models/ming_flash_omni/test_talker_modules.py b/tests/model_executor/models/ming_flash_omni/test_talker_modules.py new file mode 100644 index 0000000000..4cbbc887a5 --- /dev/null +++ b/tests/model_executor/models/ming_flash_omni/test_talker_modules.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import pytest + +from vllm_omni.model_executor.models.ming_flash_omni.talker_module import CFM, Aggregator, DiT + +torch = pytest.importorskip("torch") +pytest.importorskip("x_transformers") + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +_LATENT_DIM = 8 +_PATCH_SIZE = 4 +_HIS_PATCH_SIZE = 8 +_LLM_HIDDEN = 16 +_DIT_HIDDEN = 32 +_AGG_HIDDEN = 32 +_NUM_HEADS = 4 +_DEPTH = 2 +_STEPS = 5 + + +def _make_dit() -> DiT: + return DiT( + in_channels=_LATENT_DIM, + hidden_size=_DIT_HIDDEN, + depth=_DEPTH, + num_heads=_NUM_HEADS, + mlp_ratio=2.0, + llm_cond_dim=_LLM_HIDDEN, + ) + + +def _make_aggregator() -> Aggregator: + return Aggregator( + in_channels=_LATENT_DIM, + hidden_size=_AGG_HIDDEN, + depth=_DEPTH, + num_heads=_NUM_HEADS, + mlp_ratio=2.0, + llm_input_dim=_LLM_HIDDEN, + ) + + +class TestDiTDummyForward: + """DiT with dummy weights runs forward + CFG-doubled forward.""" + + def test_forward_shape(self) -> None: + dit = _make_dit().eval() + bsz = 2 + x = torch.randn(bsz, _PATCH_SIZE, _LATENT_DIM) + t = torch.zeros(bsz) + c = torch.randn(bsz, 1, _LLM_HIDDEN) + latent_history = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM) + + with torch.no_grad(): + out = dit(x, t, c, latent_history) + + # Output preserves the concatenated (history + time/cond prefix + x) + # token axis: history + 1 (time+cond) + patch. + assert out.shape == (bsz, _HIS_PATCH_SIZE + 1 + _PATCH_SIZE, _LATENT_DIM) + + def test_forward_with_cfg_trims_to_patch(self) -> None: + dit = _make_dit().eval() + bsz = 1 + x = torch.randn(bsz, _PATCH_SIZE, _LATENT_DIM) + t = torch.zeros(()) + c = torch.randn(bsz, 1, _LLM_HIDDEN) + latent_history = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM) + + with torch.no_grad(): + out = dit.forward_with_cfg(x, t, c, latent_history) + + # CFG doubles the batch and trims the output to the patch window. + assert out.shape == (2 * bsz, _PATCH_SIZE, _LATENT_DIM) + + +class TestAggregatorDummyForward: + """Aggregator with dummy weights maps latent patch -> LLM hidden.""" + + def test_forward_shape(self) -> None: + agg = _make_aggregator().eval() + bsz = 3 + gen_lat = torch.randn(bsz, _PATCH_SIZE, _LATENT_DIM) + + with torch.no_grad(): + out = agg(gen_lat) + + assert out.shape == (bsz, 1, _LLM_HIDDEN) + + def test_forward_is_finite(self) -> None: + agg = _make_aggregator().eval() + gen_lat = torch.randn(1, _PATCH_SIZE, _LATENT_DIM) + with torch.no_grad(): + out = agg(gen_lat) + assert torch.isfinite(out).all() + + +class TestCFMSampleDummy: + """CFM.sample drives DiT.forward_with_cfg through the integration loop.""" + + def test_sample_shape_and_finite(self) -> None: + cfm = CFM(_make_dit(), steps=_STEPS, sway_sampling_coef=-1.0).eval() + bsz = 1 + llm_cond = torch.randn(bsz, 1, _LLM_HIDDEN) + lat_cond = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM) + y0 = torch.randn(bsz, _PATCH_SIZE, _LATENT_DIM) + # Grid used by the talker; must span [0, 1] inclusive. + t = torch.linspace(0.0, 1.0, _STEPS + 1) + sde_args = torch.tensor([2.0, 0.0, 0.0]) # cfg=2.0, sigma=0, temp=0 + sde_rnd = torch.zeros(_STEPS, bsz, _PATCH_SIZE, _LATENT_DIM) + + with torch.no_grad(): + out = cfm.sample(llm_cond, lat_cond, y0, t, sde_args, sde_rnd) + + assert out.shape == y0.shape + assert torch.isfinite(out).all() + + def test_sample_zero_cfg_reduces_to_unguided(self) -> None: + """With cfg=0 the guidance term drops, but output shape is still valid.""" + cfm = CFM(_make_dit(), steps=_STEPS, sway_sampling_coef=None).eval() + bsz = 2 + llm_cond = torch.randn(bsz, 1, _LLM_HIDDEN) + lat_cond = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM) + y0 = torch.zeros(bsz, _PATCH_SIZE, _LATENT_DIM) + t = torch.linspace(0.0, 1.0, _STEPS + 1) + sde_args = torch.tensor([0.0, 0.0, 0.0]) + sde_rnd = torch.zeros(_STEPS, bsz, _PATCH_SIZE, _LATENT_DIM) + + with torch.no_grad(): + out = cfm.sample(llm_cond, lat_cond, y0, t, sde_args, sde_rnd) + + assert out.shape == (bsz, _PATCH_SIZE, _LATENT_DIM) + assert torch.isfinite(out).all() + + +class TestTalkerPipelineDummyWiring: + """End-to-end wiring of DiT -> CFM.sample -> Aggregator with dummy weights.""" + + def test_cfm_then_aggregator(self) -> None: + dit = _make_dit().eval() + cfm = CFM(dit, steps=_STEPS, sway_sampling_coef=-1.0).eval() + agg = _make_aggregator().eval() + + bsz = 1 + llm_cond = torch.randn(bsz, 1, _LLM_HIDDEN) + lat_cond = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM) + y0 = torch.randn(bsz, _PATCH_SIZE, _LATENT_DIM) + t = torch.linspace(0.0, 1.0, _STEPS + 1) + sde_args = torch.tensor([2.0, 0.0, 0.0]) + sde_rnd = torch.zeros(_STEPS, bsz, _PATCH_SIZE, _LATENT_DIM) + + with torch.no_grad(): + gen_lat = cfm.sample(llm_cond, lat_cond, y0, t, sde_args, sde_rnd) + agg_out = agg(gen_lat) + + assert gen_lat.shape == (bsz, _PATCH_SIZE, _LATENT_DIM) + assert agg_out.shape == (bsz, 1, _LLM_HIDDEN) + assert torch.isfinite(agg_out).all() diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index 96a34a8d79..82d65e5722 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -5,7 +5,10 @@ from vllm.config import ModelConfig from vllm.config.utils import config from vllm.logger import init_logger -from vllm.transformers_utils.config import get_hf_text_config +from vllm.transformers_utils.config import ( + get_hf_text_config, + thinker_uses_mrope, +) from vllm.transformers_utils.model_arch_config_convertor import ( ModelArchConfigConvertorBase, ) @@ -125,6 +128,18 @@ def architectures(self) -> list[str]: return [self.model_arch] return super().architectures + @property + def uses_mrope(self) -> bool: + if self.hf_config_name is not None: + # talker_config/thinker_config/etc + stage_config = getattr(self.hf_config, self.hf_config_name, None) + if stage_config is None: + # Check the named sub-config's text_config directly. + # Handles mrope resolution of stage-specific cls + # (e.g., talker runs as a standalone cls) + return thinker_uses_mrope(self.hf_config) + return super().uses_mrope + @property def embedding_size(self): if self.hf_config_name is not None: diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 06b739e3be..dc64ecd5c5 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -559,6 +559,16 @@ async def _preprocess_chat( engine_prompt["additional_information"] = {} engine_prompt["additional_information"]["language"] = [language.strip()] + # Style instruction — used by Ming-flash-omni instruct TTS path + # (ming_task="instruct"). For the omni speech path the thinker2talker + # bridge drops this field to match upstream omni_audio_generation + # which hardcodes instruction=None. + instructions = getattr(request, "instructions", None) + if instructions is not None and isinstance(instructions, str) and instructions.strip(): + if "additional_information" not in engine_prompt or engine_prompt["additional_information"] is None: + engine_prompt["additional_information"] = {} + engine_prompt["additional_information"]["instruction"] = instructions.strip() + return conversation, [engine_prompt] async def _inject_audio_from_video_urls( @@ -1930,7 +1940,8 @@ def _create_audio_choice( final_res = omni_outputs.request_output # OMNI: Access multimodal_output from CompletionOutput (outputs[0]), not from RequestOutput # Reference: examples/offline_inference/qwen3_omni/end2end.py line 421 - audio_data = final_res.outputs[0].multimodal_output.get("audio") + mm_output = final_res.outputs[0].multimodal_output + audio_data = mm_output.get("audio") if isinstance(audio_data, list): if stream: audio_tensor = audio_data[-1] @@ -1944,9 +1955,20 @@ def _create_audio_choice( if audio_tensor.ndim > 1: audio_tensor = audio_tensor.flatten() + # Prefer the talker-reported sample rate when present. Qwen3-Omni + # omits "sr" and runs at 24kHz; Ming-flash-omni surfaces a 44.1kHz + # AudioVAE rate via multimodal_output["sr"]. + sr_raw = mm_output.get("sr") + if sr_raw is None: + sample_rate = 24000 + elif hasattr(sr_raw, "item"): + sample_rate = int(sr_raw.item()) + else: + sample_rate = int(sr_raw) + audio_obj = CreateAudio( audio_tensor=audio_tensor, - sample_rate=24000, + sample_rate=sample_rate, response_format="wav", speed=1.0, stream_format="audio", diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index c275c77959..5820af3192 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -44,6 +44,12 @@ estimate_fish_voice_clone_prompt_len_from_normalized, normalize_fish_voice_clone_texts, ) +from vllm_omni.model_executor.models.ming_flash_omni.prompt_utils import ( + DEFAULT_PROMPT as MING_DEFAULT_PROMPT, +) +from vllm_omni.model_executor.models.ming_flash_omni.prompt_utils import ( + create_instruction as ming_create_instruction, +) from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -56,6 +62,7 @@ _OMNIVOICE_TTS_MODEL_STAGES = {"omnivoice_generator"} _VOXCPM_TTS_MODEL_STAGES = {"latent_generator", "vae"} _VOXCPM2_TTS_MODEL_STAGES = {"latent_generator"} +_MING_TTS_MODEL_STAGES = {"ming_tts"} _TTS_MODEL_STAGES: set[str] = ( _VOXTRAL_TTS_MODEL_STAGES | _QWEN3_TTS_MODEL_STAGES @@ -64,6 +71,7 @@ | _OMNIVOICE_TTS_MODEL_STAGES | _VOXCPM_TTS_MODEL_STAGES | _VOXCPM2_TTS_MODEL_STAGES + | _MING_TTS_MODEL_STAGES ) _TTS_LANGUAGES: set[str] = { "Auto", @@ -312,6 +320,8 @@ def _detect_tts_model_type(self) -> str | None: for stage in self.engine_client.stage_configs ) return "voxcpm" if has_vae_stage or model_stage == "vae" else "voxcpm2" + if model_stage in _MING_TTS_MODEL_STAGES: + return "ming_flash_omni_tts" return None def _compute_max_instructions_length(self) -> int: @@ -335,6 +345,11 @@ def _compute_max_instructions_length(self) -> int: def _load_supported_speakers(self) -> set[str]: """Load supported speakers (case-insensitive) from the model configuration.""" + if self._tts_model_type == "ming_flash_omni_tts": + # Ming-flash-omni drives speaker selection via the caption JSON + # (audio_sequence[0]["说话人"]) rather than a spk_id table, so there + # is no static speaker list to surface here. + return set() try: if self._tts_model_type == "voxcpm": return set() @@ -836,6 +851,8 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non return self._validate_voxcpm_request(request) if self._tts_model_type == "voxcpm2": return None # VoxCPM2 accepts any text input + if self._tts_model_type == "ming_flash_omni_tts": + return self._validate_ming_tts_request(request) return self._validate_qwen_tts_request(request) def _voxcpm2_encode(self, text: str) -> list[int]: @@ -857,6 +874,41 @@ def _voxcpm2_encode(self, text: str) -> list[int]: ids = self._voxcpm2_tokenizer.encode(text, add_special_tokens=True) return split_multichar_chinese(ids, self._voxcpm2_split_map) + def _validate_ming_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: + """Validate Ming-flash-omni standalone-talker request parameters.""" + if not request.input or not request.input.strip(): + return "Input text cannot be empty" + if request.instructions is not None: + if not isinstance(request.instructions, str): + return "instructions must be a string" + if len(request.instructions) > self._max_instructions_length: + return f"instructions exceeds max length {self._max_instructions_length}" + + if request.task_type is not None: + return "'task_type' is not supported for Ming-flash-omni TTS" + if request.language is not None: + return "'language' is not supported for Ming-flash-omni TTS (language is inferred from input text)" + if request.x_vector_only_mode is not None: + return "'x_vector_only_mode' is not supported for Ming-flash-omni TTS" + if request.initial_codec_chunk_frames is not None: + return "'initial_codec_chunk_frames' is not supported for Ming-flash-omni TTS" + + # Per-request voice cloning from raw audio is not yet wired up: Ming + # extracts spk_emb / prompt_wav_lat / prompt_wav_emb model-side via + # register_prompt_wav() at engine init. For ad-hoc cloning, callers + # should pre-compute speaker_embedding and pass it directly. + if request.ref_audio is not None: + return ( + "'ref_audio' is not yet supported for Ming-flash-omni TTS; " + "use a preset 'voice' or 'speaker_embedding' instead" + ) + if request.ref_text is not None: + return "'ref_text' is not yet supported for Ming-flash-omni TTS" + + if request.max_new_tokens is not None and request.max_new_tokens <= 0: + return "'max_new_tokens' must be a positive integer" + return None + def _validate_ref_audio_format(self, ref_audio: str) -> str | None: """Validate ref_audio is a supported URI format. Returns error or None.""" if not ( @@ -1519,6 +1571,52 @@ async def _build_cosyvoice3_prompt( }, } + # ---- Ming-flash-omni standalone-talker (TTS) helpers ---- + + def _build_ming_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: + # request.instructions accepts two forms: + # 1. Plain text: mapped to the caption's 风格 (style) field + # 2. JSON object: parsed and splatted into the caption. Unlocks + # Unknown keys are dropped by `ming_create_instruction`. + caption_fields: dict[str, Any] = {} + if request.instructions: + stripped = request.instructions.strip() + if stripped.startswith("{"): + try: + parsed = json.loads(stripped) + except json.JSONDecodeError: + parsed = None + if isinstance(parsed, dict): + caption_fields.update(parsed) + else: + caption_fields["风格"] = request.instructions + else: + caption_fields["风格"] = request.instructions + + has_spk_emb = request.speaker_embedding is not None + + # TTS path applies ming task type `instruct`. + # voice_name enables talker-side voice preset resolution (e.g. "DB30"). + additional_information: dict[str, Any] = { + "ming_task": "instruct", + "prompt": MING_DEFAULT_PROMPT, + "text": request.input, + "instruction": ming_create_instruction(caption_fields), + "voice_name": request.voice or None, + "use_zero_spk_emb": not has_spk_emb, + "max_decode_steps": request.max_new_tokens or _TTS_MAX_NEW_TOKENS_MAX, + "cfg": 2.0, + "sigma": 0.25, + "temperature": 0.0, + } + if has_spk_emb: + # Passed as plain float list + additional_information["spk_emb"] = list(request.speaker_embedding) + return { + "prompt_token_ids": [0], + "additional_information": additional_information, + } + # ---- Common speech generation helpers ---- async def _prepare_speech_generation( @@ -1581,6 +1679,9 @@ async def _prepare_speech_generation( elif self._tts_model_type == "cosyvoice3": prompt = await self._build_cosyvoice3_prompt(request) tts_params = {} + elif self._tts_model_type == "ming_flash_omni_tts": + prompt = self._build_ming_prompt(request) + tts_params = {} else: tts_params = self._build_tts_params(request) # Resolve ref_audio (explicit or auto-set for uploaded voices) @@ -1628,6 +1729,8 @@ async def _prepare_speech_generation( model_type = "voxcpm" elif self._tts_model_type == "voxcpm2": model_type = "voxcpm2" + elif self._tts_model_type == "ming_flash_omni_tts": + model_type = "ming_flash_omni_tts" elif self._is_tts: model_type = tts_params.get("task_type", ["unknown"])[0] else: diff --git a/vllm_omni/model_executor/models/ming_flash_omni/__init__.py b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py index d7fa44fd7e..4cd086e642 100644 --- a/vllm_omni/model_executor/models/ming_flash_omni/__init__.py +++ b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py @@ -2,6 +2,7 @@ # Copyright 2025 The vLLM-Omni team. from .ming_flash_omni import MingFlashOmniForConditionalGeneration +from .ming_flash_omni_talker import MingFlashOmniTalkerForConditionalGeneration from .ming_flash_omni_thinker import ( MingFlashOmniThinkerDummyInputsBuilder, MingFlashOmniThinkerForConditionalGeneration, @@ -11,6 +12,7 @@ __all__ = [ "MingFlashOmniForConditionalGeneration", + "MingFlashOmniTalkerForConditionalGeneration", "MingFlashOmniThinkerForConditionalGeneration", "MingFlashOmniThinkerProcessingInfo", "MingFlashOmniThinkerMultiModalProcessor", diff --git a/vllm_omni/model_executor/models/ming_flash_omni/audio_vae.py b/vllm_omni/model_executor/models/ming_flash_omni/audio_vae.py new file mode 100644 index 0000000000..9d5c266b4f --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/audio_vae.py @@ -0,0 +1,390 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright (c) Ant Group. All rights reserved. +# Adapted from: +# https://github.com/inclusionAI/Ming/tree/e58533db227031990c5a6864dcf5f08fb53ed0d2/AudioVAE + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig, PreTrainedModel, Qwen2Config, Qwen2Model +from vllm.logger import init_logger + +logger = init_logger(__name__) +try: + import flash_attn # noqa: F401 +except (ImportError, ModuleNotFoundError): + flash_attn = None + logger.warning( + "flash_attn is not available, the model may not yield the " + "exactly same result as the transformers implementation " + "in the audio tower part." + ) + + +class AudioVAEConfig(PretrainedConfig): + model_type = "audio_vae" + + def __init__( + self, + sample_rate: int = 44100, + enc_kwargs: dict | None = None, + dec_kwargs: dict | None = None, + init_method: str = "kaiming", + patch_size: int = 4, + **kwargs, + ): + self.sample_rate = sample_rate + self.enc_kwargs = enc_kwargs or {} + self.dec_kwargs = dec_kwargs or {} + self.init_method = init_method + self.patch_size = patch_size + super().__init__(**kwargs) + + +class ISTFT(nn.Module): + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + self.buffer_len = self.win_length - self.hop_length + + def _buffer_process(self, x, buffer, pad, last_chunk=False, streaming=False): + if streaming: + if buffer is None: + x = x[:, pad:] + if buffer is not None: + x[:, : self.buffer_len] += buffer + buffer = x[:, -self.buffer_len :] + if not last_chunk: + x = x[:, : -self.buffer_len] + else: + x = x[:, :-pad] + else: + x = x[:, pad:-pad] + return x, buffer + + def forward(self, spec, audio_buffer=None, window_buffer=None, streaming=False, last_chunk=False): + if self.padding == "center": + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + B, N, T = spec.shape + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + )[:, 0, 0, :] + + y, audio_buffer = self._buffer_process(y, audio_buffer, pad, last_chunk=last_chunk, streaming=streaming) + + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = ( + torch.nn.functional.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + ) + .squeeze(0) + .squeeze(0) + ) + + window_envelope, window_buffer = self._buffer_process( + window_envelope, window_buffer, pad, last_chunk=last_chunk, streaming=streaming + ) + window_envelope = window_envelope.squeeze() + + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y, audio_buffer, window_buffer + + +class ISTFTHead(nn.Module): + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = nn.Linear(dim, out_dim) + self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) + + def forward(self, x, audio_buffer=None, window_buffer=None, streaming=False, last_chunk=False): + x_pred = self.out(x) + x_pred = x_pred.transpose(1, 2) + mag, p = x_pred.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) + x = torch.cos(p) + y = torch.sin(p) + S = mag * (x + 1j * y) + audio, audio_buffer, window_buffer = self.istft( + S, audio_buffer=audio_buffer, window_buffer=window_buffer, streaming=streaming, last_chunk=last_chunk + ) + return audio.unsqueeze(1), x_pred, audio_buffer, window_buffer + + +class StreamingLinearUpsample(nn.Module): + def __init__(self, scale_factor=4): + super().__init__() + self.scale_factor = scale_factor + self.upsampler = nn.Upsample(scale_factor=scale_factor, mode="linear", align_corners=False) + + def forward(self, x, state=None, is_last=False): + if state is None: + state = {"prev_chunk": None, "history_last": None, "is_first": True} + + if x is None and not is_last: + return None, state + + if state["is_first"] and is_last: + out = self.upsampler(x.transpose(1, 2)).transpose(1, 2) + return out, None + + output_chunks = [] + + if state["is_first"]: + state["prev_chunk"] = x + state["is_first"] = False + if not is_last: + return None, state + + if state["prev_chunk"] is not None: + p = state["prev_chunk"].transpose(1, 2) + + if state["history_last"] is None: + lookahead = x[:, :1, :].transpose(1, 2) + inp = torch.cat([p, lookahead], dim=2) + up = self.upsampler(inp) + out_prev = up[:, :, : p.size(2) * self.scale_factor] + else: + lookahead = x[:, :1, :].transpose(1, 2) + inp = torch.cat([state["history_last"], p, lookahead], dim=2) + up = self.upsampler(inp) + start = self.scale_factor + end = start + p.size(2) * self.scale_factor + out_prev = up[:, :, start:end] + + output_chunks.append(out_prev.transpose(1, 2)) + state["history_last"] = p[:, :, -1:] + state["prev_chunk"] = x + + if is_last: + p = state["prev_chunk"].transpose(1, 2) + inp = torch.cat([state["history_last"], p], dim=2) + up = self.upsampler(inp) + out_last = up[:, :, self.scale_factor :] + output_chunks.append(out_last.transpose(1, 2)) + state = None + + final_out = torch.cat(output_chunks, dim=1) if output_chunks else None + return final_out, state + + +class Decoder(nn.Module): + def __init__(self, decoder_args, output_dim=320, latent_dim=64, patch_size=-1): + super().__init__() + config = Qwen2Config.from_dict(config_dict=decoder_args) + if flash_attn is None: + config._attn_implementation = "sdpa" + self.decoder = Qwen2Model(config) + self.output_dim = output_dim + self.latent_dim = latent_dim + self.fc1 = nn.Linear(latent_dim, config.hidden_size) + self.hop_length = output_dim + self.head = ISTFTHead( + dim=config.hidden_size, n_fft=self.hop_length * 4, hop_length=self.hop_length, padding="same" + ) + self.patch_size = patch_size + if self.patch_size != -1: + self.upsampling = StreamingLinearUpsample(scale_factor=patch_size) + + def low_level_reconstruct(self, x, past_key_values=None, use_cache=False, stream_state=None, last_chunk=False): + upsample_state, audio_buffer, window_buffer = stream_state + bsz, device, dtype = x.size(0), x.device, x.dtype + x = self.fc1(x) + if self.patch_size != -1: + if use_cache: + x, upsample_state = self.upsampling(x, state=upsample_state, is_last=last_chunk) + if x is None: + stream_state = (upsample_state, audio_buffer, window_buffer) + return torch.empty(bsz, 1, 0, device=device, dtype=dtype), stream_state, past_key_values + else: + x = self.upsampling.upsampler(x.transpose(1, 2)).transpose(1, 2) + + hidden_states_list = [] + + if use_cache and getattr(self.decoder.config, "sliding_window", None) is not None: + sw_size = self.decoder.config.sliding_window + target_len = sw_size - 1 + if past_key_values is None: + past_len = 0 + elif hasattr(past_key_values, "get_seq_length"): + past_len = past_key_values.get_seq_length() + elif isinstance(past_key_values, tuple) and len(past_key_values) > 0: + past_len = past_key_values[0][0].shape[-2] + else: + past_len = 0 + + curr_len = x.shape[1] + + if past_len < target_len and (past_len + curr_len) >= sw_size: + fill_len = target_len - past_len + x_fill = x[:, :fill_len, :] + outputs = self.decoder(inputs_embeds=x_fill, past_key_values=past_key_values, use_cache=use_cache) + hidden_states_list.append(outputs.last_hidden_state) + past_key_values = outputs.past_key_values + x = x[:, fill_len:, :] + + outputs = self.decoder(inputs_embeds=x, past_key_values=past_key_values, use_cache=use_cache) + hidden_states_list.append(outputs.last_hidden_state) + past_key_values = outputs.past_key_values + + if len(hidden_states_list) > 1: + full_hidden_state = torch.cat(hidden_states_list, dim=1) + else: + full_hidden_state = hidden_states_list[0] + + x_out, _, audio_buffer, window_buffer = self.head( + full_hidden_state, + streaming=use_cache, + audio_buffer=audio_buffer, + window_buffer=window_buffer, + last_chunk=last_chunk, + ) + + stream_state = (upsample_state, audio_buffer, window_buffer) + return x_out, stream_state, past_key_values + + +class Encoder(nn.Module): + def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64, patch_size=-1): + super().__init__() + config = Qwen2Config.from_dict(config_dict=encoder_args) + if flash_attn is None: + config._attn_implementation = "sdpa" + self.encoder = Qwen2Model(config) + self.input_dim = input_dim + self.hop_size = hop_size + self.latent_dim = latent_dim + self.fc1 = nn.Linear(input_dim, config.hidden_size, bias=False) + self.fc2 = nn.Linear(config.hidden_size, config.hidden_size) + self.fc3 = nn.Linear(config.hidden_size, latent_dim * 2) + self.norm = nn.LayerNorm(config.hidden_size) + self.patch_size = patch_size + if patch_size != -1: + config.num_hidden_layers = 4 + self.aggregator = Qwen2Model(config) + self.cls_embed = nn.Parameter(torch.rand(1, 1, config.hidden_size)) + self.cls_embed.data.normal_(0, 0.02) + + def get_frames(self, x): + num_frames_total = (x.size(-1) + self.hop_size - 1) // self.hop_size + expected_len = (num_frames_total - 1) * self.hop_size + self.input_dim + padding_needed = expected_len - x.size(-1) + waveform = F.pad(x, (0, padding_needed), value=0.0) + frames = waveform.unfold(dimension=-1, size=self.input_dim, step=self.hop_size) + return frames + + def pad_patch_insert_cls(self, x): + bsz, _, dim = x.size() + num_frame = x.size(1) + r = num_frame % self.patch_size + pad_num = self.patch_size - r if r else 0 + x = F.pad(x, (0, 0, 0, pad_num), value=0.0) + x = x.reshape(-1, self.patch_size, dim) + x = torch.cat((x, self.cls_embed.expand(x.size(0), -1, -1)), dim=1) + x = x.reshape(bsz, -1, dim) + return x + + def forward(self, waveform): + x = self.get_frames(waveform) + x = self.fc1(x) + x = self.fc2(x) + x = self.encoder(inputs_embeds=x) + x = x.last_hidden_state + + if self.patch_size != -1: + x = self.pad_patch_insert_cls(x) + x = self.aggregator(inputs_embeds=x) + x = x.last_hidden_state + bsz, _, dim = x.size() + x = x.reshape(-1, self.patch_size + 1, dim) + x = x[:, -1:, :].reshape(bsz, -1, dim) + + x = self.fc3(x) + return x, waveform.unsqueeze(1) + + +class AudioVAE(PreTrainedModel): + config_class = AudioVAEConfig + + def __init__(self, config: AudioVAEConfig): + super().__init__(config) + self.encoder = Encoder( + encoder_args=config.enc_kwargs["backbone"], + input_dim=config.enc_kwargs["input_dim"], + hop_size=config.enc_kwargs.get("hop_size", 320), + latent_dim=config.enc_kwargs["latent_dim"], + patch_size=config.patch_size, + ) + self.decoder = Decoder( + decoder_args=config.dec_kwargs["backbone"], + output_dim=config.dec_kwargs["output_dim"], + latent_dim=config.dec_kwargs["latent_dim"], + patch_size=config.patch_size, + ) + self.post_init() + + def _init_weights(self, module): + std = 0.02 + if isinstance(module, nn.Linear): + if self.config.init_method == "kaiming": + nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu") + else: + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def encode_latent(self, waveform, waveform_length): + from diffusers.models.autoencoders.autoencoder_oobleck import OobleckDiagonalGaussianDistribution + + frame_num = torch.ceil(waveform_length / self.config.enc_kwargs["input_dim"]).to(torch.int32) + if self.config.patch_size != -1: + frame_num = torch.ceil(frame_num / self.config.patch_size) + h, y = self.encoder(waveform) + h = h.transpose(1, 2) + + posterior = OobleckDiagonalGaussianDistribution(h) + latent = posterior.sample() + latent = latent.transpose(1, 2) + return latent, frame_num + + def decode(self, latent, past_key_values=None, use_cache=False, stream_state=(None, None, None), last_chunk=False): + waveform, stream_state, past_key_values = self.decoder.low_level_reconstruct( + latent, + past_key_values=past_key_values, + use_cache=use_cache, + stream_state=stream_state, + last_chunk=last_chunk, + ) + return waveform, stream_state, past_key_values diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py index 87728890b6..462c204386 100644 --- a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py +++ b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py @@ -16,7 +16,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Ming-flash-omni-2.0 unified model (thinker + imagegen + talker).""" +"""Ming-flash-omni-2.0 thinker / image-gen wrapper. + +This class is the multimodal-registered entry point for Ming stages that +share the thinker's backbone: comprehension / text generation (`thinker`) +and diffusion conditioning for image generation (`imagegen`, not yet +implemented). + +The talker deliberately lives elsewhere. Upstream Ming hands text (not +hidden states) from the thinker to the talker, and the talker then +tokenises that string with its own Qwen2 tokenizer and runs an entirely +self-contained LLM + CFM + AudioVAE pipeline. Because it has no +multimodal inputs, it belongs in the non-MM-registered +`MingFlashOmniTalkerForConditionalGeneration` — routing it through +this wrapper would force it through vLLM's multimodal preprocess path +and trigger a hidden-size mismatch between the outer Ming config +(4096, thinker's LLM) and the talker's Qwen2 backbone (896). +""" from collections.abc import Iterable @@ -40,7 +56,10 @@ from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin from vllm_omni.model_executor.models.output_templates import OmniOutput from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights -from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMM2Config, MingFlashOmniConfig +from vllm_omni.transformers_utils.configs.ming_flash_omni import ( + BailingMM2Config, + MingFlashOmniConfig, +) from .ming_flash_omni_thinker import ( MingFlashOmniThinkerDummyInputsBuilder, @@ -63,7 +82,7 @@ class MingFlashOmniForConditionalGeneration( SupportsMRoPE, CustomProcessMixin, ): - """Unified Ming-flash-omni-2.0 model combining thinker, imagegen, and talker.""" + """Ming-flash-omni-2.0 thinker + image-gen wrapper.""" supports_multimodal = True requires_raw_input_tokens: bool = True @@ -75,19 +94,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.has_postprocess = False config = vllm_config.model_config.hf_config - - self.vllm_config = vllm_config self.config = config - - if isinstance(config, MingFlashOmniConfig): - thinker_config = config.thinker_config - else: - thinker_config = config - - self.thinker_config: BailingMM2Config = thinker_config self.model_stage = vllm_config.model_config.model_stage + if self.model_stage == "talker": + raise ValueError( + "MingFlashOmniForConditionalGeneration does not support " + "model_stage='talker'. Use " + "model_arch='MingFlashOmniTalkerForConditionalGeneration' " + "directly — the talker has a self-contained LLM that " + "tokenises text itself and does not need the multimodal " + "preprocess path. See stage_configs/ming_flash_omni.yaml " + "stage 1 and stage_configs/ming_flash_omni_tts.yaml." + ) + if self.model_stage == "thinker": + if isinstance(config, MingFlashOmniConfig): + thinker_config: BailingMM2Config = config.thinker_config + else: + thinker_config: BailingMM2Config = config + thinker_vllm_config = vllm_config.with_hf_config( thinker_config, architectures=["MingFlashOmniThinkerForConditionalGeneration"] ) @@ -97,8 +123,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): architectures=["MingFlashOmniThinkerForConditionalGeneration"], ) self.model = self.thinker - self.imagegen = None - self.talker = None + self.make_empty_intermediate_tensors = self.thinker.make_empty_intermediate_tensors elif self.model_stage == "imagegen": # TODO: Implement image generator stage @@ -106,22 +131,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): "Image generation stage is not yet implemented. Please use model_stage='thinker' for now." ) - elif self.model_stage == "talker": - # TODO: Implement talker (TTS) stage - raise NotImplementedError( - "Talker (TTS) stage is not yet implemented. Please use model_stage='thinker' for now." - ) - else: raise ValueError( - f"Invalid model_stage: {self.model_stage}. Must be one of: 'thinker', 'imagegen', 'talker'" + f"Invalid model_stage: {self.model_stage!r}. Must be one of: 'thinker', 'imagegen'. " + f"For the talker stage, use MingFlashOmniTalkerForConditionalGeneration directly." ) - # Set up intermediate tensors - self.make_empty_intermediate_tensors = ( - self.thinker.make_empty_intermediate_tensors if self.model_stage == "thinker" else lambda: None - ) - def forward( self, input_ids: torch.Tensor, @@ -154,7 +169,7 @@ def sample( ): if hasattr(self.model, "sample"): return self.model.sample(logits, sampling_metadata) - raise NotImplementedError("sample method not available on current stage") + return None def get_mrope_input_positions(self, *args, **kwargs): if hasattr(self.model, "get_mrope_input_positions"): @@ -162,36 +177,9 @@ def get_mrope_input_positions(self, *args, **kwargs): raise NotImplementedError("get_mrope_input_positions not available on current stage") def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loaded_weights = set() - thinker_weights = [] - imagegen_weights = [] - talker_weights = [] - - for name, value in weights: - if name.startswith("thinker."): - thinker_weights.append((name, value)) - elif name.startswith("imagegen."): - imagegen_weights.append((name, value)) - elif name.startswith("talker."): - talker_weights.append((name, value)) - else: - # Weights without prefix go to thinker by default - thinker_weights.append((name, value)) - - if self.model_stage == "thinker" and thinker_weights: - # Remove "thinker." prefix before loading - thinker_weights_stripped = [ - (name.replace("thinker.", "", 1) if name.startswith("thinker.") else name, value) - for name, value in thinker_weights - ] - thinker_loaded = self.thinker.load_weights(thinker_weights_stripped) - thinker_loaded = add_prefix_to_loaded_weights(thinker_loaded, "thinker") - loaded_weights.update(thinker_loaded) - - # TODO: Load imagegen weights when implemented - # TODO: Load talker weights when implemented - - return loaded_weights + stripped = ((name.removeprefix("thinker."), value) for name, value in weights) + thinker_loaded = self.thinker.load_weights(stripped) + return add_prefix_to_loaded_weights(thinker_loaded, "thinker") def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( @@ -202,9 +190,7 @@ def get_mm_mapping(self) -> MultiModelKeys: @property def sampler(self): - if hasattr(self.model, "sampler"): - return self.model.sampler - return None + return getattr(self.model, "sampler", None) def embed_input_ids( self, diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_talker.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_talker.py new file mode 100644 index 0000000000..08ed9e8547 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_talker.py @@ -0,0 +1,586 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright (c) Ant Group. All rights reserved. +# Adapted from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/modeling_bailing_talker.py +"""Ming-flash-omni-2.0 talker (TTS) stage model.""" + +from __future__ import annotations + +import glob as glob_module +import os +from collections.abc import Iterable +from dataclasses import dataclass +from functools import cached_property +from typing import Any + +import torch +import torch.nn as nn +from safetensors.torch import load_file +from transformers import AutoTokenizer, Qwen2Config, Qwen2Model +from transformers.utils.hub import cached_file +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.transformers_utils.configs.ming_flash_omni import MingFlashOmniTalkerConfig + +from .audio_vae import AudioVAE, AudioVAEConfig +from .prompt_utils import DEFAULT_PROMPT as MING_DEFAULT_PROMPT +from .talker_module import CFM, Aggregator, DiT, MingAudioGenerator, build_tts_input +from .text_processing import segment_and_normalize +from .voice_presets import VoicePresetRegistry + +logger = init_logger(__name__) + + +@dataclass(slots=True) +class _GenerationParams: + """Resolved sampling / decoding parameters for one forward call.""" + + prompt: str + instruction: str | None + cfg: float + sigma: float + temperature: float + max_steps: int + use_zero_spk_emb: bool + max_text_length: int + use_static_cache: bool + stream_decode: bool + + +@dataclass(slots=True) +class _VoiceContext: + """Voice cloning inputs resolved from request info + presets.""" + + spk_emb: Any # list[Tensor] | Tensor | list[float] | None + prompt_text: str | None + prompt_wav_lat: torch.Tensor | None + prompt_wav_emb: torch.Tensor | None + already_projected: bool + + +class MingFlashOmniTalkerForConditionalGeneration(nn.Module, CustomProcessMixin): + """Ming-flash-omni-2.0 talker stage: text -> audio waveform. + + Uses Qwen2 LLM + CFM (Conditional Flow Matching with DiT) + Aggregator + in an autoregressive loop to produce continuous audio latents, then + AudioVAE decodes latents to waveforms. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.have_multimodal_outputs = True + self.has_preprocess = False + self.has_postprocess = False + + self.vllm_config = vllm_config + root_config = vllm_config.model_config.hf_config + + model_path = vllm_config.model_config.model + self._model_path = model_path + self.talker_dir = ( + os.path.join(model_path, "talker") if os.path.isdir(os.path.join(model_path, "talker")) else model_path + ) + + # When used standalone (model_arch=MingFlashOmniTalkerForConditionalGeneration), + # the root hf_config may be BailingMM2Config (thinker-only) due to model file structure + # Resolve talker config from talker/config.json in that case. + config = ( + root_config + if isinstance(root_config, MingFlashOmniTalkerConfig) + else self._resolve_talker_config(root_config, self.talker_dir, model_path) + ) + self.config = config + + self._standalone = prefix in ("", "talker") + if self._standalone: + self.allow_patterns_overrides = ["talker/model*.safetensors"] + self.fall_back_to_pt_during_load = False + + # LLM + llm_config = self._resolve_llm_config(config, self.talker_dir, model_path) + llm_config._attn_implementation = "sdpa" + self.llm_config = llm_config + self.hidden_size = llm_config.hidden_size + self.latent_dim = config.latent_dim + self.patch_size = config.patch_size + self.his_patch_size = config.history_patch_size + self.cfg_strength = config.cfg_strength + + self.model = Qwen2Model(llm_config) + self.cfm = CFM( + DiT(llm_input_dim=self.hidden_size, **config.flowmodel), + steps=config.steps, + ) + self.aggregator = Aggregator(llm_input_dim=self.hidden_size, **config.aggregator) + self.stop_head = nn.Linear(self.hidden_size, 2, bias=True) + # CAMPPlus 192-dim -> hidden + self.spk_head = nn.Linear(192, self.hidden_size, bias=True) + + # AudioVAE + self.audio_vae, self._vae_weight_source = self._init_audio_vae(config, self.talker_dir, model_path) + + self._use_cuda_graphs = not vllm_config.model_config.enforce_eager + + self.audio_generator = MingAudioGenerator( + config=self.config, + llm_config=self.llm_config, + model=self.model, + cfm=self.cfm, + aggregator=self.aggregator, + stop_head=self.stop_head, + audio_vae=self.audio_vae, + patch_size=self.patch_size, + his_patch_size=self.his_patch_size, + latent_dim=self.latent_dim, + cfg_strength=self.cfg_strength, + use_cuda_graphs=self._use_cuda_graphs, + ) + self.voice_presets = VoicePresetRegistry( + talker_dir=self.talker_dir, + model_path=self._model_path, + download_dir=vllm_config.load_config.download_dir, + audio_vae=self.audio_vae, + aggregator=self.aggregator, + spk_head=self.spk_head, + patch_size=self.patch_size, + ) + + @property + def device(self) -> torch.device: + return next(self.model.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.model.parameters()).dtype + + @cached_property + def tokenizer(self): + # Lazy Qwen2 tokenizer resolution: + # 1. Try local dirs first (talker/llm, talker, and then model root). + # 2. HF repo-id fallback: talker/llm is the canonical tokenizer location. + candidates = (os.path.join(self.talker_dir, "llm"), self.talker_dir, self._model_path) + for path in candidates: + if os.path.isdir(path): + try: + logger.debug("Resolving talker tokenizer from local dir %s", path) + return AutoTokenizer.from_pretrained(path, trust_remote_code=True) + except Exception: + continue + for subfolder in ("talker/llm", "llm"): + try: + logger.debug("Resolving talker tokenizer from HF subfolder %s", subfolder) + return AutoTokenizer.from_pretrained(self._model_path, subfolder=subfolder, trust_remote_code=True) + except Exception: + continue + logger.debug("Falling back to raw model_path tokenizer resolution") + return AutoTokenizer.from_pretrained(self._model_path, trust_remote_code=True) + + @staticmethod + def _resolve_talker_config(config, talker_dir: str, model_path: str) -> MingFlashOmniTalkerConfig: + """Resolve MingFlashOmniTalkerConfig when the root config is not one. + + This happens in standalone TTS mode where hf_config is BailingMM2Config. + """ + # If the root config wraps a talker_config, use it + talker_config = getattr(config, "talker_config", None) + if isinstance(talker_config, MingFlashOmniTalkerConfig): + return talker_config + + # Try loading from talker/config.json + if os.path.isdir(talker_dir): + try: + resolved = MingFlashOmniTalkerConfig.from_pretrained(talker_dir) + logger.info("Resolved talker config from %s", talker_dir) + return resolved + except Exception: + pass + + try: + resolved = MingFlashOmniTalkerConfig.from_pretrained(model_path, subfolder="talker", trust_remote_code=True) + logger.info("Resolved talker config from %s/talker (HF hub)", model_path) + return resolved + except Exception as e: + raise ValueError( + f"Cannot resolve MingFlashOmniTalkerConfig. The root config " + f"is {type(config).__name__}, and talker/config.json was not " + f"found at {talker_dir} or via HF hub: {e}" + ) from e + + @staticmethod + def _resolve_llm_config(config: MingFlashOmniTalkerConfig, talker_dir: str, model_path: str) -> Qwen2Config: + """Resolve the Qwen2 LLM config for the talker backbone.""" + + if config.llm_config is not None: + return Qwen2Config(**config.llm_config) if isinstance(config.llm_config, dict) else config.llm_config + + # Try local talker/llm directory + llm_dir = os.path.join(talker_dir, "llm") + if os.path.isdir(llm_dir): + return Qwen2Config.from_pretrained(llm_dir) + + # HF hub fallback + for subfolder in ("talker/llm", "llm"): + try: + return Qwen2Config.from_pretrained(model_path, subfolder=subfolder, trust_remote_code=True) + except Exception: + continue + + raise ValueError( + f"Cannot find talker LLM config at {llm_dir}. " + "Either provide llm_config in MingFlashOmniTalkerConfig or " + "ensure the model path contains talker/llm/config.json." + ) + + @staticmethod + def _init_audio_vae( + config: MingFlashOmniTalkerConfig, talker_dir: str, model_path: str + ) -> tuple[AudioVAE | None, str | tuple[str, str] | None]: + """Initialize AudioVAE and return (vae, weight_source). + + weight_source is either a local directory path (str) or an + (repo_id, subfolder) tuple for HF hub downloads, or None. + """ + vae_path = config.audio_vae_path or os.path.join(talker_dir, "vae") + + # Try local directory first + if os.path.isdir(vae_path): + try: + vae_config = AudioVAEConfig.from_pretrained(vae_path) + vae = AudioVAE(vae_config) + logger.info("Initialized AudioVAE from %s (sr=%d)", vae_path, vae_config.sample_rate) + return vae, vae_path + except Exception as e: + logger.warning("Failed to initialize AudioVAE from %s: %s", vae_path, e) + return None, None + + # HF hub fallback + for subfolder in ("talker/vae", "vae"): + try: + vae_config = AudioVAEConfig.from_pretrained(model_path, subfolder=subfolder, trust_remote_code=True) + vae = AudioVAE(vae_config) + logger.info(f"Initialized AudioVAE from {model_path}/{subfolder}") + return vae, (model_path, subfolder) + except Exception: + continue + + logger.info("AudioVAE not found at %s; waveform decoding unavailable", vae_path) + return None, None + + def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata=None) -> torch.Tensor | None: + return None + + def sample(self, logits: torch.Tensor, sampling_metadata): + return None + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings=None, + is_multimodal=None, + ) -> torch.Tensor: + return self.model.get_input_embeddings()(input_ids) + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors | None: + return None + + def get_dummy_runtime_additional_information(self, num_reqs: int) -> list[dict[str, object]]: + info: dict[str, object] = {"text": "dummy", "use_zero_spk_emb": True, "max_steps": 1} + return [info for _ in range(num_reqs)] + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + runtime_additional_information: list[dict] | None = None, + **kwargs, + ) -> OmniOutput: + """Run TTS generation and return audio output. + + The full autoregressive generation loop is executed inside this method. + """ + additional_info = self._extract_additional_info(runtime_additional_information) + params = self._resolve_generation_params(additional_info) + voice = self._resolve_voice(additional_info) + + latents = self._generate_latents( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + text=additional_info.get("text", ""), + params=params, + voice=voice, + ) + return self._decode_to_output(latents, stream_decode=params.stream_decode) + + @staticmethod + def _extract_additional_info( + runtime_additional_information: list[dict] | None, + ) -> dict[str, Any]: + if runtime_additional_information and len(runtime_additional_information) > 0: + return runtime_additional_information[0] or {} + return {} + + def _resolve_generation_params(self, additional_info: dict[str, Any]) -> _GenerationParams: + # "omni" : thinker -> talker hand-off with hardcoded defaults + # "instruct": standalone TTS with caller-supplied sampling knobs + ming_task = additional_info.get("ming_task", "instruct") + + if ming_task == "omni": + prompt = MING_DEFAULT_PROMPT + instruction = None + use_zero_spk_emb = additional_info.get("spk_emb") is None + cfg = 2.0 + sigma = 0.25 + temperature = 0.0 + max_steps = 200 + else: + prompt = additional_info.get("prompt", MING_DEFAULT_PROMPT) + instruction = additional_info.get("instruction", None) + use_zero_spk_emb = additional_info.get("use_zero_spk_emb", False) + cfg = additional_info.get("cfg", self.cfg_strength) + sigma = additional_info.get("sigma", 0.25) + temperature = additional_info.get("temperature", 0.0) + max_steps = int(additional_info.get("max_steps", additional_info.get("max_decode_steps", 200))) + + return _GenerationParams( + prompt=prompt, + instruction=instruction, + cfg=cfg, + sigma=sigma, + temperature=temperature, + max_steps=max_steps, + use_zero_spk_emb=use_zero_spk_emb, + max_text_length=int(additional_info.get("max_text_length", 50)), + use_static_cache=bool(additional_info.get("use_static_cache", True)), + stream_decode=bool(additional_info.get("stream_decode", True)), + ) + + def _resolve_voice(self, additional_info: dict[str, Any]) -> _VoiceContext: + spk_emb = additional_info.get("spk_emb", None) + prompt_text = additional_info.get("prompt_text", None) + prompt_wav_lat = additional_info.get("prompt_wav_lat", None) + prompt_wav_emb = additional_info.get("prompt_wav_emb", None) + already_projected = False + + voice_name = additional_info.get("voice_name", None) + if voice_name and spk_emb is None and voice_name in self.voice_presets: + preset = self.voice_presets.get(voice_name) or {} + prompt_wav_lat = preset.get("prompt_wav_lat") + prompt_wav_emb = preset.get("prompt_wav_emb") + spk_emb = preset.get("spk_emb") + already_projected = True + if prompt_text is None: + prompt_text = preset.get("prompt_text") + + return _VoiceContext( + spk_emb=spk_emb, + prompt_text=prompt_text, + prompt_wav_lat=prompt_wav_lat, + prompt_wav_emb=prompt_wav_emb, + already_projected=already_projected, + ) + + def _project_spk_emb( + self, spk_emb: Any, already_projected: bool, use_zero_spk_emb: bool + ) -> list[torch.Tensor] | None: + if spk_emb is None: + if use_zero_spk_emb: + return [torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)] + return None + + if already_projected: + return spk_emb if isinstance(spk_emb, list) else [spk_emb] + + if isinstance(spk_emb, torch.Tensor): + tensors = [spk_emb] + elif isinstance(spk_emb, list) and spk_emb and isinstance(spk_emb[0], (int, float)): + tensors = [torch.tensor(spk_emb, dtype=self.dtype).unsqueeze(0)] + elif isinstance(spk_emb, list): + tensors = spk_emb + else: + tensors = [spk_emb] + return [self.spk_head(t.to(device=self.device, dtype=self.dtype)) for t in tensors] + + def _generate_latents( + self, + *, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor | None, + text: str, + params: _GenerationParams, + voice: _VoiceContext, + ) -> list[torch.Tensor]: + generator = self.audio_generator + + if inputs_embeds is not None: + # Caller pre-built embeddings — run a single AR pass. + return generator.generate_latents( + inputs_embeds=inputs_embeds, + prompt_wav_lat=voice.prompt_wav_lat, + max_steps=params.max_steps, + cfg=params.cfg, + sigma=params.sigma, + temperature=params.temperature, + use_static_cache=params.use_static_cache, + ) + + spk_emb = self._project_spk_emb(voice.spk_emb, voice.already_projected, params.use_zero_spk_emb) + text_segments = segment_and_normalize(text, max_length=params.max_text_length) if text else [] + + if not text_segments: + # vLLM passes 1D input_ids; Qwen2Model expects (batch, seq). + inputs_embeds = self.model.get_input_embeddings()(input_ids.to(self.device)).unsqueeze(0) + return generator.generate_latents( + inputs_embeds=inputs_embeds, + prompt_wav_lat=voice.prompt_wav_lat, + max_steps=params.max_steps, + cfg=params.cfg, + sigma=params.sigma, + temperature=params.temperature, + use_static_cache=params.use_static_cache, + ) + + all_latents: list[torch.Tensor] = [] + for segment in text_segments: + seg_embeds, _ = build_tts_input( + tokenizer=self.tokenizer, + embed_tokens=self.model.get_input_embeddings(), + device=self.device, + dtype=torch.bfloat16, + text=segment, + prompt=params.prompt, + spk_emb=spk_emb, + instruction=params.instruction, + prompt_text=voice.prompt_text, + prompt_wav_emb=voice.prompt_wav_emb, + ) + effective_max_steps = generator.duration_capped_steps(len(segment), params.max_steps) + all_latents.extend( + generator.generate_latents( + inputs_embeds=seg_embeds, + prompt_wav_lat=voice.prompt_wav_lat, + max_steps=effective_max_steps, + cfg=params.cfg, + sigma=params.sigma, + temperature=params.temperature, + use_static_cache=params.use_static_cache, + ) + ) + return all_latents + + def _decode_to_output(self, latents: list[torch.Tensor], *, stream_decode: bool) -> OmniOutput: + multimodal_outputs: dict[str, Any] = {} + if latents and self.audio_vae is not None: + waveform = self.audio_generator.decode_to_waveform(latents, stream_decode=stream_decode) + if not stream_decode: + waveform = self.audio_generator.trim_trailing_silence(waveform) + multimodal_outputs["audio"] = waveform.detach().float().cpu() + multimodal_outputs["sr"] = torch.tensor(self.audio_vae.config.sample_rate) + elif latents: + all_lat = torch.cat(latents, dim=1) + multimodal_outputs["audio_latents"] = all_lat.detach().float().cpu() + + return OmniOutput(text_hidden_states=None, multimodal_outputs=multimodal_outputs) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights for all talker components. + + The talker's HF checkpoint (talker/model.safetensors) stores + weights with prefixes matching this module's submodule names directly. + And AudioVAE weights live in a separate file under talker/vae/ + """ + # Standalone: bypass the default loader's iterator (torch.load on + # .safetensors crashes) and read talker/model*.safetensors directly. + if self._standalone: + weights = self._iter_talker_safetensors() + + loader = AutoWeightsLoader( + self, + skip_prefixes=["audio_vae."], # loaded separately + skip_substrs=["rotary_embed.inv_freq"], # non-persistent buffer + ) + loaded = loader.load_weights(weights) + logger.info("Loaded %d talker weights from checkpoint", len(loaded)) + + if self.audio_vae is not None and self._vae_weight_source is not None: + loaded.update(self._load_vae_weights()) + + # Register voice presets after all weights (incl. VAE) are loaded. + try: + self.voice_presets.load_presets_from_manifest(device=self.device, dtype=self.dtype) + except Exception as e: # pragma: no cover — best-effort + logger.warning("Voice preset loading failed (non-fatal): %s", e) + + return loaded + + def _iter_talker_safetensors(self) -> Iterable[tuple[str, torch.Tensor]]: + """Yield (name, tensor) pairs from talker/model*.safetensors.""" + model_path = self._model_path + # Try local path first + for candidate in (os.path.join(model_path, "talker"), model_path): + sf_files = sorted(glob_module.glob(os.path.join(candidate, "model*.safetensors"))) + if sf_files: + for sf_path in sf_files: + yield from load_file(sf_path, device="cpu").items() + return + + # HF hub fallback: download only the talker checkpoint files + model_root = download_weights_from_hf_specific( + model_path, + self.vllm_config.load_config.download_dir, + allow_patterns=["talker/model*.safetensors"], + ) + talker_dir = os.path.join(model_root, "talker") + sf_files = sorted(glob_module.glob(os.path.join(talker_dir, "model*.safetensors"))) + if not sf_files: + raise RuntimeError(f"No talker safetensors found under {model_root}. Expected talker/model*.safetensors.") + for sf_path in sf_files: + yield from load_file(sf_path, device="cpu").items() + + def _load_vae_weights(self) -> set[str]: + """Load AudioVAE weights from talker/vae/model.safetensors.""" + if self.audio_vae is None or self._vae_weight_source is None: + return set() + + # Resolve safetensors file paths from the weight source + safetensors_files: list[str] = [] + source = self._vae_weight_source + if isinstance(source, str): + # Local directory path + safetensors_files = sorted(glob_module.glob(os.path.join(source, "*.safetensors"))) + elif isinstance(source, tuple): + # (repo_id, subfolder) for HF hub + repo_id, subfolder = source + for filename in ("model.safetensors", "diffusion_pytorch_model.safetensors"): + try: + cached = cached_file(repo_id, filename, subfolder=subfolder) + except Exception: + cached = None + if cached is not None: + safetensors_files.append(cached) + break + + if not safetensors_files: + logger.warning("No AudioVAE safetensors files found for source=%s", source) + return set() + + vae_state_keys = set(self.audio_vae.state_dict().keys()) + vae_loader = AutoWeightsLoader(self.audio_vae) + loaded: set[str] = set() + for sf_path in safetensors_files: + file_weights = load_file(sf_path, device="cpu") + matched = ((name, tensor) for name, tensor in file_weights.items() if name in vae_state_keys) + loaded.update(f"audio_vae.{name}" for name in vae_loader.load_weights(matched)) + + logger.info("Loaded %d AudioVAE weights from %s", len(loaded), source) + return loaded diff --git a/vllm_omni/model_executor/models/ming_flash_omni/prompt_utils.py b/vllm_omni/model_executor/models/ming_flash_omni/prompt_utils.py new file mode 100644 index 0000000000..4271114bc2 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/prompt_utils.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright (c) Ant Group. All rights reserved. +# Adapted from Ming repo's usage cookbook: +# https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/cookbook.ipynb +"""Shared prompt-building helpers for Ming-flash-omni standalone talker.""" + +import copy +import json +from typing import Any + +DEFAULT_PROMPT = "Please generate speech based on the following description.\n" + +BASE_CAPTION_TEMPLATE: dict[str, Any] = { + "audio_sequence": [ + { + "序号": 1, + "说话人": "speaker_1", + "方言": None, + "风格": None, + "语速": None, + "基频": None, + "音量": None, + "情感": None, + "BGM": { + "Genre": None, + "Mood": None, + "Instrument": None, + "Theme": None, + "ENV": None, + "SNR": None, + }, + "IP": None, + } + ] +} + + +def create_instruction(user_input: dict[str, Any]) -> str: + """Return a JSON caption string for ``audio_sequence[0]``. + + Only keys already present on the base template are merged in; unknown + keys are silently ignored to keep the output schema stable. + """ + caption = copy.deepcopy(BASE_CAPTION_TEMPLATE) + item = caption["audio_sequence"][0] + for key, value in user_input.items(): + if key in item: + item[key] = value + return json.dumps(caption, ensure_ascii=False) diff --git a/vllm_omni/model_executor/models/ming_flash_omni/spk_embedding.py b/vllm_omni/model_executor/models/ming_flash_omni/spk_embedding.py new file mode 100644 index 0000000000..68dbfe6502 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/spk_embedding.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright (c) Ant Group. All rights reserved. +# Ported from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/modeling_bailing_talker.py + +from __future__ import annotations + +import torch + + +class SpkembExtractor: + """CAMPPlus ONNX-based speaker embedding extractor (runs on CPU).""" + + def __init__(self, campplus_model: str, target_sr: int = 16000): + import onnxruntime + import torchaudio.compliance.kaldi as kaldi + + self.kaldi = kaldi + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 2 + self.campplus_session = onnxruntime.InferenceSession( + campplus_model, sess_options=option, providers=["CPUExecutionProvider"] + ) + self.target_sr = target_sr + + def _extract_spk_embedding(self, speech): + feat = self.kaldi.fbank(speech, num_mel_bins=80, dither=0, sample_frequency=16000) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = ( + self.campplus_session.run( + None, + {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}, + )[0] + .flatten() + .tolist() + ) + embedding = torch.tensor([embedding]) + return embedding + + def __call__(self, waveform, **kwargs) -> torch.Tensor | None: + spk_emb = self._extract_spk_embedding(waveform) + return spk_emb diff --git a/vllm_omni/model_executor/models/ming_flash_omni/talker_module.py b/vllm_omni/model_executor/models/ming_flash_omni/talker_module.py new file mode 100644 index 0000000000..80acbaad06 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/talker_module.py @@ -0,0 +1,1151 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright (c) Ant Group. All rights reserved. +# Ported from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/dit.py +# +# Ported from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/modules.py +# Ported from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/cfm.py + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# Partial of the following source code +# is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import logging +import math +from functools import cached_property +from queue import Queue +from threading import Lock + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PreTrainedTokenizerBase, Qwen2Config, Qwen2Model, StaticCache +from vllm.logger import init_logger +from x_transformers.x_transformers import RotaryEmbedding, apply_rotary_pos_emb + +from .audio_vae import AudioVAE + +logger = init_logger(__name__) + + +######################################################################## +# DiT Modules +# Ported from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/modules.py +# Ported from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/dit.py +######################################################################## + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.to(self.weight.dtype) + x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps) + return x + + +class FeedForward(nn.Module): + def __init__( + self, dim: int, dim_out: int | None = None, mult: float = 4, dropout: float = 0.0, approximate: str = "none" + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.ff(x) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + qk_norm: str | None = None, + pe_attn_head: int | None = None, + attn_mask_enabled: bool = True, + ): + super().__init__() + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + if qk_norm is None: + self.q_norm = None + self.k_norm = None + elif qk_norm == "rms_norm": + self.q_norm = RMSNorm(dim_head) + self.k_norm = RMSNorm(dim_head) + else: + raise ValueError(f"Unimplemented qk_norm: {qk_norm}") + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, dim)) + self.to_out.append(nn.Dropout(dropout)) + + self.pe_attn_head = pe_attn_head + self.attn_mask_enabled = attn_mask_enabled + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None = None, + rope: tuple[torch.Tensor, torch.Tensor | None] | None = None, + ) -> torch.Tensor: + batch_size = x.shape[0] + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + if self.q_norm is not None: + query = self.q_norm(query) + if self.k_norm is not None: + key = self.k_norm(key) + + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + + if self.pe_attn_head is not None: + on = self.pe_attn_head + query[:, :on, :, :] = apply_rotary_pos_emb(query[:, :on, :, :], freqs, q_xpos_scale) + key[:, :on, :, :] = apply_rotary_pos_emb(key[:, :on, :, :], freqs, k_xpos_scale) + else: + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + if self.attn_mask_enabled and mask is not None: + valid_sample_indices = mask.any(dim=1) + final_output = torch.zeros_like(query).to(query.device) + + attn_mask = mask[valid_sample_indices] + query = query[valid_sample_indices] + key = key[valid_sample_indices] + value = value[valid_sample_indices] + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) + attn_mask = attn_mask.expand(valid_sample_indices.sum().item(), self.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + + x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + if self.attn_mask_enabled and mask is not None: + final_output[valid_sample_indices] = x + x = final_output + + x = x.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + x = x.to(query.dtype) + + x = self.to_out[0](x) + x = self.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +class DiTBlock(nn.Module): + """A DiT block with pre-norm and residual connections.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + dropout: float = 0.1, + qk_norm: str | None = None, + pe_attn_head: int | None = None, + attn_mask_enabled: bool = True, + **kwargs, + ): + super().__init__() + self.norm1 = RMSNorm(hidden_size) + self.attn = Attention( + dim=hidden_size, + heads=num_heads, + dim_head=hidden_size // num_heads, + dropout=dropout, + qk_norm=qk_norm, + pe_attn_head=pe_attn_head, + attn_mask_enabled=attn_mask_enabled, + ) + self.norm2 = RMSNorm(hidden_size) + self.mlp = FeedForward(dim=hidden_size, mult=mlp_ratio, dropout=dropout, approximate="tanh") + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None, + rope: tuple[torch.Tensor, torch.Tensor | None] | None, + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), mask=mask, rope=rope) + x = x + self.mlp(self.norm2(x)) + return x + + +class FinalLayer(nn.Module): + """The final layer of DiT.""" + + def __init__(self, hidden_size: int, out_channels: int): + super().__init__() + self.norm_final = RMSNorm(hidden_size) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm_final(x) + x = self.linear(x) + return x + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor, scale: float = 1000) -> torch.Tensor: + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class TimestepEmbedder(nn.Module): + def __init__(self, dim: int, freq_embed_dim: int = 256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + + def forward(self, timestep: torch.Tensor) -> torch.Tensor: + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) + return time + + +class CondEmbedder(nn.Module): + """Embeds LLM hidden states with optional CFG dropout.""" + + def __init__(self, input_feature_size: int, hidden_size: int): + super().__init__() + self.cond_embedder = nn.Linear(input_feature_size, hidden_size) + + def forward(self, llm_cond: torch.Tensor) -> torch.Tensor: + return self.cond_embedder(llm_cond) + + +class DiT(nn.Module): + """Diffusion model with a Transformer backbone for audio latent generation.""" + + def __init__( + self, + in_channels: int = 64, + hidden_size: int = 1024, + depth: int = 28, + num_heads: int = 16, + mlp_ratio: float = 4.0, + llm_cond_dim: int = 896, + **kwargs, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.num_heads = num_heads + + self.t_embedder = TimestepEmbedder(hidden_size) + self.x_embedder = nn.Linear(in_channels, hidden_size) + self.c_embedder = CondEmbedder(llm_cond_dim, hidden_size) + if "spk_dim" in kwargs: + self.spk_embedder = nn.Linear(kwargs["spk_dim"], hidden_size) + else: + self.spk_embedder = None + self.hidden_size = hidden_size + + self.rotary_embed = RotaryEmbedding(hidden_size // num_heads) + + self.blocks = nn.ModuleList( + [DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **kwargs) for _ in range(depth)] + ) + self.final_layer = FinalLayer(hidden_size, self.out_channels) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + c: torch.Tensor, + latent_history: torch.Tensor, + spk_emb: torch.Tensor | None = None, + ) -> torch.Tensor: + x = torch.cat([latent_history, x], dim=1) + x = self.x_embedder(x) + t = self.t_embedder(t).unsqueeze(1) + c = self.c_embedder(c) + y = t + c + if spk_emb is None: + assert self.spk_embedder is None + x = torch.cat([y, x], dim=1) + else: + x = torch.cat([self.spk_embedder(spk_emb), y, x], dim=1) + rope = self.rotary_embed.forward_from_seq_len(x.shape[1]) + + for block in self.blocks: + x = block(x, None, rope) + x = self.final_layer(x) + return x + + def forward_with_cfg( + self, + x: torch.Tensor, + t: torch.Tensor, + c: torch.Tensor, + latent_history: torch.Tensor, + spk_emb: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward with classifier-free guidance (doubles batch for CFG).""" + x = torch.cat([x, x], dim=0) + latent_history = torch.cat([latent_history, latent_history], dim=0) + fake_latent = torch.zeros_like(c) + c = torch.cat([c, fake_latent], dim=0) + if t.ndim == 0: + t = t.repeat(x.shape[0]) + if spk_emb is not None: + spk_emb = torch.cat([spk_emb, spk_emb], dim=0) + model_out = self.forward(x, t, c, latent_history, spk_emb) + return model_out[:, -x.shape[1] :, :] + + +######################################################################################### +# CFM +# Ported from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/cfm.py +######################################################################################### + + +def get_epss_timesteps(n, device, dtype): + dt = 1 / 32 + predefined_timesteps = { + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], + 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], + } + t = predefined_timesteps.get(n, []) + if not t: + return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) + return dt * torch.tensor(t, device=device, dtype=dtype) + + +class CFM(nn.Module): + """Conditional Flow Matching module for audio latent generation.""" + + def __init__(self, model: nn.Module, steps: int = 10, sway_sampling_coef: float | None = -1.0): + """ + Args: + model: DiT used for the velocity prediction. + steps: number of integration steps per sample call. + sway_sampling_coef: coefficient used to skew the integration + grid towards low-noise timesteps. Defaults to -1.0 which + packs more steps near t=0, where prediction error is highest. + Set to `None` to use the linear grid as-is. + """ + super().__init__() + self.model = model + self.steps = steps + self.sway_sampling_coef = sway_sampling_coef + + @torch.no_grad() + def sample( + self, + llm_cond: torch.Tensor, + lat_cond: torch.Tensor, + y0: torch.Tensor, + t: torch.Tensor, + sde_args: torch.Tensor, + sde_rnd: torch.Tensor, + ): + """Sample audio latent via ODE/SDE integration with CFG. + + Args: + llm_cond: LLM hidden state (B, 1, hidden_size) + lat_cond: latent history (B, his_patch_size, latent_dim) + y0: initial noise (B, patch_size, latent_dim) + t: timesteps from get_epss_timesteps + sde_args: [cfg_strength, sigma, temperature] + sde_rnd: random noise for SDE steps (steps, B, patch_size, latent_dim) + """ + + def fn(fn_t, x): + pred_cfg = self.model.forward_with_cfg(x, fn_t, llm_cond, lat_cond, None) + pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) + return pred + (pred - null_pred) * sde_args[0] + + if self.sway_sampling_coef is not None: + t = t + self.sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + for step in range(self.steps): + dt = t[step + 1] - t[step] + y0 = y0 + fn(t[step], y0) * dt + y0 = y0 + sde_args[1] * (sde_args[2] ** 0.5) * (dt.abs() ** 0.5) * sde_rnd[step] + + return y0 + + +class CFMGraphExecutor: + """CUDA graph-accelerated executor for CFM + Aggregator + StopHead pipeline.""" + + def __init__(self, config, cfm, aggregator, stop_head: nn.Linear): + self.config = config + self.cfm = cfm + self.aggregator = aggregator + self.stop_head = stop_head + self.initialized = False + + self.last_hidden_state_placeholder = None + self.his_lat_placeholder = None + self.randn_like_placeholder = None + self.t_placeholder = None + self.sde_args_placeholder = None + self.sde_rnd_placeholder = None + self.gen_lat_placeholder = None + self.inputs_embeds_placeholder = None + self.stop_out_placeholder = None + self.graph = None + + def execute( + self, + input_tensor: torch.Tensor, + his_lat: torch.Tensor, + cfg_strength: float = 2.0, + sigma: float = 0.25, + temperature: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bat_size, his_patch_size, z_dim = his_lat.shape + randn_tensor = torch.randn( + (bat_size, self.config.patch_size, z_dim), device=input_tensor.device, dtype=input_tensor.dtype + ) + t = get_epss_timesteps(self.config.steps, device=input_tensor.device, dtype=input_tensor.dtype) + sde_rnd = torch.randn( + (self.config.steps, *randn_tensor.shape), device=input_tensor.device, dtype=input_tensor.dtype + ) + + if not self.initialized: + self._initialize_graph(input_tensor, his_lat, randn_tensor, sde_rnd) + + self.last_hidden_state_placeholder.copy_(input_tensor) + self.his_lat_placeholder.copy_(his_lat) + self.randn_like_placeholder.copy_(randn_tensor) + self.t_placeholder.copy_(t) + self.sde_args_placeholder[0] = cfg_strength + self.sde_args_placeholder[1] = sigma + self.sde_args_placeholder[2] = temperature + self.sde_rnd_placeholder.copy_(sde_rnd) + + self.graph.replay() + + gen_lat = torch.empty_like(self.gen_lat_placeholder) + gen_lat.copy_(self.gen_lat_placeholder) + + inputs_embeds = torch.empty_like(self.inputs_embeds_placeholder) + inputs_embeds.copy_(self.inputs_embeds_placeholder) + + stop_out = torch.empty_like(self.stop_out_placeholder) + stop_out.copy_(self.stop_out_placeholder) + + return gen_lat, inputs_embeds, stop_out + + def _initialize_graph( + self, + input_tensor: torch.Tensor, + his_lat: torch.Tensor, + randn_tensor: torch.Tensor, + sde_rnd: torch.Tensor, + ) -> None: + self.last_hidden_state_placeholder = torch.empty_like(input_tensor) + self.his_lat_placeholder = torch.empty_like(his_lat) + self.randn_like_placeholder = torch.empty_like(randn_tensor) + self.t_placeholder = get_epss_timesteps(self.config.steps, device=input_tensor.device, dtype=input_tensor.dtype) + self.sde_args_placeholder = torch.empty(3, device=input_tensor.device, dtype=input_tensor.dtype) + self.sde_rnd_placeholder = torch.empty_like(sde_rnd) + + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph): + self.gen_lat_placeholder = self.cfm.sample( + self.last_hidden_state_placeholder, + self.his_lat_placeholder, + self.randn_like_placeholder, + self.t_placeholder, + self.sde_args_placeholder, + self.sde_rnd_placeholder, + ) + self.inputs_embeds_placeholder = self.aggregator(self.gen_lat_placeholder) + self.stop_out_placeholder = self.stop_head(self.last_hidden_state_placeholder[:, -1, :]).softmax(dim=-1) + + self.initialized = True + + +class CFMGraphExecutorPool: + """Thread-safe pool of CFMGraphExecutors for concurrent inference.""" + + def __init__(self, config, cfm, aggregator, stop_head: nn.Linear, pool_size: int = 1): + self.config = config + self.cfm = cfm + self.aggregator = aggregator + self.stop_head = stop_head + self.pool_size = pool_size + self.pool = Queue(maxsize=pool_size) + self.lock = Lock() + + for _ in range(pool_size): + executor = CFMGraphExecutor(config, cfm, aggregator, stop_head) + self.pool.put(executor) + + def acquire(self) -> CFMGraphExecutor: + return self.pool.get() + + def release(self, executor: CFMGraphExecutor) -> None: + self.pool.put(executor) + + def execute( + self, + input_tensor: torch.Tensor, + his_lat: torch.Tensor, + cfg_strength: float = 2.0, + sigma: float = 0.25, + temperature: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + executor = self.acquire() + try: + return executor.execute( + input_tensor, his_lat, cfg_strength=cfg_strength, sigma=sigma, temperature=temperature + ) + finally: + self.release(executor) + + +######################################################################## +# Audio Postprocess +# Adapted from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/modeling_bailing_talker.py +######################################################################## + + +@torch.no_grad() +def resample(waveform: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor: + """Resample a waveform via linear interpolation (no torchaudio dep). + + Args: + waveform: Tensor shaped ``(..., num_samples)``. + orig_sr: Source sample rate (Hz); must be > 0. + target_sr: Target sample rate (Hz); must be > 0. + + Raises: + ValueError: If sample rates are non-positive, the waveform is empty, + or the resampled length would round to zero. + """ + if orig_sr <= 0: + raise ValueError(f"orig_sr must be positive, got {orig_sr}") + if target_sr <= 0: + raise ValueError(f"target_sr must be positive, got {target_sr}") + if waveform.numel() == 0 or waveform.shape[-1] == 0: + raise ValueError("waveform must contain at least one sample") + if orig_sr == target_sr: + return waveform + + ratio = target_sr / orig_sr + new_len = int(waveform.shape[-1] * ratio) + if new_len <= 0: + raise ValueError( + f"resampled waveform would be empty for input length {waveform.shape[-1]}, " + f"orig_sr={orig_sr}, target_sr={target_sr}" + ) + return torch.nn.functional.interpolate( + waveform.unsqueeze(0), + size=new_len, + mode="linear", + align_corners=False, + ).squeeze(0) + + +def trim_trailing_silence( + waveform: torch.Tensor, + sample_rate: int, + sil_th: float = 1e-3, + tail_silence_s: float = 0.3, +) -> torch.Tensor: + """Trim low-energy tail while keeping a short trailing silence. + + Works on 2-D ``(channels, samples)`` or 3-D ``(batch, channels, samples)`` + tensors. Any other shape is returned unchanged. + """ + if waveform.numel() == 0: + return waveform + + original_dim = waveform.dim() + if original_dim == 3: + speech = waveform[:, 0, :] + elif original_dim == 2: + speech = waveform + else: + return waveform + + frame_step = int(sample_rate * 0.1) + frame_size = int(sample_rate * 0.1) + if speech.shape[-1] < frame_size: + keep = min(speech.shape[-1], int(tail_silence_s * sample_rate)) + trimmed = speech[..., :keep] + else: + num_frame = (speech.shape[-1] - frame_size) // frame_step + 1 + cur_len = (num_frame - 1) * frame_step + frame_size + speech = speech[..., :cur_len] + spe_frames = speech.unfold(-1, frame_size, frame_step) + scores = spe_frames.abs().mean(dim=-1) + scores = scores.mean(dim=list(range(scores.dim() - 1))) + idx = scores.shape[0] - 1 + while idx >= 0 and scores[idx] <= sil_th: + idx -= 1 + if idx < 0: + keep = min(speech.shape[-1], int(tail_silence_s * sample_rate)) + trimmed = speech[..., :keep] + else: + non_sil_len = idx * frame_step + frame_size + int(tail_silence_s * sample_rate) + non_sil_len = min(non_sil_len, speech.shape[-1]) + trimmed = speech[..., :non_sil_len] + + if original_dim == 3: + return trimmed.unsqueeze(1) + return trimmed + + +def silence_holder( + speech: torch.Tensor, + sample_rate: int, + sil_cache: dict | None = None, + last_chunk: bool = True, + sil_th: float = 1e-3, + last_sil: float = 0.3, +) -> tuple[torch.Tensor, dict]: + """Ming-style streaming silence holder. + + Used during streaming VAE decode to defer emission of silent regions + until a non-silent frame arrives (or the stream ends). ``sil_cache`` + is carried across chunks and updated in place. + """ + if speech.numel() == 0: + return speech, sil_cache or {"holder": [], "buffer": []} + + frame_step = int(sample_rate * 0.1) + frame_size = int(sample_rate * 0.1) + if sil_cache is None: + sil_cache = {"holder": [], "buffer": []} + + if sil_cache["buffer"]: + speech = torch.cat([*sil_cache["buffer"], speech], dim=-1) + sil_cache["buffer"] = [] + + if speech.shape[-1] < frame_size: + sil_cache["buffer"].append(speech) + if last_chunk: + speech = torch.cat(sil_cache["holder"] + sil_cache["buffer"], dim=-1) + return speech[..., : int(last_sil * sample_rate)], sil_cache + return torch.zeros((*speech.shape[:-1], 0), device=speech.device, dtype=speech.dtype), sil_cache + + num_frame = (speech.shape[-1] - frame_size) // frame_step + 1 + cur_len = (num_frame - 1) * frame_step + frame_size + if speech.shape[-1] > cur_len: + sil_cache["buffer"].append(speech[..., cur_len:]) + speech = speech[..., :cur_len] + + spe_frames = speech.unfold(-1, frame_size, frame_step) + scores = spe_frames.abs().mean(dim=-1) + scores = scores.mean(dim=list(range(scores.dim() - 1))) + idx = scores.shape[0] - 1 + while idx >= 0 and scores[idx] <= sil_th: + idx -= 1 + + if idx < 0: + sil_cache["holder"].append(speech) + if last_chunk: + speech = torch.cat(sil_cache["holder"] + sil_cache["buffer"], dim=-1) + return speech[..., : int(last_sil * sample_rate)], sil_cache + return torch.zeros((*speech.shape[:-1], 0), device=speech.device, dtype=speech.dtype), sil_cache + + non_sil_len = idx * frame_step + frame_size + if last_chunk: + non_sil_len += int(last_sil * sample_rate) + non_sil_len = min(non_sil_len, speech.shape[-1]) + speech_out = torch.cat([*sil_cache["holder"], speech[..., :non_sil_len]], dim=-1) + sil_cache["holder"] = [] + if non_sil_len < speech.shape[-1]: + sil_cache["holder"].append(speech[..., non_sil_len:]) + return speech_out, sil_cache + + +######################################################################## +# Audio Postprocess +# Ported from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/aggregator.py +######################################################################## + + +class Aggregator(nn.Module): + """Maps generated audio latent patches back to LLM embedding space.""" + + def __init__( + self, + in_channels: int = 64, + hidden_size: int = 1152, + depth: int = 28, + num_heads: int = 16, + mlp_ratio: float = 4.0, + llm_input_dim: int = 896, + **kwargs, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.num_heads = num_heads + + self.word_embedder = nn.Embedding(1, hidden_size) + self.x_embedder = nn.Linear(in_channels, hidden_size) + self.hidden_size = hidden_size + + self.rotary_embed = RotaryEmbedding(hidden_size // num_heads) + + self.blocks = nn.ModuleList( + [DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **kwargs) for _ in range(depth)] + ) + self.final_layer = FinalLayer(hidden_size, llm_input_dim) + + def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: + x = self.x_embedder(x) + cls_embed = self.word_embedder(torch.zeros((x.shape[0], 1), dtype=torch.long, device=x.device)) + x = torch.cat([cls_embed, x], dim=1) + + rope = self.rotary_embed.forward_from_seq_len(x.shape[1]) + if mask is not None: + mask_pad = mask.clone().detach()[:, :1] + mask = torch.cat([mask_pad, mask], dim=-1) + for block in self.blocks: + x = block(x, mask, rope) + x = self.final_layer(x) + x = x[:, :1, :] + return x + + +######################################################################## +# Prompt Builder +# Adapted from: +# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/modeling_bailing_talker.py +######################################################################## + +_MUSIC_TAGS = ("Genre: ", "Mood: ", "Instrument: ", "Theme: ", "Duration: ") + + +def _looks_like_music_prompt(text: str) -> bool: + return all(tag in text for tag in _MUSIC_TAGS) + + +def build_tts_input( + *, + tokenizer: PreTrainedTokenizerBase, + embed_tokens: torch.nn.Module, + device: torch.device, + dtype: torch.dtype, + text: str, + prompt: str, + spk_emb: list[torch.Tensor] | None = None, + instruction: str | None = None, + prompt_text: str | None = None, + prompt_wav_emb: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build (inputs_embeds, input_ids) for one TTS segment. + + Args: + tokenizer: HF tokenizer + embed_tokens: The LLM's input-embedding module + device: Device to place the returned tensors on. + dtype: dtype for the returned `inputs_embeds`. + text: Text to synthesize. + prompt: System-level instruction prompt prepended to the user turn. + spk_emb: Optional list of speaker embeddings already projected into + LLM hidden dim; each is injected at a `<|vision_start|>` slot. + instruction: Optional free-form instruction + prompt_text: Reference text for zero-shot voice cloning. + prompt_wav_emb: Reference-wav embeddings to inject. + """ + spk_emb_prompt: list[int] = [] + if spk_emb is not None: + for i in range(len(spk_emb)): + spk_emb_prompt.extend( + tokenizer.encode(f" speaker_{i + 1}:") + + tokenizer.encode("<|vision_start|>") + + tokenizer.encode("<|vision_pad|>") + + tokenizer.encode("<|vision_end|>\n") + ) + + instruction_prompt: list[int] = [] + if instruction is not None: + instruction_prompt = tokenizer.encode(instruction) + tokenizer.encode("<|im_end|>") + + prompt_text_token: list[int] = [] + prompt_latent_token: list[int] = [] + if prompt_wav_emb is not None and prompt_text is not None: + prompt_text_token = tokenizer.encode(prompt_text) + prompt_latent_token = tokenizer.encode("") * prompt_wav_emb.size(1) + + prompt2 = [] if _looks_like_music_prompt(text) else tokenizer.encode(" Text input:\n") + + input_part = ( + tokenizer.encode("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n") + + tokenizer.encode("<|im_start|>user\n") + + tokenizer.encode(prompt) + + spk_emb_prompt + + prompt2 + + prompt_text_token + + tokenizer.encode(text) + + tokenizer.encode("<|im_end|>\n") + + tokenizer.encode("<|im_start|>assistant\n") + + instruction_prompt + + tokenizer.encode("