From 487345c6bc8b6517ee47bcd8faafd31ef8626e6c Mon Sep 17 00:00:00 2001 From: akshatvishu Date: Sat, 18 Apr 2026 22:40:42 +0530 Subject: [PATCH 01/54] feat(ming-tts): add dense omni pipeline Signed-off-by: akshatvishu --- docs/models/supported_models.md | 1 + .../examples/offline_inference/ming_tts.md | 131 +++ .../examples/online_serving/ming_tts.md | 163 ++++ examples/offline_inference/ming_tts/README.md | 224 +++++ .../offline_inference/ming_tts/end2end.py | 654 +++++++++++++ examples/online_serving/ming_tts/README.md | 312 +++++++ .../ming_tts/openai_speech_client.py | 223 +++++ examples/online_serving/ming_tts/run_curl.sh | 217 +++++ .../online_serving/ming_tts/run_server.sh | 22 + .../test_chunk_transfer_adapter.py | 25 + tests/e2e/offline_inference/test_ming_tts.py | 231 +++++ tests/e2e/online_serving/test_ming_tts.py | 95 ++ tests/engine/test_async_omni_engine_input.py | 54 ++ .../openai_api/test_serving_speech.py | 378 ++++++++ .../ming_tts/test_ming_tts_components.py | 505 ++++++++++ .../ming_tts/test_ming_tts_config_shim.py | 51 ++ .../models/ming_tts/test_ming_tts_loaders.py | 524 +++++++++++ .../ming_tts/test_ming_tts_prompt_builder.py | 375 ++++++++ .../test_ming_tts_async_chunk.py | 421 +++++++++ tests/worker/test_ming_tts_runner.py | 674 ++++++++++++++ tests/worker/test_omni_gpu_model_runner.py | 22 + vllm_omni/engine/arg_utils.py | 4 + vllm_omni/engine/async_omni_engine.py | 39 +- vllm_omni/engine/stage_init_utils.py | 8 + .../entrypoints/openai/protocol/audio.py | 66 +- .../entrypoints/openai/serving_speech.py | 287 +++++- vllm_omni/inputs/preprocess.py | 26 + .../models/ming_tts/__init__.py | 13 + .../ming_tts/audio_tokenizer/__init__.py | 2 + .../ming_tts/audio_tokenizer/audio_encoder.py | 135 +++ .../configuration_audio_vae.py | 40 + .../models/ming_tts/audio_tokenizer/istft.py | 188 ++++ .../audio_tokenizer/modeling_audio_vae.py | 178 ++++ .../ming_tts/audio_tokenizer/vae_modules.py | 208 +++++ .../models/ming_tts/config_ming_tts.py | 364 ++++++++ .../ming_tts/configuration_ming_dense.py | 57 ++ .../models/ming_tts/fm/__init__.py | 2 + .../model_executor/models/ming_tts/fm/cfm.py | 207 +++++ .../model_executor/models/ming_tts/fm/dit.py | 216 +++++ .../models/ming_tts/fm/flowloss.py | 54 ++ .../models/ming_tts/fm/modules.py | 147 +++ .../model_executor/models/ming_tts/ingress.py | 177 ++++ .../models/ming_tts/ming_tts.py | 581 ++++++++++++ .../models/ming_tts/ming_tts_audio_vae.py | 318 +++++++ .../models/ming_tts/ming_tts_llm.py | 864 ++++++++++++++++++ .../models/ming_tts/prompt_builder.py | 429 +++++++++ .../models/ming_tts/speaker_extractor.py | 66 ++ vllm_omni/model_executor/models/registry.py | 16 + .../stage_configs/ming_tts.yaml | 65 ++ .../stage_configs/ming_tts_async_chunk.yaml | 86 ++ .../stage_input_processors/ming_tts.py | 278 ++++++ vllm_omni/worker/gpu_model_runner.py | 2 + 52 files changed, 10402 insertions(+), 23 deletions(-) create mode 100644 docs/user_guide/examples/offline_inference/ming_tts.md create mode 100644 docs/user_guide/examples/online_serving/ming_tts.md create mode 100644 examples/offline_inference/ming_tts/README.md create mode 100644 examples/offline_inference/ming_tts/end2end.py create mode 100644 examples/online_serving/ming_tts/README.md create mode 100644 examples/online_serving/ming_tts/openai_speech_client.py create mode 100755 examples/online_serving/ming_tts/run_curl.sh create mode 100755 examples/online_serving/ming_tts/run_server.sh create mode 100644 tests/e2e/offline_inference/test_ming_tts.py create mode 100644 tests/e2e/online_serving/test_ming_tts.py create mode 100644 tests/model_executor/models/ming_tts/test_ming_tts_components.py create mode 100644 tests/model_executor/models/ming_tts/test_ming_tts_config_shim.py create mode 100644 tests/model_executor/models/ming_tts/test_ming_tts_loaders.py create mode 100644 tests/model_executor/models/ming_tts/test_ming_tts_prompt_builder.py create mode 100644 tests/model_executor/stage_input_processors/test_ming_tts_async_chunk.py create mode 100644 tests/worker/test_ming_tts_runner.py create mode 100644 vllm_omni/model_executor/models/ming_tts/__init__.py create mode 100644 vllm_omni/model_executor/models/ming_tts/audio_tokenizer/__init__.py create mode 100644 vllm_omni/model_executor/models/ming_tts/audio_tokenizer/audio_encoder.py create mode 100644 vllm_omni/model_executor/models/ming_tts/audio_tokenizer/configuration_audio_vae.py create mode 100644 vllm_omni/model_executor/models/ming_tts/audio_tokenizer/istft.py create mode 100644 vllm_omni/model_executor/models/ming_tts/audio_tokenizer/modeling_audio_vae.py create mode 100644 vllm_omni/model_executor/models/ming_tts/audio_tokenizer/vae_modules.py create mode 100644 vllm_omni/model_executor/models/ming_tts/config_ming_tts.py create mode 100644 vllm_omni/model_executor/models/ming_tts/configuration_ming_dense.py create mode 100644 vllm_omni/model_executor/models/ming_tts/fm/__init__.py create mode 100644 vllm_omni/model_executor/models/ming_tts/fm/cfm.py create mode 100644 vllm_omni/model_executor/models/ming_tts/fm/dit.py create mode 100644 vllm_omni/model_executor/models/ming_tts/fm/flowloss.py create mode 100644 vllm_omni/model_executor/models/ming_tts/fm/modules.py create mode 100644 vllm_omni/model_executor/models/ming_tts/ingress.py create mode 100644 vllm_omni/model_executor/models/ming_tts/ming_tts.py create mode 100644 vllm_omni/model_executor/models/ming_tts/ming_tts_audio_vae.py create mode 100644 vllm_omni/model_executor/models/ming_tts/ming_tts_llm.py create mode 100644 vllm_omni/model_executor/models/ming_tts/prompt_builder.py create mode 100644 vllm_omni/model_executor/models/ming_tts/speaker_extractor.py create mode 100644 vllm_omni/model_executor/stage_configs/ming_tts.yaml create mode 100644 vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml create mode 100644 vllm_omni/model_executor/stage_input_processors/ming_tts.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 8ece14f9c00..b298415e177 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -55,6 +55,7 @@ th { | `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-CustomVoice | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-VoiceDesign | `Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-Base | `Qwen/Qwen3-TTS-12Hz-0.6B-Base` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | +| `MingTTSForConditionalGeneration` | Ming-omni-tts-0.5B | `inclusionAI/Ming-omni-tts-0.5B` | ✅︎ | | | | | `NextStep11Pipeline` | NextStep-1.1 | `stepfun-ai/NextStep-1.1` | ✅︎ | ✅︎ | | ✅︎ | | `MiMoAudioForConditionalGeneration` | MiMo-Audio-7B-Instruct | `XiaomiMiMo/MiMo-Audio-7B-Instruct` | ✅︎ | ✅︎ | | | | `Flux2Pipeline` | FLUX.2-dev | `black-forest-labs/FLUX.2-dev` | ✅︎ | ✅︎ | | | diff --git a/docs/user_guide/examples/offline_inference/ming_tts.md b/docs/user_guide/examples/offline_inference/ming_tts.md new file mode 100644 index 00000000000..7a8cd65ed32 --- /dev/null +++ b/docs/user_guide/examples/offline_inference/ming_tts.md @@ -0,0 +1,131 @@ +# Ming-omni-tts + +Source . + +This directory contains an offline Ming example that uses the in-repo Ming prompt builder directly. It now covers the broader upstream dense TTS cookbook surface: style, IP, music-only generation, emotion, dialect, zero-shot clone, podcast, speech+bgm, and speech+sound. + +## Quick Start + +Run a zero-speaker style case: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case style \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts.yaml \ + --enforce-eager +``` + +Run emotion-controlled speech: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case emotion \ + --ref-audio /path/to/emotion_prompt.wav \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts.yaml \ + --enforce-eager +``` + +Run zero-shot cloning with a transcript: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case zero_shot \ + --ref-audio /path/to/reference.wav \ + --ref-text "在此奉劝大家别乱打美白针。" \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts.yaml \ + --enforce-eager +``` + +Run podcast generation: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case podcast \ + --ref-audio-paths /path/to/CTS-CN-F2F-2019-11-11-423-012-A.wav /path/to/CTS-CN-F2F-2019-11-11-423-012-B.wav \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts.yaml \ + --enforce-eager +``` + +Run with stats and a manifest: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case style \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts.yaml \ + --enforce-eager \ + --enable-stats \ + --stats-log-file output_audio/ming_style_pipeline.log \ + --metadata-json output_audio/ming_style_manifest.json +``` + +## Built-in Cases + +- `style`: zero-speaker style-conditioned speech +- `ip`: zero-speaker IP voice generation +- `bgm`: music generation +- `emotion`: reference-audio speech with emotion control +- `basic`: reference-audio cloning with speed / pitch / volume control +- `dialect`: reference-audio cloning with dialect control +- `zero_shot`: reference-audio cloning with explicit transcript +- `podcast`: multi-reference dialogue generation with automatic speaker embedding extraction +- `speech_bgm`: speech with background music conditioning +- `speech_sound`: speech with environment sound conditioning + +`TTA` from the upstream Ming notebook is not included here because it uses `inclusionAI/Ming-omni-tta-0.5B`, not the dense TTS model covered by this example. + +## Streaming + +Use async_chunk streaming with `AsyncOmni`: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case basic \ + --ref-audio /path/to/10002287-00000095.wav \ + --streaming \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml \ + --enforce-eager +``` + +`--streaming` currently supports one prompt per process invocation. Use +blocking mode for `--num-prompts > 1`. + +## Validation matrix + +The example is intended to cover the dense TTS workflows used by the Ming +validation helper: + +| Case | Blocking | Async chunk | Extra inputs | +|---|---:|---:|---| +| `style` | Yes | Optional smoke test | none | +| `ip` | Yes | Optional smoke test | none | +| `bgm` | Yes | Optional smoke test | none | +| `emotion` | Yes | Yes | reference WAV | +| `basic` | Yes | Yes | reference WAV | +| `dialect` | Yes | Yes | reference WAV | +| `zero_shot` | Yes | Yes | reference WAV and transcript | +| `podcast` | Yes | Yes | two reference WAVs | +| `speech_bgm` | Yes | Yes | reference WAV | +| `speech_sound` | Yes | Yes | reference WAV | + +The offline example also exposes vLLM-Omni runtime/reporting controls such as: + +- `--num-prompts` +- `--enable-stats` +- `--stats-log-file` +- `--metadata-json` +- `--stage-init-timeout` +- `--init-timeout` +- `--batch-timeout` +- `--worker-backend` +- `--ray-address` + +## Example materials + +??? abstract "README.md" + ``````md + --8<-- "examples/offline_inference/ming_tts/README.md" + `````` +??? abstract "end2end.py" + ``````py + --8<-- "examples/offline_inference/ming_tts/end2end.py" + `````` diff --git a/docs/user_guide/examples/online_serving/ming_tts.md b/docs/user_guide/examples/online_serving/ming_tts.md new file mode 100644 index 00000000000..e5bc5144dda --- /dev/null +++ b/docs/user_guide/examples/online_serving/ming_tts.md @@ -0,0 +1,163 @@ +# Ming-omni-tts + +Source . + +This example shows how to serve Ming through the OpenAI-compatible `/v1/audio/speech` endpoint. The server builds Ming prompts directly with the in-repo prompt builder, so online requests support Ming-specific structured controls instead of the Qwen placeholder path. + +## Installation + +Please refer to [README.md](https://github.com/vllm-project/vllm-omni/tree/main/README.md) + +## Launch the Server + +```bash +vllm-omni serve inclusionAI/Ming-omni-tts-0.5B \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml \ + --omni \ + --port 8091 \ + --enforce-eager +``` + +Or: + +```bash +cd examples/online_serving/ming_tts +./run_server.sh +``` + +The canonical Ming online client is `openai_speech_client.py`. It targets the +local vLLM-Omni server, not OpenAI's cloud API, so `api_key=EMPTY` is enough +for local testing. + +## Example Requests + +Basic TTS: + +```bash +python openai_speech_client.py \ + --text "你好,这是 Ming 在线语音合成测试。" +``` + +Style-conditioned speech: + +```bash +python openai_speech_client.py \ + --text "我会一直在这里陪着你。" \ + --instructions "轻柔的ASMR耳语,慢速,贴近麦克风" +``` + +Structured Ming control: + +```bash +python openai_speech_client.py \ + --text "我觉得社会企业同个人都有责任" \ + --instruction-json '{"方言":"广粤话"}' +``` + +IP voice generation: + +```bash +python openai_speech_client.py \ + --text "这款产品的名字,叫变态坑爹牛肉丸。" \ + --voice 灵小甄 +``` + +Reference-audio cloning: + +Use `ref_audio` by itself for Ming prompt-waveform conditioning. Add +`ref_text` when the request is transcript cloning, such as zero-shot or +podcast-style prompts. + +```bash +python openai_speech_client.py \ + --task-type Base \ + --text "我们的愿景是构建未来服务业的数字化基础设施。" \ + --ref-audio /path/to/reference.wav \ + --ref-text "在此奉劝大家别乱打美白针。" +``` + +Speaker-embedding cloning: + +```bash +python openai_speech_client.py \ + --task-type Base \ + --text "你好,这是一段使用说话人向量的合成语音。" \ + --speaker-embedding /path/to/ming_speaker_embedding.json +``` + +Streaming PCM: + +```bash +python openai_speech_client.py \ + --text "你好,这是流式输出测试。" \ + --instructions "平静,普通话" \ + --stream \ + --output ming_output.pcm +``` + +## Curl Helper + +Use the bundled helper for common request types: + +```bash +./run_curl.sh basic +./run_curl.sh style +./run_curl.sh ip +REF_AUDIO=/path/to/emotion_prompt.wav ./run_curl.sh emotion +REF_AUDIO=/path/to/yue_prompt.wav ./run_curl.sh dialect +REF_AUDIO=/path/to/reference.wav REF_TEXT="在此奉劝大家别乱打美白针。" ./run_curl.sh zero_shot +REF_AUDIO=/path/to/speaker_1.wav REF_AUDIO_2=/path/to/speaker_2.wav REF_TEXT="speaker_1:你好。 speaker_2:你好。" ./run_curl.sh podcast +REF_AUDIO=/path/to/00000309-00000300.wav ./run_curl.sh speech_bgm +REF_AUDIO=/path/to/00000309-00000300.wav ./run_curl.sh speech_sound +REF_AUDIO=/path/to/reference.wav REF_TEXT="在此奉劝大家别乱打美白针。" ./run_curl.sh clone_ref_audio +SPEAKER_EMBEDDING=/path/to/ming_speaker_embedding.json ./run_curl.sh clone_embedding +./run_curl.sh stream +``` + +## Audio Inputs + +- `ref_audio` accepts a local path, remote URL, or `data:` URL +- The Python client converts local files into a base64 `data:` URL +- `speaker_embedding` must be a JSON file with exactly 192 numeric values +- Ming prompt-waveform cases can use `ref_audio` without `ref_text` +- Zero-shot and podcast-style transcript cloning should include `ref_text` + +The bundled `run_curl.sh basic` mode is plain/default TTS and does not require +`REF_AUDIO`. The upstream cookbook-style `basic` case uses `ref_audio` plus +structured speed / pitch / volume instructions. + +## Field Mapping + +For Ming, the generic OpenAI request fields map to Ming controls like this: + +- `input` -> target text +- `instructions` -> Ming instruction string, or a JSON string for the structured Ming control object +- `voice` -> Ming `IP` +- `language` -> Ming `方言` +- `ref_audio` -> Ming prompt waveform +- `ref_text` -> optional transcript for zero-shot and podcast-style cloning +- `speaker_embedding` -> 192-d Ming speaker embedding + +## Voice Listing + +- `/v1/audio/voices` lists uploaded voices for Ming. +- Built-in Ming IP labels can still be used as `voice`, but they are not enumerated by the API. + +## Example materials + +??? abstract "README.md" + ``````md + --8<-- "examples/online_serving/ming_tts/README.md" + `````` +??? abstract "run_server.sh" + ``````sh + --8<-- "examples/online_serving/ming_tts/run_server.sh" + `````` +??? abstract "openai_speech_client.py" + ``````py + --8<-- "examples/online_serving/ming_tts/openai_speech_client.py" + `````` +??? abstract "run_curl.sh" + ``````sh + --8<-- "examples/online_serving/ming_tts/run_curl.sh" + `````` diff --git a/examples/offline_inference/ming_tts/README.md b/examples/offline_inference/ming_tts/README.md new file mode 100644 index 00000000000..c67772c43c1 --- /dev/null +++ b/examples/offline_inference/ming_tts/README.md @@ -0,0 +1,224 @@ +# Ming-omni-tts Offline Inference + +`end2end.py` runs Ming dense 0.5B end to end with vLLM-Omni. It uses the in-repo Ming prompt builder directly, so the example request shape matches the real integration instead of a simplified wrapper. + +## Model Overview + +Ming dense 0.5B is exposed here as a two-stage offline pipeline: + +- **Stage 0**: Qwen2-based AR generation with Ming prompt formatting and inline flow controls +- **Stage 1**: audio VAE decode to mono 44.1 kHz waveform + +The example supports both: + +- **Sequential eager** via `ming_tts.yaml` +- **Async chunk eager** via `ming_tts_async_chunk.yaml` + +## Setup + +Install vLLM-Omni with the platform requirements for your accelerator: + +```bash +uv pip install -e . +``` + +The Ming offline example does not require a separate upstream Ming package. +Reference-audio cases use the repo dependencies for audio loading, +resampling, and CampPlus speaker extraction, including `soundfile`, +`torchaudio`, and `onnxruntime`. + +## Supported Cases + +These cases cover the upstream dense TTS cookbook surface that maps cleanly onto the current vLLM-Omni example: + +- `style`: zero-speaker style-conditioned speech +- `ip`: zero-speaker IP voice generation +- `bgm`: music-only generation +- `emotion`: reference-audio speech with emotion control +- `basic`: reference-audio speech with speed / pitch / volume control +- `dialect`: reference-audio speech with dialect control +- `zero_shot`: reference-audio cloning with explicit transcript +- `podcast`: multi-reference dialogue generation with automatic speaker embedding extraction +- `speech_bgm`: speech with background music conditioning +- `speech_sound`: speech with environmental sound conditioning + +Not included: + +- `TTA` from the upstream cookbook. That notebook switches to `inclusionAI/Ming-omni-tta-0.5B`, which is a different model family and is out of scope for this dense TTS example. + +## Quick Start + +Run the zero-speaker style example: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case style \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts.yaml \ + --enforce-eager +``` + +Run zero-shot cloning with a transcript: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case zero_shot \ + --ref-audio /path/to/10002287-00000094.wav \ + --ref-text "在此奉劝大家别乱打美白针。" \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts.yaml \ + --enforce-eager +``` + +Run emotion-controlled speech: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case emotion \ + --ref-audio /path/to/emotion_prompt.wav \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts.yaml \ + --enforce-eager +``` + +Run podcast generation with two reference clips: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case podcast \ + --ref-audio-paths /path/to/CTS-CN-F2F-2019-11-11-423-012-A.wav /path/to/CTS-CN-F2F-2019-11-11-423-012-B.wav \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts.yaml \ + --enforce-eager +``` + +The script automatically extracts one 192-d speaker embedding per reference WAV using the Ming model's `campplus.onnx`. + +If you already have precomputed multi-speaker embeddings, you can override extraction with: + +```bash +--speaker-embedding /path/to/podcast_speaker_embeddings.json +``` + +where the JSON is a list of speaker embeddings, one 192-d vector per speaker. + +Use async_chunk streaming: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case basic \ + --ref-audio /path/to/10002287-00000095.wav \ + --streaming \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml \ + --enforce-eager +``` + +`--streaming` uses `AsyncOmni` and the async_chunk stage config. It currently +supports one prompt per process invocation; use blocking mode for +`--num-prompts > 1`. + +Collect runtime stats and a manifest: + +```bash +python examples/offline_inference/ming_tts/end2end.py \ + --case style \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts.yaml \ + --enforce-eager \ + --enable-stats \ + --stats-log-file output_audio/ming_style_pipeline.log \ + --metadata-json output_audio/ming_style_manifest.json +``` + +## Reference Fixtures + +The upstream Ming cookbook uses these public audio fixtures from `inclusionAI/Ming-omni-tts/data/wavs`: + +- `10002287-00000094.wav` for zero-shot cloning +- `10002287-00000095.wav` for `basic` +- `emotion_prompt.wav` for `emotion` +- `yue_prompt.wav` for `dialect` +- `00000309-00000300.wav` for `speech_bgm` and `speech_sound` +- `CTS-CN-F2F-2019-11-11-423-012-A.wav` and `CTS-CN-F2F-2019-11-11-423-012-B.wav` for `podcast` + +## Validation Matrix + +The repo-facing example is intended to cover the same dense TTS workflows used +by the local Ming validation script: + +| Case | Blocking `ming_tts.yaml` | Async chunk `ming_tts_async_chunk.yaml` | Extra inputs | +|---|---:|---:|---| +| `style` | Yes | Optional smoke test | none | +| `ip` | Yes | Optional smoke test | none | +| `bgm` | Yes | Optional smoke test | none | +| `emotion` | Yes | Yes | `--ref-audio emotion_prompt.wav` | +| `basic` | Yes | Yes | `--ref-audio 10002287-00000095.wav` | +| `dialect` | Yes | Yes | `--ref-audio yue_prompt.wav` | +| `zero_shot` | Yes | Yes | `--ref-audio 10002287-00000094.wav --ref-text ...` | +| `podcast` | Yes | Yes | two `--ref-audio-paths` | +| `speech_bgm` | Yes | Yes | `--ref-audio 00000309-00000300.wav` | +| `speech_sound` | Yes | Yes | `--ref-audio 00000309-00000300.wav` | + +## Validated Outputs + +Validation on an L4 GPU completed the full blocking matrix and the default +async_chunk matrix. Default async_chunk matched blocking output frame counts +and Stage-1 patch counts for every case: + +| Case | Blocking frames / patches / sec | Async chunk frames / patches / sec | +|---|---:|---:| +| `style` | 409248 / 29 / 9.28 | 409248 / 29 / 9.28 | +| `ip` | 183456 / 13 / 4.16 | 183456 / 13 / 4.16 | +| `bgm` | 1326528 / 94 / 30.08 | 1326528 / 94 / 30.08 | +| `emotion` | 324576 / 23 / 7.36 | 324576 / 23 / 7.36 | +| `basic` | 211680 / 15 / 4.80 | 211680 / 15 / 4.80 | +| `dialect` | 239904 / 17 / 5.44 | 239904 / 17 / 5.44 | +| `zero_shot` | 409248 / 29 / 9.28 | 409248 / 29 / 9.28 | +| `podcast` | 437472 / 31 / 9.92 | 437472 / 31 / 9.92 | +| `speech_bgm` | 296352 / 21 / 6.72 | 296352 / 21 / 6.72 | +| `speech_sound` | 352800 / 25 / 8.00 | 352800 / 25 / 8.00 | + +## Key Arguments + +| Argument | Description | +|---|---| +| `--model` | Hugging Face repo or local Ming checkpoint path | +| `--stage-configs-path` | Stage config YAML. Use `ming_tts.yaml` for blocking generation or `ming_tts_async_chunk.yaml` for streaming | +| `--case` | Built-in demo case | +| `--ref-audio` | Single reference wav path for cloning-style cases | +| `--ref-audio-paths` | Multiple reference wav paths, used by `podcast` | +| `--ref-text` | Reference transcript. Required for `zero_shot` | +| `--instructions` | Free-form Ming instruction string | +| `--instruction-json` | Structured Ming instruction JSON | +| `--speaker-embedding` | JSON file containing a 192-d speaker embedding | +| `--extract-speaker-embeddings` | Force CampPlus speaker extraction from the provided reference audio paths | +| `--max-decode-steps` | Override `ming_max_decode_steps` | +| `--num-prompts` | Repeat the same case N times. Outputs are indexed when `N > 1` | +| `--streaming` | Use `AsyncOmni` and async_chunk transport | +| `--enforce-eager` | Recommended for Ming dense; non-eager is out of scope | +| `--enable-stats` / `--log-stats` | Enable vLLM-Omni per-request stats logging | +| `--stats-log-file` | Optional path for the stats log | +| `--metadata-json` | Optional path for the run manifest JSON | +| `--stage-init-timeout` | Per-stage initialization timeout in seconds | +| `--init-timeout` | Total initialization timeout in seconds | +| `--batch-timeout` | Batch timeout in seconds | +| `--worker-backend` | `multi_process` or `ray` | +| `--ray-address` | Ray cluster address when using `--worker-backend ray` | + +## Output + +- The script writes one mono 44.1 kHz WAV file per run +- Default output directory: `output_audio/` +- Default filename: `ming_.wav` +- When `--num-prompts > 1`, outputs are indexed as `ming__00000.wav`, `..._00001.wav`, etc. +- When stats are enabled, the script can also write: + - a stats log file such as `ming_style_pipeline.log` + - a manifest JSON with per-output metadata, stage durations, peak memory info, + and streaming client latency metrics when `--streaming` is used + +## Notes + +- `style` and `ip` are zero-speaker paths and do not require a reference clip +- `emotion`, `basic`, `dialect`, `speech_bgm`, and `speech_sound` require one reference clip +- `zero_shot` requires both `--ref-audio` and `--ref-text` +- `podcast` requires at least two reference clips via `--ref-audio-paths` +- `podcast` automatically extracts one speaker embedding per reference clip +- `--speaker-embedding` may contain either one 192-d vector or a list of 192-d vectors +- `--enforce-eager` was used for the validated runs +- Validation on the L4 GPU used SDPA for the Ming audio VAE instead of + FlashAttention2, which is the preferred default when available. diff --git a/examples/offline_inference/ming_tts/end2end.py b/examples/offline_inference/ming_tts/end2end.py new file mode 100644 index 00000000000..9e9742f4e7e --- /dev/null +++ b/examples/offline_inference/ming_tts/end2end.py @@ -0,0 +1,654 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Offline inference demo for Ming-omni-tts via vLLM Omni.""" + +import asyncio +import json +import os +import time +import uuid +import wave +from pathlib import Path + +import soundfile as sf +import torch +import torchaudio +from transformers import AutoTokenizer +from vllm import SamplingParams +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni import AsyncOmni, Omni +from vllm_omni.model_executor.models.ming_tts.config_ming_tts import ( + KEY_CFG, + KEY_MAX_DECODE_STEPS, + KEY_SIGMA, + KEY_SPEAKER_EMBEDDING, + KEY_TEMPERATURE, + SAMPLE_RATE, + TEXT_EOS_TOKEN_ID, +) +from vllm_omni.model_executor.models.ming_tts.prompt_builder import build_ming_dense_prompt +from vllm_omni.model_executor.models.ming_tts.speaker_extractor import MingSpeakerEmbeddingExtractor + +DEFAULT_MODEL = "inclusionAI/Ming-omni-tts-0.5B" +DEFAULT_STAGE_CONFIG = "vllm_omni/model_executor/stage_configs/ming_tts.yaml" +DEFAULT_STREAM_STAGE_CONFIG = "vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml" +DEFAULT_OUTPUT_DIR = "output_audio" +DEFAULT_SPEECH_PROMPT = "Please generate speech based on the following description.\n" +DEFAULT_MUSIC_PROMPT = "Please generate music based on the following description.\n" +DEFAULT_PODCAST_TEXT = ( + " speaker_1:你可以说一下,就大概说一下,可能虽然我也不知道,我看过那部电影没有。\n" + " speaker_2:就是那个叫什么,变相一节课的嘛。\n" + " speaker_1:嗯。\n" + " speaker_2:一部搞笑的电影。\n" + " speaker_1:一部搞笑的。\n" +) +DEFAULT_PODCAST_PROMPT_TEXT = ( + " speaker_1:并且我们还要进行每个月还要考核 笔试的话还要进行笔试,做个,当服务员还要去笔试了\n" + " speaker_2:对啊,这真的很奇怪,就是 单纯的因,单纯自己工资不高,只是因为可能人家那个店比较出名一点,就对你苛刻要求\n" +) + +CASE_DEFAULTS = { + "style": { + "prompt": DEFAULT_SPEECH_PROMPT, + "text": "我会一直在这里陪着你,直到你慢慢、慢慢地沉入那个最温柔的梦里……好吗?", + "instruction": { + "风格": ( + "这是一种ASMR耳语,属于一种旨在引发特殊感官体验的创意风格。" + "这个女性使用轻柔的普通话进行耳语,声音气音成分重。" + "音量极低,紧贴麦克风,语速极慢,旨在制造触发听者颅内快感的声学刺激。" + ) + }, + "use_zero_spk_emb": True, + "max_decode_steps": 200, + }, + "ip": { + "prompt": DEFAULT_SPEECH_PROMPT, + "text": "这款产品的名字,叫变态坑爹牛肉丸。", + "instruction": {"IP": "灵小甄"}, + "use_zero_spk_emb": True, + "max_decode_steps": 200, + }, + "bgm": { + "prompt": DEFAULT_MUSIC_PROMPT, + "text": "Genre: 电子舞曲. Mood: 自信 / 坚定. Instrument: 架子鼓. Theme: 节日. Duration: 30s.", + "instruction": None, + "use_zero_spk_emb": False, + "max_decode_steps": 400, + }, + "tta": { + "prompt": "Please generate audio events based on given text.\n", + "text": "Thunder and a gentle rain", + "instruction": None, + "use_zero_spk_emb": False, + "max_decode_steps": 200, + "cfg": 4.5, + "sigma": 0.3, + "temperature": 2.5, + }, + "emotion": { + "prompt": DEFAULT_SPEECH_PROMPT, + "text": "我竟然抢到了陈奕迅的演唱会门票!太棒了!终于可以现场听一听他的歌声了!", + "instruction": {"情感": "高兴"}, + "requires_ref_audio": True, + "auto_extract_speaker_embeddings": True, + "max_decode_steps": 200, + }, + "basic": { + "prompt": DEFAULT_SPEECH_PROMPT, + "text": "简单地说,这相当于惠普把消费领域市场拱手相让了。", + "instruction": {"语速": "快速", "基频": "中", "音量": "高"}, + "requires_ref_audio": True, + "auto_extract_speaker_embeddings": True, + "max_decode_steps": 200, + }, + "dialect": { + "prompt": DEFAULT_SPEECH_PROMPT, + "text": "我觉得社会企业同个人都有责任", + "instruction": {"方言": "广粤话"}, + "requires_ref_audio": True, + "auto_extract_speaker_embeddings": True, + "max_decode_steps": 200, + }, + "zero_shot": { + "prompt": DEFAULT_SPEECH_PROMPT, + "text": "我们的愿景是构建未来服务业的数字化基础设施,为世界带来更多微小而美好的改变。", + "instruction": None, + "requires_ref_audio": True, + "requires_ref_text": True, + "auto_extract_speaker_embeddings": True, + "max_decode_steps": 200, + }, + "podcast": { + "prompt": DEFAULT_SPEECH_PROMPT, + "text": DEFAULT_PODCAST_TEXT, + "instruction": None, + "prompt_text": DEFAULT_PODCAST_PROMPT_TEXT, + "requires_ref_audio_count": 2, + "auto_extract_speaker_embeddings": True, + "max_decode_steps": 200, + }, + "speech_bgm": { + "prompt": DEFAULT_SPEECH_PROMPT, + "text": "此次业绩下滑原因,可归结为企业停止服务某些品牌,而带来的负面影响。", + "instruction": { + "BGM": { + "Genre": "当代古典音乐.", + "Mood": "温暖 / 友善.", + "Instrument": "电吉他", + "Theme": "节日.", + "SNR": 10.0, + "ENV": None, + } + }, + "requires_ref_audio": True, + "auto_extract_speaker_embeddings": True, + "max_decode_steps": 200, + }, + "speech_sound": { + "prompt": DEFAULT_SPEECH_PROMPT, + "text": "此次业绩下滑原因,可归结为企业停止服务某些品牌,而带来的负面影响。", + "instruction": { + "BGM": { + "ENV": "Birds chirping", + "SNR": 10.0, + "Genre": None, + "Mood": None, + "Instrument": None, + "Theme": None, + } + }, + "requires_ref_audio": True, + "auto_extract_speaker_embeddings": True, + "max_decode_steps": 200, + }, +} + + +def _load_reference_waveform(path: str) -> torch.Tensor: + samples, sample_rate = sf.read(path, dtype="float32") + waveform = torch.as_tensor(samples, dtype=torch.float32) + if waveform.ndim == 2: + waveform = waveform.mean(dim=1) + waveform = waveform.reshape(1, -1) + if int(sample_rate) != SAMPLE_RATE: + waveform = torchaudio.functional.resample(waveform, int(sample_rate), SAMPLE_RATE) + return waveform + + +def _load_speaker_embedding(path: str) -> torch.Tensor: + data = json.loads(Path(path).read_text(encoding="utf-8")) + return torch.as_tensor(data, dtype=torch.float32) + + +def _resolve_reference_inputs(args, case): + if args.ref_audio is not None and args.ref_audio_paths is not None: + raise RuntimeError("Use either --ref-audio or --ref-audio-paths, not both") + + if args.ref_audio_paths is not None: + ref_audio_paths = list(args.ref_audio_paths) + elif args.ref_audio is not None: + ref_audio_paths = [args.ref_audio] + else: + ref_audio_paths = [] + + required_count = int(case.get("requires_ref_audio_count", 0)) + if required_count > 0: + if len(ref_audio_paths) < required_count: + raise RuntimeError( + f"Case '{args.case}' requires at least {required_count} reference audio paths via --ref-audio-paths" + ) + elif case.get("requires_ref_audio") and not ref_audio_paths: + raise RuntimeError(f"--ref-audio is required for case '{args.case}'") + + if not ref_audio_paths: + return None + if len(ref_audio_paths) == 1: + return _load_reference_waveform(ref_audio_paths[0]) + return [_load_reference_waveform(path) for path in ref_audio_paths] + + +def _resolve_reference_audio_paths(args): + if args.ref_audio is not None and args.ref_audio_paths is not None: + raise RuntimeError("Use either --ref-audio or --ref-audio-paths, not both") + if args.ref_audio_paths is not None: + return list(args.ref_audio_paths) + if args.ref_audio is not None: + return [args.ref_audio] + return [] + + +def _resolve_speaker_embedding(args, case, ref_audio_paths): + if args.speaker_embedding: + return _load_speaker_embedding(args.speaker_embedding) + + should_extract = bool(case.get("auto_extract_speaker_embeddings", False) or args.extract_speaker_embeddings) + if not should_extract or not ref_audio_paths: + return None + + extractor = MingSpeakerEmbeddingExtractor(args.model) + embeddings = extractor.extract_many(ref_audio_paths) + if not embeddings: + raise RuntimeError("Speaker extraction produced no embeddings") + if len(embeddings) == 1: + return embeddings[0] + return torch.stack(embeddings, dim=0) + + +def _coerce_audio_tensor(audio, *, async_chunk: bool) -> torch.Tensor: + if isinstance(audio, list): + if async_chunk: + parts = [] + for item in audio: + tensor = torch.as_tensor(item, dtype=torch.float32).reshape(-1) + if tensor.numel() > 0: + parts.append(tensor) + if not parts: + return torch.zeros((0,), dtype=torch.float32) + return torch.cat(parts, dim=0) + + for item in reversed(audio): + tensor = torch.as_tensor(item, dtype=torch.float32).reshape(-1) + if tensor.numel() > 0: + return tensor + return torch.zeros((0,), dtype=torch.float32) + + return torch.as_tensor(audio, dtype=torch.float32).reshape(-1) + + +def _resolve_sr(sr) -> int: + if isinstance(sr, list): + sr = sr[-1] + if hasattr(sr, "item"): + return int(sr.item()) + return int(sr) + + +def _extract_sample_rate(multimodal_output: dict) -> int: + sr = multimodal_output.get("sr") + if sr is None: + raise RuntimeError("Expected multimodal_output['sr']") + return _resolve_sr(sr) + + +def _write_wav(path: str, audio: torch.Tensor, sample_rate: int) -> None: + audio = audio.clamp(-1.0, 1.0) + pcm16 = (audio * 32767.0).round().to(torch.int16).cpu().numpy() + with wave.open(path, "wb") as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(int(sample_rate)) + wav_file.writeframes(pcm16.tobytes()) + + +def _request_index(request_id: str | None, fallback: int) -> int: + try: + return int(request_id) + except (TypeError, ValueError): + if isinstance(request_id, str): + head = request_id.split("_", 1)[0] + if head.isdigit(): + return int(head) + return fallback + + +def _audio_summary(audio: torch.Tensor, sample_rate: int) -> dict: + waveform = audio.detach().cpu().reshape(-1).to(torch.float32) + return { + "sample_rate": int(sample_rate), + "num_samples": int(waveform.numel()), + "duration_seconds": float(waveform.numel()) / float(sample_rate), + "max_abs_amplitude": float(waveform.abs().max().item()) if waveform.numel() > 0 else 0.0, + } + + +def _resolve_output_name(output_name: str | None, case: str, index: int, total: int) -> str: + if total == 1: + return output_name or f"ming_{case}.wav" + base = Path(output_name or f"ming_{case}.wav") + return f"{base.stem}_{index:05d}{base.suffix or '.wav'}" + + +def _resolve_stats_log_file(args) -> str | None: + if not args.log_stats: + return None + if args.stats_log_file: + return args.stats_log_file + base = Path(args.output_name or f"ming_{args.case}.wav").stem + return str(Path(args.output_dir) / f"{base}_pipeline.log") + + +def _resolve_metadata_json(args) -> str | None: + if args.metadata_json: + return args.metadata_json + if args.log_stats: + base = Path(args.output_name or f"ming_{args.case}.wav").stem + return str(Path(args.output_dir) / f"{base}_manifest.json") + return None + + +def _build_manifest(args, prompt_payload, stats_log_file: str | None, outputs: list[dict]) -> dict: + additional_information = {} + if isinstance(prompt_payload, dict): + additional_information = dict(prompt_payload.get("additional_information", {})) + return { + "model": args.model, + "case": args.case, + "streaming": bool(args.streaming), + "stage_configs_path": args.stage_configs_path, + "enforce_eager": bool(args.enforce_eager), + "num_prompts": int(args.num_prompts), + "log_stats": bool(args.log_stats), + "stats_log_file": stats_log_file, + "prompt_text": additional_information.get("prompt_text"), + "instruction": additional_information.get("instruction"), + "speaker_embedding_shape": ( + list(additional_information[KEY_SPEAKER_EMBEDDING].shape) + if KEY_SPEAKER_EMBEDDING in additional_information + and hasattr(additional_information[KEY_SPEAKER_EMBEDDING], "shape") + else None + ), + "outputs": outputs, + "generated_at_unix": time.time(), + } + + +def _build_engine_kwargs(args, stats_log_file: str | None) -> dict: + kwargs = { + "model": args.model, + "stage_configs_path": args.stage_configs_path, + "enforce_eager": args.enforce_eager, + "trust_remote_code": args.trust_remote_code, + "log_stats": args.log_stats, + "stage_init_timeout": args.stage_init_timeout, + "init_timeout": args.init_timeout, + "batch_timeout": args.batch_timeout, + "shm_threshold_bytes": args.shm_threshold_bytes, + "worker_backend": args.worker_backend, + } + if stats_log_file is not None: + kwargs["log_file"] = stats_log_file + if args.ray_address is not None: + kwargs["ray_address"] = args.ray_address + return kwargs + + +def _extract_audio_output(outputs, *, async_chunk: bool): + output = next((item for item in outputs if item.final_output_type == "audio"), None) + if output is None: + raise RuntimeError("Expected one final output with final_output_type='audio'") + + multimodal_output = output.multimodal_output or {} + audio = multimodal_output.get("audio") + sr = multimodal_output.get("sr") + if audio is None or sr is None: + raise RuntimeError("Expected multimodal_output['audio'] and multimodal_output['sr']") + + waveform = _coerce_audio_tensor(audio, async_chunk=async_chunk) + if waveform.numel() == 0: + raise RuntimeError("Generated audio waveform is empty") + return waveform, _resolve_sr(sr) + + +def _build_instruction(args, case): + if args.instruction_json is not None: + return json.loads(args.instruction_json) + if args.instructions is not None: + return args.instructions + return case.get("instruction") + + +def _build_prompt(tokenizer, args): + case = CASE_DEFAULTS[args.case] + prompt = args.prompt or case["prompt"] + text = args.text or case["text"] + instruction = _build_instruction(args, case) + prompt_text = args.ref_text if args.ref_text is not None else case.get("prompt_text") + ref_audio_paths = _resolve_reference_audio_paths(args) + prompt_waveform = _resolve_reference_inputs(args, case) if prompt_text is not None else None + + required_count = int(case.get("requires_ref_audio_count", 0)) + if required_count > 0 and len(ref_audio_paths) < required_count: + raise RuntimeError( + f"Case '{args.case}' requires at least {required_count} reference audio paths via --ref-audio-paths" + ) + if required_count <= 0 and case.get("requires_ref_audio") and not ref_audio_paths: + raise RuntimeError(f"--ref-audio is required for case '{args.case}'") + + if case.get("requires_ref_text") and not prompt_text: + raise RuntimeError(f"--ref-text is required for case '{args.case}'") + + speaker_embedding = _resolve_speaker_embedding(args, case, ref_audio_paths) + use_zero_spk_emb = ( + bool(case.get("use_zero_spk_emb", False)) and prompt_waveform is None and speaker_embedding is None + ) + + runtime_controls = { + KEY_MAX_DECODE_STEPS: args.max_decode_steps or case["max_decode_steps"], + } + if "cfg" in case: + runtime_controls[KEY_CFG] = case["cfg"] + if "sigma" in case: + runtime_controls[KEY_SIGMA] = case["sigma"] + if "temperature" in case: + runtime_controls[KEY_TEMPERATURE] = case["temperature"] + return build_ming_dense_prompt( + tokenizer, + prompt=prompt, + text=text, + runtime_controls=runtime_controls, + instruction=instruction, + prompt_text=prompt_text, + prompt_waveform=prompt_waveform, + speaker_embedding=speaker_embedding, + use_zero_spk_emb=use_zero_spk_emb, + ) + + +async def _run_streaming(args, prompt_payload, sampling_params_list, output_dir, stats_log_file): + engine = AsyncOmni(**_build_engine_kwargs(args, stats_log_file)) + try: + all_audio_chunks = [] + accumulated_samples = 0 + chunk_idx = 0 + start_time = time.time() + chunk_times = [] + ttfp_seconds = None + final_stage_output = None + async for stage_output in engine.generate( + prompt=prompt_payload, + request_id=str(uuid.uuid4()), + sampling_params_list=sampling_params_list, + ): + final_stage_output = stage_output + multimodal_output = stage_output.multimodal_output or {} + audio = multimodal_output.get("audio") + if audio is None: + continue + + finished = stage_output.finished + if isinstance(audio, torch.Tensor): + if finished: + audio_chunk = audio[accumulated_samples:].float().detach().cpu() + else: + audio_chunk = audio.float().detach().cpu() + elif isinstance(audio, list): + audio_chunk = torch.as_tensor(audio[chunk_idx], dtype=torch.float32).reshape(-1).cpu() + else: + audio_chunk = torch.as_tensor(audio, dtype=torch.float32).reshape(-1).cpu() + + accumulated_samples += int(audio_chunk.numel()) + chunk_idx += 1 + if audio_chunk.numel() > 0: + now = time.time() + if ttfp_seconds is None: + ttfp_seconds = now - start_time + chunk_times.append(now) + all_audio_chunks.append(audio_chunk) + + if not all_audio_chunks: + raise RuntimeError("Streaming Ming example produced no audio chunks") + + waveform = torch.cat(all_audio_chunks, dim=0) + output_name = _resolve_output_name(args.output_name, args.case, 0, 1) + output_path = str(Path(output_dir) / output_name) + _write_wav(output_path, waveform, SAMPLE_RATE) + summary = { + "request_id": getattr(final_stage_output, "request_id", None), + "stage_id": getattr(final_stage_output, "stage_id", None), + "output_path": output_path, + "stage_durations": getattr(final_stage_output, "stage_durations", {}), + "peak_memory_mb": getattr(final_stage_output, "peak_memory_mb", 0.0), + "ttfp_seconds": ttfp_seconds, + "mean_inter_chunk_seconds": ( + sum(t1 - t0 for t0, t1 in zip(chunk_times, chunk_times[1:])) / (len(chunk_times) - 1) + if len(chunk_times) > 1 + else None + ), + } + summary.update(_audio_summary(waveform, SAMPLE_RATE)) + print(f"Saved streaming output to {output_path}") + print(json.dumps(summary, ensure_ascii=False, indent=2)) + return [summary] + finally: + engine.shutdown() + + +def _run_non_streaming(args, prompt_payload, sampling_params_list, output_dir, stats_log_file): + engine = Omni(**_build_engine_kwargs(args, stats_log_file)) + try: + outputs = engine.generate( + prompts=[prompt_payload for _ in range(args.num_prompts)], + sampling_params_list=sampling_params_list, + py_generator=False, + ) + summaries = [] + for fallback_index, output in enumerate(outputs): + if output.final_output_type != "audio": + continue + multimodal_output = output.multimodal_output or {} + waveform = _coerce_audio_tensor(multimodal_output.get("audio"), async_chunk=False) + sample_rate = _extract_sample_rate(multimodal_output) + request_index = _request_index(output.request_id, fallback_index) + output_name = _resolve_output_name(args.output_name, args.case, request_index, args.num_prompts) + output_path = str(Path(output_dir) / output_name) + _write_wav(output_path, waveform, sample_rate) + summary = { + "request_id": output.request_id, + "stage_id": output.stage_id, + "output_path": output_path, + "stage_durations": output.stage_durations, + "peak_memory_mb": output.peak_memory_mb, + } + summary.update(_audio_summary(waveform, sample_rate)) + summaries.append(summary) + print(f"Saved output to {output_path}") + print(json.dumps(summary, ensure_ascii=False, indent=2)) + if not summaries: + raise RuntimeError("Non-streaming Ming example produced no audio outputs") + return summaries + finally: + engine.close() + + +def main(): + parser = FlexibleArgumentParser(description="Offline Ming-omni-tts example") + parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or local path") + parser.add_argument( + "--stage-configs-path", + default=None, + help="Stage config path. Defaults to ming_tts.yaml or ming_tts_async_chunk.yaml when --streaming is set.", + ) + parser.add_argument("--case", choices=sorted(CASE_DEFAULTS), default="style", help="Built-in demo case") + parser.add_argument("--text", default=None, help="Override case text") + parser.add_argument("--prompt", default=None, help="Override the system prompt prefix") + parser.add_argument("--instructions", default=None, help="Free-form Ming instruction string") + parser.add_argument( + "--instruction-json", + default=None, + help='Structured Ming instruction JSON, for example \'{"方言":"广粤话"}\'', + ) + parser.add_argument("--ref-audio", default=None, help="Reference audio path for cloning") + parser.add_argument( + "--ref-audio-paths", + nargs="+", + default=None, + help="Multiple reference audio paths, used by multi-speaker cases like podcast", + ) + parser.add_argument("--ref-text", default=None, help="Reference transcript for cloning") + parser.add_argument("--speaker-embedding", default=None, help="Path to a JSON speaker embedding file") + parser.add_argument( + "--extract-speaker-embeddings", + action="store_true", + help="Extract 192-d Ming speaker embeddings from --ref-audio or --ref-audio-paths using campplus.onnx", + ) + parser.add_argument("--max-decode-steps", type=int, default=None, help="Override ming_max_decode_steps") + parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR, help="Directory for output wav files") + parser.add_argument("--output-name", default=None, help="Output wav filename") + parser.add_argument("--num-prompts", type=int, default=1, help="Repeat the same prompt N times") + parser.add_argument("--streaming", action="store_true", help="Use AsyncOmni with async_chunk streaming") + parser.add_argument("--trust-remote-code", action="store_true", help="Pass trust_remote_code to Omni") + parser.add_argument("--enforce-eager", action="store_true", help="Pass enforce_eager to Omni") + parser.add_argument( + "--log-stats", "--enable-stats", dest="log_stats", action="store_true", help="Enable Omni stats logging" + ) + parser.add_argument("--stats-log-file", default=None, help="Optional path for the Omni stats log file") + parser.add_argument("--metadata-json", default=None, help="Optional path for a run manifest JSON file") + parser.add_argument( + "--stage-init-timeout", type=int, default=300, help="Per-stage initialization timeout in seconds" + ) + parser.add_argument("--init-timeout", type=int, default=600, help="Total initialization timeout in seconds") + parser.add_argument("--batch-timeout", type=int, default=5, help="Batch timeout in seconds") + parser.add_argument("--shm-threshold-bytes", type=int, default=65536, help="Shared memory threshold in bytes") + parser.add_argument( + "--worker-backend", + type=str, + default="multi_process", + choices=["multi_process", "ray"], + help="Worker backend", + ) + parser.add_argument("--ray-address", default=None, help="Ray cluster address when --worker-backend ray is used") + args = parser.parse_args() + + if args.instructions is not None and args.instruction_json is not None: + raise RuntimeError("Use either --instructions or --instruction-json, not both") + if args.num_prompts < 1: + raise RuntimeError("--num-prompts must be at least 1") + if args.streaming and args.num_prompts != 1: + raise RuntimeError("--streaming currently supports exactly one prompt") + + if args.stage_configs_path is None: + args.stage_configs_path = DEFAULT_STREAM_STAGE_CONFIG if args.streaming else DEFAULT_STAGE_CONFIG + + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=False) + prompt_payload = _build_prompt(tokenizer, args) + + max_decode_steps = args.max_decode_steps or CASE_DEFAULTS[args.case]["max_decode_steps"] + sampling_params_list = [ + SamplingParams( + temperature=0.0, + max_tokens=max_decode_steps + 1, + stop_token_ids=[int(TEXT_EOS_TOKEN_ID)], + ), + SamplingParams(temperature=0.0, max_tokens=1), + ] + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + stats_log_file = _resolve_stats_log_file(args) + + if args.streaming: + summaries = asyncio.run(_run_streaming(args, prompt_payload, sampling_params_list, output_dir, stats_log_file)) + else: + summaries = _run_non_streaming(args, prompt_payload, sampling_params_list, output_dir, stats_log_file) + + metadata_json = _resolve_metadata_json(args) + manifest = _build_manifest(args, prompt_payload, stats_log_file, summaries) + if metadata_json is not None: + Path(metadata_json).write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"Saved run manifest to {metadata_json}") + + +if __name__ == "__main__": + os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + main() diff --git a/examples/online_serving/ming_tts/README.md b/examples/online_serving/ming_tts/README.md new file mode 100644 index 00000000000..76f8521a4fe --- /dev/null +++ b/examples/online_serving/ming_tts/README.md @@ -0,0 +1,312 @@ +# Ming-omni-tts + +## Installation + +Please refer to [README.md](../../../README.md) + +## Model + +| Model | Description | +|-------|-------------| +| `inclusionAI/Ming-omni-tts-0.5B` | Dense 0.5B Ming two-stage TTS model for speech generation with dialect, style, IP voice, and cloning controls | + +## Launch the Server + +```bash +vllm-omni serve inclusionAI/Ming-omni-tts-0.5B \ + --stage-configs-path vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml \ + --omni \ + --port 8091 \ + --enforce-eager +``` + +Or use the convenience script: + +```bash +cd examples/online_serving/ming_tts +./run_server.sh +``` + +The recommended online-serving path is eager async-chunk mode through +`/v1/audio/speech`. `run_server.sh` defaults to: + +- model: `inclusionAI/Ming-omni-tts-0.5B` +- stage config: `vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml` +- auth: local testing only, no real OpenAI key required + +## Send Requests + +The canonical Ming online client is: + +```bash +cd examples/online_serving/ming_tts +python openai_speech_client.py --text "你好,世界" +``` + +This talks to the local vLLM-Omni server at `http://localhost:8091/v1` and +uses `api_key=EMPTY`. It does not call OpenAI's cloud API. + +### Basic TTS + +```bash +python openai_speech_client.py \ + --text "你好,这是 Ming 在线语音合成测试。" \ + --max-new-tokens 200 +``` + +### Style-conditioned speech without a reference clip + +```bash +python openai_speech_client.py \ + --text "我会一直在这里陪着你。" \ + --instructions "轻柔的ASMR耳语,慢速,贴近麦克风" \ + --max-new-tokens 200 +``` + +### Structured Ming control via JSON + +```bash +python openai_speech_client.py \ + --text "我觉得社会企业同个人都有责任" \ + --instruction-json '{"方言":"广粤话"}' \ + --max-new-tokens 200 +``` + +### IP voice generation + +```bash +python openai_speech_client.py \ + --text "这款产品的名字,叫变态坑爹牛肉丸。" \ + --voice 灵小甄 \ + --max-new-tokens 200 +``` + +### Reference-audio cloning + +Ming has two reference-audio paths: + +- prompt-waveform conditioning, where `ref_audio` steers the voice/style and + `ref_text` is not required +- transcript cloning, where `ref_audio` and `ref_text` are paired + +```bash +python openai_speech_client.py \ + --task-type Base \ + --text "我们的愿景是构建未来服务业的数字化基础设施。" \ + --ref-audio /path/to/reference.wav \ + --max-new-tokens 200 +``` + +Pass `--ref-text` when the prompt case needs a transcript, such as zero-shot +voice cloning: + +```bash +python openai_speech_client.py \ + --task-type Base \ + --text "我们的愿景是构建未来服务业的数字化基础设施。" \ + --ref-audio /path/to/reference.wav \ + --ref-text "在此奉劝大家别乱打美白针。" \ + --max-new-tokens 200 +``` + +### Podcast-style multi-speaker prompt + +```bash +python openai_speech_client.py \ + --text "speaker_1:你可以说一下。 speaker_2:我也不知道。" \ + --ref-audio /path/to/speaker_1.wav \ + --ref-audio /path/to/speaker_2.wav \ + --ref-text "在此奉劝大家别乱打美白针。" +``` + +### x-vector style cloning with a precomputed embedding + +```bash +python openai_speech_client.py \ + --task-type Base \ + --text "你好,这是一段使用说话人向量的合成语音。" \ + --speaker-embedding /path/to/ming_speaker_embedding.json \ + --max-new-tokens 200 +``` + +### Curl examples + +Use the helper script for the common request types: + +```bash +./run_curl.sh basic +./run_curl.sh style +./run_curl.sh ip +REF_AUDIO=/path/to/emotion_prompt.wav ./run_curl.sh emotion +REF_AUDIO=/path/to/yue_prompt.wav ./run_curl.sh dialect +REF_AUDIO=/path/to/reference.wav REF_TEXT="在此奉劝大家别乱打美白针。" ./run_curl.sh zero_shot +REF_AUDIO=/path/to/speaker_1.wav REF_AUDIO_2=/path/to/speaker_2.wav REF_TEXT="speaker_1:你好。 speaker_2:你好。" ./run_curl.sh podcast +REF_AUDIO=/path/to/00000309-00000300.wav ./run_curl.sh speech_bgm +REF_AUDIO=/path/to/00000309-00000300.wav ./run_curl.sh speech_sound +REF_AUDIO=/path/to/reference.wav REF_TEXT="在此奉劝大家别乱打美白针。" ./run_curl.sh clone_ref_audio +SPEAKER_EMBEDDING=/path/to/ming_speaker_embedding.json ./run_curl.sh clone_embedding +./run_curl.sh stream +``` + +Or send a direct request: + +```bash +curl -X POST http://localhost:8091/v1/audio/speech \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer EMPTY" \ + -d '{ + "model": "inclusionAI/Ming-omni-tts-0.5B", + "input": "你好,这是 Ming 在线语音合成测试。", + "response_format": "wav" + }' \ + --output ming_output.wav +``` + +## Request Types + +Ming online serving supports these main request families through +`/v1/audio/speech`: + +| Case | Online support | Required fields | +|------|----------------|-----------------| +| default TTS | Supported | `input`, `max_new_tokens=200` | +| `style` | Supported | `input`, `instructions`, `max_new_tokens=200` | +| `ip` | Supported | `input`, `voice`, `max_new_tokens=200` | +| `basic` helper | Supported | `input`, `max_new_tokens=200` | +| upstream `basic` case | Supported | `input`, `ref_audio`, structured speed / pitch / volume `instructions`, `max_new_tokens=200` | +| `emotion` | Supported | `input`, `ref_audio`, structured emotion `instructions`, `max_new_tokens=200` | +| `dialect` | Supported | `input`, `language` or structured `instructions`, `ref_audio`, `max_new_tokens=200` | +| `zero_shot` | Supported | `input`, `ref_audio`, `ref_text`, `max_new_tokens=200` | +| `podcast` | Supported | `input`, repeated/list `ref_audio`, `ref_text`, `max_new_tokens=200` | +| `speech_bgm` | Supported | `input`, `ref_audio`, structured `instructions` with `{"BGM": ...}`, `max_new_tokens=200` | +| `speech_sound` | Supported | `input`, `ref_audio`, structured `instructions` with `{"BGM": {"ENV": ...}}`, `max_new_tokens=200` | +| `bgm` | Not supported online | Requires a future `prompt_mode=music` API extension | + +This matrix intentionally mirrors the local online validation flow. The +music-only `bgm` case remains offline-only because `/v1/audio/speech` always +uses Ming's speech prompt path today. + +## Output + +- Non-streaming requests return full audio bytes, usually written to `.wav` +- WAV outputs are expected to be readable at 44.1kHz +- Streaming requests return progressive PCM bytes; wrap or convert them to WAV + before browser playback +- The default Python client outputs: + - `ming_output.wav` for non-streaming + - `ming_output.pcm` for streaming + +## Validated Outputs + +Validation on an L4 GPU passed the online async_chunk `/v1/audio/speech` flow +for every speech-mode case in the local validation script: + +| Case | Output | Size bytes | Sample rate | Frames | +|------|--------|-----------:|------------:|-------:| +| `style` | WAV | 790316 | 44100 | 395136 | +| `ip` | WAV | 366956 | 44100 | 183456 | +| `basic` | WAV | 536300 | 44100 | 268128 | +| `emotion` | WAV | 649196 | 44100 | 324576 | +| `dialect` | WAV | 395180 | 44100 | 197568 | +| `zero_shot` | WAV | 931436 | 44100 | 465696 | +| `podcast` | WAV | 846764 | 44100 | 423360 | +| `speech_bgm` | WAV | 677420 | 44100 | 338688 | +| `speech_sound` | WAV | 649196 | 44100 | 324576 | +| `streaming` | PCM | 338688 | N/A | N/A | + +`bgm` is intentionally not included in the online pass list. It is a +music-prompt workflow, while `/v1/audio/speech` currently routes Ming through +the speech prompt path. + +## Performance + +Benchmark via `/v1/audio/speech`, `inclusionAI/Ming-omni-tts-0.5B`, +10 prompts, concurrency 1, eager mode: + +| Config | Mean TTFP | Mean E2E | Mean RTF | +|--------|----------:|---------:|---------:| +| Sequential eager | 3354.83ms | 3357.01ms | 0.561 | +| Async chunk eager | 3450.28ms | 3452.35ms | 0.577 | + +## Audio Inputs + +- `ref_audio` accepts: + - a local file path + - a remote `http://` or `https://` URL + - a `data:` URL + - repeated values for podcast-style multi-speaker prompts +- `openai_speech_client.py` converts local reference audio files into a base64 + `data:` URL before sending them to the server +- `speaker_embedding` must be a JSON file containing exactly 192 numeric values +- Ming prompt-waveform cases can use `ref_audio` without `ref_text` +- Zero-shot and podcast-style transcript cloning should include `ref_text` + +## API Field Mapping + +The OpenAI-compatible `/v1/audio/speech` endpoint stays generic. Ming-specific controls are mapped like this: + +- `input` -> target text +- `instructions` -> Ming instruction string, or a JSON string that becomes the structured Ming control object +- `voice` -> Ming `IP` field when using built-in character voices +- `language` -> Ming `方言` field +- `ref_audio` -> Ming `prompt_waveform` +- `ref_text` -> Ming `prompt_text` +- `speaker_embedding` -> 192-d Ming speaker embedding +- `max_new_tokens` -> Ming `max_decode_steps` + +## Voice Listing + +- `/v1/audio/voices` reflects uploaded voices for Ming. +- Built-in Ming IP labels like `灵小甄` are passed through as `voice` values, but they are not enumerated by the API. + +## Streaming + +Use `stream=true` to get progressive PCM output: + +```bash +python openai_speech_client.py \ + --text "你好,这是流式输出测试。" \ + --instructions "平静,普通话" \ + --stream \ + --output ming_output.pcm +``` + +## Not Supported Online Yet + +`bgm` music-prompt generation is not exposed through `/v1/audio/speech` today. +It needs a future `prompt_mode=music` API extension so the server can select +Ming's music system prompt instead of the speech system prompt. + +## Troubleshooting + +### No real OpenAI key + +The example targets a local vLLM-Omni server. `api_key=EMPTY` is expected and +is sufficient for local testing. + +### `--ref-audio` fails + +- Confirm the local file exists +- If using zero-shot or podcast transcript cloning, also provide `--ref-text` +- If passing a URL, make sure the server can fetch it + +### `--speaker-embedding` fails + +- Make sure the JSON file contains exactly 192 numeric values +- Do not wrap the list in another object + +### Connection refused + +- Check that the server is running on `localhost:8091` +- Confirm the stage config path is correct + +### No audio or wrong output file + +- Use non-streaming for `.wav` +- Use `--stream` for `.pcm` + +### `bgm` is missing online + +Use the offline example for music-only `bgm`. Online support needs an explicit +Ming prompt-mode API extension so the server can select the music prompt +instead of the speech prompt. diff --git a/examples/online_serving/ming_tts/openai_speech_client.py b/examples/online_serving/ming_tts/openai_speech_client.py new file mode 100644 index 00000000000..af1a70685b6 --- /dev/null +++ b/examples/online_serving/ming_tts/openai_speech_client.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""OpenAI-compatible client for Ming-omni-tts via /v1/audio/speech. + +Examples: + python openai_speech_client.py --text "你好,世界" + python openai_speech_client.py --text "我会一直在这里陪着你。" \ + --instructions "轻柔的ASMR耳语,慢速,贴近麦克风" --max-new-tokens 200 + python openai_speech_client.py --text "你好,这是零样本克隆测试。" \ + --ref-audio prompt.wav --ref-text "参考音频的转录文本" --max-new-tokens 200 + python openai_speech_client.py --text "speaker_1:你好。 speaker_2:你好。" \ + --ref-audio speaker_1.wav --ref-audio speaker_2.wav --ref-text "speaker_1:你好。 speaker_2:你好。" + python openai_speech_client.py --text "你好,这是流式输出测试。" \ + --stream --output ming_output.pcm +""" + +import argparse +import base64 +import json +import os + +import httpx + +DEFAULT_API_BASE = "http://localhost:8091" +DEFAULT_API_KEY = "EMPTY" +DEFAULT_MODEL = "inclusionAI/Ming-omni-tts-0.5B" +EXPECTED_SPEAKER_EMBEDDING_DIM = 192 + + +def encode_audio_to_base64(audio_path: str) -> str: + """Encode a local audio file to a base64 data URL.""" + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + ext = audio_path.lower().rsplit(".", 1)[-1] + mime_map = { + "wav": "audio/wav", + "mp3": "audio/mpeg", + "flac": "audio/flac", + "ogg": "audio/ogg", + "aac": "audio/aac", + } + mime_type = mime_map.get(ext, "audio/wav") + with open(audio_path, "rb") as f: + audio_b64 = base64.b64encode(f.read()).decode("utf-8") + return f"data:{mime_type};base64,{audio_b64}" + + +def load_speaker_embedding(path: str) -> list[float]: + """Load and validate a 192-d Ming speaker embedding JSON file.""" + with open(path, encoding="utf-8") as f: + data = json.load(f) + + if not isinstance(data, list): + raise ValueError("speaker_embedding file must contain a JSON list") + if len(data) != EXPECTED_SPEAKER_EMBEDDING_DIM: + raise ValueError( + f"Ming dense speaker_embedding must have {EXPECTED_SPEAKER_EMBEDDING_DIM} values, got {len(data)}" + ) + + values = [] + for index, value in enumerate(data): + try: + values.append(float(value)) + except (TypeError, ValueError) as exc: + raise ValueError(f"speaker_embedding[{index}] must be a number, got {value!r}") from exc + return values + + +def build_instruction_payload(args) -> str | None: + """Return a string payload for the API `instructions` field.""" + if args.instructions and args.instruction_json: + raise ValueError("Use either --instructions or --instruction-json, not both") + if args.instruction_json: + parsed = json.loads(args.instruction_json) + return json.dumps(parsed, ensure_ascii=False) + return args.instructions + + +def validate_args(args) -> None: + """Fail fast on invalid combinations before hitting the server.""" + if args.ref_text and not args.ref_audio: + raise ValueError("--ref-audio is required when --ref-text is provided") + if args.speaker_embedding and args.ref_audio and len(args.ref_audio) > 1: + raise ValueError("--speaker-embedding cannot be combined with multiple --ref-audio values") + + +def run_tts(args) -> None: + """Generate speech via the OpenAI-compatible /v1/audio/speech API.""" + validate_args(args) + + payload = { + "model": args.model, + "input": args.text, + "response_format": args.response_format, + } + + if args.voice: + payload["voice"] = args.voice + if args.task_type: + payload["task_type"] = args.task_type + if args.dialect: + payload["language"] = args.dialect + + instructions = build_instruction_payload(args) + if instructions: + payload["instructions"] = instructions + + if args.ref_audio: + ref_audio = [] + for audio in args.ref_audio: + if audio.startswith(("http://", "https://", "data:")): + ref_audio.append(audio) + else: + ref_audio.append(encode_audio_to_base64(audio)) + payload["ref_audio"] = ref_audio[0] if len(ref_audio) == 1 else ref_audio + if args.ref_text: + payload["ref_text"] = args.ref_text + if args.speaker_embedding: + payload["speaker_embedding"] = load_speaker_embedding(args.speaker_embedding) + if args.max_new_tokens: + payload["max_new_tokens"] = args.max_new_tokens + if args.stream: + payload["stream"] = True + payload["response_format"] = "pcm" + + api_url = f"{args.api_base}/v1/audio/speech" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {args.api_key}", + } + + print(f"Model: {args.model}") + print(f"Text: {args.text}") + print(f"Payload keys: {sorted(payload)}") + + if args.stream: + output_path = args.output or "ming_output.pcm" + with httpx.Client(timeout=300.0) as client: + with client.stream("POST", api_url, json=payload, headers=headers) as response: + if response.status_code != 200: + print(f"Error: {response.status_code}") + print(response.read().decode()) + return + with open(output_path, "wb") as f: + for chunk in response.iter_bytes(): + f.write(chunk) + print(f"Streamed PCM audio to: {output_path}") + return + + with httpx.Client(timeout=300.0) as client: + response = client.post(api_url, json=payload, headers=headers) + + if response.status_code != 200: + print(f"Error: {response.status_code}") + print(response.text) + return + + try: + text = response.content.decode("utf-8") + if text.startswith('{"error"'): + print(f"Error: {text}") + return + except UnicodeDecodeError: + pass + + output_path = args.output or "ming_output.wav" + with open(output_path, "wb") as f: + f.write(response.content) + print(f"Audio saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="OpenAI-compatible client for Ming-omni-tts via /v1/audio/speech") + parser.add_argument("--api-base", default=DEFAULT_API_BASE, help="API base URL") + parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key") + parser.add_argument("--model", "-m", default=DEFAULT_MODEL, help="Model name or path") + parser.add_argument("--text", required=True, help="Text to synthesize") + parser.add_argument( + "--task-type", + default=None, + choices=["CustomVoice", "VoiceDesign", "Base"], + help="Optional compatibility task type. Ming accepts the same field but primarily uses prompt metadata.", + ) + parser.add_argument( + "--voice", + default=None, + help="Maps to Ming `IP` when using built-in character voices, or to an uploaded voice sample name", + ) + parser.add_argument("--dialect", default=None, help="Maps to Ming `方言`") + parser.add_argument("--instructions", default=None, help="Free-form Ming instruction string") + parser.add_argument( + "--instruction-json", + default=None, + help='Structured Ming instruction JSON, for example \'{"情感":"高兴"}\'', + ) + parser.add_argument( + "--ref-audio", + action="append", + default=None, + help="Reference audio path, URL, or data URL. Repeat for podcast-style multi-speaker prompts.", + ) + parser.add_argument("--ref-text", default=None, help="Reference transcript for cloning") + parser.add_argument( + "--speaker-embedding", default=None, help="Path to a JSON file containing a 192-d speaker embedding" + ) + parser.add_argument("--max-new-tokens", type=int, default=None, help="Override ming_max_decode_steps") + parser.add_argument("--stream", action="store_true", help="Enable streaming PCM output") + parser.add_argument( + "--response-format", + default="wav", + choices=["wav", "mp3", "flac", "pcm", "aac", "opus"], + help="Audio format when not streaming", + ) + parser.add_argument("--output", "-o", default=None, help="Output file path") + args = parser.parse_args() + try: + run_tts(args) + except (FileNotFoundError, ValueError, json.JSONDecodeError) as exc: + raise SystemExit(f"Error: {exc}") from exc + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/ming_tts/run_curl.sh b/examples/online_serving/ming_tts/run_curl.sh new file mode 100755 index 00000000000..92762462e25 --- /dev/null +++ b/examples/online_serving/ming_tts/run_curl.sh @@ -0,0 +1,217 @@ +#!/bin/bash +# Common curl examples for Ming-omni-tts via /v1/audio/speech. +# +# Usage: +# ./run_curl.sh basic +# ./run_curl.sh style +# ./run_curl.sh ip +# REF_AUDIO=/path/to/ref.wav ./run_curl.sh emotion +# REF_AUDIO=/path/to/ref.wav ./run_curl.sh dialect +# REF_AUDIO=/path/to/ref.wav REF_TEXT="参考文本" ./run_curl.sh zero_shot +# REF_AUDIO=/path/to/speaker1.wav REF_AUDIO_2=/path/to/speaker2.wav REF_TEXT="speaker_1:... speaker_2:..." ./run_curl.sh podcast +# REF_AUDIO=/path/to/mix_ref.wav ./run_curl.sh speech_bgm +# REF_AUDIO=/path/to/mix_ref.wav ./run_curl.sh speech_sound +# REF_AUDIO=/path/to/ref.wav REF_TEXT="参考文本" ./run_curl.sh clone_ref_audio +# SPEAKER_EMBEDDING=/path/to/ming_embedding.json ./run_curl.sh clone_embedding +# ./run_curl.sh stream + +set -euo pipefail + +MODE="${1:-basic}" +HOST="${HOST:-localhost}" +PORT="${PORT:-8091}" +MODEL="${MODEL:-inclusionAI/Ming-omni-tts-0.5B}" +API_URL="http://${HOST}:${PORT}/v1/audio/speech" +TEXT="${TEXT:-你好,这是 Ming 在线语音合成测试。}" +OUTPUT="${OUTPUT:-ming_output.wav}" +STREAM_OUTPUT="${STREAM_OUTPUT:-ming_output.pcm}" +REF_AUDIO="${REF_AUDIO:-}" +REF_AUDIO_2="${REF_AUDIO_2:-}" +REF_TEXT="${REF_TEXT:-}" +SPEAKER_EMBEDDING="${SPEAKER_EMBEDDING:-}" + +build_payload() { + MODEL="$1" \ + TEXT="$2" \ + VOICE="$3" \ + INSTRUCTIONS="$4" \ + TASK_TYPE="$5" \ + REF_AUDIO_PATH="$6" \ + REF_TEXT="$7" \ + SPEAKER_EMBEDDING_PATH="$8" \ + STREAM="$9" \ + REF_AUDIO_PATH_2="${10:-}" \ + python - <<'PY' +import base64 +import json +import mimetypes +import os +import pathlib +import sys + +payload = { + "model": os.environ["MODEL"], + "input": os.environ["TEXT"], +} + +voice = os.environ["VOICE"] +instructions = os.environ["INSTRUCTIONS"] +task_type = os.environ["TASK_TYPE"] +ref_audio_path = os.environ["REF_AUDIO_PATH"] +ref_audio_path_2 = os.environ["REF_AUDIO_PATH_2"] +ref_text = os.environ["REF_TEXT"] +speaker_embedding_path = os.environ["SPEAKER_EMBEDDING_PATH"] + +if voice: + payload["voice"] = voice +if instructions: + payload["instructions"] = instructions +if task_type: + payload["task_type"] = task_type +ref_audio_items = [] +if ref_audio_path: + path = pathlib.Path(ref_audio_path) + mime_type = mimetypes.guess_type(path.name)[0] or "audio/wav" + data = base64.b64encode(path.read_bytes()).decode("utf-8") + ref_audio_items.append(f"data:{mime_type};base64,{data}") +if ref_audio_path_2: + path = pathlib.Path(ref_audio_path_2) + mime_type = mimetypes.guess_type(path.name)[0] or "audio/wav" + data = base64.b64encode(path.read_bytes()).decode("utf-8") + ref_audio_items.append(f"data:{mime_type};base64,{data}") +if ref_audio_items: + payload["ref_audio"] = ref_audio_items[0] if len(ref_audio_items) == 1 else ref_audio_items +if ref_text: + payload["ref_text"] = ref_text +if speaker_embedding_path: + path = pathlib.Path(speaker_embedding_path) + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, list): + raise SystemExit("speaker embedding file must contain a JSON list") + payload["speaker_embedding"] = data + +stream = os.environ["STREAM"] == "true" +if stream: + payload["stream"] = True + payload["response_format"] = "pcm" +else: + payload["response_format"] = "wav" + +print(json.dumps(payload, ensure_ascii=False)) +PY +} + +require_file() { + local path="$1" + local flag_name="$2" + if [ -z "$path" ]; then + echo "Missing ${flag_name}" >&2 + exit 1 + fi + if [ ! -f "$path" ]; then + echo "File not found for ${flag_name}: $path" >&2 + exit 1 + fi +} + +base_headers=( + -H "Content-Type: application/json" + -H "Authorization: Bearer EMPTY" +) + +post_payload() { + local payload="$1" + local output_path="$2" + local payload_file + payload_file="$(mktemp)" + trap 'rm -f "$payload_file"' RETURN + printf '%s' "$payload" > "$payload_file" + curl -X POST "$API_URL" "${base_headers[@]}" \ + --data-binary "@${payload_file}" \ + --output "$output_path" +} + +case "$MODE" in + basic) + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" "" "" "" "" "" "false")" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + style) + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" "轻柔的ASMR耳语,慢速,贴近麦克风" "" "" "" "" "false")" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + ip) + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "灵小甄" "" "" "" "" "" "false")" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + emotion) + require_file "$REF_AUDIO" "REF_AUDIO" + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" '{"情感":"高兴"}' "" "$REF_AUDIO" "" "" "false")" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + dialect) + require_file "$REF_AUDIO" "REF_AUDIO" + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" "" "" "$REF_AUDIO" "" "" "false")" + PAYLOAD="$(TEXT="$PAYLOAD" python - <<'PY' +import json +import os +payload = json.loads(os.environ["TEXT"]) +payload["language"] = "广粤话" +print(json.dumps(payload, ensure_ascii=False)) +PY +)" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + zero_shot) + require_file "$REF_AUDIO" "REF_AUDIO" + if [ -z "$REF_TEXT" ]; then + echo "Missing REF_TEXT" >&2 + exit 1 + fi + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" "" "Base" "$REF_AUDIO" "$REF_TEXT" "" "false")" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + podcast) + require_file "$REF_AUDIO" "REF_AUDIO" + require_file "$REF_AUDIO_2" "REF_AUDIO_2" + if [ -z "$REF_TEXT" ]; then + echo "Missing REF_TEXT" >&2 + exit 1 + fi + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" "" "Base" "$REF_AUDIO" "$REF_TEXT" "" "false" "$REF_AUDIO_2")" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + speech_bgm) + require_file "$REF_AUDIO" "REF_AUDIO" + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" '{"BGM":"舒缓的背景音乐"}' "" "$REF_AUDIO" "" "" "false")" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + speech_sound) + require_file "$REF_AUDIO" "REF_AUDIO" + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" '{"BGM":{"ENV":"轻微的环境声"}}' "" "$REF_AUDIO" "" "" "false")" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + clone_ref_audio) + require_file "$REF_AUDIO" "REF_AUDIO" + if [ -z "$REF_TEXT" ]; then + echo "Missing REF_TEXT" >&2 + exit 1 + fi + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" "" "Base" "$REF_AUDIO" "$REF_TEXT" "" "false")" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + clone_embedding) + require_file "$SPEAKER_EMBEDDING" "SPEAKER_EMBEDDING" + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" "" "Base" "" "" "$SPEAKER_EMBEDDING" "false")" + post_payload "$PAYLOAD" "$OUTPUT" + ;; + stream) + PAYLOAD="$(build_payload "$MODEL" "$TEXT" "" "平静,普通话" "" "" "" "" "true")" + post_payload "$PAYLOAD" "$STREAM_OUTPUT" + ;; + *) + echo "Unknown mode: $MODE" >&2 + echo "Supported: basic, style, ip, emotion, dialect, zero_shot, podcast, speech_bgm, speech_sound, clone_ref_audio, clone_embedding, stream" >&2 + exit 1 + ;; +esac diff --git a/examples/online_serving/ming_tts/run_server.sh b/examples/online_serving/ming_tts/run_server.sh new file mode 100755 index 00000000000..a35d4abe512 --- /dev/null +++ b/examples/online_serving/ming_tts/run_server.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Launch vLLM-Omni server for Ming-omni-tts. +# +# Usage: +# ./run_server.sh +# PORT=8000 ./run_server.sh + +set -e + +MODEL="${MODEL:-inclusionAI/Ming-omni-tts-0.5B}" +PORT="${PORT:-8091}" +STAGE_CONFIG="${STAGE_CONFIG:-vllm_omni/model_executor/stage_configs/ming_tts_async_chunk.yaml}" + +echo "Starting Ming-omni-tts server with model: $MODEL" +echo "Stage config: $STAGE_CONFIG" + +vllm-omni serve "$MODEL" \ + --stage-configs-path "$STAGE_CONFIG" \ + --host 0.0.0.0 \ + --port "$PORT" \ + --enforce-eager \ + --omni diff --git a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py index 256e3e0a3f8..03df9ad4eb0 100644 --- a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py +++ b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py @@ -122,6 +122,31 @@ def test_load_poll(build_adapter): assert "req-1" not in adapter._pending_load_reqs +def test_generation_load_preserves_payload_metadata(build_adapter): + adapter, connector = build_adapter(stage_id=1, model_mode="generation") + request = _req("req-1", RequestStatus.WAITING, external_req_id="external-1") + payload = { + "code_predictor_codes": [0], + "left_context_size": 3, + "ming_latent_patches": torch.ones((10, 4, 64), dtype=torch.float32), + "ming_request_id": "external-1", + "ming_chunk_id": 7, + "finished": torch.tensor(False), + } + connector.get.return_value = (payload, 16) + + adapter._poll_single_request(request) + + assert request.prompt_token_ids == [0] + assert request.additional_information["left_context_size"] == 3 + assert request.additional_information["ming_request_id"] == "external-1" + assert request.additional_information["ming_chunk_id"] == 7 + assert request.additional_information["ming_latent_patches"].shape == (10, 4, 64) + assert "code_predictor_codes" not in request.additional_information + assert "finished" not in request.additional_information + assert request.num_computed_tokens == 0 + + def test_save_async(build_adapter): adapter, _ = build_adapter(stage_id=1) request = _req("req-1", RequestStatus.WAITING, external_req_id="external-1") diff --git a/tests/e2e/offline_inference/test_ming_tts.py b/tests/e2e/offline_inference/test_ming_tts.py new file mode 100644 index 00000000000..128b84e2896 --- /dev/null +++ b/tests/e2e/offline_inference/test_ming_tts.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""End-to-end offline inference tests for Ming-omni-tts.""" + +import asyncio +import os +import uuid +from pathlib import Path + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer +from vllm import SamplingParams + +from tests.utils import hardware_test +from vllm_omni import AsyncOmni, Omni +from vllm_omni.model_executor.models.ming_tts.config_ming_tts import ( + KEY_MAX_DECODE_STEPS, + SAMPLE_RATE, + TEXT_EOS_TOKEN_ID, +) +from vllm_omni.model_executor.models.ming_tts.prompt_builder import build_ming_dense_prompt + +MODEL = "inclusionAI/Ming-omni-tts-0.5B" +STAGE_CONFIG = str( + Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "ming_tts.yaml" +) +STREAM_STAGE_CONFIG = str( + Path(__file__).parent.parent.parent.parent + / "vllm_omni" + / "model_executor" + / "stage_configs" + / "ming_tts_async_chunk.yaml" +) +TEST_TEXT = "我会一直在这里陪着你,直到你慢慢地沉入那个最温柔的梦里。" +TEST_INSTRUCTION = "轻柔的ASMR耳语,慢速,贴近麦克风" +MIN_AUDIO_SAMPLES = 1000 + + +def _build_prompt( + *, + text: str = TEST_TEXT, + instruction=TEST_INSTRUCTION, + use_zero_spk_emb: bool = True, +) -> dict: + tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=False) + return build_ming_dense_prompt( + tokenizer, + prompt="Please generate speech based on the following description.\n", + text=text, + instruction=instruction, + runtime_controls={KEY_MAX_DECODE_STEPS: 200}, + use_zero_spk_emb=use_zero_spk_emb, + ) + + +def _sampling_params_list() -> list[SamplingParams]: + return [ + SamplingParams( + temperature=0.0, + max_tokens=201, + stop_token_ids=[int(TEXT_EOS_TOKEN_ID)], + ), + SamplingParams(temperature=0.0, max_tokens=1), + ] + + +def _flatten_audio(audio) -> torch.Tensor: + if isinstance(audio, list): + parts = [torch.as_tensor(item, dtype=torch.float32).reshape(-1).cpu() for item in audio] + parts = [item for item in parts if item.numel() > 0] + if not parts: + return torch.zeros((0,), dtype=torch.float32) + return torch.cat(parts, dim=0) + return torch.as_tensor(audio, dtype=torch.float32).reshape(-1).cpu() + + +def _extract_audio(multimodal_output: dict) -> torch.Tensor: + audio = multimodal_output.get("audio") + if audio is None: + raise RuntimeError("Expected multimodal_output['audio']") + waveform = _flatten_audio(audio) + if waveform.numel() == 0: + raise RuntimeError("Generated audio waveform is empty") + return waveform + + +def _extract_sample_rate(multimodal_output: dict) -> int: + sample_rate = multimodal_output.get("sr") + if sample_rate is None: + raise RuntimeError("Expected multimodal_output['sr']") + if isinstance(sample_rate, list): + sample_rate = sample_rate[-1] + if hasattr(sample_rate, "item"): + sample_rate = sample_rate.item() + return int(sample_rate) + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards=1) +def test_ming_tts_offline_basic() -> None: + """Test blocking Ming generation through Omni.""" + omni = Omni( + model=MODEL, + stage_configs_path=STAGE_CONFIG, + stage_init_timeout=300, + enforce_eager=True, + ) + try: + outputs = omni.generate( + prompts=[_build_prompt()], + sampling_params_list=_sampling_params_list(), + py_generator=False, + ) + final_output = next((item for item in outputs if item.final_output_type == "audio"), None) + assert final_output is not None, "No final audio output produced" + multimodal_output = final_output.multimodal_output or {} + waveform = _extract_audio(multimodal_output) + sample_rate = _extract_sample_rate(multimodal_output) + assert waveform.ndim == 1 + assert waveform.shape[0] == waveform.numel() + assert waveform.numel() > MIN_AUDIO_SAMPLES + assert np.max(np.abs(waveform.numpy())) > 0.01, "Audio appears silent" + assert sample_rate == SAMPLE_RATE, f"Expected Ming output sample rate {SAMPLE_RATE}, got {sample_rate}" + finally: + omni.close() + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards=1) +def test_ming_tts_speaker_conditioning_differs() -> None: + """Test that different Ming speaker controls produce different waveform outputs.""" + omni = Omni( + model=MODEL, + stage_configs_path=STAGE_CONFIG, + stage_init_timeout=300, + enforce_eager=True, + ) + try: + style_outputs = omni.generate( + prompts=[_build_prompt()], + sampling_params_list=_sampling_params_list(), + py_generator=False, + ) + ip_outputs = omni.generate( + prompts=[_build_prompt(text=TEST_TEXT, instruction={"IP": "灵小甄"}, use_zero_spk_emb=True)], + sampling_params_list=_sampling_params_list(), + py_generator=False, + ) + + style_final_output = next((item for item in style_outputs if item.final_output_type == "audio"), None) + ip_final_output = next((item for item in ip_outputs if item.final_output_type == "audio"), None) + assert style_final_output is not None, "No style audio output produced" + assert ip_final_output is not None, "No IP audio output produced" + + style_waveform = _extract_audio(style_final_output.multimodal_output or {}) + ip_waveform = _extract_audio(ip_final_output.multimodal_output or {}) + assert style_waveform.numel() > MIN_AUDIO_SAMPLES + assert ip_waveform.numel() > MIN_AUDIO_SAMPLES + assert np.max(np.abs(style_waveform.numpy())) > 0.01, "Style audio appears silent" + assert np.max(np.abs(ip_waveform.numpy())) > 0.01, "IP audio appears silent" + + overlap = min(int(style_waveform.numel()), int(ip_waveform.numel())) + mean_abs_diff = torch.mean(torch.abs(style_waveform[:overlap] - ip_waveform[:overlap])).item() + assert style_waveform.shape != ip_waveform.shape or mean_abs_diff > 1e-4, ( + "Speaker-conditioned outputs should differ, but style and IP waveforms were effectively identical" + ) + finally: + omni.close() + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards=1) +def test_ming_tts_offline_streaming() -> None: + """Test async_chunk streaming Ming generation through AsyncOmni.""" + + async def _run() -> None: + async_omni = AsyncOmni( + model=MODEL, + stage_configs_path=STREAM_STAGE_CONFIG, + stage_init_timeout=300, + enforce_eager=True, + ) + try: + all_audio_chunks = [] + accumulated_samples = 0 + chunk_idx = 0 + sample_rate = None + async for stage_output in async_omni.generate( + prompt=_build_prompt(), + request_id=str(uuid.uuid4()), + sampling_params_list=_sampling_params_list(), + ): + multimodal_output = stage_output.multimodal_output or {} + audio = multimodal_output.get("audio") + if "sr" in multimodal_output: + sample_rate = _extract_sample_rate(multimodal_output) + if audio is None: + continue + finished = stage_output.finished + if isinstance(audio, torch.Tensor): + if finished: + audio_chunk = audio[accumulated_samples:].float().detach().cpu() + else: + audio_chunk = audio.float().detach().cpu() + elif isinstance(audio, list): + audio_chunk = torch.as_tensor(audio[chunk_idx], dtype=torch.float32).reshape(-1).cpu() + else: + audio_chunk = torch.as_tensor(audio, dtype=torch.float32).reshape(-1).cpu() + accumulated_samples += int(audio_chunk.numel()) + chunk_idx += 1 + if audio_chunk.numel() > 0: + all_audio_chunks.append(audio_chunk) + assert all_audio_chunks, "No streaming audio chunks received" + waveform = torch.cat(all_audio_chunks, dim=0) + assert waveform.numel() > MIN_AUDIO_SAMPLES + assert np.max(np.abs(waveform.numpy())) > 0.01, "Audio appears silent" + assert sample_rate is not None, "Streaming path did not return a sample rate" + assert sample_rate == SAMPLE_RATE, f"Expected Ming output sample rate {SAMPLE_RATE}, got {sample_rate}" + finally: + async_omni.shutdown() + + asyncio.run(_run()) diff --git a/tests/e2e/online_serving/test_ming_tts.py b/tests/e2e/online_serving/test_ming_tts.py new file mode 100644 index 00000000000..6b3e21c09bd --- /dev/null +++ b/tests/e2e/online_serving/test_ming_tts.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""E2E online-serving tests for Ming-omni-tts.""" + +import concurrent.futures +import io +import os +import wave +from pathlib import Path + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +import pytest + +from tests.conftest import OmniServerParams +from tests.utils import hardware_test +from vllm_omni.model_executor.models.ming_tts.config_ming_tts import SAMPLE_RATE + +MODEL = "inclusionAI/Ming-omni-tts-0.5B" +STAGE_CONFIG = str( + Path(__file__).parent.parent.parent.parent + / "vllm_omni" + / "model_executor" + / "stage_configs" + / "ming_tts_async_chunk.yaml" +) + +SERVER_PARAMS = [ + pytest.param( + OmniServerParams( + model=MODEL, + stage_config_path=STAGE_CONFIG, + server_args=["--enforce-eager", "--disable-log-stats"], + ), + id="async_chunk", + ) +] + + +def _wav_sample_rate(audio_bytes: bytes) -> int: + with wave.open(io.BytesIO(audio_bytes), "rb") as wav_file: + return int(wav_file.getframerate()) + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards=1) +@pytest.mark.parametrize("omni_server", SERVER_PARAMS, indirect=True) +def test_ming_tts_audio_speech_non_streaming(omni_server, openai_client) -> None: + """Test non-streaming Ming generation through /v1/audio/speech.""" + request_config = { + "model": omni_server.model, + "input": "我会一直在这里陪着你,直到你慢慢地沉入那个最温柔的梦里。", + "stream": False, + "response_format": "wav", + } + request_inputs = [ + "我会一直在这里陪着你,直到你慢慢地沉入那个最温柔的梦里。", + "这款产品的名字,叫变态坑爹牛肉丸。", + ] + + def _send_one(text): + per_request_config = {**request_config, "input": text} + responses = openai_client.send_audio_speech_request(per_request_config) + assert len(responses) == 1 + return text, responses[0] + + with concurrent.futures.ThreadPoolExecutor(max_workers=len(request_inputs)) as executor: + futures = [executor.submit(_send_one, text) for text in request_inputs] + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + assert {text for text, _ in results} == set(request_inputs) + assert len(results) == len(request_inputs) + for _, response in results: + assert response.audio_bytes is not None, "Expected WAV bytes from /v1/audio/speech" + sample_rate = _wav_sample_rate(response.audio_bytes) + assert sample_rate == SAMPLE_RATE, f"Expected Ming output sample rate {SAMPLE_RATE}, got {sample_rate}" + + +@pytest.mark.advanced_model +@pytest.mark.omni +@hardware_test(res={"cuda": "L4"}, num_cards=1) +@pytest.mark.parametrize("omni_server", SERVER_PARAMS, indirect=True) +def test_ming_tts_audio_speech_streaming(omni_server, openai_client) -> None: + """Test streaming Ming generation through /v1/audio/speech.""" + request_config = { + "model": omni_server.model, + "input": "这款产品的名字,叫变态坑爹牛肉丸。", + "voice": "灵小甄", + "stream": True, + "response_format": "wav", + } + openai_client.send_audio_speech_request(request_config) diff --git a/tests/engine/test_async_omni_engine_input.py b/tests/engine/test_async_omni_engine_input.py index 3700e426d42..f5433443a09 100644 --- a/tests/engine/test_async_omni_engine_input.py +++ b/tests/engine/test_async_omni_engine_input.py @@ -1,4 +1,7 @@ +from unittest.mock import Mock + import pytest +import torch from pytest_mock import MockerFixture from vllm.sampling_params import SamplingParams from vllm.v1.engine import EngineCoreRequest @@ -88,3 +91,54 @@ def test_build_add_request_message_with_resumable_streaming(mocker: MockerFixtur assert msg["type"] == "streaming_update" input_processor.process_inputs.assert_called_once() assert input_processor.process_inputs.call_args.kwargs["resumable"] is True + + +def test_build_add_request_message_uses_ingress_processed_prompt_for_additional_information(): + engine = object.__new__(AsyncOmniEngine) + params = SamplingParams(max_tokens=8) + engine.default_sampling_params_list = [params] + engine.stage_metadata = [{"stage_type": "llm"}] + engine.supported_tasks = ("speech",) + + input_processor = Mock() + input_processor.process_inputs.return_value = _make_engine_core_request() + input_processor.input_preprocessor = Mock() + prompt_latents = torch.ones((4, 64), dtype=torch.float32) + processed_prompt = { + "prompt_token_ids": [1, 2, 3, 4], + "additional_information": { + "ming_prompt_latents": prompt_latents, + "global_request_id": ["req-1"], + }, + } + input_processor.input_preprocessor.consume_last_processed_prompt.return_value = processed_prompt + engine.input_processor = input_processor + + output_processor = Mock() + engine.output_processors = [output_processor] + + raw_prompt = { + "prompt_token_ids": [1, 2, 3], + "additional_information": {}, + } + + msg = engine._build_add_request_message( + request_id="req-1", + prompt=raw_prompt, + sampling_params_list=[params], + final_stage_id=0, + arrival_time=0.0, + ) + + request = msg["prompt"] + assert isinstance(request, OmniEngineCoreRequest) + assert request.additional_information is not None + assert request.additional_information.entries["ming_prompt_latents"].tensor_shape == [4, 64] + input_processor.input_preprocessor.consume_last_processed_prompt.assert_called_once() + output_processor.add_request.assert_called_once() + call_kwargs = output_processor.add_request.call_args.kwargs + assert call_kwargs["request"] is request + assert call_kwargs["prompt"] is None + assert call_kwargs["parent_req"] is None + assert call_kwargs["request_index"] == 0 + assert call_kwargs["queue"] is None diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py index b78d62d9eda..816d87e592a 100644 --- a/tests/entrypoints/openai_api/test_serving_speech.py +++ b/tests/entrypoints/openai_api/test_serving_speech.py @@ -6,6 +6,7 @@ from inspect import Signature, signature from pathlib import Path from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock import numpy as np import pytest @@ -1929,6 +1930,48 @@ def fish_speech_server(mocker: MockerFixture): server.shutdown() +@pytest.fixture +def ming_speech_server(mocker: MockerFixture): + mocker.patch.object(OmniOpenAIServingSpeech, "_load_supported_speakers", return_value={"灵小甄"}) + mocker.patch.object(OmniOpenAIServingSpeech, "_load_codec_frame_rate", return_value=None) + + mock_engine_client = mocker.MagicMock() + mock_engine_client.errored = False + mock_engine_client.model_config = mocker.MagicMock(model="inclusionAI/Ming-omni-tts-0.5B") + mock_engine_client.default_sampling_params_list = [ + SimpleNamespace(max_tokens=512, stop_token_ids=[]), + SimpleNamespace(max_tokens=1, stop_token_ids=[]), + ] + mock_engine_client.tts_batch_max_items = 32 + mock_engine_client.generate = mocker.MagicMock(return_value="generator") + mock_engine_client.stage_configs = [ + SimpleNamespace( + engine_args=SimpleNamespace( + model_stage="llm", + model_arch="MingTTSForConditionalGeneration", + worker_type="ar", + ), + tts_args={}, + ) + ] + + mock_models = mocker.MagicMock() + mock_models.is_base_model.return_value = True + + server = OmniOpenAIServingSpeech( + engine_client=mock_engine_client, + models=mock_models, + request_logger=mocker.MagicMock(), + ) + server._build_ming_prompt = MagicMock( + return_value={ + "prompt_token_ids": [1, 2, 3], + "additional_information": {}, + } + ) + return server + + class TestFishSpeechServing: def test_build_fish_prompt_normalizes_legacy_speaker_tags(self, fish_speech_server): tokenizer = _FakeFishTokenizer() @@ -2065,6 +2108,341 @@ def test_create_speech_batch_allows_fish_text_only_items(self, fish_speech_serve fish_speech_server._generate_audio_bytes.assert_awaited_once() +class TestMingSpeechServing: + class _FakeMingTokenizer: + def __init__(self): + self._token_to_id = { + "": 9001, + "<|vision_start|>": 9002, + "<|vision_pad|>": 9003, + "<|vision_end|>\n": 9004, + } + self._next = 100 + + def encode(self, text): + if text not in self._token_to_id: + self._token_to_id[text] = self._next + self._next += 1 + return [self._token_to_id[text]] + + def convert_tokens_to_ids(self, token): + if token not in self._token_to_id: + self._token_to_id[token] = self._next + self._next += 1 + return self._token_to_id[token] + + def test_protocol_accepts_ming_podcast_ref_audio_and_nested_embeddings(self): + request = OpenAICreateSpeechRequest( + input=" speaker_1:你好。\n speaker_2:你好。\n", + ref_audio=["data:audio/wav;base64,aaa", "data:audio/wav;base64,bbb"], + ref_text=" speaker_1:参考一。\n speaker_2:参考二。\n", + speaker_embedding=[[0.1] * 192, [0.2] * 192], + ) + + assert request.ref_audio == ["data:audio/wav;base64,aaa", "data:audio/wav;base64,bbb"] + assert request.speaker_embedding == [[0.1] * 192, [0.2] * 192] + + def test_protocol_preserves_single_ming_ref_audio_and_flat_embedding(self): + single_ref = OpenAICreateSpeechRequest( + input="Hello", + ref_audio="data:audio/wav;base64,aaa", + ref_text="reference", + ) + single_embedding = OpenAICreateSpeechRequest( + input="Hello", + speaker_embedding=[0.1] * 192, + ) + + assert single_ref.ref_audio == "data:audio/wav;base64,aaa" + assert single_embedding.speaker_embedding == [0.1] * 192 + + def test_validate_ming_podcast_rules(self, ming_speech_server): + valid = OpenAICreateSpeechRequest( + input=" speaker_1:你好。\n speaker_2:你好。\n", + ref_audio=["data:audio/wav;base64,aaa", "data:audio/wav;base64,bbb"], + ref_text=" speaker_1:参考一。\n speaker_2:参考二。\n", + ) + one_clip = OpenAICreateSpeechRequest( + input=" speaker_1:你好。\n", + ref_audio=["data:audio/wav;base64,aaa"], + ref_text=" speaker_1:参考一。\n", + ) + missing_ref_text = OpenAICreateSpeechRequest( + input=" speaker_1:你好。\n speaker_2:你好。\n", + ref_audio=["data:audio/wav;base64,aaa", "data:audio/wav;base64,bbb"], + ) + mismatched_embeddings = OpenAICreateSpeechRequest( + input=" speaker_1:你好。\n speaker_2:你好。\n", + ref_audio=["data:audio/wav;base64,aaa", "data:audio/wav;base64,bbb"], + ref_text=" speaker_1:参考一。\n speaker_2:参考二。\n", + speaker_embedding=[[0.1] * 192], + ) + + assert ming_speech_server._validate_ming_tts_request(valid) is None + assert "at least two" in ming_speech_server._validate_ming_tts_request(one_clip) + assert "ref_text" in ming_speech_server._validate_ming_tts_request(missing_ref_text) + assert "one speaker embedding per ref_audio" in ming_speech_server._validate_ming_tts_request( + mismatched_embeddings + ) + + def test_validate_ming_single_speaker_clone_still_accepts_existing_shape(self, ming_speech_server): + request = OpenAICreateSpeechRequest( + input="Hello", + ref_audio="data:audio/wav;base64,aaa", + ref_text="reference text", + ) + + assert ming_speech_server._validate_ming_tts_request(request) is None + + def test_resolve_ref_audio_many_preserves_order(self, ming_speech_server): + ming_speech_server._resolve_ref_audio = AsyncMock( + side_effect=[ + ([0.1, 0.2], 24000), + ([0.3, 0.4], 44100), + ] + ) + + resolved = asyncio.run( + ming_speech_server._resolve_ref_audio_many(["data:audio/wav;base64,aaa", "data:audio/wav;base64,bbb"]) + ) + + assert resolved == [([0.1, 0.2], 24000), ([0.3, 0.4], 44100)] + ming_speech_server._resolve_ref_audio.assert_any_await("data:audio/wav;base64,aaa") + ming_speech_server._resolve_ref_audio.assert_any_await("data:audio/wav;base64,bbb") + + def test_extract_ming_speaker_embeddings_uses_one_call_per_wav(self, ming_speech_server, mocker: MockerFixture): + calls = [] + + class _FakeExtractor: + def __init__(self, model, target_sr=16000): + self.model = model + self.target_sr = target_sr + + def extract_from_waveform(self, waveform, sample_rate): + calls.append( + { + "model": self.model, + "target_sr": self.target_sr, + "shape": tuple(waveform.shape), + "sample_rate": int(sample_rate), + } + ) + return torch.full((192,), float(len(calls)), dtype=torch.float32) + + mocker.patch( + "vllm_omni.model_executor.models.ming_tts.speaker_extractor.MingSpeakerEmbeddingExtractor", + _FakeExtractor, + ) + + embeddings = ming_speech_server._extract_ming_speaker_embeddings_from_ref_audio( + [ + ([0.1, 0.2], 22050), + ([0.3, 0.4, 0.5], 44100), + ] + ) + + assert len(embeddings) == 2 + assert embeddings[0] == [1.0] * 192 + assert embeddings[1] == [2.0] * 192 + assert calls == [ + { + "model": "inclusionAI/Ming-omni-tts-0.5B", + "target_sr": 16000, + "shape": (1, 2), + "sample_rate": 22050, + }, + { + "model": "inclusionAI/Ming-omni-tts-0.5B", + "target_sr": 16000, + "shape": (1, 3), + "sample_rate": 44100, + }, + ] + + def test_build_ming_prompt_handles_multi_speaker_podcast_inputs(self, ming_speech_server): + from vllm_omni.model_executor.models.ming_tts.config_ming_tts import KEY_SPEAKER_EMBEDDING + + ming_speech_server._tts_tokenizer = self._FakeMingTokenizer() + request = OpenAICreateSpeechRequest( + input=" speaker_1:你好。\n speaker_2:你好。\n", + ref_audio=["data:audio/wav;base64,aaa", "data:audio/wav;base64,bbb"], + ref_text=" speaker_1:参考一。\n speaker_2:参考二。\n", + speaker_embedding=[[0.1] * 192, [0.2] * 192], + ) + + prompt = OmniOpenAIServingSpeech._build_ming_prompt( + ming_speech_server, + request, + ref_audio_data=[ + ([0.1] * 10, 44100), + ([0.2] * 20, 44100), + ], + ) + + info = prompt["additional_information"] + assert tuple(info[KEY_SPEAKER_EMBEDDING].shape) == (2, 192) + assert int(info["prompt_waveform_length"].item()) >= 30 + assert info["prompt_text"] == " speaker_1:参考一。\n speaker_2:参考二。\n" + assert ( + prompt["prompt_token_ids"].count( + ming_speech_server._tts_tokenizer.convert_tokens_to_ids("<|vision_start|>") + ) + == 2 + ) + + def test_build_ming_prompt_concatenates_podcast_waveforms_before_builder( + self, ming_speech_server, mocker: MockerFixture + ): + captured = {} + + def _fake_build_ming_dense_prompt(*args, **kwargs): + captured.update(kwargs) + return {"prompt_token_ids": [1], "additional_information": {}} + + mocker.patch( + "vllm_omni.model_executor.models.ming_tts.prompt_builder.build_ming_dense_prompt", + side_effect=_fake_build_ming_dense_prompt, + ) + ming_speech_server._tts_tokenizer = object() + request = OpenAICreateSpeechRequest( + input=" speaker_1:你好。\n speaker_2:你好。\n", + ref_audio=["data:audio/wav;base64,aaa", "data:audio/wav;base64,bbb"], + ref_text=" speaker_1:参考一。\n speaker_2:参考二。\n", + speaker_embedding=[[0.1] * 192, [0.2] * 192], + ) + + OmniOpenAIServingSpeech._build_ming_prompt( + ming_speech_server, + request, + ref_audio_data=[ + ([0.1] * 10, 44100), + ([0.2] * 20, 44100), + ], + ) + + assert tuple(captured["prompt_waveform"].shape) == (1, 30) + assert captured["speaker_embedding"] == [[0.1] * 192, [0.2] * 192] + assert captured["prompt_text"] == " speaker_1:参考一。\n speaker_2:参考二。\n" + + def test_build_ming_prompt_uses_single_ref_audio_as_speaker_only_without_ref_text( + self, ming_speech_server, mocker: MockerFixture + ): + captured = {} + + def _fake_build_ming_dense_prompt(*args, **kwargs): + captured.update(kwargs) + return {"prompt_token_ids": [1], "additional_information": {}} + + mocker.patch( + "vllm_omni.model_executor.models.ming_tts.prompt_builder.build_ming_dense_prompt", + side_effect=_fake_build_ming_dense_prompt, + ) + ming_speech_server._tts_tokenizer = object() + request = OpenAICreateSpeechRequest( + input="我竟然抢到了陈奕迅的演唱会门票!", + ref_audio="data:audio/wav;base64,aaa", + speaker_embedding=[0.1] * 192, + instructions='{"情感":"高兴"}', + ) + + OmniOpenAIServingSpeech._build_ming_prompt( + ming_speech_server, + request, + ref_audio_data=([0.1] * 10, 44100), + ) + + assert captured["prompt_waveform"] is None + assert captured["prompt_text"] is None + assert captured["speaker_embedding"] == [0.1] * 192 + + def test_build_ming_prompt_keeps_single_ref_audio_waveform_with_ref_text( + self, ming_speech_server, mocker: MockerFixture + ): + captured = {} + + def _fake_build_ming_dense_prompt(*args, **kwargs): + captured.update(kwargs) + return {"prompt_token_ids": [1], "additional_information": {}} + + mocker.patch( + "vllm_omni.model_executor.models.ming_tts.prompt_builder.build_ming_dense_prompt", + side_effect=_fake_build_ming_dense_prompt, + ) + ming_speech_server._tts_tokenizer = object() + request = OpenAICreateSpeechRequest( + input="我们的愿景是构建未来服务业的数字化基础设施。", + ref_audio="data:audio/wav;base64,aaa", + ref_text="在此奉劝大家别乱打美白针。", + speaker_embedding=[0.1] * 192, + ) + + OmniOpenAIServingSpeech._build_ming_prompt( + ming_speech_server, + request, + ref_audio_data=([0.1] * 10, 44100), + ) + + assert tuple(captured["prompt_waveform"].shape) == (1, 10) + assert captured["prompt_text"] == "在此奉劝大家别乱打美白针。" + assert captured["speaker_embedding"] == [0.1] * 192 + + def test_prepare_speech_generation_sets_ming_stop_token(self, ming_speech_server): + from vllm_omni.model_executor.models.ming_tts.config_ming_tts import TEXT_EOS_TOKEN_ID + + request = OpenAICreateSpeechRequest( + input="这款产品的名字,叫变态坑爹牛肉丸。", + voice="灵小甄", + ) + + request_id, generator, _ = asyncio.run(ming_speech_server._prepare_speech_generation(request)) + + assert request_id.startswith("speech-") + assert generator == "generator" + sampling_params_list = ming_speech_server.engine_client.generate.call_args.kwargs["sampling_params_list"] + assert sampling_params_list[0].stop_token_ids == [int(TEXT_EOS_TOKEN_ID)] + assert sampling_params_list[0].max_tokens == 512 + assert ming_speech_server.engine_client.default_sampling_params_list[0].stop_token_ids == [] + assert ming_speech_server.engine_client.default_sampling_params_list[0].max_tokens == 512 + + def test_prepare_speech_generation_overrides_ming_stage_max_tokens(self, ming_speech_server): + from vllm_omni.model_executor.models.ming_tts.config_ming_tts import TEXT_EOS_TOKEN_ID + + request = OpenAICreateSpeechRequest( + input="这款产品的名字,叫变态坑爹牛肉丸。", + voice="灵小甄", + max_new_tokens=16, + ) + + request_id, generator, _ = asyncio.run(ming_speech_server._prepare_speech_generation(request)) + + assert request_id.startswith("speech-") + assert generator == "generator" + sampling_params_list = ming_speech_server.engine_client.generate.call_args.kwargs["sampling_params_list"] + assert sampling_params_list[0].stop_token_ids == [int(TEXT_EOS_TOKEN_ID)] + assert sampling_params_list[0].max_tokens == 17 + assert ming_speech_server.engine_client.default_sampling_params_list[0].max_tokens == 512 + + def test_prepare_speech_generation_extracts_ming_single_ref_audio_speaker_embedding( + self, ming_speech_server, mocker: MockerFixture + ): + request = OpenAICreateSpeechRequest( + input="我竟然抢到了陈奕迅的演唱会门票!", + ref_audio="data:audio/wav;base64,aaa", + instructions='{"情感":"高兴"}', + ) + ming_speech_server._resolve_ref_audio = AsyncMock(return_value=([0.1, 0.2], 44100)) + ming_speech_server._extract_ming_speaker_embeddings_from_ref_audio = mocker.MagicMock( + return_value=[[0.3] * 192] + ) + + asyncio.run(ming_speech_server._prepare_speech_generation(request)) + + ming_speech_server._extract_ming_speaker_embeddings_from_ref_audio.assert_called_once_with( + [([0.1, 0.2], 44100)] + ) + assert request.speaker_embedding == [0.3] * 192 + + class TestWAVStreaming: """Integration tests for WAV format streaming.""" diff --git a/tests/model_executor/models/ming_tts/test_ming_tts_components.py b/tests/model_executor/models/ming_tts/test_ming_tts_components.py new file mode 100644 index 00000000000..14c4c02db05 --- /dev/null +++ b/tests/model_executor/models/ming_tts/test_ming_tts_components.py @@ -0,0 +1,505 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from vllm_omni.model_executor.models.ming_tts.audio_tokenizer.configuration_audio_vae import AudioVAEconfig +from vllm_omni.model_executor.models.ming_tts.audio_tokenizer.istft import ISTFT, ISTFTHead +from vllm_omni.model_executor.models.ming_tts.audio_tokenizer.modeling_audio_vae import AudioVAE +from vllm_omni.model_executor.models.ming_tts.audio_tokenizer.vae_modules import StreamingLinearUpsample +from vllm_omni.model_executor.models.ming_tts.fm.cfm import CFM, Solver, get_epss_timesteps +from vllm_omni.model_executor.models.ming_tts.fm.dit import ( + Aggregator, + CondEmbedder, + DiT, + SinusPositionEmbedding, + TimestepEmbedder, +) +from vllm_omni.model_executor.models.ming_tts.fm.flowloss import FlowLoss +from vllm_omni.model_executor.models.ming_tts.fm.modules import Attention, DiTBlock, RMSNorm +from vllm_omni.model_executor.models.ming_tts.ming_tts import ( + _coerce_prompt_latents, + _find_audio_placeholder_positions, + _initial_history, +) +from vllm_omni.model_executor.models.ming_tts.ming_tts_audio_vae import _coerce_finished, _coerce_latent_chunk +from vllm_omni.model_executor.models.ming_tts.ming_tts_llm import _coerce_latent_history +from vllm_omni.model_executor.stage_input_processors.ming_tts import llm2audio_vae_async_chunk + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _tiny_qwen_config(hidden_size=8): + return { + "hidden_size": hidden_size, + "intermediate_size": hidden_size * 2, + "num_hidden_layers": 1, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "vocab_size": 32, + "max_position_embeddings": 64, + } + + +def _tiny_audio_vae_config(): + return AudioVAEconfig( + sample_rate=16000, + patch_size=2, + enc_kwargs={ + "backbone": _tiny_qwen_config(), + "input_dim": 4, + "hop_size": 4, + "latent_dim": 2, + }, + dec_kwargs={ + "backbone": _tiny_qwen_config(), + "output_dim": 4, + "latent_dim": 2, + }, + semantic_module_kwargs=None, + ) + + +class _DummyCFMModel(nn.Module): + def __init__(self): + super().__init__() + self.anchor = nn.Parameter(torch.zeros(())) + + def forward(self, x, t, c, latent_history, mask=None): + del t, c, latent_history + if mask is not None: + x = x.masked_fill(~mask.unsqueeze(-1), 0.0) + return x + + def forward_with_cfg(self, x, t, c, cfg_scale, latent_history, patch_size): + del t, c, cfg_scale, latent_history + cond = x[:, -patch_size:, :] + 1.0 + uncond = x[:, -patch_size:, :] + return torch.cat([cond, uncond], dim=0) + + +def test_rmsnorm_preserves_shape_and_dtype(): + norm = RMSNorm(dim=8, eps=1e-6) + x = torch.randn(2, 3, 8, dtype=torch.float32) + + out = norm(x) + + assert out.shape == x.shape + assert out.dtype == x.dtype + + +def test_attention_forward_shape_and_mask(): + attn = Attention(dim=8, heads=2, dim_head=4, dropout=0.0) + x = torch.randn(1, 5, 8) + mask = torch.tensor([[True, True, True, True, False]]) + + out = attn(x, mask=mask) + + assert out.shape == x.shape + assert torch.allclose(out[:, -1], torch.zeros_like(out[:, -1])) + + +def test_attention_rejects_bad_mask_shape(): + attn = Attention(dim=8, heads=2, dim_head=4, dropout=0.0) + x = torch.randn(1, 5, 8) + + with pytest.raises(ValueError, match="Mask shape mismatch"): + attn(x, mask=torch.ones(1, 4, dtype=torch.bool)) + + +def test_dit_block_forward_shape(): + block = DiTBlock(hidden_size=8, num_heads=2, mlp_ratio=2.0, dropout=0.0) + x = torch.randn(1, 5, 8) + mask = torch.ones(1, 5, dtype=torch.bool) + + out = block(x, mask, rope=None) + + assert out.shape == x.shape + + +def test_sinus_position_embedding_shape(): + embed = SinusPositionEmbedding(dim=8) + t = torch.tensor([0.0, 1.0], dtype=torch.float32) + + out = embed(t) + + assert out.shape == (2, 8) + + +def test_timestep_embedder_distinguishes_steps(): + embedder = TimestepEmbedder(dim=8, freq_embed_dim=8) + + out_a = embedder(torch.tensor([0.0], dtype=torch.float32)) + out_b = embedder(torch.tensor([1.0], dtype=torch.float32)) + + assert out_a.shape == (1, 8) + assert not torch.allclose(out_a, out_b) + + +def test_cond_embedder_rejects_bad_rank(): + embedder = CondEmbedder(input_feature_size=4, hidden_size=8, dropout_prob=0.0) + + with pytest.raises(ValueError, match="rank-3"): + embedder(torch.randn(1, 4), train=False) + + +def test_cond_drop_preserves_conditioning_dtype(): + embedder = CondEmbedder(input_feature_size=4, hidden_size=8, dropout_prob=1.0) + llm_cond = torch.randn(1, 1, 4, dtype=torch.float16) + + out = embedder.cond_drop(llm_cond) + + assert out.dtype == llm_cond.dtype + + +def test_dit_forward_shape(): + model = DiT( + in_channels=2, + hidden_size=8, + depth=1, + num_heads=2, + mlp_ratio=2.0, + llm_cond_dim=4, + cfg_dropout_prob=0.0, + ) + x = torch.randn(1, 2, 2) + latent_history = torch.randn(1, 4, 2) + c = torch.randn(1, 1, 4) + mask = torch.ones(1, 2, dtype=torch.bool) + + out = model(x=x, t=torch.tensor([0.5]), c=c, latent_history=latent_history, mask=mask) + + assert out.shape == (1, 7, 2) + + +def test_dit_forward_with_cfg_preserves_conditioning_dtype(monkeypatch): + model = DiT( + in_channels=2, + hidden_size=8, + depth=1, + num_heads=2, + mlp_ratio=2.0, + llm_cond_dim=4, + cfg_dropout_prob=0.0, + ) + seen = {} + + def _fake_forward(x, t, c, latent_history, mask=None): + del x, t, latent_history, mask + seen["dtype"] = c.dtype + return torch.zeros((c.shape[0], 7, 2), dtype=torch.float32) + + monkeypatch.setattr(model, "forward", _fake_forward) + x = torch.randn(1, 2, 2, dtype=torch.float16) + latent_history = torch.randn(1, 4, 2, dtype=torch.float16) + c = torch.randn(1, 1, 4, dtype=torch.float16) + + model.forward_with_cfg( + x=x, + t=torch.tensor([0.5], dtype=torch.float16), + c=c, + cfg_scale=2.0, + latent_history=latent_history, + patch_size=2, + ) + + assert seen["dtype"] == c.dtype + + +def test_aggregator_forward_shape(): + agg = Aggregator( + in_channels=2, + hidden_size=8, + depth=1, + num_heads=2, + mlp_ratio=2.0, + llm_input_dim=4, + ) + x = torch.randn(2, 3, 2) + mask = torch.ones(2, 3, dtype=torch.bool) + + out = agg(x, mask=mask) + + assert out.shape == (2, 1, 4) + + +def test_get_epss_timesteps_predefined_and_fallback(): + predefined = get_epss_timesteps(10, device=torch.device("cpu"), dtype=torch.float32) + fallback = get_epss_timesteps(9, device=torch.device("cpu"), dtype=torch.float32) + + assert predefined.shape == (11,) + assert torch.allclose(predefined[-1], torch.tensor(1.0)) + assert fallback.shape == (10,) + assert torch.allclose(fallback, torch.linspace(0, 1, 10)) + + +def test_solver_integrate_zero_function_is_stable(): + y0 = torch.ones(1, 2, 2) + solver = Solver(lambda t, y: torch.zeros_like(y), y0=y0, sigma=0.0, temperature=0.0) + t = torch.linspace(0, 1, 4) + + out = solver.integrate(t) + + assert out.shape == (4, 1, 2, 2) + assert torch.allclose(out[0], y0) + assert torch.allclose(out[-1], y0) + + +def test_cfm_forward_returns_scalar_loss(): + torch.manual_seed(0) + cfm = CFM(model=_DummyCFMModel()) + cond = torch.randn(1, 1, 4) + target = torch.randn(1, 2, 2) + latent_history = torch.randn(1, 4, 2) + mask = torch.ones(1, 2, dtype=torch.bool) + + loss = cfm(cond=cond, target=target, latent_history=latent_history, mask=mask, patch_size=2) + + assert loss.ndim == 0 + assert torch.isfinite(loss) + + +def test_cfm_sample_returns_sample_and_trajectory(): + torch.manual_seed(0) + cfm = CFM(model=_DummyCFMModel()) + noise = torch.randn(1, 2, 2) + cond = torch.randn(1, 1, 4) + latent_history = torch.randn(1, 4, 2) + + out, trajectory = cfm.sample(noise=noise, c=cond, latent_history=latent_history, steps=4, patch_size=2) + + assert out.shape == (1, 2, 2) + assert trajectory.shape == (5, 1, 2, 2) + + +def test_cfm_sample_rejects_low_cfg_scale(): + cfm = CFM(model=_DummyCFMModel()) + noise = torch.randn(1, 2, 2) + cond = torch.randn(1, 1, 4) + latent_history = torch.randn(1, 4, 2) + + out, trajectory = cfm.sample( + noise=noise, + c=cond, + latent_history=latent_history, + cfg_scale=0.0, + patch_size=2, + ) + + assert out.shape == (1, 2, 2) + assert trajectory.ndim == 4 + + +def test_flowloss_sample_returns_tensor_shape_and_dtype(monkeypatch): + flow = FlowLoss( + z_channels=2, + llm_cond_dim=4, + hidden_size=8, + depth=1, + num_heads=2, + mlp_ratio=2.0, + cfg_dropout_prob=0.0, + ) + + def _fake_sample(**kwargs): + noise = kwargs["noise"] + return noise.transpose(1, 2), torch.zeros(1) + + monkeypatch.setattr(flow.cfm, "sample", _fake_sample) + z = torch.randn(1, 1, 4, dtype=torch.float32) + latent_history = torch.randn(1, 4, 2, dtype=torch.float32) + + out = flow.sample(z=z, latent_history=latent_history, patch_size=3) + + assert out.shape == (1, 3, 2) + assert out.dtype == z.dtype + + +def test_streaming_linear_upsample_rejects_empty_final_flush(): + upsample = StreamingLinearUpsample(scale_factor=2) + + with pytest.raises(ValueError, match="end-of-stream"): + upsample(None, state=None, is_last=True) + + +def test_streaming_linear_upsample_streams_and_flushes(): + upsample = StreamingLinearUpsample(scale_factor=2) + chunk_a = torch.randn(1, 2, 3) + chunk_b = torch.randn(1, 2, 3) + + out_a, state = upsample(chunk_a, state=None, is_last=False) + out_b, state = upsample(chunk_b, state=state, is_last=True) + + assert out_a is None + assert out_b is not None + assert out_b.shape[0] == 1 + assert out_b.shape[-1] == 3 + assert state is None + + +def test_istft_rejects_bad_rank(): + istft = ISTFT(n_fft=16, hop_length=4, win_length=16, padding="same") + + with pytest.raises(ValueError, match="rank-3"): + istft(torch.randn(1, 9)) + + +def test_istft_head_output_shape(): + head = ISTFTHead(dim=8, n_fft=16, hop_length=4, padding="same") + x = torch.randn(1, 3, 8) + + audio, spec, audio_buffer, window_buffer = head(x) + + assert audio.shape[0] == 1 + assert audio.shape[1] == 1 + assert spec.shape == (1, 18, 3) + assert audio_buffer is None + assert window_buffer is None + + +def test_audio_vae_encode_and_decode_shapes(): + torch.manual_seed(0) + vae = AudioVAE(_tiny_audio_vae_config()) + waveform = torch.randn(1, 12) + waveform_length = torch.tensor([12], dtype=torch.int32) + + latent, frame_num = vae.encode_latent(waveform, waveform_length) + audio, stream_state, past_key_values = vae.decode(latent, use_cache=False) + + assert latent.ndim == 3 + assert latent.shape[0] == 1 + assert latent.shape[-1] == 2 + assert frame_num.tolist() == [2] + assert audio.ndim == 3 + assert audio.shape[0] == 1 + assert audio.shape[1] == 1 + assert stream_state == (None, None, None) + assert past_key_values is None + + +def test_audio_vae_rejects_invalid_inputs(): + vae = AudioVAE(_tiny_audio_vae_config()) + + with pytest.raises(ValueError, match="waveform rank-2"): + vae.encode_latent(torch.randn(12), torch.tensor([12], dtype=torch.int32)) + + with pytest.raises(ValueError, match="Latent dim mismatch"): + vae.decode(torch.randn(1, 2, 3)) + + +def test_coerce_prompt_latents_supports_frames_and_patch_groups(): + frames = torch.arange(8, dtype=torch.float32).reshape(4, 2) + patches = torch.arange(16, dtype=torch.float32).reshape(2, 2, 4) + + out_frames = _coerce_prompt_latents(frames, patch_size=2, latent_dim=2) + out_patches = _coerce_prompt_latents(patches, patch_size=2, latent_dim=4) + + assert out_frames["patches"].shape == (2, 2, 2) + assert out_frames["frames"].shape == (4, 2) + assert out_patches["patches"].shape == (2, 2, 4) + assert out_patches["frames"].shape == (4, 4) + + +def test_initial_history_keeps_tail(): + frames = torch.arange(12, dtype=torch.float32).reshape(6, 2) + + history = _initial_history( + frames, + history_size=4, + latent_dim=2, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + assert history.shape == (4, 2) + assert torch.allclose(history, frames[-4:]) + + +def test_find_audio_placeholder_positions_uses_audio_span(): + cfg = SimpleNamespace( + audio_dummy_token_id=151705, + audio_start_token_id=151706, + audio_end_token_id=151707, + ) + input_ids = torch.tensor([151705, 1, 151706, 151705, 151705, 151707, 151705], dtype=torch.long) + + out = _find_audio_placeholder_positions(input_ids, cfg) + + assert out.tolist() == [3, 4] + + +def test_helper_coercions_fail_loudly(): + cfg = SimpleNamespace(history_patch_size=4, latent_dim=2) + + assert _coerce_finished(torch.tensor([1], dtype=torch.bool)) is True + latent_chunk = _coerce_latent_chunk( + torch.ones(4, 2), + device=torch.device("cpu"), + dtype=torch.float32, + latent_dim=2, + patch_size=4, + ) + assert latent_chunk.shape == (1, 4, 2) + + grouped_chunk = _coerce_latent_chunk( + torch.ones(2, 4, 2), + device=torch.device("cpu"), + dtype=torch.float32, + latent_dim=2, + patch_size=4, + ) + assert grouped_chunk.shape == (2, 4, 2) + + with pytest.raises(RuntimeError, match="latent_history shape mismatch"): + _coerce_latent_history(torch.ones(3, 2), device=torch.device("cpu"), dtype=torch.float32, cfg=cfg) + + with pytest.raises(ValueError, match="Latent patch size mismatch"): + _coerce_latent_chunk( + torch.ones(1, 3, 2), + device=torch.device("cpu"), + dtype=torch.float32, + latent_dim=2, + patch_size=4, + ) + + with pytest.raises(ValueError, match="Latent dim mismatch"): + _coerce_latent_chunk( + torch.ones(4, 3), + device=torch.device("cpu"), + dtype=torch.float32, + latent_dim=2, + patch_size=4, + ) + + +def test_ming_async_chunk_rejects_left_context_replay(): + transfer_manager = SimpleNamespace( + connector=SimpleNamespace(config={"extra": {"latent_chunk_size": 10, "latent_left_context": 1}}), + put_req_chunk={"req-1": 0}, + request_payload={}, + ) + request = SimpleNamespace(external_req_id="req-1", is_finished=lambda: False) + + with pytest.raises(ValueError, match="latent_left_context replay"): + llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=None, + request=request, + is_finished=False, + ) + + +def test_coerce_latent_history_casts_to_requested_dtype(): + cfg = SimpleNamespace(history_patch_size=4, latent_dim=2) + + history = _coerce_latent_history( + torch.ones(1, 4, 2, dtype=torch.float16), + device=torch.device("cpu"), + dtype=torch.float32, + cfg=cfg, + ) + + assert history.dtype == torch.float32 diff --git a/tests/model_executor/models/ming_tts/test_ming_tts_config_shim.py b/tests/model_executor/models/ming_tts/test_ming_tts_config_shim.py new file mode 100644 index 00000000000..06cd4a8a787 --- /dev/null +++ b/tests/model_executor/models/ming_tts/test_ming_tts_config_shim.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import AutoConfig + +from vllm_omni.engine.arg_utils import _register_omni_hf_configs +from vllm_omni.model_executor.models.ming_tts.configuration_ming_dense import MingDenseConfig + + +def test_ming_dense_autoconfig_registration_uses_local_config(tmp_path): + _register_omni_hf_configs() + model_dir = tmp_path / "ming" + model_dir.mkdir() + (model_dir / "config.json").write_text( + """ +{ + "model_type": "dense", + "auto_map": {"AutoConfig": "configuration_bailingmm.BailingMMConfig"}, + "llm_config": { + "model_type": "qwen2", + "hidden_size": 896, + "intermediate_size": 4864, + "num_hidden_layers": 24, + "num_attention_heads": 14, + "num_key_value_heads": 2, + "vocab_size": 151936 + }, + "audio_tokenizer_config": { + "sample_rate": 44100, + "patch_size": 4, + "enc_kwargs": { + "latent_dim": 64, + "input_dim": 882, + "hop_size": 882, + "backbone": {"attn_implementation": "flash_attention_2"} + }, + "dec_kwargs": { + "latent_dim": 64, + "output_dim": 882, + "backbone": {"_attn_implementation": "flash_attention_2"} + } + } +} +""".strip() + ) + + cfg = AutoConfig.from_pretrained(model_dir, trust_remote_code=False, local_files_only=True) + + assert isinstance(cfg, MingDenseConfig) + assert cfg.get_text_config().num_attention_heads == 14 + assert cfg.audio_tokenizer_config.sample_rate == 44100 + assert cfg.audio_tokenizer_config.patch_size == 4 diff --git a/tests/model_executor/models/ming_tts/test_ming_tts_loaders.py b/tests/model_executor/models/ming_tts/test_ming_tts_loaders.py new file mode 100644 index 00000000000..b7f95469bf4 --- /dev/null +++ b/tests/model_executor/models/ming_tts/test_ming_tts_loaders.py @@ -0,0 +1,524 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + +import pytest +import torch +from vllm.v1.outputs import SamplerOutput + +from vllm_omni.model_executor.models.ming_tts.config_ming_tts import KEY_PROMPT_LATENTS, KEY_REQUEST_ID, MingTTSConfig +from vllm_omni.model_executor.models.ming_tts.ming_tts import MingTTSForConditionalGeneration +from vllm_omni.model_executor.models.ming_tts.ming_tts_audio_vae import MingAudioVAEModel +from vllm_omni.model_executor.models.ming_tts.ming_tts_llm import MingLLMModel + + +class _DummyBackbone(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = torch.nn.Module() + self.model.layers = torch.nn.ModuleList([torch.nn.Linear(2, 2, bias=False)]) + self.last_forward_kwargs = None + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return torch.zeros((input_ids.shape[0], 2), dtype=torch.float32) + + def forward(self, *args, **kwargs): + del args + self.last_forward_kwargs = dict(kwargs) + return torch.zeros((1, 2), dtype=torch.float32) + + +class _DummyAggregator(torch.nn.Module): + def __init__(self, in_channels: int, llm_input_dim: int, **kwargs): + super().__init__() + del kwargs + self.proj_in = torch.nn.Linear(in_channels, llm_input_dim, bias=False) + + def forward(self, patch: torch.Tensor) -> torch.Tensor: + return self.proj_in(patch.mean(dim=1)).unsqueeze(1) + + +class _DummyFlowLoss(torch.nn.Module): + def __init__(self, z_channels: int, llm_cond_dim: int, **kwargs): + super().__init__() + del z_channels, kwargs + self.dummy = torch.nn.Linear(llm_cond_dim, 64, bias=False) + + def sample(self, **kwargs): + del kwargs + return torch.zeros((1, 4, 64), dtype=torch.float32) + + +class _DummyAudioVAE(torch.nn.Module): + def __init__(self, config): + super().__init__() + del config + self.encoder = torch.nn.Linear(2, 2, bias=False) + self.decoder = torch.nn.Linear(2, 2, bias=False) + self.last_chunk_values = [] + + def encode_latent(self, waveform: torch.Tensor, waveform_length: torch.Tensor): + del waveform_length + batch = int(waveform.shape[0]) + return torch.zeros((batch, 8, 64), dtype=torch.float32), None + + def decode( + self, + latent_patch: torch.Tensor, + *, + past_key_values=None, + use_cache=True, + stream_state=None, + last_chunk=False, + ): + del past_key_values, use_cache, stream_state + self.last_chunk_values.append(last_chunk) + samples = int(latent_patch.shape[1]) * 8 + waveform = torch.ones((1, 1, samples), dtype=torch.float32) + return waveform, (None, None, None), None + + +def _make_audio_cfg(): + return SimpleNamespace( + enc_kwargs={ + "backbone": {"hidden_size": 2}, + "input_dim": 882, + "hop_size": 882, + "latent_dim": 64, + }, + dec_kwargs={ + "backbone": {"hidden_size": 2}, + "output_dim": 882, + "latent_dim": 64, + }, + patch_size=4, + sample_rate=44100, + semantic_module_kwargs=None, + ) + + +def _make_config() -> MingTTSConfig: + cfg = MingTTSConfig(audio_tokenizer_config=_make_audio_cfg()) + cfg.validate() + return cfg + + +def _make_vllm_config(model_stage: str): + return SimpleNamespace( + model_config=SimpleNamespace(hf_config=SimpleNamespace(), model_stage=model_stage), + quant_config=None, + device_config=SimpleNamespace(device=torch.device("cpu")), + ) + + +def test_ming_llm_load_weights_maps_and_loads_expected_prefixes(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_llm as llm_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(llm_mod, "init_vllm_registered_model", lambda **kwargs: _DummyBackbone()) + monkeypatch.setattr(llm_mod, "Aggregator", _DummyAggregator) + monkeypatch.setattr(llm_mod, "FlowLoss", _DummyFlowLoss) + + model = MingLLMModel(vllm_config=_make_vllm_config("llm")) + weights = [ + ("model.model.layers.0.weight", torch.full((2, 2), 1.0, dtype=torch.float32)), + ("linear_proj_audio.proj_in.weight", torch.full((896, 64), 2.0, dtype=torch.float32)), + ("flowloss.dummy.weight", torch.full((64, 896), 3.0, dtype=torch.float32)), + ("stop_head.weight", torch.full((2, 896), 4.0, dtype=torch.float32)), + ("stop_head.bias", torch.full((2,), 5.0, dtype=torch.float32)), + ("spk_head.weight", torch.full((896, 192), 6.0, dtype=torch.float32)), + ("spk_head.bias", torch.full((896,), 7.0, dtype=torch.float32)), + ] + + loaded = model.load_weights(weights) + + assert "model.model.layers.0.weight" in loaded + assert "linear_proj_audio.proj_in.weight" in loaded + assert "flowloss.dummy.weight" in loaded + assert "stop_head.weight" in loaded + assert "spk_head.weight" in loaded + assert torch.allclose(model.model.model.layers[0].weight, torch.full((2, 2), 1.0)) + assert torch.allclose(model.linear_proj_audio.proj_in.weight, torch.full((896, 64), 2.0)) + assert torch.allclose(model.flowloss.dummy.weight, torch.full((64, 896), 3.0)) + + +def test_ming_llm_load_weights_accepts_complete_checkpoint_and_forward_shape(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_llm as llm_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(llm_mod, "init_vllm_registered_model", lambda **kwargs: _DummyBackbone()) + monkeypatch.setattr(llm_mod, "Aggregator", _DummyAggregator) + monkeypatch.setattr(llm_mod, "FlowLoss", _DummyFlowLoss) + + model = MingLLMModel(vllm_config=_make_vllm_config("llm")) + model.load_weights( + [ + ("model.layers.0.weight", torch.ones((2, 2), dtype=torch.float32)), + ("linear_proj_audio.proj_in.weight", torch.ones((896, 64), dtype=torch.float32)), + ("flowloss.dummy.weight", torch.ones((64, 896), dtype=torch.float32)), + ("stop_head.weight", torch.ones((2, 896), dtype=torch.float32)), + ("stop_head.bias", torch.ones((2,), dtype=torch.float32)), + ("spk_head.weight", torch.ones((896, 192), dtype=torch.float32)), + ("spk_head.bias", torch.ones((896,), dtype=torch.float32)), + ] + ) + + output = model.forward( + input_ids=torch.tensor([1], dtype=torch.long), + positions=torch.tensor([0], dtype=torch.long), + ) + + assert output.text_hidden_states.shape == (1, 2) + assert output.multimodal_outputs is None + + +def test_ming_llm_load_weights_fails_when_custom_heads_missing(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_llm as llm_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(llm_mod, "init_vllm_registered_model", lambda **kwargs: _DummyBackbone()) + monkeypatch.setattr(llm_mod, "Aggregator", _DummyAggregator) + monkeypatch.setattr(llm_mod, "FlowLoss", _DummyFlowLoss) + + model = MingLLMModel(vllm_config=_make_vllm_config("llm")) + weights = [ + ("model.layers.0.weight", torch.full((2, 2), 1.0, dtype=torch.float32)), + ("stop_head.weight", torch.full((2, 896), 4.0, dtype=torch.float32)), + ("stop_head.bias", torch.full((2,), 5.0, dtype=torch.float32)), + ("spk_head.weight", torch.full((896, 192), 6.0, dtype=torch.float32)), + ("spk_head.bias", torch.full((896,), 7.0, dtype=torch.float32)), + ] + + with pytest.raises(RuntimeError, match="flowloss|linear_proj_audio"): + model.load_weights(weights) + + +def test_ming_llm_load_weights_rejects_incomplete_checkpoint(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_llm as llm_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(llm_mod, "init_vllm_registered_model", lambda **kwargs: _DummyBackbone()) + monkeypatch.setattr(llm_mod, "Aggregator", _DummyAggregator) + monkeypatch.setattr(llm_mod, "FlowLoss", _DummyFlowLoss) + + model = MingLLMModel(vllm_config=_make_vllm_config("llm")) + + with pytest.raises(RuntimeError, match="flowloss|linear_proj_audio|stop_head|spk_head"): + model.load_weights( + [ + ("model.layers.0.weight", torch.ones((2, 2), dtype=torch.float32)), + ("stop_head.weight", torch.ones((2, 896), dtype=torch.float32)), + ("stop_head.bias", torch.ones((2,), dtype=torch.float32)), + ] + ) + + +def test_ming_audio_vae_load_weights_fails_when_audio_params_missing(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_audio_vae as vae_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(vae_mod, "AudioVAE", _DummyAudioVAE) + + model = MingAudioVAEModel(vllm_config=_make_vllm_config("audio_vae")) + + with pytest.raises(RuntimeError, match="params not loaded"): + model.load_weights( + [ + ("audio.encoder.weight", torch.full((2, 2), 1.0, dtype=torch.float32)), + ] + ) + + +def test_ming_audio_vae_load_weights_accepts_complete_checkpoint_and_forward_shape(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_audio_vae as vae_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(vae_mod, "AudioVAE", _DummyAudioVAE) + + model = MingAudioVAEModel(vllm_config=_make_vllm_config("audio_vae")) + model.load_weights( + [ + ("audio.encoder.weight", torch.ones((2, 2), dtype=torch.float32)), + ("audio.decoder.weight", torch.ones((2, 2), dtype=torch.float32)), + ] + ) + + output = model.forward( + runtime_additional_information=[ + { + KEY_REQUEST_ID: "rid-audio", + "ming_latent_patches": torch.ones((1, 4, 64), dtype=torch.float32), + "stream_finished": torch.tensor(True, dtype=torch.bool), + } + ] + ) + + waveform = output.multimodal_outputs["model_outputs"][0] + sample_rate = output.multimodal_outputs["sr"][0] + assert waveform.ndim == 1 + assert waveform.dtype == torch.float32 + assert waveform.shape == (32,) + assert int(sample_rate.item()) == 44100 + assert model.audio.last_chunk_values == [True] + + +def test_ming_audio_vae_load_weights_rejects_incomplete_checkpoint(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_audio_vae as vae_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(vae_mod, "AudioVAE", _DummyAudioVAE) + + model = MingAudioVAEModel(vllm_config=_make_vllm_config("audio_vae")) + + with pytest.raises(RuntimeError, match="params not loaded|no checkpoint weights"): + model.load_weights( + [ + ("audio.encoder.weight", torch.ones((2, 2), dtype=torch.float32)), + ] + ) + + +def test_ming_audio_vae_load_weights_rejects_empty_input(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_audio_vae as vae_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(vae_mod, "AudioVAE", _DummyAudioVAE) + + model = MingAudioVAEModel(vllm_config=_make_vllm_config("audio_vae")) + + with pytest.raises(RuntimeError, match="no checkpoint weights"): + model.load_weights([]) + + +def test_ming_llm_forward_drops_runner_only_kwargs(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_llm as llm_mod + + cfg = _make_config() + backbone = _DummyBackbone() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(llm_mod, "init_vllm_registered_model", lambda **kwargs: backbone) + monkeypatch.setattr(llm_mod, "Aggregator", _DummyAggregator) + monkeypatch.setattr(llm_mod, "FlowLoss", _DummyFlowLoss) + + model = MingLLMModel(vllm_config=_make_vllm_config("llm")) + output = model.forward( + input_ids=torch.tensor([1], dtype=torch.long), + positions=torch.tensor([0], dtype=torch.long), + sampling_metadata=object(), + logits_index=0, + sampler=object(), + additional_information={"text": "hello"}, + ) + + assert set(backbone.last_forward_kwargs) == { + "input_ids", + "positions", + "intermediate_tensors", + "inputs_embeds", + } + assert torch.equal(backbone.last_forward_kwargs["input_ids"], torch.tensor([1], dtype=torch.long)) + assert torch.equal(backbone.last_forward_kwargs["positions"], torch.tensor([0], dtype=torch.long)) + assert backbone.last_forward_kwargs["intermediate_tensors"] is None + assert torch.allclose(backbone.last_forward_kwargs["inputs_embeds"], torch.zeros((1, 2), dtype=torch.float32)) + assert output.text_hidden_states.shape == (1, 2) + assert output.multimodal_outputs is None + + +def test_ming_llm_forward_normalizes_runtime_additional_information(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_llm as llm_mod + + cfg = _make_config() + backbone = _DummyBackbone() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(llm_mod, "init_vllm_registered_model", lambda **kwargs: backbone) + monkeypatch.setattr(llm_mod, "Aggregator", _DummyAggregator) + monkeypatch.setattr(llm_mod, "FlowLoss", _DummyFlowLoss) + + model = MingLLMModel(vllm_config=_make_vllm_config("llm")) + output = model.forward( + input_ids=torch.tensor([1], dtype=torch.long), + positions=torch.tensor([0], dtype=torch.long), + runtime_additional_information=[{"decode_step": 0}], + ) + + assert set(backbone.last_forward_kwargs) == { + "input_ids", + "positions", + "intermediate_tensors", + "inputs_embeds", + } + assert torch.equal(backbone.last_forward_kwargs["input_ids"], torch.tensor([1], dtype=torch.long)) + assert torch.equal(backbone.last_forward_kwargs["positions"], torch.tensor([0], dtype=torch.long)) + assert backbone.last_forward_kwargs["intermediate_tensors"] is None + assert torch.allclose(backbone.last_forward_kwargs["inputs_embeds"], torch.zeros((1, 2), dtype=torch.float32)) + assert output.text_hidden_states.shape == (1, 2) + assert output.multimodal_outputs is None + + +def test_ming_stage0_sampler_uses_model_sample(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts as ming_mod + + class _DummyStage0(torch.nn.Module): + def sample(self, logits, sampling_metadata): + del logits, sampling_metadata + return SamplerOutput( + sampled_token_ids=torch.tensor([[151705]], dtype=torch.int32), + logprobs_tensors=None, + ) + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(ming_mod, "init_vllm_registered_model", lambda **kwargs: _DummyStage0()) + + model = MingTTSForConditionalGeneration(vllm_config=_make_vllm_config("llm")) + sampler_output = model.sampler( + torch.zeros((1, cfg.llm_vocab_size), dtype=torch.float32), + SimpleNamespace(seq_groups=[]), + ) + + assert isinstance(sampler_output, SamplerOutput) + assert sampler_output.sampled_token_ids.dtype == torch.int32 + assert sampler_output.sampled_token_ids.tolist() == [[151705]] + + +def test_ming_stage0_load_weights_does_not_load_audio_weights(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts as ming_mod + + class _DummyStage0(torch.nn.Module): + def __init__(self): + super().__init__() + self.loaded = None + + def load_weights(self, weights): + self.loaded = list(weights) + return {name for name, _ in self.loaded} + + cfg = _make_config() + stage0 = _DummyStage0() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(ming_mod, "init_vllm_registered_model", lambda **kwargs: stage0) + + model = MingTTSForConditionalGeneration(vllm_config=_make_vllm_config("llm")) + loaded = model.load_weights( + [ + ("model.layers.0.weight", torch.ones((2, 2), dtype=torch.float32)), + ("linear_proj_audio.proj_in.weight", torch.ones((896, 64), dtype=torch.float32)), + ("flowloss.dummy.weight", torch.ones((64, 896), dtype=torch.float32)), + ("stop_head.weight", torch.ones((2, 896), dtype=torch.float32)), + ("spk_head.weight", torch.ones((896, 192), dtype=torch.float32)), + ("audio.encoder.weight", torch.ones((2, 2), dtype=torch.float32)), + ] + ) + + assert "model.audio.encoder.weight" not in loaded + assert all(not name.startswith("audio.") for name, _ in stage0.loaded) + assert not hasattr(model, "_prompt_audio_encoder") + + +def test_ming_resolve_prompt_latents_accepts_raw_waveform(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts as ming_mod + + class _DummyStage0(torch.nn.Module): + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return torch.zeros((input_ids.shape[0], 2), dtype=torch.float32) + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(ming_mod, "init_vllm_registered_model", lambda **kwargs: _DummyStage0()) + + model = MingTTSForConditionalGeneration(vllm_config=_make_vllm_config("llm")) + direct = torch.ones((8, 64), dtype=torch.float32) + + resolved = model._resolve_prompt_latents({KEY_PROMPT_LATENTS: direct}) + assert resolved is not None + assert torch.equal(resolved["frames"], direct) + + model._encode_prompt_waveform_to_latents = lambda waveform, waveform_length=None: torch.ones( + (8, 64), dtype=torch.float32 + ) + resolved = model._resolve_prompt_latents( + { + "prompt_waveform": torch.ones((1, 1000), dtype=torch.float32), + "prompt_waveform_length": torch.tensor([1000], dtype=torch.int32), + "prompt_text": "Reference words.", + } + ) + assert resolved is not None + assert resolved["patches"].shape == (2, 4, 64) + + +def test_ming_resolve_prompt_latents_rejects_dual_truth_waveform_and_latents(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts as ming_mod + + class _DummyStage0(torch.nn.Module): + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return torch.zeros((input_ids.shape[0], 2), dtype=torch.float32) + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(ming_mod, "init_vllm_registered_model", lambda **kwargs: _DummyStage0()) + + model = MingTTSForConditionalGeneration(vllm_config=_make_vllm_config("llm")) + + with pytest.raises(ValueError, match="Choose exactly one source of truth"): + model._resolve_prompt_latents( + { + KEY_PROMPT_LATENTS: torch.ones((8, 64), dtype=torch.float32), + "prompt_waveform": torch.ones((1, 1000), dtype=torch.float32), + "prompt_waveform_length": torch.tensor([1000], dtype=torch.int32), + "prompt_text": "Reference words.", + } + ) + + +def test_ming_prefill_overwrites_speaker_slot_embedding(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts as ming_mod + + class _DummyStage0(torch.nn.Module): + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return torch.arange(int(input_ids.shape[0]) * 2, dtype=torch.float32).reshape(int(input_ids.shape[0]), 2) + + def project_speaker_embedding(self, spk_emb: torch.Tensor) -> torch.Tensor: + del spk_emb + return torch.tensor([[101.0, 202.0]], dtype=torch.float32) + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(ming_mod, "init_vllm_registered_model", lambda **kwargs: _DummyStage0()) + + vllm_config = _make_vllm_config("llm") + vllm_config.model_config.hf_config = SimpleNamespace(vision_start_token_id=10) + model = MingTTSForConditionalGeneration(vllm_config=vllm_config) + + input_ids = torch.tensor([1, 10, 20, 2], dtype=torch.long) + input_embeds = model.model.embed_input_ids(input_ids) + _, updated_embeds, _ = model._prefill_preprocess( + input_ids, + input_embeds, + speaker_embedding=torch.ones((192,), dtype=torch.float32), + ) + + assert torch.allclose(updated_embeds[2], torch.tensor([101.0, 202.0], dtype=torch.float32)) diff --git a/tests/model_executor/models/ming_tts/test_ming_tts_prompt_builder.py b/tests/model_executor/models/ming_tts/test_ming_tts_prompt_builder.py new file mode 100644 index 00000000000..4381b913021 --- /dev/null +++ b/tests/model_executor/models/ming_tts/test_ming_tts_prompt_builder.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.models.ming_tts.config_ming_tts import ( + AUDIO_FRAME_HOP, + KEY_CFG, + KEY_MAX_DECODE_STEPS, + KEY_MIN_DECODE_STEPS, + KEY_PROMPT_LATENTS, + KEY_SPEAKER_EMBEDDING, + PATCH_SIZE, + SAMPLE_RATE, +) +from vllm_omni.model_executor.models.ming_tts.ingress import MingIngressProcessor +from vllm_omni.model_executor.models.ming_tts.prompt_builder import ( + build_dense_prompt_token_ids, + build_ming_dense_prompt, + count_prompt_waveform_patches, + pad_prompt_waveform, +) + + +class _DummyTokenizer: + def __init__(self): + self._token_to_id = {"": 9001, "<|vision_start|>": 9002} + self._id_to_token = {token_id: token for token, token_id in self._token_to_id.items()} + self._next = 100 + + def encode(self, text): + if text not in self._token_to_id: + self._token_to_id[text] = self._next + self._id_to_token[self._next] = text + self._next += 1 + return [self._token_to_id[text]] + + def convert_tokens_to_ids(self, token): + if token not in self._token_to_id: + self._token_to_id[token] = self._next + self._id_to_token[self._next] = token + self._next += 1 + return self._token_to_id[token] + + def decode(self, token_ids): + return "".join(self._id_to_token[int(token_id)] for token_id in token_ids) + + +def _make_dummy_ingress_processor(tokenizer): + processor = MingIngressProcessor.__new__(MingIngressProcessor) + processor.tokenizer = tokenizer + processor.profile_ingress = False + processor.ming_config = SimpleNamespace(patch_size=4, latent_dim=64, vae_patch_size=4, audio_frame_hop=882) + return processor + + +def test_build_dense_prompt_token_ids_matches_ming_dense_layout(): + tokenizer = _DummyTokenizer() + + prompt_ids = build_dense_prompt_token_ids( + tokenizer, + prompt="Prompt text.", + text="Target text.", + instruction="instruction-json", + prompt_text="reference transcript", + speaker_count=2, + prompt_patch_count=3, + ) + + assert prompt_ids.count(tokenizer.convert_tokens_to_ids("")) == 3 + assert prompt_ids.count(tokenizer.convert_tokens_to_ids("<|vision_start|>")) == 2 + assert tokenizer.encode("instruction-json")[0] in prompt_ids + assert tokenizer.encode("reference transcript")[0] in prompt_ids + + +def test_build_ming_dense_prompt_pads_prompt_waveform_and_zero_speaker(): + tokenizer = _DummyTokenizer() + waveform = torch.ones((1, 1000), dtype=torch.float32) + + prompt = build_ming_dense_prompt( + tokenizer, + prompt="Please imitate the reference speech.", + text="Hello world.", + prompt_text="Reference words.", + prompt_waveform=waveform, + use_zero_spk_emb=True, + ) + + info = prompt["additional_information"] + padded_waveform = info["prompt_waveform"] + + assert padded_waveform.shape == (1, 14112) + assert int(info[KEY_SPEAKER_EMBEDDING].numel()) == 192 + expected_patch_count = count_prompt_waveform_patches(waveform) + assert prompt["prompt_token_ids"].count(tokenizer.convert_tokens_to_ids("")) == expected_patch_count + + +def test_build_ming_dense_prompt_uses_patch_count_not_frame_count_for_zero_shot_waveform(): + tokenizer = _DummyTokenizer() + waveform = torch.ones((1, 211680), dtype=torch.float32) + + prompt = build_ming_dense_prompt( + tokenizer, + prompt="Please generate speech based on the following description.\n", + text="Target text.", + prompt_text="Reference words.", + prompt_waveform=waveform, + speaker_embedding=torch.ones((192,), dtype=torch.float32), + ) + + expected_patch_count = count_prompt_waveform_patches(waveform) + assert prompt["additional_information"].get(KEY_PROMPT_LATENTS) is None + assert prompt["prompt_token_ids"].count(tokenizer.convert_tokens_to_ids("")) == expected_patch_count + + +def test_build_ming_dense_prompt_accepts_flat_speaker_embedding_list(): + tokenizer = _DummyTokenizer() + speaker_embedding = [0.1] * 192 + + prompt = build_ming_dense_prompt( + tokenizer, + prompt="Please imitate the reference speech.", + text="Hello world.", + speaker_embedding=speaker_embedding, + ) + + info = prompt["additional_information"] + assert tuple(info[KEY_SPEAKER_EMBEDDING].shape) == (192,) + assert prompt["prompt_token_ids"].count(tokenizer.convert_tokens_to_ids("<|vision_start|>")) == 1 + + +def test_build_ming_dense_prompt_uses_prompt_latents_to_set_patch_count(): + tokenizer = _DummyTokenizer() + prompt_latents = torch.ones((15, 4, 64), dtype=torch.float32) + + prompt = build_ming_dense_prompt( + tokenizer, + prompt="Please generate speech based on the following description.\n", + text="Target text.", + prompt_text="Reference words.", + prompt_latents=prompt_latents, + speaker_embedding=torch.ones((192,), dtype=torch.float32), + ) + + assert torch.equal(prompt["additional_information"][KEY_PROMPT_LATENTS], prompt_latents) + assert prompt["prompt_token_ids"].count(tokenizer.convert_tokens_to_ids("")) == 15 + + +def test_build_ming_dense_prompt_allows_raw_waveform_shell_without_explicit_prompt_latents(): + tokenizer = _DummyTokenizer() + waveform = torch.ones((1, 1000), dtype=torch.float32) + + prompt = build_ming_dense_prompt( + tokenizer, + prompt="Please imitate the reference speech.", + text="Hello world.", + prompt_text="Reference words.", + prompt_waveform=waveform, + speaker_embedding=torch.ones((192,), dtype=torch.float32), + ) + + expected_patch_count = count_prompt_waveform_patches(waveform) + assert prompt["additional_information"].get(KEY_PROMPT_LATENTS) is None + assert prompt["prompt_token_ids"].count(tokenizer.convert_tokens_to_ids("")) == expected_patch_count + + +def test_build_ming_dense_prompt_rejects_dual_truth_waveform_and_prompt_latents(): + tokenizer = _DummyTokenizer() + waveform = torch.ones((1, 1000), dtype=torch.float32) + prompt_latents = torch.ones((4, 64), dtype=torch.float32) + + with pytest.raises(ValueError, match="Choose exactly one source of truth"): + build_ming_dense_prompt( + tokenizer, + prompt="Please imitate the reference speech.", + text="Hello world.", + prompt_text="Reference words.", + prompt_waveform=waveform, + prompt_latents=prompt_latents, + ) + + +def test_ming_ingress_processor_preserves_raw_waveform_for_stage0_encoding(): + tokenizer = _DummyTokenizer() + waveform = torch.ones((1, 1000), dtype=torch.float32) + prompt_text = "Reference words." + prompt = build_ming_dense_prompt( + tokenizer, + prompt="Please imitate the reference speech.", + text="Hello world.", + prompt_text=prompt_text, + prompt_waveform=waveform, + speaker_embedding=torch.ones((192,), dtype=torch.float32), + ) + prompt["prompt"] = "Please imitate the reference speech." + prompt["text"] = "Hello world." + prompt["prompt_text"] = prompt_text + prompt["prompt_waveform"] = waveform + prompt["prompt_waveform_length"] = torch.tensor([1000], dtype=torch.int32) + + processor = _make_dummy_ingress_processor(tokenizer) + finalized = processor(prompt) + + assert finalized["prompt_waveform"] is waveform + assert torch.equal(finalized["prompt_waveform_length"], torch.tensor([1000], dtype=torch.int32)) + assert finalized["additional_information"]["prompt_waveform"] is prompt["additional_information"]["prompt_waveform"] + assert torch.equal( + finalized["additional_information"]["prompt_waveform_length"], + prompt["additional_information"]["prompt_waveform_length"], + ) + assert KEY_PROMPT_LATENTS not in finalized["additional_information"] + expected_patch_count = count_prompt_waveform_patches(waveform) + assert finalized["prompt_token_ids"].count(tokenizer.convert_tokens_to_ids("")) == expected_patch_count + + +def test_build_ming_dense_prompt_rejects_prompt_waveform_without_prompt_text(): + tokenizer = _DummyTokenizer() + waveform = torch.ones((1, 1000), dtype=torch.float32) + + with pytest.raises(ValueError, match="prompt_waveform requires prompt_text"): + build_ming_dense_prompt( + tokenizer, + prompt="Please generate speech based on the following description.\n", + text="我竟然抢到了陈奕迅的演唱会门票!", + instruction={"情感": "高兴"}, + prompt_waveform=waveform, + ) + + +def test_ming_ingress_processor_rejects_raw_prompt_waveform_without_prompt_text(): + tokenizer = _DummyTokenizer() + waveform = torch.ones((1, 1000), dtype=torch.float32) + prompt = { + "prompt": "Please generate speech based on the following description.\n", + "text": "我竟然抢到了陈奕迅的演唱会门票!", + "prompt_token_ids": [1, 2, 3], + "additional_information": { + "prompt_waveform": waveform, + "prompt_waveform_length": torch.tensor([1000], dtype=torch.int32), + }, + } + + processor = _make_dummy_ingress_processor(tokenizer) + + with pytest.raises(RuntimeError, match="prompt_waveform requires prompt_text"): + processor(prompt) + + +def test_ming_ingress_processor_rebuilds_podcast_prompt_with_prompt_text_before_target_text(): + tokenizer = _DummyTokenizer() + prompt_prefix = "Please generate speech based on the following description.\n" + prompt_text = " speaker_1:reference one\n speaker_2:reference two\n" + target_text = " speaker_1:target one\n speaker_2:target two\n" + speaker_embeddings = torch.ones((2, 192), dtype=torch.float32) + prompt_waveform = [ + torch.ones((1, 1000), dtype=torch.float32), + torch.ones((1, 2000), dtype=torch.float32), + ] + + prompt = build_ming_dense_prompt( + tokenizer, + prompt=prompt_prefix, + text=target_text, + prompt_text=prompt_text, + prompt_waveform=prompt_waveform, + speaker_embedding=speaker_embeddings, + ) + + processor = _make_dummy_ingress_processor(tokenizer) + finalized = processor(prompt) + decoded = tokenizer.decode(finalized["prompt_token_ids"]) + expected_patch_count = count_prompt_waveform_patches(prompt_waveform) + + assert decoded.index(prompt_text) < decoded.index(target_text) + assert finalized["prompt_token_ids"].count(tokenizer.convert_tokens_to_ids("<|vision_start|>")) == 2 + assert finalized["prompt_token_ids"].count(tokenizer.convert_tokens_to_ids("")) == expected_patch_count + assert "prompt_waveform" in finalized["additional_information"] + assert KEY_PROMPT_LATENTS not in finalized["additional_information"] + + +def test_build_ming_dense_prompt_keeps_single_speaker_initial_payload_compatible(): + tokenizer = _DummyTokenizer() + prompt_prefix = "Please imitate the reference speech." + target_text = "Hello world." + prompt_text = "Reference words." + waveform = torch.ones((1, 1000), dtype=torch.float32) + + prompt = build_ming_dense_prompt( + tokenizer, + prompt=prompt_prefix, + text=target_text, + prompt_text=prompt_text, + prompt_waveform=waveform, + speaker_embedding=torch.ones((192,), dtype=torch.float32), + ) + expected_patch_count = count_prompt_waveform_patches(waveform) + expected_prompt_token_ids = build_dense_prompt_token_ids( + tokenizer, + prompt=prompt_prefix, + text=target_text, + prompt_text=prompt_text, + speaker_count=1, + prompt_patch_count=expected_patch_count, + ) + + assert prompt["prompt"] == prompt_prefix + assert prompt["text"] == target_text + assert prompt["prompt_token_ids"] == expected_prompt_token_ids + assert prompt["prompt_token_ids"].count(tokenizer.convert_tokens_to_ids("")) == expected_patch_count + assert prompt["additional_information"]["prompt_text"] == prompt_text + + +def test_pad_prompt_waveform_matches_upstream_ming_alignment(): + padded = pad_prompt_waveform(torch.ones((1, 3529), dtype=torch.float32)) + assert int(padded.shape[-1]) == 14112 + assert int(padded.shape[-1]) % int((float(SAMPLE_RATE) / 12.5) * int(PATCH_SIZE)) == 0 + assert int(padded.shape[-1]) % int(AUDIO_FRAME_HOP * PATCH_SIZE) == 0 + + +def test_build_ming_dense_prompt_injects_duration_window_when_missing(): + tokenizer = _DummyTokenizer() + + prompt = build_ming_dense_prompt( + tokenizer, + prompt="Please generate music based on the following description.\n", + text=" Genre: electronic. Mood: confident. Instrument: drums. Theme: festival. Duration: 30s.", + runtime_controls={KEY_CFG: 2.0}, + ) + + info = prompt["additional_information"] + assert float(info[KEY_CFG].item()) == 2.0 + assert int(info[KEY_MIN_DECODE_STEPS].item()) == 91 + assert int(info[KEY_MAX_DECODE_STEPS].item()) == 97 + + +def test_build_ming_dense_prompt_preserves_explicit_decode_window_overrides(): + tokenizer = _DummyTokenizer() + + prompt = build_ming_dense_prompt( + tokenizer, + prompt="Please generate music based on the following description.\n", + text=" Genre: electronic. Mood: confident. Instrument: drums. Theme: festival. Duration: 30s.", + runtime_controls={ + KEY_MIN_DECODE_STEPS: 11, + KEY_MAX_DECODE_STEPS: 13, + }, + ) + + info = prompt["additional_information"] + assert int(info[KEY_MIN_DECODE_STEPS].item()) == 11 + assert int(info[KEY_MAX_DECODE_STEPS].item()) == 13 + + +def test_build_ming_dense_prompt_does_not_inject_duration_window_without_valid_duration(): + tokenizer = _DummyTokenizer() + + prompt_missing = build_ming_dense_prompt( + tokenizer, + prompt="Please generate music based on the following description.\n", + text=" Genre: electronic. Mood: confident. Instrument: drums. Theme: festival.", + runtime_controls={KEY_CFG: 2.0}, + ) + prompt_malformed = build_ming_dense_prompt( + tokenizer, + prompt="Please generate music based on the following description.\n", + text=" Genre: electronic. Mood: confident. Instrument: drums. Theme: festival. Duration: nope.", + runtime_controls={KEY_CFG: 2.0}, + ) + + for prompt in (prompt_missing, prompt_malformed): + info = prompt["additional_information"] + assert KEY_MIN_DECODE_STEPS not in info + assert KEY_MAX_DECODE_STEPS not in info diff --git a/tests/model_executor/stage_input_processors/test_ming_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_ming_tts_async_chunk.py new file mode 100644 index 00000000000..1a7acd04263 --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_ming_tts_async_chunk.py @@ -0,0 +1,421 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.models.ming_tts.config_ming_tts import ( + KEY_REQUEST_ID, + LATENT_CHUNK_SIZE, + LATENT_LEFT_CONTEXT, + PATCH_SIZE, +) +from vllm_omni.model_executor.stage_input_processors.ming_tts import ( + MING_EMIT_PATCH_COUNT_KEY, + MING_ESTIMATED_BYTES_KEY, + MING_FINAL_DECODE_STEP_KEY, + MING_FINAL_FLUSH_KEY, + MING_LATENT_SHAPE_KEY, + MING_STOP_REASON_KEY, + _extract_last_patch, + llm2audio_vae, + llm2audio_vae_async_chunk, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +_LATENT_D = 64 + + +def _req(external_req_id: str, *, finished: bool): + return SimpleNamespace( + external_req_id=external_req_id, + is_finished=lambda: finished, + ) + + +def _manager(*, chunk_size: int | None = 2, left_context: int | None = 0, raw_config=None): + if raw_config is None: + extra = {} + if chunk_size is not None: + extra["latent_chunk_size"] = chunk_size + if left_context is not None: + extra["latent_left_context"] = left_context + raw_config = {"extra": extra} + return SimpleNamespace( + code_prompt_token_ids=defaultdict(list), + put_req_chunk=defaultdict(int), + request_payload={}, + connector=SimpleNamespace(config=raw_config), + ) + + +def _patch(fill: float) -> torch.Tensor: + return torch.full((PATCH_SIZE, _LATENT_D), fill, dtype=torch.float32) + + +def _payload(fill: float, *, has_patch=True, decode_step=None, stop_reason=None) -> dict[str, object]: + payload = { + "ming_has_patch": torch.tensor([has_patch]), + "ming_latent_patch": _patch(fill).unsqueeze(0), + } + if decode_step is not None: + payload["ming_decode_step"] = torch.tensor([decode_step], dtype=torch.int32) + if stop_reason is not None: + payload[MING_STOP_REASON_KEY] = (stop_reason,) + return payload + + +def test_extract_last_patch_uses_active_mask(): + patch = torch.arange(3 * PATCH_SIZE * _LATENT_D, dtype=torch.float16).reshape(3, PATCH_SIZE, _LATENT_D) + payload = { + "ming_has_patch": torch.tensor([False, True, False]), + "ming_latent_patch": patch, + } + + out = _extract_last_patch(payload) + + assert out is not None + assert out.shape == (PATCH_SIZE, _LATENT_D) + assert out.dtype == torch.float32 + assert out.device.type == "cpu" + assert torch.allclose(out, patch[1].to(torch.float32).cpu()) + + +def test_llm2audio_vae_async_chunk_waits_for_full_chunk(): + transfer_manager = _manager(chunk_size=2) + request = _req("rid-wait", finished=False) + + payload = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(1.0), + request=request, + ) + + assert payload is None + assert len(transfer_manager.code_prompt_token_ids["rid-wait"]) == 1 + + +def test_llm2audio_vae_async_chunk_partial_chunk_does_not_emit(): + transfer_manager = _manager(chunk_size=3) + request = _req("rid-partial", finished=False) + + first = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(1.0), + request=request, + ) + second = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(2.0), + request=request, + ) + + assert first is None + assert second is None + assert len(transfer_manager.code_prompt_token_ids["rid-partial"]) == 2 + + +def test_llm2audio_vae_async_chunk_emits_full_chunk(): + transfer_manager = _manager(chunk_size=2) + request_id = "rid-full" + request = _req(request_id, finished=False) + transfer_manager.code_prompt_token_ids[request_id].append(_patch(1.0)) + + payload = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(2.0), + request=request, + ) + + assert payload is not None + assert payload["finished"].item() is False + assert payload["stream_finished"].item() is False + assert payload[KEY_REQUEST_ID] == request_id + assert payload["code_predictor_codes"] == [0] + assert payload["ming_latent_patches"].shape == (2, PATCH_SIZE, _LATENT_D) + assert payload[MING_EMIT_PATCH_COUNT_KEY] == 2 + assert payload[MING_LATENT_SHAPE_KEY] == (2, PATCH_SIZE, _LATENT_D) + assert payload[MING_ESTIMATED_BYTES_KEY] == int( + payload["ming_latent_patches"].numel() * payload["ming_latent_patches"].element_size() + ) + assert payload[MING_ESTIMATED_BYTES_KEY] > 0 + assert payload[MING_FINAL_FLUSH_KEY] is False + assert torch.allclose(payload["ming_latent_patches"][0], _patch(1.0)) + assert torch.allclose(payload["ming_latent_patches"][1], _patch(2.0)) + assert transfer_manager.request_payload[request_id]["_ming_async_state"]["seen_patch_len"] == 2 + + +def test_llm2audio_vae_async_chunk_multi_request_interleaving_has_no_state_bleed(): + transfer_manager = _manager(chunk_size=2) + req_a = _req("rid-a", finished=False) + req_b = _req("rid-b", finished=False) + + assert ( + llm2audio_vae_async_chunk(transfer_manager=transfer_manager, pooling_output=_payload(1.0), request=req_a) + is None + ) + assert ( + llm2audio_vae_async_chunk(transfer_manager=transfer_manager, pooling_output=_payload(10.0), request=req_b) + is None + ) + + payload_a = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(2.0), + request=req_a, + ) + assert payload_a is not None + assert payload_a[KEY_REQUEST_ID] == "rid-a" + assert torch.allclose(payload_a["ming_latent_patches"][0], _patch(1.0)) + assert torch.allclose(payload_a["ming_latent_patches"][1], _patch(2.0)) + + assert len(transfer_manager.code_prompt_token_ids["rid-b"]) == 1 + + payload_b = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(20.0), + request=req_b, + ) + assert payload_b is not None + assert payload_b[KEY_REQUEST_ID] == "rid-b" + assert torch.allclose(payload_b["ming_latent_patches"][0], _patch(10.0)) + assert torch.allclose(payload_b["ming_latent_patches"][1], _patch(20.0)) + + assert transfer_manager.request_payload["rid-a"]["_ming_async_state"]["seen_patch_len"] == 2 + assert transfer_manager.request_payload["rid-b"]["_ming_async_state"]["seen_patch_len"] == 2 + + +def test_llm2audio_vae_async_chunk_finish_after_full_chunk_only_emits_eof(): + transfer_manager = _manager(chunk_size=2) + request_id = "rid-drain" + request = _req(request_id, finished=False) + transfer_manager.code_prompt_token_ids[request_id].append(_patch(1.0)) + + payload = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(2.0), + request=request, + ) + + assert payload is not None + assert transfer_manager.request_payload[request_id]["_ming_async_state"]["seen_patch_len"] == 2 + + finish_payload = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=None, + request=_req(request_id, finished=True), + ) + + assert finish_payload == { + "code_predictor_codes": [], + "finished": torch.tensor(True, dtype=torch.bool), + "stream_finished": torch.tensor(True, dtype=torch.bool), + "ming_chunk_id": 0, + KEY_REQUEST_ID: request_id, + MING_EMIT_PATCH_COUNT_KEY: 0, + MING_LATENT_SHAPE_KEY: None, + MING_ESTIMATED_BYTES_KEY: 0, + MING_FINAL_FLUSH_KEY: True, + } + + +def test_llm2audio_vae_async_chunk_flushes_tail_on_finish_without_new_patch(): + transfer_manager = _manager(chunk_size=3) + request_id = "rid-tail" + request = _req(request_id, finished=True) + transfer_manager.code_prompt_token_ids[request_id] = [ + _patch(1.0), + _patch(2.0), + ] + + payload = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=None, + request=request, + ) + + assert payload is not None + assert payload["finished"].item() is True + assert payload["stream_finished"].item() is True + assert payload[KEY_REQUEST_ID] == request_id + assert payload["ming_latent_patches"].shape == (2, PATCH_SIZE, _LATENT_D) + assert payload[MING_EMIT_PATCH_COUNT_KEY] == 2 + assert payload[MING_LATENT_SHAPE_KEY] == (2, PATCH_SIZE, _LATENT_D) + assert payload[MING_ESTIMATED_BYTES_KEY] > 0 + assert payload[MING_FINAL_FLUSH_KEY] is True + assert torch.allclose(payload["ming_latent_patches"][0], _patch(1.0)) + assert torch.allclose(payload["ming_latent_patches"][1], _patch(2.0)) + + +def test_llm2audio_vae_async_chunk_final_flush_emits_partial_chunk_with_new_patch(): + transfer_manager = _manager(chunk_size=3) + request_id = "rid-tail-new" + + transfer_manager.code_prompt_token_ids[request_id].append(_patch(1.0)) + payload = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(2.0, decode_step=7, stop_reason="stop_head"), + request=_req(request_id, finished=True), + ) + + assert payload is not None + assert payload["finished"].item() is True + assert payload["stream_finished"].item() is True + assert payload[MING_EMIT_PATCH_COUNT_KEY] == 2 + assert payload[MING_FINAL_FLUSH_KEY] is True + assert payload[MING_FINAL_DECODE_STEP_KEY] == 7 + assert payload[MING_STOP_REASON_KEY] == "stop_head" + assert torch.allclose(payload["ming_latent_patches"][0], _patch(1.0)) + assert torch.allclose(payload["ming_latent_patches"][1], _patch(2.0)) + + +def test_llm2audio_vae_async_chunk_emits_eof_when_finished_without_frames(): + transfer_manager = _manager(chunk_size=2) + request = _req("rid-eof", finished=True) + + payload = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=None, + request=request, + ) + + assert payload == { + "code_predictor_codes": [], + "finished": torch.tensor(True, dtype=torch.bool), + "stream_finished": torch.tensor(True, dtype=torch.bool), + "ming_chunk_id": 0, + KEY_REQUEST_ID: "rid-eof", + MING_EMIT_PATCH_COUNT_KEY: 0, + MING_LATENT_SHAPE_KEY: None, + MING_ESTIMATED_BYTES_KEY: 0, + MING_FINAL_FLUSH_KEY: True, + } + + +def test_llm2audio_vae_async_chunk_zero_latent_final_flush_returns_empty_payload_not_error(): + transfer_manager = _manager(chunk_size=2) + + payload = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output={ + "ming_has_patch": torch.tensor([False]), + "ming_latent_patch": torch.zeros((1, PATCH_SIZE, _LATENT_D), dtype=torch.float32), + }, + request=_req("rid-zero-final", finished=True), + ) + + assert payload == { + "code_predictor_codes": [], + "finished": torch.tensor(True, dtype=torch.bool), + "stream_finished": torch.tensor(True, dtype=torch.bool), + "ming_chunk_id": 0, + KEY_REQUEST_ID: "rid-zero-final", + MING_EMIT_PATCH_COUNT_KEY: 0, + MING_LATENT_SHAPE_KEY: None, + MING_ESTIMATED_BYTES_KEY: 0, + MING_FINAL_FLUSH_KEY: True, + } + + +def test_llm2audio_vae_async_chunk_rejects_left_context_config(): + transfer_manager = _manager(chunk_size=2, left_context=1) + request = _req("rid-bad-cfg", finished=False) + + with pytest.raises( + ValueError, + match="does not support latent_left_context replay.*Got latent_left_context=1", + ): + llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(1.0), + request=request, + ) + + +def test_llm2audio_vae_async_chunk_rejects_non_positive_chunk_size(): + transfer_manager = _manager(chunk_size=0, left_context=0) + + with pytest.raises(ValueError, match="Invalid Ming latent_chunk_size=0"): + llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(1.0), + request=_req("rid-bad-chunk", finished=False), + ) + + +def test_llm2audio_vae_async_chunk_missing_config_uses_fallback_defaults(): + transfer_manager = _manager(raw_config={"extra": {}}) + request_id = "rid-fallback" + + for idx in range(LATENT_CHUNK_SIZE - 1): + payload = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(float(idx + 1)), + request=_req(request_id, finished=False), + ) + assert payload is None + + payload = llm2audio_vae_async_chunk( + transfer_manager=transfer_manager, + pooling_output=_payload(float(LATENT_CHUNK_SIZE)), + request=_req(request_id, finished=False), + ) + + assert payload is not None + assert payload[MING_EMIT_PATCH_COUNT_KEY] == LATENT_CHUNK_SIZE + assert payload[MING_LATENT_SHAPE_KEY] == (LATENT_CHUNK_SIZE, PATCH_SIZE, _LATENT_D) + assert LATENT_LEFT_CONTEXT == 0 + + +def test_llm2audio_vae_builds_generation_prompt_from_stage_output(): + patches = torch.arange(2 * PATCH_SIZE * _LATENT_D, dtype=torch.float32).reshape(2, PATCH_SIZE, _LATENT_D) + stage_output = SimpleNamespace( + request_id="rid-stage", + finished=True, + outputs=[ + SimpleNamespace( + multimodal_output={ + "ming_has_patch": torch.tensor([True, True]), + "ming_latent_patch": patches, + "ming_decode_step": torch.tensor([26, 27], dtype=torch.int32), + "ming_stop_reason": ("continue", "stop_head"), + } + ) + ], + ) + stage = SimpleNamespace(engine_outputs=[stage_output]) + + prompts = llm2audio_vae(stage_list=[stage], engine_input_source=[0]) + + assert len(prompts) == 1 + info = prompts[0]["additional_information"] + assert info[KEY_REQUEST_ID] == "rid-stage" + assert info["finished"].item() is True + assert info["ming_latent_patches"].shape == (2, PATCH_SIZE, _LATENT_D) + assert torch.allclose(info["ming_latent_patches"], patches) + assert info[MING_FINAL_DECODE_STEP_KEY] == 27 + assert info[MING_STOP_REASON_KEY] == "stop_head" + + +def test_llm2audio_vae_skips_unfinished_stage_output(): + patch = torch.arange(PATCH_SIZE * _LATENT_D, dtype=torch.float32).reshape(1, PATCH_SIZE, _LATENT_D) + stage_output = SimpleNamespace( + request_id="rid-unfinished", + finished=False, + outputs=[ + SimpleNamespace( + multimodal_output={ + "ming_has_patch": torch.tensor([True]), + "ming_latent_patch": patch, + } + ) + ], + ) + stage = SimpleNamespace(engine_outputs=[stage_output]) + + prompts = llm2audio_vae(stage_list=[stage], engine_input_source=[0]) + + assert prompts == [] diff --git a/tests/worker/test_ming_tts_runner.py b/tests/worker/test_ming_tts_runner.py new file mode 100644 index 00000000000..89deda2ddb1 --- /dev/null +++ b/tests/worker/test_ming_tts_runner.py @@ -0,0 +1,674 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.models.ming_tts.config_ming_tts import ( + KEY_LATENT_HISTORY, + KEY_NEXT_EMBEDS, + KEY_PROMPT_LATENTS, + KEY_REQUEST_ID, + KEY_SPEAKER_EMBEDDING, + MingTTSConfig, +) +from vllm_omni.model_executor.models.ming_tts.ming_tts import MingTTSForConditionalGeneration +from vllm_omni.model_executor.models.ming_tts.ming_tts_audio_vae import MingAudioVAEModel +from vllm_omni.model_executor.models.ming_tts.ming_tts_llm import ( + MING_STOP_REASON_CONTINUE, + MING_STOP_REASON_KEY, + MING_STOP_REASON_MAX_DECODE_STEPS, + MING_STOP_REASON_STOP_HEAD, + MingLLMModel, + _resolve_ming_stop_decision, +) +from vllm_omni.model_executor.models.output_templates import OmniOutput + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class DummyBackbone(torch.nn.Module): + def __init__(self, hidden_size: int): + super().__init__() + self.hidden_size = hidden_size + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + ids = input_ids.to(torch.float32).reshape(-1, 1) + return ids.repeat(1, self.hidden_size) / 100.0 + + def get_input_embeddings(self): + return None + + def forward(self, input_ids, positions, intermediate_tensors=None, inputs_embeds=None, **kwargs): + del input_ids, positions, intermediate_tensors, kwargs + return inputs_embeds + + +class NaNOnSecondDecodeBackbone(DummyBackbone): + def __init__(self, hidden_size: int): + super().__init__(hidden_size) + self.decode_calls = 0 + + def forward(self, input_ids, positions, intermediate_tensors=None, inputs_embeds=None, **kwargs): + del input_ids, positions, intermediate_tensors, kwargs + self.decode_calls += 1 + if self.decode_calls >= 2: + return torch.full_like(inputs_embeds, float("nan")) + return inputs_embeds + + +class DummyAggregator(torch.nn.Module): + def __init__(self, in_channels: int, llm_input_dim: int, **kwargs): + super().__init__() + del in_channels, kwargs + self.hidden_size = llm_input_dim + + def forward(self, patch: torch.Tensor) -> torch.Tensor: + pooled = patch.mean(dim=1) + repeats = self.hidden_size // pooled.shape[-1] + return pooled.repeat(1, repeats).reshape(pooled.shape[0], 1, self.hidden_size) + + +class DummyFlowLoss(torch.nn.Module): + def __init__(self, z_channels: int, llm_cond_dim: int, **kwargs): + super().__init__() + del z_channels, llm_cond_dim, kwargs + + def sample(self, z, latent_history, cfg, patch_size, sigma, temperature): + del latent_history, cfg, sigma, temperature + base = z[:, 0, :64] + return torch.stack([base + float(i + 1) for i in range(patch_size)], dim=1) + + +class DummyAudioVAE(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.weight = torch.nn.Parameter(torch.tensor(1.0)) + self.decode_calls: list[dict[str, object]] = [] + + def encode_latent(self, waveform: torch.Tensor, waveform_length: torch.Tensor): + if waveform.ndim == 2: + frames = waveform.shape[-1] // 64 + latent = waveform[:, : frames * 64].reshape(waveform.shape[0], frames, 64) + else: + latent = waveform + frame_num = torch.full((latent.shape[0],), latent.shape[1], dtype=torch.int32, device=latent.device) + return latent.to(torch.float32), frame_num + + def decode(self, latent, past_key_values=None, use_cache=False, stream_state=(None, None, None), last_chunk=False): + del use_cache, last_chunk + prev_frames = int((past_key_values or {}).get("frames", 0)) + waveform = latent.sum(dim=-1).reshape(latent.shape[0], -1).to(torch.float32) + prev_frames * 10.0 + new_stream_state = ("stream", prev_frames + latent.shape[1], tuple(latent.shape)) + new_past = {"frames": prev_frames + int(latent.shape[1])} + self.decode_calls.append( + { + "stream_state": stream_state, + "past_key_values": past_key_values, + "latent_shape": tuple(latent.shape), + } + ) + return waveform, new_stream_state, new_past + + +class _DummySamplingMetadata: + def __init__(self, step: int): + self.output_token_ids = [[0] * int(step)] + + +def _make_config() -> MingTTSConfig: + audio_cfg = SimpleNamespace( + enc_kwargs={"latent_dim": 64, "input_dim": 882, "hop_size": 882}, + dec_kwargs={"latent_dim": 64, "output_dim": 882}, + patch_size=4, + sample_rate=44100, + ) + cfg = MingTTSConfig(audio_tokenizer_config=audio_cfg) + cfg.validate() + return cfg + + +def _make_vllm_config(model_stage: str, **hf_overrides): + return SimpleNamespace( + model_config=SimpleNamespace(hf_config=SimpleNamespace(**hf_overrides), model_stage=model_stage), + quant_config=None, + device_config=SimpleNamespace(device=torch.device("cpu")), + ) + + +def _make_runner_for_ming(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts as wrapper_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_audio_vae as vae_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_llm as llm_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + + monkeypatch.setattr(llm_mod, "init_vllm_registered_model", lambda **kwargs: DummyBackbone(cfg.llm_hidden_size)) + monkeypatch.setattr(llm_mod, "Aggregator", DummyAggregator) + monkeypatch.setattr(llm_mod, "FlowLoss", DummyFlowLoss) + monkeypatch.setattr(wrapper_mod, "AudioVAE", DummyAudioVAE, raising=False) + monkeypatch.setattr(vae_mod, "AudioVAE", DummyAudioVAE) + + llm_model = MingLLMModel(vllm_config=_make_vllm_config("llm")) + vae_model = MingAudioVAEModel(vllm_config=_make_vllm_config("audio_vae")) + + def _wrapper_loader(*, architectures, **kwargs): + arch = architectures[0] + if arch == "MingLLMModel": + return llm_model + if arch == "MingAudioVAEModel": + return vae_model + raise AssertionError(f"unexpected architecture {arch}") + + monkeypatch.setattr(wrapper_mod, "init_vllm_registered_model", _wrapper_loader) + + stage1 = MingTTSForConditionalGeneration(vllm_config=_make_vllm_config("llm")) + stage2 = MingTTSForConditionalGeneration(vllm_config=_make_vllm_config("audio_vae")) + + return SimpleNamespace(config=cfg, llm=llm_model, vae=vae_model, stage1=stage1, stage2=stage2) + + +def test_ming_llm_step_shapes(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + cfg = runner.config + + prefill_ids = torch.tensor( + [1, cfg.audio_start_token_id, cfg.audio_dummy_token_id, cfg.audio_dummy_token_id, cfg.audio_end_token_id, 2], + dtype=torch.long, + ) + prefill_embeds = torch.zeros((prefill_ids.shape[0], cfg.llm_hidden_size), dtype=torch.float32) + prompt_latents = torch.arange(8 * 64, dtype=torch.float32).reshape(1, 8, 64) + + _, prefill_out_embeds, prefill_info = runner.stage1.preprocess_input( + prefill_ids, + prefill_embeds, + **{KEY_PROMPT_LATENTS: prompt_latents}, + **{KEY_REQUEST_ID: "req-1"}, + ) + + assert prefill_info[KEY_LATENT_HISTORY].shape == (32, 64) + assert torch.allclose(prefill_info[KEY_LATENT_HISTORY][-8:], prompt_latents.reshape(8, 64)) + assert torch.count_nonzero(prefill_out_embeds[1]).item() > 0 + assert torch.count_nonzero(prefill_out_embeds[2]).item() > 0 + + decode_ids = torch.tensor([cfg.audio_dummy_token_id], dtype=torch.long) + decode_embeds = torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32) + _, decode_embeds, decode_info = runner.stage1.preprocess_input( + decode_ids, + decode_embeds, + **prefill_info, + ) + + output = runner.llm.forward( + decode_ids, + positions=torch.tensor([0], dtype=torch.long), + inputs_embeds=decode_embeds, + model_intermediate_buffer=[decode_info], + seq_token_counts=[1], + ) + mm = output.multimodal_outputs + + assert mm["ming_latent_patch"].shape == (1, 4, 64) + assert mm["ming_next_embeds"].shape == (1, 1, cfg.llm_hidden_size) + assert mm["ming_new_history"].shape == (1, 32, 64) + + update = runner.stage1.postprocess(output.text_hidden_states, multimodal_outputs=mm, **decode_info) + assert update[KEY_LATENT_HISTORY].shape == (1, 32, 64) + assert torch.allclose(update[KEY_LATENT_HISTORY][0, -4:], mm["ming_latent_patch"][0].cpu()) + assert update[KEY_NEXT_EMBEDS].shape == (1, 1, cfg.llm_hidden_size) + + +def test_ming_prefill_injects_speaker_into_dense_placeholder(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts as wrapper_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_llm as llm_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(llm_mod, "init_vllm_registered_model", lambda **kwargs: DummyBackbone(cfg.llm_hidden_size)) + monkeypatch.setattr(llm_mod, "Aggregator", DummyAggregator) + monkeypatch.setattr(llm_mod, "FlowLoss", DummyFlowLoss) + monkeypatch.setattr( + wrapper_mod, "init_vllm_registered_model", lambda **kwargs: MingLLMModel(vllm_config=_make_vllm_config("llm")) + ) + + vision_start_token_id = 32001 + stage1 = MingTTSForConditionalGeneration( + vllm_config=_make_vllm_config("llm", vision_start_token_id=vision_start_token_id) + ) + + input_ids = torch.tensor( + [ + 1, + vision_start_token_id, + 77, + cfg.audio_start_token_id, + cfg.audio_dummy_token_id, + cfg.audio_end_token_id, + ], + dtype=torch.long, + ) + input_embeds = torch.zeros((input_ids.shape[0], cfg.llm_hidden_size), dtype=torch.float32) + baseline_embeds = stage1.model.embed_input_ids(input_ids).clone() + speaker = torch.ones((192,), dtype=torch.float32) + + _, out_embeds, _ = stage1.preprocess_input( + input_ids, + input_embeds, + **{KEY_SPEAKER_EMBEDDING: speaker}, + ) + + assert torch.count_nonzero(out_embeds[2]).item() > 0 + assert not torch.allclose(out_embeds[2], baseline_embeds[2]) + assert torch.allclose(out_embeds[3], baseline_embeds[3]) + + +def test_ming_prefill_injects_multiple_speakers_into_multiple_dense_placeholders(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts as wrapper_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_llm as llm_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr(llm_mod, "init_vllm_registered_model", lambda **kwargs: DummyBackbone(cfg.llm_hidden_size)) + monkeypatch.setattr(llm_mod, "Aggregator", DummyAggregator) + monkeypatch.setattr(llm_mod, "FlowLoss", DummyFlowLoss) + monkeypatch.setattr( + wrapper_mod, "init_vllm_registered_model", lambda **kwargs: MingLLMModel(vllm_config=_make_vllm_config("llm")) + ) + + vision_start_token_id = 32001 + stage1 = MingTTSForConditionalGeneration( + vllm_config=_make_vllm_config("llm", vision_start_token_id=vision_start_token_id) + ) + + input_ids = torch.tensor( + [ + 1, + vision_start_token_id, + 77, + 2, + vision_start_token_id, + 88, + cfg.audio_start_token_id, + cfg.audio_dummy_token_id, + cfg.audio_end_token_id, + ], + dtype=torch.long, + ) + input_embeds = torch.zeros((input_ids.shape[0], cfg.llm_hidden_size), dtype=torch.float32) + baseline_embeds = stage1.model.embed_input_ids(input_ids).clone() + speaker = torch.ones((2, 192), dtype=torch.float32) + + _, out_embeds, _ = stage1.preprocess_input( + input_ids, + input_embeds, + **{KEY_SPEAKER_EMBEDDING: speaker}, + ) + + assert torch.count_nonzero(out_embeds[2]).item() > 0 + assert torch.count_nonzero(out_embeds[5]).item() > 0 + assert not torch.allclose(out_embeds[2], baseline_embeds[2]) + assert not torch.allclose(out_embeds[5], baseline_embeds[5]) + assert torch.allclose(out_embeds[6], baseline_embeds[6]) + + +def test_ming_stop_logic_no_stop_before_min_required_decode_steps(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + cfg = runner.config + + def _high_stop(_hidden_states): + return torch.tensor([[0.0, 10.0]], dtype=torch.float32) + + monkeypatch.setattr(runner.llm.stop_head, "forward", _high_stop) + hidden = torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32) + + stop_reason, stop_now, force_stop, min_required_decode_steps, next_token_id = _resolve_ming_stop_decision( + step=4, + stop_prob=1.0, + stop_threshold=float(cfg.stop_head_threshold), + min_stop_step=int(cfg.stop_head_min_steps), + min_decode_steps=7, + max_decode_steps=int(cfg.max_decode_steps), + audio_dummy_token_id=int(cfg.audio_dummy_token_id), + text_eos_token_id=int(cfg.text_eos_token_id), + ) + assert stop_reason == MING_STOP_REASON_CONTINUE + assert stop_now is False + assert force_stop is False + assert min_required_decode_steps == 7 + assert next_token_id == cfg.audio_dummy_token_id + + logits_step3 = runner.llm.compute_logits( + OmniOutput( + text_hidden_states=hidden, + multimodal_outputs={"ming_min_decode_steps": torch.tensor([7], dtype=torch.int32)}, + ), + _DummySamplingMetadata(step=3), + ) + out_step3 = runner.llm.sample(logits_step3, _DummySamplingMetadata(step=3)) + assert int(out_step3.sampled_token_ids[0, 0]) == cfg.audio_dummy_token_id + assert torch.isfinite(logits_step3[0, int(cfg.audio_dummy_token_id)]) + assert not torch.isfinite(logits_step3[0, int(cfg.text_eos_token_id)]) + + +def test_ming_stop_logic_stop_head_inside_window(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + cfg = runner.config + + def _high_stop(_hidden_states): + return torch.tensor([[0.0, 10.0]], dtype=torch.float32) + + monkeypatch.setattr(runner.llm.stop_head, "forward", _high_stop) + hidden = torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32) + + stop_reason, stop_now, force_stop, min_required_decode_steps, next_token_id = _resolve_ming_stop_decision( + step=4, + stop_prob=1.0, + stop_threshold=float(cfg.stop_head_threshold), + min_stop_step=int(cfg.stop_head_min_steps), + min_decode_steps=0, + max_decode_steps=int(cfg.max_decode_steps), + audio_dummy_token_id=int(cfg.audio_dummy_token_id), + text_eos_token_id=int(cfg.text_eos_token_id), + ) + assert stop_reason == MING_STOP_REASON_STOP_HEAD + assert stop_now is True + assert force_stop is False + assert min_required_decode_steps == int(cfg.stop_head_min_steps) + 1 + assert next_token_id == cfg.text_eos_token_id + + logits_step4 = runner.llm.compute_logits(hidden, _DummySamplingMetadata(step=4)) + out_step4 = runner.llm.sample(logits_step4, _DummySamplingMetadata(step=4)) + assert int(out_step4.sampled_token_ids[0, 0]) == cfg.text_eos_token_id + + +def test_ming_stop_logic_rejects_impossible_decode_window(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + cfg = runner.config + hidden = torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32) + + with pytest.raises(RuntimeError, match="Invalid Ming decode window"): + runner.llm.compute_logits( + OmniOutput( + text_hidden_states=hidden, + multimodal_outputs={ + "ming_min_decode_steps": torch.tensor([7], dtype=torch.int32), + "ming_max_decode_steps": torch.tensor([5], dtype=torch.int32), + }, + ), + _DummySamplingMetadata(step=4), + ) + + +def test_ming_stop_logic_max_decode_guard(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + cfg = runner.config + cfg.max_decode_steps = 5 + + def _high_stop(_hidden_states): + return torch.tensor([[0.0, 10.0]], dtype=torch.float32) + + monkeypatch.setattr(runner.llm.stop_head, "forward", _high_stop) + hidden = torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32) + + stop_reason, stop_now, force_stop, min_required_decode_steps, next_token_id = _resolve_ming_stop_decision( + step=4, + stop_prob=1.0, + stop_threshold=float(cfg.stop_head_threshold), + min_stop_step=int(cfg.stop_head_min_steps), + min_decode_steps=0, + max_decode_steps=int(cfg.max_decode_steps), + audio_dummy_token_id=int(cfg.audio_dummy_token_id), + text_eos_token_id=int(cfg.text_eos_token_id), + ) + assert stop_reason == MING_STOP_REASON_MAX_DECODE_STEPS + assert stop_now is True + assert force_stop is True + assert min_required_decode_steps == int(cfg.stop_head_min_steps) + 1 + assert next_token_id == cfg.text_eos_token_id + + logits = runner.llm.compute_logits(hidden, _DummySamplingMetadata(step=4)) + out = runner.llm.sample(logits, _DummySamplingMetadata(step=4)) + assert int(out.sampled_token_ids[0, 0]) == cfg.text_eos_token_id + + +def test_ming_compute_logits_uses_forward_stop_prob_payload(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + cfg = runner.config + + def _low_stop(_hidden_states): + return torch.tensor([[10.0, 0.0]], dtype=torch.float32) + + monkeypatch.setattr(runner.llm.stop_head, "forward", _low_stop) + hidden = torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32) + + logits = runner.llm.compute_logits( + OmniOutput( + text_hidden_states=hidden, + multimodal_outputs={ + "ming_stop_prob": torch.tensor([1.0], dtype=torch.float32), + "ming_decode_step": torch.tensor([4], dtype=torch.int32), + }, + ), + _DummySamplingMetadata(step=4), + ) + out = runner.llm.sample(logits, _DummySamplingMetadata(step=4)) + assert int(out.sampled_token_ids[0, 0]) == cfg.text_eos_token_id + + +def test_ming_compute_logits_uses_cached_forward_stop_prob_for_tensor_path(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + cfg = runner.config + + def _low_stop(_hidden_states): + return torch.tensor([[10.0, 0.0]], dtype=torch.float32) + + monkeypatch.setattr(runner.llm.stop_head, "forward", _low_stop) + runner.llm._last_sample_stop_probs = torch.tensor([1.0], dtype=torch.float32) + runner.llm._last_sample_decode_steps = torch.tensor([4], dtype=torch.int32) + hidden = torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32) + + logits = runner.llm.compute_logits(hidden, _DummySamplingMetadata(step=4)) + out = runner.llm.sample(logits, _DummySamplingMetadata(step=4)) + assert int(out.sampled_token_ids[0, 0]) == cfg.text_eos_token_id + + +def test_ming_forward_exposes_stop_reason_in_outputs_and_pending_state(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + cfg = runner.config + + def _low_stop(_hidden_states): + return torch.tensor([[10.0, 0.0]], dtype=torch.float32) + + monkeypatch.setattr(runner.llm.stop_head, "forward", _low_stop) + decode_ids = torch.tensor([cfg.audio_dummy_token_id], dtype=torch.long) + decode_embeds = torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32) + output = runner.llm.forward( + decode_ids, + positions=torch.tensor([0], dtype=torch.long), + inputs_embeds=decode_embeds, + model_intermediate_buffer=[ + { + KEY_LATENT_HISTORY: torch.zeros((cfg.history_patch_size, cfg.latent_dim), dtype=torch.float32), + KEY_REQUEST_ID: "req-stop-reason", + } + ], + seq_token_counts=[1], + ) + + assert output.multimodal_outputs[MING_STOP_REASON_KEY] == (MING_STOP_REASON_CONTINUE,) + pending = runner.llm.pop_postprocess_update("req-stop-reason") + assert pending[MING_STOP_REASON_KEY] == MING_STOP_REASON_CONTINUE + + +def test_ming_postprocess_forwards_stop_reason(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + cfg = runner.config + + decode_ids = torch.tensor([cfg.audio_dummy_token_id], dtype=torch.long) + decode_embeds = torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32) + decode_info = { + KEY_LATENT_HISTORY: torch.zeros((cfg.history_patch_size, cfg.latent_dim), dtype=torch.float32), + KEY_REQUEST_ID: "req-postprocess-stop-reason", + } + + output = runner.llm.forward( + decode_ids, + positions=torch.tensor([0], dtype=torch.long), + inputs_embeds=decode_embeds, + model_intermediate_buffer=[decode_info], + seq_token_counts=[1], + ) + update = runner.stage1.postprocess(output.text_hidden_states, **decode_info) + + assert update[MING_STOP_REASON_KEY] == MING_STOP_REASON_CONTINUE + + +def test_ming_vae_incremental_decode(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + + chunk_a = torch.stack( + [ + torch.ones((4, 64), dtype=torch.float32), + torch.full((4, 64), 2.0, dtype=torch.float32), + ], + dim=0, + ) + out_a = runner.stage2.forward( + model_intermediate_buffer=[ + { + "ming_latent_patches": chunk_a, + "finished": torch.tensor(False), + "stream_finished": torch.tensor(False), + KEY_REQUEST_ID: "r1", + } + ] + ) + wav_a = out_a.multimodal_outputs["model_outputs"][0] + state_a = runner.vae._stream_state["r1"] + past_a = runner.vae._past_key_values["r1"] + + chunk_b = torch.full((1, 4, 64), 3.0, dtype=torch.float32) + out_b = runner.stage2.forward( + model_intermediate_buffer=[ + { + "ming_latent_patches": chunk_b, + "finished": torch.tensor(False), + "stream_finished": torch.tensor(False), + KEY_REQUEST_ID: "r1", + } + ] + ) + wav_b = out_b.multimodal_outputs["model_outputs"][0] + state_b = runner.vae._stream_state["r1"] + + assert len(runner.vae.audio.decode_calls) == 3 + assert runner.vae.audio.decode_calls[1]["latent_shape"] == (1, 4, 64) + assert runner.vae.audio.decode_calls[1]["past_key_values"] == {"frames": 4} + assert runner.vae.audio.decode_calls[2]["stream_state"] == state_a + assert runner.vae.audio.decode_calls[2]["past_key_values"] == past_a + assert state_b != state_a + + expected_a = torch.cat( + [ + chunk_a[0].sum(dim=-1), + chunk_a[1].sum(dim=-1) + 4 * 10.0, + ] + ) + expected_b = chunk_b[0].sum(dim=-1) + 8 * 10.0 + assert torch.allclose(wav_a, expected_a) + assert torch.allclose(wav_b, expected_b) + assert torch.allclose(torch.cat([wav_a, wav_b]), torch.cat([expected_a, expected_b])) + + +def test_ming_vae_finalizes_when_stream_finished_is_absent(monkeypatch): + runner = _make_runner_for_ming(monkeypatch) + chunk = torch.stack( + [ + torch.ones((4, 64), dtype=torch.float32), + torch.full((4, 64), 2.0, dtype=torch.float32), + ], + dim=0, + ) + + out = runner.stage2.forward( + model_intermediate_buffer=[ + { + "ming_latent_patches": chunk, + "finished": torch.tensor(True), + KEY_REQUEST_ID: "r-sequential", + } + ] + ) + + wav = out.multimodal_outputs["model_outputs"][0] + assert wav.numel() > 0 + assert "r-sequential" not in runner.vae._stream_state + assert "r-sequential" not in runner.vae._past_key_values + + +def test_ming_recurrent_backbone_can_poison_hidden_states_before_flowloss(monkeypatch): + import vllm_omni.model_executor.models.ming_tts.config_ming_tts as cfg_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts as wrapper_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_audio_vae as vae_mod + import vllm_omni.model_executor.models.ming_tts.ming_tts_llm as llm_mod + + cfg = _make_config() + monkeypatch.setattr(cfg_mod.MingTTSConfig, "from_hf_config", classmethod(lambda cls, hf: cfg)) + monkeypatch.setattr( + llm_mod, "init_vllm_registered_model", lambda **kwargs: NaNOnSecondDecodeBackbone(cfg.llm_hidden_size) + ) + monkeypatch.setattr(llm_mod, "Aggregator", DummyAggregator) + monkeypatch.setattr(llm_mod, "FlowLoss", DummyFlowLoss) + monkeypatch.setattr(vae_mod, "AudioVAE", DummyAudioVAE) + + llm_model = MingLLMModel(vllm_config=_make_vllm_config("llm")) + + def _wrapper_loader(*, architectures, **kwargs): + arch = architectures[0] + if arch == "MingLLMModel": + return llm_model + raise AssertionError(f"unexpected architecture {arch}") + + monkeypatch.setattr(wrapper_mod, "init_vllm_registered_model", _wrapper_loader) + stage1 = MingTTSForConditionalGeneration(vllm_config=_make_vllm_config("llm")) + + decode_ids = torch.tensor([cfg.audio_dummy_token_id], dtype=torch.long) + decode_embeds = torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32) + decode_info = { + KEY_LATENT_HISTORY: torch.zeros((cfg.history_patch_size, cfg.latent_dim), dtype=torch.float32), + KEY_REQUEST_ID: "req-nan", + } + + _, decode_embeds, decode_info = stage1.preprocess_input(decode_ids, decode_embeds, **decode_info) + output = llm_model.forward( + decode_ids, + positions=torch.tensor([0], dtype=torch.long), + inputs_embeds=decode_embeds, + model_intermediate_buffer=[decode_info], + seq_token_counts=[1], + ) + mm = output.multimodal_outputs + assert torch.isfinite(mm["ming_next_embeds"]).all() + + update = stage1.postprocess(output.text_hidden_states, multimodal_outputs=mm, **decode_info) + _, next_decode_embeds, next_decode_info = stage1.preprocess_input( + decode_ids, + torch.zeros((1, cfg.llm_hidden_size), dtype=torch.float32), + **update, + ) + assert torch.isfinite(next_decode_embeds).all() + + with pytest.raises(RuntimeError, match="Non-finite z_diff_cond before FlowLoss.sample"): + llm_model.forward( + decode_ids, + positions=torch.tensor([1], dtype=torch.long), + inputs_embeds=next_decode_embeds, + model_intermediate_buffer=[next_decode_info], + seq_token_counts=[1], + ) diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index a74c9ffc2d2..738d4c6457b 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -300,6 +300,28 @@ def test_update_intermediate_buffer_skips_unknown_req_id(): assert "unknown_req" not in runner.model_intermediate_buffer +def test_update_additional_information_uses_legacy_additional_information(): + runner = _make_runner(req_ids=("r1",), hidden_size=4) + + scheduler_output = SimpleNamespace( + scheduled_new_reqs=[ + SimpleNamespace( + req_id="r1", + additional_information={"new_field": 1}, + ) + ], + scheduled_cached_reqs=SimpleNamespace( + additional_information={"r1": {"cached_field": 3}}, + ), + ) + + OmniGPUModelRunner._update_additional_information(runner, scheduler_output) + + info = runner.model_intermediate_buffer["r1"] + assert info["new_field"] == 1 + assert info["cached_field"] == 3 + + def test_maybe_attach_mimo_audio_req_infos_enriches_dict(): runner = _make_runner_for_mimo() req_id = "r_mimo" diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index 89139bf1b0b..783a146bc94 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -36,6 +36,9 @@ def _register_omni_hf_configs() -> None: from transformers import AutoConfig from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config + from vllm_omni.model_executor.models.ming_tts.configuration_ming_dense import ( + MingDenseConfig, + ) from vllm_omni.model_executor.models.omnivoice.config import OmniVoiceConfig from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import ( Qwen3TTSConfig, @@ -58,6 +61,7 @@ def _register_omni_hf_configs() -> None: _CONFIG_REGISTRY = None for model_type, config_cls in [ + ("dense", MingDenseConfig), ("qwen3_tts", Qwen3TTSConfig), ("cosyvoice3", CosyVoice3Config), ("omnivoice", OmniVoiceConfig), diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 61da4388be0..ceec612c499 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -220,6 +220,27 @@ def _apply_omni_final_stage_metadata( ) +def _consume_processed_prompt( + input_processor: InputProcessor | None, + fallback_prompt: Any, +) -> Any: + """Return the prompt dict actually seen by the stage-0 preprocessor.""" + if input_processor is None: + return fallback_prompt + preprocessor = getattr(input_processor, "input_preprocessor", None) + if preprocessor is None: + return fallback_prompt + consume = getattr(preprocessor, "consume_last_processed_prompt", None) + if consume is None: + return fallback_prompt + processed_prompt = consume() + if processed_prompt is None: + return fallback_prompt + if fallback_prompt is not None and not isinstance(processed_prompt, type(fallback_prompt)): + return fallback_prompt + return processed_prompt + + def _weak_shutdown_async_omni_engine( orchestrator_thread: threading.Thread | None, request_queue: janus.Queue[dict[str, Any]] | None, @@ -658,10 +679,19 @@ def _attach_llm_stage( # Use omni preprocessor so text-only prompts with # mm_processor_kwargs (e.g. GLM-Image t2i target_h/target_w) # still go through multimodal processor path. - input_processor.input_preprocessor = OmniInputPreprocessor( + omni_preprocessor = OmniInputPreprocessor( vllm_config=started.vllm_config, renderer=input_processor.renderer, ) + ingress_processor_factory = getattr(started.metadata, "initial_prompt_processor_factory", None) + if ingress_processor_factory is not None: + omni_preprocessor.set_initial_prompt_processor( + ingress_processor_factory( + vllm_config=started.vllm_config, + tokenizer=tokenizer, + ) + ) + input_processor.input_preprocessor = omni_preprocessor except Exception: try: stage_client.shutdown() @@ -1051,9 +1081,10 @@ def _build_add_request_message( data_parallel_rank=data_parallel_rank, resumable=resumable, ) + processed_prompt = _consume_processed_prompt(self.input_processor, prompt) # TODO (Peiqi): add this for Qwen3-TTS only. Other models don't have # additional_information field in the prompt. - request = _upgrade_to_omni_request(request, prompt) + request = _upgrade_to_omni_request(request, processed_prompt) if reasoning_ended is not None: request.reasoning_ended = reasoning_ended @@ -1121,11 +1152,13 @@ def _enqueue_cfg_companions( params=companion_params, supported_tasks=self.supported_tasks, ) + processed_prompt = _consume_processed_prompt(self.input_processor, companion_prompt) + request = _upgrade_to_omni_request(request, processed_prompt) request.external_req_id = cid self.output_processors[0].add_request( request=request, - prompt=companion_prompt, + prompt=processed_prompt, parent_req=None, request_index=0, queue=None, diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index cc7676ba5d4..c5d6f01fa85 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -255,6 +255,7 @@ class StageMetadata: final_output_type: str | None default_sampling_params: OmniSamplingParams custom_process_input_func: Callable | None + initial_prompt_processor_factory: Callable | None model_stage: str | None runtime_cfg: Any prompt_expand_func: Callable | None = None @@ -309,6 +310,11 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata: mod_path, fn_name = _cpif_path.rsplit(".", 1) custom_process_input_func = getattr(importlib.import_module(mod_path), fn_name) + initial_prompt_processor_factory: Callable | None = None + if hasattr(stage_config, "initial_prompt_processor"): + mod_path, fn_name = stage_config.initial_prompt_processor.rsplit(".", 1) + initial_prompt_processor_factory = getattr(importlib.import_module(mod_path), fn_name) + prompt_expand_func: Callable | None = None _pef_path = getattr(stage_config, "prompt_expand_func", None) if _pef_path: @@ -333,6 +339,7 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata: final_output_type=final_output_type, default_sampling_params=default_sampling_params, custom_process_input_func=custom_process_input_func, + initial_prompt_processor_factory=initial_prompt_processor_factory, model_stage=None, runtime_cfg=runtime_cfg, cfg_kv_collect_func=cfg_kv_collect_func, @@ -354,6 +361,7 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata: final_output_type=final_output_type, default_sampling_params=default_sampling_params, custom_process_input_func=custom_process_input_func, + initial_prompt_processor_factory=initial_prompt_processor_factory, model_stage=model_stage, runtime_cfg=runtime_cfg, prompt_expand_func=prompt_expand_func, diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py index 59b5777a874..b5b5e13b915 100644 --- a/vllm_omni/entrypoints/openai/protocol/audio.py +++ b/vllm_omni/entrypoints/openai/protocol/audio.py @@ -7,6 +7,43 @@ _MAX_EMBEDDING_DIM = 8192 +def _normalize_ref_audio_value(value): + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, (list, tuple)): + items = [] + for item in value: + if not isinstance(item, str): + raise TypeError("'ref_audio' list entries must be strings") + items.append(item) + if not items: + raise ValueError("'ref_audio' list cannot be empty") + return items + raise TypeError("'ref_audio' must be a string or list of strings") + + +def _normalize_speaker_embedding_value(value): + if value is None: + return None + if not isinstance(value, (list, tuple)): + raise TypeError("'speaker_embedding' must be a list of numbers or list of embedding vectors") + if not value: + return [] + + first = value[0] + if isinstance(first, (list, tuple)): + embeddings = [] + for item in value: + if not isinstance(item, (list, tuple)): + raise TypeError("'speaker_embedding' must not mix flat and nested values") + embeddings.append([float(x) for x in item]) + return embeddings + + return [float(x) for x in value] + + class OpenAICreateSpeechRequest(BaseModel): input: str model: str | None = None @@ -46,7 +83,7 @@ class OpenAICreateSpeechRequest(BaseModel): default=None, description="Language code (e.g., 'Chinese', 'English', 'Auto')", ) - ref_audio: str | None = Field( + ref_audio: str | list[str] | None = Field( default=None, description="Reference audio for voice cloning (Base task). URL, base64, or file URI.", ) @@ -58,7 +95,7 @@ class OpenAICreateSpeechRequest(BaseModel): default=None, description="Use speaker embedding only without in-context learning (Base task)", ) - speaker_embedding: list[float] | None = Field( + speaker_embedding: list[float] | list[list[float]] | None = Field( default=None, max_length=_MAX_EMBEDDING_DIM, description="Pre-computed speaker embedding vector (1024-dim for 0.6B, " @@ -86,17 +123,36 @@ def validate_stream_format(cls, v: str) -> str: raise ValueError("'sse' is not a supported stream_format yet. Please use 'audio'.") return v + @field_validator("ref_audio", mode="before") + @classmethod + def normalize_ref_audio(cls, v): + return _normalize_ref_audio_value(v) + @field_validator("speaker_embedding") @classmethod - def validate_speaker_embedding(cls, v: list[float] | None) -> list[float] | None: - if v is not None and not all(math.isfinite(x) for x in v): + def validate_speaker_embedding( + cls, v: list[float] | list[list[float]] | None + ) -> list[float] | list[list[float]] | None: + v = _normalize_speaker_embedding_value(v) + if v is None: + return None + if not v: + return [] + if isinstance(v[0], list): + for item in v: + if not item: + raise ValueError("'speaker_embedding' nested vectors must be non-empty") + if not all(math.isfinite(x) for x in item): + raise ValueError("'speaker_embedding' values must be finite (no NaN or Inf)") + return v + if not all(math.isfinite(x) for x in v): raise ValueError("'speaker_embedding' values must be finite (no NaN or Inf)") return v @model_validator(mode="after") def validate_embedding_constraints(self) -> "OpenAICreateSpeechRequest": if self.speaker_embedding is not None: - if self.ref_audio is not None: + if self.ref_audio is not None and not isinstance(self.ref_audio, list): raise ValueError("'speaker_embedding' and 'ref_audio' are mutually exclusive") return self diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 4946368e904..93322ce4120 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -56,6 +56,7 @@ logger = init_logger(__name__) # TTS Configuration +_MING_TTS_MODEL_ARCHS = {"MingTTSForConditionalGeneration"} _VOXTRAL_TTS_MODEL_STAGES = {"audio_generation"} _QWEN3_TTS_MODEL_STAGES = {"qwen3_tts"} _FISH_TTS_MODEL_STAGES = {"fish_speech_slow_ar"} @@ -92,6 +93,7 @@ _TTS_MAX_INSTRUCTIONS_LENGTH = 500 _TTS_MAX_NEW_TOKENS_MIN = 1 _TTS_MAX_NEW_TOKENS_MAX = 4096 +_MING_DEFAULT_PROMPT = "Please generate speech based on the following description.\n" def _create_wav_header(sample_rate: int, num_channels: int = 1, bits_per_sample: int = 16) -> bytes: @@ -291,7 +293,13 @@ def shutdown(self) -> None: def _find_tts_stage(self): """Find and return the TTS stage config, or None if not found.""" for stage in self.engine_client.stage_configs: - if stage.engine_args.model_stage in _TTS_MODEL_STAGES: + engine_args = getattr(stage, "engine_args", None) + model_stage = getattr(engine_args, "model_stage", None) + model_arch = getattr(engine_args, "model_arch", None) + worker_type = getattr(engine_args, "worker_type", None) + if model_stage in _TTS_MODEL_STAGES: + return stage + if model_arch in _MING_TTS_MODEL_ARCHS and worker_type == "ar": return stage return None @@ -323,6 +331,8 @@ def _detect_tts_model_type(self) -> str | None: return "voxcpm" if has_vae_stage or model_stage == "vae" else "voxcpm2" if model_stage in _MING_TTS_MODEL_STAGES: return "ming_flash_omni_tts" + if model_arch in _MING_TTS_MODEL_ARCHS: + return "ming_tts" return None def _compute_max_instructions_length(self) -> int: @@ -354,6 +364,8 @@ def _load_supported_speakers(self) -> set[str]: try: if self._tts_model_type == "voxcpm": return set() + if self._tts_model_type == "ming_tts": + return set() if self._tts_model_type == "voxtral_tts": config = self.engine_client.model_config.hf_config.audio_config else: @@ -751,11 +763,15 @@ async def upload_voice_embedding(self, embedding_json: str, consent: str, name: raise ValueError("'speaker_embedding' values must be finite (no NaN or Inf)") emb_dim = len(embedding) - if emb_dim not in {1024, 2048}: - logger.warning( - "speaker_embedding has %d dimensions; expected 1024 (0.6B) or 2048 (1.7B)", - emb_dim, - ) + expected_dims = {192} if self._tts_model_type == "ming_tts" else {1024, 2048} + if emb_dim not in expected_dims: + if self._tts_model_type == "ming_tts": + logger.warning("speaker_embedding has %d dimensions; Ming dense expects 192", emb_dim) + else: + logger.warning( + "speaker_embedding has %d dimensions; expected 1024 (0.6B) or 2048 (1.7B)", + emb_dim, + ) voice_name_lower = name.lower() if voice_name_lower in self.uploaded_speakers: @@ -838,7 +854,7 @@ async def delete_voice(self, name: str) -> bool: def _is_tts_model(self) -> bool: """Check if the current model is a supported TTS model.""" - return any(stage.engine_args.model_stage in _TTS_MODEL_STAGES for stage in self.engine_client.stage_configs) + return self._find_tts_stage() is not None def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: """Validate TTS request parameters. Returns error message or None.""" @@ -853,6 +869,8 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non if self._tts_model_type == "voxcpm2": return None # VoxCPM2 accepts any text input if self._tts_model_type == "ming_flash_omni_tts": + return self._validate_ming_flash_omni_tts_request(request) + if self._tts_model_type == "ming_tts": return self._validate_ming_tts_request(request) return self._validate_qwen_tts_request(request) @@ -875,7 +893,7 @@ def _voxcpm2_encode(self, text: str) -> list[int]: ids = self._voxcpm2_tokenizer.encode(text, add_special_tokens=True) return split_multichar_chinese(ids, self._voxcpm2_split_map) - def _validate_ming_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: + def _validate_ming_flash_omni_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: """Validate Ming-flash-omni standalone-talker request parameters.""" if not request.input or not request.input.strip(): return "Input text cannot be empty" @@ -912,6 +930,8 @@ def _validate_ming_tts_request(self, request: OpenAICreateSpeechRequest) -> str def _validate_ref_audio_format(self, ref_audio: str) -> str | None: """Validate ref_audio is a supported URI format. Returns error or None.""" + if not isinstance(ref_audio, str): + return "ref_audio must be a URL (http/https), base64 data URL (data:...), or file URI (file://...)" if not ( ref_audio.startswith(("http://", "https://")) or ref_audio.startswith("data:") @@ -1175,6 +1195,99 @@ def _validate_cosyvoice3_request(self, request: OpenAICreateSpeechRequest) -> st return None + def _validate_ming_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: + """Validate Ming TTS request parameters. Returns error message or None.""" + if not request.input or not request.input.strip(): + return "Input text cannot be empty" + + if isinstance(request.ref_audio, list): + return self._validate_ming_tts_podcast_request(request) + return self._validate_ming_tts_single_speaker_request(request) + + def _validate_ming_tts_single_speaker_request(self, request: OpenAICreateSpeechRequest) -> str | None: + if request.ref_audio is not None: + fmt_err = self._validate_ref_audio_format(request.ref_audio) + if fmt_err: + return fmt_err + + if request.speaker_embedding is not None: + if not request.speaker_embedding: + return "'speaker_embedding' must be a non-empty list of floats" + emb_len = len(request.speaker_embedding) + if emb_len != 192: + logger.warning( + "speaker_embedding has %d dimensions; Ming dense expects 192. " + "Wrong dimensions will likely fail or degrade output.", + emb_len, + ) + + voice_lower = request.voice.lower() if isinstance(request.voice, str) else None + uploaded_voice = bool(voice_lower and voice_lower in self.uploaded_speakers) + clone_source_present = request.ref_audio is not None or request.speaker_embedding is not None or uploaded_voice + + if request.task_type == "Base" and not clone_source_present: + return "Base task requires 'ref_audio', 'speaker_embedding', or an uploaded voice sample" + + if request.ref_audio is not None and request.ref_text is not None and not request.ref_text.strip(): + return "'ref_text' must be non-empty when provided with 'ref_audio'" + + # Ming offline ref-audio cases use prompt_waveform without prompt_text; + # keep the transcript requirement for other TTS models. + if request.ref_audio is not None and request.speaker_embedding is None and not self._is_ming_tts_model(): + uploaded_ref_text = self.uploaded_speakers[voice_lower].get("ref_text") if uploaded_voice else None + if not (request.ref_text and request.ref_text.strip()) and not uploaded_ref_text: + return "Reference-audio cloning requires non-empty 'ref_text'" + + if request.ref_text is not None and request.ref_audio is None and not uploaded_voice: + return "'ref_text' requires 'ref_audio' or an uploaded voice sample" + + if request.instructions and len(request.instructions) > self._max_instructions_length: + return f"Instructions too long (max {self._max_instructions_length} characters)" + + if request.max_new_tokens is not None: + if request.max_new_tokens < _TTS_MAX_NEW_TOKENS_MIN: + return f"max_new_tokens must be at least {_TTS_MAX_NEW_TOKENS_MIN}" + if request.max_new_tokens > _TTS_MAX_NEW_TOKENS_MAX: + return f"max_new_tokens cannot exceed {_TTS_MAX_NEW_TOKENS_MAX}" + + return None + + def _validate_ming_tts_podcast_request(self, request: OpenAICreateSpeechRequest) -> str | None: + if len(request.ref_audio) < 2: + return "Podcast-style Ming requests require at least two 'ref_audio' clips" + + for ref_audio in request.ref_audio: + fmt_err = self._validate_ref_audio_format(ref_audio) + if fmt_err: + return fmt_err + + if not request.ref_text or not request.ref_text.strip(): + return "Podcast-style Ming requests require non-empty 'ref_text'" + + if request.speaker_embedding is not None: + embeddings = request.speaker_embedding + embedding_count = len(embeddings) if embeddings and isinstance(embeddings[0], list) else 1 + if embedding_count != len(request.ref_audio): + return ( + "Podcast-style Ming requests require one speaker embedding per ref_audio clip; " + f"got {embedding_count} embeddings for {len(request.ref_audio)} clips" + ) + if embeddings and isinstance(embeddings[0], list): + for item in embeddings: + if len(item) != 192: + return "Podcast-style Ming speaker embeddings must each have 192 dimensions" + + if request.instructions and len(request.instructions) > self._max_instructions_length: + return f"Instructions too long (max {self._max_instructions_length} characters)" + + if request.max_new_tokens is not None: + if request.max_new_tokens < _TTS_MAX_NEW_TOKENS_MIN: + return f"max_new_tokens must be at least {_TTS_MAX_NEW_TOKENS_MIN}" + if request.max_new_tokens > _TTS_MAX_NEW_TOKENS_MAX: + return f"max_new_tokens cannot exceed {_TTS_MAX_NEW_TOKENS_MAX}" + + return None + async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int]: """Resolve ref_audio to (wav_samples, sample_rate). @@ -1209,13 +1322,123 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int ) return wav_np.tolist(), sr - async def _generate_audio_chunks( + async def _resolve_ref_audio_many(self, ref_audio_list: list[str]) -> list[tuple[list[float], int]]: + resolved = [] + for ref_audio in ref_audio_list: + resolved.append(await self._resolve_ref_audio(ref_audio)) + return resolved + + # ---- Ming TTS helpers ---- + + def _is_ming_tts_model(self) -> bool: + return self._tts_model_type == "ming_tts" + + def _coerce_ming_prompt_waveform(self, wav_samples, sample_rate): + from torchaudio.functional import resample as resample_audio + + from vllm_omni.model_executor.models.ming_tts.config_ming_tts import SAMPLE_RATE + + waveform = torch.as_tensor(wav_samples, dtype=torch.float32).reshape(1, -1) + if int(sample_rate) != SAMPLE_RATE: + waveform = resample_audio(waveform, int(sample_rate), SAMPLE_RATE) + return waveform + + def _build_ming_prompt_waveform( self, - generator, - request_id: str, - response_format: str = "pcm", - raw_request: Request | None = None, + ref_audio_data: tuple[list[float], int] | list[tuple[list[float], int]] | None, ): + if isinstance(ref_audio_data, list): + return torch.cat( + [self._coerce_ming_prompt_waveform(item[0], item[1]) for item in ref_audio_data], + dim=-1, + ) + if ref_audio_data is not None: + return self._coerce_ming_prompt_waveform(ref_audio_data[0], ref_audio_data[1]) + return None + + def _extract_ming_speaker_embeddings_from_ref_audio( + self, + ref_audio_data_list: list[tuple[list[float], int]], + ) -> list[list[float]]: + from vllm_omni.model_executor.models.ming_tts.speaker_extractor import MingSpeakerEmbeddingExtractor + + extractor = MingSpeakerEmbeddingExtractor(self.engine_client.model_config.model, target_sr=16000) + embeddings = [] + for wav_samples, sr in ref_audio_data_list: + waveform = torch.as_tensor(wav_samples, dtype=torch.float32).reshape(1, -1) + embedding = extractor.extract_from_waveform(waveform, int(sr)) + flat = embedding.detach().reshape(-1).to(torch.float32).cpu() + if int(flat.numel()) != 192: + raise ValueError(f"Ming speaker extractor returned {int(flat.numel())} dims; expected 192") + embeddings.append(flat.tolist()) + return embeddings + + def _parse_ming_instruction(self, request: OpenAICreateSpeechRequest) -> Any: + """Build a Ming instruction payload from OpenAI speech fields.""" + instruction_text = request.instructions.strip() if isinstance(request.instructions, str) else None + instruction_dict: dict[str, Any] = {} + + if request.language not in (None, "", "Auto"): + instruction_dict["方言"] = request.language + + voice_lower = request.voice.lower() if isinstance(request.voice, str) else None + if request.voice and not (voice_lower and voice_lower in self.uploaded_speakers): + instruction_dict["IP"] = request.voice + + if instruction_text: + try: + parsed = json.loads(instruction_text) + except json.JSONDecodeError: + parsed = None + if isinstance(parsed, dict): + instruction_dict.update(parsed) + elif instruction_dict: + instruction_dict["风格"] = instruction_text + else: + return instruction_text + + return instruction_dict or None + + def _build_ming_dense_prompt( + self, + request: OpenAICreateSpeechRequest, + *, + ref_audio_data: tuple[list[float], int] | list[tuple[list[float], int]] | None = None, + ) -> dict[str, Any]: + """Build a Ming dense prompt directly from the OpenAI speech request.""" + from transformers import AutoTokenizer + + from vllm_omni.model_executor.models.ming_tts.config_ming_tts import KEY_MAX_DECODE_STEPS + from vllm_omni.model_executor.models.ming_tts.prompt_builder import build_ming_dense_prompt + + if self._tts_tokenizer is None: + model_name = self.engine_client.model_config.model + self._tts_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False) + + ref_text = request.ref_text + prompt_waveform = self._build_ming_prompt_waveform(ref_audio_data) if ref_text is not None else None + speaker_embedding = request.speaker_embedding + use_zero_spk_emb = prompt_waveform is None and speaker_embedding is None + + runtime_controls = {} + if request.max_new_tokens is not None: + runtime_controls[KEY_MAX_DECODE_STEPS] = request.max_new_tokens + + return build_ming_dense_prompt( + self._tts_tokenizer, + # bgm / music-prompt mode not supported online; + # requires prompt_mode API extension (deferred). + prompt=_MING_DEFAULT_PROMPT, + text=request.input, + runtime_controls=runtime_controls or None, + instruction=self._parse_ming_instruction(request), + prompt_text=ref_text, + prompt_waveform=prompt_waveform, + speaker_embedding=speaker_embedding, + use_zero_spk_emb=use_zero_spk_emb, + ) + + async def _generate_audio_chunks(self, generator, request_id: str, response_format: str = "pcm"): """Generate audio chunks for streaming response. Handles two audio output modes from the engine: @@ -1574,7 +1797,7 @@ async def _build_cosyvoice3_prompt( # ---- Ming-flash-omni standalone-talker (TTS) helpers ---- - def _build_ming_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: + def _build_ming_flash_omni_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: # request.instructions accepts two forms: # 1. Plain text: mapped to the caption's 风格 (style) field # 2. JSON object: parsed and splatted into the caption. Unlocks @@ -1681,8 +1904,29 @@ async def _prepare_speech_generation( prompt = await self._build_cosyvoice3_prompt(request) tts_params = {} elif self._tts_model_type == "ming_flash_omni_tts": - prompt = self._build_ming_prompt(request) + prompt = self._build_ming_flash_omni_prompt(request) tts_params = {} + elif self._tts_model_type == "ming_tts": + ref_audio_source = request.ref_audio + voice_lower = request.voice.lower() if isinstance(request.voice, str) else None + if ref_audio_source is None and voice_lower in self.uploaded_speakers: + ref_audio_source = self._get_uploaded_audio_data(request.voice) + if request.ref_text is None: + request.ref_text = self.uploaded_speakers[voice_lower].get("ref_text") + ref_audio_data = None + if isinstance(ref_audio_source, list): + ref_audio_data = await self._resolve_ref_audio_many(ref_audio_source) + if request.speaker_embedding is None: + request.speaker_embedding = self._extract_ming_speaker_embeddings_from_ref_audio(ref_audio_data) + elif ref_audio_source is not None and isinstance(ref_audio_source, str): + wav_list, sr = await self._resolve_ref_audio(ref_audio_source) + ref_audio_data = (wav_list, sr) + if request.speaker_embedding is None: + request.speaker_embedding = self._extract_ming_speaker_embeddings_from_ref_audio( + [ref_audio_data] + )[0] + prompt = self._build_ming_dense_prompt(request, ref_audio_data=ref_audio_data) + tts_params = prompt.get("additional_information", {}) else: tts_params = self._build_tts_params(request) # Resolve ref_audio (explicit or auto-set for uploaded voices) @@ -1732,6 +1976,8 @@ async def _prepare_speech_generation( model_type = "voxcpm2" elif self._tts_model_type == "ming_flash_omni_tts": model_type = "ming_flash_omni_tts" + elif self._tts_model_type == "ming_tts": + model_type = "ming_tts" elif self._is_tts: model_type = tts_params.get("task_type", ["unknown"])[0] else: @@ -1790,6 +2036,17 @@ async def _prepare_speech_generation( sampling_params_list = copy.deepcopy(sampling_params_list) sampling_params_list[0].max_tokens = request.max_new_tokens + elif self._tts_model_type == "ming_tts" and sampling_params_list: + import copy + + from vllm_omni.model_executor.models.ming_tts.config_ming_tts import TEXT_EOS_TOKEN_ID + + sampling_params_list = copy.deepcopy(sampling_params_list) + sampling_params_list[0].stop_token_ids = [int(TEXT_EOS_TOKEN_ID)] + if request.max_new_tokens is not None: + # Ming emits TEXT_EOS after the latent decode budget is exhausted, so + # Stage-0 needs one extra token beyond ming_max_decode_steps. + sampling_params_list[0].max_tokens = int(request.max_new_tokens) + 1 generator = self.engine_client.generate( prompt=prompt, diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py index cca6ce56870..a5bba6cb4df 100644 --- a/vllm_omni/inputs/preprocess.py +++ b/vllm_omni/inputs/preprocess.py @@ -25,6 +25,30 @@ class OmniInputPreprocessor(InputPreprocessor): Supports processing tokens, embeddings, text, and multimodal inputs. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.initial_prompt_processor = None + self._last_processed_prompt = None + + def set_initial_prompt_processor(self, processor: Any) -> None: + self.initial_prompt_processor = processor + + def consume_last_processed_prompt(self) -> Any: + prompt = self._last_processed_prompt + self._last_processed_prompt = None + return prompt + + def _apply_initial_prompt_processor(self, prompt: SingletonDictPrompt) -> SingletonDictPrompt: + self._last_processed_prompt = prompt + processor = self.initial_prompt_processor + if processor is None or not isinstance(prompt, dict): + return prompt + processed = processor(prompt) + if not isinstance(processed, dict): + raise TypeError(f"Initial prompt processor must return a prompt dict, got {type(processed).__name__}") + self._last_processed_prompt = processed + return processed + def _process_text( self, parsed_content: OmniTextPrompt, @@ -164,6 +188,8 @@ def _prompt_to_llm_inputs( * [`SingletonInput`][vllm.inputs.engine.SingletonInput] instance """ + prompt = self._apply_initial_prompt_processor(prompt) + if "prompt_embeds" in prompt: return self._process_embeds(prompt) # type: ignore[arg-type] diff --git a/vllm_omni/model_executor/models/ming_tts/__init__.py b/vllm_omni/model_executor/models/ming_tts/__init__.py new file mode 100644 index 00000000000..5c945c5410e --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .configuration_ming_dense import MingDenseConfig +from .ming_tts import MingTTSForConditionalGeneration +from .ming_tts_audio_vae import MingAudioVAEModel +from .ming_tts_llm import MingLLMModel + +__all__ = [ + "MingDenseConfig", + "MingTTSForConditionalGeneration", + "MingLLMModel", + "MingAudioVAEModel", +] diff --git a/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/__init__.py b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/audio_encoder.py b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/audio_encoder.py new file mode 100644 index 00000000000..d5d11121791 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/audio_encoder.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adopted from https://github.com/inclusionAI/Ming-omni-tts/blob/main/audio_tokenizer/audio_encoder.py + +from collections.abc import Iterable + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torchtune.modules import RotaryPositionalEmbeddings +from transformers.cache_utils import DynamicCache + +try: + from flash_attn import flash_attn_func + + _FLASH_ATTN_AVAILABLE = True +except (ImportError, RuntimeError, OSError): + _FLASH_ATTN_AVAILABLE = False + flash_attn_func = None # guarded by semantic_module_kwargs check above + + +class LayerNorm(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + return super().forward(x.float()).type(x.dtype) + + +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), + ) + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_state: int, n_head: int, layer_idx: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + self.layer_idx = layer_idx + self.rotary_embed = RotaryPositionalEmbeddings(dim=n_state // n_head) + + def forward(self, x: Tensor, past_key_values=None): + q = self.query(x) + k = self.key(x) + v = self.value(x) + + wv, qk, past_key_values = self.qkv_attention(q, k, v, past_key_values=past_key_values) + return self.out(wv), qk, past_key_values + + def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, past_key_values=None): + if not _FLASH_ATTN_AVAILABLE: + raise ImportError("flash_attn is required for Ming semantic audio encoder attention.") + q = q.view(*q.shape[:2], self.n_head, -1) # [B, T, nhead, dm] + k = k.view(*k.shape[:2], self.n_head, -1) # [B, T, nhead, dm] + v = v.view(*v.shape[:2], self.n_head, -1) # [B, T, nhead, dm] + + if past_key_values is not None: + past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + q.size(1), device=q.device) + cache_position = cache_position.unsqueeze(0) + else: + cache_position = None + + q = self.rotary_embed(q, input_pos=cache_position) + k = self.rotary_embed(k, input_pos=cache_position) + + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + if past_key_values is not None: + k, v = past_key_values.update(k, v, self.layer_idx, {"cache_position": cache_position}) + + a = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), causal=True) + out = a.flatten(start_dim=2) + qk = None + + return out, qk, past_key_values + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, layer_idx: int): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head, layer_idx) + self.attn_ln = LayerNorm(n_state) + n_mlp = n_state * 4 + self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) + self.mlp_ln = LayerNorm(n_state) + self.layer_idx = layer_idx + + def forward(self, x: Tensor, past_key_values=None): + attn_out, _, past_key_values = self.attn(self.attn_ln(x), past_key_values=past_key_values) + x = x + attn_out + x = x + self.mlp(self.mlp_ln(x)) + return x, past_key_values + + +class WhisperAudioEncoder(nn.Module): + def __init__(self, n_state: int, n_head: int, n_layer: int): + super().__init__() + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head, layer_idx=i) for i in range(n_layer)] + ) + self.ln_post = LayerNorm(n_state) + + def forward(self, whisper_feats: Tensor, use_cache=False, past_key_values=None, **kwargs): + if past_key_values is None and use_cache: + past_key_values = DynamicCache() + + x = whisper_feats + + for block in self.blocks: + x, past_key_values = block(x, past_key_values=past_key_values) + + x = self.ln_post(x) + + return x, past_key_values + + @classmethod + def from_pretrained(cls, dims): + audio_encoder = cls( + dims["n_state"], + dims["n_head"], + dims["n_layer"], + ) + + audio_encoder.audio_emb_dim = dims["n_state"] + return audio_encoder diff --git a/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/configuration_audio_vae.py b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/configuration_audio_vae.py new file mode 100644 index 00000000000..ce9c069c277 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/configuration_audio_vae.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adopted from https://github.com/inclusionAI/Ming-omni-tts/blob/main/audio_tokenizer/configuration_audio_vae.py + + +from transformers import PretrainedConfig + + +class AudioVAEconfig(PretrainedConfig): + def __init__( + self, + sample_rate: int = 16000, + enc_kwargs: dict = None, + semantic_module_kwargs: dict = None, + dec_kwargs: dict = None, + hifi_gan_disc_kwargs: dict = None, + spec_disc_kwargs: dict = None, + lambda_disc=1.0, + lambda_mel_loss=15, + lambda_adv=1.0, + lambda_feat_match_loss=1.0, + lambda_semantic=5.0, + init_method="normal", + patch_size=-1, + **kwargs, + ): + self.sample_rate = sample_rate + self.enc_kwargs = enc_kwargs + self.semantic_module_kwargs = semantic_module_kwargs + self.dec_kwargs = dec_kwargs + self.hifi_gan_disc_kwargs = hifi_gan_disc_kwargs + self.spec_disc_kwargs = spec_disc_kwargs + self.lambda_disc = lambda_disc + self.lambda_mel_loss = lambda_mel_loss + self.lambda_adv = lambda_adv + self.lambda_feat_match_loss = lambda_feat_match_loss + self.lambda_semantic = lambda_semantic + self.init_method = init_method + self.patch_size = patch_size + super().__init__(**kwargs) diff --git a/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/istft.py b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/istft.py new file mode 100644 index 00000000000..c365381c87f --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/istft.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adopted from https://github.com/inclusionAI/Ming-omni-tts/blob/main/audio_tokenizer/istft.py + +import torch +import torch.nn as nn + + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + self.audio_buffer = None + self.window_buffer = None + self.buffer_len = self.win_length - self.hop_length + + def __buffer_process(self, x, buffer, pad, last_chunk=False, streaming=False): + if streaming: + if buffer is None: + # first chunk + x = x[:, pad:] + if buffer is not None: + # next chunk + x[:, : self.buffer_len] += buffer + buffer = x[:, -self.buffer_len :] + if not last_chunk: + x = x[:, : -self.buffer_len] + else: + x = x[:, :-pad] + else: + x = x[:, pad:-pad] + + return x, buffer + + def forward(self, spec: torch.Tensor, audio_buffer=None, window_buffer=None, streaming=False, last_chunk=False): + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + audio_buffer (Tensor): [Streaming Input/State] The audio overlap buffer from the previous chunk. + Shape: (B, win_length - hop_length) + window_buffer (Tensor): [Streaming Input/State] The window overlap buffer from the previous chunk. + streaming: If `True`, the function operates in streaming mode, processing `spec` as a single chunk. + last_chunk: When `streaming=True` and `last_chunk=True`, the function can perform final "flush" operations + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + if spec.dim() != 3: + raise ValueError(f"Expected spec rank-3 [Batch, Freq, Time], got {tuple(spec.shape)}") + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + )[:, 0, 0, :] + + y, audio_buffer = self.__buffer_process(y, audio_buffer, pad, last_chunk=last_chunk, streaming=streaming) + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = ( + torch.nn.functional.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + ) + .squeeze(0) + .squeeze(0) + ) + + window_envelope, window_buffer = self.__buffer_process( + window_envelope, window_buffer, pad, last_chunk=last_chunk, streaming=streaming + ) + window_envelope = window_envelope.squeeze() + + # Normalize + if not (window_envelope > 1e-11).all(): + raise RuntimeError("ISTFT window envelope underflowed; invalid overlap-add state.") + y = y / window_envelope + + return y, audio_buffer, window_buffer + + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) + + def forward(self, x: torch.Tensor, audio_buffer=None, window_buffer=None, streaming=False, last_chunk=False): + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x_pred = self.out(x) + # x_pred = x + x_pred = x_pred.transpose(1, 2) + mag, p = x_pred.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + S = mag * (x + 1j * y) + audio, audio_buffer, window_buffer = self.istft( + S, audio_buffer=audio_buffer, window_buffer=window_buffer, streaming=streaming, last_chunk=last_chunk + ) + return audio.unsqueeze(1), x_pred, audio_buffer, window_buffer diff --git a/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/modeling_audio_vae.py b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/modeling_audio_vae.py new file mode 100644 index 00000000000..f72c12184fd --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/modeling_audio_vae.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adopted from https://github.com/inclusionAI/Ming-omni-tts/blob/main/audio_tokenizer/modeling_audio_vae.py +# audio_tokenizer/modeling_audio_vae.py +import torch +import torch.nn as nn +from vllm.logger import init_logger + +from .configuration_audio_vae import AudioVAEconfig +from .vae_modules import Decoder, Encoder + +logger = init_logger(__name__) + + +def _get_backbone(config: AudioVAEconfig, branch: str): + branch_cfg = getattr(config, branch, None) + if not isinstance(branch_cfg, dict): + return None + backbone = branch_cfg.get("backbone") + if not isinstance(backbone, dict): + return None + return backbone + + +def _maybe_fallback_attention(config: AudioVAEconfig) -> None: + enc_backbone = _get_backbone(config, "enc_kwargs") + dec_backbone = _get_backbone(config, "dec_kwargs") + requested_attn_impl = "flash_attention_2" + + if dec_backbone is not None: + requested_attn_impl = dec_backbone.get( + "_attn_implementation", + dec_backbone.get("attn_implementation", requested_attn_impl), + ) + elif enc_backbone is not None: + requested_attn_impl = enc_backbone.get( + "_attn_implementation", + enc_backbone.get("attn_implementation", requested_attn_impl), + ) + + if requested_attn_impl != "flash_attention_2": + return + + try: + import flash_attn # noqa: F401 + except ImportError: + if enc_backbone is not None: + enc_backbone["_attn_implementation"] = "sdpa" + enc_backbone["attn_implementation"] = "sdpa" + if dec_backbone is not None: + dec_backbone["_attn_implementation"] = "sdpa" + dec_backbone["attn_implementation"] = "sdpa" + logger.warning("flash_attn not available, falling back to sdpa for Ming audio VAE") + + +class AudioVAE(nn.Module): + def __init__(self, config: AudioVAEconfig): + super().__init__() + self.config = config + _maybe_fallback_attention(self.config) + + # --- Ming/Bailing config sanity (fail early on bad nested config parsing) --- + enc_kwargs = config.enc_kwargs + dec_kwargs = config.dec_kwargs + + # Required nested fields + for k in ("backbone", "input_dim", "latent_dim"): + if k not in enc_kwargs: + raise ValueError(f"AudioVAE.enc_kwargs missing required key: {k}") + for k in ("backbone", "output_dim", "latent_dim"): + if k not in dec_kwargs: + raise ValueError(f"AudioVAE.dec_kwargs missing required key: {k}") + + # Ming-specific geometry checks (safe because this integration targets Ming checkpoint family) + hop_size = enc_kwargs.get("hop_size", enc_kwargs["input_dim"]) + if enc_kwargs["input_dim"] != hop_size: + raise ValueError(f"AudioVAE encoder input_dim ({enc_kwargs['input_dim']}) != hop_size ({hop_size}).") + if hop_size != dec_kwargs["output_dim"]: + raise ValueError( + f"AudioVAE encoder hop_size ({hop_size}) != decoder output_dim ({dec_kwargs['output_dim']})." + ) + + self.encoder = Encoder( + encoder_args=enc_kwargs["backbone"], + input_dim=enc_kwargs["input_dim"], + hop_size=hop_size, + latent_dim=enc_kwargs["latent_dim"], + patch_size=config.patch_size, + ) + + # Semantic module is null for this checkpoint. + if config.semantic_module_kwargs is not None: + from .audio_encoder import WhisperAudioEncoder + + semantic_model = WhisperAudioEncoder.from_pretrained(dims=config.semantic_module_kwargs["whisper_encoder"]) + else: + semantic_model = None + + self.decoder = Decoder( + decoder_args=dec_kwargs["backbone"], # IMPORTANT: decoder uses dec_kwargs.backbone + output_dim=dec_kwargs["output_dim"], # Ming checkpoint uses 882 + latent_dim=dec_kwargs["latent_dim"], + semantic_model=semantic_model, + patch_size=config.patch_size, + ) + + @torch.inference_mode() + def encode_latent( + self, + waveform: torch.Tensor, + waveform_length: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode waveform -> acoustic latent. + """ + if waveform.ndim != 2: + raise ValueError(f"Expected waveform rank-2 [Batch, Time], got {tuple(waveform.shape)}") + if waveform_length.ndim != 1: + raise ValueError(f"Expected waveform_length rank-1 [Batch], got {tuple(waveform_length.shape)}") + if waveform.shape[0] != waveform_length.shape[0]: + raise ValueError( + "Batch mismatch: " + f"waveform batch={waveform.shape[0]} vs " + f"waveform_length batch={waveform_length.shape[0]}" + ) + if torch.any(waveform_length <= 0): + raise ValueError("waveform_length must be strictly positive.") + + frame_num = torch.ceil(waveform_length / self.config.enc_kwargs["input_dim"]).to(torch.int32) + if self.config.patch_size != -1: + frame_num = torch.ceil(frame_num / self.config.patch_size) + + h, _ = self.encoder(waveform) + h = h.transpose(1, 2) # [B, 2*latent_dim, T] (posterior params: mean + logvar) + + # Inline OobleckDiagonalGaussianDistribution.sample() + mean, logvar = torch.chunk(h, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + latent = mean + std * torch.randn_like(mean) # [B, latent_dim, T] + latent = latent.transpose(1, 2) # [B, T, d/2] + + return latent, frame_num + + @torch.inference_mode() + def decode( + self, + latent: torch.Tensor, + past_key_values=None, + use_cache: bool = False, + stream_state: tuple = (None, None, None), + last_chunk: bool = False, + ) -> tuple[torch.Tensor, tuple, object]: + """ + Decode acoustic latent -> waveform. + """ + if latent.dim() != 3: + raise ValueError(f"Expected latent rank-3 [B,T,D], got shape={tuple(latent.shape)}") + if latent.shape[0] <= 0: + raise ValueError("latent batch size must be positive.") + + target_dtype = next(self.decoder.parameters()).dtype + target_device = next(self.decoder.parameters()).device + if latent.dtype != target_dtype or latent.device != target_device: + latent = latent.to(device=target_device, dtype=target_dtype) + + expected_latent_dim = self.config.dec_kwargs["latent_dim"] + if latent.shape[-1] != expected_latent_dim: + raise ValueError(f"Latent dim mismatch in decode(): got {latent.shape[-1]}, expected {expected_latent_dim}") + + waveform, stream_state, past_key_values = self.decoder.low_level_reconstruct( + latent, + past_key_values=past_key_values, + use_cache=use_cache, + stream_state=stream_state, + last_chunk=last_chunk, + ) + return waveform, stream_state, past_key_values diff --git a/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/vae_modules.py b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/vae_modules.py new file mode 100644 index 00000000000..3920f4be7d4 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/audio_tokenizer/vae_modules.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adopted from https://github.com/inclusionAI/Ming-omni-tts/blob/main/audio_tokenizer/vae_modules.py +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Qwen2Config, Qwen2Model + +from .istft import ISTFTHead + + +class StreamingLinearUpsample(nn.Module): + def __init__(self, scale_factor=4): + super().__init__() + self.scale_factor = scale_factor + self.upsampler = nn.Upsample(scale_factor=scale_factor, mode="linear", align_corners=False) + + def forward(self, x, state=None, is_last=False): + if x is None and is_last and (state is None or state.get("prev_chunk") is None): + raise ValueError("Received end-of-stream without any latent chunk to upsample.") + # 初始化状态 + if state is None: + state = {"prev_chunk": None, "history_last": None, "is_first": True} + + if x is None and not is_last: + return None, state + + if state["is_first"] and is_last: + out = self.upsampler(x.transpose(1, 2)).transpose(1, 2) + return out, None # 结束后清除状态 + + output_chunks = [] + + if state["is_first"]: + state["prev_chunk"] = x + state["is_first"] = False + if not is_last: + return None, state + + if state["prev_chunk"] is not None: + p = state["prev_chunk"].transpose(1, 2) + + if state["history_last"] is None: + lookahead = x[:, :1, :].transpose(1, 2) + inp = torch.cat([p, lookahead], dim=2) + up = self.upsampler(inp) + out_prev = up[:, :, : p.size(2) * self.scale_factor] + else: + lookahead = x[:, :1, :].transpose(1, 2) + inp = torch.cat([state["history_last"], p, lookahead], dim=2) + up = self.upsampler(inp) + start = self.scale_factor + end = start + p.size(2) * self.scale_factor + out_prev = up[:, :, start:end] + + output_chunks.append(out_prev.transpose(1, 2)) + state["history_last"] = p[:, :, -1:] + state["prev_chunk"] = x + + if is_last: + p = state["prev_chunk"].transpose(1, 2) + inp = torch.cat([state["history_last"], p], dim=2) + up = self.upsampler(inp) + out_last = up[:, :, self.scale_factor :] + output_chunks.append(out_last.transpose(1, 2)) + state = None # 结束 + + final_out = torch.cat(output_chunks, dim=1) if output_chunks else None + return final_out, state + + +class Encoder(nn.Module): + def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64, patch_size=-1): + super().__init__() + config = Qwen2Config.from_dict(config_dict=encoder_args) + self.encoder = Qwen2Model(config) + self.input_dim = input_dim + self.hop_size = hop_size + self.latent_dim = latent_dim + self.fc1 = nn.Linear(input_dim, config.hidden_size, bias=False) + self.fc2 = nn.Linear(config.hidden_size, config.hidden_size) + self.fc3 = nn.Linear(config.hidden_size, latent_dim * 2) + self.norm = nn.LayerNorm(config.hidden_size) + self.patch_size = patch_size + if patch_size != -1: + aggregator_config = Qwen2Config.from_dict({**encoder_args, "num_hidden_layers": 4}) + self.aggregator = Qwen2Model(aggregator_config) + self.cls_embed = nn.Parameter(torch.rand(1, 1, config.hidden_size)) + self.cls_embed.data.normal_(0, 0.02) + + def get_frames(self, x): + num_frames_total = (x.size(-1) + self.hop_size - 1) // self.hop_size # 向上取整的帧数 + expected_len = (num_frames_total - 1) * self.hop_size + self.input_dim + padding_needed = expected_len - x.size(-1) + waveform = F.pad(x, (0, padding_needed), value=0.0) + + frames = waveform.unfold(dimension=-1, size=self.input_dim, step=self.hop_size) # [B, T, d] + return frames + + def pad_patch_insert_cls(self, x): + bsz, _, dim = x.size() + num_frame = x.size(1) + r = num_frame % self.patch_size + pad_num = self.patch_size - r if r else 0 + x = F.pad(x, (0, 0, 0, pad_num), value=0.0) # 帧数对齐到patch_size倍数 + x = x.reshape(-1, self.patch_size, dim) + x = torch.cat((x, self.cls_embed.expand(x.size(0), -1, -1)), dim=1) # 每个patch后插入一个cls + x = x.reshape(bsz, -1, dim) + return x + + def forward(self, waveform): + x = self.get_frames(waveform) + + x = self.fc1(x) + x = self.fc2(x) + x = self.encoder(inputs_embeds=x) + x = x.last_hidden_state + + # downsample + if self.patch_size != -1: + x = self.pad_patch_insert_cls(x) + x = self.aggregator(inputs_embeds=x) + x = x.last_hidden_state + bsz, _, dim = x.size() + x = x.reshape(-1, self.patch_size + 1, dim) + x = x[:, -1:, :].reshape(bsz, -1, dim) + + x = self.fc3(x) + return x, waveform.unsqueeze(1) + + +class Decoder(nn.Module): + def __init__(self, decoder_args, output_dim=320, latent_dim=64, semantic_model=None, patch_size=-1): + super().__init__() + config = Qwen2Config.from_dict(config_dict=decoder_args) + self.decoder = Qwen2Model(config) + self.output_dim = output_dim + self.latent_dim = latent_dim + self.fc1 = nn.Linear(latent_dim, config.hidden_size) + + if semantic_model is not None: + self.gelu = nn.GELU() + self.fc2 = nn.Linear(config.hidden_size, semantic_model.audio_emb_dim) + self.semantic_model = semantic_model + self.fc3 = nn.Linear(semantic_model.audio_emb_dim, config.hidden_size) + else: + self.semantic_model = None + + self.hop_length = output_dim + self.head = ISTFTHead( + dim=config.hidden_size, n_fft=self.hop_length * 4, hop_length=self.hop_length, padding="same" + ) + self.patch_size = patch_size + if self.patch_size != -1: + self.upsampling = StreamingLinearUpsample(scale_factor=patch_size) + + def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=False): + x = self.fc1(x) + + if self.semantic_model is not None: + x = self.fc2(self.gelu(x)) + x, past_key_values = self.semantic_model( + whisper_feats=x, past_key_values=past_key_values, use_cache=use_cache + ) + unified_emb = x + if only_semantic_emb: + return unified_emb, past_key_values + x = self.fc3(x) + else: + unified_emb = None + + if self.patch_size != -1: + x = self.upsampling(x.transpose(1, 2)).transpose(1, 2) + + x = self.decoder(inputs_embeds=x) + x = x.last_hidden_state + + x, _ = self.head(x) + + return x, unified_emb + + def low_level_reconstruct(self, x, past_key_values=None, use_cache=False, stream_state=None, last_chunk=False): + # Guard against None on first chunk (connector initialises per-request) + if stream_state is None: + stream_state = (None, None, None) + upsample_state, audio_buffer, window_buffer = stream_state + bsz, device, dtype = x.size(0), x.device, x.dtype + x = self.fc1(x) + if self.patch_size != -1: + if use_cache: + # streaming + x, upsample_state = self.upsampling(x, state=upsample_state, is_last=last_chunk) + if x is None: + stream_state = (upsample_state, audio_buffer, window_buffer) + return torch.empty(bsz, 1, 0, device=device, dtype=dtype), stream_state, past_key_values + else: + x = self.upsampling.upsampler(x.transpose(1, 2)).transpose(1, 2) + + outputs = self.decoder(inputs_embeds=x, past_key_values=past_key_values, use_cache=use_cache) + past_key_values = outputs.past_key_values + x = outputs.last_hidden_state + + x, _, audio_buffer, window_buffer = self.head( + x, streaming=use_cache, audio_buffer=audio_buffer, window_buffer=window_buffer, last_chunk=last_chunk + ) + + stream_state = (upsample_state, audio_buffer, window_buffer) + return x, stream_state, past_key_values diff --git a/vllm_omni/model_executor/models/ming_tts/config_ming_tts.py b/vllm_omni/model_executor/models/ming_tts/config_ming_tts.py new file mode 100644 index 00000000000..09ae85be69a --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/config_ming_tts.py @@ -0,0 +1,364 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# config_ming_tts.py +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from transformers import PretrainedConfig, Qwen2Config + +from .audio_tokenizer.configuration_audio_vae import AudioVAEconfig + +# --------------------------------------------------------------------------- +# Token IDs (confirmed from tokenizer_config.json) +# --------------------------------------------------------------------------- + +AUDIO_DUMMY_TOKEN_ID: int = 151705 # +AUDIO_START_TOKEN_ID: int = 151706 # +AUDIO_EOS_TOKEN_ID: int = 151704 # +VISION_START_TOKEN_ID: int = 151652 # <|vision_start|> + +TEXT_EOS_TOKEN_ID: int = 151669 # +PAD_TOKEN_ID: int = 151643 # <|endoftext|> + +# Backward-compat alias for older code paths +EOS_TOKEN_ID: int = TEXT_EOS_TOKEN_ID + + +# --------------------------------------------------------------------------- +# Architectural constants (confirmed from original config.json) +# --------------------------------------------------------------------------- + +LATENT_DIM: int = 64 +PATCH_SIZE: int = 4 +HISTORY_PATCH_SIZE: int = 32 +LLM_HIDDEN_SIZE: int = 896 +LLM_VOCAB_SIZE: int = 151936 +AGGREGATOR_HIDDEN_SIZE: int = 1024 +VAE_PATCH_SIZE: int = 4 +SAMPLE_RATE: int = 44100 + +# AudioVAE frame/hop geometry (confirmed) +AUDIO_FRAME_HOP: int = 882 # enc input_dim / hop_size / dec output_dim + +# stop_head defaults +STOP_HEAD_MIN_STEPS: int = 3 +STOP_HEAD_THRESHOLD: float = 0.5 + +# FlowLoss sampling defaults +DEFAULT_CFG: float = 2.0 +DEFAULT_SIGMA: float = 0.25 +DEFAULT_TEMPERATURE: float = 0.0 + +# Connector / Stage-2 streaming defaults (runtime tuning) +LATENT_CHUNK_SIZE: int = 25 +LATENT_LEFT_CONTEXT: int = 0 +MAX_DECODE_STEPS: int = 200 + +# seq_data.extra_data keys +KEY_LATENT_HISTORY: str = "ming_latent_history" +KEY_DECODE_STEP: str = "ming_decode_step" +KEY_LAST_STOP_PROB: str = "ming_last_stop_prob" +KEY_NEXT_EMBEDS: str = "ming_next_embeds" +KEY_PROMPT_LATENTS: str = "ming_prompt_latents" +KEY_PROMPT_LATENT_TAIL: str = "ming_prompt_latent_tail" +KEY_SPEAKER_EMBEDDING: str = "ming_speaker_embedding" +KEY_REQUEST_ID: str = "ming_request_id" +KEY_CHUNK_ID: str = "ming_chunk_id" +KEY_CFG: str = "ming_cfg" +KEY_SIGMA: str = "ming_sigma" +KEY_TEMPERATURE: str = "ming_temperature" +KEY_MAX_DECODE_STEPS: str = "ming_max_decode_steps" +KEY_MIN_DECODE_STEPS: str = "ming_min_decode_steps" +KEY_TEXT_MODE: str = "ming_text_mode" + + +@dataclass +class MingTTSConfig: + """Flat config object shared by Stage-1 and Stage-2. Build via from_hf_config().""" + + # --- LLM backbone --- + llm_hidden_size: int = LLM_HIDDEN_SIZE + llm_vocab_size: int = LLM_VOCAB_SIZE + llm_config: dict[str, Any] = field(default_factory=dict) + + # --- Audio latent space --- + latent_dim: int = LATENT_DIM + patch_size: int = PATCH_SIZE + history_patch_size: int = HISTORY_PATCH_SIZE + + # --- Flow / Aggregator sub-configs --- + ditar_config: dict[str, Any] = field(default_factory=dict) + aggregator_config: dict[str, Any] = field(default_factory=dict) + + # --- AudioVAE --- + audio_tokenizer_config: AudioVAEconfig | None = None + vae_patch_size: int = VAE_PATCH_SIZE + sample_rate: int = SAMPLE_RATE + audio_frame_hop: int = AUDIO_FRAME_HOP + + # --- Generation control --- + cfg: float = DEFAULT_CFG + sigma: float = DEFAULT_SIGMA + temperature: float = DEFAULT_TEMPERATURE + stop_head_min_steps: int = STOP_HEAD_MIN_STEPS + stop_head_threshold: float = STOP_HEAD_THRESHOLD + max_decode_steps: int = MAX_DECODE_STEPS + + # --- Stage-2 chunking (runtime tuning) --- + latent_chunk_size: int = LATENT_CHUNK_SIZE + latent_left_context: int = LATENT_LEFT_CONTEXT + + # --- Token IDs --- + text_eos_token_id: int = TEXT_EOS_TOKEN_ID + eos_token_id: int = TEXT_EOS_TOKEN_ID # compat alias + pad_token_id: int = PAD_TOKEN_ID + audio_dummy_token_id: int = AUDIO_DUMMY_TOKEN_ID + audio_start_token_id: int = AUDIO_START_TOKEN_ID + audio_end_token_id: int = AUDIO_END_TOKEN_ID + audio_eos_token_id: int = AUDIO_EOS_TOKEN_ID + + @classmethod + def from_hf_config(cls, hf_config: PretrainedConfig) -> MingTTSConfig: + """ + Build from vllm-omni's hf_config. Supports nested configs as objects or dicts. + """ + + # --- Read nested sub-configs (must NOT read flat hf_config attrs for these) --- + llm_raw = getattr(hf_config, "llm_config", {}) or {} + ditar_raw = getattr(hf_config, "ditar_config", {}) or {} + agg_raw = getattr(hf_config, "aggregator_config", {}) or {} + atc_raw = getattr(hf_config, "audio_tokenizer_config", None) + + llm_dict = _to_plain_dict(llm_raw) + ditar = _to_plain_dict(ditar_raw) + agg = _to_plain_dict(agg_raw) + + # Keep Ming DiT backend explicit; original checkpoint uses "torch" + ditar.setdefault("attn_backend", "torch") + + atc = _coerce_audio_vae_config(atc_raw) + + # --- Pull nested values safely --- + atc_enc_latent_dim = _nested_get(atc, "enc_kwargs", "latent_dim", default=LATENT_DIM) + atc_patch_size = _nested_get(atc, "patch_size", default=VAE_PATCH_SIZE) + atc_sample_rate = _nested_get(atc, "sample_rate", default=SAMPLE_RATE) + + enc_input_dim = _nested_get(atc, "enc_kwargs", "input_dim", default=AUDIO_FRAME_HOP) + enc_hop_size = _nested_get(atc, "enc_kwargs", "hop_size", default=AUDIO_FRAME_HOP) + dec_output_dim = _nested_get(atc, "dec_kwargs", "output_dim", default=AUDIO_FRAME_HOP) + + cfg = cls( + llm_hidden_size=llm_dict.get("hidden_size", LLM_HIDDEN_SIZE), + llm_vocab_size=llm_dict.get("vocab_size", LLM_VOCAB_SIZE), + llm_config=llm_dict, + latent_dim=atc_enc_latent_dim, + patch_size=ditar.get("patch_size", PATCH_SIZE), + history_patch_size=ditar.get("history_patch_size", HISTORY_PATCH_SIZE), + ditar_config=ditar, + aggregator_config=agg, + audio_tokenizer_config=atc, + vae_patch_size=atc_patch_size, + sample_rate=atc_sample_rate, + audio_frame_hop=enc_hop_size if enc_hop_size is not None else AUDIO_FRAME_HOP, + ) + + # Optional debug cache (safe to keep) + cfg._enc_input_dim = enc_input_dim + cfg._enc_hop_size = enc_hop_size + cfg._dec_output_dim = dec_output_dim + + return cfg + + def validate(self) -> None: + """Run before GPU allocation/weight loading. Raises ValueError on mismatches.""" + + # --- Token IDs --- + if self.audio_dummy_token_id != 151705: + raise ValueError( + f"audio_dummy_token_id={self.audio_dummy_token_id}, expected 151705 (). " + "Wrong tokenizer/checkpoint?" + ) + if self.audio_eos_token_id != 151704: + raise ValueError( + f"audio_eos_token_id={self.audio_eos_token_id}, expected 151704 (). " + "Wrong tokenizer/checkpoint?" + ) + if self.text_eos_token_id != 151669: + raise ValueError( + f"text_eos_token_id={self.text_eos_token_id}, expected 151669 (). Wrong tokenizer/checkpoint?" + ) + + # --- Required sub-config --- + if self.audio_tokenizer_config is None: + raise ValueError("audio_tokenizer_config is None. Nested AudioVAE config was not deserialized correctly.") + + # --- Confirmed checkpoint-family constants --- + if self.latent_dim != LATENT_DIM: + raise ValueError( + f"latent_dim mismatch: got {self.latent_dim}, expected {LATENT_DIM}. " + "Check audio_tokenizer_config.enc_kwargs.latent_dim." + ) + if self.patch_size != PATCH_SIZE: + raise ValueError( + f"patch_size mismatch: got {self.patch_size}, expected {PATCH_SIZE}. Check ditar_config.patch_size." + ) + if self.history_patch_size != HISTORY_PATCH_SIZE: + raise ValueError( + f"history_patch_size mismatch: got {self.history_patch_size}, expected {HISTORY_PATCH_SIZE}. " + "Check ditar_config.history_patch_size." + ) + if self.llm_hidden_size != LLM_HIDDEN_SIZE: + raise ValueError( + f"llm_hidden_size mismatch: got {self.llm_hidden_size}, expected {LLM_HIDDEN_SIZE}. " + "Check llm_config.hidden_size." + ) + if self.llm_vocab_size != LLM_VOCAB_SIZE: + raise ValueError(f"llm_vocab_size mismatch: got {self.llm_vocab_size}, expected {LLM_VOCAB_SIZE}.") + if self.sample_rate != SAMPLE_RATE: + raise ValueError(f"sample_rate mismatch: got {self.sample_rate}, expected {SAMPLE_RATE}.") + + # --- Cross-config consistency checks --- + if self.vae_patch_size != self.patch_size: + raise ValueError(f"VAE patch size ({self.vae_patch_size}) != flow/DiT patch size ({self.patch_size}).") + + llm_hidden_from_cfg = self.llm_config.get("hidden_size") + if llm_hidden_from_cfg is not None and llm_hidden_from_cfg != self.llm_hidden_size: + raise ValueError( + f"llm_hidden_size ({self.llm_hidden_size}) != llm_config.hidden_size ({llm_hidden_from_cfg})." + ) + + agg_h = self.aggregator_config.get("hidden_size") + dit_h = self.ditar_config.get("hidden_size") + if agg_h is not None and dit_h is not None and agg_h != dit_h: + raise ValueError(f"aggregator_config.hidden_size ({agg_h}) != ditar_config.hidden_size ({dit_h}).") + if agg_h is not None and agg_h != AGGREGATOR_HIDDEN_SIZE: + raise ValueError(f"aggregator hidden_size mismatch: got {agg_h}, expected {AGGREGATOR_HIDDEN_SIZE}.") + if dit_h is not None and dit_h != AGGREGATOR_HIDDEN_SIZE: + raise ValueError(f"ditar hidden_size mismatch: got {dit_h}, expected {AGGREGATOR_HIDDEN_SIZE}.") + + atc = self.audio_tokenizer_config + enc_latent = _nested_get(atc, "enc_kwargs", "latent_dim", default=None) + dec_latent = _nested_get(atc, "dec_kwargs", "latent_dim", default=None) + if enc_latent is not None and enc_latent != self.latent_dim: + raise ValueError(f"audio enc latent_dim ({enc_latent}) != Ming latent_dim ({self.latent_dim}).") + if dec_latent is not None and dec_latent != self.latent_dim: + raise ValueError(f"audio dec latent_dim ({dec_latent}) != Ming latent_dim ({self.latent_dim}).") + + atc_patch = _nested_get(atc, "patch_size", default=None) + if atc_patch is not None and atc_patch != self.vae_patch_size: + raise ValueError( + f"audio_tokenizer_config.patch_size ({atc_patch}) != vae_patch_size ({self.vae_patch_size})." + ) + + atc_sr = _nested_get(atc, "sample_rate", default=None) + if atc_sr is not None and atc_sr != self.sample_rate: + raise ValueError(f"audio_tokenizer_config.sample_rate ({atc_sr}) != sample_rate ({self.sample_rate}).") + + enc_input_dim = _nested_get(atc, "enc_kwargs", "input_dim", default=None) + enc_hop_size = _nested_get(atc, "enc_kwargs", "hop_size", default=None) + dec_output_dim = _nested_get(atc, "dec_kwargs", "output_dim", default=None) + + if enc_input_dim is not None and enc_hop_size is not None and enc_input_dim != enc_hop_size: + raise ValueError(f"AudioVAE encoder input_dim ({enc_input_dim}) != hop_size ({enc_hop_size}).") + if enc_hop_size is not None and dec_output_dim is not None and enc_hop_size != dec_output_dim: + raise ValueError( + f"AudioVAE encoder hop_size ({enc_hop_size}) != decoder output_dim ({dec_output_dim}). " + "Expected 882 in this checkpoint family." + ) + + # Runtime tuning sanity + if self.latent_chunk_size <= 0: + raise ValueError(f"latent_chunk_size must be > 0, got {self.latent_chunk_size}.") + if self.latent_left_context < 0: + raise ValueError(f"latent_left_context must be >= 0, got {self.latent_left_context}.") + if self.max_decode_steps <= 0: + raise ValueError(f"max_decode_steps must be > 0, got {self.max_decode_steps}.") + if not (0.0 <= self.stop_head_threshold <= 1.0): + raise ValueError(f"stop_head_threshold must be in [0,1], got {self.stop_head_threshold}.") + if self.stop_head_min_steps < 0: + raise ValueError(f"stop_head_min_steps must be >= 0, got {self.stop_head_min_steps}.") + + def make_qwen2_config(self) -> Qwen2Config: + """Reconstruct Qwen2Config for Stage-1 LLM backbone init.""" + if not self.llm_config: + raise ValueError("llm_config is empty; from_hf_config() failed to parse nested llm_config.") + return Qwen2Config.from_dict(self.llm_config) + + @property + def latent_patch_shape(self) -> tuple[int, int]: + return (self.patch_size, self.latent_dim) + + @property + def chunk_frames(self) -> int: + return self.latent_chunk_size * self.patch_size + + @property + def approx_chunk_seconds(self) -> float: + # One latent frame ~ one 882-sample hop in this checkpoint family. + return (self.chunk_frames * self.audio_frame_hop) / float(self.sample_rate) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _to_plain_dict(obj: Any) -> dict[str, Any]: + """Normalize nested config objects into plain dicts when possible.""" + if obj is None: + return {} + if isinstance(obj, dict): + return dict(obj) + if isinstance(obj, PretrainedConfig): + return obj.to_dict() + if hasattr(obj, "to_dict") and callable(obj.to_dict): + try: + return dict(obj.to_dict()) + except Exception: + pass + try: + return dict(vars(obj)) + except Exception: + return {} + + +def _coerce_audio_vae_config(atc_raw: Any) -> AudioVAEconfig | None: + """ + Normalize audio_tokenizer_config into AudioVAEconfig when possible. + Handles: + - already AudioVAEconfig + - dict + - PretrainedConfig-like object + """ + if atc_raw is None: + return None + atc_dict = _to_plain_dict(atc_raw) + if not atc_dict: + # Return raw object as fallback; _nested_get/validate can still work + return atc_raw # type: ignore[return-value] + + if hasattr(AudioVAEconfig, "from_dict") and callable(getattr(AudioVAEconfig, "from_dict")): + try: + return AudioVAEconfig.from_dict(atc_dict) # type: ignore[misc] + except Exception: + pass + try: + return AudioVAEconfig(**atc_dict) # type: ignore[arg-type] + except Exception: + return atc_raw # type: ignore[return-value] + + +def _nested_get(obj: Any, *keys: str, default: Any = None) -> Any: + """Safe nested attribute/key access for dicts and config-like objects.""" + cur = obj + for k in keys: + if cur is None: + return default + if isinstance(cur, dict): + cur = cur.get(k) + else: + cur = getattr(cur, k, None) + return cur if cur is not None else default diff --git a/vllm_omni/model_executor/models/ming_tts/configuration_ming_dense.py b/vllm_omni/model_executor/models/ming_tts/configuration_ming_dense.py new file mode 100644 index 00000000000..d6f5c8182e9 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/configuration_ming_dense.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from typing import Any + +from transformers import PretrainedConfig, Qwen2Config + +from .audio_tokenizer.configuration_audio_vae import AudioVAEconfig + + +def _coerce_qwen2_config(value: Any) -> Qwen2Config: + if isinstance(value, Qwen2Config): + return value + if isinstance(value, PretrainedConfig): + return Qwen2Config.from_dict(value.to_dict()) + if isinstance(value, dict): + return Qwen2Config.from_dict(dict(value)) + raise TypeError(f"Unsupported llm_config type for Ming dense config: {type(value)!r}") + + +def _coerce_audio_vae_config(value: Any) -> AudioVAEconfig | None: + if value is None: + return None + if isinstance(value, AudioVAEconfig): + value = value.to_dict() + elif isinstance(value, PretrainedConfig): + value = value.to_dict() + elif isinstance(value, dict): + value = dict(value) + else: + raise TypeError(f"Unsupported audio_tokenizer_config type for Ming dense config: {type(value)!r}") + + return AudioVAEconfig(**value) + + +class MingDenseConfig(PretrainedConfig): + model_type = "dense" + + def __init__( + self, + llm_config: Qwen2Config | dict[str, Any] | None = None, + ditar_config: dict[str, Any] | None = None, + aggregator_config: dict[str, Any] | None = None, + audio_tokenizer_config: AudioVAEconfig | dict[str, Any] | None = None, + architectures: list[str] | None = None, + **kwargs: Any, + ) -> None: + super().__init__(architectures=architectures, **kwargs) + self.llm_config = _coerce_qwen2_config(llm_config or {}) + self.ditar_config = dict(ditar_config or {}) + self.aggregator_config = dict(aggregator_config or {}) + self.audio_tokenizer_config = _coerce_audio_vae_config(audio_tokenizer_config) + + def get_text_config(self, decoder: bool = False, **kwargs: Any) -> Qwen2Config: + del decoder, kwargs + return self.llm_config diff --git a/vllm_omni/model_executor/models/ming_tts/fm/__init__.py b/vllm_omni/model_executor/models/ming_tts/fm/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/fm/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/model_executor/models/ming_tts/fm/cfm.py b/vllm_omni/model_executor/models/ming_tts/fm/cfm.py new file mode 100644 index 00000000000..b1924973b47 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/fm/cfm.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adopted from https://github.com/inclusionAI/Ming-omni-tts/blob/main/fm/CFM.py + + +import torch +import torch.nn.functional as F +from torch import nn + + +class Solver: + def __init__(self, func, y0, sigma=0.25, temperature=1.5) -> None: + self.func = func + self.y0 = y0 + self.sigma = sigma + self.temperature = temperature + + def integrate(self, t): + solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) + solution[0] = self.y0 + + j = 1 + y0 = self.y0 + for t0, t1 in zip(t[:-1], t[1:]): + dt = t1 - t0 + f0 = self.func(t0, y0) + dy = dt * f0 + y1 = y0 + dy + + while j < len(t) and t1 >= t[j]: + solution[j] = self._linear_interp(t0, t1, y0, y1, t[j]) + j += 1 + + noise = torch.randn_like(y0) + shift = self.sigma * (self.temperature**0.5) * (abs(dt) ** 0.5) * noise + y0 = y1 + shift + + return solution + + def _linear_interp(self, t0, t1, y0, y1, t): + if t == t0: + return y0 + if t == t1: + return y1 + slope = (t - t0) / (t1 - t0) + return y0 + slope * (y1 - y0) + + +def get_epss_timesteps(n, device, dtype): + dt = 1 / 32 + predefined_timesteps = { + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], + 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], + } + t = predefined_timesteps.get(n, []) + if not t: + return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) + return dt * torch.tensor(t, device=device, dtype=dtype) + + +class CFM(nn.Module): + def __init__( + self, + model: nn.Module, + ): + super().__init__() + self.model = model + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, + cond, + target, + latent_history, + mask, + patch_size, + ): + if patch_size <= 0: + raise ValueError(f"patch_size must be positive, got {patch_size}") + if cond.ndim != 3: + raise ValueError(f"Expected cond rank-3 [Batch, Time, Dimension], got {tuple(cond.shape)}") + if target.ndim != 3: + raise ValueError(f"Expected target rank-3 [Batch, Time, Dimension], got {tuple(target.shape)}") + if latent_history.ndim != 3: + raise ValueError( + f"Expected latent_history rank-3 [Batch, Time, Dimension], got {tuple(latent_history.shape)}" + ) + if cond.shape[0] != target.shape[0] or cond.shape[0] != latent_history.shape[0]: + raise ValueError( + "Batch mismatch across cond, target, and latent_history: " + f"{cond.shape[0]}, {target.shape[0]}, {latent_history.shape[0]}" + ) + token_mask = _coerce_token_mask( + mask, batch_size=target.shape[0], target_steps=target.shape[1], device=target.device + ) + + x1 = target + batch, dtype = x1.shape[0], x1.dtype + x0 = torch.randn_like(x1) + time = torch.rand((batch,), dtype=dtype, device=self.device) + # sample xt (φ_t(x) in the paper) + t = time.unsqueeze(-1).unsqueeze(-1) + x = (1 - t) * x0 + t * x1 + flow = x1 - x0 + + pred = self.model(x=x, t=time, c=cond, latent_history=latent_history, mask=token_mask) + pred = pred[:, -patch_size:, :] + + loss = F.mse_loss(pred, flow, reduction="none") + loss_mask = token_mask.unsqueeze(-1).expand_as(loss) + loss = loss[loss_mask] + + return loss.mean() + + @torch.no_grad() + def sample( + self, + noise, + c, + latent_history, + steps=10, + cfg_scale=1.0, + sway_sampling_coef=-1.0, + use_epss=True, + patch_size=1, + sigma=0.25, + temperature=1.5, + ): + if steps <= 0: + raise ValueError(f"steps must be positive, got {steps}") + if patch_size <= 0: + raise ValueError(f"patch_size must be positive, got {patch_size}") + if noise.ndim != 3: + raise ValueError(f"Expected noise rank-3 [Batch, Dimension, Time], got {tuple(noise.shape)}") + if c.ndim != 3: + raise ValueError(f"Expected conditioning rank-3 [Batch, Time, Dimension], got {tuple(c.shape)}") + if latent_history.ndim != 3: + raise ValueError( + f"Expected latent_history rank-3 [Batch, Time, Dimension], got {tuple(latent_history.shape)}" + ) + if noise.shape[0] != c.shape[0] or noise.shape[0] != latent_history.shape[0]: + raise ValueError( + "Batch mismatch across noise, conditioning, and latent_history: " + f"{noise.shape[0]}, {c.shape[0]}, {latent_history.shape[0]}" + ) + if noise.shape[-1] != patch_size: + raise ValueError(f"noise time dim mismatch: got {noise.shape[-1]}, expected patch_size={patch_size}") + + def fn(t, x): + if cfg_scale < 1e-5: + if t.ndim == 0: + t = t.repeat(x.shape[0]) + pred = self.model( + x=x, + t=t, + c=torch.zeros_like(c), + latent_history=latent_history, + ) + return pred[:, -patch_size:, :] + + # predict flow (cond and uncond), for classifier-free guidance + pred_cfg = self.model.forward_with_cfg( + x=x, + t=t, + c=c, + latent_history=latent_history, + cfg_scale=cfg_scale, + patch_size=patch_size, + ) + pred, null_pred = torch.chunk(pred_cfg, 2, dim=0) + return pred + (pred - null_pred) * cfg_scale + + y0 = noise.transpose(1, 2) + t_start = 0 + + if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE + t = get_epss_timesteps(steps, device=self.device, dtype=noise.dtype) + else: + t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=noise.dtype) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + solver = Solver(fn, y0, sigma=sigma, temperature=temperature) + trajectory = solver.integrate(t) + sampled = trajectory[-1] + out = sampled + + return out, trajectory + + +def _coerce_token_mask(mask, *, batch_size, target_steps, device): + if not isinstance(mask, torch.Tensor): + mask = torch.as_tensor(mask, device=device) + if mask.ndim == 3 and mask.shape[-1] == 1: + mask = mask.squeeze(-1) + if mask.ndim != 2: + raise ValueError(f"Expected mask rank-2 [Batch, Time] or rank-3 [Batch, Time, 1], got {tuple(mask.shape)}") + if mask.shape[0] != batch_size or mask.shape[1] != target_steps: + raise ValueError(f"Mask shape mismatch: got {tuple(mask.shape)}, expected {(batch_size, target_steps)}") + return mask.to(device=device, dtype=torch.bool) diff --git a/vllm_omni/model_executor/models/ming_tts/fm/dit.py b/vllm_omni/model_executor/models/ming_tts/fm/dit.py new file mode 100644 index 00000000000..2024f26ca2d --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/fm/dit.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adopted from https://github.com/inclusionAI/Ming-omni-tts/blob/main/fm/dit.py + +import math + +import torch +import torch.nn as nn +from x_transformers.x_transformers import RotaryEmbedding + +from .modules import DiTBlock, FinalLayer + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + if x.ndim == 0: + x = x.reshape(1) + if x.ndim != 1: + raise ValueError(f"Expected timestep rank-1 [Batch], got {tuple(x.shape)}") + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class TimestepEmbedder(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + + def forward(self, timestep): + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time + + +class CondEmbedder(nn.Module): + def __init__(self, input_feature_size, hidden_size, dropout_prob): + super().__init__() + self.dropout_prob = dropout_prob + self.cond_embedder = nn.Linear(input_feature_size, hidden_size) + + def cond_drop(self, llm_cond): + if llm_cond.ndim != 3: + raise ValueError(f"Expected conditioning rank-3 [Batch, Time, Dimension], got {tuple(llm_cond.shape)}") + bsz = llm_cond.shape[0] + drop_latent_mask = torch.rand(bsz) < self.dropout_prob + drop_latent_mask = drop_latent_mask.unsqueeze(-1).unsqueeze(-1).to(llm_cond.dtype).to(llm_cond.device) + fake_latent = torch.zeros_like(llm_cond) + llm_cond = drop_latent_mask * fake_latent + (1 - drop_latent_mask) * llm_cond + + return llm_cond + + def forward(self, llm_cond, train): + if llm_cond.ndim != 3: + raise ValueError(f"Expected conditioning rank-3 [Batch, Time, Dimension], got {tuple(llm_cond.shape)}") + use_dropout = self.dropout_prob > 0 + if train and use_dropout: + llm_cond = self.cond_drop(llm_cond) + + llm_cond = self.cond_embedder(llm_cond) + + return llm_cond + + +class DiT(nn.Module): + def __init__( + self, + in_channels=4, + hidden_size=1024, + depth=28, + num_heads=16, + mlp_ratio=4.0, + llm_cond_dim=896, + cfg_dropout_prob=0.1, + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = in_channels + self.num_heads = num_heads + self.t_embedder = TimestepEmbedder(hidden_size) + self.x_embedder = nn.Linear(in_channels, hidden_size) + self.c_embedder = CondEmbedder(llm_cond_dim, hidden_size, cfg_dropout_prob) + self.hidden_size = hidden_size + self.rotary_embed = RotaryEmbedding(hidden_size // num_heads) + self.blocks = nn.ModuleList( + [DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **kwargs) for _ in range(depth)] + ) + self.final_layer = FinalLayer(hidden_size, self.out_channels) + + def forward(self, x, t, c, latent_history, mask=None): + if x.ndim != 3: + raise ValueError(f"Expected x rank-3 [Batch, Time, Dimension], got {tuple(x.shape)}") + if latent_history.ndim != 3: + raise ValueError( + f"Expected latent_history rank-3 [Batch, Time, Dimension], got {tuple(latent_history.shape)}" + ) + if c.ndim != 3: + raise ValueError(f"Expected conditioning rank-3 [Batch, Time, Dimension], got {tuple(c.shape)}") + if x.shape[0] != latent_history.shape[0] or x.shape[0] != c.shape[0]: + raise ValueError( + "Batch mismatch across x, conditioning, and latent_history: " + f"{x.shape[0]}, {c.shape[0]}, {latent_history.shape[0]}" + ) + if x.shape[-1] != self.in_channels: + raise ValueError(f"x feature dim mismatch: got {x.shape[-1]}, expected {self.in_channels}") + if latent_history.shape[-1] != self.in_channels: + raise ValueError( + f"latent_history feature dim mismatch: got {latent_history.shape[-1]}, expected {self.in_channels}" + ) + if t.ndim == 0: + t = t.reshape(1) + if t.ndim != 1: + raise ValueError(f"Expected timestep rank-1 [Batch], got {tuple(t.shape)}") + if t.shape[0] != x.shape[0]: + raise ValueError(f"Timestep batch mismatch: got {t.shape[0]}, expected {x.shape[0]}") + + t = self.t_embedder(t).unsqueeze(1) + x_now = self.x_embedder(x) + x_history = self.x_embedder(latent_history) + x = torch.cat([x_history, x_now], dim=1) + c = self.c_embedder(c, self.training) + y = t + c + x = torch.cat([y, x], dim=1) + rope = self.rotary_embed.forward_from_seq_len(x.shape[1]) + + if mask is not None: + if mask.ndim != 2: + raise ValueError(f"Expected mask rank-2 [Batch, Time], got {tuple(mask.shape)}") + if mask.shape[0] != x_now.shape[0] or mask.shape[1] != x_now.shape[1]: + raise ValueError( + f"Mask shape mismatch: got {tuple(mask.shape)}, expected {(x_now.shape[0], x_now.shape[1])}" + ) + mask_pad = mask.clone().detach()[:, :1].expand(-1, x_history.shape[1] + c.shape[1]) + mask = torch.cat([mask_pad, mask], dim=-1) + for block in self.blocks: + x = block(x, mask, rope) + x = self.final_layer(x) + return x + + def forward_with_cfg(self, x, t, c, cfg_scale, latent_history, patch_size): + if patch_size <= 0: + raise ValueError(f"patch_size must be positive, got {patch_size}") + if not cfg_scale == 1: + x = torch.cat([x, x], dim=0) + latent_history = torch.cat([latent_history, latent_history], dim=0) + fake_latent = torch.zeros_like(c) + c = torch.cat([c, fake_latent], dim=0) + if t.ndim == 0: + t = t.repeat(x.shape[0]) + model_out = self.forward(x, t, c, latent_history) + return model_out[:, -patch_size:, :] + + +class Aggregator(nn.Module): + def __init__( + self, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + llm_input_dim=896, + **kwargs, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.num_heads = num_heads + + self.word_embedder = nn.Embedding(1, hidden_size) + self.x_embedder = nn.Linear(in_channels, hidden_size) + self.hidden_size = hidden_size + + self.rotary_embed = RotaryEmbedding(hidden_size // num_heads) + + self.blocks = nn.ModuleList( + [DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **kwargs) for _ in range(depth)] + ) + self.final_layer = FinalLayer(hidden_size, llm_input_dim) + + def forward(self, x, mask=None): + if x.ndim != 3: + raise ValueError(f"Expected x rank-3 [Batch, Time, Dimension], got {tuple(x.shape)}") + if x.shape[-1] != self.in_channels: + raise ValueError(f"x feature dim mismatch: got {x.shape[-1]}, expected {self.in_channels}") + x = self.x_embedder(x) + cls_embed = self.word_embedder(torch.zeros((x.shape[0], 1), dtype=torch.long, device=x.device)) + x = torch.cat([cls_embed, x], dim=1) + + rope = self.rotary_embed.forward_from_seq_len(x.shape[1]) + if mask is not None: + if mask.ndim != 2: + raise ValueError(f"Expected mask rank-2 [Batch, Time], got {tuple(mask.shape)}") + if mask.shape[0] != x.shape[0] or mask.shape[1] != x.shape[1] - 1: + raise ValueError( + f"Mask shape mismatch: got {tuple(mask.shape)}, expected {(x.shape[0], x.shape[1] - 1)}" + ) + mask_pad = mask.clone().detach()[:, :1] + mask = torch.cat([mask_pad, mask], dim=-1) + for block in self.blocks: + x = block(x, mask, rope) + x = self.final_layer(x) + x = x[:, :1, :] + return x diff --git a/vllm_omni/model_executor/models/ming_tts/fm/flowloss.py b/vllm_omni/model_executor/models/ming_tts/fm/flowloss.py new file mode 100644 index 00000000000..18c59186c3a --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/fm/flowloss.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adopted from https://github.com/inclusionAI/Ming-omni-tts/blob/main/fm/flowloss.py + +import torch +import torch.nn as nn + +from .cfm import CFM +from .dit import DiT + + +class FlowLoss(nn.Module): + """Diffusion Loss""" + + def __init__(self, z_channels, llm_cond_dim, **kwargs): + super().__init__() + self.z_channels = z_channels + self.cfm = CFM(model=DiT(in_channels=z_channels, llm_cond_dim=llm_cond_dim, **kwargs)) + + def forward(self, cond, target, latent_history, mask, patch_size): + return self.cfm(cond=cond, target=target, latent_history=latent_history, mask=mask, patch_size=patch_size) + + def sample(self, z, latent_history, cfg=2.0, patch_size=1, sigma=0.25, temperature=0): + if z.ndim != 3: + raise ValueError(f"Expected z rank-3 [Batch, Time, Dimension], got {tuple(z.shape)}") + if z.shape[1] != 1: + raise ValueError(f"Expected z time dim to be 1 for Ming dense decode, got {z.shape[1]}") + if latent_history.ndim != 3: + raise ValueError( + f"Expected latent_history rank-3 [Batch, Time, Dimension], got {tuple(latent_history.shape)}" + ) + if z.shape[0] != latent_history.shape[0]: + raise ValueError(f"Batch mismatch: z batch={z.shape[0]} vs latent_history batch={latent_history.shape[0]}") + if patch_size <= 0: + raise ValueError(f"patch_size must be positive, got {patch_size}") + if not torch.isfinite(z).all(): + raise RuntimeError("Non-finite conditioning z in FlowLoss.sample().") + if not torch.isfinite(latent_history).all(): + raise RuntimeError("Non-finite latent_history in FlowLoss.sample().") + noise = torch.randn(z.shape[0], self.z_channels, patch_size, device=z.device) + if not torch.isfinite(noise).all(): + raise RuntimeError("Non-finite noise in FlowLoss.sample().") + noise = noise.to(dtype=z.dtype) # match conditioning dtype — no autocast in vllm-omni + out, _ = self.cfm.sample( + noise=noise, + c=z, + latent_history=latent_history, + cfg_scale=cfg, + patch_size=patch_size, + sigma=sigma, + temperature=temperature, + ) + # out shape: [B, patch_size, z_channels] + return out diff --git a/vllm_omni/model_executor/models/ming_tts/fm/modules.py b/vllm_omni/model_executor/models/ming_tts/fm/modules.py new file mode 100644 index 00000000000..1163f8f1837 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/fm/modules.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adopted from https://github.com/inclusionAI/Ming-omni-tts/blob/main/fm/modules.py +import torch +import torch.nn.functional as F +from torch import nn +from x_transformers.x_transformers import apply_rotary_pos_emb + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.native_rms_norm = float(torch.__version__[:3]) >= 2.4 + + def forward(self, x): + if self.native_rms_norm: + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.to(self.weight.dtype) + x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps) + else: + variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.to(self.weight.dtype) + x = x * self.weight + + return x + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.ff(x) + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("SDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, dim)) + self.to_out.append(nn.Dropout(dropout)) + + def forward( + self, + x: float, # noised input x + mask=None, + rope=None, # rotary position embedding for x + ) -> torch.Tensor: + if x.ndim != 3: + raise ValueError(f"Expected x rank-3 [Batch, Time, Dimension], got {tuple(x.shape)}") + if x.shape[-1] != self.dim: + raise ValueError(f"x feature dim mismatch: got {x.shape[-1]}, expected {self.dim}") + if mask is not None: + if mask.ndim != 2: + raise ValueError(f"Expected mask rank-2 [Batch, Time], got {tuple(mask.shape)}") + if mask.shape[0] != x.shape[0] or mask.shape[1] != x.shape[1]: + raise ValueError(f"Mask shape mismatch: got {tuple(mask.shape)}, expected {tuple(x.shape[:2])}") + + batch_size = x.shape[0] + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0) + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + x = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + x = x.to(query.dtype) + x = self.to_out[0](x) + x = self.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +class DiTBlock(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, dropout=0.1, **kwargs): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = Attention(dim=hidden_size, heads=num_heads, dim_head=hidden_size // num_heads, dropout=dropout) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + self.mlp = FeedForward(dim=hidden_size, mult=mlp_ratio, dropout=dropout, approximate="tanh") + + def forward(self, x, mask, rope): + x = x + self.attn(self.norm1(x), mask=mask, rope=rope) + x = x + self.mlp(self.norm2(x)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = RMSNorm(hidden_size, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + def forward(self, x): + x = self.norm_final(x) + x = self.linear(x) + return x diff --git a/vllm_omni/model_executor/models/ming_tts/ingress.py b/vllm_omni/model_executor/models/ming_tts/ingress.py new file mode 100644 index 00000000000..77428164d5a --- /dev/null +++ b/vllm_omni/model_executor/models/ming_tts/ingress.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import copy +import os +import time +from typing import Any + +from vllm.logger import init_logger + +from .config_ming_tts import ( + AUDIO_DUMMY_TOKEN_ID, + AUDIO_START_TOKEN_ID, + KEY_PROMPT_LATENTS, + KEY_SPEAKER_EMBEDDING, + KEY_TEXT_MODE, + MingTTSConfig, +) +from .prompt_builder import ( + build_dense_prompt_token_ids, + coerce_speaker_embeddings, + count_prompt_waveform_patches, + create_instruction, +) + +logger = init_logger(__name__) + + +def _rebuild_prompt_token_ids_with_exact_patch_count(prompt_token_ids: Any, prompt_patch_count: int) -> list[int]: + if not isinstance(prompt_token_ids, list) or not prompt_token_ids: + raise ValueError("Ming prompt finalization requires existing prompt_token_ids") + + audio_start_index = -1 + for idx in range(len(prompt_token_ids) - 1, -1, -1): + if int(prompt_token_ids[idx]) == AUDIO_START_TOKEN_ID: + audio_start_index = idx + break + if audio_start_index < 0: + raise ValueError("Ming prompt finalization could not locate