diff --git a/examples/offline_inference/qwen3_tts/README.md b/examples/offline_inference/qwen3_tts/README.md index ddbd3fcbcc8..2b04cad75ab 100644 --- a/examples/offline_inference/qwen3_tts/README.md +++ b/examples/offline_inference/qwen3_tts/README.md @@ -87,7 +87,20 @@ Examples: python end2end.py --query-type Base --mode-tag icl ``` +## Batched Decoding + +The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_batch_size > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`. + +``` +python end2end.py --query-type CustomVoice \ + --txt-prompts benchmark_prompts.txt \ + --batch-size 4 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml +``` + +**Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_batch_size >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_batch_size`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently. + ## Notes - The script uses the model paths embedded in `end2end.py`. Update them if your local cache path differs. -- Use `--output-dir` (preferred) or `--output-wav` to change the output folder. +- Use `--output-dir` to change the output folder. diff --git a/examples/offline_inference/qwen3_tts/benchmark_prompts.txt b/examples/offline_inference/qwen3_tts/benchmark_prompts.txt new file mode 100644 index 00000000000..92577a418cd --- /dev/null +++ b/examples/offline_inference/qwen3_tts/benchmark_prompts.txt @@ -0,0 +1,12 @@ +Hello, welcome to the voice synthesis benchmark test. +She said she would be here by noon, but nobody showed up. +The quick brown fox jumps over the lazy dog near the riverbank. +I can't believe how beautiful the sunset looks from up here on the mountain. +Please remember to bring your identification documents to the appointment tomorrow morning. +Have you ever wondered what it would be like to travel through time and visit ancient civilizations? +The restaurant on the corner serves the best pasta I have ever tasted in my entire life. +After the meeting, we should discuss the quarterly results and plan for the next phase. +Learning a new language takes patience, practice, and a genuine curiosity about other cultures. +The train leaves at half past seven, so we need to arrive at the station before then. +Could you please turn down the music a little bit, I'm trying to concentrate on my work. +It was a dark and stormy night when the old lighthouse keeper heard a knock at the door. diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index 12e5e193542..fec515d303a 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -248,6 +248,13 @@ def main(args): Args: args: Parsed CLI args from parse_args(). """ + if args.batch_size < 1 or (args.batch_size & (args.batch_size - 1)) != 0: + raise ValueError( + f"--batch-size must be a power of two (got {args.batch_size}); " + "non-power-of-two values do not align with CUDA graph capture sizes " + "of Code2Wav." + ) + query_func = query_map[args.query_type] if args.query_type in {"CustomVoice", "VoiceDesign"}: query_result = query_func(use_batch_sample=args.use_batch_sample) @@ -260,6 +267,33 @@ def main(args): query_result = query_func() model_name = query_result.model_name + + # Load prompts from text file if provided. + # Use the default query as a template so task-specific fields + # (e.g. ref_audio for Base) are preserved; only override text. + if args.txt_prompts: + with open(args.txt_prompts) as f: + lines = [line.strip() for line in f if line.strip()] + if not lines: + raise ValueError(f"No valid prompts found in {args.txt_prompts}") + template = query_result.inputs + if isinstance(template, list): + template = template[0] + template_info = template["additional_information"] + inputs = [] + for text in lines: + additional_information = {**template_info, "text": [text]} + inputs.append( + { + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, + } + ) + else: + inputs = query_result.inputs + if not isinstance(inputs, list): + inputs = [inputs] + omni = Omni( model=model_name, stage_configs_path=args.stage_configs_path, @@ -267,32 +301,35 @@ def main(args): stage_init_timeout=args.stage_init_timeout, ) - output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav + output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) - omni_generator = omni.generate(query_result.inputs, sampling_params_list=None) - for stage_outputs in omni_generator: - for output in stage_outputs.request_output: - request_id = output.request_id - audio_data = output.outputs[0].multimodal_output["audio"] - # async_chunk mode returns a list of chunks; concatenate them. - if isinstance(audio_data, list): - audio_tensor = torch.cat(audio_data, dim=-1) - else: - audio_tensor = audio_data - output_wav = os.path.join(output_dir, f"output_{request_id}.wav") - sr_val = output.outputs[0].multimodal_output["sr"] - audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1]) - # Convert to numpy array and ensure correct format - audio_numpy = audio_tensor.float().detach().cpu().numpy() - - # Ensure audio is 1D (flatten if needed) - if audio_numpy.ndim > 1: - audio_numpy = audio_numpy.flatten() - - # Save audio file with explicit WAV format - sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV") - print(f"Request ID: {request_id}, Saved audio to {output_wav}") + batch_size = args.batch_size + for batch_start in range(0, len(inputs), batch_size): + batch = inputs[batch_start : batch_start + batch_size] + omni_generator = omni.generate(batch, sampling_params_list=None) + for stage_outputs in omni_generator: + for output in stage_outputs.request_output: + request_id = output.request_id + audio_data = output.outputs[0].multimodal_output["audio"] + # async_chunk mode returns a list of chunks; concatenate them. + if isinstance(audio_data, list): + audio_tensor = torch.cat(audio_data, dim=-1) + else: + audio_tensor = audio_data + output_wav = os.path.join(output_dir, f"output_{request_id}.wav") + sr_val = output.outputs[0].multimodal_output["sr"] + audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1]) + # Convert to numpy array and ensure correct format + audio_numpy = audio_tensor.float().detach().cpu().numpy() + + # Ensure audio is 1D (flatten if needed) + if audio_numpy.ndim > 1: + audio_numpy = audio_numpy.flatten() + + # Save audio file with explicit WAV format + sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV") + print(f"Request ID: {request_id}, Saved audio to {output_wav}") def parse_args(): @@ -341,9 +378,9 @@ def parse_args(): help="Threshold for using shared memory in bytes (default: 65536)", ) parser.add_argument( - "--output-wav", + "--output-dir", default="output_audio", - help="[Deprecated] Output wav directory (use --output-dir).", + help="Output directory for generated wav files (default: output_audio).", ) parser.add_argument( "--num-prompts", @@ -401,6 +438,12 @@ def parse_args(): choices=["icl", "xvec_only"], help="Mode tag for Base query x_vector_only_mode (default: icl).", ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Number of prompts per batch (default: 1, sequential).", + ) return parser.parse_args() diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 339268f34f0..a87027de18f 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -9,6 +9,7 @@ import torch.nn as nn from transformers.utils.hub import cached_file from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger from vllm_omni.model_executor.models.output_templates import OmniOutput @@ -116,6 +117,28 @@ def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None) -> None: return None + def _split_request_ids(self, ids: torch.Tensor, seq_token_counts: list[int] | None = None) -> list[torch.Tensor]: + """Split concatenated input_ids into per-request segments. + + Uses seq_token_counts (injected by the runner via model_kwargs) when + available, falling back to forward-context ubatch_slices when + micro-batching is active. Returns [ids] for single-request batches. + """ + if seq_token_counts is not None and len(seq_token_counts) > 1: + boundaries = [0] + for count in seq_token_counts: + boundaries.append(boundaries[-1] + count) + n = ids.numel() + return [ids[boundaries[i] : min(boundaries[i + 1], n)] for i in range(len(seq_token_counts))] + if is_forward_context_available(): + slices = get_forward_context().ubatch_slices + if slices is not None and len(slices) > 1 and not any(hasattr(s, "token_slice") for s in slices): + boundaries = [0] + for s in slices: + boundaries.append(boundaries[-1] + s) + return [ids[boundaries[i] : boundaries[i + 1]] for i in range(len(boundaries) - 1)] + return [ids] + @torch.no_grad() def forward( self, @@ -124,98 +147,129 @@ def forward( intermediate_tensors: Any = None, inputs_embeds: torch.Tensor | None = None, **kwargs: Any, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> OmniOutput: """Decode codec codes into audio waveform. - input_ids layout: [codec_context_frames, *flat_codes] + input_ids layout per request: [codec_context_frames, *flat_codes] where flat_codes is codebook-major [q*F]. + + When batched, uses forward context ubatch_slices to split the + concatenated input_ids and decode via a single batched forward pass. """ - tok = self._ensure_speech_tokenizer_loaded() + self._ensure_speech_tokenizer_loaded() assert self._num_quantizers is not None + assert self._decode_upsample_rate is not None assert self._output_sample_rate is not None - sr_val = self._output_sample_rate - empty_ret = ( - torch.zeros((0,), dtype=torch.float32), - torch.tensor(sr_val, dtype=torch.int32), - ) - - if input_ids is None: - return empty_ret - + tok = self._speech_tokenizer q = int(self._num_quantizers) - ids = input_ids.reshape(-1).to(dtype=torch.long) - n_tokens = ids.numel() - - if n_tokens == 0: - return empty_ret - - # input_ids[0] = codec_context_frames (prepended by stage_input_processor). - ctx_frames = int(ids[0].item()) - ids = ids[1:] - n_tokens = ids.numel() - - if n_tokens == 0: - return empty_ret - - # Warmup / dummy_run: not divisible by num_quantizers. - if n_tokens % q != 0: - logger.warning( - "Code2Wav input_ids length %d not divisible by num_quantizers %d, " - "likely a warmup run; returning empty audio.", - n_tokens, - q, + upsample = int(self._decode_upsample_rate) + sr_val = int(self._output_sample_rate) + sr_tensor = torch.tensor(sr_val, dtype=torch.int32) + empty = torch.zeros((0,), dtype=torch.float32) + + if input_ids is None or input_ids.numel() == 0: + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"model_outputs": [empty], "sr": [sr_tensor]}, ) - return empty_ret - total_frames = n_tokens // q - - # Reshape codebook-major flat [q*F] -> [q, F] -> [F, q] for SpeechTokenizer. - codes_fq = ids.reshape(q, total_frames).transpose(0, 1).contiguous() + ids = input_ids.reshape(-1).to(dtype=torch.long) + request_ids_list = self._split_request_ids(ids, kwargs.get("seq_token_counts")) + + # Parse each request: extract ctx_frames, validate, reshape codes. + # input_ids layout per request: [codec_context_frames, *flat_codes] + # where flat_codes is codebook-major [q*F]. + parsed = [] # (ctx_frames, actual_frames) + valid_codes = [] + valid_indices = [] + for i, req_ids in enumerate(request_ids_list): + if req_ids.numel() < 2: + parsed.append((0, 0)) + continue + ctx_frames = int(req_ids[0].item()) + flat = req_ids[1:] + n = flat.numel() + # Warmup / dummy_run: not divisible by num_quantizers. + if n == 0 or n % q != 0: + if n > 0: + logger.warning( + "Code2Wav input_ids length %d not divisible by num_quantizers %d, " + "likely a warmup run; returning empty audio.", + n, + q, + ) + parsed.append((0, 0)) + continue + frames = n // q + # Reshape codebook-major flat [q*F] -> [q, F] -> [F, q] for SpeechTokenizer. + codes_fq = flat.reshape(q, frames).transpose(0, 1).contiguous() + parsed.append((ctx_frames, frames)) + valid_codes.append({"audio_codes": codes_fq}) + valid_indices.append(i) + + num_req = len(request_ids_list) + if not valid_codes: + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={ + "model_outputs": [empty] * num_req, + "sr": [sr_tensor] * num_req, + }, + ) - if not self._logged_codec_stats and total_frames > 1: + if not self._logged_codec_stats: self._logged_codec_stats = True try: - uniq = int(torch.unique(codes_fq).numel()) - cmin = int(codes_fq.min().item()) - cmax = int(codes_fq.max().item()) - head = codes_fq[: min(2, total_frames), : min(8, q)].cpu().tolist() + c = valid_codes[0]["audio_codes"] logger.info( - "Code2Wav codec: frames=%d q=%d uniq=%d range=[%d,%d] head=%s", - total_frames, + "Code2Wav codec: frames=%d q=%d uniq=%d range=[%d,%d] head=%s batch=%d", + c.shape[0], q, - uniq, - cmin, - cmax, - head, + int(torch.unique(c).numel()), + int(c.min().item()), + int(c.max().item()), + c[: min(2, c.shape[0]), : min(8, q)].cpu().tolist(), + len(valid_codes), ) except Exception: pass - wavs, sr = tok.decode({"audio_codes": codes_fq}) - if not wavs: - raise ValueError("SpeechTokenizer code2wav produced empty waveform list.") - audio_np = wavs[0].astype(np.float32, copy=False) - - # Trim left-context waveform samples (streaming sliding window). - if ctx_frames > 0: - upsample = self._decode_upsample_rate - if upsample is None or upsample <= 0: - raise ValueError(f"Invalid decode upsample rate: {upsample}") - cut = ctx_frames * upsample - if cut < audio_np.shape[0]: - audio_np = audio_np[cut:] - else: - logger.warning( - "Context trim %d >= decoded length %d; returning empty audio.", - cut, - audio_np.shape[0], - ) - return empty_ret + # Batched decode: single forward pass through SpeechTokenizer. + wavs, _ = tok.decode(valid_codes) + if len(wavs) != len(valid_codes): + raise RuntimeError(f"Code2Wav returned {len(wavs)} waveforms for {len(valid_codes)} requests") + + # Build per-request outputs, trimming padding and left-context. + audios = [empty] * num_req + srs = [sr_tensor] * num_req + + for j, idx in enumerate(valid_indices): + ctx_frames, actual_frames = parsed[idx] + audio_np = wavs[j].astype(np.float32, copy=False) + # Trim decoder padding (output may be longer due to batch padding). + expected_len = actual_frames * upsample + if audio_np.shape[0] > expected_len: + audio_np = audio_np[:expected_len] + # Trim left-context waveform samples (streaming sliding window). + if ctx_frames > 0: + cut = ctx_frames * upsample + if cut < audio_np.shape[0]: + audio_np = audio_np[cut:] + else: + logger.warning( + "Context trim %d >= decoded length %d; returning empty audio.", + cut, + audio_np.shape[0], + ) + continue + if audio_np.shape[0] > 0: + audios[idx] = torch.from_numpy(audio_np).to(dtype=torch.float32).reshape(-1) - audio_tensor = torch.from_numpy(audio_np).to(dtype=torch.float32).reshape(-1) - sr_tensor = torch.tensor(int(sr), dtype=torch.int32) - return audio_tensor, sr_tensor + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"model_outputs": audios, "sr": srs}, + ) def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput: if isinstance(model_outputs, OmniOutput): diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml new file mode 100644 index 00000000000..d737391095b --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml @@ -0,0 +1,105 @@ +# Same as qwen3_tts.yaml with batched talker and code2wav. +# Stage 0: max_batch_size 4, stage 1: max_batch_size 4. +# max_batch_size must be a power of two to align with CUDA graph capture sizes +# (stage 0) and must match --batch-size in end2end.py / benchmark scripts. +async_chunk: true +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 4 + engine_args: + model_stage: qwen3_tts + model_arch: Qwen3TTSTalkerForConditionalGeneration + # Force stage-specific registered architecture. + hf_overrides: + architectures: [Qwen3TTSTalkerForConditionalGeneration] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: false + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.3 + distributed_executor_backend: "mp" + max_num_batched_tokens: 512 + max_model_len: 4096 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk + # Use named connector to apply runtime.connectors.extra. + output_connectors: + to_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 4 + engine_args: + model_stage: code2wav + model_arch: Qwen3TTSCode2Wav + # Force stage-specific registered architecture. + hf_overrides: + architectures: [Qwen3TTSCode2Wav] + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.2 + distributed_executor_backend: "mp" + # Must be divisible by num_code_groups and cover (left_context + chunk). + max_num_batched_tokens: 8192 + # async_chunk appends windows per step; max_model_len must cover accumulated stream. + max_model_len: 32768 + engine_input_source: [0] + final_output: true + final_output_type: audio + # Distributed connector configuration + input_connectors: + from_stage_0: connector_of_shared_memory + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + # Frame-aligned codec streaming transport. + codec_streaming: true + # Connector polling / timeout (unit: loop count, sleep interval in seconds). + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + # Align with Omni: small chunks with sufficient context overlap. + codec_chunk_frames: 25 + codec_left_context_frames: 25 + + edges: + - from: 0 + to: 1 + window_size: -1