diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index 93aeba3ca5f..14ecfa17ca8 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -2,11 +2,15 @@ Provides single and batch sample inputs for CustomVoice, VoiceDesign, and Base tasks, then runs Omni generation and saves output wav files. + +Also includes streaming generation tests for verifying streaming consistency. """ +import json import os +import time from typing import NamedTuple - +import torch import soundfile as sf os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -136,6 +140,79 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult: ) +def get_streaming_test_query(use_batch_sample: bool = False) -> QueryResult: + """Build streaming generation test inputs for comparing streaming vs blocking modes. + + This function creates test cases that can be used to verify streaming generation + consistency, similar to the test cases in test_streaming_consistency.py. + + Args: + use_batch_sample: When True, return a batch of prompts; otherwise a single prompt. + + Returns: + QueryResult with Omni inputs and the model path for streaming tests. + """ + task_type = "CustomVoice" + + if use_batch_sample: + # Batch test case - multiple texts with different parameters + texts = [ + "其实我真的有发现,我是一个特别善于观察别人情绪的人。", + "She said she would be here by noon, but I'm starting to worry.", + "今天天气真不错,我们一起去公园散步吧!", + ] + instructs = ["", "Slightly worried tone.", "开心愉快的语气"] + languages = ["Chinese", "English", "Chinese"] + speakers = ["Vivian", "Ryan", "Vivian"] + + inputs = [] + for text, instruct, language, speaker in zip(texts, instructs, languages, speakers): + prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + inputs.append( + { + "prompt": prompt, + "additional_information": { + "task_type": [task_type], + "text": [text], + "instruct": [instruct], + "language": [language], + "speaker": [speaker], + "max_new_tokens": [100], # Shorter for testing + "stream": [True], # Enable streaming generation + "chunk_size": [5], # Chunk size for streaming + "left_context_size": [25], # Left context size for streaming + }, + } + ) + else: + # Single test case + text = "这是一个流式生成测试的例子,我们来验证流式生成和批量生成的一致性。" + language = "Chinese" + speaker = "Vivian" + instruct = "用清晰自然的语气说" + prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + + inputs = { + "prompt": prompt, + "additional_information": { + "task_type": [task_type], + "text": [text], + "language": [language], + "speaker": [speaker], + "instruct": [instruct], + "max_new_tokens": [80], # Moderate length for testing + "stream": [True], # Enable streaming generation + "chunk_size": [5], # Chunk size for streaming + "left_context_size": [25], # Left context size for streaming + }, + } + + return QueryResult( + inputs=inputs, + model_name="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", + ) + + def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> QueryResult: """Build Base (voice clone) sample inputs. @@ -198,6 +275,108 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que ) +def run_streaming_generation_test(omni, query_result, sampling_params_list, output_dir): + """Run streaming generation test and save results with detailed analysis. + + This function runs the streaming generation and provides detailed output + about chunk information, timing, and consistency checks. + + Args: + omni: Omni instance + query_result: QueryResult with streaming test inputs + sampling_params_list: List of SamplingParams + output_dir: Output directory for saving results + """ + print("=" * 60) + print("STREAMING GENERATION TEST") + print("=" * 60) + + # Track timing and chunk information + test_results = { + "start_time": time.time(), + "chunks": [], + "total_audio_generated": 0, + "total_chunks": 0, + } + + omni_generator = omni.generate(query_result.inputs, sampling_params_list) + + chunk_idx = 0 + audio_samplerate = 12000 # default value + for stage_outputs in omni_generator: + chunk_start_time = time.time() + + request_id = stage_outputs.request_id + multimodal_output = stage_outputs.multimodal_output + if not multimodal_output or "audio" not in multimodal_output: + continue + audio_tensor = multimodal_output["audio"] + audio_samplerate = multimodal_output.get("sr", torch.tensor(12000)).item() + + # Convert to numpy array + audio_numpy = audio_tensor.float().detach().cpu().numpy() + if audio_numpy.ndim > 1: + audio_numpy = audio_numpy.flatten() + + # Save chunk audio + chunk_output_wav = os.path.join(output_dir, f"streaming_chunk_{chunk_idx:03d}_req_{request_id}.wav") + sf.write(chunk_output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV") + + # Record chunk information + chunk_info = { + "chunk_idx": chunk_idx, + "request_id": request_id, + "audio_length_samples": len(audio_numpy), + "audio_duration_seconds": len(audio_numpy) / audio_samplerate, + "processing_time": time.time() - chunk_start_time, + "output_file": chunk_output_wav, + } + + test_results["chunks"].append(chunk_info) + test_results["total_audio_generated"] += len(audio_numpy) + + print( + f"Chunk {chunk_idx:3d} | Request {request_id} | " + f"Samples: {len(audio_numpy):6d} | " + f"Duration: {len(audio_numpy) / audio_samplerate:.2f}s | " + f"Processing: {chunk_info['processing_time']:.3f}s" + ) + + chunk_idx += 1 + + test_results["end_time"] = time.time() + test_results["total_chunks"] = chunk_idx + test_results["total_processing_time"] = test_results["end_time"] - test_results["start_time"] + test_results["total_audio_duration"] = test_results["total_audio_generated"] / audio_samplerate + + # Save test results summary + results_file = os.path.join(output_dir, "streaming_test_results.json") + with open(results_file, "w", encoding="utf-8") as f: + # Convert chunk info output_file paths to strings for JSON serialization + serializable_results = test_results.copy() + serializable_results["chunks"] = [ + {k: str(v) if k == "output_file" else v for k, v in chunk.items()} for chunk in test_results["chunks"] + ] + json.dump(serializable_results, f, indent=2, ensure_ascii=False) + + # Print summary + print("\n" + "=" * 60) + print("STREAMING TEST SUMMARY") + print("=" * 60) + print(f"Total chunks generated: {test_results['total_chunks']}") + print(f"Total audio samples: {test_results['total_audio_generated']}") + print(f"Total audio duration: {test_results['total_audio_duration']:.2f} seconds") + print(f"Total processing time: {test_results['total_processing_time']:.2f} seconds") + print( + f"Average chunk processing time: " + f"{test_results['total_processing_time'] / max(1, test_results['total_chunks']):.3f} seconds" + ) + if test_results["total_processing_time"] > 0: + print(f"Real-time factor: {test_results['total_audio_duration'] / test_results['total_processing_time']:.2f}x") + print(f"Results saved to: {results_file}") + print("=" * 60) + + def main(args): """Run offline inference with Omni using prepared sample inputs. @@ -212,6 +391,8 @@ def main(args): use_batch_sample=args.use_batch_sample, mode_tag=args.mode_tag, ) + elif args.query_type == "StreamingTest": + query_result = query_func(use_batch_sample=args.use_batch_sample) else: query_result = query_func() @@ -240,13 +421,19 @@ def main(args): output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav os.makedirs(output_dir, exist_ok=True) - omni_generator = omni.generate(query_result.inputs, sampling_params_list) - for stage_outputs in omni_generator: - for output in stage_outputs.request_output: - request_id = output.request_id - audio_tensor = output.outputs[0].multimodal_output["audio"] + # Use streaming test runner for streaming query types + if args.query_type in {"StreamingTest"}: + run_streaming_generation_test(omni, query_result, sampling_params_list, output_dir) + else: + omni_generator = omni.generate(query_result.inputs, sampling_params_list) + for stage_outputs in omni_generator: + request_id = stage_outputs.request_id + multimodal_output = stage_outputs.multimodal_output + if not multimodal_output or "audio" not in multimodal_output: + continue + audio_tensor = multimodal_output["audio"] output_wav = os.path.join(output_dir, f"output_{request_id}.wav") - audio_samplerate = output.outputs[0].multimodal_output["sr"].item() + audio_samplerate = multimodal_output.get("sr", torch.tensor(12000)).item() # Convert to numpy array and ensure correct format audio_numpy = audio_tensor.float().detach().cpu().numpy() @@ -365,6 +552,12 @@ def parse_args(): choices=["icl", "xvec_only"], help="Mode tag for Base query x_vector_only_mode (default: icl).", ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Output directory for audio files (overrides --output-wav).", + ) return parser.parse_args() @@ -373,6 +566,7 @@ def parse_args(): "CustomVoice": get_custom_voice_query, "VoiceDesign": get_voice_design_query, "Base": get_base_query, + "StreamingTest": get_streaming_test_query, } diff --git a/tests/model_executor/models/qwen3_tts/__init__.py b/tests/model_executor/models/qwen3_tts/__init__.py new file mode 100644 index 00000000000..f21c5682bcb --- /dev/null +++ b/tests/model_executor/models/qwen3_tts/__init__.py @@ -0,0 +1 @@ +# Tests for Qwen3 TTS model diff --git a/tests/model_executor/models/qwen3_tts/test_streaming_vs_nonstreaming.py b/tests/model_executor/models/qwen3_tts/test_streaming_vs_nonstreaming.py new file mode 100644 index 00000000000..b9a8179b1fd --- /dev/null +++ b/tests/model_executor/models/qwen3_tts/test_streaming_vs_nonstreaming.py @@ -0,0 +1,254 @@ +""" +Test streaming vs non-streaming codec codes consistency. + +This script verifies that streaming and non-streaming generation produce +identical codec codes (token_id level comparison). + +Usage: + python test_streaming_vs_nonstreaming.py + + # Override model path + MODEL_PATH=/your/path python test_streaming_vs_nonstreaming.py +""" + +import os +import sys +from pathlib import Path + +import torch + +# Add project root to path +sys.path.insert(0, str(Path(__file__).resolve().parents[4])) + +from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts import Qwen3TTSModel + +# Model path - can be overridden via MODEL_PATH environment variable +MODEL_PATH = os.environ.get("MODEL_PATH", "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice") + + +def test_streaming_vs_nonstreaming_codes(): + """ + Test that streaming and non-streaming generation produce identical codec codes. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Load model + print(f"Loading model from {MODEL_PATH}...") + model = Qwen3TTSModel.from_pretrained( + MODEL_PATH, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, + device_map=device, + ) + print("Model loaded successfully!") + + tts_model_type = model.model.tts_model_type + print(f"Model type: {tts_model_type}") + + # Common parameters + text = "这是一个流式生成测试的例子,我们来验证流式生成和批量生成的一致性。" + chunk_size = 25 + max_new_tokens = 2048 + + # Prepare common generation kwargs based on model type + if tts_model_type == "custom_voice": + speaker = "Vivian" + language = "Chinese" + instruct = "用清晰自然的语气说" + elif tts_model_type == "voice_design": + speaker = None + language = "Chinese" + instruct = "用清晰自然的语气说" + elif tts_model_type == "base": + speaker = None + language = "Chinese" + instruct = None + else: + raise ValueError(f"Unknown model type: {tts_model_type}") + + print("\n" + "=" * 60) + print("CODEC CODES COMPARISON: Streaming vs Non-Streaming") + print(f" text={text!r}") + print(f" speaker={speaker!r}") + print(f" language={language!r}") + print(f" instruct={instruct!r}") + print(f" chunk_size={chunk_size}") + print(f" max_new_tokens={max_new_tokens}") + print(" do_sample=False (greedy)") + print("=" * 60) + + # ========== Prepare common inputs ========== + input_ids = model._tokenize_texts([model._build_assistant_text(text)]) + print(f"\n[INPUT] input_ids shape: {input_ids[0].shape}") + + instruct_ids = [None] + if instruct is not None and instruct != "": + instruct_ids = [model._tokenize_texts([model._build_instruct_text(instruct)])[0]] + print(f"[INPUT] instruct_ids shape: {instruct_ids[0].shape}") + + # Use greedy decoding for deterministic comparison + gen_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": False, + "repetition_penalty": 1.05, + } + print(f"[INPUT] gen_kwargs: {gen_kwargs}") + + # ========== NON-STREAMING: Get codec codes via generate() ========== + print("\n" + "-" * 60) + print(">>> NON-STREAMING: Generating codec codes with generate()...") + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed(42) + + non_streaming_codes_list, _ = model.model.generate( + input_ids=input_ids, + instruct_ids=instruct_ids, + languages=[language], + speakers=[speaker], + **gen_kwargs, + ) + + non_streaming_codes = non_streaming_codes_list[0] # [T, K] + print(f"[NON-STREAMING] codec_codes shape: {non_streaming_codes.shape}") + print(f"[NON-STREAMING] Total tokens: {non_streaming_codes.shape[0]}") + print(f"[NON-STREAMING] First 10 tokens (first codebook):\n {non_streaming_codes[:10, 0].tolist()}") + + # ========== STREAMING: Get codec codes via generate_streaming_iter() ========== + print("\n" + "-" * 60) + print(">>> STREAMING: Generating codec codes with talker.generate_streaming_iter()...") + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed(42) + + # Prepare talker inputs (same as generate_streaming does internally) + talker_input_embeds, trailing_text_hiddens, tts_pad_embed = model.model._prepare_talker_inputs( + input_ids=input_ids, + instruct_ids=instruct_ids, + languages=[language], + speakers=[speaker], + ) + + # Collect all codec codes from streaming iterator + streaming_all_codes = [] + chunk_count = 0 + + for chunk_output in model.model.talker.generate_streaming_iter( + inputs_embeds=talker_input_embeds, + attention_mask=torch.ones( + (1, talker_input_embeds.shape[1]), + device=talker_input_embeds.device, + dtype=torch.long, + ), + trailing_text_hidden=trailing_text_hiddens, + tts_pad_embed=tts_pad_embed, + chunk_size=chunk_size, + max_new_tokens=max_new_tokens, + min_new_tokens=2, + do_sample=False, + repetition_penalty=1.05, + eos_token_id=model.model.config.talker_config.codec_eos_token_id, + suppress_tokens=[ + i + for i in range( + model.model.config.talker_config.vocab_size - 1024, model.model.config.talker_config.vocab_size + ) + if i not in (model.model.config.talker_config.codec_eos_token_id,) + ], + ): + chunk_count += 1 + streaming_all_codes.append(chunk_output.codec_codes) + print( + f"[STREAMING] Chunk {chunk_count}: codes shape={chunk_output.codec_codes.shape}, " + f"total_generated={chunk_output.total_generated}, finished={chunk_output.is_finished}" + ) + + # Concatenate all streaming codes + streaming_codes = torch.cat(streaming_all_codes, dim=0) if streaming_all_codes else torch.tensor([]) + print(f"\n[STREAMING] Total codec_codes shape: {streaming_codes.shape}") + print(f"[STREAMING] Total tokens: {streaming_codes.shape[0]}") + print(f"[STREAMING] First 10 tokens (first codebook):\n {streaming_codes[:10, 0].tolist()}") + + # ========== COMPARE AND ASSERT (excluding EOS token) ========== + print("\n" + "=" * 60) + print("COMPARISON RESULTS (TOKEN_ID LEVEL):") + print("=" * 60) + print(f"Non-streaming tokens (raw): {non_streaming_codes.shape[0]}") + print(f"Streaming tokens (raw): {streaming_codes.shape[0]}") + + # Move to CPU for comparison + streaming_codes_cpu = streaming_codes.cpu() + non_streaming_codes_cpu = non_streaming_codes.cpu() + + # Get EOS token ID + eos_token_id = model.model.config.talker_config.codec_eos_token_id + print(f"EOS token ID: {eos_token_id}") + + # Remove EOS tokens from the end (check first codebook for EOS) + def remove_trailing_eos(codes, eos_id): + """Remove trailing EOS tokens from codes.""" + if codes.shape[0] == 0: + return codes + # Check from the end, remove all trailing EOS tokens + end_idx = codes.shape[0] + while end_idx > 0 and codes[end_idx - 1, 0].item() == eos_id: + end_idx -= 1 + return codes[:end_idx] + + non_streaming_codes_no_eos = remove_trailing_eos(non_streaming_codes_cpu, eos_token_id) + streaming_codes_no_eos = remove_trailing_eos(streaming_codes_cpu, eos_token_id) + + print(f"Non-streaming tokens (no EOS): {non_streaming_codes_no_eos.shape[0]}") + print(f"Streaming tokens (no EOS): {streaming_codes_no_eos.shape[0]}") + + # Assert token counts match (excluding EOS) + assert non_streaming_codes_no_eos.shape[0] == streaming_codes_no_eos.shape[0], ( + f"Token count mismatch (excluding EOS): non-streaming={non_streaming_codes_no_eos.shape[0]}, " + f"streaming={streaming_codes_no_eos.shape[0]}" + ) + print("✓ Token counts match (excluding EOS)!") + + # Assert all codebooks match + num_codebooks = min(non_streaming_codes_no_eos.shape[1], streaming_codes_no_eos.shape[1]) + total_diff_count = 0 + first_diff_info = None + + for cb in range(num_codebooks): + diff_mask = non_streaming_codes_no_eos[:, cb] != streaming_codes_no_eos[:, cb] + diff_count = diff_mask.sum().item() + + if diff_count > 0: + total_diff_count += diff_count + if first_diff_info is None: + diff_positions = torch.where(diff_mask)[0] + pos = diff_positions[0].item() + first_diff_info = { + "codebook": cb, + "position": pos, + "non_streaming": non_streaming_codes_no_eos[pos, cb].item(), + "streaming": streaming_codes_no_eos[pos, cb].item(), + } + + if total_diff_count > 0: + print(f"\n✗ Found {total_diff_count} token differences!") + print(f" First diff at codebook {first_diff_info['codebook']}, position {first_diff_info['position']}:") + print(f" Non-streaming: {first_diff_info['non_streaming']}") + print(f" Streaming: {first_diff_info['streaming']}") + + assert total_diff_count == 0, ( + f"Token mismatch: {total_diff_count} differences found. " + f"First diff at codebook {first_diff_info['codebook']}, position {first_diff_info['position']}: " + f"non-streaming={first_diff_info['non_streaming']}, streaming={first_diff_info['streaming']}" + ) + + print(f"✓ All {non_streaming_codes_no_eos.shape[0]} tokens are IDENTICAL across all {num_codebooks} codebooks!") + print("\n" + "=" * 60) + print("TEST PASSED!") + print("=" * 60) + + return True + + +if __name__ == "__main__": + success = test_streaming_vs_nonstreaming_codes() + sys.exit(0 if success else 1) diff --git a/vllm_omni/metrics/stats.py b/vllm_omni/metrics/stats.py index 037db2ea8da..d767229e21a 100644 --- a/vllm_omni/metrics/stats.py +++ b/vllm_omni/metrics/stats.py @@ -225,8 +225,13 @@ def record_audio_generated_frames( output_to_yield.final_output_type == "audio" and finished and (multimodal_output := output_to_yield.multimodal_output.get("audio")) is not None + and len(multimodal_output) > 0 ): - nframes = int(multimodal_output[-1].shape[0]) + last_output = multimodal_output[-1] + # Handle case where last_output might be an empty tuple or doesn't have valid shape + if not hasattr(last_output, "shape") or len(last_output.shape) == 0: + return + nframes = int(last_output.shape[0]) stage_events_for_req = self.stage_events.get(request_id, []) if stage_events_for_req: for stage_event in stage_events_for_req: diff --git a/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py index 1e759a8d2b4..c176368abaf 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py +++ b/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py @@ -15,9 +15,13 @@ import json import os -from collections.abc import Callable +import queue +import threading +from collections.abc import Callable, Generator, Iterator from dataclasses import dataclass +from typing import Any +import numpy as np import torch from librosa.filters import mel as librosa_mel_fn from torch import nn @@ -44,7 +48,7 @@ from transformers.utils.hub import cached_file from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific - +from vllm_omni.outputs import AsyncDecodingPipeline, StreamingChunkOutput from .configuration_qwen3_tts import ( Qwen3TTSConfig, Qwen3TTSSpeakerEncoderConfig, @@ -1324,6 +1328,11 @@ class Qwen3TTSTalkerOutputWithPast(ModelOutput): tts_pad_embed: torch.FloatTensor | None = None + + + + + class Qwen3TTSTalkerDecoderLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() @@ -1648,6 +1657,9 @@ def forward( # Generate else: last_id_hidden = self.get_input_embeddings()(input_ids) + # Suppress tokens >= 2048 for code_predictor (decoder codebook only has 2048 entries) + CODEC_VALID_MAX = 2048 + code_predictor_suppress_tokens = list(range(CODEC_VALID_MAX, self.code_predictor.config.vocab_size)) predictor_result = self.code_predictor.generate( inputs_embeds=torch.cat((past_hidden, last_id_hidden), dim=1), max_new_tokens=self.config.num_code_groups - 1, @@ -1657,6 +1669,7 @@ def forward( temperature=subtalker_temperature, output_hidden_states=True, return_dict_in_generate=True, + suppress_tokens=code_predictor_suppress_tokens, ) codec_ids = torch.cat((input_ids, predictor_result.sequences), dim=-1) codec_hiddens = torch.cat( @@ -1792,6 +1805,244 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_ model_kwargs["tts_pad_embed"] = outputs.tts_pad_embed return model_kwargs + @torch.no_grad() + def generate_streaming_iter( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + trailing_text_hidden: torch.Tensor, + tts_pad_embed: torch.Tensor, + chunk_size: int = 25, + max_new_tokens: int = 4096, + do_sample: bool = True, + top_k: int = 50, + top_p: float = 1.0, + temperature: float = 0.9, + subtalker_dosample: bool = True, + subtalker_top_k: int = 50, + subtalker_top_p: float = 1.0, + subtalker_temperature: float = 0.9, + eos_token_id: int | None = None, + repetition_penalty: float = 1.05, + suppress_tokens: list[int] | None = None, + **kwargs, + ) -> Generator[StreamingChunkOutput, None, None]: + """ + Streaming generation iterator that yields codec tokens in chunks. + + This method generates talker codes incrementally and yields them in chunks + of `chunk_size` tokens, enabling real-time processing. + + The sampling logic is designed to match HuggingFace GenerationMixin.generate() + to ensure identical outputs between streaming and blocking modes. + + Args: + inputs_embeds: Input embeddings for the talker model + attention_mask: Attention mask for the input + trailing_text_hidden: Hidden states for trailing text + tts_pad_embed: Padding embedding + chunk_size: Number of tokens per chunk to yield + max_new_tokens: Maximum number of new tokens to generate + do_sample: Whether to use sampling + top_k: Top-k sampling parameter + top_p: Top-p sampling parameter + temperature: Sampling temperature + subtalker_dosample: Whether to use sampling for subtalker + subtalker_top_k: Subtalker top-k parameter + subtalker_top_p: Subtalker top-p parameter + subtalker_temperature: Subtalker temperature + eos_token_id: End of sequence token ID + repetition_penalty: Repetition penalty + suppress_tokens: List of token ids to suppress (set to -inf) + + Yields: + StreamingChunkOutput: Contains codec codes and metadata for each chunk + """ + if eos_token_id is None: + eos_token_id = self.config.codec_eos_token_id + + # Suppress certain tokens (same as blocking generate) + if suppress_tokens is None: + suppress_tokens = [ + i for i in range(self.config.vocab_size - 1024, self.config.vocab_size) if i not in (eos_token_id,) + ] + + # Convert suppress_tokens to tensor for efficient masking + suppress_tokens_tensor = torch.tensor(suppress_tokens, device=inputs_embeds.device) + + # Initialize cache and generation state + past_key_values = DynamicCache() + cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + + # Prefill phase + outputs = self.forward( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + cache_position=cache_position, + trailing_text_hidden=trailing_text_hidden, + tts_pad_embed=tts_pad_embed, + generation_step=0, + subtalker_dosample=subtalker_dosample, + subtalker_top_k=subtalker_top_k, + subtalker_top_p=subtalker_top_p, + subtalker_temperature=subtalker_temperature, + output_hidden_states=True, + ) + + past_key_values = outputs.past_key_values + past_hidden = outputs.past_hidden + generation_step = outputs.generation_step + + # Token accumulator for chunking + accumulated_codes: list[torch.Tensor] = [] + accumulated_hidden: list[torch.Tensor] = [] + total_generated = 0 + chunk_idx = 0 + + # Repetition penalty tracking - use tensor for efficient lookup + generated_ids: list[int] = [] + + for step in range(max_new_tokens): + # Get logits for next token + logits = outputs.logits[:, -1, :].clone() # [batch, vocab], clone to avoid in-place modification + + # Step 1: Apply repetition penalty (same as HuggingFace RepetitionPenaltyLogitsProcessor) + if repetition_penalty != 1.0 and len(generated_ids) > 0: + # Create a tensor of unique generated ids for vectorized operation + unique_ids = list(set(generated_ids)) + for token_id in unique_ids: + # HuggingFace style: score < 0 -> multiply by penalty, score > 0 -> divide by penalty + if logits[0, token_id] < 0: + logits[0, token_id] = logits[0, token_id] * repetition_penalty + else: + logits[0, token_id] = logits[0, token_id] / repetition_penalty + + # Step 2: Apply suppress_tokens (set to -inf so they won't be sampled) + if len(suppress_tokens) > 0: + logits[:, suppress_tokens_tensor] = float("-inf") + + # Step 3: Sample or greedy + if do_sample: + # Apply temperature FIRST (before top-k/top-p, matching HuggingFace TemperatureLogitsWarper) + if temperature != 1.0 and temperature > 0: + logits = logits / temperature + + # Top-k filtering (HuggingFace TopKLogitsWarper style) + if top_k > 0 and top_k < logits.size(-1): + # Get the top-k threshold + top_k_values, _ = torch.topk(logits, min(top_k, logits.size(-1))) + threshold = top_k_values[..., -1, None] + # Remove tokens with logits below threshold + logits = torch.where(logits < threshold, torch.tensor(float("-inf"), device=logits.device), logits) + + # Top-p (nucleus) filtering (HuggingFace TopPLogitsWarper style) + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Keep at least one token (shift right and set first to False) + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = False + + # Scatter back to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + dim=1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, float("-inf")) + + # Sample from the filtered distribution + probs = F.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + # Greedy decoding + next_token = torch.argmax(logits, dim=-1, keepdim=True) + + next_token_id = next_token.item() + generated_ids.append(next_token_id) + + # Check for EOS + is_eos = next_token_id == eos_token_id + + # Update attention mask + attention_mask = torch.cat( + [attention_mask, torch.ones((1, 1), device=attention_mask.device, dtype=attention_mask.dtype)], + dim=1, + ) + cache_position = torch.tensor([cache_position[-1] + 1], device=inputs_embeds.device) + + # Forward pass for next token + outputs = self.forward( + input_ids=next_token, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + cache_position=cache_position, + past_hidden=past_hidden, + trailing_text_hidden=trailing_text_hidden, + tts_pad_embed=tts_pad_embed, + generation_step=generation_step, + subtalker_dosample=subtalker_dosample, + subtalker_top_k=subtalker_top_k, + subtalker_top_p=subtalker_top_p, + subtalker_temperature=subtalker_temperature, + output_hidden_states=True, + ) + + past_key_values = outputs.past_key_values + past_hidden = outputs.past_hidden + generation_step = outputs.generation_step + + # Get codec_ids from hidden_states (contains tuple of (hidden_states, codec_ids)) + codec_ids = outputs.hidden_states[1] # [batch, num_quantizers] + + if codec_ids is not None: + # Check if any codec layer contains EOS token (should terminate) + if (codec_ids == eos_token_id).any(): + is_eos = True + + accumulated_codes.append(codec_ids[0]) # Remove batch dim + if outputs.hidden_states[0] is not None: + accumulated_hidden.append(past_hidden[0, 0, :]) + total_generated += 1 + + # Yield chunk if accumulated enough or finished + if len(accumulated_codes) >= chunk_size or is_eos: + if len(accumulated_codes) > 0: + chunk_codes = torch.stack(accumulated_codes, dim=0) # [chunk_size, num_quantizers] + chunk_hidden = torch.stack(accumulated_hidden, dim=0) if accumulated_hidden else None + + yield StreamingChunkOutput( + codec_codes=chunk_codes, + hidden_states=chunk_hidden, + chunk_idx=chunk_idx, + is_finished=is_eos, + total_generated=total_generated, + ) + + accumulated_codes = [] + accumulated_hidden = [] + chunk_idx += 1 + + if is_eos or max_new_tokens is not None and total_generated >= max_new_tokens: + break + + # Yield any remaining tokens + if len(accumulated_codes) > 0: + chunk_codes = torch.stack(accumulated_codes, dim=0) + chunk_hidden = torch.stack(accumulated_hidden, dim=0) if accumulated_hidden else None + + yield StreamingChunkOutput( + codec_codes=chunk_codes, + hidden_states=chunk_hidden, + chunk_idx=chunk_idx, + is_finished=True, + total_generated=total_generated, + ) + class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin): config_class = Qwen3TTSConfig @@ -2030,7 +2281,6 @@ def generate( ): talker_kwargs = { "max_new_tokens": max_new_tokens, - "min_new_tokens": 2, "do_sample": do_sample, "top_k": top_k, "top_p": top_p, @@ -2317,10 +2567,359 @@ def generate( return talker_codes_list, talker_hidden_states_list + @torch.no_grad() + def generate_streaming( + self, + input_ids: list[torch.Tensor] | None = None, + instruct_ids: list[torch.Tensor] | None = None, + ref_ids: list[torch.Tensor] | None = None, + voice_clone_prompt: list[dict] = None, + languages: list[str] = None, + speakers: list[str] = None, + non_streaming_mode: bool = False, + max_new_tokens: int = 4096, + do_sample: bool = True, + top_k: int = 50, + top_p: float = 1.0, + temperature: float = 0.9, + subtalker_dosample: bool = True, + subtalker_top_k: int = 50, + subtalker_top_p: float = 1.0, + subtalker_temperature: float = 0.9, + eos_token_id: int | None = None, + repetition_penalty: float = 1.05, + chunk_size: int = 25, + left_context_size: int = 25, + **kwargs: Any, + ) -> Generator[tuple[np.ndarray, bool, int], None, None]: + """ + True streaming generation method that yields audio chunks as tokens are generated. + + This method generates talker codes incrementally and decodes them to audio + in parallel, enabling real-time audio streaming output. Unlike the previous + implementation, this truly streams - it yields audio chunks while still + generating subsequent tokens. + + Args: + input_ids: Input token IDs for the text to synthesize + instruct_ids: Optional instruction token IDs + ref_ids: Optional reference audio token IDs for voice cloning + voice_clone_prompt: Voice cloning prompt configuration + languages: Language codes for each sample + speakers: Speaker IDs for each sample + non_streaming_mode: If True, wait for full text before generating + max_new_tokens: Maximum number of new tokens to generate + do_sample: Whether to use sampling for generation + top_k: Top-k sampling parameter + top_p: Top-p (nucleus) sampling parameter + temperature: Sampling temperature + subtalker_dosample: Whether to use sampling for subtalker + subtalker_top_k: Subtalker top-k parameter + subtalker_top_p: Subtalker top-p parameter + subtalker_temperature: Subtalker temperature + eos_token_id: End of sequence token ID + repetition_penalty: Repetition penalty for generation + chunk_size: Number of codec frames per audio chunk (default 25) + left_context_size: Context frames for smooth chunk boundaries (default 25) + + Yields: + tuple[np.ndarray, bool, int]: (audio_chunk, is_finished, sample_rate) + - audio_chunk: Decoded audio waveform numpy array for this chunk + - is_finished: True if generation is complete + - sample_rate: Audio sample rate + """ + # Prepare input embeddings (same logic as generate method) + talker_input_embeds, trailing_text_hiddens, tts_pad_embed = self._prepare_talker_inputs( + input_ids=input_ids, + instruct_ids=instruct_ids, + ref_ids=ref_ids, + voice_clone_prompt=voice_clone_prompt, + languages=languages, + speakers=speakers, + non_streaming_mode=non_streaming_mode, + ) + + # Get reference codes for voice cloning + ref_code = None + if voice_clone_prompt is not None: + ref_code_list = voice_clone_prompt.get("ref_code", None) + if ref_code_list is not None and ref_code_list[0] is not None: + ref_code = ref_code_list[0].to(self.talker.device) + + # Create async decoding pipeline + with AsyncDecodingPipeline( + speech_tokenizer=self.speech_tokenizer, + ref_code=ref_code, + left_context_size=left_context_size, + ) as decode_pipeline: + # decode_pipeline.start() is called in __enter__ + + # Use streaming iterator to generate tokens + streaming_iter = self.talker.generate_streaming_iter( + inputs_embeds=talker_input_embeds, + attention_mask=torch.ones( + (1, talker_input_embeds.shape[1]), + device=talker_input_embeds.device, + dtype=torch.long, + ), + trailing_text_hidden=trailing_text_hiddens, + tts_pad_embed=tts_pad_embed, + chunk_size=chunk_size, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + top_k=top_k, + top_p=top_p, + temperature=temperature, + subtalker_dosample=subtalker_dosample, + subtalker_top_k=subtalker_top_k, + subtalker_top_p=subtalker_top_p, + subtalker_temperature=subtalker_temperature, + eos_token_id=eos_token_id if eos_token_id is not None else self.config.talker_config.codec_eos_token_id, + repetition_penalty=repetition_penalty, + suppress_tokens=[ + i + for i in range(self.config.talker_config.vocab_size - 1024, self.config.talker_config.vocab_size) + if i not in (self.config.talker_config.codec_eos_token_id,) + ], + ) + + # Producer: generate chunks and submit for decoding + pending_chunks = 0 + + for chunk_output in streaming_iter: + # Submit chunk for async decoding + decode_pipeline.submit_chunk( + chunk_output.codec_codes, + is_last=chunk_output.is_finished, + ) + pending_chunks += 1 + + # Consumer: yield decoded chunks as they become available + while pending_chunks > 0: + audio, is_last, sr, error = decode_pipeline.get_decoded_chunk(timeout=0.01) + if error is not None: + raise error + if audio is not None: + pending_chunks -= 1 + yield audio, is_last, sr + if is_last: + return + else: + # No decoded chunk available yet, continue generating + break + + # Drain remaining decoded chunks + while pending_chunks > 0: + audio, is_last, sr, error = decode_pipeline.get_decoded_chunk(timeout=1.0) + if error is not None: + raise error + if audio is not None: + pending_chunks -= 1 + yield audio, is_last, sr + + def _prepare_talker_inputs( + self, + input_ids: list[torch.Tensor], + instruct_ids: list[torch.Tensor] | None = None, + ref_ids: list[torch.Tensor] | None = None, + voice_clone_prompt: dict | None = None, + languages: list[str] | None = None, + speakers: list[str] | None = None, + non_streaming_mode: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepare talker input embeddings for a single sample (batch size 1 for streaming). + + Returns: + tuple: (talker_input_embeds, trailing_text_hiddens, tts_pad_embed) + """ + # Only support single sample for streaming + assert len(input_ids) == 1, "Streaming only supports batch size 1" + + input_id = input_ids[0] + language = languages[0] if languages else "auto" + speaker = speakers[0] if speakers else None + + # Get special embeddings + tts_bos_embed, tts_eos_embed, tts_pad_embed = self.talker.text_projection( + self.talker.get_text_embeddings()( + torch.tensor( + [[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]], + device=self.talker.device, + dtype=input_id.dtype, + ) + ) + ).chunk(3, dim=1) # 3 * [1 1 d] + + # Voice clone speaker prompt + voice_clone_spk_embeds = None + if voice_clone_prompt is not None: + voice_clone_spk_embeds = self.generate_speaker_prompt(voice_clone_prompt) + + # Determine speaker embedding + if voice_clone_spk_embeds is None: + if speaker == "" or speaker is None: + speaker_embed = None + else: + if speaker.lower() not in self.config.talker_config.spk_id: + raise NotImplementedError(f"Speaker {speaker} not implemented") + spk_id = self.config.talker_config.spk_id[speaker.lower()] + speaker_embed = self.talker.get_input_embeddings()( + torch.tensor(spk_id, device=self.talker.device, dtype=input_id.dtype) + ) + else: + if voice_clone_prompt["x_vector_only_mode"][0] or voice_clone_prompt["icl_mode"][0]: + speaker_embed = voice_clone_spk_embeds[0] + else: + speaker_embed = None + + # Language ID + if language.lower() == "auto": + language_id = None + else: + if language.lower() not in self.config.talker_config.codec_language_id: + raise NotImplementedError(f"Language {language} not implemented") + language_id = self.config.talker_config.codec_language_id[language.lower()] + + # Build codec prefill + if language_id is None: + codec_prefill_list = [ + [ + self.config.talker_config.codec_nothink_id, + self.config.talker_config.codec_think_bos_id, + self.config.talker_config.codec_think_eos_id, + ] + ] + else: + codec_prefill_list = [ + [ + self.config.talker_config.codec_think_id, + self.config.talker_config.codec_think_bos_id, + language_id, + self.config.talker_config.codec_think_eos_id, + ] + ] + + codec_input_emebdding_0 = self.talker.get_input_embeddings()( + torch.tensor(codec_prefill_list, device=self.talker.device, dtype=input_id.dtype) + ) + codec_input_emebdding_1 = self.talker.get_input_embeddings()( + torch.tensor( + [[self.config.talker_config.codec_pad_id, self.config.talker_config.codec_bos_id]], + device=self.talker.device, + dtype=input_id.dtype, + ) + ) + if speaker_embed is None: + codec_input_emebdding = torch.cat([codec_input_emebdding_0, codec_input_emebdding_1], dim=1) + else: + codec_input_emebdding = torch.cat( + [codec_input_emebdding_0, speaker_embed.view(1, 1, -1), codec_input_emebdding_1], dim=1 + ) + + # Build talker input embed + talker_input_embed_parts = [] + + # Instruct embed + if instruct_ids is not None and instruct_ids[0] is not None: + talker_input_embed_parts.append( + self.talker.text_projection(self.talker.get_text_embeddings()(instruct_ids[0])) + ) + + # Role embed: <|im_start|>assistant\n + _talker_input_embed_role = self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, :3])) + + # tts_pad * N + tts_bos + _talker_input_embed = ( + torch.cat( + ( + tts_pad_embed.expand(-1, codec_input_emebdding.shape[1] - 2, -1), + tts_bos_embed, + ), + dim=1, + ) + + codec_input_emebdding[:, :-1] + ) + + talker_input_embed = torch.cat((_talker_input_embed_role, _talker_input_embed), dim=1) + + # Handle ICL mode or regular mode + if ( + voice_clone_prompt is not None + and voice_clone_prompt["ref_code"] is not None + and voice_clone_prompt["icl_mode"][0] + ): + icl_input_embed, trailing_text_hidden = self.generate_icl_prompt( + text_id=input_id[:, 3:-5], + ref_id=ref_ids[0][:, 3:-2], + ref_code=voice_clone_prompt["ref_code"][0].to(self.talker.device), + tts_pad_embed=tts_pad_embed, + tts_eos_embed=tts_eos_embed, + non_streaming_mode=non_streaming_mode, + ) + talker_input_embed = torch.cat([talker_input_embed, icl_input_embed], dim=1) + else: + # tts_text_first_token + talker_input_embed = torch.cat( + [ + talker_input_embed, + self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 3:4])) + + codec_input_emebdding[:, -1:], + ], + dim=1, + ) + if non_streaming_mode: + talker_input_embed = talker_input_embed[:, :-1] + talker_input_embed = torch.cat( + [ + talker_input_embed, + torch.cat( + ( + self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 3:-5])), + tts_eos_embed, + ), + dim=1, + ) + + self.talker.get_input_embeddings()( + torch.tensor( + [[self.config.talker_config.codec_pad_id] * (input_id[:, 3:-5].shape[1] + 1)], + device=self.talker.device, + dtype=input_id.dtype, + ) + ), + tts_pad_embed + + self.talker.get_input_embeddings()( + torch.tensor( + [[self.config.talker_config.codec_bos_id]], + device=self.talker.device, + dtype=input_id.dtype, + ) + ), + ], + dim=1, + ) + trailing_text_hidden = tts_pad_embed + else: + trailing_text_hidden = torch.cat( + ( + self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 4:-5])), + tts_eos_embed, + ), + dim=1, + ) + + # Prepend any additional parts + if talker_input_embed_parts: + talker_input_embed = torch.cat(talker_input_embed_parts + [talker_input_embed], dim=1) + + return talker_input_embed, trailing_text_hidden, tts_pad_embed + __all__ = [ "Qwen3TTSForConditionalGeneration", "Qwen3TTSTalkerForConditionalGeneration", "Qwen3TTSPreTrainedModel", "Qwen3TTSTalkerModel", + "StreamingChunkOutput", + "AsyncDecodingPipeline", ] diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py index 8514a725d4b..b1ea60f27ee 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py @@ -15,7 +15,7 @@ import base64 import io import urllib.request -from collections.abc import Iterable +from collections.abc import Generator, Iterable from dataclasses import dataclass from typing import Any from urllib.parse import urlparse @@ -87,6 +87,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Store vllm_config for potential future use self.vllm_config = vllm_config + # Track streaming state for accumulating audio chunks + self._streaming_state = {} def forward( self, @@ -104,7 +106,7 @@ def forward( positions: Position IDs (not used for TTS, but required by runner) intermediate_tensors: Intermediate tensors for pipeline parallelism (not used) inputs_embeds: Input embeddings (not used for TTS, but required by runner) - **kwargs: Additional arguments including task_type, sampling_metadata, etc. + **kwargs: Additional arguments including task_type, sampling_metadata, stream, etc. Returns: OmniOutput: Contains multimodal outputs with audio tensors @@ -121,6 +123,22 @@ def forward( speaker = runtime_additional_information.pop("speaker", ["uncle_fu"])[0] language = runtime_additional_information.pop("language", ["Auto"])[0] instruct = runtime_additional_information.pop("instruct", [""])[0] + # Check if streaming mode is requested + stream = ( + runtime_additional_information.pop("stream", [False])[0] + if "stream" in runtime_additional_information + else False + ) + chunk_size = ( + runtime_additional_information.pop("chunk_size", [25])[0] + if "chunk_size" in runtime_additional_information + else 25 + ) + left_context_size = ( + runtime_additional_information.pop("left_context_size", [25])[0] + if "left_context_size" in runtime_additional_information + else 25 + ) for key, value in runtime_additional_information.items(): if isinstance(value, list) and len(value) > 0: runtime_additional_information[key] = value[0] @@ -132,6 +150,19 @@ def forward( if not text: logger.info("Profile run detected (empty text). Capping max_new_tokens to 2.") runtime_additional_information["max_new_tokens"] = 2 + # Check if this is a streaming request + if stream: + return self.forward_streaming( + text=text, + task_type=task_type, + speaker=speaker, + language=language, + instruct=instruct, + chunk_size=chunk_size, + left_context_size=left_context_size, + **runtime_additional_information, + **kwargs, + ) # Call the appropriate generation method based on task_type if task_type == "CustomVoice": @@ -150,6 +181,161 @@ def forward( # Convert result to OmniOutput format return self.make_omni_output(result, **kwargs) + def forward_streaming( + self, + text: str, + task_type: str, + speaker: str = "uncle_fu", + language: str = "Auto", + instruct: str = "", + chunk_size: int = 25, + left_context_size: int = 25, + **kwargs: Any, + ) -> OmniOutput: + """ + Forward pass for streaming TTS generation. + + This method handles streaming by maintaining state and returning one audio chunk at a time. + The caller should call this repeatedly until is_finished is True. + + Args: + text: Text to synthesize + task_type: Type of TTS task (CustomVoice, VoiceDesign, Base) + speaker: Speaker name for CustomVoice + language: Language code + instruct: Instruction text + chunk_size: Number of codec frames per chunk + left_context_size: Context frames for smooth boundaries + **kwargs: Additional generation options + + Returns: + OmniOutput: Contains audio chunk and streaming status + """ + request_id = kwargs.get("request_id", "default") + + # Initialize streaming state if not exists + if request_id not in self._streaming_state: + # Start streaming generation + if task_type == "CustomVoice": + generator = self.model.generate_custom_voice_streaming( + text, + speaker=speaker, + language=language, + instruct=instruct, + chunk_size=chunk_size, + left_context_size=left_context_size, + **{ + k: v + for k, v in kwargs.items() + if k not in ["request_id", "sampling_metadata", "logits_index", "sampler"] + }, + ) + elif task_type == "VoiceDesign": + generator = self.model.generate_voice_design_streaming( + text, + instruct=instruct, + language=language, + chunk_size=chunk_size, + left_context_size=left_context_size, + **{ + k: v + for k, v in kwargs.items() + if k not in ["request_id", "sampling_metadata", "logits_index", "sampler"] + }, + ) + elif task_type == "Base": + generator = self.model.generate_voice_clone_streaming( + text, + language=language, + chunk_size=chunk_size, + left_context_size=left_context_size, + **{ + k: v + for k, v in kwargs.items() + if k not in ["request_id", "sampling_metadata", "logits_index", "sampler"] + }, + ) + else: + raise ValueError(f"Invalid task type: {task_type}") + + self._streaming_state[request_id] = { + "generator": generator, + "audio_chunks": [], + "is_finished": False, + "sample_rate": None, + } + + state = self._streaming_state[request_id] + + # If already finished, return final output + if state["is_finished"]: + # Clean up state and return accumulated audio + audio_chunks = state["audio_chunks"] + sr = state["sample_rate"] + del self._streaming_state[request_id] + + if audio_chunks: + full_audio = np.concatenate(audio_chunks) + # Use .clone().contiguous() to ensure safe serialization (avoid stride-0 issues) + audio_tensor = torch.from_numpy(full_audio).float().clone().contiguous() + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={ + "model_outputs": audio_tensor, + "sr": torch.tensor(sr, dtype=torch.int), + "finished": torch.tensor(True), + }, + ) + return OmniOutput(text_hidden_states=None, multimodal_outputs={"finished": torch.tensor(True)}) + + # Get next chunk from generator + try: + audio_chunk, is_finished, sr = next(state["generator"]) + state["audio_chunks"].append( + audio_chunk if isinstance(audio_chunk, np.ndarray) else audio_chunk.cpu().numpy() + ) + state["sample_rate"] = sr + state["is_finished"] = is_finished + + # Convert chunk to tensor + # Use .clone().contiguous() to ensure safe serialization (avoid stride-0 issues) + if isinstance(audio_chunk, np.ndarray): + audio_tensor = torch.from_numpy(audio_chunk).float().clone().contiguous() + else: + audio_tensor = audio_chunk.float().clone().contiguous() + + # Clean up state immediately when finished to prevent memory leaks + if is_finished: + del self._streaming_state[request_id] + + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={ + "model_outputs": audio_tensor, + "sr": torch.tensor(sr, dtype=torch.int), + "finished": torch.tensor(is_finished), + }, + ) + except StopIteration: + # Generator exhausted + audio_chunks = state["audio_chunks"] + sr = state["sample_rate"] + del self._streaming_state[request_id] + + if audio_chunks: + full_audio = np.concatenate(audio_chunks) + # Use .clone().contiguous() to ensure safe serialization (avoid stride-0 issues) + audio_tensor = torch.from_numpy(full_audio).float().clone().contiguous() + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={ + "model_outputs": audio_tensor, + "sr": torch.tensor(sr, dtype=torch.int), + "finished": torch.tensor(True), + }, + ) + return OmniOutput(text_hidden_states=None, multimodal_outputs={"finished": torch.tensor(True)}) + def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput | tuple, **kwargs: Any) -> OmniOutput: """ Make an OmniOutput object from model outputs. @@ -1086,3 +1272,279 @@ def get_supported_languages(self) -> list[str] | None: if supported is None: return None return sorted(supported) + + # ==================== Streaming Generation Methods ==================== + + @torch.inference_mode() + def generate_custom_voice_streaming( + self, + text: str | list[str], + speaker: str | list[str], + language: str | list[str] = None, + instruct: str | list[str] | None = None, + chunk_size: int = 25, + left_context_size: int = 25, + **kwargs: Any, + ) -> Generator[tuple[np.ndarray, bool, int], None, None]: + """ + Streaming version of generate_custom_voice. Yields audio chunks as they are generated. + + Args: + text: Text(s) to synthesize. + speaker: Speaker name(s). + language: Language(s) for each sample. + instruct: Optional instruction(s). + chunk_size: Number of codec frames per audio chunk (default 25). + left_context_size: Context frames for smooth chunk boundaries (default 25). + **kwargs: Additional generation options. + + Yields: + tuple[np.ndarray, bool, int]: (audio_chunk, is_finished, sample_rate) + """ + if self.model.tts_model_type != "custom_voice": + raise ValueError( + f"model with \ntokenizer_type: {self.model.tokenizer_type}\n" + f"tts_model_size: {self.model.tts_model_size}\n" + f"tts_model_type: {self.model.tts_model_type}\n" + "does not support generate_custom_voice_streaming" + ) + + texts = self._ensure_list(text) + languages = ( + self._ensure_list(language) + if isinstance(language, list) + else ([language] * len(texts) if language is not None else ["Auto"] * len(texts)) + ) + speakers = self._ensure_list(speaker) + if self.model.tts_model_size in "0b6": + instruct = None + instructs = ( + self._ensure_list(instruct) + if isinstance(instruct, list) + else ([instruct] * len(texts) if instruct is not None else [""] * len(texts)) + ) + + if len(languages) == 1 and len(texts) > 1: + languages = languages * len(texts) + if len(speakers) == 1 and len(texts) > 1: + speakers = speakers * len(texts) + if len(instructs) == 1 and len(texts) > 1: + instructs = instructs * len(texts) + + if not (len(texts) == len(languages) == len(speakers) == len(instructs)): + raise ValueError( + f"Batch size mismatch: text={len(texts)}, " + f"language={len(languages)}, speaker={len(speakers)}, " + f"instruct={len(instructs)}" + ) + + self._validate_languages(languages) + self._validate_speakers(speakers) + + input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts]) + + instruct_ids: list[torch.Tensor | None] = [] + for ins in instructs: + if ins is None or ins == "": + instruct_ids.append(None) + else: + instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0]) + + gen_kwargs = self._merge_generate_kwargs(**kwargs) + + # Use streaming generation + for audio_chunk, is_finished, sr in self.model.generate_streaming( + input_ids=input_ids, + instruct_ids=instruct_ids, + languages=languages, + speakers=speakers, + chunk_size=chunk_size, + left_context_size=left_context_size, + **gen_kwargs, + ): + # Convert tensor to numpy + if isinstance(audio_chunk, torch.Tensor): + audio_chunk = audio_chunk.cpu().numpy() + yield audio_chunk, is_finished, sr + + @torch.inference_mode() + def generate_voice_design_streaming( + self, + text: str | list[str], + instruct: str | list[str], + language: str | list[str] = None, + chunk_size: int = 25, + left_context_size: int = 25, + **kwargs: Any, + ) -> Generator[tuple[np.ndarray, bool, int], None, None]: + """ + Streaming version of generate_voice_design. Yields audio chunks as they are generated. + + Args: + text: Text(s) to synthesize. + instruct: Instruction text(s) describing the desired voice. + language: Language(s) for each sample. + chunk_size: Number of codec frames per audio chunk (default 25). + left_context_size: Context frames for smooth chunk boundaries (default 25). + **kwargs: Additional generation options. + + Yields: + tuple[np.ndarray, bool, int]: (audio_chunk, is_finished, sample_rate) + """ + if self.model.tts_model_type != "voice_design": + raise ValueError( + f"model with \ntokenizer_type: {self.model.tokenizer_type}\n" + f"tts_model_size: {self.model.tts_model_size}\n" + f"tts_model_type: {self.model.tts_model_type}\n" + "does not support generate_voice_design_streaming" + ) + + texts = self._ensure_list(text) + languages = ( + self._ensure_list(language) + if isinstance(language, list) + else ([language] * len(texts) if language is not None else ["Auto"] * len(texts)) + ) + instructs = self._ensure_list(instruct) + + if len(languages) == 1 and len(texts) > 1: + languages = languages * len(texts) + if len(instructs) == 1 and len(texts) > 1: + instructs = instructs * len(texts) + + if not (len(texts) == len(languages) == len(instructs)): + raise ValueError( + f"Batch size mismatch: text={len(texts)}, language={len(languages)}, instruct={len(instructs)}" + ) + + self._validate_languages(languages) + + input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts]) + + instruct_ids = [] + for ins in instructs: + if ins is None or ins == "": + instruct_ids.append(None) + else: + instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0]) + + gen_kwargs = self._merge_generate_kwargs(**kwargs) + + for audio_chunk, is_finished, sr in self.model.generate_streaming( + input_ids=input_ids, + instruct_ids=instruct_ids, + languages=languages, + speakers=[None] * len(texts), + chunk_size=chunk_size, + left_context_size=left_context_size, + **gen_kwargs, + ): + if isinstance(audio_chunk, torch.Tensor): + audio_chunk = audio_chunk.cpu().numpy() + yield audio_chunk, is_finished, sr + + @torch.inference_mode() + def generate_voice_clone_streaming( + self, + text: str | list[str], + language: str | list[str] = None, + ref_audio: AudioLike | list[AudioLike] | None = None, + ref_text: str | list[str | None] | None = None, + x_vector_only_mode: bool | list[bool] = False, + voice_clone_prompt: dict[str, Any] | list[VoiceClonePromptItem] | None = None, + chunk_size: int = 25, + left_context_size: int = 25, + **kwargs: Any, + ) -> Generator[tuple[np.ndarray, bool, int], None, None]: + """ + Streaming version of generate_voice_clone. Yields audio chunks as they are generated. + + Args: + text: Text(s) to synthesize. + language: Language(s) for each sample. + ref_audio: Reference audio(s) for voice cloning. + ref_text: Reference text(s) for ICL mode. + x_vector_only_mode: If True, only speaker embedding is used. + voice_clone_prompt: Pre-computed voice clone prompt items. + chunk_size: Number of codec frames per audio chunk (default 25). + left_context_size: Context frames for smooth chunk boundaries (default 25). + **kwargs: Additional generation options. + + Yields: + tuple[np.ndarray, bool, int]: (audio_chunk, is_finished, sample_rate) + """ + if self.model.tts_model_type != "base": + raise ValueError( + f"model with \ntokenizer_type: {self.model.tokenizer_type}\n" + f"tts_model_size: {self.model.tts_model_size}\n" + f"tts_model_type: {self.model.tts_model_type}\n" + "does not support generate_voice_clone_streaming" + ) + + texts = self._ensure_list(text) + languages = ( + self._ensure_list(language) + if isinstance(language, list) + else ([language] * len(texts) if language is not None else ["Auto"] * len(texts)) + ) + if len(languages) == 1 and len(texts) > 1: + languages = languages * len(texts) + if len(texts) != len(languages): + raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}") + + self._validate_languages(languages) + + if voice_clone_prompt is None: + if ref_audio is None: + sample_rate = int(self.model.speaker_encoder_sample_rate) + ref_audio = (np.zeros(sample_rate, dtype=np.float32), sample_rate) + logger.warning("ref_audio is not provided. Using silent clip.") + prompt_items = self.create_voice_clone_prompt( + ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=x_vector_only_mode + ) + if len(prompt_items) == 1 and len(texts) > 1: + prompt_items = prompt_items * len(texts) + if len(prompt_items) != len(texts): + raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}") + voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items) + ref_texts_for_ids = [it.ref_text for it in prompt_items] + else: + if isinstance(voice_clone_prompt, list): + prompt_items = voice_clone_prompt + if len(prompt_items) == 1 and len(texts) > 1: + prompt_items = prompt_items * len(texts) + if len(prompt_items) != len(texts): + raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}") + voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items) + ref_texts_for_ids = [it.ref_text for it in prompt_items] + else: + voice_clone_prompt_dict = voice_clone_prompt + ref_texts_for_ids = None + + input_texts = [self._build_assistant_text(t) for t in texts] + input_ids = self._tokenize_texts(input_texts) + + ref_ids = None + if ref_texts_for_ids is not None: + ref_ids = [] + for i, rt in enumerate(ref_texts_for_ids): + if rt is None or rt == "": + ref_ids.append(None) + else: + ref_tok = self._tokenize_texts([self._build_ref_text(rt)])[0] + ref_ids.append(ref_tok) + + gen_kwargs = self._merge_generate_kwargs(**kwargs) + + for audio_chunk, is_finished, sr in self.model.generate_streaming( + input_ids=input_ids, + ref_ids=ref_ids, + voice_clone_prompt=voice_clone_prompt_dict, + languages=languages, + chunk_size=chunk_size, + left_context_size=left_context_size, + **gen_kwargs, + ): + if isinstance(audio_chunk, torch.Tensor): + audio_chunk = audio_chunk.cpu().numpy() + yield audio_chunk, is_finished, sr diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index 78fc965fe32..b951e9ab9d6 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -855,7 +855,7 @@ def forward(self, codes): hidden = self.quantizer.decode(codes) hidden = self.pre_conv(hidden).transpose(1, 2) - hidden = self.pre_transformer(inputs_embeds=hidden).last_hidden_state + hidden = self.pre_transformer(inputs_embeds=hidden, use_cache=False).last_hidden_state hidden = hidden.permute(0, 2, 1) for blocks in self.upsample: for block in blocks: @@ -865,7 +865,37 @@ def forward(self, codes): wav = block(wav) return wav.clamp(min=-1, max=1) - def chunked_decode(self, codes, chunk_size=300, left_context_size=25): + def chunked_decode(self, codes: torch.Tensor, chunk_size: int = 300, left_context_size: int = 25) -> torch.Tensor: + """ + Decode codec tokens to audio waveform in chunks. + + Args: + codes: Codec tokens, shape [B, num_quantizers, T] + chunk_size: Number of frames per chunk + left_context_size: Context frames from previous chunk + + Returns: + torch.Tensor: Decoded audio waveform + """ + # Filter out invalid frames: decoder codebook only has `codebook_size` entries + # Any token >= codebook_size (including EOS tokens) should be truncated + codec_valid_max = self.config.codebook_size + + # Check if any layer has invalid token at each time step + invalid_mask = (codes >= codec_valid_max).any(dim=1) # [B, T] + + # Truncate at first invalid position + if invalid_mask.any(): + for b in range(codes.shape[0]): + if invalid_mask[b].any(): + first_invalid = invalid_mask[b].nonzero(as_tuple=True)[0][0].item() + codes = codes[:, :, :first_invalid] + break # Assuming batch size 1 for now + + if codes.shape[-1] == 0: + # All tokens were invalid, return empty audio + return torch.zeros((codes.shape[0], 1, 0), device=codes.device) + wavs = [] start_index = 0 while start_index < codes.shape[-1]: diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 64569846a00..a8d6f25eb9c 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -1,11 +1,11 @@ from dataclasses import dataclass, field from typing import Any - +import queue, threading import torch from PIL import Image from vllm.outputs import RequestOutput from vllm.v1.outputs import ModelRunnerOutput - +from collections.abc import Iterator from vllm_omni.inputs.data import OmniPromptType @@ -128,6 +128,15 @@ def multimodal_output(self) -> dict[str, Any]: For diffusion outputs, this returns the local _multimodal_output field. """ if self.request_output is not None: + # Handle case where request_output is a list (e.g., from batched generation) + if isinstance(self.request_output, list): + for req_out in self.request_output: + if hasattr(req_out, "outputs") and req_out.outputs: + for output in req_out.outputs: + mm = getattr(output, "multimodal_output", None) + if mm: + return mm + return {} # Check completion outputs first (where multimodal_output is attached) if self.request_output.outputs: for output in self.request_output.outputs: @@ -250,3 +259,164 @@ def __repr__(self) -> str: ] return f"OmniRequestOutput({', '.join(parts)})" + + +@dataclass +class StreamingChunkOutput: + """Output for each streaming chunk during TTS generation.""" + + codec_codes: torch.Tensor # [chunk_size, num_quantizers] codec tokens for this chunk + hidden_states: torch.Tensor | None = None # corresponding hidden states + chunk_idx: int = 0 # chunk index + is_finished: bool = False # whether generation is complete + total_generated: int = 0 # total tokens generated so far + + +class AsyncDecodingPipeline: + """ + Asynchronous decoding pipeline that runs audio decoding in a background thread + while generation continues in the main thread. + """ + + def __init__( + self, + speech_tokenizer, + ref_code: torch.Tensor | None = None, + left_context_size: int = 25, + max_queue_size: int = 10, + ): + self.speech_tokenizer = speech_tokenizer + self.ref_code = ref_code + self.left_context_size = left_context_size + + # Queue for codec chunks to be decoded + # Each item is (codes_with_context, is_last, context_frames_to_remove) + self._input_queue: queue.Queue = queue.Queue(maxsize=max_queue_size) + # Queue for decoded audio chunks + self._output_queue: queue.Queue = queue.Queue() + + self._decode_thread: threading.Thread | None = None + self._stop_event = threading.Event() + self._started = False + self._all_codes: list[torch.Tensor] = [] + self._sample_rate: int | None = None + + def start(self): + """Start the background decoding thread.""" + if self._started: + return + self._stop_event.clear() + self._decode_thread = threading.Thread(target=self._decode_worker, daemon=True) + self._decode_thread.start() + self._started = True + + def _decode_worker(self): + """Background worker that decodes codec chunks to audio.""" + chunk_idx = 0 + + while not self._stop_event.is_set(): + try: + item = self._input_queue.get(timeout=0.1) + except queue.Empty: + continue + + if item is None: # Sentinel to stop + break + + codes_chunk, is_last, context_frames = item + + # Decode the chunk + try: + # codes shape: [seq_len, num_quantizers] -> [1, seq_len, num_quantizers] + # model.decode expects [B, T, K] and internally transposes to [B, K, T] + codes_for_decode = codes_chunk.unsqueeze(0) + wavs, sr = self.speech_tokenizer.decode({"audio_codes": codes_for_decode}) + audio_chunk = wavs[0] # numpy array + self._sample_rate = sr + + # Remove context samples from the beginning of the audio + if context_frames > 0: + upsample_rate = getattr(self.speech_tokenizer.model, "decode_upsample_rate", 2000) + context_samples = context_frames * upsample_rate + if context_samples < len(audio_chunk): + audio_chunk = audio_chunk[context_samples:] + + self._output_queue.put((audio_chunk, is_last, sr, None)) + except Exception as e: + self._output_queue.put((None, is_last, None, e)) + + chunk_idx += 1 + + def submit_chunk(self, codec_codes: torch.Tensor, is_last: bool = False): + """Submit a chunk of codec codes for decoding.""" + self._all_codes.append(codec_codes) + + # Prepare chunk with context and track how many context frames were added + context_frames = 0 + + if len(self._all_codes) == 1: + # First chunk - prepend ref_code if available + if self.ref_code is not None: + codes_with_context = torch.cat([self.ref_code, codec_codes], dim=0) + context_frames = self.ref_code.shape[0] + else: + codes_with_context = codec_codes + context_frames = 0 + else: + # Subsequent chunks - add left context from previously generated codes + context_codes = torch.cat(self._all_codes[:-1], dim=0) + context_frames = min(self.left_context_size, context_codes.shape[0]) + context_start = context_codes.shape[0] - context_frames + context = context_codes[context_start:] + codes_with_context = torch.cat([context, codec_codes], dim=0) + + self._input_queue.put((codes_with_context, is_last, context_frames)) + + # Limit memory usage: only keep enough codes for left_context_size + # Merge old codes if we have too many chunks + if len(self._all_codes) > 10: + # Merge all codes and keep only the last left_context_size frames + all_merged = torch.cat(self._all_codes, dim=0) + if all_merged.shape[0] > self.left_context_size: + self._all_codes = [all_merged[-self.left_context_size :]] + else: + self._all_codes = [all_merged] + + def get_decoded_chunk(self, timeout: float | None = None) -> tuple[Any, bool, int | None, Exception | None]: + """ + Get the next decoded audio chunk. + + Returns: + tuple: (audio_chunk, is_last, sample_rate, error) + """ + try: + return self._output_queue.get(timeout=timeout) + except queue.Empty: + return None, False, None, None + + def iter_decoded_chunks(self) -> Iterator[tuple[Any, bool, int]]: + """Iterate over decoded audio chunks as they become available.""" + while True: + audio, is_last, sr, error = self.get_decoded_chunk(timeout=1.0) + if error is not None: + raise error + if audio is not None: + yield audio, is_last, sr + if is_last: + break + + def stop(self): + """Stop the decoding pipeline.""" + self._stop_event.set() + self._input_queue.put(None) # Sentinel + if self._decode_thread is not None: + self._decode_thread.join(timeout=2.0) + self._started = False + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + return False