Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4f6899a
[+] Feat: Support disaggregated inference pipeline for Talker and Spe…
Sy0307 Feb 2, 2026
b07bdf2
[~] Fix: Enhance compatibility in OmniConnector adapter
Sy0307 Feb 9, 2026
fbd0a92
[~] Fix: Enhance OmniGenerationScheduler and Qwen3TTSTokenizer functi…
Sy0307 Feb 9, 2026
fc6afae
[~] Refactor: codec frame rate handling in OmniModelConfig and Qwen3T…
Sy0307 Feb 9, 2026
05e9cff
[~] Refactor: Improve device handling and additional information mana…
Sy0307 Feb 10, 2026
1852a23
[~] Refactor: Streamline prompt handling and additional information e…
Sy0307 Feb 10, 2026
8fb1c61
[~] Refactor: Simplify payload handling and enhance metadata manageme…
Sy0307 Feb 10, 2026
82b2640
[~] Refactor: Remove unused logger initialization and streamline code…
Sy0307 Feb 10, 2026
885ee3d
[~] Refactor: Update TTS prompt handling and introduce new configurat…
Sy0307 Feb 10, 2026
ad6b676
[~] Fix: Enhance TTS processing to fix audio overlap issues
Sy0307 Feb 10, 2026
a304e9e
[~] Style: Fix code format errors of pre-commit
Sy0307 Feb 10, 2026
6eea932
[~] Style: Re-fix code format errors of pre-commit
Sy0307 Feb 10, 2026
3a06119
[~] Refactor: Remove Qwen3 TTS model files and update registry to ref…
Sy0307 Feb 10, 2026
14a9ddc
[~] Refactor: Optimize TTS model initialization and enhance configura…
Sy0307 Feb 11, 2026
4965b84
[~] Feat: Implement ref_audio resolution for TTS processing to solve …
Sy0307 Feb 11, 2026
6fa6afc
[~] Feat: Improve codec mask handling in Qwen3 TTS
Sy0307 Feb 11, 2026
44693a0
[~] Refactor: Clean up additional information handling in OmniGenerat…
Sy0307 Feb 12, 2026
03f66bd
[~] Style: Format error fixed
Sy0307 Feb 12, 2026
3316f49
[+] Feat: Enhance SSRF protection and improve TTS processing for cuda…
Sy0307 Feb 13, 2026
1cad15d
[~] Style: Fix pre-commit issue
Sy0307 Feb 14, 2026
b6c1928
[+] Fix: Enhance scheduling logic and chunk processing in OmniGenerat…
Sy0307 Feb 15, 2026
8fbf458
[~] Style: Fix ruff format error
Sy0307 Feb 15, 2026
5530aaa
[+] CI: Add prompt length estimation for Talker stage and refactor in…
Sy0307 Feb 16, 2026
a108e62
[~] CI: Fix unit-test for "talker_mtp_output_key" caller update for q…
Sy0307 Feb 19, 2026
1a1bc1a
Merge branch 'main' into dev/tts_disaggregation
hsliuustc0106 Feb 20, 2026
60b6603
[~] CI: Increase timeout for Omni Model Test step from 15 to 20 minutes
Sy0307 Feb 20, 2026
2773227
Merge branch 'main' into dev/tts_disaggregation
hsliuustc0106 Feb 20, 2026
b6e6972
[~] Fix: Copy input batch request IDs and indices in NPU and GPU mode…
Sy0307 Feb 20, 2026
f9d656f
[+] Style: Add comments to clarify copying operation
Sy0307 Feb 20, 2026
33fbc3b
[-] Build: Remove deprecated configuration file for Qwen3 TTS talker …
Sy0307 Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ steps:
# type: DirectoryOrCreate

- label: "Omni Model Test"
timeout_in_minutes: 15
timeout_in_minutes: 20
depends_on: image-build
commands:
- export VLLM_LOGGING_LEVEL=DEBUG
Expand Down
204 changes: 120 additions & 84 deletions examples/offline_inference/qwen3_tts/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
tasks, then runs Omni generation and saves output wav files.
"""

import logging
import os
from typing import NamedTuple
from typing import Any, NamedTuple

import soundfile as sf
import torch

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

from vllm import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser

from vllm_omni import Omni

logger = logging.getLogger(__name__)


class QueryResult(NamedTuple):
"""Container for a prepared Omni request."""
Expand All @@ -24,6 +27,44 @@ class QueryResult(NamedTuple):
model_name: str


def _estimate_prompt_len(
additional_information: dict[str, Any],
model_name: str,
_cache: dict[str, Any] = {},
) -> int:
"""Estimate prompt_token_ids placeholder length for the Talker stage.

The AR Talker replaces all input embeddings via ``preprocess``, so the
placeholder values are irrelevant but the **length** must match the
embeddings that ``preprocess`` will produce.
"""
try:
from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig
from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import (
Qwen3TTSTalkerForConditionalGeneration,
)

if model_name not in _cache:
from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left")
cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True)
_cache[model_name] = (tok, getattr(cfg, "talker_config", None))

tok, tcfg = _cache[model_name]
task_type = (additional_information.get("task_type") or ["CustomVoice"])[0]
return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information(
additional_information=additional_information,
task_type=task_type,
tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"],
codec_language_id=getattr(tcfg, "codec_language_id", None),
spk_is_dialect=getattr(tcfg, "spk_is_dialect", None),
)
except Exception as exc:
logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc)
return 2048


def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult:
"""Build CustomVoice sample inputs.

Expand All @@ -34,47 +75,48 @@ def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult:
QueryResult with Omni inputs and the CustomVoice model path.
"""
task_type = "CustomVoice"
model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
if use_batch_sample:
texts = ["其实我真的有发现,我是一个特别善于观察别人情绪的人。", "She said she would be here by noon."]
instructs = ["", "Very happy."]
languages = ["Chinese", "English"]
speakers = ["Vivian", "Ryan"]
inputs = []
for text, instruct, language, speaker in zip(texts, instructs, languages, speakers):
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
additional_information = {
"task_type": [task_type],
"text": [text],
"instruct": [instruct],
"language": [language],
"speaker": [speaker],
"max_new_tokens": [2048],
}
inputs.append(
{
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"text": [text],
"instruct": [instruct],
"language": [language],
"speaker": [speaker],
"max_new_tokens": [2048],
},
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
"additional_information": additional_information,
}
)
else:
text = "其实我真的有发现,我是一个特别善于观察别人情绪的人。"
language = "Chinese"
speaker = "Vivian"
instruct = "用特别愤怒的语气说"
prompts = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
additional_information = {
"task_type": [task_type],
"text": [text],
"language": [language],
"speaker": [speaker],
"instruct": [instruct],
"max_new_tokens": [2048],
}
inputs = {
"prompt": prompts,
"additional_information": {
"task_type": [task_type],
"text": [text],
"language": [language],
"speaker": [speaker],
"instruct": [instruct],
"max_new_tokens": [2048],
},
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
"additional_information": additional_information,
}
return QueryResult(
inputs=inputs,
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
model_name=model_name,
)


Expand All @@ -88,6 +130,7 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult:
QueryResult with Omni inputs and the VoiceDesign model path.
"""
task_type = "VoiceDesign"
model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign"
if use_batch_sample:
texts = [
"哥哥,你回来啦,人家等了你好久好久了,要抱抱!",
Expand All @@ -100,39 +143,39 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult:
languages = ["Chinese", "English"]
inputs = []
for text, instruct, language in zip(texts, instructs, languages):
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
additional_information = {
"task_type": [task_type],
"text": [text],
"language": [language],
"instruct": [instruct],
"max_new_tokens": [2048],
"non_streaming_mode": [True],
}
inputs.append(
{
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"text": [text],
"language": [language],
"instruct": [instruct],
"max_new_tokens": [2048],
"non_streaming_mode": [True],
},
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
"additional_information": additional_information,
}
)
else:
text = "哥哥,你回来啦,人家等了你好久好久了,要抱抱!"
instruct = "体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。"
language = "Chinese"
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
additional_information = {
"task_type": [task_type],
"text": [text],
"language": [language],
"instruct": [instruct],
"max_new_tokens": [2048],
"non_streaming_mode": [True],
}
inputs = {
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"text": [text],
"language": [language],
"instruct": [instruct],
"max_new_tokens": [2048],
"non_streaming_mode": [True],
},
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
"additional_information": additional_information,
}
return QueryResult(
inputs=inputs,
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
model_name=model_name,
)


