-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Qwen3TTS][Feat] Code2Wav batched decoding #1426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
hsliuustc0106
merged 16 commits into
vllm-project:main
from
JuanPZuluaga:feat/code2wav-batched-decode
Feb 24, 2026
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
b1f669d
[Qwen3TTS][feat] Code2Wav batched decoding
da84873
move to forward pass instead of helper
9dddbb8
update to the benchmark scripts
d1a4bd9
added batched decoding stage config
a6dfea0
lint
5c5cd1a
fix logic in e2e.py
f67f932
change split req_ids and support UBatchSlice
5cbc214
guard for wavs returned; e2e assert
27e3126
log and logger improv
f242d25
log and assert
1a0d5fc
revert lint
8a9589b
revert lint2
157363d
assert
a4fd379
lint
9e1cfcc
fix boundaries issues
c399f51
Merge branch 'main' into feat/code2wav-batched-decode
JuanPZuluaga File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 12 additions & 0 deletions
12
examples/offline_inference/qwen3_tts/benchmark_prompts.txt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| Hello, welcome to the voice synthesis benchmark test. | ||
| She said she would be here by noon, but nobody showed up. | ||
| The quick brown fox jumps over the lazy dog near the riverbank. | ||
| I can't believe how beautiful the sunset looks from up here on the mountain. | ||
| Please remember to bring your identification documents to the appointment tomorrow morning. | ||
| Have you ever wondered what it would be like to travel through time and visit ancient civilizations? | ||
| The restaurant on the corner serves the best pasta I have ever tasted in my entire life. | ||
| After the meeting, we should discuss the quarterly results and plan for the next phase. | ||
| Learning a new language takes patience, practice, and a genuine curiosity about other cultures. | ||
| The train leaves at half past seven, so we need to arrive at the station before then. | ||
| Could you please turn down the music a little bit, I'm trying to concentrate on my work. | ||
| It was a dark and stormy night when the old lighthouse keeper heard a knock at the door. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -248,6 +248,13 @@ def main(args): | |
| Args: | ||
| args: Parsed CLI args from parse_args(). | ||
| """ | ||
| if args.batch_size < 1 or (args.batch_size & (args.batch_size - 1)) != 0: | ||
| raise ValueError( | ||
| f"--batch-size must be a power of two (got {args.batch_size}); " | ||
| "non-power-of-two values do not align with CUDA graph capture sizes " | ||
| "of Code2Wav." | ||
| ) | ||
|
|
||
| query_func = query_map[args.query_type] | ||
| if args.query_type in {"CustomVoice", "VoiceDesign"}: | ||
| query_result = query_func(use_batch_sample=args.use_batch_sample) | ||
|
|
@@ -260,39 +267,69 @@ def main(args): | |
| query_result = query_func() | ||
|
|
||
| model_name = query_result.model_name | ||
|
|
||
| # Load prompts from text file if provided. | ||
| # Use the default query as a template so task-specific fields | ||
| # (e.g. ref_audio for Base) are preserved; only override text. | ||
| if args.txt_prompts: | ||
| with open(args.txt_prompts) as f: | ||
| lines = [line.strip() for line in f if line.strip()] | ||
| if not lines: | ||
| raise ValueError(f"No valid prompts found in {args.txt_prompts}") | ||
| template = query_result.inputs | ||
| if isinstance(template, list): | ||
| template = template[0] | ||
| template_info = template["additional_information"] | ||
| inputs = [] | ||
| for text in lines: | ||
| additional_information = {**template_info, "text": [text]} | ||
| inputs.append( | ||
| { | ||
| "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), | ||
| "additional_information": additional_information, | ||
| } | ||
| ) | ||
| else: | ||
| inputs = query_result.inputs | ||
| if not isinstance(inputs, list): | ||
| inputs = [inputs] | ||
|
|
||
| omni = Omni( | ||
| model=model_name, | ||
| stage_configs_path=args.stage_configs_path, | ||
| log_stats=args.log_stats, | ||
| stage_init_timeout=args.stage_init_timeout, | ||
| ) | ||
|
|
||
| output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav | ||
| output_dir = args.output_dir | ||
| os.makedirs(output_dir, exist_ok=True) | ||
|
|
||
| omni_generator = omni.generate(query_result.inputs, sampling_params_list=None) | ||
| for stage_outputs in omni_generator: | ||
| for output in stage_outputs.request_output: | ||
| request_id = output.request_id | ||
| audio_data = output.outputs[0].multimodal_output["audio"] | ||
| # async_chunk mode returns a list of chunks; concatenate them. | ||
| if isinstance(audio_data, list): | ||
| audio_tensor = torch.cat(audio_data, dim=-1) | ||
| else: | ||
| audio_tensor = audio_data | ||
| output_wav = os.path.join(output_dir, f"output_{request_id}.wav") | ||
| sr_val = output.outputs[0].multimodal_output["sr"] | ||
| audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1]) | ||
| # Convert to numpy array and ensure correct format | ||
| audio_numpy = audio_tensor.float().detach().cpu().numpy() | ||
|
|
||
| # Ensure audio is 1D (flatten if needed) | ||
| if audio_numpy.ndim > 1: | ||
| audio_numpy = audio_numpy.flatten() | ||
|
|
||
| # Save audio file with explicit WAV format | ||
| sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV") | ||
| print(f"Request ID: {request_id}, Saved audio to {output_wav}") | ||
| batch_size = args.batch_size | ||
| for batch_start in range(0, len(inputs), batch_size): | ||
| batch = inputs[batch_start : batch_start + batch_size] | ||
| omni_generator = omni.generate(batch, sampling_params_list=None) | ||
| for stage_outputs in omni_generator: | ||
| for output in stage_outputs.request_output: | ||
| request_id = output.request_id | ||
| audio_data = output.outputs[0].multimodal_output["audio"] | ||
| # async_chunk mode returns a list of chunks; concatenate them. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason not to always pass a list here? The conditional unwrapping means
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
| if isinstance(audio_data, list): | ||
| audio_tensor = torch.cat(audio_data, dim=-1) | ||
| else: | ||
| audio_tensor = audio_data | ||
| output_wav = os.path.join(output_dir, f"output_{request_id}.wav") | ||
| sr_val = output.outputs[0].multimodal_output["sr"] | ||
| audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1]) | ||
| # Convert to numpy array and ensure correct format | ||
| audio_numpy = audio_tensor.float().detach().cpu().numpy() | ||
|
|
||
| # Ensure audio is 1D (flatten if needed) | ||
| if audio_numpy.ndim > 1: | ||
| audio_numpy = audio_numpy.flatten() | ||
|
|
||
| # Save audio file with explicit WAV format | ||
| sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV") | ||
| print(f"Request ID: {request_id}, Saved audio to {output_wav}") | ||
|
|
||
|
|
||
| def parse_args(): | ||
|
|
@@ -341,9 +378,9 @@ def parse_args(): | |
| help="Threshold for using shared memory in bytes (default: 65536)", | ||
| ) | ||
| parser.add_argument( | ||
| "--output-wav", | ||
| "--output-dir", | ||
| default="output_audio", | ||
| help="[Deprecated] Output wav directory (use --output-dir).", | ||
| help="Output directory for generated wav files (default: output_audio).", | ||
| ) | ||
| parser.add_argument( | ||
| "--num-prompts", | ||
|
|
@@ -401,6 +438,12 @@ def parse_args(): | |
| choices=["icl", "xvec_only"], | ||
| help="Mode tag for Base query x_vector_only_mode (default: icl).", | ||
| ) | ||
| parser.add_argument( | ||
| "--batch-size", | ||
| type=int, | ||
| default=1, | ||
| help="Number of prompts per batch (default: 1, sequential).", | ||
| ) | ||
|
|
||
| return parser.parse_args() | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
--batch-sizemust match a CUDA graph capture size, a runtime power-of-two check would save users from a cryptic CUDA graph error.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed