Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 201 additions & 7 deletions examples/offline_inference/qwen3_tts/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

Provides single and batch sample inputs for CustomVoice, VoiceDesign, and Base
tasks, then runs Omni generation and saves output wav files.

Also includes streaming generation tests for verifying streaming consistency.
"""

import json
import os
import time
from typing import NamedTuple

import torch
import soundfile as sf

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
Expand Down Expand Up @@ -136,6 +140,79 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult:
)


def get_streaming_test_query(use_batch_sample: bool = False) -> QueryResult:
"""Build streaming generation test inputs for comparing streaming vs blocking modes.

This function creates test cases that can be used to verify streaming generation
consistency, similar to the test cases in test_streaming_consistency.py.

Args:
use_batch_sample: When True, return a batch of prompts; otherwise a single prompt.

Returns:
QueryResult with Omni inputs and the model path for streaming tests.
"""
task_type = "CustomVoice"

if use_batch_sample:
# Batch test case - multiple texts with different parameters
texts = [
"其实我真的有发现,我是一个特别善于观察别人情绪的人。",
"She said she would be here by noon, but I'm starting to worry.",
"今天天气真不错,我们一起去公园散步吧!",
]
instructs = ["", "Slightly worried tone.", "开心愉快的语气"]
languages = ["Chinese", "English", "Chinese"]
speakers = ["Vivian", "Ryan", "Vivian"]

inputs = []
for text, instruct, language, speaker in zip(texts, instructs, languages, speakers):
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
inputs.append(
{
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"text": [text],
"instruct": [instruct],
"language": [language],
"speaker": [speaker],
"max_new_tokens": [100], # Shorter for testing
"stream": [True], # Enable streaming generation
"chunk_size": [5], # Chunk size for streaming
"left_context_size": [25], # Left context size for streaming
},
}
)
else:
# Single test case
text = "这是一个流式生成测试的例子,我们来验证流式生成和批量生成的一致性。"
language = "Chinese"
speaker = "Vivian"
instruct = "用清晰自然的语气说"
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"

inputs = {
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"text": [text],
"language": [language],
"speaker": [speaker],
"instruct": [instruct],
"max_new_tokens": [80], # Moderate length for testing
"stream": [True], # Enable streaming generation
"chunk_size": [5], # Chunk size for streaming
"left_context_size": [25], # Left context size for streaming
},
}

return QueryResult(
inputs=inputs,
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
)


def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> QueryResult:
"""Build Base (voice clone) sample inputs.

Expand Down Expand Up @@ -198,6 +275,108 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que
)


def run_streaming_generation_test(omni, query_result, sampling_params_list, output_dir):
"""Run streaming generation test and save results with detailed analysis.

This function runs the streaming generation and provides detailed output
about chunk information, timing, and consistency checks.

Args:
omni: Omni instance
query_result: QueryResult with streaming test inputs
sampling_params_list: List of SamplingParams
output_dir: Output directory for saving results
"""
print("=" * 60)
print("STREAMING GENERATION TEST")
print("=" * 60)

# Track timing and chunk information
test_results = {
"start_time": time.time(),
"chunks": [],
"total_audio_generated": 0,
"total_chunks": 0,
}

omni_generator = omni.generate(query_result.inputs, sampling_params_list)

chunk_idx = 0
audio_samplerate = 12000 # default value
for stage_outputs in omni_generator:
chunk_start_time = time.time()

request_id = stage_outputs.request_id
multimodal_output = stage_outputs.multimodal_output
if not multimodal_output or "audio" not in multimodal_output:
continue
audio_tensor = multimodal_output["audio"]
audio_samplerate = multimodal_output.get("sr", torch.tensor(12000)).item()

# Convert to numpy array
audio_numpy = audio_tensor.float().detach().cpu().numpy()
if audio_numpy.ndim > 1:
audio_numpy = audio_numpy.flatten()

# Save chunk audio
chunk_output_wav = os.path.join(output_dir, f"streaming_chunk_{chunk_idx:03d}_req_{request_id}.wav")
sf.write(chunk_output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV")

# Record chunk information
chunk_info = {
"chunk_idx": chunk_idx,
"request_id": request_id,
"audio_length_samples": len(audio_numpy),
"audio_duration_seconds": len(audio_numpy) / audio_samplerate,
"processing_time": time.time() - chunk_start_time,
"output_file": chunk_output_wav,
}

test_results["chunks"].append(chunk_info)
test_results["total_audio_generated"] += len(audio_numpy)

print(
f"Chunk {chunk_idx:3d} | Request {request_id} | "
f"Samples: {len(audio_numpy):6d} | "
f"Duration: {len(audio_numpy) / audio_samplerate:.2f}s | "
f"Processing: {chunk_info['processing_time']:.3f}s"
)

chunk_idx += 1

test_results["end_time"] = time.time()
test_results["total_chunks"] = chunk_idx
test_results["total_processing_time"] = test_results["end_time"] - test_results["start_time"]
test_results["total_audio_duration"] = test_results["total_audio_generated"] / audio_samplerate

# Save test results summary
results_file = os.path.join(output_dir, "streaming_test_results.json")
with open(results_file, "w", encoding="utf-8") as f:
# Convert chunk info output_file paths to strings for JSON serialization
serializable_results = test_results.copy()
serializable_results["chunks"] = [
{k: str(v) if k == "output_file" else v for k, v in chunk.items()} for chunk in test_results["chunks"]
]
json.dump(serializable_results, f, indent=2, ensure_ascii=False)

# Print summary
print("\n" + "=" * 60)
print("STREAMING TEST SUMMARY")
print("=" * 60)
print(f"Total chunks generated: {test_results['total_chunks']}")
print(f"Total audio samples: {test_results['total_audio_generated']}")
print(f"Total audio duration: {test_results['total_audio_duration']:.2f} seconds")
print(f"Total processing time: {test_results['total_processing_time']:.2f} seconds")
print(
f"Average chunk processing time: "
f"{test_results['total_processing_time'] / max(1, test_results['total_chunks']):.3f} seconds"
)
if test_results["total_processing_time"] > 0:
print(f"Real-time factor: {test_results['total_audio_duration'] / test_results['total_processing_time']:.2f}x")
print(f"Results saved to: {results_file}")
print("=" * 60)


def main(args):
"""Run offline inference with Omni using prepared sample inputs.

Expand All @@ -212,6 +391,8 @@ def main(args):
use_batch_sample=args.use_batch_sample,
mode_tag=args.mode_tag,
)
elif args.query_type == "StreamingTest":
query_result = query_func(use_batch_sample=args.use_batch_sample)
else:
query_result = query_func()

Expand Down Expand Up @@ -240,13 +421,19 @@ def main(args):
output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav
os.makedirs(output_dir, exist_ok=True)

omni_generator = omni.generate(query_result.inputs, sampling_params_list)
for stage_outputs in omni_generator:
for output in stage_outputs.request_output:
request_id = output.request_id
audio_tensor = output.outputs[0].multimodal_output["audio"]
# Use streaming test runner for streaming query types
if args.query_type in {"StreamingTest"}:
run_streaming_generation_test(omni, query_result, sampling_params_list, output_dir)
else:
omni_generator = omni.generate(query_result.inputs, sampling_params_list)
for stage_outputs in omni_generator:
request_id = stage_outputs.request_id
multimodal_output = stage_outputs.multimodal_output
if not multimodal_output or "audio" not in multimodal_output:
continue
audio_tensor = multimodal_output["audio"]
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
audio_samplerate = output.outputs[0].multimodal_output["sr"].item()
audio_samplerate = multimodal_output.get("sr", torch.tensor(12000)).item()
# Convert to numpy array and ensure correct format
audio_numpy = audio_tensor.float().detach().cpu().numpy()

Expand Down Expand Up @@ -365,6 +552,12 @@ def parse_args():
choices=["icl", "xvec_only"],
help="Mode tag for Base query x_vector_only_mode (default: icl).",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory for audio files (overrides --output-wav).",
)

return parser.parse_args()

Expand All @@ -373,6 +566,7 @@ def parse_args():
"CustomVoice": get_custom_voice_query,
"VoiceDesign": get_voice_design_query,
"Base": get_base_query,
"StreamingTest": get_streaming_test_query,
}


Expand Down
1 change: 1 addition & 0 deletions tests/model_executor/models/qwen3_tts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Tests for Qwen3 TTS model
Loading