Skip to content
15 changes: 14 additions & 1 deletion examples/offline_inference/qwen3_tts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,20 @@ Examples:
python end2end.py --query-type Base --mode-tag icl
```

## Batched Decoding

The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_batch_size > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.

```
python end2end.py --query-type CustomVoice \
--txt-prompts benchmark_prompts.txt \
--batch-size 4 \
--stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
```

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since --batch-size must match a CUDA graph capture size, a runtime power-of-two check would save users from a cryptic CUDA graph error.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

**Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_batch_size >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_batch_size`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently.

## Notes

- The script uses the model paths embedded in `end2end.py`. Update them if your local cache path differs.
- Use `--output-dir` (preferred) or `--output-wav` to change the output folder.
- Use `--output-dir` to change the output folder.
12 changes: 12 additions & 0 deletions examples/offline_inference/qwen3_tts/benchmark_prompts.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Hello, welcome to the voice synthesis benchmark test.
She said she would be here by noon, but nobody showed up.
The quick brown fox jumps over the lazy dog near the riverbank.
I can't believe how beautiful the sunset looks from up here on the mountain.
Please remember to bring your identification documents to the appointment tomorrow morning.
Have you ever wondered what it would be like to travel through time and visit ancient civilizations?
The restaurant on the corner serves the best pasta I have ever tasted in my entire life.
After the meeting, we should discuss the quarterly results and plan for the next phase.
Learning a new language takes patience, practice, and a genuine curiosity about other cultures.
The train leaves at half past seven, so we need to arrive at the station before then.
Could you please turn down the music a little bit, I'm trying to concentrate on my work.
It was a dark and stormy night when the old lighthouse keeper heard a knock at the door.
95 changes: 69 additions & 26 deletions examples/offline_inference/qwen3_tts/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,13 @@ def main(args):
Args:
args: Parsed CLI args from parse_args().
"""
if args.batch_size < 1 or (args.batch_size & (args.batch_size - 1)) != 0:
raise ValueError(
f"--batch-size must be a power of two (got {args.batch_size}); "
"non-power-of-two values do not align with CUDA graph capture sizes "
"of Code2Wav."
)

query_func = query_map[args.query_type]
if args.query_type in {"CustomVoice", "VoiceDesign"}:
query_result = query_func(use_batch_sample=args.use_batch_sample)
Expand All @@ -260,39 +267,69 @@ def main(args):
query_result = query_func()

model_name = query_result.model_name

# Load prompts from text file if provided.
# Use the default query as a template so task-specific fields
# (e.g. ref_audio for Base) are preserved; only override text.
if args.txt_prompts:
with open(args.txt_prompts) as f:
lines = [line.strip() for line in f if line.strip()]
if not lines:
raise ValueError(f"No valid prompts found in {args.txt_prompts}")
template = query_result.inputs
if isinstance(template, list):
template = template[0]
template_info = template["additional_information"]
inputs = []
for text in lines:
additional_information = {**template_info, "text": [text]}
inputs.append(
{
"prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name),
"additional_information": additional_information,
}
)
else:
inputs = query_result.inputs
if not isinstance(inputs, list):
inputs = [inputs]

omni = Omni(
model=model_name,
stage_configs_path=args.stage_configs_path,
log_stats=args.log_stats,
stage_init_timeout=args.stage_init_timeout,
)

output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)

omni_generator = omni.generate(query_result.inputs, sampling_params_list=None)
for stage_outputs in omni_generator:
for output in stage_outputs.request_output:
request_id = output.request_id
audio_data = output.outputs[0].multimodal_output["audio"]
# async_chunk mode returns a list of chunks; concatenate them.
if isinstance(audio_data, list):
audio_tensor = torch.cat(audio_data, dim=-1)
else:
audio_tensor = audio_data
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
sr_val = output.outputs[0].multimodal_output["sr"]
audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1])
# Convert to numpy array and ensure correct format
audio_numpy = audio_tensor.float().detach().cpu().numpy()

# Ensure audio is 1D (flatten if needed)
if audio_numpy.ndim > 1:
audio_numpy = audio_numpy.flatten()

# Save audio file with explicit WAV format
sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV")
print(f"Request ID: {request_id}, Saved audio to {output_wav}")
batch_size = args.batch_size
for batch_start in range(0, len(inputs), batch_size):
batch = inputs[batch_start : batch_start + batch_size]
omni_generator = omni.generate(batch, sampling_params_list=None)
for stage_outputs in omni_generator:
for output in stage_outputs.request_output:
request_id = output.request_id
audio_data = output.outputs[0].multimodal_output["audio"]
# async_chunk mode returns a list of chunks; concatenate them.
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to always pass a list here? The conditional unwrapping means omni.generate() gets different types depending on batch size.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if isinstance(audio_data, list):
audio_tensor = torch.cat(audio_data, dim=-1)
else:
audio_tensor = audio_data
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
sr_val = output.outputs[0].multimodal_output["sr"]
audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1])
# Convert to numpy array and ensure correct format
audio_numpy = audio_tensor.float().detach().cpu().numpy()

# Ensure audio is 1D (flatten if needed)
if audio_numpy.ndim > 1:
audio_numpy = audio_numpy.flatten()

# Save audio file with explicit WAV format
sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV")
print(f"Request ID: {request_id}, Saved audio to {output_wav}")


def parse_args():
Expand Down Expand Up @@ -341,9 +378,9 @@ def parse_args():
help="Threshold for using shared memory in bytes (default: 65536)",
)
parser.add_argument(
"--output-wav",
"--output-dir",
default="output_audio",
help="[Deprecated] Output wav directory (use --output-dir).",
help="Output directory for generated wav files (default: output_audio).",
)
parser.add_argument(
"--num-prompts",
Expand Down Expand Up @@ -401,6 +438,12 @@ def parse_args():
choices=["icl", "xvec_only"],
help="Mode tag for Base query x_vector_only_mode (default: icl).",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Number of prompts per batch (default: 1, sequential).",
)

return parser.parse_args()

Expand Down
Loading