From b1f669d2971611252c240824e01963755af3ae95 Mon Sep 17 00:00:00 2001 From: pablo Date: Sat, 21 Feb 2026 10:57:31 +0000 Subject: [PATCH 01/15] [Qwen3TTS][feat] Code2Wav batched decoding Signed-off-by: pablo --- .../models/qwen3_tts/qwen3_tts_code2wav.py | 200 +++++++++++------- 1 file changed, 120 insertions(+), 80 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 339268f34f0..a45d42b0f40 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -9,6 +9,7 @@ import torch.nn as nn from transformers.utils.hub import cached_file from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger from vllm_omni.model_executor.models.output_templates import OmniOutput @@ -116,6 +117,106 @@ def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None) -> None: return None + def _split_request_ids(self, ids: torch.Tensor) -> list[torch.Tensor]: + """Split concatenated input_ids into per-request segments using forward context.""" + if is_forward_context_available(): + slices = get_forward_context().ubatch_slices + if slices is not None and len(slices) > 1: + boundaries = [0] + for s in slices: + n = s if isinstance(s, int) else (s.token_slice.stop - s.token_slice.start) + boundaries.append(boundaries[-1] + n) + return [ids[boundaries[i]:boundaries[i + 1]] for i in range(len(boundaries) - 1)] + return [ids] + + def _decode_batch( + self, request_ids_list: list[torch.Tensor], + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Decode request(s) in a single batched forward pass through SpeechTokenizer. + + Returns (audios, srs) lists — one entry per request. + """ + tok = self._speech_tokenizer + q = int(self._num_quantizers) + upsample = int(self._decode_upsample_rate) + sr_val = int(self._output_sample_rate) + sr_tensor = torch.tensor(sr_val, dtype=torch.int32) + empty = torch.zeros((0,), dtype=torch.float32) + + # Parse each request: extract ctx_frames, validate, reshape codes. + # input_ids layout per request: [codec_context_frames, *flat_codes] + # where flat_codes is codebook-major [q*F]. + parsed = [] # (ctx_frames, actual_frames) + valid_codes = [] + valid_indices = [] + for i, ids in enumerate(request_ids_list): + if ids.numel() < 2: + parsed.append((0, 0)) + continue + ctx_frames = int(ids[0].item()) + flat = ids[1:] + n = flat.numel() + # Warmup / dummy_run: not divisible by num_quantizers. + if n == 0 or n % q != 0: + if n > 0: + logger.warning( + "Code2Wav input_ids length %d not divisible by num_quantizers %d, " + "likely a warmup run; returning empty audio.", + n, q, + ) + parsed.append((0, 0)) + continue + frames = n // q + # Reshape codebook-major flat [q*F] -> [q, F] -> [F, q] for SpeechTokenizer. + codes_fq = flat.reshape(q, frames).transpose(0, 1).contiguous() + parsed.append((ctx_frames, frames)) + valid_codes.append({"audio_codes": codes_fq}) + valid_indices.append(i) + + if not valid_codes: + n = len(request_ids_list) + return [empty] * n, [sr_tensor] * n + + if not self._logged_codec_stats: + self._logged_codec_stats = True + try: + c = valid_codes[0]["audio_codes"] + logger.info( + "Code2Wav codec: frames=%d q=%d uniq=%d range=[%d,%d] head=%s batch=%d", + c.shape[0], q, int(torch.unique(c).numel()), + int(c.min().item()), int(c.max().item()), + c[: min(2, c.shape[0]), : min(8, q)].cpu().tolist(), + len(valid_codes), + ) + except Exception: + pass + + # Batched decode: single forward pass through SpeechTokenizer. + wavs, _ = tok.decode(valid_codes) + + # Build per-request outputs, trimming padding and left-context. + audios = [empty] * len(request_ids_list) + srs = [sr_tensor] * len(request_ids_list) + + for j, idx in enumerate(valid_indices): + ctx_frames, actual_frames = parsed[idx] + audio_np = wavs[j].astype(np.float32, copy=False) + # Trim decoder padding (output may be longer due to batch padding). + expected_len = actual_frames * upsample + if audio_np.shape[0] > expected_len: + audio_np = audio_np[:expected_len] + # Trim left-context waveform samples (streaming sliding window). + if ctx_frames > 0: + cut = ctx_frames * upsample + if cut < audio_np.shape[0]: + audio_np = audio_np[cut:] + else: + continue # context trim >= decoded length + if audio_np.shape[0] > 0: + audios[idx] = torch.from_numpy(audio_np).to(dtype=torch.float32).reshape(-1) + + return audios, srs + @torch.no_grad() def forward( self, @@ -124,98 +225,37 @@ def forward( intermediate_tensors: Any = None, inputs_embeds: torch.Tensor | None = None, **kwargs: Any, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor] | OmniOutput: """Decode codec codes into audio waveform. - input_ids layout: [codec_context_frames, *flat_codes] + input_ids layout per request: [codec_context_frames, *flat_codes] where flat_codes is codebook-major [q*F]. + + When batched, uses forward context ubatch_slices to split the + concatenated input_ids and decode via a single batched forward pass. """ - tok = self._ensure_speech_tokenizer_loaded() + self._ensure_speech_tokenizer_loaded() assert self._num_quantizers is not None assert self._output_sample_rate is not None sr_val = self._output_sample_rate - empty_ret = ( - torch.zeros((0,), dtype=torch.float32), - torch.tensor(sr_val, dtype=torch.int32), - ) - - if input_ids is None: - return empty_ret - - q = int(self._num_quantizers) - ids = input_ids.reshape(-1).to(dtype=torch.long) - n_tokens = ids.numel() - - if n_tokens == 0: - return empty_ret - - # input_ids[0] = codec_context_frames (prepended by stage_input_processor). - ctx_frames = int(ids[0].item()) - ids = ids[1:] - n_tokens = ids.numel() - - if n_tokens == 0: - return empty_ret - - # Warmup / dummy_run: not divisible by num_quantizers. - if n_tokens % q != 0: - logger.warning( - "Code2Wav input_ids length %d not divisible by num_quantizers %d, " - "likely a warmup run; returning empty audio.", - n_tokens, - q, + if input_ids is None or input_ids.numel() == 0: + return ( + torch.zeros((0,), dtype=torch.float32), + torch.tensor(sr_val, dtype=torch.int32), ) - return empty_ret - - total_frames = n_tokens // q - # Reshape codebook-major flat [q*F] -> [q, F] -> [F, q] for SpeechTokenizer. - codes_fq = ids.reshape(q, total_frames).transpose(0, 1).contiguous() - - if not self._logged_codec_stats and total_frames > 1: - self._logged_codec_stats = True - try: - uniq = int(torch.unique(codes_fq).numel()) - cmin = int(codes_fq.min().item()) - cmax = int(codes_fq.max().item()) - head = codes_fq[: min(2, total_frames), : min(8, q)].cpu().tolist() - logger.info( - "Code2Wav codec: frames=%d q=%d uniq=%d range=[%d,%d] head=%s", - total_frames, - q, - uniq, - cmin, - cmax, - head, - ) - except Exception: - pass + ids = input_ids.reshape(-1).to(dtype=torch.long) + request_ids_list = self._split_request_ids(ids) + audios, srs = self._decode_batch(request_ids_list) - wavs, sr = tok.decode({"audio_codes": codes_fq}) - if not wavs: - raise ValueError("SpeechTokenizer code2wav produced empty waveform list.") - audio_np = wavs[0].astype(np.float32, copy=False) - - # Trim left-context waveform samples (streaming sliding window). - if ctx_frames > 0: - upsample = self._decode_upsample_rate - if upsample is None or upsample <= 0: - raise ValueError(f"Invalid decode upsample rate: {upsample}") - cut = ctx_frames * upsample - if cut < audio_np.shape[0]: - audio_np = audio_np[cut:] - else: - logger.warning( - "Context trim %d >= decoded length %d; returning empty audio.", - cut, - audio_np.shape[0], - ) - return empty_ret + if len(audios) == 1: + return audios[0], srs[0] - audio_tensor = torch.from_numpy(audio_np).to(dtype=torch.float32).reshape(-1) - sr_tensor = torch.tensor(int(sr), dtype=torch.int32) - return audio_tensor, sr_tensor + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"model_outputs": audios, "sr": srs}, + ) def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput: if isinstance(model_outputs, OmniOutput): From da84873d868dea7034b460d2af9ba3653022a89b Mon Sep 17 00:00:00 2001 From: pablo Date: Sat, 21 Feb 2026 11:08:11 +0000 Subject: [PATCH 02/15] move to forward pass instead of helper Signed-off-by: pablo --- .../models/qwen3_tts/qwen3_tts_code2wav.py | 92 +++++++++---------- 1 file changed, 42 insertions(+), 50 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index a45d42b0f40..2d360254c86 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -129,13 +129,27 @@ def _split_request_ids(self, ids: torch.Tensor) -> list[torch.Tensor]: return [ids[boundaries[i]:boundaries[i + 1]] for i in range(len(boundaries) - 1)] return [ids] - def _decode_batch( - self, request_ids_list: list[torch.Tensor], - ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - """Decode request(s) in a single batched forward pass through SpeechTokenizer. + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: Any = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Any, + ) -> OmniOutput: + """Decode codec codes into audio waveform. - Returns (audios, srs) lists — one entry per request. + input_ids layout per request: [codec_context_frames, *flat_codes] + where flat_codes is codebook-major [q*F]. + + When batched, uses forward context ubatch_slices to split the + concatenated input_ids and decode via a single batched forward pass. """ + self._ensure_speech_tokenizer_loaded() + assert self._num_quantizers is not None + assert self._output_sample_rate is not None + tok = self._speech_tokenizer q = int(self._num_quantizers) upsample = int(self._decode_upsample_rate) @@ -143,18 +157,27 @@ def _decode_batch( sr_tensor = torch.tensor(sr_val, dtype=torch.int32) empty = torch.zeros((0,), dtype=torch.float32) + if input_ids is None or input_ids.numel() == 0: + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={"model_outputs": [empty], "sr": [sr_tensor]}, + ) + + ids = input_ids.reshape(-1).to(dtype=torch.long) + request_ids_list = self._split_request_ids(ids) + # Parse each request: extract ctx_frames, validate, reshape codes. # input_ids layout per request: [codec_context_frames, *flat_codes] # where flat_codes is codebook-major [q*F]. parsed = [] # (ctx_frames, actual_frames) valid_codes = [] valid_indices = [] - for i, ids in enumerate(request_ids_list): - if ids.numel() < 2: + for i, req_ids in enumerate(request_ids_list): + if req_ids.numel() < 2: parsed.append((0, 0)) continue - ctx_frames = int(ids[0].item()) - flat = ids[1:] + ctx_frames = int(req_ids[0].item()) + flat = req_ids[1:] n = flat.numel() # Warmup / dummy_run: not divisible by num_quantizers. if n == 0 or n % q != 0: @@ -173,9 +196,15 @@ def _decode_batch( valid_codes.append({"audio_codes": codes_fq}) valid_indices.append(i) + num_req = len(request_ids_list) if not valid_codes: - n = len(request_ids_list) - return [empty] * n, [sr_tensor] * n + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={ + "model_outputs": [empty] * num_req, + "sr": [sr_tensor] * num_req, + }, + ) if not self._logged_codec_stats: self._logged_codec_stats = True @@ -195,8 +224,8 @@ def _decode_batch( wavs, _ = tok.decode(valid_codes) # Build per-request outputs, trimming padding and left-context. - audios = [empty] * len(request_ids_list) - srs = [sr_tensor] * len(request_ids_list) + audios = [empty] * num_req + srs = [sr_tensor] * num_req for j, idx in enumerate(valid_indices): ctx_frames, actual_frames = parsed[idx] @@ -215,43 +244,6 @@ def _decode_batch( if audio_np.shape[0] > 0: audios[idx] = torch.from_numpy(audio_np).to(dtype=torch.float32).reshape(-1) - return audios, srs - - @torch.no_grad() - def forward( - self, - input_ids: torch.Tensor | None = None, - positions: torch.Tensor | None = None, - intermediate_tensors: Any = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, torch.Tensor] | OmniOutput: - """Decode codec codes into audio waveform. - - input_ids layout per request: [codec_context_frames, *flat_codes] - where flat_codes is codebook-major [q*F]. - - When batched, uses forward context ubatch_slices to split the - concatenated input_ids and decode via a single batched forward pass. - """ - self._ensure_speech_tokenizer_loaded() - assert self._num_quantizers is not None - assert self._output_sample_rate is not None - - sr_val = self._output_sample_rate - if input_ids is None or input_ids.numel() == 0: - return ( - torch.zeros((0,), dtype=torch.float32), - torch.tensor(sr_val, dtype=torch.int32), - ) - - ids = input_ids.reshape(-1).to(dtype=torch.long) - request_ids_list = self._split_request_ids(ids) - audios, srs = self._decode_batch(request_ids_list) - - if len(audios) == 1: - return audios[0], srs[0] - return OmniOutput( text_hidden_states=None, multimodal_outputs={"model_outputs": audios, "sr": srs}, From 9dddbb8d0976f231480e710ebcf8840fb177a466 Mon Sep 17 00:00:00 2001 From: pablo Date: Sat, 21 Feb 2026 22:35:07 +0000 Subject: [PATCH 03/15] update to the benchmark scripts Signed-off-by: pablo --- .../offline_inference/qwen3_tts/README.md | 15 +++- .../qwen3_tts/benchmark_prompts.txt | 12 +++ .../offline_inference/qwen3_tts/end2end.py | 88 +++++++++++++------ 3 files changed, 88 insertions(+), 27 deletions(-) create mode 100644 examples/offline_inference/qwen3_tts/benchmark_prompts.txt diff --git a/examples/offline_inference/qwen3_tts/README.md b/examples/offline_inference/qwen3_tts/README.md index ddbd3fcbcc8..2b04cad75ab 100644 --- a/examples/offline_inference/qwen3_tts/README.md +++ b/examples/offline_inference/qwen3_tts/README.md @@ -87,7 +87,20 @@ Examples: python end2end.py --query-type Base --mode-tag icl ``` +## Batched Decoding + +The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_batch_size > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`. + +``` +python end2end.py --query-type CustomVoice \ + --txt-prompts benchmark_prompts.txt \ + --batch-size 4 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml +``` + +**Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_batch_size >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_batch_size`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently. + ## Notes - The script uses the model paths embedded in `end2end.py`. Update them if your local cache path differs. -- Use `--output-dir` (preferred) or `--output-wav` to change the output folder. +- Use `--output-dir` to change the output folder. diff --git a/examples/offline_inference/qwen3_tts/benchmark_prompts.txt b/examples/offline_inference/qwen3_tts/benchmark_prompts.txt new file mode 100644 index 00000000000..92577a418cd --- /dev/null +++ b/examples/offline_inference/qwen3_tts/benchmark_prompts.txt @@ -0,0 +1,12 @@ +Hello, welcome to the voice synthesis benchmark test. +She said she would be here by noon, but nobody showed up. +The quick brown fox jumps over the lazy dog near the riverbank. +I can't believe how beautiful the sunset looks from up here on the mountain. +Please remember to bring your identification documents to the appointment tomorrow morning. +Have you ever wondered what it would be like to travel through time and visit ancient civilizations? +The restaurant on the corner serves the best pasta I have ever tasted in my entire life. +After the meeting, we should discuss the quarterly results and plan for the next phase. +Learning a new language takes patience, practice, and a genuine curiosity about other cultures. +The train leaves at half past seven, so we need to arrive at the station before then. +Could you please turn down the music a little bit, I'm trying to concentrate on my work. +It was a dark and stormy night when the old lighthouse keeper heard a knock at the door. diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index 12e5e193542..9c08c1ad3de 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -260,6 +260,32 @@ def main(args): query_result = query_func() model_name = query_result.model_name + + # Load prompts from text file if provided. + if args.txt_prompts: + with open(args.txt_prompts) as f: + lines = [line.strip() for line in f if line.strip()] + inputs = [] + for text in lines: + additional_information = { + "task_type": [args.query_type], + "text": [text], + "language": ["Auto"], + "speaker": ["Vivian"], + "instruct": [""], + "max_new_tokens": [2048], + } + inputs.append( + { + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, + } + ) + else: + inputs = query_result.inputs + if not isinstance(inputs, list): + inputs = [inputs] + omni = Omni( model=model_name, stage_configs_path=args.stage_configs_path, @@ -267,32 +293,36 @@ def main(args): stage_init_timeout=args.stage_init_timeout, ) - output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav + output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) - omni_generator = omni.generate(query_result.inputs, sampling_params_list=None) - for stage_outputs in omni_generator: - for output in stage_outputs.request_output: - request_id = output.request_id - audio_data = output.outputs[0].multimodal_output["audio"] - # async_chunk mode returns a list of chunks; concatenate them. - if isinstance(audio_data, list): - audio_tensor = torch.cat(audio_data, dim=-1) - else: - audio_tensor = audio_data - output_wav = os.path.join(output_dir, f"output_{request_id}.wav") - sr_val = output.outputs[0].multimodal_output["sr"] - audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1]) - # Convert to numpy array and ensure correct format - audio_numpy = audio_tensor.float().detach().cpu().numpy() - - # Ensure audio is 1D (flatten if needed) - if audio_numpy.ndim > 1: - audio_numpy = audio_numpy.flatten() - - # Save audio file with explicit WAV format - sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV") - print(f"Request ID: {request_id}, Saved audio to {output_wav}") + batch_size = args.batch_size + for batch_start in range(0, len(inputs), batch_size): + batch = inputs[batch_start : batch_start + batch_size] + batch_input = batch[0] if len(batch) == 1 else batch + omni_generator = omni.generate(batch_input, sampling_params_list=None) + for stage_outputs in omni_generator: + for output in stage_outputs.request_output: + request_id = output.request_id + audio_data = output.outputs[0].multimodal_output["audio"] + # async_chunk mode returns a list of chunks; concatenate them. + if isinstance(audio_data, list): + audio_tensor = torch.cat(audio_data, dim=-1) + else: + audio_tensor = audio_data + output_wav = os.path.join(output_dir, f"output_{request_id}.wav") + sr_val = output.outputs[0].multimodal_output["sr"] + audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1]) + # Convert to numpy array and ensure correct format + audio_numpy = audio_tensor.float().detach().cpu().numpy() + + # Ensure audio is 1D (flatten if needed) + if audio_numpy.ndim > 1: + audio_numpy = audio_numpy.flatten() + + # Save audio file with explicit WAV format + sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV") + print(f"Request ID: {request_id}, Saved audio to {output_wav}") def parse_args(): @@ -341,9 +371,9 @@ def parse_args(): help="Threshold for using shared memory in bytes (default: 65536)", ) parser.add_argument( - "--output-wav", + "--output-dir", default="output_audio", - help="[Deprecated] Output wav directory (use --output-dir).", + help="Output directory for generated wav files (default: output_audio).", ) parser.add_argument( "--num-prompts", @@ -401,6 +431,12 @@ def parse_args(): choices=["icl", "xvec_only"], help="Mode tag for Base query x_vector_only_mode (default: icl).", ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Number of prompts per batch (default: 1, sequential).", + ) return parser.parse_args() From d1a4bd9ea5971ddf8b1812956f7d839e4bbd6b50 Mon Sep 17 00:00:00 2001 From: pablo Date: Sat, 21 Feb 2026 22:35:50 +0000 Subject: [PATCH 04/15] added batched decoding stage config Signed-off-by: pablo --- .../stage_configs/qwen3_tts_batch.yaml | 103 ++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml new file mode 100644 index 00000000000..955735000f1 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml @@ -0,0 +1,103 @@ +# Same as qwen3_tts.yaml with batched talker and code2wav. +# Stage 0: max_batch_size 4, stage 1: max_batch_size 4. +async_chunk: true +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 4 + engine_args: + model_stage: qwen3_tts + model_arch: Qwen3TTSTalkerForConditionalGeneration + # Force stage-specific registered architecture. + hf_overrides: + architectures: [Qwen3TTSTalkerForConditionalGeneration] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: false + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.3 + distributed_executor_backend: "mp" + max_num_batched_tokens: 512 + max_model_len: 4096 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk + # Use named connector to apply runtime.connectors.extra. + output_connectors: + to_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 4 + engine_args: + model_stage: code2wav + model_arch: Qwen3TTSCode2Wav + # Force stage-specific registered architecture. + hf_overrides: + architectures: [Qwen3TTSCode2Wav] + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.2 + distributed_executor_backend: "mp" + # Must be divisible by num_code_groups and cover (left_context + chunk). + max_num_batched_tokens: 8192 + # async_chunk appends windows per step; max_model_len must cover accumulated stream. + max_model_len: 32768 + engine_input_source: [0] + final_output: true + final_output_type: audio + # Distributed connector configuration + input_connectors: + from_stage_0: connector_of_shared_memory + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + # Frame-aligned codec streaming transport. + codec_streaming: true + # Connector polling / timeout (unit: loop count, sleep interval in seconds). + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + # Align with Omni: small chunks with sufficient context overlap. + codec_chunk_frames: 25 + codec_left_context_frames: 25 + + edges: + - from: 0 + to: 1 + window_size: -1 From a6dfea0a6b26e70467fc537b85267d43a6be0fbf Mon Sep 17 00:00:00 2001 From: pablo Date: Sat, 21 Feb 2026 22:37:19 +0000 Subject: [PATCH 05/15] lint Signed-off-by: pablo --- examples/offline_inference/qwen3_omni/end2end.py | 4 ++-- tests/conftest.py | 2 +- tests/engine/test_async_omni_engine_abort.py | 2 +- tests/entrypoints/openai_api/test_image_server.py | 2 +- tests/entrypoints/test_omni_llm.py | 2 +- vllm_omni/benchmarks/patch/patch.py | 5 +++-- vllm_omni/entrypoints/omni.py | 2 +- vllm_omni/entrypoints/omni_stage.py | 2 +- vllm_omni/entrypoints/openai/api_server.py | 3 ++- .../models/qwen3_tts/qwen3_tts_code2wav.py | 12 ++++++++---- 10 files changed, 21 insertions(+), 15 deletions(-) diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 509886a02bf..89d65e50de3 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -12,15 +12,15 @@ import librosa import numpy as np import soundfile as sf -import vllm from PIL import Image -from vllm import SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset, video_to_ndarrays from vllm.multimodal.image import convert_image_mode from vllm.utils.argparse_utils import FlexibleArgumentParser +import vllm +from vllm import SamplingParams from vllm_omni.entrypoints.omni import Omni SEED = 42 diff --git a/tests/conftest.py b/tests/conftest.py index 08d26a6d168..bad7600fe88 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,11 +27,11 @@ import torch import yaml from openai import OpenAI -from vllm import TextPrompt from vllm.distributed.parallel_state import cleanup_dist_env_and_memory from vllm.logger import init_logger from vllm.utils.network_utils import get_open_port +from vllm import TextPrompt from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniSamplingParams from vllm_omni.outputs import OmniRequestOutput diff --git a/tests/engine/test_async_omni_engine_abort.py b/tests/engine/test_async_omni_engine_abort.py index b5f9bac9914..b81c61d4efe 100644 --- a/tests/engine/test_async_omni_engine_abort.py +++ b/tests/engine/test_async_omni_engine_abort.py @@ -5,10 +5,10 @@ from pathlib import Path import pytest -from vllm import SamplingParams from vllm.inputs import PromptType from tests.utils import hardware_test +from vllm import SamplingParams from vllm_omni.entrypoints.async_omni import AsyncOmni os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 0c6479ccea7..49031ee97f9 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -15,8 +15,8 @@ import pytest from fastapi.testclient import TestClient from PIL import Image -from vllm import SamplingParams +from vllm import SamplingParams from vllm_omni.entrypoints.openai.image_api_utils import ( encode_image_base64, parse_size, diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index 4f05575ca59..33fd002e73a 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -5,8 +5,8 @@ from unittest.mock import MagicMock import pytest -from vllm import SamplingParams +from vllm import SamplingParams from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK pytestmark = [pytest.mark.core_model, pytest.mark.cpu] diff --git a/vllm_omni/benchmarks/patch/patch.py b/vllm_omni/benchmarks/patch/patch.py index 5a03a8069e2..eb8e8ae626c 100644 --- a/vllm_omni/benchmarks/patch/patch.py +++ b/vllm_omni/benchmarks/patch/patch.py @@ -18,7 +18,6 @@ from pydub import AudioSegment from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase -from vllm.benchmarks import datasets from vllm.benchmarks.datasets import SampleRequest from vllm.benchmarks.lib.endpoint_request_func import ( ASYNC_REQUEST_FUNCS, @@ -33,6 +32,8 @@ ) from vllm.logger import init_logger +from vllm.benchmarks import datasets + logger = init_logger(__name__) from vllm_omni.benchmarks.data_modules.random_multi_modal_dataset import OmniRandomMultiModalDataset @@ -204,9 +205,9 @@ async def async_request_openai_chat_omni_completions( # ruff: noqa: E402 # Prevent import order from causing patch failures -from vllm.benchmarks import serve from vllm.benchmarks.serve import TaskType, calculate_metrics_for_embeddings, get_request, wait_for_endpoint +from vllm.benchmarks import serve from vllm_omni.benchmarks.metrics.metrics import MultiModalsBenchmarkMetrics, calculate_metrics # ruff: noqa: E402 diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index f30cd7d368e..0c3e32bcd5b 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -13,9 +13,9 @@ import huggingface_hub from omegaconf import OmegaConf from tqdm.auto import tqdm -from vllm import SamplingParams from vllm.logger import init_logger +from vllm import SamplingParams from vllm_omni.distributed.omni_connectors import ( get_stage_connector_config, initialize_orchestrator_connectors, diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 6c9723b6b5b..d0b10a136f8 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -20,7 +20,6 @@ from dataclasses import fields from typing import Any, Literal, cast -from vllm import PromptType, RequestOutput from vllm.inputs import TextPrompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -31,6 +30,7 @@ from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.llm_engine import LLMEngine +from vllm import PromptType, RequestOutput from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.distributed.omni_connectors import build_stage_connectors from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 3898da3081e..565cea15aa6 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -24,7 +24,6 @@ from PIL import Image from starlette.datastructures import State from starlette.routing import Route -from vllm import SamplingParams from vllm.engine.protocol import EngineClient from vllm.entrypoints.anthropic.serving import AnthropicServingMessages from vllm.entrypoints.chat_utils import load_chat_template @@ -34,6 +33,8 @@ from vllm.entrypoints.openai.api_server import build_app as build_openai_app from vllm.entrypoints.openai.api_server import setup_server as setup_openai_server +from vllm import SamplingParams + # vLLM moved `base` from openai.basic.api_router to serve.instrumentator.basic. # Keep a fallback for older/newer upstream layouts during rebase windows. try: diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 2d360254c86..8efdf68069b 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -126,7 +126,7 @@ def _split_request_ids(self, ids: torch.Tensor) -> list[torch.Tensor]: for s in slices: n = s if isinstance(s, int) else (s.token_slice.stop - s.token_slice.start) boundaries.append(boundaries[-1] + n) - return [ids[boundaries[i]:boundaries[i + 1]] for i in range(len(boundaries) - 1)] + return [ids[boundaries[i] : boundaries[i + 1]] for i in range(len(boundaries) - 1)] return [ids] @torch.no_grad() @@ -185,7 +185,8 @@ def forward( logger.warning( "Code2Wav input_ids length %d not divisible by num_quantizers %d, " "likely a warmup run; returning empty audio.", - n, q, + n, + q, ) parsed.append((0, 0)) continue @@ -212,8 +213,11 @@ def forward( c = valid_codes[0]["audio_codes"] logger.info( "Code2Wav codec: frames=%d q=%d uniq=%d range=[%d,%d] head=%s batch=%d", - c.shape[0], q, int(torch.unique(c).numel()), - int(c.min().item()), int(c.max().item()), + c.shape[0], + q, + int(torch.unique(c).numel()), + int(c.min().item()), + int(c.max().item()), c[: min(2, c.shape[0]), : min(8, q)].cpu().tolist(), len(valid_codes), ) From 5c5cd1a629bf090c99f4a8157d5beddfd37db3c8 Mon Sep 17 00:00:00 2001 From: pablo Date: Sat, 21 Feb 2026 23:02:48 +0000 Subject: [PATCH 06/15] fix logic in e2e.py Signed-off-by: pablo --- examples/offline_inference/qwen3_tts/end2end.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index 9c08c1ad3de..22d8f9cf873 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -262,19 +262,18 @@ def main(args): model_name = query_result.model_name # Load prompts from text file if provided. + # Use the default query as a template so task-specific fields + # (e.g. ref_audio for Base) are preserved; only override text. if args.txt_prompts: with open(args.txt_prompts) as f: lines = [line.strip() for line in f if line.strip()] + template = query_result.inputs + if isinstance(template, list): + template = template[0] + template_info = template["additional_information"] inputs = [] for text in lines: - additional_information = { - "task_type": [args.query_type], - "text": [text], - "language": ["Auto"], - "speaker": ["Vivian"], - "instruct": [""], - "max_new_tokens": [2048], - } + additional_information = {**template_info, "text": [text]} inputs.append( { "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), From f67f93212f278d9f2550cbcf9ab3cf77c58c934a Mon Sep 17 00:00:00 2001 From: pablo Date: Sun, 22 Feb 2026 18:14:29 +0000 Subject: [PATCH 07/15] change split req_ids and support UBatchSlice Signed-off-by: pablo --- .../model_executor/models/qwen3_tts/qwen3_tts_code2wav.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 8efdf68069b..fc74176a0a0 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -121,11 +121,10 @@ def _split_request_ids(self, ids: torch.Tensor) -> list[torch.Tensor]: """Split concatenated input_ids into per-request segments using forward context.""" if is_forward_context_available(): slices = get_forward_context().ubatch_slices - if slices is not None and len(slices) > 1: + if slices is not None and len(slices) > 1 and not any(hasattr(s, "token_slice") for s in slices): boundaries = [0] for s in slices: - n = s if isinstance(s, int) else (s.token_slice.stop - s.token_slice.start) - boundaries.append(boundaries[-1] + n) + boundaries.append(boundaries[-1] + s) return [ids[boundaries[i] : boundaries[i + 1]] for i in range(len(boundaries) - 1)] return [ids] From 5cbc21489a66717baca80c46f26edb23efd4b321 Mon Sep 17 00:00:00 2001 From: pablo Date: Sun, 22 Feb 2026 18:22:13 +0000 Subject: [PATCH 08/15] guard for wavs returned; e2e assert Signed-off-by: pablo --- examples/offline_inference/qwen3_tts/end2end.py | 10 ++++++++-- .../models/qwen3_tts/qwen3_tts_code2wav.py | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index 22d8f9cf873..62a61878142 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -248,6 +248,13 @@ def main(args): Args: args: Parsed CLI args from parse_args(). """ + if args.batch_size < 1 or (args.batch_size & (args.batch_size - 1)) != 0: + raise ValueError( + f"--batch-size must be a power of two (got {args.batch_size}); " + "non-power-of-two values do not align with CUDA graph capture sizes " + "of Code2Wav." + ) + query_func = query_map[args.query_type] if args.query_type in {"CustomVoice", "VoiceDesign"}: query_result = query_func(use_batch_sample=args.use_batch_sample) @@ -298,8 +305,7 @@ def main(args): batch_size = args.batch_size for batch_start in range(0, len(inputs), batch_size): batch = inputs[batch_start : batch_start + batch_size] - batch_input = batch[0] if len(batch) == 1 else batch - omni_generator = omni.generate(batch_input, sampling_params_list=None) + omni_generator = omni.generate(batch, sampling_params_list=None) for stage_outputs in omni_generator: for output in stage_outputs.request_output: request_id = output.request_id diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index fc74176a0a0..37ce1051c8f 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -225,6 +225,7 @@ def forward( # Batched decode: single forward pass through SpeechTokenizer. wavs, _ = tok.decode(valid_codes) + assert len(wavs) == len(valid_codes), f"Code2Wav returned {len(wavs)} waveforms for {len(valid_codes)} requests" # Build per-request outputs, trimming padding and left-context. audios = [empty] * num_req From 27e31267932d26687ace92064216791a8f9c23ad Mon Sep 17 00:00:00 2001 From: pablo Date: Sun, 22 Feb 2026 18:32:57 +0000 Subject: [PATCH 09/15] log and logger improv Signed-off-by: pablo --- .../models/qwen3_tts/qwen3_tts_code2wav.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 37ce1051c8f..7dc92d5a7f6 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -118,7 +118,15 @@ def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_meta return None def _split_request_ids(self, ids: torch.Tensor) -> list[torch.Tensor]: - """Split concatenated input_ids into per-request segments using forward context.""" + """ + Split concatenated input_ids into per-request segments using forward context. + + Uses ubatch_slices from forward context which contains either: + - int: number of tokens in the request + - slice object with token_slice attribute + + Returns list of per-request id tensors, or [ids] if not in batched context. + """ if is_forward_context_available(): slices = get_forward_context().ubatch_slices if slices is not None and len(slices) > 1 and not any(hasattr(s, "token_slice") for s in slices): @@ -244,7 +252,12 @@ def forward( if cut < audio_np.shape[0]: audio_np = audio_np[cut:] else: - continue # context trim >= decoded length + logger.warning( + "Context trim %d >= decoded length %d; returning empty audio.", + cut, + audio_np.shape[0], + ) + continue if audio_np.shape[0] > 0: audios[idx] = torch.from_numpy(audio_np).to(dtype=torch.float32).reshape(-1) From f242d25f944782c7d7287ae43521139835dd0cd9 Mon Sep 17 00:00:00 2001 From: pablo Date: Sun, 22 Feb 2026 18:45:26 +0000 Subject: [PATCH 10/15] log and assert Signed-off-by: pablo --- vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py | 1 + vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml | 2 ++ 2 files changed, 3 insertions(+) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 7dc92d5a7f6..ac553aeb7d7 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -155,6 +155,7 @@ def forward( """ self._ensure_speech_tokenizer_loaded() assert self._num_quantizers is not None + assert self._decode_upsample_rate is not None assert self._output_sample_rate is not None tok = self._speech_tokenizer diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml index 955735000f1..d737391095b 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml @@ -1,5 +1,7 @@ # Same as qwen3_tts.yaml with batched talker and code2wav. # Stage 0: max_batch_size 4, stage 1: max_batch_size 4. +# max_batch_size must be a power of two to align with CUDA graph capture sizes +# (stage 0) and must match --batch-size in end2end.py / benchmark scripts. async_chunk: true stage_args: - stage_id: 0 From 1a0d5fca778dccd5a7a89f24cc4f2805579d73b0 Mon Sep 17 00:00:00 2001 From: pablo Date: Tue, 24 Feb 2026 07:03:24 +0000 Subject: [PATCH 11/15] revert lint Signed-off-by: pablo --- examples/offline_inference/qwen3_omni/end2end.py | 4 ++-- tests/conftest.py | 2 +- tests/engine/test_async_omni_engine_abort.py | 2 +- tests/entrypoints/openai_api/test_image_server.py | 1 - tests/entrypoints/test_omni_llm.py | 1 - vllm_omni/entrypoints/omni.py | 2 +- vllm_omni/entrypoints/omni_stage.py | 2 +- vllm_omni/entrypoints/openai/api_server.py | 3 +-- 8 files changed, 7 insertions(+), 10 deletions(-) diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 89d65e50de3..509886a02bf 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -12,15 +12,15 @@ import librosa import numpy as np import soundfile as sf +import vllm from PIL import Image +from vllm import SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset, video_to_ndarrays from vllm.multimodal.image import convert_image_mode from vllm.utils.argparse_utils import FlexibleArgumentParser -import vllm -from vllm import SamplingParams from vllm_omni.entrypoints.omni import Omni SEED = 42 diff --git a/tests/conftest.py b/tests/conftest.py index bad7600fe88..08d26a6d168 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,11 +27,11 @@ import torch import yaml from openai import OpenAI +from vllm import TextPrompt from vllm.distributed.parallel_state import cleanup_dist_env_and_memory from vllm.logger import init_logger from vllm.utils.network_utils import get_open_port -from vllm import TextPrompt from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniSamplingParams from vllm_omni.outputs import OmniRequestOutput diff --git a/tests/engine/test_async_omni_engine_abort.py b/tests/engine/test_async_omni_engine_abort.py index b81c61d4efe..b5f9bac9914 100644 --- a/tests/engine/test_async_omni_engine_abort.py +++ b/tests/engine/test_async_omni_engine_abort.py @@ -5,10 +5,10 @@ from pathlib import Path import pytest +from vllm import SamplingParams from vllm.inputs import PromptType from tests.utils import hardware_test -from vllm import SamplingParams from vllm_omni.entrypoints.async_omni import AsyncOmni os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 49031ee97f9..c863c540957 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -15,7 +15,6 @@ import pytest from fastapi.testclient import TestClient from PIL import Image - from vllm import SamplingParams from vllm_omni.entrypoints.openai.image_api_utils import ( encode_image_base64, diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index 33fd002e73a..1378b85ffdd 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock import pytest - from vllm import SamplingParams from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 0c3e32bcd5b..f30cd7d368e 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -13,9 +13,9 @@ import huggingface_hub from omegaconf import OmegaConf from tqdm.auto import tqdm +from vllm import SamplingParams from vllm.logger import init_logger -from vllm import SamplingParams from vllm_omni.distributed.omni_connectors import ( get_stage_connector_config, initialize_orchestrator_connectors, diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index d0b10a136f8..6c9723b6b5b 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -20,6 +20,7 @@ from dataclasses import fields from typing import Any, Literal, cast +from vllm import PromptType, RequestOutput from vllm.inputs import TextPrompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -30,7 +31,6 @@ from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.llm_engine import LLMEngine -from vllm import PromptType, RequestOutput from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.distributed.omni_connectors import build_stage_connectors from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 565cea15aa6..3898da3081e 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -24,6 +24,7 @@ from PIL import Image from starlette.datastructures import State from starlette.routing import Route +from vllm import SamplingParams from vllm.engine.protocol import EngineClient from vllm.entrypoints.anthropic.serving import AnthropicServingMessages from vllm.entrypoints.chat_utils import load_chat_template @@ -33,8 +34,6 @@ from vllm.entrypoints.openai.api_server import build_app as build_openai_app from vllm.entrypoints.openai.api_server import setup_server as setup_openai_server -from vllm import SamplingParams - # vLLM moved `base` from openai.basic.api_router to serve.instrumentator.basic. # Keep a fallback for older/newer upstream layouts during rebase windows. try: From 8a9589b73abfd5dae40f850409577f82e70daec7 Mon Sep 17 00:00:00 2001 From: pablo Date: Tue, 24 Feb 2026 07:13:51 +0000 Subject: [PATCH 12/15] revert lint2 Signed-off-by: pablo --- tests/entrypoints/openai_api/test_image_server.py | 1 + tests/entrypoints/test_omni_llm.py | 1 + vllm_omni/benchmarks/patch/patch.py | 6 ++---- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index c863c540957..0c6479ccea7 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -16,6 +16,7 @@ from fastapi.testclient import TestClient from PIL import Image from vllm import SamplingParams + from vllm_omni.entrypoints.openai.image_api_utils import ( encode_image_base64, parse_size, diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index 1378b85ffdd..4f05575ca59 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -6,6 +6,7 @@ import pytest from vllm import SamplingParams + from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK pytestmark = [pytest.mark.core_model, pytest.mark.cpu] diff --git a/vllm_omni/benchmarks/patch/patch.py b/vllm_omni/benchmarks/patch/patch.py index eb8e8ae626c..568e9923b14 100644 --- a/vllm_omni/benchmarks/patch/patch.py +++ b/vllm_omni/benchmarks/patch/patch.py @@ -18,6 +18,7 @@ from pydub import AudioSegment from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase +from vllm.benchmarks import datasets from vllm.benchmarks.datasets import SampleRequest from vllm.benchmarks.lib.endpoint_request_func import ( ASYNC_REQUEST_FUNCS, @@ -32,8 +33,6 @@ ) from vllm.logger import init_logger -from vllm.benchmarks import datasets - logger = init_logger(__name__) from vllm_omni.benchmarks.data_modules.random_multi_modal_dataset import OmniRandomMultiModalDataset @@ -205,9 +204,8 @@ async def async_request_openai_chat_omni_completions( # ruff: noqa: E402 # Prevent import order from causing patch failures -from vllm.benchmarks.serve import TaskType, calculate_metrics_for_embeddings, get_request, wait_for_endpoint - from vllm.benchmarks import serve +from vllm.benchmarks.serve import TaskType, calculate_metrics_for_embeddings, get_request, wait_for_endpoint from vllm_omni.benchmarks.metrics.metrics import MultiModalsBenchmarkMetrics, calculate_metrics # ruff: noqa: E402 From 157363d37b972469621eefba7b8fa72ccb0b0e0e Mon Sep 17 00:00:00 2001 From: pablo Date: Tue, 24 Feb 2026 07:22:40 +0000 Subject: [PATCH 13/15] assert Signed-off-by: pablo --- examples/offline_inference/qwen3_tts/end2end.py | 2 ++ .../model_executor/models/qwen3_tts/qwen3_tts_code2wav.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index 62a61878142..fec515d303a 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -274,6 +274,8 @@ def main(args): if args.txt_prompts: with open(args.txt_prompts) as f: lines = [line.strip() for line in f if line.strip()] + if not lines: + raise ValueError(f"No valid prompts found in {args.txt_prompts}") template = query_result.inputs if isinstance(template, list): template = template[0] diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index ac553aeb7d7..f94ff31e78d 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -234,7 +234,8 @@ def forward( # Batched decode: single forward pass through SpeechTokenizer. wavs, _ = tok.decode(valid_codes) - assert len(wavs) == len(valid_codes), f"Code2Wav returned {len(wavs)} waveforms for {len(valid_codes)} requests" + if len(wavs) != len(valid_codes): + raise RuntimeError(f"Code2Wav returned {len(wavs)} waveforms for {len(valid_codes)} requests") # Build per-request outputs, trimming padding and left-context. audios = [empty] * num_req From a4fd3797ed4867406bfac9b562ef0847b5e7dc4a Mon Sep 17 00:00:00 2001 From: pablo Date: Tue, 24 Feb 2026 07:23:51 +0000 Subject: [PATCH 14/15] lint Signed-off-by: pablo --- vllm_omni/benchmarks/patch/patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/benchmarks/patch/patch.py b/vllm_omni/benchmarks/patch/patch.py index 568e9923b14..5a03a8069e2 100644 --- a/vllm_omni/benchmarks/patch/patch.py +++ b/vllm_omni/benchmarks/patch/patch.py @@ -206,6 +206,7 @@ async def async_request_openai_chat_omni_completions( # Prevent import order from causing patch failures from vllm.benchmarks import serve from vllm.benchmarks.serve import TaskType, calculate_metrics_for_embeddings, get_request, wait_for_endpoint + from vllm_omni.benchmarks.metrics.metrics import MultiModalsBenchmarkMetrics, calculate_metrics # ruff: noqa: E402 From 9e1cfcc6bfd6ad0482de646644997e9bac6c3dc6 Mon Sep 17 00:00:00 2001 From: pablo Date: Tue, 24 Feb 2026 11:21:07 +0000 Subject: [PATCH 15/15] fix boundaries issues Signed-off-by: pablo --- .../models/qwen3_tts/qwen3_tts_code2wav.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index f94ff31e78d..a87027de18f 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -117,16 +117,19 @@ def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None) -> None: return None - def _split_request_ids(self, ids: torch.Tensor) -> list[torch.Tensor]: - """ - Split concatenated input_ids into per-request segments using forward context. - - Uses ubatch_slices from forward context which contains either: - - int: number of tokens in the request - - slice object with token_slice attribute + def _split_request_ids(self, ids: torch.Tensor, seq_token_counts: list[int] | None = None) -> list[torch.Tensor]: + """Split concatenated input_ids into per-request segments. - Returns list of per-request id tensors, or [ids] if not in batched context. + Uses seq_token_counts (injected by the runner via model_kwargs) when + available, falling back to forward-context ubatch_slices when + micro-batching is active. Returns [ids] for single-request batches. """ + if seq_token_counts is not None and len(seq_token_counts) > 1: + boundaries = [0] + for count in seq_token_counts: + boundaries.append(boundaries[-1] + count) + n = ids.numel() + return [ids[boundaries[i] : min(boundaries[i + 1], n)] for i in range(len(seq_token_counts))] if is_forward_context_available(): slices = get_forward_context().ubatch_slices if slices is not None and len(slices) > 1 and not any(hasattr(s, "token_slice") for s in slices): @@ -172,7 +175,7 @@ def forward( ) ids = input_ids.reshape(-1).to(dtype=torch.long) - request_ids_list = self._split_request_ids(ids) + request_ids_list = self._split_request_ids(ids, kwargs.get("seq_token_counts")) # Parse each request: extract ctx_frames, validate, reshape codes. # input_ids layout per request: [codec_context_frames, *flat_codes]