Expand All @@ -147,6 +190,7 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que
QueryResult with Omni inputs and the Base model path.
"""
task_type = "Base"
model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
ref_audio_path_1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"
ref_audio_single = ref_audio_path_1
ref_text_single = (
Expand All @@ -163,38 +207,38 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que
syn_lang_batch = ["Chinese", "English"]
inputs = []
for text, language in zip(syn_text_batch, syn_lang_batch):
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
additional_information = {
"task_type": [task_type],
"ref_audio": [ref_audio_single],
"ref_text": [ref_text_single],
"text": [text],
"language": [language],
"x_vector_only_mode": [x_vector_only_mode],
"max_new_tokens": [2048],
}
inputs.append(
{
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"ref_audio": [ref_audio_single],
"ref_text": [ref_text_single],
"text": [text],
"language": [language],
"x_vector_only_mode": [x_vector_only_mode],
"max_new_tokens": [2048],
},
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
"additional_information": additional_information,
}
)
else:
prompt = f"<|im_start|>assistant\n{syn_text_single}<|im_end|>\n<|im_start|>assistant\n"
additional_information = {
"task_type": [task_type],
"ref_audio": [ref_audio_single],
"ref_text": [ref_text_single],
"text": [syn_text_single],
"language": [syn_lang_single],
"x_vector_only_mode": [x_vector_only_mode],
"max_new_tokens": [2048],
}
inputs = {
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"ref_audio": [ref_audio_single],
"ref_text": [ref_text_single],
"text": [syn_text_single],
"language": [syn_lang_single],
"x_vector_only_mode": [x_vector_only_mode],
"max_new_tokens": [2048],
},
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
"additional_information": additional_information,
}
return QueryResult(
inputs=inputs,
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
model_name=model_name,
)


Expand Down Expand Up @@ -223,30 +267,22 @@ def main(args):
stage_init_timeout=args.stage_init_timeout,
)

sampling_params = SamplingParams(
temperature=0.9,
top_p=1.0,
top_k=50,
max_tokens=2048,
seed=42,
detokenize=False,
repetition_penalty=1.05,
)

sampling_params_list = [
sampling_params,
]

output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav
os.makedirs(output_dir, exist_ok=True)

omni_generator = omni.generate(query_result.inputs, sampling_params_list)
omni_generator = omni.generate(query_result.inputs, sampling_params_list=None)
for stage_outputs in omni_generator:
for output in stage_outputs.request_output:
request_id = output.request_id
audio_tensor = output.outputs[0].multimodal_output["audio"]
audio_data = output.outputs[0].multimodal_output["audio"]
# async_chunk mode returns a list of chunks; concatenate them.
if isinstance(audio_data, list):
audio_tensor = torch.cat(audio_data, dim=-1)
else:
audio_tensor = audio_data
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
audio_samplerate = output.outputs[0].multimodal_output["sr"].item()
sr_val = output.outputs[0].multimodal_output["sr"]
audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1])
# Convert to numpy array and ensure correct format
audio_numpy = audio_tensor.float().detach().cpu().numpy()

Expand Down
10 changes: 6 additions & 4 deletions tests/entrypoints/openai_api/test_serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,12 @@ def test_is_tts_model(self, speech_server):
speech_server.engine_client.stage_list = [mock_stage]
assert speech_server._is_tts_model() is True

def test_build_tts_prompt(self, speech_server):
"""Test TTS prompt format."""
prompt = speech_server._build_tts_prompt("Hello")
assert prompt == "<|im_start|>assistant\nHello<|im_end|>\n<|im_start|>assistant\n"
def test_estimate_prompt_len_fallback(self, speech_server):
"""Test prompt length estimation falls back to 2048 when model is unavailable."""
tts_params = {"text": ["Hello"], "task_type": ["CustomVoice"]}
result = speech_server._estimate_prompt_len(tts_params)
# Without a real model, it should fall back to 2048.
assert result == 2048

def test_validate_tts_request_basic(self, speech_server):
"""Test basic validation cases."""
Expand Down
1 change: 1 addition & 0 deletions tests/worker/test_omni_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))

runner.talker_mtp = DummyTalkerMTP()
runner.model = SimpleNamespace(talker_mtp_output_key="code_predictor_codes")
runner.vllm_config = object()

# Provide a minimal implementation that returns the expected 4-tuple.
Expand Down
16 changes: 16 additions & 0 deletions vllm_omni/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class OmniModelConfig(ModelConfig):
}
)
omni_kv_config: dict | None = None
codec_frame_rate_hz: float | None = None

@property
def registry(self):
Expand Down Expand Up @@ -128,6 +129,21 @@ def __post_init__(
video_pruning_rate=video_pruning_rate,
)

# Qwen3-TTS: infer codec frame rate from the model config for online serving.
if self.codec_frame_rate_hz is None and self.model_arch == "Qwen3TTSTalkerForConditionalGenerationARVLLM":
talker_cfg = getattr(self.hf_config, "talker_config", None)
if isinstance(talker_cfg, dict):
pos_per_sec = talker_cfg.get("position_id_per_seconds")
else:
pos_per_sec = getattr(talker_cfg, "position_id_per_seconds", None)
if pos_per_sec is not None:
try:
fps = float(pos_per_sec)
except Exception:
fps = None
if fps is not None and fps > 0:
self.codec_frame_rate_hz = fps

# Override hf_text_config with omni-specific logic for multi-stage models
# (e.g., thinker_config, talker_config)
new_hf_text_config = self.draw_hf_text_config()
Expand Down
Loading