From 9999b8f86370408cce0d1a5fbdd37e310b684a18 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Thu, 14 May 2026 18:08:56 +0800 Subject: [PATCH 1/6] docs: record qwen3 tts ws1 baseline Signed-off-by: Sy03 <1370724210@qq.com> --- artifacts/qwen3_tts_ws1_baseline/README.md | 23 ++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 artifacts/qwen3_tts_ws1_baseline/README.md diff --git a/artifacts/qwen3_tts_ws1_baseline/README.md b/artifacts/qwen3_tts_ws1_baseline/README.md new file mode 100644 index 00000000000..4275dcbd0dc --- /dev/null +++ b/artifacts/qwen3_tts_ws1_baseline/README.md @@ -0,0 +1,23 @@ +# Qwen3-TTS WS1 Baseline + +Baseline scope: +- Config fixed at Stage0 max_num_seqs=64 and Stage1 max_num_seqs=10. +- Existing initial_codec_chunk_frames=1 is kept. +- Existing Code2Wav exact-length batching is kept. +- No WS1 Stage0 slot runner is enabled. + +Primary workload: +- Model: Qwen/Qwen3-TTS-12Hz-1.7B-Base +- Task: voice_clone +- Concurrency: 64 +- Num prompts: 256 for stable run, 128 for quick A/B +- Warmups: 2, excluded from steady-state SLA + +Metrics: +- median / p99 TTFT +- median / p99 audio TTFP +- median / p99 E2EL +- median / p99 audio RTF +- audio throughput +- request throughput +- failed request count From 7c98bad3d92c9ee10753ea0d59a2f0516277706f Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Thu, 14 May 2026 18:33:52 +0800 Subject: [PATCH 2/6] feat: add qwen3 tts stage0 slot runner Signed-off-by: Sy03 <1370724210@qq.com> --- tests/worker/test_omni_gpu_model_runner.py | 44 +++ .../test_qwen3_tts_stage0_step_runner.py | 149 ++++++++ .../models/qwen3_tts/qwen3_tts_talker.py | 2 + vllm_omni/worker/gpu_model_runner.py | 97 ++++- vllm_omni/worker/omni_step_runner.py | 66 ++++ .../worker/qwen3_tts_stage0_step_runner.py | 345 ++++++++++++++++++ 6 files changed, 687 insertions(+), 16 deletions(-) create mode 100644 tests/worker/test_qwen3_tts_stage0_step_runner.py create mode 100644 vllm_omni/worker/omni_step_runner.py create mode 100644 vllm_omni/worker/qwen3_tts_stage0_step_runner.py diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index c2f3f3622df..75a87835f9b 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -173,6 +173,12 @@ def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_ monkeypatch.setattr(runner, "_determine_batch_execution_and_padding", fake_determine.__get__(runner, type(runner))) + class NoIndexList(list): + def index(self, *args, **kwargs): + raise AssertionError("_talker_mtp_forward should not linearly search req_ids per request") + + runner.input_batch.req_ids = NoIndexList(runner.input_batch.req_ids) + # Initialize per-request embeds (batch-major inside talker_mtp_inputs_embeds) runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0]) runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0]) @@ -330,6 +336,44 @@ def test_update_intermediate_buffer_skips_unknown_req_id(): assert "unknown_req" not in runner.model_intermediate_buffer +def test_update_talker_mtp_output_writes_single_nested_value(): + runner = _make_runner(req_ids=("r1",), hidden_size=4) + runner.model.gpu_resident_buffer_keys = {("codes", "audio")} + src = torch.tensor([[1, 2, 3]], dtype=torch.long) + + OmniGPUModelRunner._update_talker_mtp_output(runner, "r1", ("codes", "audio"), src) + + stored = runner.model_intermediate_buffer["r1"]["codes"]["audio"] + assert torch.equal(stored, src) + assert stored.data_ptr() != src.data_ptr() + assert runner.requests["r1"].additional_information_cpu is runner.model_intermediate_buffer["r1"] + + +def test_optional_omni_step_runner_cleanup_is_called(): + runner = object.__new__(OmniGPUModelRunner) + freed = [] + + class DummyStepRunner: + def free_request(self, req_id): + freed.append(req_id) + + runner.omni_step_runner = DummyStepRunner() + runner.requests = {"r1": object()} + runner.model_intermediate_buffer = {"r1": {"codes": {"audio": 1}}} + runner.num_prompt_logprobs = {"r1": 0} + runner._downstream_payload_cache = {"r1": object()} + runner._talker_mtp_generators = {"r1": object()} + + OmniGPUModelRunner._free_omni_request_state(runner, "r1") + + assert freed == ["r1"] + assert "r1" not in runner.requests + assert "r1" not in runner.model_intermediate_buffer + assert "r1" not in runner.num_prompt_logprobs + assert "r1" not in runner._downstream_payload_cache + assert "r1" not in runner._talker_mtp_generators + + def test_maybe_attach_mimo_audio_req_infos_enriches_dict(): runner = _make_runner_for_mimo() req_id = "r_mimo" diff --git a/tests/worker/test_qwen3_tts_stage0_step_runner.py b/tests/worker/test_qwen3_tts_stage0_step_runner.py new file mode 100644 index 00000000000..d6b565fc840 --- /dev/null +++ b/tests/worker/test_qwen3_tts_stage0_step_runner.py @@ -0,0 +1,149 @@ +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.worker.qwen3_tts_stage0_step_runner import ( + Qwen3TTSSlotTable, + Qwen3TTSStage0StepRunner, + flatten_codec_frames_for_code2wav, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def test_qwen3_tts_slot_table_allocates_and_reuses_slots(): + table = Qwen3TTSSlotTable( + max_slots=2, + hidden_size=4, + num_quantizers=16, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + s0 = table.allocate("r1") + s1 = table.allocate("r2") + assert s0 != s1 + assert table.allocate("r1") == s0 + + table.free("r1") + s2 = table.allocate("r3") + assert s2 == s0 + assert "r1" not in table.req_to_slot + assert table.slots[s2].req_id == "r3" + + +def test_qwen3_tts_slot_table_exhaustion_is_explicit(): + table = Qwen3TTSSlotTable( + max_slots=1, + hidden_size=4, + num_quantizers=16, + device=torch.device("cpu"), + dtype=torch.float32, + ) + table.allocate("r1") + + with pytest.raises(RuntimeError, match="slot table exhausted"): + table.allocate("r2") + + +def _runner_config(async_chunk=True, model_stage="qwen3_tts", has_talker_mtp=True): + return SimpleNamespace( + vllm_config=SimpleNamespace( + model_config=SimpleNamespace( + async_chunk=async_chunk, + model_stage=model_stage, + ) + ), + has_talker_mtp=has_talker_mtp, + ) + + +def test_stage0_step_runner_supports_decode_only_qwen3_tts_async_chunk(): + step_runner = Qwen3TTSStage0StepRunner(max_slots=4, hidden_size=8, num_quantizers=16) + + assert step_runner.supports_step( + runner=_runner_config(), + request_ids=["r1", "r2"], + num_scheduled_tokens=[1, 1], + is_prefill_by_req={"r1": False, "r2": False}, + ) + + +def test_stage0_step_runner_rejects_prefill_or_wrong_stage(): + step_runner = Qwen3TTSStage0StepRunner(max_slots=4, hidden_size=8, num_quantizers=16) + + assert not step_runner.supports_step( + runner=_runner_config(model_stage="code2wav"), + request_ids=["r1"], + num_scheduled_tokens=[1], + is_prefill_by_req={"r1": False}, + ) + assert not step_runner.supports_step( + runner=_runner_config(), + request_ids=["r1"], + num_scheduled_tokens=[1], + is_prefill_by_req={"r1": True}, + ) + assert not step_runner.supports_step( + runner=_runner_config(async_chunk=False), + request_ids=["r1"], + num_scheduled_tokens=[1], + is_prefill_by_req={"r1": False}, + ) + + +def test_stage0_step_runner_commits_next_embeds_and_codes(): + class FakeTalkerMTP: + def __call__(self, input_ids, req_embeds, last_hidden, text_step, **kwargs): + codes = torch.arange(input_ids.shape[0] * 16, dtype=torch.long).reshape(input_ids.shape[0], 16) + return req_embeds + 10, codes + + runner = SimpleNamespace( + talker_mtp=FakeTalkerMTP(), + input_batch=SimpleNamespace(req_ids=["r1", "r2"]), + query_start_loc=SimpleNamespace(cpu=torch.tensor([0, 1], dtype=torch.int32)), + model_intermediate_buffer={}, + requests={ + "r1": SimpleNamespace(additional_information_cpu=None), + "r2": SimpleNamespace(additional_information_cpu=None), + }, + model=SimpleNamespace(talker_mtp_output_key=("codes", "audio"), gpu_resident_buffer_keys=set()), + vllm_config=SimpleNamespace(model_config=SimpleNamespace(subtalker_sampling_params={})), + ) + inputs_embeds = torch.zeros((2, 4), dtype=torch.float32) + + step_runner = Qwen3TTSStage0StepRunner(max_slots=2, hidden_size=4, num_quantizers=16) + prepared = step_runner.prepare_step( + request_ids=["r1", "r2"], + runner=runner, + input_ids=torch.tensor([101, 102], dtype=torch.long), + req_embeds=torch.ones((2, 4), dtype=torch.float32), + last_talker_hidden=torch.ones((2, 4), dtype=torch.float32) * 2, + text_step=torch.ones((2, 4), dtype=torch.float32) * 3, + ) + step_runner.run_step(prepared=prepared, runner=runner) + step_runner.commit_step(prepared=prepared, runner=runner, inputs_embeds=inputs_embeds) + + assert torch.equal(inputs_embeds, torch.ones((2, 4), dtype=torch.float32) * 11) + assert torch.equal(runner.model_intermediate_buffer["r1"]["codes"]["audio"], torch.arange(16).reshape(1, 16)) + assert torch.equal(runner.model_intermediate_buffer["r2"]["codes"]["audio"], torch.arange(16, 32).reshape(1, 16)) + assert runner.requests["r1"].additional_information_cpu is runner.model_intermediate_buffer["r1"] + + +def test_stage0_step_runner_records_fast_path_and_fallback_counts(): + step_runner = Qwen3TTSStage0StepRunner(max_slots=2, hidden_size=4, num_quantizers=16) + + step_runner.record_fast_path(batch_size=2) + step_runner.record_fallback("prefill") + + assert step_runner.stats.fast_path_steps == 1 + assert step_runner.stats.fast_path_requests == 2 + assert step_runner.stats.fallback_reasons["prefill"] == 1 + + +def test_qwen3_tts_slot_codec_frames_match_legacy_flattening(): + frames_fq = torch.arange(4 * 16, dtype=torch.long).reshape(4, 16) + legacy_flat = frames_fq.transpose(0, 1).contiguous().reshape(-1) + + assert torch.equal(flatten_codec_frames_for_code2wav(frames_fq), legacy_flat) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index 5703377daa7..b01a2aa287f 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -302,6 +302,8 @@ class Qwen3TTSTalkerForConditionalGeneration(nn.Module): """vLLM-AR talker: step-wise layer-0 codec decoding. Predicts residual codebooks (1..Q-1) into `audio_codes` and streams text via `tailing_text_hidden`.""" + omni_step_runner_cls = "vllm_omni.worker.qwen3_tts_stage0_step_runner.Qwen3TTSStage0StepRunner" + hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # Talker backbone (Qwen3 decoder-only). diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index f4eadedc046..3974c64ab63 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1,3 +1,4 @@ +import importlib from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -45,6 +46,7 @@ def __init__(self, *args, **kwargs): self.model_intermediate_buffer: dict[str, dict[str, Any]] = {} self._omni_num_scheduled_tokens_np: np.ndarray | None = None self._omni_last_model_output: object | None = None + self.omni_step_runner = None # The Omni tensor prefix cache will be allocated # when we initialize the metadata builders if enabled self.omni_prefix_cache = None @@ -113,6 +115,17 @@ def load_model(self, *args, **kwargs) -> None: ) self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) + self._init_omni_step_runner() + + def _init_omni_step_runner(self) -> None: + cls_path = getattr(self.model, "omni_step_runner_cls", None) + if not cls_path: + self.omni_step_runner = None + return + module_name, class_name = cls_path.rsplit(".", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + self.omni_step_runner = cls.from_runner(self) def _init_mrope_positions(self, req_state: CachedRequestState): """Initialize M-RoPE positions for multimodal inputs. @@ -265,13 +278,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - self.model_intermediate_buffer.pop(req_id, None) - self.num_prompt_logprobs.pop(req_id, None) - if hasattr(self, "_downstream_payload_cache"): - self._downstream_payload_cache.pop(req_id, None) - if hasattr(self, "_talker_mtp_generators"): - self._talker_mtp_generators.pop(req_id, None) + self._free_omni_request_state(req_id) if hasattr(self, "late_interaction_runner"): self.late_interaction_runner.on_requests_finished(scheduler_output.finished_req_ids) @@ -571,6 +578,18 @@ def correct_spec_decode_token_counts(): else: return None + def _free_omni_request_state(self, req_id: str) -> None: + self.requests.pop(req_id, None) + self.model_intermediate_buffer.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) + if hasattr(self, "_downstream_payload_cache"): + self._downstream_payload_cache.pop(req_id, None) + if hasattr(self, "_talker_mtp_generators"): + self._talker_mtp_generators.pop(req_id, None) + step_runner = getattr(self, "omni_step_runner", None) + if step_runner is not None: + step_runner.free_request(req_id) + @torch.inference_mode() def extract_multimodal_outputs(self, hidden_states: torch.Tensor | list[torch.Tensor] | OmniOutput) -> dict: if ( @@ -1304,6 +1323,7 @@ def _preprocess( # Overlay custom prompt_embeds per request for the prompt portion; # collect additional_information (tensor/list) for prefill portion only decode_req_ids = [] + decode_is_prefill_by_req: dict[str, bool] = {} for req_index, req_id in enumerate(self.input_batch.req_ids): req_infos = self.model_intermediate_buffer.get(req_id, {}) @@ -1337,9 +1357,13 @@ def _preprocess( self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) self.text_step.gpu[decode_slice].copy_(text_step) decode_req_ids.append(req_id) + num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_index]) + prompt_token_ids = getattr(req_state, "prompt_token_ids", ()) if req_state is not None else () + prompt_len = len(prompt_token_ids or ()) + decode_is_prefill_by_req[req_id] = num_computed_tokens < prompt_len # TODO(Peiqi): the merge stage could move out from the critical path - self._merge_additional_information_update(req_id, update_dict) + self._update_intermediate_buffer(req_id, update_dict) # update the inputs_embeds and input_ids seg_len = min(span_len, req_embeds.shape[0]) @@ -1348,8 +1372,35 @@ def _preprocess( input_ids[s : s + seg_len] = req_input_ids # run talker mtp decode - if self.has_talker_mtp: - self._talker_mtp_forward(decode_req_ids, inputs_embeds) + if self.has_talker_mtp and decode_req_ids: + batch_size = len(decode_req_ids) + mtp_input_ids = self.talker_mtp_input_ids.gpu[:batch_size] + mtp_req_embeds = self.talker_mtp_inputs_embeds.gpu[:batch_size] + mtp_last_hidden = self.last_talker_hidden.gpu[:batch_size] + mtp_text_step = self.text_step.gpu[:batch_size] + step_runner = getattr(self, "omni_step_runner", None) + use_fast_path = step_runner is not None and step_runner.supports_step( + runner=self, + request_ids=decode_req_ids, + num_scheduled_tokens=[1] * len(decode_req_ids), + is_prefill_by_req=decode_is_prefill_by_req, + ) + if use_fast_path: + prepared = step_runner.prepare_step( + request_ids=decode_req_ids, + runner=self, + input_ids=mtp_input_ids, + req_embeds=mtp_req_embeds, + last_talker_hidden=mtp_last_hidden, + text_step=mtp_text_step, + ) + step_runner.run_step(prepared=prepared, runner=self) + step_runner.commit_step(prepared=prepared, runner=self, inputs_embeds=inputs_embeds) + else: + if step_runner is not None: + reason = getattr(step_runner, "_last_fallback_reason", None) or "unsupported" + step_runner.record_fallback(reason) + self._talker_mtp_forward(decode_req_ids, inputs_embeds) return ( input_ids, @@ -1433,14 +1484,28 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te ) # update the inputs_embeds and code_predictor_codes out_key = getattr(self.model, "talker_mtp_output_key", ("codes", "audio")) - if not isinstance(out_key, tuple) or len(out_key) != 2: - raise TypeError(f"talker_mtp_output_key must be a 2-tuple, got {type(out_key).__name__}: {out_key!r}") + req_index_by_id = {req_id: idx for idx, req_id in enumerate(self.input_batch.req_ids)} + query_start_loc_cpu = self.query_start_loc.cpu for idx, req_id in enumerate(decode_req_ids): - req_index = self.input_batch.req_ids.index(req_id) - start_offset = int(self.query_start_loc.cpu[req_index]) + req_index = req_index_by_id[req_id] + start_offset = int(query_start_loc_cpu[req_index]) inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] - update_dict = {out_key[0]: {out_key[1]: code_predictor_codes[idx : idx + 1]}} - self._merge_additional_information_update(req_id, update_dict) + self._update_talker_mtp_output(req_id, out_key, code_predictor_codes[idx : idx + 1]) + + def _update_talker_mtp_output(self, req_id: str, out_key: tuple[str, str], value: torch.Tensor) -> None: + if not isinstance(out_key, tuple) or len(out_key) != 2: + raise TypeError(f"talker_mtp_output_key must be a 2-tuple, got {type(out_key).__name__}: {out_key!r}") + req_state = self.requests.get(req_id) + if req_state is None: + return + type_key, qual = out_key + gpu_keys: set[tuple[str, str]] = set() + if hasattr(self, "model") and hasattr(self.model, "gpu_resident_buffer_keys"): + gpu_keys = self.model.gpu_resident_buffer_keys + existing = self.model_intermediate_buffer.setdefault(req_id, {}) + existing_sub = existing.setdefault(type_key, {}) + self._store_value(existing_sub, qual, value, {q for tk, q in gpu_keys if tk == type_key}) + setattr(req_state, "additional_information_cpu", existing) def _model_forward( self, diff --git a/vllm_omni/worker/omni_step_runner.py b/vllm_omni/worker/omni_step_runner.py new file mode 100644 index 00000000000..f85df9171f6 --- /dev/null +++ b/vllm_omni/worker/omni_step_runner.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Any, Protocol + +import torch + + +@dataclass(slots=True) +class OmniPreparedStep: + request_ids: list[str] + token_slices: list[slice] = field(default_factory=list) + fallback_request_ids: list[str] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +class OmniStepRunner(Protocol): + @classmethod + def from_runner(cls, runner: Any) -> OmniStepRunner: + ... + + def supports_step( + self, + *, + runner: Any, + request_ids: list[str], + num_scheduled_tokens: Sequence[int], + is_prefill_by_req: Mapping[str, bool], + ) -> bool: + ... + + def prepare_step( + self, + *, + request_ids: list[str], + runner: Any, + input_ids: torch.Tensor, + req_embeds: torch.Tensor, + last_talker_hidden: torch.Tensor, + text_step: torch.Tensor, + ) -> OmniPreparedStep: + ... + + def run_step( + self, + *, + prepared: OmniPreparedStep, + runner: Any, + ) -> None: + ... + + def commit_step( + self, + *, + prepared: OmniPreparedStep, + runner: Any, + inputs_embeds: torch.Tensor, + ) -> None: + ... + + def free_request(self, request_id: str) -> None: + ... diff --git a/vllm_omni/worker/qwen3_tts_stage0_step_runner.py b/vllm_omni/worker/qwen3_tts_stage0_step_runner.py new file mode 100644 index 00000000000..ae747cc8ecc --- /dev/null +++ b/vllm_omni/worker/qwen3_tts_stage0_step_runner.py @@ -0,0 +1,345 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import torch +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.config import CUDAGraphMode +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger + +from vllm_omni.worker.omni_step_runner import OmniPreparedStep + +logger = init_logger(__name__) + + +@dataclass(slots=True) +class Qwen3TTSSlot: + req_id: str | None + prompt_len: int + num_computed_tokens: int + text_offset: int + codec_len: int + emitted_chunks: int + finished: bool + + +class Qwen3TTSSlotTable: + def __init__( + self, + *, + max_slots: int, + hidden_size: int, + num_quantizers: int, + device: torch.device, + dtype: torch.dtype, + ) -> None: + self.max_slots = max_slots + self.hidden_size = hidden_size + self.num_quantizers = num_quantizers + self.slots = [Qwen3TTSSlot(None, 0, 0, 0, 0, 0, False) for _ in range(max_slots)] + self.req_to_slot: dict[str, int] = {} + self.free_slots = list(range(max_slots - 1, -1, -1)) + + self.input_ids = torch.empty((max_slots,), dtype=torch.long, device=device) + self.inputs_embeds = torch.empty((max_slots, hidden_size), dtype=dtype, device=device) + self.last_talker_hidden = torch.empty((max_slots, hidden_size), dtype=dtype, device=device) + self.text_step = torch.empty((max_slots, hidden_size), dtype=dtype, device=device) + self.next_embeds = torch.empty((max_slots, hidden_size), dtype=dtype, device=device) + self.sampled_codes = torch.empty((max_slots, num_quantizers), dtype=torch.long, device=device) + self.codec_frames: dict[str, list[torch.Tensor]] = defaultdict(list) + + def allocate(self, req_id: str) -> int: + existing = self.req_to_slot.get(req_id) + if existing is not None: + return existing + if not self.free_slots: + raise RuntimeError("Qwen3-TTS Stage0 slot table exhausted") + slot = self.free_slots.pop() + self.req_to_slot[req_id] = slot + self.slots[slot] = Qwen3TTSSlot(req_id, 0, 0, 0, 0, 0, False) + return slot + + def free(self, req_id: str) -> None: + slot = self.req_to_slot.pop(req_id, None) + if slot is None: + return + self.slots[slot] = Qwen3TTSSlot(None, 0, 0, 0, 0, 0, False) + self.codec_frames.pop(req_id, None) + self.free_slots.append(slot) + + +@dataclass(slots=True) +class Qwen3TTSPreparedStep(OmniPreparedStep): + slot_indices: list[int] = field(default_factory=list) + query_offsets: list[int] = field(default_factory=list) + next_embeds: torch.Tensor | None = None + sampled_codes: torch.Tensor | None = None + + +@dataclass(slots=True) +class Qwen3TTSStage0StepStats: + fast_path_steps: int = 0 + fast_path_requests: int = 0 + fallback_reasons: dict[str, int] = field(default_factory=dict) + + +def flatten_codec_frames_for_code2wav(frames_fq: torch.Tensor) -> torch.Tensor: + if frames_fq.ndim != 2: + raise ValueError(f"expected [frames, quantizers], got {tuple(frames_fq.shape)}") + return frames_fq.transpose(0, 1).contiguous().reshape(-1) + + +class Qwen3TTSStage0StepRunner: + def __init__( + self, + *, + max_slots: int, + hidden_size: int, + num_quantizers: int = 16, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, + log_every_n_steps: int = 1000, + ) -> None: + if max_slots <= 0: + raise ValueError(f"max_slots must be positive, got {max_slots}") + if hidden_size <= 0: + raise ValueError(f"hidden_size must be positive, got {hidden_size}") + if num_quantizers <= 0: + raise ValueError(f"num_quantizers must be positive, got {num_quantizers}") + device = device or torch.device("cpu") + self.table = Qwen3TTSSlotTable( + max_slots=max_slots, + hidden_size=hidden_size, + num_quantizers=num_quantizers, + device=device, + dtype=dtype, + ) + self.max_slots = max_slots + self.hidden_size = hidden_size + self.num_quantizers = num_quantizers + self.stats = Qwen3TTSStage0StepStats() + self.log_every_n_steps = log_every_n_steps + self._last_fallback_reason: str | None = None + + @classmethod + def from_runner(cls, runner: Any) -> Qwen3TTSStage0StepRunner: + mtp_buffer = runner.talker_mtp_inputs_embeds.gpu + model = getattr(runner, "model", None) + talker_config = getattr(model, "talker_config", None) + num_quantizers = int(getattr(talker_config, "num_code_groups", 0) or 16) + step_runner = cls( + max_slots=int(mtp_buffer.shape[0]), + hidden_size=int(mtp_buffer.shape[-1]), + num_quantizers=num_quantizers, + device=mtp_buffer.device, + dtype=mtp_buffer.dtype, + ) + step_runner.table.input_ids = runner.talker_mtp_input_ids.gpu + step_runner.table.inputs_embeds = runner.talker_mtp_inputs_embeds.gpu + step_runner.table.last_talker_hidden = runner.last_talker_hidden.gpu + step_runner.table.text_step = runner.text_step.gpu + return step_runner + + def _reject(self, reason: str) -> bool: + self._last_fallback_reason = reason + return False + + def supports_step( + self, + *, + runner: Any, + request_ids: list[str], + num_scheduled_tokens: Sequence[int], + is_prefill_by_req: Mapping[str, bool], + ) -> bool: + if not request_ids: + return self._reject("empty") + model_config = getattr(getattr(runner, "vllm_config", None), "model_config", None) + if not bool(getattr(model_config, "async_chunk", False)): + return self._reject("async_chunk_disabled") + if getattr(model_config, "model_stage", None) != "qwen3_tts": + return self._reject("wrong_stage") + if not bool(getattr(runner, "has_talker_mtp", False)): + return self._reject("no_talker_mtp") + if len(num_scheduled_tokens) != len(request_ids): + return self._reject("shape_mismatch") + if any(int(n) != 1 for n in num_scheduled_tokens): + return self._reject("non_decode_step") + if any(bool(is_prefill_by_req.get(req_id, True)) for req_id in request_ids): + return self._reject("prefill") + self._last_fallback_reason = None + return True + + def prepare_step( + self, + *, + request_ids: list[str], + runner: Any, + input_ids: torch.Tensor, + req_embeds: torch.Tensor, + last_talker_hidden: torch.Tensor, + text_step: torch.Tensor, + ) -> Qwen3TTSPreparedStep: + batch_size = len(request_ids) + if batch_size > self.max_slots: + raise RuntimeError(f"Qwen3-TTS Stage0 slot batch too large: {batch_size} > {self.max_slots}") + if req_embeds.shape[-1] != self.hidden_size: + raise ValueError(f"expected hidden_size={self.hidden_size}, got {req_embeds.shape[-1]}") + + req_index_by_id = {req_id: idx for idx, req_id in enumerate(runner.input_batch.req_ids)} + slot_indices: list[int] = [] + query_offsets: list[int] = [] + if self.table.input_ids[:batch_size].data_ptr() != input_ids[:batch_size].data_ptr(): + self.table.input_ids[:batch_size].copy_(input_ids[:batch_size].to(device=self.table.input_ids.device)) + if self.table.inputs_embeds[:batch_size].data_ptr() != req_embeds[:batch_size].data_ptr(): + self.table.inputs_embeds[:batch_size].copy_( + req_embeds[:batch_size].to(device=self.table.inputs_embeds.device) + ) + if self.table.last_talker_hidden[:batch_size].data_ptr() != last_talker_hidden[:batch_size].data_ptr(): + self.table.last_talker_hidden[:batch_size].copy_( + last_talker_hidden[:batch_size].to(device=self.table.last_talker_hidden.device) + ) + if self.table.text_step[:batch_size].data_ptr() != text_step[:batch_size].data_ptr(): + self.table.text_step[:batch_size].copy_(text_step[:batch_size].to(device=self.table.text_step.device)) + for batch_idx, req_id in enumerate(request_ids): + slot_indices.append(self.table.allocate(req_id)) + req_index = req_index_by_id[req_id] + query_offsets.append(int(runner.query_start_loc.cpu[req_index])) + + return Qwen3TTSPreparedStep( + request_ids=list(request_ids), + slot_indices=slot_indices, + query_offsets=query_offsets, + metadata={"batch_size": batch_size}, + ) + + def _talker_kwargs(self, runner: Any, request_ids: list[str], device: torch.device) -> dict[str, Any]: + subtalker_params = getattr(runner.vllm_config.model_config, "subtalker_sampling_params", None) + if not isinstance(subtalker_params, dict): + subtalker_params = {} + talker_kwargs: dict[str, Any] = { + "do_sample": subtalker_params.get("do_sample"), + "temperature": subtalker_params.get("temperature"), + "top_k": subtalker_params.get("top_k"), + "top_p": subtalker_params.get("top_p"), + } + if not request_ids: + return talker_kwargs + first_req_id = request_ids[0] + first_sp = getattr(runner.requests[first_req_id], "sampling_params", None) + extra_args = getattr(first_sp, "extra_args", None) if first_sp is not None else None + seed = extra_args.get("qwen3_tts_request_seed") if isinstance(extra_args, dict) else None + if seed is None: + return talker_kwargs + generators = getattr(runner, "_talker_mtp_generators", None) + if generators is None: + generators = {} + runner._talker_mtp_generators = generators + generator = generators.get(first_req_id) + if generator is None or generator.device != device: + generator = torch.Generator(device=device) + generator.manual_seed(int(seed)) + generators[first_req_id] = generator + talker_kwargs["generator"] = generator + return talker_kwargs + + def run_step( + self, + *, + prepared: Qwen3TTSPreparedStep, + runner: Any, + ) -> None: + batch_size = int(prepared.metadata["batch_size"]) + if batch_size == 0: + return + if hasattr(runner, "_determine_batch_execution_and_padding"): + cudagraph_mode, batch_desc, _, _, _ = runner._determine_batch_execution_and_padding( + num_tokens=batch_size, + num_reqs=batch_size, + num_scheduled_tokens_np=np.ones(batch_size, dtype=np.int32), + max_num_scheduled_tokens=1, + use_cascade_attn=False, + ) + if not isinstance(runner.talker_mtp, CUDAGraphWrapper): + cudagraph_mode = CUDAGraphMode.NONE + num_tokens_padded = batch_size + else: + num_tokens_padded = batch_desc.num_tokens + talker_kwargs = self._talker_kwargs(runner, prepared.request_ids, self.table.input_ids.device) + with set_forward_context( + None, + runner.vllm_config, + cudagraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_desc, + ): + next_embeds, sampled_codes = runner.talker_mtp( + self.table.input_ids[:num_tokens_padded], + self.table.inputs_embeds[:num_tokens_padded], + self.table.last_talker_hidden[:num_tokens_padded], + self.table.text_step[:num_tokens_padded], + **talker_kwargs, + ) + else: + next_embeds, sampled_codes = runner.talker_mtp( + self.table.input_ids[:batch_size], + self.table.inputs_embeds[:batch_size], + self.table.last_talker_hidden[:batch_size], + self.table.text_step[:batch_size], + ) + prepared.next_embeds = next_embeds[:batch_size] + prepared.sampled_codes = sampled_codes[:batch_size] + + def commit_step( + self, + *, + prepared: Qwen3TTSPreparedStep, + runner: Any, + inputs_embeds: torch.Tensor, + ) -> None: + if prepared.next_embeds is None or prepared.sampled_codes is None: + raise RuntimeError("run_step must be called before commit_step") + out_key = getattr(runner.model, "talker_mtp_output_key", ("codes", "audio")) + for idx, req_id in enumerate(prepared.request_ids): + start_offset = prepared.query_offsets[idx] + inputs_embeds[start_offset : start_offset + 1] = prepared.next_embeds[idx : idx + 1] + codes = prepared.sampled_codes[idx : idx + 1] + update = getattr(runner, "_update_talker_mtp_output", None) + if update is None: + self._update_runner_buffer(runner, req_id, out_key, codes) + else: + update(req_id, out_key, codes) + self.record_fast_path(batch_size=len(prepared.request_ids)) + + def _update_runner_buffer(self, runner: Any, req_id: str, out_key: tuple[str, str], value: torch.Tensor) -> None: + type_key, qual = out_key + existing = runner.model_intermediate_buffer.setdefault(req_id, {}) + existing_sub = existing.setdefault(type_key, {}) + existing_sub[qual] = value.detach().clone() + req_state = runner.requests.get(req_id) + if req_state is not None: + setattr(req_state, "additional_information_cpu", existing) + + def record_fast_path(self, *, batch_size: int) -> None: + self.stats.fast_path_steps += 1 + self.stats.fast_path_requests += int(batch_size) + if self.log_every_n_steps > 0 and self.stats.fast_path_steps % self.log_every_n_steps == 0: + logger.info( + "Qwen3-TTS Stage0 fast path stats: steps=%d, requests=%d, fallback=%s", + self.stats.fast_path_steps, + self.stats.fast_path_requests, + self.stats.fallback_reasons, + ) + + def record_fallback(self, reason: str) -> None: + self.stats.fallback_reasons[reason] = self.stats.fallback_reasons.get(reason, 0) + 1 + + def free_request(self, request_id: str) -> None: + self.table.free(request_id) From 4beba32b161a2d792f0ac19dc0c41113de1237f1 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Thu, 14 May 2026 19:05:59 +0800 Subject: [PATCH 3/6] feat: batch qwen3 tts stage0 preprocessing Signed-off-by: Sy03 <1370724210@qq.com> --- artifacts/qwen3_tts_ws1_baseline/README.md | 20 + .../test_qwen3_tts_talker_preprocess.py | 442 ++++++++++++++++++ tests/worker/test_omni_gpu_model_runner.py | 50 +- .../test_qwen3_tts_stage0_step_runner.py | 149 ------ .../models/qwen3_tts/qwen3_tts_talker.py | 315 ++++++++++++- vllm_omni/worker/gpu_model_runner.py | 125 ++--- vllm_omni/worker/omni_step_runner.py | 66 --- .../worker/qwen3_tts_stage0_step_runner.py | 345 -------------- 8 files changed, 811 insertions(+), 701 deletions(-) create mode 100644 tests/model_executor/models/qwen3_tts/test_qwen3_tts_talker_preprocess.py delete mode 100644 tests/worker/test_qwen3_tts_stage0_step_runner.py delete mode 100644 vllm_omni/worker/omni_step_runner.py delete mode 100644 vllm_omni/worker/qwen3_tts_stage0_step_runner.py diff --git a/artifacts/qwen3_tts_ws1_baseline/README.md b/artifacts/qwen3_tts_ws1_baseline/README.md index 4275dcbd0dc..e424d1f5d3d 100644 --- a/artifacts/qwen3_tts_ws1_baseline/README.md +++ b/artifacts/qwen3_tts_ws1_baseline/README.md @@ -21,3 +21,23 @@ Metrics: - audio throughput - request throughput - failed request count + +Validated WS1 result: +- Change: batched Stage0 Base voice_clone preprocessing for tokenizer ids, + ref-audio normalization, and same-sample-rate ref_code encoding. +- Remote result: + `/home/admin/workspace/remote_workspace/qwen3_stage0_slot_runner_ab_20260514_1840/results_20260514_190000/ab_summary.json` +- Workload: 2x H20, GPU pair `0,1`, concurrency 64, prompts 256, + warmups 2, Stage0 `max_num_seqs=64`, Stage1 `max_num_seqs=10`. +- Correctness smoke: new and old both completed 256 requests, failed requests + 0 in the benchmark log, with nonzero audio output (`1078.00s` new, + `1076.96s` old). + +| Metric | New | Old | Delta | +| --- | ---: | ---: | ---: | +| Audio throughput | 28.8289 | 25.5383 | +12.89% | +| Request throughput | 6.8462 | 6.0706 | +12.78% | +| Median audio RTF | 2.1888 | 2.4626 | -11.12% | +| Median audio TTFP ms | 1573.48 | 1634.63 | -3.74% | +| P99 audio TTFP ms | 5032.01 | 7319.04 | -31.25% | +| Median E2EL ms | 8746.86 | 9949.19 | -12.08% | diff --git a/tests/model_executor/models/qwen3_tts/test_qwen3_tts_talker_preprocess.py b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_talker_preprocess.py new file mode 100644 index 00000000000..826551bdc26 --- /dev/null +++ b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_talker_preprocess.py @@ -0,0 +1,442 @@ +from types import SimpleNamespace + +import numpy as np +import torch + +from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( + _NORMALIZED_REF_AUDIO_KEY, + _PRECOMPUTED_REF_CODE_KEY, + _PRECOMPUTED_REF_IDS_KEY, + _PRECOMPUTED_TEXT_IDS_KEY, + Qwen3TTSTalkerForConditionalGeneration, +) + + +def _make_minimal_talker(): + model = Qwen3TTSTalkerForConditionalGeneration.__new__(Qwen3TTSTalkerForConditionalGeneration) + model.talker_config = SimpleNamespace(codec_pad_id=7, num_code_groups=16) + return model + + +def test_single_token_prefill_uses_prefill_path(): + model = _make_minimal_talker() + full_prompt_embeds = torch.arange(12, dtype=torch.float32).reshape(3, 4) + trailing_text = torch.ones((2, 4), dtype=torch.float32) + tts_pad = torch.full((1, 4), 0.5, dtype=torch.float32) + ref_code = torch.arange(32, dtype=torch.long).reshape(2, 16) + + def fake_build_prompt_embeds(*, task_type, info_dict): + return full_prompt_embeds, trailing_text, tts_pad, 2, ref_code + + model._build_prompt_embeds = fake_build_prompt_embeds + + input_ids = torch.tensor([123], dtype=torch.long) + out_ids, out_embeds, update = model.preprocess( + input_ids=input_ids, + input_embeds=None, + text=["hello"], + task_type=["CustomVoice"], + _omni_is_prefill=True, + _omni_num_computed_tokens=0, + _omni_prompt_len=3, + ) + + assert out_ids.tolist() == [7] + assert torch.equal(out_embeds.cpu(), full_prompt_embeds[:1].to(torch.bfloat16)) + assert update["meta"]["talker_prefill_offset"] == 1 + assert update["meta"]["talker_text_offset"] == 0 + assert update["meta"]["ref_code_len"] == 2 + assert torch.equal(update["embed"]["prefill"], full_prompt_embeds) + assert torch.equal(update["embed"]["tts_pad"], tts_pad) + assert torch.equal(update["hidden_states"]["trailing_text"], trailing_text) + assert torch.equal(update["codes"]["ref"], ref_code) + assert update["codes"]["audio"].shape == (1, 16) + + +def test_single_token_prefill_can_be_inferred_from_token_progress(): + model = _make_minimal_talker() + full_prompt_embeds = torch.arange(8, dtype=torch.float32).reshape(2, 4) + trailing_text = torch.ones((1, 4), dtype=torch.float32) + tts_pad = torch.zeros((1, 4), dtype=torch.float32) + + def fake_build_prompt_embeds(*, task_type, info_dict): + return full_prompt_embeds, trailing_text, tts_pad, None, None + + model._build_prompt_embeds = fake_build_prompt_embeds + + out_ids, out_embeds, update = model.preprocess( + input_ids=torch.tensor([123], dtype=torch.long), + input_embeds=None, + text=["hello"], + task_type=["CustomVoice"], + _omni_num_computed_tokens=0, + _omni_prompt_len=2, + ) + + assert out_ids.tolist() == [7] + assert torch.equal(out_embeds.cpu(), full_prompt_embeds[:1].to(torch.bfloat16)) + assert update["meta"]["talker_prefill_offset"] == 1 + + +def test_decode_advances_trailing_text_by_offset_without_rewriting_tail(): + model = _make_minimal_talker() + + def fake_embed_input_ids(input_ids): + return input_ids.to(torch.float32).reshape(1, 1, 1).expand(1, 1, 4) + + model.embed_input_ids = fake_embed_input_ids + trailing_text = torch.arange(12, dtype=torch.float32).reshape(3, 4) + last_hidden = torch.full((4,), 2.0, dtype=torch.float32) + tts_pad = torch.full((1, 4), -1.0, dtype=torch.float32) + + out_ids, out_embeds, update = model.preprocess( + input_ids=torch.tensor([123], dtype=torch.long), + input_embeds=None, + text=["hello"], + task_type=["CustomVoice"], + hidden_states={"trailing_text": trailing_text, "last": last_hidden}, + embed={"tts_pad": tts_pad}, + meta={"talker_text_offset": 1}, + _omni_is_prefill=False, + _omni_num_computed_tokens=2, + _omni_prompt_len=2, + ) + + assert out_ids.tolist() == [123] + assert torch.equal(out_embeds.cpu(), torch.full((1, 4), 123.0, dtype=torch.bfloat16)) + assert "hidden_states" not in update + assert update["meta"]["talker_text_offset"] == 2 + past_hidden, text_step = update["mtp_inputs"] + assert torch.equal(past_hidden.cpu(), last_hidden.reshape(1, -1).to(torch.bfloat16)) + assert torch.equal(text_step.cpu(), trailing_text[1:2].to(torch.bfloat16)) + + +def test_decode_advances_trailing_text_offset_across_multiple_steps(): + model = _make_minimal_talker() + + def fake_embed_input_ids(input_ids): + return input_ids.to(torch.float32).reshape(1, 1, 1).expand(1, 1, 4) + + model.embed_input_ids = fake_embed_input_ids + trailing_text = torch.arange(8, dtype=torch.float32).reshape(2, 4) + state_tail = trailing_text + last_hidden = torch.full((4,), 2.0, dtype=torch.float32) + tts_pad = torch.full((1, 4), -1.0, dtype=torch.float32) + meta = {"talker_text_offset": 0} + seen_steps = [] + + for _ in range(3): + _, _, update = model.preprocess( + input_ids=torch.tensor([123], dtype=torch.long), + input_embeds=None, + text=["hello"], + task_type=["CustomVoice"], + hidden_states={"trailing_text": state_tail, "last": last_hidden}, + embed={"tts_pad": tts_pad}, + meta=meta, + _omni_is_prefill=False, + _omni_num_computed_tokens=2, + _omni_prompt_len=2, + ) + seen_steps.append(update["mtp_inputs"][1].cpu()) + if "hidden_states" in update and "trailing_text" in update["hidden_states"]: + state_tail = update["hidden_states"]["trailing_text"] + meta = update["meta"] + + assert torch.equal(seen_steps[0], trailing_text[0:1].to(torch.bfloat16)) + assert torch.equal(seen_steps[1], trailing_text[1:2].to(torch.bfloat16)) + assert torch.equal(seen_steps[2], tts_pad.to(torch.bfloat16)) + assert meta["talker_text_offset"] == 0 + assert state_tail.numel() == 0 + + +def test_decode_compacts_long_trailing_text_after_large_offset(): + model = _make_minimal_talker() + + def fake_embed_input_ids(input_ids): + return input_ids.to(torch.float32).reshape(1, 1, 1).expand(1, 1, 4) + + model.embed_input_ids = fake_embed_input_ids + trailing_text = torch.arange(130 * 4, dtype=torch.float32).reshape(130, 4) + last_hidden = torch.full((4,), 2.0, dtype=torch.float32) + tts_pad = torch.full((1, 4), -1.0, dtype=torch.float32) + + _, _, update = model.preprocess( + input_ids=torch.tensor([123], dtype=torch.long), + input_embeds=None, + text=["hello"], + task_type=["CustomVoice"], + hidden_states={"trailing_text": trailing_text, "last": last_hidden}, + embed={"tts_pad": tts_pad}, + meta={"talker_text_offset": 64}, + _omni_is_prefill=False, + _omni_num_computed_tokens=2, + _omni_prompt_len=2, + ) + + assert torch.equal(update["mtp_inputs"][1].cpu(), trailing_text[64:65].to(torch.bfloat16)) + assert update["meta"]["talker_text_offset"] == 0 + assert torch.equal(update["hidden_states"]["trailing_text"], trailing_text[65:]) + + +def test_base_voice_clone_normalizes_ref_audio_once_for_ref_code_and_speaker(): + model = _make_minimal_talker() + device_param = torch.nn.Parameter(torch.empty(0)) + model.parameters = lambda: iter([device_param]) + model.config = SimpleNamespace( + tts_bos_token_id=100, + tts_eos_token_id=101, + tts_pad_token_id=102, + ) + model.talker_config = SimpleNamespace( + codec_nothink_id=10, + codec_think_bos_id=11, + codec_think_eos_id=12, + codec_think_id=13, + codec_language_id={}, + codec_pad_id=7, + codec_bos_id=8, + num_code_groups=2, + spk_is_dialect={}, + ) + + class FakeTokenizer: + def __call__(self, *_args, **_kwargs): + return {"input_ids": torch.arange(8, dtype=torch.long).reshape(1, -1)} + + model._get_tokenizer = lambda: FakeTokenizer() + model.text_embedding = lambda ids: torch.ones((*ids.shape, 4), device=ids.device) + model.text_projection = lambda embeds: embeds + model.embed_input_ids = lambda ids: torch.zeros((*ids.shape, 4), device=ids.device) + model._generate_icl_prompt = lambda **kwargs: ( + torch.ones((1, 2, 4), device=kwargs["ref_code"].device), + torch.ones((1, 4), device=kwargs["ref_code"].device), + ) + + normalize_calls = [] + ref_audio = np.arange(1024, dtype=np.float32) + model._normalize_ref_audio = lambda raw: normalize_calls.append(raw) or (ref_audio, 16000) + + ref_audio_ids = [] + model._encode_ref_audio_to_code = lambda wav, _sr: ref_audio_ids.append(id(wav)) or torch.ones( + (2, 2), dtype=torch.long + ) + model._extract_speaker_embedding = lambda wav, _sr: ref_audio_ids.append(id(wav)) or torch.ones( + 4, dtype=torch.bfloat16 + ) + + _prompt, _trailing, _pad, ref_code_len, ref_code = model._build_prompt_embeds( + task_type="Base", + info_dict={ + "text": ["hello"], + "ref_audio": ["ref.wav"], + "ref_ids": torch.arange(8, dtype=torch.long).reshape(1, -1), + "non_streaming_mode": [False], + }, + ) + + assert normalize_calls == ["ref.wav"] + assert ref_audio_ids == [id(ref_audio), id(ref_audio)] + assert ref_code_len == 2 + assert torch.equal(ref_code, torch.ones((2, 2), dtype=torch.long)) + + +def test_base_voice_clone_batch_preprocess_encodes_ref_code_by_sample_rate(): + model = _make_minimal_talker() + wav1 = np.arange(2048, dtype=np.float32) + wav2 = np.arange(3072, dtype=np.float32) + normalize_calls = [] + model._normalize_ref_audio = lambda raw: normalize_calls.append(raw) or ( + wav1 if raw == "a.wav" else wav2, + 16000, + ) + + class FakeSpeechTokenizer: + def __init__(self): + self.calls = [] + + def encode(self, audios, *, sr, return_dict): + self.calls.append((audios, sr, return_dict)) + return SimpleNamespace( + audio_codes=[ + torch.full((2, 2), 11, dtype=torch.long), + torch.full((3, 2), 22, dtype=torch.long), + ] + ) + + tok = FakeSpeechTokenizer() + model._ensure_speech_tokenizer_loaded = lambda: tok + + class FakeTextTokenizer: + def __init__(self): + self.calls = [] + + def __call__(self, texts, *, padding=False): + self.calls.append((texts, padding)) + return {"input_ids": [[idx + 1, idx + 2, idx + 3] for idx, _ in enumerate(texts)]} + + text_tok = FakeTextTokenizer() + model._get_tokenizer = lambda: text_tok + buf = { + "r1": { + "task_type": ["Base"], + "text": ["one"], + "ref_audio": ["a.wav"], + "ref_text": ["hello"], + "x_vector_only_mode": [False], + }, + "r2": { + "task_type": ["Base"], + "text": ["two"], + "ref_audio": ["b.wav"], + "ref_text": ["world"], + "x_vector_only_mode": [False], + }, + } + + model.preprocess_batch( + req_ids=["r1", "r2"], + model_intermediate_buffer=buf, + device=torch.device("cpu"), + ) + + assert normalize_calls == ["a.wav", "b.wav"] + assert len(tok.calls) == 1 + audios, sr, return_dict = tok.calls[0] + assert audios[0] is wav1 + assert audios[1] is wav2 + assert sr == 16000 + assert return_dict is True + assert torch.equal(buf["r1"]["codes"][_PRECOMPUTED_REF_CODE_KEY], torch.full((2, 2), 11)) + assert torch.equal(buf["r2"]["codes"][_PRECOMPUTED_REF_CODE_KEY], torch.full((3, 2), 22)) + assert buf["r1"][_NORMALIZED_REF_AUDIO_KEY][0] is wav1 + assert buf["r2"][_NORMALIZED_REF_AUDIO_KEY][0] is wav2 + assert len(text_tok.calls) == 2 + assert torch.equal(buf["r1"][_PRECOMPUTED_TEXT_IDS_KEY], torch.tensor([1, 2, 3])) + assert torch.equal(buf["r2"][_PRECOMPUTED_TEXT_IDS_KEY], torch.tensor([2, 3, 4])) + assert torch.equal(buf["r1"][_PRECOMPUTED_REF_IDS_KEY], torch.tensor([1, 2, 3])) + assert torch.equal(buf["r2"][_PRECOMPUTED_REF_IDS_KEY], torch.tensor([2, 3, 4])) + + +def test_base_voice_clone_batch_preprocess_reuses_singleton_normalized_audio_without_speech_tokenizer(): + model = _make_minimal_talker() + wav = np.arange(2048, dtype=np.float32) + model._normalize_ref_audio = lambda raw: (wav, 16000) + model._ensure_speech_tokenizer_loaded = lambda: (_ for _ in ()).throw( + AssertionError("singleton should not load speech tokenizer") + ) + + class FakeTextTokenizer: + def __call__(self, texts, *, padding=False): + return {"input_ids": [[7, 8, 9] for _ in texts]} + + model._get_tokenizer = lambda: FakeTextTokenizer() + buf = { + "r1": { + "task_type": ["Base"], + "text": ["one"], + "ref_audio": ["a.wav"], + "ref_text": ["hello"], + "x_vector_only_mode": [False], + } + } + + model.preprocess_batch( + req_ids=["r1"], + model_intermediate_buffer=buf, + device=torch.device("cpu"), + ) + + assert buf["r1"][_NORMALIZED_REF_AUDIO_KEY][0] is wav + assert torch.equal(buf["r1"][_PRECOMPUTED_TEXT_IDS_KEY], torch.tensor([7, 8, 9])) + assert torch.equal(buf["r1"][_PRECOMPUTED_REF_IDS_KEY], torch.tensor([7, 8, 9])) + + +def test_base_voice_clone_batch_preprocess_skips_after_initial_prefill_state_exists(): + model = _make_minimal_talker() + model._normalize_ref_audio = lambda _raw: (_ for _ in ()).throw(AssertionError("normalize not expected")) + model._get_tokenizer = lambda: (_ for _ in ()).throw(AssertionError("tokenizer not expected")) + model._ensure_speech_tokenizer_loaded = lambda: (_ for _ in ()).throw( + AssertionError("speech tokenizer not expected") + ) + buf = { + "r1": { + "task_type": ["Base"], + "text": ["one"], + "ref_audio": ["a.wav"], + "ref_text": ["hello"], + "x_vector_only_mode": [False], + "embed": {"prefill": torch.ones((1, 4))}, + } + } + + model.preprocess_batch( + req_ids=["r1"], + model_intermediate_buffer=buf, + device=torch.device("cpu"), + ) + + assert _PRECOMPUTED_TEXT_IDS_KEY not in buf["r1"] + assert _NORMALIZED_REF_AUDIO_KEY not in buf["r1"] + + +def test_base_voice_clone_uses_batched_ref_code_without_serial_encode(): + model = _make_minimal_talker() + device_param = torch.nn.Parameter(torch.empty(0)) + model.parameters = lambda: iter([device_param]) + model.config = SimpleNamespace( + tts_bos_token_id=100, + tts_eos_token_id=101, + tts_pad_token_id=102, + ) + model.talker_config = SimpleNamespace( + codec_nothink_id=10, + codec_think_bos_id=11, + codec_think_eos_id=12, + codec_think_id=13, + codec_language_id={}, + codec_pad_id=7, + codec_bos_id=8, + num_code_groups=2, + spk_is_dialect={}, + ) + + class FakeTokenizer: + def __call__(self, *_args, **_kwargs): + return {"input_ids": torch.arange(8, dtype=torch.long).reshape(1, -1)} + + model._get_tokenizer = lambda: FakeTokenizer() + model.text_embedding = lambda ids: torch.ones((*ids.shape, 4), device=ids.device) + model.text_projection = lambda embeds: embeds + model.embed_input_ids = lambda ids: torch.zeros((*ids.shape, 4), device=ids.device) + model._generate_icl_prompt = lambda **kwargs: ( + torch.ones((1, 2, 4), device=kwargs["ref_code"].device), + torch.ones((1, 4), device=kwargs["ref_code"].device), + ) + + ref_audio = np.arange(2048, dtype=np.float32) + ref_code = torch.arange(4, dtype=torch.long).reshape(2, 2) + model._normalize_ref_audio = lambda _raw: (_ for _ in ()).throw(AssertionError("serial normalize not expected")) + model._encode_ref_audio_to_code = lambda _wav, _sr: (_ for _ in ()).throw( + AssertionError("serial encode not expected") + ) + speaker_wav_ids = [] + model._extract_speaker_embedding = lambda wav, _sr: speaker_wav_ids.append(id(wav)) or torch.ones( + 4, dtype=torch.bfloat16 + ) + + _prompt, _trailing, _pad, ref_code_len, out_ref_code = model._build_prompt_embeds( + task_type="Base", + info_dict={ + "text": ["hello"], + "ref_audio": ["ref.wav"], + "ref_ids": torch.arange(8, dtype=torch.long).reshape(1, -1), + "non_streaming_mode": [False], + "codes": {_PRECOMPUTED_REF_CODE_KEY: ref_code}, + _NORMALIZED_REF_AUDIO_KEY: (ref_audio, 16000), + }, + ) + + assert speaker_wav_ids == [id(ref_audio)] + assert ref_code_len == 2 + assert torch.equal(out_ref_code, ref_code) diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index 75a87835f9b..2d6f11e7af8 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -173,12 +173,6 @@ def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_ monkeypatch.setattr(runner, "_determine_batch_execution_and_padding", fake_determine.__get__(runner, type(runner))) - class NoIndexList(list): - def index(self, *args, **kwargs): - raise AssertionError("_talker_mtp_forward should not linearly search req_ids per request") - - runner.input_batch.req_ids = NoIndexList(runner.input_batch.req_ids) - # Initialize per-request embeds (batch-major inside talker_mtp_inputs_embeds) runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0]) runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0]) @@ -336,42 +330,28 @@ def test_update_intermediate_buffer_skips_unknown_req_id(): assert "unknown_req" not in runner.model_intermediate_buffer -def test_update_talker_mtp_output_writes_single_nested_value(): - runner = _make_runner(req_ids=("r1",), hidden_size=4) - runner.model.gpu_resident_buffer_keys = {("codes", "audio")} - src = torch.tensor([[1, 2, 3]], dtype=torch.long) +def test_maybe_run_batch_preprocess_calls_model_hook(): + runner = object.__new__(OmniGPUModelRunner) + runner.model_intermediate_buffer = {"r1": {"text": ["hello"]}} + calls = [] - OmniGPUModelRunner._update_talker_mtp_output(runner, "r1", ("codes", "audio"), src) + class DummyModel: + def preprocess_batch(self, *, req_ids, model_intermediate_buffer, device): + calls.append((req_ids, model_intermediate_buffer, device)) - stored = runner.model_intermediate_buffer["r1"]["codes"]["audio"] - assert torch.equal(stored, src) - assert stored.data_ptr() != src.data_ptr() - assert runner.requests["r1"].additional_information_cpu is runner.model_intermediate_buffer["r1"] + runner.model = DummyModel() + OmniGPUModelRunner._maybe_run_batch_preprocess(runner, ["r1"], torch.device("cpu")) -def test_optional_omni_step_runner_cleanup_is_called(): - runner = object.__new__(OmniGPUModelRunner) - freed = [] + assert calls == [(["r1"], runner.model_intermediate_buffer, torch.device("cpu"))] - class DummyStepRunner: - def free_request(self, req_id): - freed.append(req_id) - runner.omni_step_runner = DummyStepRunner() - runner.requests = {"r1": object()} - runner.model_intermediate_buffer = {"r1": {"codes": {"audio": 1}}} - runner.num_prompt_logprobs = {"r1": 0} - runner._downstream_payload_cache = {"r1": object()} - runner._talker_mtp_generators = {"r1": object()} - - OmniGPUModelRunner._free_omni_request_state(runner, "r1") +def test_maybe_run_batch_preprocess_skips_missing_hook(): + runner = object.__new__(OmniGPUModelRunner) + runner.model_intermediate_buffer = {} + runner.model = object() - assert freed == ["r1"] - assert "r1" not in runner.requests - assert "r1" not in runner.model_intermediate_buffer - assert "r1" not in runner.num_prompt_logprobs - assert "r1" not in runner._downstream_payload_cache - assert "r1" not in runner._talker_mtp_generators + OmniGPUModelRunner._maybe_run_batch_preprocess(runner, ["r1"], torch.device("cpu")) def test_maybe_attach_mimo_audio_req_infos_enriches_dict(): diff --git a/tests/worker/test_qwen3_tts_stage0_step_runner.py b/tests/worker/test_qwen3_tts_stage0_step_runner.py deleted file mode 100644 index d6b565fc840..00000000000 --- a/tests/worker/test_qwen3_tts_stage0_step_runner.py +++ /dev/null @@ -1,149 +0,0 @@ -from types import SimpleNamespace - -import pytest -import torch - -from vllm_omni.worker.qwen3_tts_stage0_step_runner import ( - Qwen3TTSSlotTable, - Qwen3TTSStage0StepRunner, - flatten_codec_frames_for_code2wav, -) - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -def test_qwen3_tts_slot_table_allocates_and_reuses_slots(): - table = Qwen3TTSSlotTable( - max_slots=2, - hidden_size=4, - num_quantizers=16, - device=torch.device("cpu"), - dtype=torch.float32, - ) - - s0 = table.allocate("r1") - s1 = table.allocate("r2") - assert s0 != s1 - assert table.allocate("r1") == s0 - - table.free("r1") - s2 = table.allocate("r3") - assert s2 == s0 - assert "r1" not in table.req_to_slot - assert table.slots[s2].req_id == "r3" - - -def test_qwen3_tts_slot_table_exhaustion_is_explicit(): - table = Qwen3TTSSlotTable( - max_slots=1, - hidden_size=4, - num_quantizers=16, - device=torch.device("cpu"), - dtype=torch.float32, - ) - table.allocate("r1") - - with pytest.raises(RuntimeError, match="slot table exhausted"): - table.allocate("r2") - - -def _runner_config(async_chunk=True, model_stage="qwen3_tts", has_talker_mtp=True): - return SimpleNamespace( - vllm_config=SimpleNamespace( - model_config=SimpleNamespace( - async_chunk=async_chunk, - model_stage=model_stage, - ) - ), - has_talker_mtp=has_talker_mtp, - ) - - -def test_stage0_step_runner_supports_decode_only_qwen3_tts_async_chunk(): - step_runner = Qwen3TTSStage0StepRunner(max_slots=4, hidden_size=8, num_quantizers=16) - - assert step_runner.supports_step( - runner=_runner_config(), - request_ids=["r1", "r2"], - num_scheduled_tokens=[1, 1], - is_prefill_by_req={"r1": False, "r2": False}, - ) - - -def test_stage0_step_runner_rejects_prefill_or_wrong_stage(): - step_runner = Qwen3TTSStage0StepRunner(max_slots=4, hidden_size=8, num_quantizers=16) - - assert not step_runner.supports_step( - runner=_runner_config(model_stage="code2wav"), - request_ids=["r1"], - num_scheduled_tokens=[1], - is_prefill_by_req={"r1": False}, - ) - assert not step_runner.supports_step( - runner=_runner_config(), - request_ids=["r1"], - num_scheduled_tokens=[1], - is_prefill_by_req={"r1": True}, - ) - assert not step_runner.supports_step( - runner=_runner_config(async_chunk=False), - request_ids=["r1"], - num_scheduled_tokens=[1], - is_prefill_by_req={"r1": False}, - ) - - -def test_stage0_step_runner_commits_next_embeds_and_codes(): - class FakeTalkerMTP: - def __call__(self, input_ids, req_embeds, last_hidden, text_step, **kwargs): - codes = torch.arange(input_ids.shape[0] * 16, dtype=torch.long).reshape(input_ids.shape[0], 16) - return req_embeds + 10, codes - - runner = SimpleNamespace( - talker_mtp=FakeTalkerMTP(), - input_batch=SimpleNamespace(req_ids=["r1", "r2"]), - query_start_loc=SimpleNamespace(cpu=torch.tensor([0, 1], dtype=torch.int32)), - model_intermediate_buffer={}, - requests={ - "r1": SimpleNamespace(additional_information_cpu=None), - "r2": SimpleNamespace(additional_information_cpu=None), - }, - model=SimpleNamespace(talker_mtp_output_key=("codes", "audio"), gpu_resident_buffer_keys=set()), - vllm_config=SimpleNamespace(model_config=SimpleNamespace(subtalker_sampling_params={})), - ) - inputs_embeds = torch.zeros((2, 4), dtype=torch.float32) - - step_runner = Qwen3TTSStage0StepRunner(max_slots=2, hidden_size=4, num_quantizers=16) - prepared = step_runner.prepare_step( - request_ids=["r1", "r2"], - runner=runner, - input_ids=torch.tensor([101, 102], dtype=torch.long), - req_embeds=torch.ones((2, 4), dtype=torch.float32), - last_talker_hidden=torch.ones((2, 4), dtype=torch.float32) * 2, - text_step=torch.ones((2, 4), dtype=torch.float32) * 3, - ) - step_runner.run_step(prepared=prepared, runner=runner) - step_runner.commit_step(prepared=prepared, runner=runner, inputs_embeds=inputs_embeds) - - assert torch.equal(inputs_embeds, torch.ones((2, 4), dtype=torch.float32) * 11) - assert torch.equal(runner.model_intermediate_buffer["r1"]["codes"]["audio"], torch.arange(16).reshape(1, 16)) - assert torch.equal(runner.model_intermediate_buffer["r2"]["codes"]["audio"], torch.arange(16, 32).reshape(1, 16)) - assert runner.requests["r1"].additional_information_cpu is runner.model_intermediate_buffer["r1"] - - -def test_stage0_step_runner_records_fast_path_and_fallback_counts(): - step_runner = Qwen3TTSStage0StepRunner(max_slots=2, hidden_size=4, num_quantizers=16) - - step_runner.record_fast_path(batch_size=2) - step_runner.record_fallback("prefill") - - assert step_runner.stats.fast_path_steps == 1 - assert step_runner.stats.fast_path_requests == 2 - assert step_runner.stats.fallback_reasons["prefill"] == 1 - - -def test_qwen3_tts_slot_codec_frames_match_legacy_flattening(): - frames_fq = torch.arange(4 * 16, dtype=torch.long).reshape(4, 16) - legacy_flat = frames_fq.transpose(0, 1).contiguous().reshape(-1) - - assert torch.equal(flatten_codec_frames_for_code2wav(frames_fq), legacy_flat) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index b01a2aa287f..bdf702631b1 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -37,6 +37,12 @@ logger = init_logger(__name__) +_TRAILING_TEXT_COMPACT_MIN_FRAMES = 64 +_PRECOMPUTED_REF_CODE_KEY = "precomputed_ref" +_NORMALIZED_REF_AUDIO_KEY = "_qwen3_tts_normalized_ref_audio" +_PRECOMPUTED_TEXT_IDS_KEY = "_qwen3_tts_text_ids" +_PRECOMPUTED_REF_IDS_KEY = "_qwen3_tts_ref_ids" + # --------------------------------------------------------------------------- # Components ported from the HuggingFace Qwen3-TTS reference implementation. @@ -302,8 +308,6 @@ class Qwen3TTSTalkerForConditionalGeneration(nn.Module): """vLLM-AR talker: step-wise layer-0 codec decoding. Predicts residual codebooks (1..Q-1) into `audio_codes` and streams text via `tailing_text_hidden`.""" - omni_step_runner_cls = "vllm_omni.worker.qwen3_tts_stage0_step_runner.Qwen3TTSStage0StepRunner" - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # Talker backbone (Qwen3 decoder-only). @@ -562,6 +566,14 @@ def preprocess( span_len = int(input_ids.shape[0]) if span_len <= 0: return input_ids, input_embeds if input_embeds is not None else self.embed_input_ids(input_ids), {} + is_prefill_raw = info_dict.get("_omni_is_prefill") + if isinstance(is_prefill_raw, bool): + is_prefill = is_prefill_raw + else: + try: + is_prefill = int(info_dict["_omni_num_computed_tokens"]) < int(info_dict["_omni_prompt_len"]) + except Exception: + is_prefill = span_len > 1 text_list = info_dict.get("text") if not isinstance(text_list, list) or not text_list or not text_list[0]: @@ -578,7 +590,7 @@ def preprocess( else: codec_streaming = task_type == "Base" - if span_len > 1: + if is_prefill: # Prefill (prompt embeddings) prompt_embeds_cpu = embed.get("prefill") tts_pad_embed_cpu = embed.get("tts_pad") @@ -602,7 +614,11 @@ def preprocess( "tts_pad": tts_pad_embed.detach(), }, "hidden_states": {"trailing_text": tailing_text_hidden.detach()}, - "meta": {"talker_prefill_offset": 0, "codec_streaming": codec_streaming}, + "meta": { + "talker_prefill_offset": 0, + "talker_text_offset": 0, + "codec_streaming": codec_streaming, + }, } if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0: info_update.setdefault("codes", {})["ref"] = ref_code.detach().to("cpu").contiguous() @@ -636,7 +652,10 @@ def preprocess( take = torch.cat([take, pad_rows], dim=0) prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16) info_update = { - "meta": {"talker_prefill_offset": int(offset + span_len), "codec_streaming": codec_streaming} + "meta": { + "talker_prefill_offset": int(offset + span_len), + "codec_streaming": codec_streaming, + } } # When inputs_embeds is set, token ids are ignored by the model but must stay in-vocab for vLLM bookkeeping. @@ -660,12 +679,37 @@ def preprocess( tts_pad_embed = tts_pad_embed_buf.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) tail = hs.get("trailing_text") - if isinstance(tail, torch.Tensor) and tail.ndim == 2 and tail.shape[0] > 0: - text_step = tail[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) - new_tail = tail[1:] if tail.shape[0] > 1 else tail[:0] + text_offset = max(0, int(meta.get("talker_text_offset", 0) or 0)) + trailing_text_update = None + if isinstance(tail, torch.Tensor) and tail.ndim == 2: + tail_len = int(tail.shape[0]) + if text_offset < tail_len: + text_step = ( + tail[text_offset : text_offset + 1] + .to( + device=input_ids.device, + dtype=torch.bfloat16, + ) + .reshape(1, -1) + ) + next_text_offset = text_offset + 1 + should_compact_tail = next_text_offset >= tail_len or ( + next_text_offset >= _TRAILING_TEXT_COMPACT_MIN_FRAMES and next_text_offset * 2 >= tail_len + ) + if should_compact_tail: + if next_text_offset >= tail_len: + trailing_text_update = torch.empty((0, tail.shape[1]), device=tail.device, dtype=tail.dtype) + else: + trailing_text_update = tail[next_text_offset:].contiguous() + next_text_offset = 0 + else: + text_step = tts_pad_embed + next_text_offset = 0 + if tail.numel() > 0: + trailing_text_update = torch.empty((0, tail.shape[1]), device=tail.device, dtype=tail.dtype) else: text_step = tts_pad_embed - new_tail = tail if isinstance(tail, torch.Tensor) else torch.empty((0, tts_pad_embed.shape[-1])) + next_text_offset = text_offset last_hidden = hs.get("last") if not isinstance(last_hidden, torch.Tensor): @@ -679,10 +723,14 @@ def preprocess( inputs_embeds_out = last_id_hidden.reshape(1, -1) info_update = { - "hidden_states": {"trailing_text": new_tail}, "mtp_inputs": (past_hidden, text_step), - "meta": {"codec_streaming": codec_streaming}, + "meta": { + "talker_text_offset": int(next_text_offset), + "codec_streaming": codec_streaming, + }, } + if trailing_text_update is not None: + info_update["hidden_states"] = {"trailing_text": trailing_text_update.detach()} return input_ids, inputs_embeds_out, info_update def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]: @@ -1184,6 +1232,198 @@ def _encode_ref_audio_to_code(self, wav: np.ndarray, sr: int) -> torch.Tensor: return ref_code.to(device=next(self.parameters()).device, dtype=torch.long) raise ValueError("SpeechTokenizer.encode did not return audio_codes tensor") + @staticmethod + def _first_value(value: object, default: object = None) -> object: + if isinstance(value, list): + return value[0] if value else default + return value if value is not None else default + + @staticmethod + def _coerce_ref_code_tensor(value: object, *, device: torch.device) -> torch.Tensor | None: + value = Qwen3TTSTalkerForConditionalGeneration._first_value(value) + if value is None: + return None + if isinstance(value, torch.Tensor): + ref_code = value + elif isinstance(value, np.ndarray): + ref_code = torch.from_numpy(value) + elif isinstance(value, list) and value: + ref_code = torch.as_tensor(value, dtype=torch.long) + else: + return None + if ref_code.ndim == 3: + ref_code = ref_code[0] + if ref_code.ndim != 2 or ref_code.numel() == 0: + return None + return ref_code.to(device=device, dtype=torch.long).contiguous() + + @staticmethod + def _coerce_token_ids(value: object, *, device: torch.device) -> torch.Tensor | None: + value = Qwen3TTSTalkerForConditionalGeneration._first_value(value) + if value is None: + return None + if isinstance(value, torch.Tensor): + ids = value + elif isinstance(value, np.ndarray): + ids = torch.from_numpy(value) + elif isinstance(value, list) and value and all(isinstance(v, (int, np.integer)) for v in value): + ids = torch.tensor(value, dtype=torch.long) + else: + return None + if ids.ndim == 1: + ids = ids.unsqueeze(0) + if ids.ndim != 2 or ids.numel() == 0: + return None + return ids.to(device=device, dtype=torch.long).contiguous() + + @staticmethod + def _split_ref_code_batch(ref_codes: object, expected: int) -> list[object] | None: + if isinstance(ref_codes, torch.Tensor): + if ref_codes.ndim == 3 and int(ref_codes.shape[0]) == expected: + return [ref_codes[i] for i in range(expected)] + if expected == 1: + return [ref_codes] + return None + if isinstance(ref_codes, (list, tuple)) and len(ref_codes) == expected: + return list(ref_codes) + return None + + @staticmethod + def _voice_clone_prompt_dict(raw: object) -> dict[str, object] | None: + raw = Qwen3TTSTalkerForConditionalGeneration._first_value(raw) + if isinstance(raw, dict): + return raw + if isinstance(raw, list) and raw and isinstance(raw[0], dict): + return raw[0] + return None + + @staticmethod + def _has_ref_code_like(value: object) -> bool: + value = Qwen3TTSTalkerForConditionalGeneration._first_value(value) + if isinstance(value, torch.Tensor): + return value.numel() > 0 and value.ndim in (2, 3) + if isinstance(value, np.ndarray): + return value.size > 0 and value.ndim in (2, 3) + if isinstance(value, list): + return bool(value) + return False + + @staticmethod + def _needs_initial_prompt_preprocess(info_dict: dict[str, Any]) -> bool: + embed = info_dict.get("embed") + return not (isinstance(embed, dict) and "prefill" in embed) + + def _needs_batched_ref_code(self, info_dict: dict[str, Any]) -> bool: + if not self._needs_initial_prompt_preprocess(info_dict): + return False + if self._first_value(info_dict.get("task_type"), "CustomVoice") != "Base": + return False + codes = info_dict.get("codes") + if isinstance(codes, dict) and self._has_ref_code_like(codes.get(_PRECOMPUTED_REF_CODE_KEY)): + return False + + voice_clone_prompt = self._voice_clone_prompt_dict(info_dict.get("voice_clone_prompt")) + if voice_clone_prompt is not None: + if self._has_ref_code_like(voice_clone_prompt.get("ref_code")): + return False + icl_flag = self._first_value(voice_clone_prompt.get("icl_mode")) + if isinstance(icl_flag, bool) and not icl_flag: + return False + + xvec_only = self._first_value(info_dict.get("x_vector_only_mode"), False) + if isinstance(xvec_only, bool) and xvec_only: + return False + + ref_audio_list = info_dict.get("ref_audio") + return isinstance(ref_audio_list, list) and bool(ref_audio_list) + + @torch.inference_mode() + def preprocess_batch( + self, + *, + req_ids: list[str], + model_intermediate_buffer: dict[str, dict[str, Any]], + device: torch.device, + ) -> None: + """Batch Base voice-clone ref-audio codec extraction for current prefill requests.""" + pending_text: list[tuple[dict[str, Any], str]] = [] + pending_ref_text: list[tuple[dict[str, Any], str]] = [] + groups: dict[int, list[tuple[dict[str, Any], np.ndarray, int]]] = {} + for req_id in req_ids: + info_dict = model_intermediate_buffer.get(req_id) + if not isinstance(info_dict, dict): + continue + + if ( + self._needs_initial_prompt_preprocess(info_dict) + and self._first_value(info_dict.get("task_type"), "CustomVoice") == "Base" + ): + if _PRECOMPUTED_TEXT_IDS_KEY not in info_dict: + text = self._first_value(info_dict.get("text"), "") + if isinstance(text, str): + pending_text.append((info_dict, self._build_assistant_text(text))) + + if _PRECOMPUTED_REF_IDS_KEY not in info_dict: + ref_text = self._first_value(info_dict.get("ref_text")) + if isinstance(ref_text, str) and ref_text.strip(): + pending_ref_text.append((info_dict, self._build_ref_text(ref_text))) + + if not self._needs_batched_ref_code(info_dict): + continue + ref_audio_list = info_dict.get("ref_audio") + if not isinstance(ref_audio_list, list) or not ref_audio_list: + continue + try: + wav, sr = self._normalize_ref_audio(ref_audio_list[0]) + except Exception: + # Keep the original per-request path responsible for surfacing + # invalid ref_audio errors with its existing messages. + continue + info_dict[_NORMALIZED_REF_AUDIO_KEY] = (wav, sr) + groups.setdefault(int(sr), []).append((info_dict, wav, int(sr))) + + if pending_text or pending_ref_text: + tok_text = self._get_tokenizer() + for items, key in ( + (pending_text, _PRECOMPUTED_TEXT_IDS_KEY), + (pending_ref_text, _PRECOMPUTED_REF_IDS_KEY), + ): + if not items: + continue + try: + tokenized = tok_text([text for _, text in items], padding=False) + except Exception as exc: + logger.debug("Qwen3-TTS batched text tokenization failed; falling back to serial path: %s", exc) + continue + input_ids = tokenized.get("input_ids") if isinstance(tokenized, dict) else None + if not isinstance(input_ids, list) or len(input_ids) != len(items): + continue + for (info_dict, _), ids in zip(items, input_ids, strict=True): + if isinstance(ids, list) and ids: + info_dict[key] = torch.tensor(ids, dtype=torch.long) + + groups = {sr: items for sr, items in groups.items() if len(items) >= 2} + if not groups: + return + + tok = self._ensure_speech_tokenizer_loaded() + for sr, items in groups.items(): + wavs = [wav for _, wav, _ in items] + try: + enc = tok.encode(wavs, sr=int(sr), return_dict=True) + ref_codes = self._split_ref_code_batch(getattr(enc, "audio_codes", None), len(items)) + except Exception as exc: + logger.debug("Qwen3-TTS batched ref_code encode failed; falling back to serial path: %s", exc) + continue + if ref_codes is None: + continue + for (info_dict, wav, item_sr), ref_code in zip(items, ref_codes, strict=True): + ref_code_t = self._coerce_ref_code_tensor(ref_code, device=device) + if ref_code_t is None: + continue + info_dict.setdefault("codes", {})[_PRECOMPUTED_REF_CODE_KEY] = ref_code_t + info_dict[_NORMALIZED_REF_AUDIO_KEY] = (wav, item_sr) + def _generate_icl_prompt( self, *, @@ -1258,9 +1498,12 @@ def _build_prompt_embeds( # Text ids for assistant template (always). tok = self._get_tokenizer() - input_ids = tok(self._build_assistant_text(text), return_tensors="pt", padding=False)["input_ids"].to( - device=next(self.parameters()).device - ) + dev = next(self.parameters()).device + input_ids = self._coerce_token_ids(info_dict.pop(_PRECOMPUTED_TEXT_IDS_KEY, None), device=dev) + if input_ids is None: + input_ids = tok(self._build_assistant_text(text), return_tensors="pt", padding=False)["input_ids"].to( + device=dev + ) # Optional instruct prefix. instruct = (info_dict.get("instruct") or [""])[0] @@ -1364,6 +1607,26 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: xvec_only = bool((info_dict.get("x_vector_only_mode") or [False])[0]) in_context_mode = not xvec_only voice_clone_prompt = _normalize_voice_clone_prompt(info_dict.get("voice_clone_prompt")) + ref_audio_wav: np.ndarray | None = None + ref_audio_sr: int | None = None + + def _get_ref_audio() -> tuple[np.ndarray, int]: + nonlocal ref_audio_wav, ref_audio_sr + if ref_audio_wav is None or ref_audio_sr is None: + normalized_ref_audio = info_dict.pop(_NORMALIZED_REF_AUDIO_KEY, None) + if ( + isinstance(normalized_ref_audio, tuple) + and len(normalized_ref_audio) == 2 + and isinstance(normalized_ref_audio[0], np.ndarray) + ): + ref_audio_wav = normalized_ref_audio[0] + ref_audio_sr = int(normalized_ref_audio[1]) + return ref_audio_wav, ref_audio_sr + ref_audio_list = info_dict.get("ref_audio") + if not isinstance(ref_audio_list, list) or not ref_audio_list: + raise ValueError("Base requires `ref_audio`.") + ref_audio_wav, ref_audio_sr = self._normalize_ref_audio(ref_audio_list[0]) + return ref_audio_wav, ref_audio_sr # Speaker cache: only for uploaded (named) speakers _speaker_cache_key = None @@ -1415,12 +1678,17 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: ref_code_t = ref_code_t[0] ref_code_t = ref_code_t.to(device=input_ids.device, dtype=torch.long) ref_code_len = int(ref_code_t.shape[0]) - elif in_context_mode: + else: + codes = info_dict.get("codes") + precomputed_ref_code = ( + codes.get(_PRECOMPUTED_REF_CODE_KEY) if isinstance(codes, dict) else None + ) + ref_code_t = self._coerce_ref_code_tensor(precomputed_ref_code, device=input_ids.device) + if isinstance(ref_code_t, torch.Tensor): + ref_code_len = int(ref_code_t.shape[0]) + if ref_code_t is None and in_context_mode: # Compute ref_code from ref_audio if not provided. - ref_audio_list = info_dict.get("ref_audio") - if not isinstance(ref_audio_list, list) or not ref_audio_list: - raise ValueError("Base requires `ref_audio`.") - wav_np, sr = self._normalize_ref_audio(ref_audio_list[0]) + wav_np, sr = _get_ref_audio() ref_code_t = self._encode_ref_audio_to_code(wav_np, sr).to(device=input_ids.device) ref_code_len = int(ref_code_t.shape[0]) if isinstance(ref_code_t, torch.Tensor): @@ -1439,10 +1707,7 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: # Plain list/array from API (survived msgspec IPC serialization). speaker_embed = torch.tensor(spk, dtype=torch.bfloat16, device=input_ids.device).view(1, 1, -1) else: - ref_audio_list = info_dict.get("ref_audio") - if not isinstance(ref_audio_list, list) or not ref_audio_list: - raise ValueError("Base requires `ref_audio`.") - wav_np, sr = self._normalize_ref_audio(ref_audio_list[0]) + wav_np, sr = _get_ref_audio() speaker_embed = self._extract_speaker_embedding(wav_np, sr).view(1, 1, -1) # Cache miss: store extraction result @@ -1469,7 +1734,9 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: if in_context_mode: # Prefer explicit tokenized `ref_ids` if provided (matches official signature). - ref_ids = _to_long_tensor(info_dict.get("ref_ids"), device=input_ids.device) + ref_ids = self._coerce_token_ids(info_dict.pop(_PRECOMPUTED_REF_IDS_KEY, None), device=input_ids.device) + if ref_ids is None: + ref_ids = _to_long_tensor(info_dict.get("ref_ids"), device=input_ids.device) if ref_ids is None and voice_clone_prompt is not None: ref_ids = _to_long_tensor( voice_clone_prompt.get("ref_ids") or voice_clone_prompt.get("ref_id"), device=input_ids.device diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 3974c64ab63..1cc6a1c7019 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1,4 +1,3 @@ -import importlib from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -46,7 +45,6 @@ def __init__(self, *args, **kwargs): self.model_intermediate_buffer: dict[str, dict[str, Any]] = {} self._omni_num_scheduled_tokens_np: np.ndarray | None = None self._omni_last_model_output: object | None = None - self.omni_step_runner = None # The Omni tensor prefix cache will be allocated # when we initialize the metadata builders if enabled self.omni_prefix_cache = None @@ -115,17 +113,6 @@ def load_model(self, *args, **kwargs) -> None: ) self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) - self._init_omni_step_runner() - - def _init_omni_step_runner(self) -> None: - cls_path = getattr(self.model, "omni_step_runner_cls", None) - if not cls_path: - self.omni_step_runner = None - return - module_name, class_name = cls_path.rsplit(".", 1) - module = importlib.import_module(module_name) - cls = getattr(module, class_name) - self.omni_step_runner = cls.from_runner(self) def _init_mrope_positions(self, req_state: CachedRequestState): """Initialize M-RoPE positions for multimodal inputs. @@ -278,7 +265,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: - self._free_omni_request_state(req_id) + self.requests.pop(req_id, None) + self.model_intermediate_buffer.pop(req_id, None) + self.num_prompt_logprobs.pop(req_id, None) + if hasattr(self, "_downstream_payload_cache"): + self._downstream_payload_cache.pop(req_id, None) + if hasattr(self, "_talker_mtp_generators"): + self._talker_mtp_generators.pop(req_id, None) if hasattr(self, "late_interaction_runner"): self.late_interaction_runner.on_requests_finished(scheduler_output.finished_req_ids) @@ -578,18 +571,6 @@ def correct_spec_decode_token_counts(): else: return None - def _free_omni_request_state(self, req_id: str) -> None: - self.requests.pop(req_id, None) - self.model_intermediate_buffer.pop(req_id, None) - self.num_prompt_logprobs.pop(req_id, None) - if hasattr(self, "_downstream_payload_cache"): - self._downstream_payload_cache.pop(req_id, None) - if hasattr(self, "_talker_mtp_generators"): - self._talker_mtp_generators.pop(req_id, None) - step_runner = getattr(self, "omni_step_runner", None) - if step_runner is not None: - step_runner.free_request(req_id) - @torch.inference_mode() def extract_multimodal_outputs(self, hidden_states: torch.Tensor | list[torch.Tensor] | OmniOutput) -> dict: if ( @@ -1194,6 +1175,22 @@ def _maybe_attach_mimo_audio_req_infos( return req_infos + def _maybe_run_batch_preprocess(self, req_ids: list[str], device: torch.device) -> None: + """Run an optional model-specific batch preprocess hook. + + The generic runner only supplies current request ids and the runner-owned + intermediate buffer; model-specific code decides whether there is any + batchable work. + """ + preprocess_batch = getattr(self.model, "preprocess_batch", None) + if not callable(preprocess_batch): + return + preprocess_batch( + req_ids=req_ids, + model_intermediate_buffer=self.model_intermediate_buffer, + device=device, + ) + def _preprocess( self, scheduler_output: "SchedulerOutput", @@ -1320,10 +1317,12 @@ def _preprocess( self._update_additional_information(scheduler_output) if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: + preprocess_device = input_ids.device if input_ids is not None else inputs_embeds.device + self._maybe_run_batch_preprocess(self.input_batch.req_ids, preprocess_device) + # Overlay custom prompt_embeds per request for the prompt portion; # collect additional_information (tensor/list) for prefill portion only decode_req_ids = [] - decode_is_prefill_by_req: dict[str, bool] = {} for req_index, req_id in enumerate(self.input_batch.req_ids): req_infos = self.model_intermediate_buffer.get(req_id, {}) @@ -1338,6 +1337,13 @@ def _preprocess( # call the custom process function req_infos["request_id"] = req_id + prompt_token_ids = getattr(req_state, "prompt_token_ids", ()) if req_state is not None else () + prompt_len = len(prompt_token_ids or ()) + num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_index]) + is_prefill = num_computed_tokens < prompt_len + req_infos["_omni_prompt_len"] = prompt_len + req_infos["_omni_num_computed_tokens"] = num_computed_tokens + req_infos["_omni_is_prefill"] = is_prefill embed_slice = inputs_embeds[s:e] if inputs_embeds is not None else None req_input_ids, req_embeds, update_dict = self.model.preprocess( input_ids=input_ids[s:e], input_embeds=embed_slice, **req_infos @@ -1349,7 +1355,7 @@ def _preprocess( dtype=req_embeds.dtype, ) - if self.has_talker_mtp and span_len == 1: + if self.has_talker_mtp and span_len == 1 and not is_prefill: last_talker_hidden, text_step = update_dict.pop("mtp_inputs") decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) @@ -1357,13 +1363,9 @@ def _preprocess( self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) self.text_step.gpu[decode_slice].copy_(text_step) decode_req_ids.append(req_id) - num_computed_tokens = int(self.input_batch.num_computed_tokens_cpu[req_index]) - prompt_token_ids = getattr(req_state, "prompt_token_ids", ()) if req_state is not None else () - prompt_len = len(prompt_token_ids or ()) - decode_is_prefill_by_req[req_id] = num_computed_tokens < prompt_len # TODO(Peiqi): the merge stage could move out from the critical path - self._update_intermediate_buffer(req_id, update_dict) + self._merge_additional_information_update(req_id, update_dict) # update the inputs_embeds and input_ids seg_len = min(span_len, req_embeds.shape[0]) @@ -1372,35 +1374,8 @@ def _preprocess( input_ids[s : s + seg_len] = req_input_ids # run talker mtp decode - if self.has_talker_mtp and decode_req_ids: - batch_size = len(decode_req_ids) - mtp_input_ids = self.talker_mtp_input_ids.gpu[:batch_size] - mtp_req_embeds = self.talker_mtp_inputs_embeds.gpu[:batch_size] - mtp_last_hidden = self.last_talker_hidden.gpu[:batch_size] - mtp_text_step = self.text_step.gpu[:batch_size] - step_runner = getattr(self, "omni_step_runner", None) - use_fast_path = step_runner is not None and step_runner.supports_step( - runner=self, - request_ids=decode_req_ids, - num_scheduled_tokens=[1] * len(decode_req_ids), - is_prefill_by_req=decode_is_prefill_by_req, - ) - if use_fast_path: - prepared = step_runner.prepare_step( - request_ids=decode_req_ids, - runner=self, - input_ids=mtp_input_ids, - req_embeds=mtp_req_embeds, - last_talker_hidden=mtp_last_hidden, - text_step=mtp_text_step, - ) - step_runner.run_step(prepared=prepared, runner=self) - step_runner.commit_step(prepared=prepared, runner=self, inputs_embeds=inputs_embeds) - else: - if step_runner is not None: - reason = getattr(step_runner, "_last_fallback_reason", None) or "unsupported" - step_runner.record_fallback(reason) - self._talker_mtp_forward(decode_req_ids, inputs_embeds) + if self.has_talker_mtp: + self._talker_mtp_forward(decode_req_ids, inputs_embeds) return ( input_ids, @@ -1484,28 +1459,14 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te ) # update the inputs_embeds and code_predictor_codes out_key = getattr(self.model, "talker_mtp_output_key", ("codes", "audio")) - req_index_by_id = {req_id: idx for idx, req_id in enumerate(self.input_batch.req_ids)} - query_start_loc_cpu = self.query_start_loc.cpu - for idx, req_id in enumerate(decode_req_ids): - req_index = req_index_by_id[req_id] - start_offset = int(query_start_loc_cpu[req_index]) - inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] - self._update_talker_mtp_output(req_id, out_key, code_predictor_codes[idx : idx + 1]) - - def _update_talker_mtp_output(self, req_id: str, out_key: tuple[str, str], value: torch.Tensor) -> None: if not isinstance(out_key, tuple) or len(out_key) != 2: raise TypeError(f"talker_mtp_output_key must be a 2-tuple, got {type(out_key).__name__}: {out_key!r}") - req_state = self.requests.get(req_id) - if req_state is None: - return - type_key, qual = out_key - gpu_keys: set[tuple[str, str]] = set() - if hasattr(self, "model") and hasattr(self.model, "gpu_resident_buffer_keys"): - gpu_keys = self.model.gpu_resident_buffer_keys - existing = self.model_intermediate_buffer.setdefault(req_id, {}) - existing_sub = existing.setdefault(type_key, {}) - self._store_value(existing_sub, qual, value, {q for tk, q in gpu_keys if tk == type_key}) - setattr(req_state, "additional_information_cpu", existing) + for idx, req_id in enumerate(decode_req_ids): + req_index = self.input_batch.req_ids.index(req_id) + start_offset = int(self.query_start_loc.cpu[req_index]) + inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] + update_dict = {out_key[0]: {out_key[1]: code_predictor_codes[idx : idx + 1]}} + self._merge_additional_information_update(req_id, update_dict) def _model_forward( self, diff --git a/vllm_omni/worker/omni_step_runner.py b/vllm_omni/worker/omni_step_runner.py deleted file mode 100644 index f85df9171f6..00000000000 --- a/vllm_omni/worker/omni_step_runner.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from dataclasses import dataclass, field -from typing import Any, Protocol - -import torch - - -@dataclass(slots=True) -class OmniPreparedStep: - request_ids: list[str] - token_slices: list[slice] = field(default_factory=list) - fallback_request_ids: list[str] = field(default_factory=list) - metadata: dict[str, Any] = field(default_factory=dict) - - -class OmniStepRunner(Protocol): - @classmethod - def from_runner(cls, runner: Any) -> OmniStepRunner: - ... - - def supports_step( - self, - *, - runner: Any, - request_ids: list[str], - num_scheduled_tokens: Sequence[int], - is_prefill_by_req: Mapping[str, bool], - ) -> bool: - ... - - def prepare_step( - self, - *, - request_ids: list[str], - runner: Any, - input_ids: torch.Tensor, - req_embeds: torch.Tensor, - last_talker_hidden: torch.Tensor, - text_step: torch.Tensor, - ) -> OmniPreparedStep: - ... - - def run_step( - self, - *, - prepared: OmniPreparedStep, - runner: Any, - ) -> None: - ... - - def commit_step( - self, - *, - prepared: OmniPreparedStep, - runner: Any, - inputs_embeds: torch.Tensor, - ) -> None: - ... - - def free_request(self, request_id: str) -> None: - ... diff --git a/vllm_omni/worker/qwen3_tts_stage0_step_runner.py b/vllm_omni/worker/qwen3_tts_stage0_step_runner.py deleted file mode 100644 index ae747cc8ecc..00000000000 --- a/vllm_omni/worker/qwen3_tts_stage0_step_runner.py +++ /dev/null @@ -1,345 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from __future__ import annotations - -from collections import defaultdict -from collections.abc import Mapping, Sequence -from dataclasses import dataclass, field -from typing import Any - -import numpy as np -import torch -from vllm.compilation.cuda_graph import CUDAGraphWrapper -from vllm.config import CUDAGraphMode -from vllm.forward_context import set_forward_context -from vllm.logger import init_logger - -from vllm_omni.worker.omni_step_runner import OmniPreparedStep - -logger = init_logger(__name__) - - -@dataclass(slots=True) -class Qwen3TTSSlot: - req_id: str | None - prompt_len: int - num_computed_tokens: int - text_offset: int - codec_len: int - emitted_chunks: int - finished: bool - - -class Qwen3TTSSlotTable: - def __init__( - self, - *, - max_slots: int, - hidden_size: int, - num_quantizers: int, - device: torch.device, - dtype: torch.dtype, - ) -> None: - self.max_slots = max_slots - self.hidden_size = hidden_size - self.num_quantizers = num_quantizers - self.slots = [Qwen3TTSSlot(None, 0, 0, 0, 0, 0, False) for _ in range(max_slots)] - self.req_to_slot: dict[str, int] = {} - self.free_slots = list(range(max_slots - 1, -1, -1)) - - self.input_ids = torch.empty((max_slots,), dtype=torch.long, device=device) - self.inputs_embeds = torch.empty((max_slots, hidden_size), dtype=dtype, device=device) - self.last_talker_hidden = torch.empty((max_slots, hidden_size), dtype=dtype, device=device) - self.text_step = torch.empty((max_slots, hidden_size), dtype=dtype, device=device) - self.next_embeds = torch.empty((max_slots, hidden_size), dtype=dtype, device=device) - self.sampled_codes = torch.empty((max_slots, num_quantizers), dtype=torch.long, device=device) - self.codec_frames: dict[str, list[torch.Tensor]] = defaultdict(list) - - def allocate(self, req_id: str) -> int: - existing = self.req_to_slot.get(req_id) - if existing is not None: - return existing - if not self.free_slots: - raise RuntimeError("Qwen3-TTS Stage0 slot table exhausted") - slot = self.free_slots.pop() - self.req_to_slot[req_id] = slot - self.slots[slot] = Qwen3TTSSlot(req_id, 0, 0, 0, 0, 0, False) - return slot - - def free(self, req_id: str) -> None: - slot = self.req_to_slot.pop(req_id, None) - if slot is None: - return - self.slots[slot] = Qwen3TTSSlot(None, 0, 0, 0, 0, 0, False) - self.codec_frames.pop(req_id, None) - self.free_slots.append(slot) - - -@dataclass(slots=True) -class Qwen3TTSPreparedStep(OmniPreparedStep): - slot_indices: list[int] = field(default_factory=list) - query_offsets: list[int] = field(default_factory=list) - next_embeds: torch.Tensor | None = None - sampled_codes: torch.Tensor | None = None - - -@dataclass(slots=True) -class Qwen3TTSStage0StepStats: - fast_path_steps: int = 0 - fast_path_requests: int = 0 - fallback_reasons: dict[str, int] = field(default_factory=dict) - - -def flatten_codec_frames_for_code2wav(frames_fq: torch.Tensor) -> torch.Tensor: - if frames_fq.ndim != 2: - raise ValueError(f"expected [frames, quantizers], got {tuple(frames_fq.shape)}") - return frames_fq.transpose(0, 1).contiguous().reshape(-1) - - -class Qwen3TTSStage0StepRunner: - def __init__( - self, - *, - max_slots: int, - hidden_size: int, - num_quantizers: int = 16, - device: torch.device | None = None, - dtype: torch.dtype = torch.float32, - log_every_n_steps: int = 1000, - ) -> None: - if max_slots <= 0: - raise ValueError(f"max_slots must be positive, got {max_slots}") - if hidden_size <= 0: - raise ValueError(f"hidden_size must be positive, got {hidden_size}") - if num_quantizers <= 0: - raise ValueError(f"num_quantizers must be positive, got {num_quantizers}") - device = device or torch.device("cpu") - self.table = Qwen3TTSSlotTable( - max_slots=max_slots, - hidden_size=hidden_size, - num_quantizers=num_quantizers, - device=device, - dtype=dtype, - ) - self.max_slots = max_slots - self.hidden_size = hidden_size - self.num_quantizers = num_quantizers - self.stats = Qwen3TTSStage0StepStats() - self.log_every_n_steps = log_every_n_steps - self._last_fallback_reason: str | None = None - - @classmethod - def from_runner(cls, runner: Any) -> Qwen3TTSStage0StepRunner: - mtp_buffer = runner.talker_mtp_inputs_embeds.gpu - model = getattr(runner, "model", None) - talker_config = getattr(model, "talker_config", None) - num_quantizers = int(getattr(talker_config, "num_code_groups", 0) or 16) - step_runner = cls( - max_slots=int(mtp_buffer.shape[0]), - hidden_size=int(mtp_buffer.shape[-1]), - num_quantizers=num_quantizers, - device=mtp_buffer.device, - dtype=mtp_buffer.dtype, - ) - step_runner.table.input_ids = runner.talker_mtp_input_ids.gpu - step_runner.table.inputs_embeds = runner.talker_mtp_inputs_embeds.gpu - step_runner.table.last_talker_hidden = runner.last_talker_hidden.gpu - step_runner.table.text_step = runner.text_step.gpu - return step_runner - - def _reject(self, reason: str) -> bool: - self._last_fallback_reason = reason - return False - - def supports_step( - self, - *, - runner: Any, - request_ids: list[str], - num_scheduled_tokens: Sequence[int], - is_prefill_by_req: Mapping[str, bool], - ) -> bool: - if not request_ids: - return self._reject("empty") - model_config = getattr(getattr(runner, "vllm_config", None), "model_config", None) - if not bool(getattr(model_config, "async_chunk", False)): - return self._reject("async_chunk_disabled") - if getattr(model_config, "model_stage", None) != "qwen3_tts": - return self._reject("wrong_stage") - if not bool(getattr(runner, "has_talker_mtp", False)): - return self._reject("no_talker_mtp") - if len(num_scheduled_tokens) != len(request_ids): - return self._reject("shape_mismatch") - if any(int(n) != 1 for n in num_scheduled_tokens): - return self._reject("non_decode_step") - if any(bool(is_prefill_by_req.get(req_id, True)) for req_id in request_ids): - return self._reject("prefill") - self._last_fallback_reason = None - return True - - def prepare_step( - self, - *, - request_ids: list[str], - runner: Any, - input_ids: torch.Tensor, - req_embeds: torch.Tensor, - last_talker_hidden: torch.Tensor, - text_step: torch.Tensor, - ) -> Qwen3TTSPreparedStep: - batch_size = len(request_ids) - if batch_size > self.max_slots: - raise RuntimeError(f"Qwen3-TTS Stage0 slot batch too large: {batch_size} > {self.max_slots}") - if req_embeds.shape[-1] != self.hidden_size: - raise ValueError(f"expected hidden_size={self.hidden_size}, got {req_embeds.shape[-1]}") - - req_index_by_id = {req_id: idx for idx, req_id in enumerate(runner.input_batch.req_ids)} - slot_indices: list[int] = [] - query_offsets: list[int] = [] - if self.table.input_ids[:batch_size].data_ptr() != input_ids[:batch_size].data_ptr(): - self.table.input_ids[:batch_size].copy_(input_ids[:batch_size].to(device=self.table.input_ids.device)) - if self.table.inputs_embeds[:batch_size].data_ptr() != req_embeds[:batch_size].data_ptr(): - self.table.inputs_embeds[:batch_size].copy_( - req_embeds[:batch_size].to(device=self.table.inputs_embeds.device) - ) - if self.table.last_talker_hidden[:batch_size].data_ptr() != last_talker_hidden[:batch_size].data_ptr(): - self.table.last_talker_hidden[:batch_size].copy_( - last_talker_hidden[:batch_size].to(device=self.table.last_talker_hidden.device) - ) - if self.table.text_step[:batch_size].data_ptr() != text_step[:batch_size].data_ptr(): - self.table.text_step[:batch_size].copy_(text_step[:batch_size].to(device=self.table.text_step.device)) - for batch_idx, req_id in enumerate(request_ids): - slot_indices.append(self.table.allocate(req_id)) - req_index = req_index_by_id[req_id] - query_offsets.append(int(runner.query_start_loc.cpu[req_index])) - - return Qwen3TTSPreparedStep( - request_ids=list(request_ids), - slot_indices=slot_indices, - query_offsets=query_offsets, - metadata={"batch_size": batch_size}, - ) - - def _talker_kwargs(self, runner: Any, request_ids: list[str], device: torch.device) -> dict[str, Any]: - subtalker_params = getattr(runner.vllm_config.model_config, "subtalker_sampling_params", None) - if not isinstance(subtalker_params, dict): - subtalker_params = {} - talker_kwargs: dict[str, Any] = { - "do_sample": subtalker_params.get("do_sample"), - "temperature": subtalker_params.get("temperature"), - "top_k": subtalker_params.get("top_k"), - "top_p": subtalker_params.get("top_p"), - } - if not request_ids: - return talker_kwargs - first_req_id = request_ids[0] - first_sp = getattr(runner.requests[first_req_id], "sampling_params", None) - extra_args = getattr(first_sp, "extra_args", None) if first_sp is not None else None - seed = extra_args.get("qwen3_tts_request_seed") if isinstance(extra_args, dict) else None - if seed is None: - return talker_kwargs - generators = getattr(runner, "_talker_mtp_generators", None) - if generators is None: - generators = {} - runner._talker_mtp_generators = generators - generator = generators.get(first_req_id) - if generator is None or generator.device != device: - generator = torch.Generator(device=device) - generator.manual_seed(int(seed)) - generators[first_req_id] = generator - talker_kwargs["generator"] = generator - return talker_kwargs - - def run_step( - self, - *, - prepared: Qwen3TTSPreparedStep, - runner: Any, - ) -> None: - batch_size = int(prepared.metadata["batch_size"]) - if batch_size == 0: - return - if hasattr(runner, "_determine_batch_execution_and_padding"): - cudagraph_mode, batch_desc, _, _, _ = runner._determine_batch_execution_and_padding( - num_tokens=batch_size, - num_reqs=batch_size, - num_scheduled_tokens_np=np.ones(batch_size, dtype=np.int32), - max_num_scheduled_tokens=1, - use_cascade_attn=False, - ) - if not isinstance(runner.talker_mtp, CUDAGraphWrapper): - cudagraph_mode = CUDAGraphMode.NONE - num_tokens_padded = batch_size - else: - num_tokens_padded = batch_desc.num_tokens - talker_kwargs = self._talker_kwargs(runner, prepared.request_ids, self.table.input_ids.device) - with set_forward_context( - None, - runner.vllm_config, - cudagraph_runtime_mode=cudagraph_mode, - batch_descriptor=batch_desc, - ): - next_embeds, sampled_codes = runner.talker_mtp( - self.table.input_ids[:num_tokens_padded], - self.table.inputs_embeds[:num_tokens_padded], - self.table.last_talker_hidden[:num_tokens_padded], - self.table.text_step[:num_tokens_padded], - **talker_kwargs, - ) - else: - next_embeds, sampled_codes = runner.talker_mtp( - self.table.input_ids[:batch_size], - self.table.inputs_embeds[:batch_size], - self.table.last_talker_hidden[:batch_size], - self.table.text_step[:batch_size], - ) - prepared.next_embeds = next_embeds[:batch_size] - prepared.sampled_codes = sampled_codes[:batch_size] - - def commit_step( - self, - *, - prepared: Qwen3TTSPreparedStep, - runner: Any, - inputs_embeds: torch.Tensor, - ) -> None: - if prepared.next_embeds is None or prepared.sampled_codes is None: - raise RuntimeError("run_step must be called before commit_step") - out_key = getattr(runner.model, "talker_mtp_output_key", ("codes", "audio")) - for idx, req_id in enumerate(prepared.request_ids): - start_offset = prepared.query_offsets[idx] - inputs_embeds[start_offset : start_offset + 1] = prepared.next_embeds[idx : idx + 1] - codes = prepared.sampled_codes[idx : idx + 1] - update = getattr(runner, "_update_talker_mtp_output", None) - if update is None: - self._update_runner_buffer(runner, req_id, out_key, codes) - else: - update(req_id, out_key, codes) - self.record_fast_path(batch_size=len(prepared.request_ids)) - - def _update_runner_buffer(self, runner: Any, req_id: str, out_key: tuple[str, str], value: torch.Tensor) -> None: - type_key, qual = out_key - existing = runner.model_intermediate_buffer.setdefault(req_id, {}) - existing_sub = existing.setdefault(type_key, {}) - existing_sub[qual] = value.detach().clone() - req_state = runner.requests.get(req_id) - if req_state is not None: - setattr(req_state, "additional_information_cpu", existing) - - def record_fast_path(self, *, batch_size: int) -> None: - self.stats.fast_path_steps += 1 - self.stats.fast_path_requests += int(batch_size) - if self.log_every_n_steps > 0 and self.stats.fast_path_steps % self.log_every_n_steps == 0: - logger.info( - "Qwen3-TTS Stage0 fast path stats: steps=%d, requests=%d, fallback=%s", - self.stats.fast_path_steps, - self.stats.fast_path_requests, - self.stats.fallback_reasons, - ) - - def record_fallback(self, reason: str) -> None: - self.stats.fallback_reasons[reason] = self.stats.fallback_reasons.get(reason, 0) + 1 - - def free_request(self, request_id: str) -> None: - self.table.free(request_id) From 902aefa0010a17ce13f7874e617ad2ddafd2ba9b Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Sun, 17 May 2026 02:12:08 +0800 Subject: [PATCH 4/6] feat: optimize qwen3 tts high concurrency serving Signed-off-by: Sy03 <1370724210@qq.com> --- .../qwen3_tts/test_code_predictor_dtype.py | 53 ++ .../qwen3_tts/test_cuda_graph_decoder.py | 129 ++++- .../qwen3_tts/test_qwen3_tts_code2wav.py | 307 +++++++++++- .../test_qwen3_tts_talker_preprocess.py | 67 ++- .../test_qwen3_tts_async_chunk.py | 30 ++ tests/worker/test_omni_gpu_model_runner.py | 41 ++ .../deploy/qwen3_tts_high_concurrency.yaml | 95 ++++ .../models/common/qwen3_code_predictor.py | 190 +++++-- .../qwen3_tts/cuda_graph_decoder_wrapper.py | 472 ++++++++++++++++-- .../models/qwen3_tts/qwen3_tts_code2wav.py | 306 +++++++++++- .../models/qwen3_tts/qwen3_tts_talker.py | 190 ++++++- .../modeling_qwen3_tts_tokenizer_v2.py | 45 +- .../stage_input_processors/qwen3_tts.py | 33 +- vllm_omni/worker/gpu_model_runner.py | 113 ++++- 14 files changed, 1932 insertions(+), 139 deletions(-) create mode 100644 vllm_omni/deploy/qwen3_tts_high_concurrency.yaml diff --git a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py index 3c76f421019..42d229c3e6f 100644 --- a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py +++ b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py @@ -48,6 +48,7 @@ def _build_mock_modules(mocker: MockerFixture) -> dict[str, object]: """Build the dict of modules to inject into sys.modules.""" platforms_mock = mocker.MagicMock() platforms_mock.current_omni_platform.supports_torch_inductor.return_value = False + platforms_mock.current_omni_platform.is_npu.return_value = False logger_mock = mocker.MagicMock() logger_mock.init_logger = lambda name: mocker.MagicMock() @@ -395,3 +396,55 @@ def test_tts_config(self, loaded_target_classes) -> None: assert config.use_parallel_embedding is False assert config.return_proj_buf is False assert config.sampling_mode == "per_call" + + def test_prefix_graph_config_helpers(self, loaded_target_classes) -> None: + """Prefix graph helpers parse deploy config values and keep valid seq lens only.""" + _ = loaded_target_classes + common_mod = sys.modules["vllm_omni.model_executor.models.common.qwen3_code_predictor"] + wrapper_cls = common_mod.CodePredictorWrapper + + assert wrapper_cls._parse_positive_int_set("64; 128,0,-1") == { + 64, + 128, + } + assert wrapper_cls._parse_positive_int_set([2, "4", 0]) == {2, 4} + + wrapper = object.__new__(wrapper_cls) + wrapper._prefix_graph_seq_lens = {1, 2, 4, 8, 99} + assert wrapper._prefix_seq_lens(6) == [2, 4] + + def test_prefix_graph_env_requires_cuda_graphs( + self, + mocker: MockerFixture, + loaded_target_classes, + ) -> None: + """Avoid prefix warmup on shared code-predictor users that disable CUDA graphs.""" + _ = loaded_target_classes + common_mod = sys.modules["vllm_omni.model_executor.models.common.qwen3_code_predictor"] + mocker.patch.object(common_mod.current_omni_platform, "is_npu", return_value=False) + + cp_config, _ = _make_tiny_config(loaded_target_classes) + vllm_config = _make_vllm_config(mocker, max_num_seqs=2) + vllm_config.model_config.stage_connector_config = { + "extra": { + "code_predictor_prefix_graphs": True, + "code_predictor_prefix_graph_buckets": [2], + "code_predictor_prefix_graph_seq_lens": "2,3", + } + } + + no_graph_wrapper = common_mod.CodePredictorWrapper( + vllm_config=vllm_config, + cp_config=cp_config, + wrapper_config=common_mod.CodePredictorWrapperConfig(use_cuda_graphs=False), + ) + assert no_graph_wrapper._prefix_graphs_enabled is False + assert no_graph_wrapper._prefix_graph_buckets == {2} + assert no_graph_wrapper._prefix_graph_seq_lens == {2, 3} + + graph_wrapper = common_mod.CodePredictorWrapper( + vllm_config=vllm_config, + cp_config=cp_config, + wrapper_config=common_mod.CodePredictorWrapperConfig(use_cuda_graphs=True), + ) + assert graph_wrapper._prefix_graphs_enabled is True diff --git a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py index 33b6187e54d..dbd2172464d 100644 --- a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py +++ b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py @@ -87,6 +87,7 @@ def wrapper(decoder): w = CUDAGraphDecoderWrapper( decoder=decoder, capture_sizes=[25, 50, 100], + capture_batch_sizes=[1, 2], num_quantizers=NUM_QUANTIZERS, enabled=True, ) @@ -94,8 +95,8 @@ def wrapper(decoder): return w -def _random_codes(seq_len, device=DEVICE): - return torch.randint(0, 100, (1, NUM_QUANTIZERS, seq_len), dtype=torch.long, device=device) +def _random_codes(seq_len, batch_size=1, device=DEVICE): + return torch.randint(0, 100, (batch_size, NUM_QUANTIZERS, seq_len), dtype=torch.long, device=device) # ────────────────────────────────────────────────────────────────── @@ -213,6 +214,44 @@ def test_chunked_decode_exact_size_equivalence(decoder, wrapper, total_len): torch.testing.assert_close(graph_out, eager_out, atol=0, rtol=0) +def test_chunked_decode_output_survives_later_replay(wrapper): + """Chunked output must not alias graph static buffers overwritten by later replays.""" + codes = _random_codes(100) + overwrite_codes = _random_codes(100) + + with torch.no_grad(): + graph_out = wrapper.chunked_decode_with_cudagraph(codes, chunk_size=50, left_context_size=0) + expected = graph_out.clone() + _ = wrapper.decode(overwrite_codes[..., :50]) + _ = wrapper.decode(overwrite_codes) + + torch.testing.assert_close(graph_out, expected, atol=0, rtol=0) + + +def test_batched_chunked_decode_variable_lengths_matches_per_request_eager(decoder, wrapper): + """Variable-length chunk batching should match independent chunked decodes.""" + long_codes = _random_codes(100) + short_codes = _random_codes(50) + padded_codes = torch.zeros(2, NUM_QUANTIZERS, 100, dtype=torch.long, device=DEVICE) + padded_codes[0, :, :] = long_codes[0] + padded_codes[1, :, :50] = short_codes[0] + + with torch.no_grad(): + eager_long = _eager_chunked(decoder, long_codes, chunk_size=50, left_context_size=0) + eager_short = _eager_chunked(decoder, short_codes, chunk_size=50, left_context_size=0) + graph_out = wrapper.batched_chunked_decode_with_cudagraph( + padded_codes, + [100, 50], + chunk_size=50, + left_context_size=0, + max_batch_size=2, + ) + + assert graph_out.shape == (2, 1, 100 * TOTAL_UPSAMPLE) + torch.testing.assert_close(graph_out[0:1], eager_long, atol=1e-6, rtol=1e-6) + torch.testing.assert_close(graph_out[1:2, :, : 50 * TOTAL_UPSAMPLE], eager_short, atol=1e-6, rtol=1e-6) + + def _eager_chunked(decoder, codes, chunk_size, left_context_size): """Eager chunked decode matching the real decoder's chunked_decode logic.""" wavs = [] @@ -254,15 +293,95 @@ def test_disabled_wrapper_matches_eager(decoder, wrapper): torch.testing.assert_close(graph_out, eager_out, atol=0, rtol=0) -def test_batch_size_gt1_falls_back(decoder, wrapper): - """Batch size > 1 should fall back to eager (bit-identical).""" - codes = torch.randint(0, 100, (2, NUM_QUANTIZERS, 25), dtype=torch.long, device=DEVICE) +def test_batch_size_gt1_uses_matching_graph(decoder, wrapper): + """Captured batch size > 1 should replay a matching graph.""" + assert (2, 25) in wrapper.graphs + codes = _random_codes(25, batch_size=2) with torch.no_grad(): eager_out = decoder(codes) graph_out = wrapper.decode(codes) torch.testing.assert_close(graph_out, eager_out, atol=0, rtol=0) +def test_uncaptured_batch_size_falls_back(decoder, wrapper): + """Uncaptured batch sizes should fall back to eager.""" + assert (3, 25) not in wrapper.graphs + codes = _random_codes(25, batch_size=3) + with torch.no_grad(): + eager_out = decoder(codes) + graph_out = wrapper.decode(codes) + torch.testing.assert_close(graph_out, eager_out, atol=0, rtol=0) + + +def test_extra_capture_shape_uses_sparse_graph(decoder): + """Extra capture shapes should not expand to a full batch x size product.""" + sparse_wrapper = CUDAGraphDecoderWrapper( + decoder=decoder, + capture_sizes=[25], + capture_batch_sizes=[1], + extra_capture_shapes=[(2, 50)], + num_quantizers=NUM_QUANTIZERS, + enabled=True, + ) + sparse_wrapper.warmup(DEVICE) + + assert (1, 25) in sparse_wrapper.graphs + assert (2, 50) in sparse_wrapper.graphs + assert (2, 25) not in sparse_wrapper.graphs + + codes = _random_codes(50, batch_size=2) + with torch.no_grad(): + eager_out = decoder(codes) + graph_out = sparse_wrapper.decode(codes) + torch.testing.assert_close(graph_out, eager_out, atol=0, rtol=0) + + +def test_compile_shape_supports_exact_and_padded_buckets(decoder, monkeypatch): + """Configured torch.compile shapes should replay exact and padded CUDA Graph buckets.""" + + compile_kwargs = {} + + def _fake_compile(model, **_kwargs): + compile_kwargs.update(_kwargs) + + def _compiled(codes): + return model(codes) + 0.125 + + return _compiled + + monkeypatch.setattr(torch, "compile", _fake_compile) + + compiled_wrapper = CUDAGraphDecoderWrapper( + decoder=decoder, + capture_sizes=[25, 50], + capture_batch_sizes=[1], + compile_shapes=[(1, 25), (1, 50)], + num_quantizers=NUM_QUANTIZERS, + enabled=True, + ) + compiled_wrapper.warmup(DEVICE) + + exact_codes = _random_codes(25) + padded_codes = _random_codes(30) + uncaptured_codes = _random_codes(60) + padded_static = torch.zeros(1, NUM_QUANTIZERS, 50, dtype=torch.long, device=DEVICE) + padded_static[:, :, :30] = padded_codes + with torch.no_grad(): + exact_eager = decoder(exact_codes) + exact_out = compiled_wrapper.decode(exact_codes) + padded_graph_expected = decoder(padded_static)[..., : 30 * TOTAL_UPSAMPLE] + padded_out = compiled_wrapper.decode(padded_codes) + uncaptured_eager = decoder(uncaptured_codes) + uncaptured_out = compiled_wrapper.decode(uncaptured_codes) + + torch.testing.assert_close(exact_out, exact_eager + 0.125, atol=0, rtol=0) + torch.testing.assert_close(padded_out, padded_graph_expected + 0.125, atol=0, rtol=0) + torch.testing.assert_close(uncaptured_out, uncaptured_eager, atol=0, rtol=0) + assert compile_kwargs["mode"] == "default" + assert compile_kwargs["fullgraph"] is False + assert compile_kwargs["dynamic"] is False + + def test_deterministic_across_calls(decoder, wrapper): """Same input should produce identical CUDA graph output across calls.""" codes = _random_codes(30) diff --git a/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py index 88ab16c0d1c..2e91f2a4f43 100644 --- a/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py +++ b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py @@ -24,6 +24,7 @@ def __init__(self, total_upsample: int = _TOTAL_UPSAMPLE): super().__init__() self.total_upsample = total_upsample self.decode_calls: list[dict[str, int]] = [] + self.batched_decode_calls: list[dict[str, int]] = [] self.cudagraph_calls: list[dict[str, int | torch.device]] = [] def to(self, *args, **kwargs): @@ -40,12 +41,40 @@ def chunked_decode( { "chunk_size": chunk_size, "left_context_size": left_context_size, + "codes_shape": tuple(codes.shape), } ) + batch = codes.shape[0] frames = codes.shape[-1] wav_len = frames * self.total_upsample + 6 - wav = torch.arange(wav_len, dtype=torch.float32) - return wav.view(1, 1, -1) + wav = torch.arange(wav_len, dtype=torch.float32).view(1, 1, -1) + offsets = torch.arange(batch, dtype=torch.float32).view(batch, 1, 1) * 1000 + return wav.expand(batch, 1, wav_len) + offsets + + def batched_chunked_decode( + self, + codes: torch.Tensor, + lengths: list[int], + *, + chunk_size: int = 300, + left_context_size: int = 25, + max_batch_size: int = 0, + ) -> torch.Tensor: + self.batched_decode_calls.append( + { + "chunk_size": chunk_size, + "left_context_size": left_context_size, + "max_batch_size": max_batch_size, + "codes_shape": tuple(codes.shape), + "lengths": tuple(lengths), + } + ) + batch = codes.shape[0] + frames = codes.shape[-1] + wav_len = frames * self.total_upsample + 6 + wav = torch.arange(wav_len, dtype=torch.float32).view(1, 1, -1) + offsets = torch.arange(batch, dtype=torch.float32).view(batch, 1, 1) * 1000 + return wav.expand(batch, 1, wav_len) + offsets def enable_cudagraph(self, **kwargs): self.cudagraph_calls.append(kwargs) @@ -177,6 +206,7 @@ def test_connector_codec_chunking_does_not_override_decode_chunking(): assert model.decoder.decode_calls[-1] == { "chunk_size": 300, "left_context_size": 25, + "codes_shape": (1, _NUM_QUANTIZERS, 6), } @@ -199,6 +229,175 @@ def test_decode_chunking_can_be_overridden_separately(): assert model._decode_left_context_frames == 17 +def test_malformed_codec_length_warning_is_rate_limited(): + model = _make_model() + + with patch("vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav.logger.warning") as warning: + out1 = model.forward(input_ids=torch.arange(1, dtype=torch.long)) + out2 = model.forward(input_ids=torch.arange(1, dtype=torch.long)) + + assert warning.call_count == 1 + assert "not divisible by num_quantizers" in warning.call_args[0][0] + assert "suppressing repeats" in warning.call_args[0][0] + assert out1.multimodal_outputs["model_outputs"][0].numel() == 0 + assert out2.multimodal_outputs["model_outputs"][0].numel() == 0 + + +def test_forward_batches_equal_length_requests_in_one_decoder_call(): + model = _make_model() + + out = model.forward( + input_ids=torch.arange(12, dtype=torch.long), + seq_token_counts=[6, 6], + runtime_additional_information=[ + {"meta": {"left_context_size": 0}}, + {"meta": {"left_context_size": 1}}, + ], + ) + + assert model.decoder.decode_calls == [ + { + "chunk_size": 300, + "left_context_size": 25, + "codes_shape": (2, _NUM_QUANTIZERS, 3), + } + ] + audios = out.multimodal_outputs["model_outputs"] + torch.testing.assert_close(audios[0], torch.arange(12, dtype=torch.float32)) + torch.testing.assert_close(audios[1], torch.arange(1004, 1012, dtype=torch.float32)) + + +def test_forward_uses_variable_length_chunk_batching_for_long_bucket_group(): + model = _make_model() + model._decode_batch_bucket_frames = [4] + model._decode_variable_chunk_batch_min_frames = 1 + + out = model.forward( + input_ids=torch.arange(12, dtype=torch.long), + seq_token_counts=[4, 8], + runtime_additional_information=[ + {"meta": {"left_context_size": 0}}, + {"meta": {"left_context_size": 2}}, + ], + ) + + assert model.decoder.decode_calls == [] + assert model.decoder.batched_decode_calls == [ + { + "chunk_size": 300, + "left_context_size": 25, + "max_batch_size": 0, + "codes_shape": (2, _NUM_QUANTIZERS, 4), + "lengths": (2, 4), + } + ] + audios = out.multimodal_outputs["model_outputs"] + torch.testing.assert_close(audios[0], torch.arange(8, dtype=torch.float32)) + torch.testing.assert_close(audios[1], torch.arange(1008, 1016, dtype=torch.float32)) + + +def test_forward_bucket_batches_different_length_requests_and_trims_rows(): + model = _make_model() + model._decode_batch_bucket_frames = [4] + + out = model.forward( + input_ids=torch.arange(18, dtype=torch.long), + seq_token_counts=[4, 6, 8], + runtime_additional_information=[ + {"meta": {"left_context_size": 0}}, + {"meta": {"left_context_size": 1}}, + {"meta": {"left_context_size": 2}}, + ], + ) + + assert model.decoder.decode_calls == [ + { + "chunk_size": 300, + "left_context_size": 25, + "codes_shape": (3, _NUM_QUANTIZERS, 4), + } + ] + assert model.decoder.batched_decode_calls == [] + audios = out.multimodal_outputs["model_outputs"] + torch.testing.assert_close(audios[0], torch.arange(8, dtype=torch.float32)) + torch.testing.assert_close(audios[1], torch.arange(1004, 1012, dtype=torch.float32)) + torch.testing.assert_close(audios[2], torch.arange(2008, 2016, dtype=torch.float32)) + + +def test_forward_bucket_pads_only_to_group_max_frame_length(): + model = _make_model() + model._decode_batch_bucket_frames = [8] + + out = model.forward( + input_ids=torch.arange(10, dtype=torch.long), + seq_token_counts=[4, 6], + runtime_additional_information=[ + {"meta": {"left_context_size": 0}}, + {"meta": {"left_context_size": 0}}, + ], + ) + + assert model.decoder.decode_calls == [ + { + "chunk_size": 300, + "left_context_size": 25, + "codes_shape": (2, _NUM_QUANTIZERS, 3), + } + ] + assert model.decoder.batched_decode_calls == [] + audios = out.multimodal_outputs["model_outputs"] + torch.testing.assert_close(audios[0], torch.arange(8, dtype=torch.float32)) + torch.testing.assert_close(audios[1], torch.arange(1000, 1012, dtype=torch.float32)) + + +def test_forward_splits_bucket_groups_by_configured_max_batch_size(): + model = _make_model() + model._decode_batch_bucket_frames = [4] + model._decode_batch_max_size = 2 + + model.forward( + input_ids=torch.arange(18, dtype=torch.long), + seq_token_counts=[4, 6, 8], + runtime_additional_information=[ + {"meta": {"left_context_size": 0}}, + {"meta": {"left_context_size": 0}}, + {"meta": {"left_context_size": 0}}, + ], + ) + + assert model.decoder.decode_calls == [ + { + "chunk_size": 300, + "left_context_size": 25, + "codes_shape": (2, _NUM_QUANTIZERS, 3), + }, + { + "chunk_size": 300, + "left_context_size": 25, + "codes_shape": (1, _NUM_QUANTIZERS, 4), + }, + ] + assert model.decoder.batched_decode_calls == [] + + +def test_forward_does_not_pad_singleton_bucket_group(): + model = _make_model() + model._decode_batch_bucket_frames = [4] + + model.forward( + input_ids=torch.arange(4, dtype=torch.long), + runtime_additional_information=[{"meta": {"left_context_size": 0}}], + ) + + assert model.decoder.decode_calls == [ + { + "chunk_size": 300, + "left_context_size": 25, + "codes_shape": (1, _NUM_QUANTIZERS, 2), + } + ] + + def test_decode_chunking_override_is_passed_to_cudagraph(): model = _make_model( async_chunk=True, @@ -216,6 +415,10 @@ def test_decode_chunking_override_is_passed_to_cudagraph(): _load_weights_noop(model) assert model.decoder.cudagraph_calls[-1] == { + "capture_sizes": None, + "capture_batch_sizes": None, + "extra_capture_shapes": None, + "compile_shapes": None, "device": torch.device("cuda"), "codec_chunk_frames": 25, "codec_left_context_frames": 72, @@ -224,6 +427,106 @@ def test_decode_chunking_override_is_passed_to_cudagraph(): } +def test_cudagraph_capture_shapes_can_be_configured(): + model = _make_model( + async_chunk=True, + device=torch.device("cuda"), + stage_connector_config={ + "extra": { + "decode_cudagraph_capture_sizes": "97,325", + "decode_cudagraph_batch_sizes": [1, 2, 4, 8], + "decode_cudagraph_extra_capture_shapes": ["3:325", [5, 325]], + } + }, + ) + + _load_weights_noop(model) + + call = model.decoder.cudagraph_calls[-1] + assert call["capture_sizes"] == [97, 325] + assert call["capture_batch_sizes"] == [1, 2, 4, 8] + assert call["extra_capture_shapes"] == [(3, 325), (5, 325)] + + +def test_decode_compile_shapes_can_be_configured(): + model = _make_model( + async_chunk=True, + device=torch.device("cuda"), + stage_connector_config={ + "extra": { + "decode_compile_shapes": ["1:325", [1, 73]], + } + }, + ) + + _load_weights_noop(model) + + call = model.decoder.cudagraph_calls[-1] + assert call["compile_shapes"] == [(1, 73), (1, 325)] + + +def test_decode_tf32_can_be_configured(): + old_matmul_tf32 = torch.backends.cuda.matmul.allow_tf32 + old_cudnn_tf32 = torch.backends.cudnn.allow_tf32 + old_matmul_precision = torch.get_float32_matmul_precision() + try: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.set_float32_matmul_precision("highest") + model = _make_model( + async_chunk=True, + device=torch.device("cuda"), + stage_connector_config={ + "extra": { + "decode_enable_tf32": "true", + } + }, + ) + + _load_weights_noop(model) + + assert torch.backends.cuda.matmul.allow_tf32 is True + assert torch.backends.cudnn.allow_tf32 is True + assert torch.get_float32_matmul_precision() == "high" + finally: + torch.backends.cuda.matmul.allow_tf32 = old_matmul_tf32 + torch.backends.cudnn.allow_tf32 = old_cudnn_tf32 + torch.set_float32_matmul_precision(old_matmul_precision) + + +def test_decode_batch_bucket_frames_can_be_configured(): + model = _make_model( + async_chunk=True, + stage_connector_config={ + "extra": { + "decode_batch_bucket_frames": "73,169", + "decode_batch_max_size": 10, + "decode_variable_chunk_batch_min_frames": 512, + } + }, + ) + + _load_weights_noop(model) + + assert model._decode_batch_bucket_frames == [73, 169] + assert model._decode_batch_max_size == 10 + assert model._decode_variable_chunk_batch_min_frames == 512 + + +def test_invalid_decode_batch_max_size_is_rejected(): + model = _make_model( + async_chunk=True, + stage_connector_config={ + "extra": { + "decode_batch_max_size": -1, + } + }, + ) + + with pytest.raises(ValueError, match="decode_batch_max_size"): + _load_weights_noop(model) + + def test_invalid_decode_chunking_is_rejected(): model = _make_model( async_chunk=True, diff --git a/tests/model_executor/models/qwen3_tts/test_qwen3_tts_talker_preprocess.py b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_talker_preprocess.py index 826551bdc26..fd3fecaa91c 100644 --- a/tests/model_executor/models/qwen3_tts/test_qwen3_tts_talker_preprocess.py +++ b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_talker_preprocess.py @@ -179,6 +179,52 @@ def fake_embed_input_ids(input_ids): assert torch.equal(update["hidden_states"]["trailing_text"], trailing_text[65:]) +def test_decode_batch_preprocess_matches_decode_state_updates(): + model = _make_minimal_talker() + + def fake_embed_input_ids(input_ids): + return input_ids.to(torch.float32).reshape(-1, 1, 1).expand(-1, 1, 4) + + model.embed_input_ids = fake_embed_input_ids + trailing_a = torch.arange(12, dtype=torch.float32).reshape(3, 4) + trailing_b = torch.arange(8, dtype=torch.float32).reshape(2, 4) + 100 + last_a = torch.full((4,), 2.0, dtype=torch.float32) + last_b = torch.full((4,), 3.0, dtype=torch.float32) + tts_pad = torch.full((1, 4), -1.0, dtype=torch.float32) + + out_ids, out_embeds, past_hidden, text_step, updates = model.preprocess_decode_batch( + input_ids=torch.tensor([101, 202], dtype=torch.long), + req_infos=[ + { + "text": ["hello"], + "task_type": ["Base"], + "hidden_states": {"trailing_text": trailing_a, "last": last_a}, + "embed": {"tts_pad": tts_pad}, + "meta": {"talker_text_offset": 1}, + }, + { + "text": ["world"], + "task_type": ["CustomVoice"], + "hidden_states": {"trailing_text": trailing_b, "last": last_b}, + "embed": {"tts_pad": tts_pad}, + "meta": {"talker_text_offset": 2}, + }, + ], + ) + + assert out_ids.tolist() == [101, 202] + assert torch.equal(out_embeds.cpu(), torch.tensor([[101.0] * 4, [202.0] * 4], dtype=torch.bfloat16)) + assert torch.equal(past_hidden.cpu(), torch.stack([last_a, last_b]).to(torch.bfloat16)) + assert torch.equal(text_step[0].cpu(), trailing_a[1].to(torch.bfloat16)) + assert torch.equal(text_step[1].cpu(), tts_pad.reshape(-1).to(torch.bfloat16)) + assert updates[0]["meta"]["talker_text_offset"] == 2 + assert updates[0]["meta"]["codec_streaming"] is True + assert "hidden_states" not in updates[0] + assert updates[1]["meta"]["talker_text_offset"] == 0 + assert updates[1]["meta"]["codec_streaming"] is False + assert updates[1]["hidden_states"]["trailing_text"].numel() == 0 + + def test_base_voice_clone_normalizes_ref_audio_once_for_ref_code_and_speaker(): model = _make_minimal_talker() device_param = torch.nn.Parameter(torch.empty(0)) @@ -218,11 +264,11 @@ def __call__(self, *_args, **_kwargs): model._normalize_ref_audio = lambda raw: normalize_calls.append(raw) or (ref_audio, 16000) ref_audio_ids = [] - model._encode_ref_audio_to_code = lambda wav, _sr: ref_audio_ids.append(id(wav)) or torch.ones( - (2, 2), dtype=torch.long + model._encode_ref_audio_to_code = lambda wav, _sr: ( + ref_audio_ids.append(id(wav)) or torch.ones((2, 2), dtype=torch.long) ) - model._extract_speaker_embedding = lambda wav, _sr: ref_audio_ids.append(id(wav)) or torch.ones( - 4, dtype=torch.bfloat16 + model._extract_speaker_embedding = lambda wav, _sr: ( + ref_audio_ids.append(id(wav)) or torch.ones(4, dtype=torch.bfloat16) ) _prompt, _trailing, _pad, ref_code_len, ref_code = model._build_prompt_embeds( @@ -246,9 +292,12 @@ def test_base_voice_clone_batch_preprocess_encodes_ref_code_by_sample_rate(): wav1 = np.arange(2048, dtype=np.float32) wav2 = np.arange(3072, dtype=np.float32) normalize_calls = [] - model._normalize_ref_audio = lambda raw: normalize_calls.append(raw) or ( - wav1 if raw == "a.wav" else wav2, - 16000, + model._normalize_ref_audio = lambda raw: ( + normalize_calls.append(raw) + or ( + wav1 if raw == "a.wav" else wav2, + 16000, + ) ) class FakeSpeechTokenizer: @@ -421,8 +470,8 @@ def __call__(self, *_args, **_kwargs): AssertionError("serial encode not expected") ) speaker_wav_ids = [] - model._extract_speaker_embedding = lambda wav, _sr: speaker_wav_ids.append(id(wav)) or torch.ones( - 4, dtype=torch.bfloat16 + model._extract_speaker_embedding = lambda wav, _sr: ( + speaker_wav_ids.append(id(wav)) or torch.ones(4, dtype=torch.bfloat16) ) _prompt, _trailing, _pad, ref_code_len, out_ref_code = model._build_prompt_embeds( diff --git a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py index 950ae213f72..8f74ed094c4 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py @@ -294,6 +294,36 @@ def test_ref_code_context_applies_to_all_streaming_chunks(): assert len(payload.codes.audio) == _Q * (35 + 2) +def test_streaming_ref_code_context_is_bounded_for_batchable_shapes(): + tm = _tm(chunk_frames=4, left_context=3, initial_chunk_frames=4) + rid = "r-ref-bounded" + tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(8)] + ref_code = torch.tensor( + [ + [1, 1, 1, 1], + [2, 2, 2, 2], + [3, 3, 3, 3], + [4, 4, 4, 4], + [5, 5, 5, 5], + ], + dtype=torch.long, + ) + tm.request_payload[rid] = ref_code + + payload = talker2code2wav_async_chunk( + transfer_manager=tm, + pooling_output={"codes": {"audio": torch.zeros((0,)), "ref": ref_code}}, + request=_req(rid, finished=False), + is_finished=False, + ) + + assert payload is not None + assert payload.meta.left_context_size == 3 + 3 + assert len(payload.codes.audio) == _Q * (3 + 3 + 4) + frames = payload.codes.audio.reshape(_Q, -1).transpose(0, 1) + torch.testing.assert_close(frames[:3], ref_code[-3:]) + + def test_ref_code_context_can_be_buffered_before_first_emit(): tm = _tm() rid = "r-ref-buffered" diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index 2d6f11e7af8..4f494488401 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -82,6 +82,7 @@ def forward( ): self.calls.append( { + "batch_size": int(req_embeds.shape[0]), "do_sample": do_sample, "temperature": temperature, "top_k": top_k, @@ -265,6 +266,7 @@ def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_ assert runner.talker_mtp.calls == [ { + "batch_size": 1, "do_sample": False, "temperature": 0.2, "top_k": 9, @@ -275,6 +277,45 @@ def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_ assert runner.talker_mtp.calls[0]["generator"] is not None +def test_talker_mtp_forward_keeps_explicit_seeded_requests_scalar(monkeypatch): + import vllm_omni.worker.gpu_model_runner as mod + + monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) + + runner = _make_runner(req_ids=("r1", "r2"), hidden_size=4) + runner.requests["r1"].sampling_params = SimpleNamespace( + seed=11, + extra_args={"qwen3_tts_request_seed": 11}, + ) + runner.requests["r2"].sampling_params = SimpleNamespace( + seed=22, + extra_args={"qwen3_tts_request_seed": 22}, + ) + runner.talker_mtp = CaptureTalkerMTP() + runner.vllm_config = SimpleNamespace(model_config=SimpleNamespace(subtalker_sampling_params={})) + + def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_scheduled_tokens, use_cascade_attn): + batch_desc = SimpleNamespace(num_tokens=int(num_tokens)) + return (False, batch_desc, None, None, None) + + monkeypatch.setattr(runner, "_determine_batch_execution_and_padding", fake_determine.__get__(runner, type(runner))) + + runner.talker_mtp_input_ids.gpu[:] = torch.tensor([101, 202], dtype=torch.int64) + runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0]) + runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0]) + saved_input_ids = runner.talker_mtp_input_ids.gpu.clone() + saved_embeds = runner.talker_mtp_inputs_embeds.gpu.clone() + + inputs_embeds = torch.zeros((6, 4), dtype=torch.float32) + OmniGPUModelRunner._talker_mtp_forward(runner, ["r1", "r2"], inputs_embeds) + + assert [call["batch_size"] for call in runner.talker_mtp.calls] == [1, 1] + assert all(call["generator"] is not None for call in runner.talker_mtp.calls) + assert runner.talker_mtp.calls[0]["generator"] is not runner.talker_mtp.calls[1]["generator"] + assert torch.equal(runner.talker_mtp_input_ids.gpu, saved_input_ids) + assert torch.equal(runner.talker_mtp_inputs_embeds.gpu, saved_embeds) + + def test_update_intermediate_buffer_writes_to_buffer_and_setattr(monkeypatch): """Validate that _update_intermediate_buffer writes to model_intermediate_buffer (forward path) and mirrors to additional_information_cpu setattr (backward compat).""" diff --git a/vllm_omni/deploy/qwen3_tts_high_concurrency.yaml b/vllm_omni/deploy/qwen3_tts_high_concurrency.yaml new file mode 100644 index 00000000000..9c6ec19fe0d --- /dev/null +++ b/vllm_omni/deploy/qwen3_tts_high_concurrency.yaml @@ -0,0 +1,95 @@ +# Qwen3-TTS high-concurrency deploy profile. +# +# This is intentionally separate from qwen3_tts.yaml. Use it only when running +# sustained high-concurrency serving on two GPUs, for example: +# +# vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base --omni \ +# --deploy-config vllm_omni/deploy/qwen3_tts_high_concurrency.yaml +# +# Profile validated for the c64 / PROMPTS=512 performance experiments: +# Stage 0 talker on GPU 0 with S0=64, Stage 1 Code2Wav on GPU 1 with S1=10. +# +async_chunk: true + +connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + codec_streaming: true + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + codec_chunk_frames: 25 + codec_left_context_frames: 72 + # Stage0 code-predictor prefix CUDA graphs for the c64 hot path. + # These keys are consumed by Qwen3-TTS talker and ignored by Code2Wav. + code_predictor_prefix_graphs: true + code_predictor_prefix_graph_buckets: [64] + code_predictor_prefix_graph_seq_lens: [2, 3, 4, 5, 6, 7, 8] + # Keep voice-clone reference context bounded so Stage1 chunk lengths are + # stable across different reference-audio durations. + ref_code_context_frames: 72 + # Emit only the first audio chunk early, then return to codec_chunk_frames. + initial_codec_chunk_frames: 1 + # Common Stage1 decode buckets: + # no-ref first/steady chunks: 25 / 97 frames + # Base ref-context first/steady chunks: 73 / 169 frames + # decoder internal non-streaming chunks: 325 frames + decode_cudagraph_capture_sizes: [25, 73, 97, 169, 325] + # Keep B>1 captures opt-in; c64 e2e validation did not show a stable win. + decode_cudagraph_batch_sizes: [1] + decode_compile_shapes: [] + decode_batch_max_size: 1 + decode_batch_bucket_frames: [] + decode_enable_tf32: false + +stages: + - stage_id: 0 + max_num_seqs: 64 + gpu_memory_utilization: 0.3 + trust_remote_code: true + enable_prefix_caching: false + async_scheduling: true + max_num_batched_tokens: 512 + max_model_len: 4096 + devices: "0" + output_connectors: + to_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + repetition_penalty: 1.05 + subtalker_sampling_params: + do_sample: true + temperature: 0.9 + top_k: 50 + top_p: 1.0 + + - stage_id: 1 + max_num_seqs: 10 + gpu_memory_utilization: 0.3 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + async_scheduling: true + max_num_batched_tokens: 65536 + max_model_len: 65536 + devices: "1" + input_connectors: + from_stage_0: connector_of_shared_memory + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + repetition_penalty: 1.0 + +platforms: + npu: + stages: + - stage_id: 0 + enforce_eager: true diff --git a/vllm_omni/model_executor/models/common/qwen3_code_predictor.py b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py index d36068d4678..8f60c2a108d 100644 --- a/vllm_omni/model_executor/models/common/qwen3_code_predictor.py +++ b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py @@ -352,7 +352,7 @@ def forward( # layers already upcast internally; this extends the same # treatment to attention and MLP. input_dtype = inputs_embeds.dtype - use_fp32 = input_dtype == torch.float16 + use_fp32 = input_dtype == torch.float16 and inputs_embeds.device.type == "cuda" if use_fp32: inputs_embeds = inputs_embeds.float() hidden_states = inputs_embeds @@ -465,10 +465,26 @@ def __init__( self._model_dtype: torch.dtype | None = None self._compiled_model_fwd = None self._bucket_sizes: list[int] = [] - self._bucket_pos_ids: dict[int, torch.Tensor] = {} + self._bucket_pos_ids: dict[int | tuple[int, int], torch.Tensor] = {} self._lm_heads_list: list[nn.Module] | None = None self._codec_embeds_list: list[nn.Module] | None = None - self._device_graphs: dict[int, tuple] = {} # (graph, static_output) per bucket + self._device_graphs: dict[int | tuple[int, int], tuple] = {} # (graph, static_output) per bucket + prefix_graph_cfg = self._stage_connector_extra_config(vllm_config) + prefix_graphs_requested = self._parse_bool_config(prefix_graph_cfg.get("code_predictor_prefix_graphs")) + is_npu = current_omni_platform.is_npu() + self._prefix_graphs_enabled = prefix_graphs_requested and wrapper_config.use_cuda_graphs and not is_npu + if prefix_graphs_requested and not self._prefix_graphs_enabled: + logger.info_once( + "code_predictor: prefix CUDA graphs requested but disabled because use_cuda_graphs=%s is_npu=%s", + wrapper_config.use_cuda_graphs, + is_npu, + ) + self._prefix_graph_buckets = self._parse_positive_int_set( + prefix_graph_cfg.get("code_predictor_prefix_graph_buckets") + ) + self._prefix_graph_seq_lens = self._parse_positive_int_set( + prefix_graph_cfg.get("code_predictor_prefix_graph_seq_lens") + ) def get_input_embeddings(self) -> nn.ModuleList: return self.model.get_input_embeddings() @@ -544,6 +560,50 @@ def _padded_bsz(self, bsz: int) -> int: return bucket return bsz + @staticmethod + def _stage_connector_extra_config(vllm_config: VllmConfig) -> dict: + model_cfg = getattr(vllm_config, "model_config", None) + connector_cfg = getattr(model_cfg, "stage_connector_config", None) + if isinstance(connector_cfg, dict): + extra_cfg = connector_cfg.get("extra", connector_cfg) + else: + extra_cfg = getattr(connector_cfg, "extra", None) + return extra_cfg if isinstance(extra_cfg, dict) else {} + + @staticmethod + def _parse_bool_config(value: object) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in ("1", "true", "yes", "on") + if isinstance(value, int): + return bool(value) + return False + + @staticmethod + def _parse_positive_int_set(value: object) -> set[int]: + if value is None: + return set() + if isinstance(value, str): + raw_values = [item.strip() for item in value.replace(";", ",").split(",") if item.strip()] + elif isinstance(value, int): + raw_values = [value] + else: + raw_values = list(value) + values: set[int] = set() + for item in raw_values: + value = int(item) + if value > 0: + values.add(value) + return values + + def _prefix_seq_lens(self, max_seq: int) -> list[int]: + all_seq_lens = list(range(2, max_seq)) + if not self._prefix_graph_seq_lens: + return all_seq_lens + allowed = set(all_seq_lens) + return sorted(seq_len for seq_len in self._prefix_graph_seq_lens if seq_len in allowed) + def _warmup_buckets(self) -> None: """Warmup power-of-2 batch-size buckets to front-load Inductor compilation.""" max_bsz = self._vllm_config.scheduler_config.max_num_seqs @@ -560,12 +620,44 @@ def _warmup_buckets(self) -> None: self._ensure_buffers(device, self._model_dtype, max(self._bucket_sizes)) proj_buf = self._proj_buf - for bsz in self._bucket_sizes: - pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1).contiguous() - self._bucket_pos_ids[bsz] = pos_ids - for _ in range(3): - self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids) - logger.info("code_predictor: warmup done for buckets %s", self._bucket_sizes) + if self._prefix_graphs_enabled: + prefix_seq_lens = self._prefix_seq_lens(max_seq) + needs_full_graph = set(prefix_seq_lens) != set(range(2, max_seq)) + for bsz in self._bucket_sizes: + capture_prefixes = not self._prefix_graph_buckets or bsz in self._prefix_graph_buckets + if not capture_prefixes or needs_full_graph: + pos_ids = ( + torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1).contiguous() + ) + self._bucket_pos_ids[bsz] = pos_ids + for _ in range(3): + self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids) + if capture_prefixes: + for seq_len in prefix_seq_lens: + pos_ids = ( + torch.arange(seq_len, device=device, dtype=torch.long) + .unsqueeze(0) + .expand(bsz, -1) + .contiguous() + ) + self._bucket_pos_ids[(bsz, seq_len)] = pos_ids + for _ in range(2): + self._compiled_model_fwd(proj_buf[:bsz, :seq_len, :], pos_ids) + logger.info( + "code_predictor: prefix warmup done for buckets %s prefix_buckets=%s seq_lens=%s", + self._bucket_sizes, + sorted(self._prefix_graph_buckets) if self._prefix_graph_buckets else "all", + prefix_seq_lens, + ) + else: + for bsz in self._bucket_sizes: + pos_ids = ( + torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1).contiguous() + ) + self._bucket_pos_ids[bsz] = pos_ids + for _ in range(3): + self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids) + logger.info("code_predictor: warmup done for buckets %s", self._bucket_sizes) def _capture_cuda_graphs(self) -> None: """Capture a CUDA graph per bucket using vLLM's global graph pool.""" @@ -575,17 +667,50 @@ def _capture_cuda_graphs(self) -> None: max_seq = self._num_groups + 1 proj_buf = self._proj_buf - for bsz in self._bucket_sizes: - static_input = proj_buf[:bsz, :max_seq, :] - pos_ids = self._bucket_pos_ids[bsz] + if self._prefix_graphs_enabled: + prefix_seq_lens = self._prefix_seq_lens(max_seq) + needs_full_graph = set(prefix_seq_lens) != set(range(2, max_seq)) + for bsz in self._bucket_sizes: + capture_prefixes = not self._prefix_graph_buckets or bsz in self._prefix_graph_buckets + if not capture_prefixes or needs_full_graph: + static_input = proj_buf[:bsz, :max_seq, :] + pos_ids = self._bucket_pos_ids[bsz] + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, pool=pool): + static_output = self._compiled_model_fwd(static_input, pos_ids) + + self._device_graphs[bsz] = (g, static_output) + + if capture_prefixes: + for seq_len in prefix_seq_lens: + static_input = proj_buf[:bsz, :seq_len, :] + pos_ids = self._bucket_pos_ids[(bsz, seq_len)] + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, pool=pool): + static_output = self._compiled_model_fwd(static_input, pos_ids) + + self._device_graphs[(bsz, seq_len)] = (g, static_output) + + logger.info( + "code_predictor: captured prefix CUDA graphs for buckets %s prefix_buckets=%s seq_lens=%s", + self._bucket_sizes, + sorted(self._prefix_graph_buckets) if self._prefix_graph_buckets else "all", + prefix_seq_lens, + ) + else: + for bsz in self._bucket_sizes: + static_input = proj_buf[:bsz, :max_seq, :] + pos_ids = self._bucket_pos_ids[bsz] - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g, pool=pool): - static_output = self._compiled_model_fwd(static_input, pos_ids) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, pool=pool): + static_output = self._compiled_model_fwd(static_input, pos_ids) - self._device_graphs[bsz] = (g, static_output) + self._device_graphs[bsz] = (g, static_output) - logger.info("code_predictor: captured CUDA graphs for buckets %s", self._bucket_sizes) + logger.info("code_predictor: captured CUDA graphs for buckets %s", self._bucket_sizes) def _capture_npu_graphs(self) -> None: """Capture an NPU graph per bucket using torch_npu's NPUGraph.""" @@ -648,16 +773,6 @@ def forward( proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1) proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1) - # Get pre-computed pos_ids for this bucket - full_pos_ids = self._bucket_pos_ids.get(padded_bsz) - if full_pos_ids is None: - full_pos_ids = ( - torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(padded_bsz, -1).contiguous() - ) - - # Use captured device graph if available, otherwise call compiled fn. - device_graph_entry = self._device_graphs.get(padded_bsz) - # Prepare sampling parameters stored_mode = self._wrapper_config.sampling_mode == "stored" if stored_mode: @@ -681,12 +796,31 @@ def forward( # Autoregressive loop: predict layers 1..G-1 for step in range(1, num_groups): + graph_key: int | tuple[int, int] = padded_bsz + seq_len = max_seq + if self._prefix_graphs_enabled: + prefix_key = (padded_bsz, step + 1) + if prefix_key in self._device_graphs: + graph_key = prefix_key + seq_len = step + 1 + pos_ids = self._bucket_pos_ids.get(graph_key) + if pos_ids is None: + pos_ids = ( + torch.arange(seq_len, device=device, dtype=torch.long) + .unsqueeze(0) + .expand(padded_bsz, -1) + .contiguous() + ) + + # Use captured device graph if available, otherwise call compiled fn. + device_graph_entry = self._device_graphs.get(graph_key) + # Run transformer (device graph replay or compiled forward) if device_graph_entry is not None: device_graph_entry[0].replay() hidden_out = device_graph_entry[1] else: - hidden_out = model_fwd(proj_buf[:padded_bsz, :max_seq, :], full_pos_ids) + hidden_out = model_fwd(proj_buf[:padded_bsz, :seq_len, :], pos_ids) logits = lm_heads[step - 1](hidden_out[:bsz, step, :]) diff --git a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py index 9993784431b..a9cdc8948b9 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py +++ b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py @@ -7,6 +7,11 @@ reducing kernel launch overhead during inference. """ +import os +import time +from collections import Counter +from collections.abc import Callable, Sequence + import torch from torch.cuda import CUDAGraph from vllm.logger import init_logger @@ -15,6 +20,108 @@ logger = init_logger(__name__) +def _normalize_decode_lengths(lengths: Sequence[int], batch_size: int, max_len: int) -> tuple[int, ...]: + if len(lengths) != batch_size: + raise ValueError(f"Expected {batch_size} decode lengths, got {len(lengths)}") + + normalized: list[int] = [] + for length in lengths: + length_int = int(length) + if length_int < 0 or length_int > max_len: + raise ValueError(f"Invalid decode length {length_int}; expected 0 <= length <= {max_len}") + normalized.append(length_int) + return tuple(normalized) + + +def _batched_chunked_decode( + codes: torch.Tensor, + lengths: Sequence[int], + *, + decode_fn: Callable[[torch.Tensor], torch.Tensor], + total_upsample: int, + chunk_size: int = 300, + left_context_size: int = 25, + max_batch_size: int = 0, +) -> torch.Tensor: + """Decode a padded batch by grouping same-round chunks across requests.""" + if codes.dim() < 3: + raise ValueError(f"Expected codes with shape [B, Q, F], got {tuple(codes.shape)}") + if chunk_size <= 0: + raise ValueError(f"chunk_size must be positive, got {chunk_size}") + if left_context_size < 0: + raise ValueError(f"left_context_size must be non-negative, got {left_context_size}") + if max_batch_size < 0: + raise ValueError(f"max_batch_size must be non-negative, got {max_batch_size}") + + batch_size = int(codes.shape[0]) + max_input_len = int(codes.shape[-1]) + length_values = _normalize_decode_lengths(lengths, batch_size, max_input_len) + max_decode_len = max(length_values, default=0) + if max_decode_len == 0: + return torch.empty((batch_size, 1, 0), dtype=torch.float32, device=codes.device) + + total_upsample = int(total_upsample) + wav_out: torch.Tensor | None = None + num_rounds = (max_decode_len + chunk_size - 1) // chunk_size + + for round_index in range(num_rounds): + start_index = round_index * chunk_size + grouped_jobs: dict[int, list[tuple[int, int, int, int, int, int]]] = {} + for req_index, total_len in enumerate(length_values): + if start_index >= total_len: + continue + end_index = min(start_index + chunk_size, total_len) + context_size = left_context_size if start_index - left_context_size > 0 else start_index + input_start = start_index - context_size + input_end = end_index + chunk_len = input_end - input_start + grouped_jobs.setdefault(chunk_len, []).append( + (req_index, input_start, input_end, start_index, end_index, context_size) + ) + + for jobs in grouped_jobs.values(): + job_batches = ( + [jobs] + if max_batch_size <= 0 or len(jobs) <= max_batch_size + else [jobs[start : start + max_batch_size] for start in range(0, len(jobs), max_batch_size)] + ) + for job_batch in job_batches: + chunk_rows = [ + codes[req_index, :, input_start:input_end] + for req_index, input_start, input_end, _, _, _ in job_batch + ] + codes_chunk = torch.stack( + chunk_rows, + dim=0, + ) + wav_chunk = decode_fn(codes_chunk) + if wav_chunk.shape[0] != len(job_batch): + raise ValueError( + f"Decoder returned batch size {wav_chunk.shape[0]} for input batch size {len(job_batch)}" + ) + if wav_out is None: + wav_out = torch.empty( + (batch_size, *wav_chunk.shape[1:-1], max_decode_len * total_upsample), + dtype=wav_chunk.dtype, + device=wav_chunk.device, + ) + + for row, (req_index, _, _, chunk_start, chunk_end, context_size) in enumerate(job_batch): + src_start = context_size * total_upsample + dst_start = chunk_start * total_upsample + dst_end = chunk_end * total_upsample + src_end = src_start + (dst_end - dst_start) + if src_end > wav_chunk.shape[-1]: + raise ValueError( + f"Decoder returned too-short chunk output: need {src_end}, got {wav_chunk.shape[-1]}" + ) + wav_out[req_index, ..., dst_start:dst_end].copy_(wav_chunk[row, ..., src_start:src_end]) + + if wav_out is None: + return torch.empty((batch_size, 1, 0), dtype=torch.float32, device=codes.device) + return wav_out + + class CUDAGraphDecoderWrapper: """ CUDA Graph wrapper for Qwen3TTSTokenizerV2Decoder. @@ -34,21 +141,61 @@ def __init__( self, decoder: torch.nn.Module, capture_sizes: list[int] | None = None, + capture_batch_sizes: list[int] | None = None, + extra_capture_shapes: list[tuple[int, int]] | None = None, + compile_shapes: list[tuple[int, int]] | None = None, num_quantizers: int = 8, enabled: bool = True, ): self.decoder = decoder self._explicit_sizes = capture_sizes is not None self.capture_sizes = sorted(capture_sizes) if capture_sizes else [] + self.capture_batch_sizes = sorted(set(capture_batch_sizes or [1])) + self.extra_capture_shapes = sorted( + { + (int(batch_size), int(size)) + for batch_size, size in extra_capture_shapes or [] + if int(batch_size) > 0 and int(size) > 0 + } + ) + self.compile_shapes = sorted( + { + (int(batch_size), int(size)) + for batch_size, size in compile_shapes or [] + if int(batch_size) > 0 and int(size) > 0 + } + ) + self._bucket_sizes = self.capture_sizes self.num_quantizers = num_quantizers self.enabled = enabled - self.graphs: dict[int, CUDAGraph] = {} - self.static_inputs: dict[int, torch.Tensor] = {} - self.static_outputs: dict[int, torch.Tensor] = {} + self.graphs: dict[tuple[int, int], CUDAGraph] = {} + self.static_inputs: dict[tuple[int, int], torch.Tensor] = {} + self.static_outputs: dict[tuple[int, int], torch.Tensor] = {} + self._compiled_decoder: Callable[[torch.Tensor], torch.Tensor] | None = None + self._compiled_graphs: dict[tuple[int, int], CUDAGraph] = {} + self._compiled_static_inputs: dict[tuple[int, int], torch.Tensor] = {} + self._compiled_static_outputs: dict[tuple[int, int], torch.Tensor] = {} + self._compiled_shapes: set[tuple[int, int]] = set() self._warmed_up = False self._device = None + self._stats_enabled = os.environ.get("VLLM_OMNI_QWEN3_CODE2WAV_CUDAGRAPH_STATS", "").lower() in ( + "1", + "true", + "yes", + "on", + ) + self._stats_log_every = int(os.environ.get("VLLM_OMNI_QWEN3_CODE2WAV_CUDAGRAPH_STATS_LOG_EVERY", "0") or 0) + self._stats_total = 0 + self._stats_hits = 0 + self._stats_compiled_hits = 0 + self._stats_fallbacks = 0 + self._stats_stream_capture_fallbacks = 0 + self._stats_requests: Counter[tuple[int, int]] = Counter() + self._stats_hit_shapes: Counter[tuple[int, int, int]] = Counter() + self._stats_compiled_shapes: Counter[tuple[int, int]] = Counter() + self._stats_fallback_shapes: Counter[tuple[int, int, int]] = Counter() @staticmethod def compute_capture_sizes( @@ -78,11 +225,68 @@ def compute_capture_sizes( return sorted(sizes) def _get_padded_size(self, actual_size: int) -> int | None: - for size in self.capture_sizes: + for size in self._bucket_sizes: if actual_size <= size: return size return None + def _get_capture_shapes(self) -> list[tuple[int, int]]: + shapes = {(batch_size, size) for batch_size in self.capture_batch_sizes for size in self.capture_sizes} + shapes.update(self.extra_capture_shapes) + return sorted(shapes) + + def _record_decode_stats( + self, + *, + hit: bool, + batch_size: int, + actual_size: int, + padded_size: int | None, + stream_capture: bool = False, + compiled: bool = False, + ) -> None: + if not self._stats_enabled: + return + + padded_key = int(padded_size) if padded_size is not None else -1 + self._stats_total += 1 + self._stats_requests[(batch_size, actual_size)] += 1 + if hit: + self._stats_hits += 1 + if compiled: + self._stats_compiled_hits += 1 + self._stats_compiled_shapes[(batch_size, actual_size)] += 1 + else: + self._stats_hit_shapes[(batch_size, actual_size, padded_key)] += 1 + else: + self._stats_fallbacks += 1 + self._stats_fallback_shapes[(batch_size, actual_size, padded_key)] += 1 + if stream_capture: + self._stats_stream_capture_fallbacks += 1 + + if self._stats_log_every > 0 and self._stats_total % self._stats_log_every == 0: + self.log_decode_stats() + + def log_decode_stats(self) -> None: + if not self._stats_enabled or self._stats_total == 0: + return + hit_rate = 100.0 * self._stats_hits / self._stats_total + logger.info( + "Code2Wav CUDA Graph stats: total=%d hits=%d fallbacks=%d " + "compiled_hits=%d stream_capture_fallbacks=%d hit_rate=%.2f%% " + "top_requests=%s top_compiled=%s top_hits=%s top_fallbacks=%s", + self._stats_total, + self._stats_hits, + self._stats_fallbacks, + self._stats_compiled_hits, + self._stats_stream_capture_fallbacks, + hit_rate, + self._stats_requests.most_common(12), + self._stats_compiled_shapes.most_common(12), + self._stats_hit_shapes.most_common(12), + self._stats_fallback_shapes.most_common(12), + ) + def warmup( self, device: torch.device, @@ -106,28 +310,97 @@ def warmup( decode_left_context=decode_left_context, ) - logger.info("Starting CUDA Graph warmup for %d sizes: %s", len(self.capture_sizes), self.capture_sizes) + self.capture_batch_sizes = [bs for bs in self.capture_batch_sizes if bs > 0] + if not self.capture_batch_sizes: + self.capture_batch_sizes = [1] + + self._bucket_sizes = sorted(set(self.capture_sizes) | {size for _, size in self.extra_capture_shapes}) + capture_shapes = self._get_capture_shapes() + + logger.info( + "Starting CUDA Graph warmup for %d shapes: batch_sizes=%s seq_lens=%s extra_shapes=%s", + len(capture_shapes), + self.capture_batch_sizes, + self.capture_sizes, + self.extra_capture_shapes, + ) + warmup_start_s = time.perf_counter() + mem_before = self._get_cuda_memory_stats(device) # Warmup runs to ensure CUDA memory is allocated - for size in self.capture_sizes: - dummy = torch.zeros(1, self.num_quantizers, size, dtype=dtype, device=device) + for batch_size, size in capture_shapes: + dummy = torch.zeros(batch_size, self.num_quantizers, size, dtype=dtype, device=device) with torch.no_grad(): _ = self.decoder(dummy) torch.accelerator.synchronize(device) - for size in self.capture_sizes: + for batch_size, size in capture_shapes: try: - self._capture(size, device, dtype) - logger.info(" Captured CUDA Graph for size=%d", size) + self._capture(batch_size, size, device, dtype) + logger.info(" Captured CUDA Graph for batch=%d size=%d", batch_size, size) except Exception: - logger.warning(" Failed to capture graph for size=%d", size, exc_info=True) + logger.warning(" Failed to capture graph for batch=%d size=%d", batch_size, size, exc_info=True) + + if self.compile_shapes: + self._warmup_compile_shapes(device, dtype) self._warmed_up = True - logger.info("CUDA Graph warmup complete: %d/%d captured", len(self.graphs), len(self.capture_sizes)) + warmup_ms = (time.perf_counter() - warmup_start_s) * 1000.0 + mem_after = self._get_cuda_memory_stats(device) + logger.info( + "CUDA Graph warmup complete: %d/%d captured in %.1f ms%s", + len(self.graphs), + len(capture_shapes), + warmup_ms, + self._format_cuda_memory_delta(mem_before, mem_after), + ) + + def _warmup_compile_shapes(self, device: torch.device, dtype: torch.dtype) -> None: + logger.info("Starting torch.compile + CUDA Graph warmup for decoder shapes: %s", self.compile_shapes) + compile_start_s = time.perf_counter() + try: + self._compiled_decoder = torch.compile( + self.decoder.forward, + mode="default", + fullgraph=False, + dynamic=False, + ) + except Exception: + logger.warning("Failed to create torch.compile decoder wrapper", exc_info=True) + self._compiled_decoder = None + return - def _capture(self, size: int, device: torch.device, dtype: torch.dtype): - static_input = torch.zeros(1, self.num_quantizers, size, dtype=dtype, device=device) + assert self._compiled_decoder is not None + for batch_size, size in self.compile_shapes: + shape_start_s = time.perf_counter() + try: + self._capture_compiled(batch_size, size, device, dtype) + self._compiled_shapes.add((batch_size, size)) + logger.info( + " torch.compile + CUDA Graph ready for batch=%d size=%d in %.1f ms", + batch_size, + size, + (time.perf_counter() - shape_start_s) * 1000.0, + ) + except Exception: + logger.warning( + " Failed to capture torch.compile CUDA Graph for batch=%d size=%d; " + "this shape will use CUDA Graph/eager fallback", + batch_size, + size, + exc_info=True, + ) + logger.info( + "torch.compile + CUDA Graph warmup complete: %d/%d shapes ready in %.1f ms", + len(self._compiled_shapes), + len(self.compile_shapes), + (time.perf_counter() - compile_start_s) * 1000.0, + ) + + def _capture(self, batch_size: int, size: int, device: torch.device, dtype: torch.dtype): + key = (batch_size, size) + static_input = torch.zeros(batch_size, self.num_quantizers, size, dtype=dtype, device=device) with torch.no_grad(): _ = self.decoder(static_input) torch.accelerator.synchronize(device) @@ -137,12 +410,64 @@ def _capture(self, size: int, device: torch.device, dtype: torch.dtype): with torch.cuda.graph(graph, pool=current_platform.get_global_graph_pool()): static_output = self.decoder(static_input) - self.graphs[size] = graph - self.static_inputs[size] = static_input - self.static_outputs[size] = static_output + self.graphs[key] = graph + self.static_inputs[key] = static_input + self.static_outputs[key] = static_output - def decode(self, codes: torch.Tensor) -> torch.Tensor: - if not self.enabled or not self._warmed_up or codes.shape[0] != 1: + def _capture_compiled(self, batch_size: int, size: int, device: torch.device, dtype: torch.dtype): + if self._compiled_decoder is None: + raise RuntimeError("Compiled decoder is not initialized") + + key = (batch_size, size) + static_input = torch.zeros(batch_size, self.num_quantizers, size, dtype=dtype, device=device) + with torch.inference_mode(): + for _ in range(5): + _ = self._compiled_decoder(static_input) + torch.accelerator.synchronize(device) + + graph = CUDAGraph() + with torch.inference_mode(): + with torch.cuda.graph(graph, pool=current_platform.get_global_graph_pool()): + static_output = self._compiled_decoder(static_input) + + self._compiled_graphs[key] = graph + self._compiled_static_inputs[key] = static_input + self._compiled_static_outputs[key] = static_output + + @staticmethod + def _get_cuda_memory_stats(device: torch.device) -> tuple[int, int, int] | None: + if device.type != "cuda": + return None + try: + return ( + int(torch.cuda.memory_allocated(device)), + int(torch.cuda.memory_reserved(device)), + int(torch.cuda.max_memory_reserved(device)), + ) + except Exception: + return None + + @staticmethod + def _format_cuda_memory_delta( + before: tuple[int, int, int] | None, + after: tuple[int, int, int] | None, + ) -> str: + if before is None or after is None: + return "" + + def gib(value: int) -> float: + return value / 1024**3 + + alloc_before, reserved_before, max_reserved_before = before + alloc_after, reserved_after, max_reserved_after = after + return ( + f" (cuda_mem allocated {gib(alloc_before):.2f}->{gib(alloc_after):.2f} GiB, " + f"reserved {gib(reserved_before):.2f}->{gib(reserved_after):.2f} GiB, " + f"max_reserved {gib(max_reserved_before):.2f}->{gib(max_reserved_after):.2f} GiB)" + ) + + def _decode(self, codes: torch.Tensor, *, clone_graph_output: bool) -> torch.Tensor: + if not self.enabled or not self._warmed_up: return self.decoder(codes) # Inner CUDA graph replay is illegal while an outer stream capture is @@ -152,20 +477,76 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor: # outside the startup capture window, so normal inference still hits # the graph fast path. if torch.cuda.is_current_stream_capturing(): + self._record_decode_stats( + hit=False, + batch_size=int(codes.shape[0]), + actual_size=int(codes.shape[-1]), + padded_size=None, + stream_capture=True, + ) return self.decoder(codes) - actual_size = codes.shape[-1] + batch_size = int(codes.shape[0]) + actual_size = int(codes.shape[-1]) padded_size = self._get_padded_size(actual_size) - - if padded_size is None or padded_size not in self.graphs: + compile_key = (batch_size, actual_size) + if compile_key not in self._compiled_shapes and padded_size is not None: + compile_key = (batch_size, padded_size) + if compile_key in self._compiled_shapes: + compiled_size = compile_key[1] + self._record_decode_stats( + hit=True, + batch_size=batch_size, + actual_size=actual_size, + padded_size=compiled_size, + compiled=True, + ) + static_input = self._compiled_static_inputs[compile_key] + if actual_size == compiled_size: + static_input.copy_(codes) + else: + static_input.zero_() + static_input[:, :, :actual_size] = codes + self._compiled_graphs[compile_key].replay() + actual_out_len = actual_size * self.decoder.total_upsample + output = self._compiled_static_outputs[compile_key][..., :actual_out_len] + if clone_graph_output: + return output.clone() + return output + + graph_key = (batch_size, padded_size) if padded_size is not None else None + + if graph_key is None or graph_key not in self.graphs: + self._record_decode_stats( + hit=False, + batch_size=batch_size, + actual_size=actual_size, + padded_size=padded_size, + ) return self.decoder(codes) - self.static_inputs[padded_size].zero_() - self.static_inputs[padded_size][:, :, :actual_size] = codes - self.graphs[padded_size].replay() + self._record_decode_stats( + hit=True, + batch_size=batch_size, + actual_size=actual_size, + padded_size=padded_size, + ) + static_input = self.static_inputs[graph_key] + if actual_size == padded_size: + static_input.copy_(codes) + else: + static_input.zero_() + static_input[:, :, :actual_size] = codes + self.graphs[graph_key].replay() actual_out_len = actual_size * self.decoder.total_upsample - return self.static_outputs[padded_size][..., :actual_out_len].clone() + output = self.static_outputs[graph_key][..., :actual_out_len] + if clone_graph_output: + return output.clone() + return output + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + return self._decode(codes, clone_graph_output=True) def chunked_decode_with_cudagraph( self, @@ -173,19 +554,48 @@ def chunked_decode_with_cudagraph( chunk_size: int = 300, left_context_size: int = 25, ) -> torch.Tensor: - wavs = [] start_index = 0 total_len = codes.shape[-1] total_upsample = self.decoder.total_upsample + wav_out = None while start_index < total_len: end_index = min(start_index + chunk_size, total_len) context_size = left_context_size if start_index - left_context_size > 0 else start_index codes_chunk = codes[..., start_index - context_size : end_index] - wav_chunk = self.decode(codes_chunk) - - wavs.append(wav_chunk[..., context_size * total_upsample :]) + wav_chunk = self._decode(codes_chunk, clone_graph_output=False) + + if wav_out is None: + wav_out = torch.empty( + (*wav_chunk.shape[:-1], total_len * total_upsample), + dtype=wav_chunk.dtype, + device=wav_chunk.device, + ) + src_start = context_size * total_upsample + dst_start = start_index * total_upsample + dst_end = end_index * total_upsample + wav_out[..., dst_start:dst_end].copy_(wav_chunk[..., src_start:]) start_index = end_index - return torch.cat(wavs, dim=-1) + if wav_out is None: + return self.decoder(codes) + return wav_out + + def batched_chunked_decode_with_cudagraph( + self, + codes: torch.Tensor, + lengths: Sequence[int], + chunk_size: int = 300, + left_context_size: int = 25, + max_batch_size: int = 0, + ) -> torch.Tensor: + return _batched_chunked_decode( + codes, + lengths, + decode_fn=lambda codes_chunk: self._decode(codes_chunk, clone_graph_output=False), + total_upsample=self.decoder.total_upsample, + chunk_size=chunk_size, + left_context_size=left_context_size, + max_batch_size=max_batch_size, + ) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index fb4721ca25b..8c200141aac 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -1,5 +1,7 @@ from __future__ import annotations +import os +from collections import Counter from collections.abc import Iterable from typing import Any @@ -43,7 +45,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._decode_chunk_frames = 300 self._decode_left_context_frames = 25 + self._decode_batch_bucket_frames: list[int] = [] + self._decode_batch_max_size = 0 + self._decode_variable_chunk_batch_min_frames = self._decode_chunk_frames + self._decode_left_context_frames + 1 self._logged_codec_stats = False + self._logged_malformed_codec_lengths: set[tuple[int, int]] = set() + self._batch_stats_enabled = os.environ.get("VLLM_OMNI_QWEN3_CODE2WAV_BATCH_STATS", "").lower() in ( + "1", + "true", + "yes", + "on", + ) + self._batch_stats_log_every = int(os.environ.get("VLLM_OMNI_QWEN3_CODE2WAV_BATCH_STATS_LOG_EVERY", "0") or 0) + self._batch_stats_forwards = 0 + self._batch_stats_groups = 0 + self._batch_stats_requests = 0 + self._batch_stats_padded_frames = 0 + self._batch_stats_decoded_frames = 0 + self._batch_stats_actual_frames: Counter[int] = Counter() + self._batch_stats_bucket_groups: Counter[tuple[int, int]] = Counter() # Construct decoder from config so it is visible to vLLM's # memory profiler at startup. Weights are loaded later in @@ -91,6 +111,50 @@ def _split_request_ids(self, ids: torch.Tensor, seq_token_counts: list[int] | No return [ids[boundaries[i] : boundaries[i + 1]] for i in range(len(boundaries) - 1)] return [ids] + def _get_decode_batch_bucket_frames(self, actual_frames: int) -> int: + for bucket_frames in self._decode_batch_bucket_frames: + if actual_frames <= bucket_frames: + return bucket_frames + return actual_frames + + def _record_decode_batch_stats( + self, + *, + group_size: int, + bucket_frames: int, + actual_frames: list[int], + ) -> None: + if not self._batch_stats_enabled: + return + + self._batch_stats_groups += 1 + self._batch_stats_requests += group_size + self._batch_stats_decoded_frames += group_size * bucket_frames + self._batch_stats_padded_frames += sum(bucket_frames - frames for frames in actual_frames) + self._batch_stats_actual_frames.update(actual_frames) + self._batch_stats_bucket_groups[(group_size, bucket_frames)] += 1 + + def log_decode_batch_stats(self) -> None: + if not self._batch_stats_enabled or self._batch_stats_requests == 0: + return + + avg_group_size = self._batch_stats_requests / max(1, self._batch_stats_groups) + pad_ratio = self._batch_stats_padded_frames / max(1, self._batch_stats_decoded_frames) + logger.info( + "Code2Wav batch stats: forwards=%d groups=%d requests=%d " + "avg_group_size=%.2f padded_frames=%d decoded_frames=%d pad_ratio=%.2f%% " + "top_actual_frames=%s top_bucket_groups=%s", + self._batch_stats_forwards, + self._batch_stats_groups, + self._batch_stats_requests, + avg_group_size, + self._batch_stats_padded_frames, + self._batch_stats_decoded_frames, + 100.0 * pad_ratio, + self._batch_stats_actual_frames.most_common(12), + self._batch_stats_bucket_groups.most_common(12), + ) + @torch.no_grad() def forward( self, @@ -111,6 +175,7 @@ def forward( Length management is done here instead of relying on HF's padding=-1 sentinel logic. """ + self._batch_stats_forwards += 1 decoder = self.decoder q = int(self._num_quantizers) upsample = int(self._total_upsample) @@ -147,11 +212,15 @@ def forward( n = flat.numel() if n == 0 or n % q != 0: if n > 0: - logger.warning( - "Code2Wav input_ids length %d not divisible by num_quantizers %d; skipping malformed request.", - n, - q, - ) + key = (int(n), q) + if key not in self._logged_malformed_codec_lengths: + self._logged_malformed_codec_lengths.add(key) + logger.warning( + "Code2Wav input_ids length %d not divisible by num_quantizers %d; " + "skipping malformed request and suppressing repeats for this length.", + n, + q, + ) parsed.append((0, 0)) continue frames = n // q @@ -187,22 +256,95 @@ def forward( except Exception: pass - # Decode directly via decoder.chunked_decode(), staying entirely on GPU. - # Each request decoded individually with CUDA graph replay at bs=1. - wav_tensors: list[torch.Tensor] = [] - for codes_qf in valid_codes_qf: - codes_bqf = codes_qf.unsqueeze(0) # [1, Q, F] - try: - wav = decoder.chunked_decode( - codes_bqf, - chunk_size=self._decode_chunk_frames, - left_context_size=self._decode_left_context_frames, - ) # [1, 1, wav_len] - except TypeError: - # Unit-test fakes and older decoder shims may not accept the - # explicit chunk kwargs; production Qwen3-TTS decoders do. - wav = decoder.chunked_decode(codes_bqf) # [1, 1, wav_len] - wav_tensors.append(wav.squeeze(0).squeeze(0)) # [wav_len] + wav_tensors: list[torch.Tensor | None] = [None] * len(valid_codes_qf) + + def _decode_group_chunks(group_chunks: list[list[tuple[int, torch.Tensor]]]) -> None: + for group_chunk in group_chunks: + actual_frames = [int(codes_qf.shape[1]) for _, codes_qf in group_chunk] + target_frames = max(actual_frames) + is_equal_length_batch = all(frames == target_frames for frames in actual_frames) + use_variable_length_batch = ( + len(group_chunk) > 1 + and not is_equal_length_batch + and target_frames >= self._decode_variable_chunk_batch_min_frames + and hasattr(decoder, "batched_chunked_decode") + ) + if len(group_chunk) == 1: + codes_bqf = group_chunk[0][1].unsqueeze(0) + elif is_equal_length_batch: + codes_bqf = torch.stack([codes_qf for _, codes_qf in group_chunk], dim=0) + else: + first = group_chunk[0][1] + codes_bqf = first.new_zeros((len(group_chunk), q, target_frames)) + for row, (_, codes_qf) in enumerate(group_chunk): + codes_bqf[row, :, : codes_qf.shape[1]] = codes_qf + self._record_decode_batch_stats( + group_size=len(group_chunk), + bucket_frames=target_frames, + actual_frames=actual_frames, + ) + try: + if use_variable_length_batch: + wav_batch = decoder.batched_chunked_decode( + codes_bqf, + actual_frames, + chunk_size=self._decode_chunk_frames, + left_context_size=self._decode_left_context_frames, + max_batch_size=self._decode_batch_max_size, + ) # [B, 1, wav_len] + else: + wav_batch = decoder.chunked_decode( + codes_bqf, + chunk_size=self._decode_chunk_frames, + left_context_size=self._decode_left_context_frames, + ) # [B, 1, wav_len] + except TypeError: + # Unit-test fakes and older decoder shims may not accept the + # explicit chunk kwargs; production Qwen3-TTS decoders do. + wav_batch = decoder.chunked_decode(codes_bqf) # [B, 1, wav_len] + + if wav_batch.dim() == 3 and wav_batch.shape[1] == 1: + wav_rows = wav_batch[:, 0, :] + elif wav_batch.dim() == 2: + wav_rows = wav_batch + else: + raise ValueError( + "Code2Wav decoder returned unexpected shape " + f"{tuple(wav_batch.shape)} for batch size {len(group_chunk)}" + ) + if wav_rows.shape[0] != len(group_chunk): + raise ValueError( + f"Code2Wav decoder returned batch size {wav_rows.shape[0]} " + f"for input batch size {len(group_chunk)}" + ) + for row, (j, _) in enumerate(group_chunk): + wav_tensors[j] = wav_rows[row] + + # Group by configured frame buckets instead of only exact lengths. + # For ordinary async streaming windows this is the real batching + # opportunity; decoder-internal variable chunk batching is gated to + # longer inputs where repeated full chunks can amortize its overhead. + grouped_codes: dict[int, list[tuple[int, torch.Tensor]]] = {} + for j, codes_qf in enumerate(valid_codes_qf): + frames = int(codes_qf.shape[1]) + grouped_codes.setdefault(self._get_decode_batch_bucket_frames(frames), []).append((j, codes_qf)) + + for _bucket_frames, group in grouped_codes.items(): + if self._decode_batch_max_size > 0 and len(group) > self._decode_batch_max_size: + # Keep each decoder call inside the configured CUDA graph batch + # envelope. Sorting by length lowers right-padding within each + # split while outputs are restored by original request index. + group = sorted(group, key=lambda item: int(item[1].shape[1])) + group_chunks = [ + group[start : start + self._decode_batch_max_size] + for start in range(0, len(group), self._decode_batch_max_size) + ] + else: + group_chunks = [group] + _decode_group_chunks(group_chunks) + + if self._batch_stats_log_every > 0 and self._batch_stats_forwards % self._batch_stats_log_every == 0: + self.log_decode_batch_stats() audios: list[torch.Tensor] = [empty] * num_req srs = [sr_tensor] * num_req @@ -210,6 +352,7 @@ def forward( for j, idx in enumerate(valid_indices): ctx_frames, actual_frames = parsed[idx] wav = wav_tensors[j] + assert wav is not None # Slice on exact codec-frame boundaries instead of proportionally. start = max(0, ctx_frames * upsample) end = max(start, actual_frames * upsample) @@ -302,6 +445,80 @@ def _get_int_config(name: str, default: int) -> int: except (TypeError, ValueError) as exc: raise ValueError(f"Invalid Qwen3-TTS Code2Wav config {name}={value!r}") from exc + def _get_bool_config(name: str, default: bool) -> bool: + value = extra_cfg.get(name, default) + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in ("1", "true", "yes", "on"): + return True + if lowered in ("0", "false", "no", "off"): + return False + if isinstance(value, int): + return bool(value) + raise ValueError(f"Invalid Qwen3-TTS Code2Wav config {name}={value!r}") + + def _get_int_list_config(name: str) -> list[int] | None: + value = extra_cfg.get(name) + if value is None: + return None + if isinstance(value, str): + raw_values = [item.strip() for item in value.split(",") if item.strip()] + elif isinstance(value, int): + raw_values = [value] + else: + try: + raw_values = list(value) + except TypeError as exc: + raise ValueError(f"Invalid Qwen3-TTS Code2Wav config {name}={value!r}") from exc + values: set[int] = set() + for item in raw_values: + try: + parsed = int(item) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid Qwen3-TTS Code2Wav config {name}={value!r}") from exc + if parsed > 0: + values.add(parsed) + return sorted(values) + + def _get_int_pair_list_config(name: str) -> list[tuple[int, int]] | None: + value = extra_cfg.get(name) + if value is None: + return None + if isinstance(value, str): + raw_values = [item.strip() for item in value.split(",") if item.strip()] + else: + try: + raw_values = list(value) + except TypeError as exc: + raise ValueError(f"Invalid Qwen3-TTS Code2Wav config {name}={value!r}") from exc + + pairs: set[tuple[int, int]] = set() + for item in raw_values: + if isinstance(item, str): + if ":" not in item: + raise ValueError(f"Invalid Qwen3-TTS Code2Wav config {name}={value!r}") + left, right = item.split(":", 1) + raw_pair = (left.strip(), right.strip()) + else: + try: + raw_pair = tuple(item) + except TypeError as exc: + raise ValueError(f"Invalid Qwen3-TTS Code2Wav config {name}={value!r}") from exc + if len(raw_pair) != 2: + raise ValueError(f"Invalid Qwen3-TTS Code2Wav config {name}={value!r}") + try: + batch_size = int(raw_pair[0]) + seq_len = int(raw_pair[1]) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid Qwen3-TTS Code2Wav config {name}={value!r}") from exc + if batch_size > 0 and seq_len > 0: + pairs.add((batch_size, seq_len)) + return sorted(pairs) + if isinstance(extra_cfg, dict): codec_chunk_frames = int(extra_cfg.get("codec_chunk_frames") or 0) codec_left_context_frames = int(extra_cfg.get("codec_left_context_frames") or 0) @@ -318,6 +535,49 @@ def _get_int_config(name: str, default: int) -> int: ) self._decode_chunk_frames = decode_chunk_frames self._decode_left_context_frames = decode_left_context_frames + decode_cudagraph_capture_sizes = _get_int_list_config("decode_cudagraph_capture_sizes") + decode_cudagraph_batch_sizes = _get_int_list_config("decode_cudagraph_batch_sizes") + decode_cudagraph_extra_capture_shapes = _get_int_pair_list_config("decode_cudagraph_extra_capture_shapes") + decode_compile_shapes = _get_int_pair_list_config("decode_compile_shapes") + decode_batch_bucket_frames = _get_int_list_config("decode_batch_bucket_frames") + if decode_batch_bucket_frames is not None: + self._decode_batch_bucket_frames = decode_batch_bucket_frames + decode_batch_max_size = _get_int_config("decode_batch_max_size", self._decode_batch_max_size) + if decode_batch_max_size < 0: + raise ValueError(f"Invalid Qwen3-TTS Code2Wav config decode_batch_max_size={decode_batch_max_size}") + self._decode_batch_max_size = decode_batch_max_size + decode_variable_chunk_batch_min_frames = _get_int_config( + "decode_variable_chunk_batch_min_frames", + self._decode_variable_chunk_batch_min_frames, + ) + if decode_variable_chunk_batch_min_frames < 0: + raise ValueError( + "Invalid Qwen3-TTS Code2Wav config " + f"decode_variable_chunk_batch_min_frames={decode_variable_chunk_batch_min_frames}" + ) + self._decode_variable_chunk_batch_min_frames = decode_variable_chunk_batch_min_frames + decode_enable_tf32 = _get_bool_config("decode_enable_tf32", False) + else: + decode_cudagraph_capture_sizes = None + decode_cudagraph_batch_sizes = None + decode_cudagraph_extra_capture_shapes = None + decode_compile_shapes = None + decode_enable_tf32 = False + + if decode_enable_tf32 and device.type == "cuda": + # PyTorch exposes TF32 controls as process-wide CUDA backend + # switches. This opt-in is intended for deployments where + # Code2Wav runs in its own Stage1 worker process. + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.set_float32_matmul_precision("high") + logger.info( + "Qwen3-TTS Code2Wav TF32 enabled process-wide: " + "matmul.allow_tf32=%s cudnn.allow_tf32=%s float32_matmul_precision=%s", + torch.backends.cuda.matmul.allow_tf32, + torch.backends.cudnn.allow_tf32, + torch.get_float32_matmul_precision(), + ) if hasattr(self.decoder, "enable_cudagraph") and device.type == "cuda": try: @@ -339,6 +599,10 @@ def _get_int_config(name: str, default: int) -> int: ) self.decoder.enable_cudagraph( + capture_sizes=decode_cudagraph_capture_sizes, + capture_batch_sizes=decode_cudagraph_batch_sizes, + extra_capture_shapes=decode_cudagraph_extra_capture_shapes, + compile_shapes=decode_compile_shapes, device=device, codec_chunk_frames=codec_chunk_frames, codec_left_context_frames=codec_left_context_frames, diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index bdf702631b1..a705a1cd48d 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -265,6 +265,8 @@ def mel_spectrogram( fmin: int, fmax: int | None = None, center: bool = False, + mel_basis: torch.Tensor | None = None, + hann_window: torch.Tensor | None = None, ) -> torch.Tensor: """Calculate mel spectrogram of an input signal using torchaudio mel filterbank and torch STFT.""" if torch.min(y) < -1.0: @@ -272,14 +274,20 @@ def mel_spectrogram( if torch.max(y) > 1.0: logger.warning("Max value of input waveform signal is %s", torch.max(y)) device = y.device - mel_basis = mel_filter_bank( - sr=sampling_rate, - n_fft=n_fft, - n_mels=num_mels, - fmin=fmin, - fmax=fmax, - ).to(device) - hann_window = torch.hann_window(win_size).to(device) + if mel_basis is None: + mel_basis = mel_filter_bank( + sr=sampling_rate, + n_fft=n_fft, + n_mels=num_mels, + fmin=fmin, + fmax=fmax, + ).to(device) + elif mel_basis.device != device: + mel_basis = mel_basis.to(device) + if hann_window is None: + hann_window = torch.hann_window(win_size, device=device) + elif hann_window.device != device: + hann_window = hann_window.to(device) padding = (n_fft - hop_size) // 2 y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) spec = torch.stft( @@ -733,6 +741,117 @@ def preprocess( info_update["hidden_states"] = {"trailing_text": trailing_text_update.detach()} return input_ids, inputs_embeds_out, info_update + def preprocess_decode_batch( + self, + *, + input_ids: torch.Tensor, + req_infos: list[dict[str, Any]], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[dict[str, Any]]]: + """Batch the decode-only preprocess path for Qwen3-TTS. + + This mirrors the scalar decode branch in ``preprocess()``, but performs + the token embedding lookup once for the whole decode batch. + """ + input_ids_flat = input_ids.reshape(-1) + if int(input_ids_flat.numel()) != len(req_infos): + raise ValueError( + f"preprocess_decode_batch expected {len(req_infos)} input ids, got {int(input_ids_flat.numel())}" + ) + + device = input_ids_flat.device + dtype = torch.bfloat16 + past_hidden_list: list[torch.Tensor] = [] + text_step_list: list[torch.Tensor] = [] + updates: list[dict[str, Any]] = [] + + for info_dict in req_infos: + additional_information = info_dict.get("additional_information") + if isinstance(additional_information, dict): + merged: dict[str, Any] = {k: v for k, v in info_dict.items() if k != "additional_information"} + for k, v in additional_information.items(): + merged.setdefault(k, v) + info_dict = merged + + payload: OmniPayload = info_dict + embed = payload.get("embed", {}) + hs = payload.get("hidden_states", {}) + meta = payload.get("meta", {}) + + text_list = info_dict.get("text") + if not isinstance(text_list, list) or not text_list or not text_list[0]: + raise ValueError("Missing additional_information.text for Qwen3-TTS AR talker.") + + task_type = (info_dict.get("task_type") or ["CustomVoice"])[0] + codec_streaming_val = meta.get("codec_streaming") + if isinstance(codec_streaming_val, list): + codec_streaming_raw = codec_streaming_val[0] if codec_streaming_val else None + else: + codec_streaming_raw = codec_streaming_val + if isinstance(codec_streaming_raw, bool): + codec_streaming = codec_streaming_raw + else: + codec_streaming = task_type == "Base" + + tts_pad_embed_buf = embed.get("tts_pad") + if not isinstance(tts_pad_embed_buf, torch.Tensor): + raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must run first.") + tts_pad_embed = tts_pad_embed_buf.to(device=device, dtype=dtype).reshape(1, -1) + + tail = hs.get("trailing_text") + text_offset = max(0, int(meta.get("talker_text_offset", 0) or 0)) + trailing_text_update = None + if isinstance(tail, torch.Tensor) and tail.ndim == 2: + tail_len = int(tail.shape[0]) + if text_offset < tail_len: + text_step = tail[text_offset : text_offset + 1].to(device=device, dtype=dtype).reshape(1, -1) + next_text_offset = text_offset + 1 + should_compact_tail = next_text_offset >= tail_len or ( + next_text_offset >= _TRAILING_TEXT_COMPACT_MIN_FRAMES and next_text_offset * 2 >= tail_len + ) + if should_compact_tail: + if next_text_offset >= tail_len: + trailing_text_update = torch.empty((0, tail.shape[1]), device=tail.device, dtype=tail.dtype) + else: + trailing_text_update = tail[next_text_offset:].contiguous() + next_text_offset = 0 + else: + text_step = tts_pad_embed + next_text_offset = 0 + if tail.numel() > 0: + trailing_text_update = torch.empty((0, tail.shape[1]), device=tail.device, dtype=tail.dtype) + else: + text_step = tts_pad_embed + next_text_offset = text_offset + + last_hidden = hs.get("last") + if not isinstance(last_hidden, torch.Tensor): + raise RuntimeError("Missing hidden_states['last'] in additional_information; postprocess must run.") + past_hidden_list.append(last_hidden.to(device=device, dtype=dtype).reshape(1, -1)) + text_step_list.append(text_step) + + info_update: dict[str, Any] = { + "meta": { + "talker_text_offset": int(next_text_offset), + "codec_streaming": codec_streaming, + }, + } + if trailing_text_update is not None: + info_update["hidden_states"] = {"trailing_text": trailing_text_update.detach()} + updates.append(info_update) + + inputs_embeds_out = self.embed_input_ids(input_ids_flat.reshape(-1, 1).to(torch.long)).to( + device=device, + dtype=dtype, + ) + inputs_embeds_out = inputs_embeds_out.reshape(len(req_infos), -1) + return ( + input_ids_flat, + inputs_embeds_out, + torch.cat(past_hidden_list, dim=0), + torch.cat(text_step_list, dim=0), + updates, + ) + def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]: # Keep the last token hidden for the next decode step's code predictor. # Stays on GPU - gpu_resident_buffer_keys avoids the CPU round-trip. @@ -1171,9 +1290,21 @@ def _extract_speaker_embedding(self, wav: np.ndarray, sr: int) -> torch.Tensor: wav = resampler.resample(wav.astype(np.float32), orig_sr=int(sr)) sr = target_sr - # Follow official implementation: mel_spectrogram expects 24kHz. + # Follow official implementation: mel_spectrogram expects 24kHz. Move + # the waveform first so STFT/mel computation stays on the model device + # instead of materializing a CPU mel tensor and copying it per request. + wav_tensor = torch.from_numpy(wav).to(device=dev, dtype=torch.float32).unsqueeze(0) + mel_basis, hann_window = self._get_speaker_mel_buffers( + device=dev, + sampling_rate=24000, + n_fft=1024, + num_mels=128, + win_size=1024, + fmin=0, + fmax=12000, + ) mels = mel_spectrogram( - torch.from_numpy(wav).unsqueeze(0), + wav_tensor, n_fft=1024, num_mels=128, sampling_rate=24000, @@ -1181,10 +1312,43 @@ def _extract_speaker_embedding(self, wav: np.ndarray, sr: int) -> torch.Tensor: win_size=1024, fmin=0, fmax=12000, + mel_basis=mel_basis, + hann_window=hann_window, ).transpose(1, 2) - spk = self.speaker_encoder(mels.to(dev, dtype=torch.bfloat16))[0] + spk = self.speaker_encoder(mels.to(dtype=torch.bfloat16))[0] return spk.to(dtype=torch.bfloat16) + def _get_speaker_mel_buffers( + self, + *, + device: torch.device, + sampling_rate: int, + n_fft: int, + num_mels: int, + win_size: int, + fmin: int, + fmax: int | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + cache = getattr(self, "_speaker_mel_buffer_cache", None) + if cache is None: + cache = {} + self._speaker_mel_buffer_cache = cache + key = (str(device), int(sampling_rate), int(n_fft), int(num_mels), int(win_size), int(fmin), fmax) + cached = cache.get(key) + if cached is not None: + return cached + mel_basis = mel_filter_bank( + sr=int(sampling_rate), + n_fft=int(n_fft), + n_mels=int(num_mels), + fmin=int(fmin), + fmax=fmax, + ).to(device) + hann_window = torch.hann_window(int(win_size), device=device) + cached = (mel_basis, hann_window) + cache[key] = cached + return cached + def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: if self._speech_tokenizer is not None: return self._speech_tokenizer @@ -1680,9 +1844,7 @@ def _get_ref_audio() -> tuple[np.ndarray, int]: ref_code_len = int(ref_code_t.shape[0]) else: codes = info_dict.get("codes") - precomputed_ref_code = ( - codes.get(_PRECOMPUTED_REF_CODE_KEY) if isinstance(codes, dict) else None - ) + precomputed_ref_code = codes.get(_PRECOMPUTED_REF_CODE_KEY) if isinstance(codes, dict) else None ref_code_t = self._coerce_ref_code_tensor(precomputed_ref_code, device=input_ids.device) if isinstance(ref_code_t, torch.Tensor): ref_code_len = int(ref_code_t.shape[0]) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index e71dbc091e8..d11d3889285 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -534,6 +534,11 @@ def forward( cache_position=None, **kwargs, ) -> BaseModelOutputWithPast: + """ + Args: + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ if input_ids is not None: raise ValueError("input_ids is not expected") if (input_ids is None) ^ (inputs_embeds is not None): @@ -848,6 +853,9 @@ def precompute_snake_caches(self): def enable_cudagraph( self, capture_sizes: list[int] | None = None, + capture_batch_sizes: list[int] | None = None, + extra_capture_shapes: list[tuple[int, int]] | None = None, + compile_shapes: list[tuple[int, int]] | None = None, device: torch.device | None = None, codec_chunk_frames: int = 0, codec_left_context_frames: int = 0, @@ -865,6 +873,9 @@ def enable_cudagraph( self._cudagraph_wrapper = CUDAGraphDecoderWrapper( decoder=self, capture_sizes=capture_sizes, + capture_batch_sizes=capture_batch_sizes, + extra_capture_shapes=extra_capture_shapes, + compile_shapes=compile_shapes, num_quantizers=self.config.num_quantizers, enabled=True, ) @@ -878,8 +889,11 @@ def enable_cudagraph( ) self._cudagraph_enabled = True logger.info( - "CUDA Graph enabled for decoder: seq_lens=%s", + "CUDA Graph enabled for decoder: batch_sizes=%s seq_lens=%s extra_shapes=%s compile_shapes=%s", + self._cudagraph_wrapper.capture_batch_sizes, self._cudagraph_wrapper.capture_sizes, + self._cudagraph_wrapper.extra_capture_shapes, + self._cudagraph_wrapper.compile_shapes, ) def disable_cudagraph(self): @@ -921,6 +935,35 @@ def chunked_decode(self, codes, chunk_size=300, left_context_size=25): start_index = end_index return torch.cat(wavs, dim=-1) + def batched_chunked_decode( + self, + codes, + lengths, + chunk_size=300, + left_context_size=25, + max_batch_size=0, + ): + if self._cudagraph_enabled and self._cudagraph_wrapper is not None: + return self._cudagraph_wrapper.batched_chunked_decode_with_cudagraph( + codes, + lengths, + chunk_size=chunk_size, + left_context_size=left_context_size, + max_batch_size=max_batch_size, + ) + + from ..cuda_graph_decoder_wrapper import _batched_chunked_decode + + return _batched_chunked_decode( + codes, + lengths, + decode_fn=self, + total_upsample=self.total_upsample, + chunk_size=chunk_size, + left_context_size=left_context_size, + max_batch_size=max_batch_size, + ) + class Qwen3TTSTokenizerV2Encoder(MimiModel): def __init__(self, config: MimiConfig): diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 35edbfc1e2c..faa7e4cc4d3 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -164,6 +164,7 @@ def talker2code2wav_async_chunk( chunk_size = int(cfg.get("codec_chunk_frames", 25)) left_context_size_config = int(cfg.get("codec_left_context_frames", 25)) configured_initial_chunk_size = int(cfg.get("initial_codec_chunk_frames") or 0) + ref_code_context_frames = int(cfg.get("ref_code_context_frames") or left_context_size_config) # Per-request override takes priority over dynamic IC. fixed_initial_chunk_size = configured_initial_chunk_size > 0 @@ -193,11 +194,18 @@ def talker2code2wav_async_chunk( _ic_cache[request_id] = compute_dynamic_initial_chunk_size(active, capacity, max_ic) initial_chunk_size = _ic_cache[request_id] - if chunk_size <= 0 or left_context_size_config < 0 or configured_initial_chunk_size < 0 or initial_chunk_size < 0: + if ( + chunk_size <= 0 + or left_context_size_config < 0 + or configured_initial_chunk_size < 0 + or initial_chunk_size < 0 + or ref_code_context_frames < 0 + ): raise ValueError( f"Invalid codec chunk config: codec_chunk_frames={chunk_size}, " f"codec_left_context_frames={left_context_size_config}, " - f"initial_codec_chunk_frames={initial_chunk_size}" + f"initial_codec_chunk_frames={initial_chunk_size}, " + f"ref_code_context_frames={ref_code_context_frames}" ) if initial_chunk_size > chunk_size: @@ -237,15 +245,22 @@ def talker2code2wav_async_chunk( left_context_size = max(0, end_index - context_length) window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:] - # Prepend ref_code as decoder context for every chunk so the vocoder - # maintains voice-clone speaker identity throughout the stream. The HF - # reference decodes ref_code + all_codes in one pass; without ref_code - # context on later chunks the decoder loses speaker identity and produces - # distorted audio. Use `.get()` (not `.pop()`) to keep ref_code for - # subsequent chunks. + # Prepend a bounded ref_code tail as decoder context for every chunk so the + # vocoder keeps voice-clone speaker identity without making Stage1 shapes + # depend on full reference-audio length. The decoder is causal with sliding + # attention, so frames older than this context window cannot affect the + # emitted chunk. Use `.get()` (not `.pop()`) to keep ref_code for later chunks. ref_code = request_payload.get(request_id) if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0: - ref_frames = ref_code.tolist() + ref_context = ref_code + if ref_code_context_frames > 0 and int(ref_context.shape[0]) > ref_code_context_frames: + logger.info_once( + "Qwen3-TTS async chunk uses the last %d/%d ref_code frames as bounded Code2Wav context.", + ref_code_context_frames, + int(ref_context.shape[0]), + ) + ref_context = ref_context[-ref_code_context_frames:] + ref_frames = ref_context.tolist() window_frames = ref_frames + window_frames left_context_size += len(ref_frames) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 1cc6a1c7019..c1ab5658e6f 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1323,6 +1323,47 @@ def _preprocess( # Overlay custom prompt_embeds per request for the prompt portion; # collect additional_information (tensor/list) for prefill portion only decode_req_ids = [] + decode_start_offsets = [] + decode_batch_items = [] + batch_decode_preprocess = getattr(self.model, "preprocess_decode_batch", None) + + def flush_decode_batch() -> None: + nonlocal inputs_embeds + if not decode_batch_items: + return + + req_ids_b = [item[0] for item in decode_batch_items] + start_offsets_b = [item[1] for item in decode_batch_items] + req_infos_b = [item[2] for item in decode_batch_items] + ids_b = torch.stack([input_ids[offset : offset + 1].reshape(-1)[0] for offset in start_offsets_b]) + req_input_ids, req_embeds, last_talker_hidden, text_step, updates = batch_decode_preprocess( + input_ids=ids_b, + req_infos=req_infos_b, + ) + if inputs_embeds is None: + inputs_embeds = torch.empty( + (input_ids.shape[0], req_embeds.shape[-1]), + device=req_embeds.device, + dtype=req_embeds.dtype, + ) + + offsets_t = torch.tensor(start_offsets_b, device=req_embeds.device, dtype=torch.long) + inputs_embeds.index_copy_(0, offsets_t, req_embeds) + input_ids.index_copy_(0, offsets_t, req_input_ids.reshape(-1).to(dtype=input_ids.dtype)) + + dst = slice(len(decode_req_ids), len(decode_req_ids) + len(req_ids_b)) + self.talker_mtp_input_ids.gpu[dst].copy_(req_input_ids.reshape(-1)) + self.talker_mtp_inputs_embeds.gpu[dst].copy_(req_embeds) + self.last_talker_hidden.gpu[dst].copy_(last_talker_hidden) + self.text_step.gpu[dst].copy_(text_step) + + for req_id_b, update_dict_b in zip(req_ids_b, updates, strict=True): + self._merge_additional_information_update(req_id_b, update_dict_b) + + decode_req_ids.extend(req_ids_b) + decode_start_offsets.extend(start_offsets_b) + decode_batch_items.clear() + for req_index, req_id in enumerate(self.input_batch.req_ids): req_infos = self.model_intermediate_buffer.get(req_id, {}) @@ -1344,6 +1385,12 @@ def _preprocess( req_infos["_omni_prompt_len"] = prompt_len req_infos["_omni_num_computed_tokens"] = num_computed_tokens req_infos["_omni_is_prefill"] = is_prefill + if callable(batch_decode_preprocess) and self.has_talker_mtp and span_len == 1 and not is_prefill: + decode_batch_items.append((req_id, s, req_infos)) + continue + + flush_decode_batch() + embed_slice = inputs_embeds[s:e] if inputs_embeds is not None else None req_input_ids, req_embeds, update_dict = self.model.preprocess( input_ids=input_ids[s:e], input_embeds=embed_slice, **req_infos @@ -1363,6 +1410,7 @@ def _preprocess( self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) self.text_step.gpu[decode_slice].copy_(text_step) decode_req_ids.append(req_id) + decode_start_offsets.append(s) # TODO(Peiqi): the merge stage could move out from the critical path self._merge_additional_information_update(req_id, update_dict) @@ -1373,9 +1421,11 @@ def _preprocess( if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len: input_ids[s : s + seg_len] = req_input_ids + flush_decode_batch() + # run talker mtp decode if self.has_talker_mtp: - self._talker_mtp_forward(decode_req_ids, inputs_embeds) + self._talker_mtp_forward(decode_req_ids, inputs_embeds, decode_start_offsets) return ( input_ids, @@ -1386,7 +1436,12 @@ def _preprocess( ec_connector_output, ) - def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None: + def _talker_mtp_forward( + self, + decode_req_ids: list[str], + inputs_embeds: torch.Tensor, + start_offsets: list[int] | None = None, + ) -> None: decode_batch_size = len(decode_req_ids) if decode_batch_size == 0: return @@ -1412,23 +1467,40 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te subtalker_params = getattr(self.vllm_config.model_config, "subtalker_sampling_params", None) if not isinstance(subtalker_params, dict): subtalker_params = {} + + def _explicit_talker_seed(req_id: str) -> int | None: + sampling_params = getattr(self.requests[req_id], "sampling_params", None) + extra_args = getattr(sampling_params, "extra_args", None) if sampling_params is not None else None + seed = extra_args.get("qwen3_tts_request_seed") if isinstance(extra_args, dict) else None + return int(seed) if seed is not None else None + + if decode_batch_size > 1 and any(_explicit_talker_seed(req_id) is not None for req_id in decode_req_ids): + # A torch.Generator is a single stream. Using one generator for a + # multi-row batch would make explicitly-seeded requests depend on + # other rows in the same scheduler step, so keep that path scalar. + saved_input_ids = self.talker_mtp_input_ids.gpu[:decode_batch_size].clone() + saved_embeds = self.talker_mtp_inputs_embeds.gpu[:decode_batch_size].clone() + saved_hidden = self.last_talker_hidden.gpu[:decode_batch_size].clone() + saved_text = self.text_step.gpu[:decode_batch_size].clone() + try: + for row, req_id in enumerate(decode_req_ids): + self.talker_mtp_input_ids.gpu[:1].copy_(saved_input_ids[row : row + 1]) + self.talker_mtp_inputs_embeds.gpu[:1].copy_(saved_embeds[row : row + 1]) + self.last_talker_hidden.gpu[:1].copy_(saved_hidden[row : row + 1]) + self.text_step.gpu[:1].copy_(saved_text[row : row + 1]) + row_offsets = None if start_offsets is None else [start_offsets[row]] + self._talker_mtp_forward([req_id], inputs_embeds, row_offsets) + finally: + self.talker_mtp_input_ids.gpu[:decode_batch_size].copy_(saved_input_ids) + self.talker_mtp_inputs_embeds.gpu[:decode_batch_size].copy_(saved_embeds) + self.last_talker_hidden.gpu[:decode_batch_size].copy_(saved_hidden) + self.text_step.gpu[:decode_batch_size].copy_(saved_text) + return + generator = None if decode_req_ids: first_req_id = decode_req_ids[0] - first_sp = getattr(self.requests[first_req_id], "sampling_params", None) - extra_args = getattr(first_sp, "extra_args", None) if first_sp is not None else None - seed = extra_args.get("qwen3_tts_request_seed") if isinstance(extra_args, dict) else None - if len(decode_req_ids) > 1 and seed is not None: - other_seeds = { - getattr(getattr(self.requests[rid], "sampling_params", None), "seed", None) - for rid in decode_req_ids[1:] - } - if other_seeds != {seed}: - logger.warning( - "Fast AR seed: batch has mixed seeds; using first request's seed=%d for all %d requests.", - seed, - len(decode_req_ids), - ) + seed = _explicit_talker_seed(first_req_id) if seed is not None: generators = getattr(self, "_talker_mtp_generators", None) if generators is None: @@ -1461,9 +1533,12 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te out_key = getattr(self.model, "talker_mtp_output_key", ("codes", "audio")) if not isinstance(out_key, tuple) or len(out_key) != 2: raise TypeError(f"talker_mtp_output_key must be a 2-tuple, got {type(out_key).__name__}: {out_key!r}") - for idx, req_id in enumerate(decode_req_ids): - req_index = self.input_batch.req_ids.index(req_id) - start_offset = int(self.query_start_loc.cpu[req_index]) + if start_offsets is None: + start_offsets = [] + for req_id in decode_req_ids: + req_index = self.input_batch.req_ids.index(req_id) + start_offsets.append(int(self.query_start_loc.cpu[req_index])) + for idx, (req_id, start_offset) in enumerate(zip(decode_req_ids, start_offsets, strict=True)): inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] update_dict = {out_key[0]: {out_key[1]: code_predictor_codes[idx : idx + 1]}} self._merge_additional_information_update(req_id, update_dict) From 392b84066e7f0fc73e34d6f75a2c68524bc98ad0 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Sun, 17 May 2026 03:21:53 +0800 Subject: [PATCH 5/6] chore: remove qwen3 tts baseline artifact Signed-off-by: Sy03 <1370724210@qq.com> --- artifacts/qwen3_tts_ws1_baseline/README.md | 43 ---------------------- 1 file changed, 43 deletions(-) delete mode 100644 artifacts/qwen3_tts_ws1_baseline/README.md diff --git a/artifacts/qwen3_tts_ws1_baseline/README.md b/artifacts/qwen3_tts_ws1_baseline/README.md deleted file mode 100644 index e424d1f5d3d..00000000000 --- a/artifacts/qwen3_tts_ws1_baseline/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# Qwen3-TTS WS1 Baseline - -Baseline scope: -- Config fixed at Stage0 max_num_seqs=64 and Stage1 max_num_seqs=10. -- Existing initial_codec_chunk_frames=1 is kept. -- Existing Code2Wav exact-length batching is kept. -- No WS1 Stage0 slot runner is enabled. - -Primary workload: -- Model: Qwen/Qwen3-TTS-12Hz-1.7B-Base -- Task: voice_clone -- Concurrency: 64 -- Num prompts: 256 for stable run, 128 for quick A/B -- Warmups: 2, excluded from steady-state SLA - -Metrics: -- median / p99 TTFT -- median / p99 audio TTFP -- median / p99 E2EL -- median / p99 audio RTF -- audio throughput -- request throughput -- failed request count - -Validated WS1 result: -- Change: batched Stage0 Base voice_clone preprocessing for tokenizer ids, - ref-audio normalization, and same-sample-rate ref_code encoding. -- Remote result: - `/home/admin/workspace/remote_workspace/qwen3_stage0_slot_runner_ab_20260514_1840/results_20260514_190000/ab_summary.json` -- Workload: 2x H20, GPU pair `0,1`, concurrency 64, prompts 256, - warmups 2, Stage0 `max_num_seqs=64`, Stage1 `max_num_seqs=10`. -- Correctness smoke: new and old both completed 256 requests, failed requests - 0 in the benchmark log, with nonzero audio output (`1078.00s` new, - `1076.96s` old). - -| Metric | New | Old | Delta | -| --- | ---: | ---: | ---: | -| Audio throughput | 28.8289 | 25.5383 | +12.89% | -| Request throughput | 6.8462 | 6.0706 | +12.78% | -| Median audio RTF | 2.1888 | 2.4626 | -11.12% | -| Median audio TTFP ms | 1573.48 | 1634.63 | -3.74% | -| P99 audio TTFP ms | 5032.01 | 7319.04 | -31.25% | -| Median E2EL ms | 8746.86 | 9949.19 | -12.08% | From 03c0aea7beb7a8194a75b1dade224fd2d0774981 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Mon, 18 May 2026 00:00:02 +0800 Subject: [PATCH 6/6] fix: validate qwen3 code predictor graph config Signed-off-by: Sy03 <1370724210@qq.com> --- .../models/qwen3_tts/test_code_predictor_dtype.py | 2 ++ .../models/common/qwen3_code_predictor.py | 14 ++++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py index 871f1dfb8c8..87281af82b2 100644 --- a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py +++ b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py @@ -415,6 +415,8 @@ def test_prefix_graph_config_helpers(self, loaded_target_classes) -> None: 128, } assert wrapper_cls._parse_positive_int_set([2, "4", 0]) == {2, 4} + with pytest.raises(ValueError, match="Invalid positive int config value 'bad'"): + wrapper_cls._parse_positive_int_set("2,bad") wrapper = object.__new__(wrapper_cls) wrapper._prefix_graph_seq_lens = {1, 2, 4, 8, 99} diff --git a/vllm_omni/model_executor/models/common/qwen3_code_predictor.py b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py index 4105a1e547b..eca5addba79 100644 --- a/vllm_omni/model_executor/models/common/qwen3_code_predictor.py +++ b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py @@ -591,12 +591,18 @@ def _parse_positive_int_set(value: object) -> set[int]: elif isinstance(value, int): raw_values = [value] else: - raw_values = list(value) + try: + raw_values = list(value) + except TypeError as exc: + raise ValueError(f"Invalid positive int config value {value!r}") from exc values: set[int] = set() for item in raw_values: - value = int(item) - if value > 0: - values.add(value) + try: + parsed = int(item) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid positive int config value {item!r}") from exc + if parsed > 0: + values.add(parsed) return values def _prefix_seq_lens(self, max_seq: int) -> list[int]: