From 4f6899a8dea36a0e4d29c54edf6d254f301f6fa3 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Mon, 2 Feb 2026 11:41:17 -0800 Subject: [PATCH 01/28] [+] Feat: Support disaggregated inference pipeline for Talker and SpeechTokenizer Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/config/model.py | 17 + .../core/sched/omni_generation_scheduler.py | 93 +- .../distributed/omni_connectors/adapter.py | 29 +- .../chunk_transfer_adapter.py | 3 +- vllm_omni/inputs/preprocess.py | 88 +- .../models/qwen3_tts/qwen3_tts_code2wav.py | 361 +++++ .../qwen3_tts_code_predictor_vllm.py | 407 +++++ .../qwen3_tts/qwen3_tts_disaggregated.py | 228 +++ .../models/qwen3_tts/qwen3_tts_talker_ar.py | 1374 +++++++++++++++++ .../models/qwen3_tts/qwen3_tts_tokenizer.py | 79 +- vllm_omni/model_executor/models/registry.py | 19 +- .../stage_configs/qwen3_tts.yaml | 89 +- .../stage_input_processors/qwen3_tts.py | 84 + .../worker/gpu_generation_model_runner.py | 23 +- vllm_omni/worker/gpu_model_runner.py | 107 +- 15 files changed, 2854 insertions(+), 147 deletions(-) create mode 100644 vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py create mode 100644 vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py create mode 100644 vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py create mode 100644 vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py create mode 100644 vllm_omni/model_executor/stage_input_processors/qwen3_tts.py diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index a9ffa015fe1..7f915dc56e7 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -58,6 +58,8 @@ class OmniModelConfig(ModelConfig): } ) omni_kv_config: dict | None = None + # Codec frame rate (frames/sec) for prompt length estimation. + codec_frame_rate_hz: float | None = None @property def registry(self): @@ -128,6 +130,21 @@ def __post_init__( video_pruning_rate=video_pruning_rate, ) + # Qwen3-TTS: infer codec frame rate from the model config for online serving. + if self.codec_frame_rate_hz is None and self.model_arch == "Qwen3TTSTalkerForConditionalGenerationARVLLM": + talker_cfg = getattr(self.hf_config, "talker_config", None) + if isinstance(talker_cfg, dict): + pos_per_sec = talker_cfg.get("position_id_per_seconds") + else: + pos_per_sec = getattr(talker_cfg, "position_id_per_seconds", None) + if pos_per_sec is not None: + try: + fps = float(pos_per_sec) + except Exception: + fps = None + if fps is not None and fps > 0: + self.codec_frame_rate_hz = fps + # Override hf_text_config with omni-specific logic for multi-stage models # (e.g., thinker_config, talker_config) new_hf_text_config = self.draw_hf_text_config() diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 684aab9ce20..0b88e1ee521 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -31,12 +31,8 @@ def __init__(self, *args, **kwargs): self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) def schedule(self) -> SchedulerOutput: - """Diffusion fast path: - - Feed all input tokens of the request at once - (if 0, allocate 1 placeholder token). - - If the token budget cannot be satisfied at once, fall back to the - default vLLM scheduling. - """ + """Diffusion fast path: schedule all prompt tokens at once (use 1 placeholder if empty). + Fall back to vLLM scheduling if the token budget cannot be satisfied.""" token_budget = self.max_num_scheduled_tokens scheduled_timestamp = time.monotonic() @@ -50,7 +46,7 @@ def schedule(self) -> SchedulerOutput: scheduled_encoder_inputs: dict[str, list[int]] = {} cached_prompt_token_ids: dict[str, list[int]] = {} - # Temporary queue: preserve waiting order, do not disturb non-diffusion requests + # Temporary queue to preserve waiting order for non-diffusion requests. skipped_waiting_requests = create_request_queue(self.policy) req_index = 0 if self.chunk_transfer_adapter: @@ -62,16 +58,32 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] # OMNI: Skip requests that are not in self.requests - # This can happen when connector marks request as finished and it's removed from requests if request.request_id not in self.requests or ( self.chunk_transfer_adapter is None and request.status == RequestStatus.FINISHED_STOPPED ): already_finished_reqs.add(request) req_index += 1 continue - num_computed_tokens = request.num_computed_tokens - required_tokens = max(len(request.prompt_token_ids) - num_computed_tokens, 1) + required_tokens = len(request.prompt_token_ids) - num_computed_tokens + # async_chunk: don't schedule placeholder tokens when no new chunk is available. + if required_tokens <= 0: + if ( + self.chunk_transfer_adapter is not None + and request.request_id in self.chunk_transfer_adapter.finished_requests + ): + request.status = RequestStatus.FINISHED_STOPPED + # Upstream may finish with no terminal tokens; append one pad token so we can emit FINISHED. + if len(request.prompt_token_ids) <= num_computed_tokens: + request.prompt_token_ids.append(0) + try: + request._all_token_ids.append(0) # type: ignore[attr-defined] + except Exception: + pass + required_tokens = len(request.prompt_token_ids) - num_computed_tokens + else: + req_index += 1 + continue num_new_tokens = min(required_tokens, token_budget) new_blocks = self.kv_cache_manager.allocate_slots( request, @@ -109,11 +121,23 @@ def schedule(self) -> SchedulerOutput: self.waiting.pop_request() continue - # Uniformly treat as diffusion. A feature flag can be added later - # via config or request tag. + # async_chunk: wait for the first upstream chunk (don't start with placeholders). + if self.chunk_transfer_adapter is not None and len(request.prompt_token_ids) == 0: + if request.request_id in self.chunk_transfer_adapter.finished_requests: + request.status = RequestStatus.FINISHED_STOPPED + _ai = getattr(request, "additional_information", None) or {} + _pad = _ai.get("prompt_placeholder_pad_id", [0])[0] + request.prompt_token_ids.append(_pad) + try: + request._all_token_ids.append(_pad) # type: ignore[attr-defined] + except Exception: + pass + else: + break + + # Treat all requests as diffusion here (feature flag can be added later). - # Allocate all input tokens for the request in one shot - # (allocate 1 placeholder if zero) + # Allocate all prompt tokens at once (use 1 placeholder if empty). required_tokens = max(len(request.prompt_token_ids), 1) num_new_tokens = min(required_tokens, token_budget) new_blocks = self.kv_cache_manager.allocate_slots( @@ -192,6 +216,17 @@ def schedule(self) -> SchedulerOutput: num_output_tokens=cached_reqs_data.num_output_tokens, prompt_token_ids=cached_prompt_token_ids, ) + # async_chunk: forward per-step additional_information updates for cached requests. + try: + cached_ai: dict[str, object] = {} + for req in scheduled_running_reqs: + ai = getattr(req, "additional_information", None) + if isinstance(ai, dict) and ai: + cached_ai[req.request_id] = ai + if cached_ai: + setattr(cached_reqs_data, "additional_information", cached_ai) + except Exception: + pass total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) scheduler_output = SchedulerOutput( @@ -225,8 +260,7 @@ def schedule(self) -> SchedulerOutput: self._update_after_schedule(scheduler_output) try: - # Rewrap base NewRequestData entries with OmniNewRequestData, - # enriching with request-level payloads + # Wrap base NewRequestData as OmniNewRequestData and attach request-level payloads. new_list = [] for nr in scheduler_output.scheduled_new_reqs: req_id = getattr(nr, "req_id", None) @@ -259,25 +293,14 @@ def schedule(self) -> SchedulerOutput: init_logger(__name__).exception("Failed to wrap scheduled_new_reqs with OmniNewRequestData") return scheduler_output - - """ - Scheduler for the diffusion model. - This scheduler is modified to stop the request immediately for the diffusion model. - This is because the diffusion model can generate the final image/audio in one step. - Note: This is just a minimal modification to the original scheduler, - and there should be some further efforts to optimize the scheduler. - The original scheduler is still used for the AR model. - """ + # Diffusion scheduler: stop requests immediately after one step (AR uses the original vLLM scheduler). def update_from_output( self, scheduler_output: SchedulerOutput, model_runner_output: OmniModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: - """Update the scheduler state based on the model runner output. - - This method is modified to stop the request immediately for the diffusion model. - """ + """Update scheduler state from model_runner_output (diffusion requests stop immediately).""" sampled_token_ids = model_runner_output.sampled_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict @@ -306,9 +329,7 @@ def update_from_output( if kv_connector_output and getattr(kv_connector_output, "invalid_block_ids", None): failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids) - # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, - # the below loop can be a performance bottleneck. We should do our best - # to avoid expensive operations inside the loop. + # NOTE: keep loop body cheap (len(num_scheduled_tokens) can be 1K+). stopped_running_reqs: set[Request] = set() stopped_preempted_reqs: set[Request] = set() for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): @@ -318,9 +339,7 @@ def update_from_output( continue request = self.requests.get(req_id) if request is None or request.is_finished(): - # The request is already finished. This can happen if the - # request is aborted while the model is executing it (e.g., - # in pipeline parallelism or async scheduling). + # Request may already be finished (e.g., aborted during execution / pipeline parallelism / async scheduling). continue req_index = model_runner_output.req_id_to_index[req_id] @@ -390,9 +409,7 @@ def update_from_output( new_logprobs = logprobs.slice_request(req_index, len(new_token_ids)) if new_token_ids and self.structured_output_manager.should_advance(request): - # NOTE: structured_output_request should not be None if - # use_structured_output, we have check above, so safe to ignore - # type warning + # NOTE: structured_output_request is guaranteed when structured output is enabled (ignore type warning). request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] # noqa: E501 req_id, new_token_ids ) diff --git a/vllm_omni/distributed/omni_connectors/adapter.py b/vllm_omni/distributed/omni_connectors/adapter.py index 1d1dd6f0f27..3bc05eeb3bc 100644 --- a/vllm_omni/distributed/omni_connectors/adapter.py +++ b/vllm_omni/distributed/omni_connectors/adapter.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# temporary for compatibility with vllm_omni.entrypoints.omni_stage.py -# and vllm_omni.entrypoints.omni_llm.py +# Temporary compatibility shim for vllm_omni.entrypoints.omni_stage.py / omni_llm.py. import time from collections.abc import Callable @@ -26,12 +25,7 @@ def try_send_via_connector( next_stage_queue_submit_fn: Callable[[dict[str, Any]], None], metrics: OrchestratorAggregator, ) -> bool: - """ - Attempts to send data via OmniConnector. - Returns True if successful, False otherwise. - Encapsulates the logic of preparing payload, sending via connector, - sending notification, and recording metrics. - """ + """Send payload via OmniConnector and enqueue notification/metrics; return True on success.""" try: t0 = time.time() @@ -96,10 +90,7 @@ def try_recv_via_connector( connectors: dict[Any, Any], stage_id: int, ) -> tuple[Any, dict[str, Any] | None]: - """ - Attempts to resolve input data from either connector or IPC. - Returns (engine_inputs, rx_metrics) or (None, None) if failed/skipped. - """ + """Resolve engine_inputs from connector/IPC payload; returns (engine_inputs, rx_metrics) or (None, None).""" rid = task["request_id"] if task.get("from_connector"): @@ -154,10 +145,7 @@ def try_recv_via_connector( ) return None, None else: - # Data comes from queue as usual (e.g. seed request for Stage-0) - # Since fallback logic is deprecated, we assume this is a direct inputs payload. - # We still need to decode it if it used SHM (via legacy stage_utils logic, or new shm_connector format) - # For Stage-0 specifically, 'engine_inputs' is often directly in the task dict. + # Queue path (e.g. Stage-0 seed): task should carry direct inputs, but still decode SHM/IPC if present. # Try to use the new stage_utils which uses OmniSerializer from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc_with_metrics @@ -174,14 +162,7 @@ def try_recv_via_connector( def compute_talker_prompt_ids_length(prompt_ids: list[int]) -> int: - """Compute the length of the talker prompt ids. - - Args: - prompt_ids: The prompt ids tensor. - - Returns: - The length of the talker prompt ids. - """ + """Compute talker prompt length for chat-style prompt ids (system/user/assistant).""" im_start_token_id = 151644 system_token_id = 8948 user_token_id = 872 diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 70d38a9d687..2cb4e20e59d 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -179,10 +179,10 @@ def _poll_single_request(self, req_id: str): else: if payload_data.get("finished"): self.finished_requests.add(req_id) - req.status = RequestStatus.FINISHED_STOPPED req.prompt_token_ids = payload_data.get("code_predictor_codes", []) req.num_computed_tokens = 0 + req.additional_information = payload_data # Mark as finished for consumption with self.lock: @@ -308,7 +308,6 @@ def _process_chunk_queue( # of schedule, but have not scheduled continue if request.request_id in self.finished_requests: - request.additional_information = {} continue # Requests that waiting for chunk self.load_async(request) diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py index 09b215bf98a..f6e490567af 100644 --- a/vllm_omni/inputs/preprocess.py +++ b/vllm_omni/inputs/preprocess.py @@ -19,12 +19,72 @@ class OmniInputPreprocessor(InputPreprocessor): - """Input preprocessor for omni models. - - Extends the base InputPreprocessor to handle omni-specific input - types including prompt embeddings and additional information payloads. - Supports processing tokens, embeddings, text, and multimodal inputs. - """ + """Input preprocessor for omni models (tokens/embeds/multimodal + additional_information).""" + + def _is_qwen3_tts_talker_ar(self) -> bool: + archs = getattr(self.model_config, "architectures", None) + return bool(archs) and "Qwen3TTSTalkerForConditionalGenerationARVLLM" in archs + + def _get_qwen3_tts_codec_pad_id(self) -> int: + hf_config = getattr(self.model_config, "hf_config", None) + talker_config = getattr(hf_config, "talker_config", None) + pad = getattr(talker_config, "codec_pad_id", None) + try: + pad_id = int(pad) + except Exception: + pad_id = 0 + return max(0, pad_id) + + def _get_qwen3_tts_prompt_len_tokenizer(self): + # Qwen3-TTS talker prompt length must match HF AutoTokenizer (fix_mistral_regex). + tok = getattr(self, "_qwen3_tts_prompt_len_tokenizer", None) + if tok is not None: + return tok + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained( + self.model_config.model, + trust_remote_code=True, + fix_mistral_regex=True, + use_fast=True, + ) + tok.padding_side = "left" + self._qwen3_tts_prompt_len_tokenizer = tok + return tok + + def _estimate_qwen3_tts_talker_prompt_len(self, additional_information: dict[str, Any] | None) -> int: + """Estimate Qwen3-TTS talker placeholder prompt length for vLLM scheduling. + Real conditioning is carried in additional_information.""" + info = additional_information if isinstance(additional_information, dict) else {} + + def _first(x: object, default: object = "") -> object: + if isinstance(x, list): + return x[0] if x else default + return x if x is not None else default + + task_type = str(_first(info.get("task_type"), "CustomVoice") or "CustomVoice") + hf_config = getattr(self.model_config, "hf_config", None) + talker_config = getattr(hf_config, "talker_config", None) + codec_language_id = getattr(talker_config, "codec_language_id", None) + spk_is_dialect = getattr(talker_config, "spk_is_dialect", None) + + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker_ar import ( + Qwen3TTSTalkerForConditionalGenerationARVLLM, + ) + + tok = self._get_qwen3_tts_prompt_len_tokenizer() + + def _hf_tokenize_len(s: str) -> list[int]: + return tok(s, padding=False)["input_ids"] + + return Qwen3TTSTalkerForConditionalGenerationARVLLM.estimate_prompt_len_from_additional_information( + info, + task_type=task_type, + tokenize_prompt=_hf_tokenize_len, + codec_language_id=codec_language_id, + spk_is_dialect=spk_is_dialect, + estimate_ref_code_len=None, + ) def _process_text( self, @@ -51,10 +111,18 @@ def _process_text( if additional_information is not None: inputs["additional_information"] = additional_information else: - prompt_token_ids = self._tokenize_prompt( - prompt_text, - tokenization_kwargs=tokenization_kwargs, - ) + if self._is_qwen3_tts_talker_ar(): + # Qwen3-TTS talker uses a small codec vocab; text token ids are OOV. + # Use in-vocab pad placeholders for scheduling. + additional_information = parsed_content.get("additional_information") + prompt_len = self._estimate_qwen3_tts_talker_prompt_len(additional_information) + pad_id = self._get_qwen3_tts_codec_pad_id() + prompt_token_ids = [pad_id] * prompt_len + else: + prompt_token_ids = self._tokenize_prompt( + prompt_text, + tokenization_kwargs=tokenization_kwargs, + ) inputs = token_inputs_omni( prompt_token_ids, prompt_embeds=parsed_content.get("prompt_embeds"), diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py new file mode 100644 index 00000000000..cb949d0c0c4 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -0,0 +1,361 @@ +from __future__ import annotations + +import os +from collections.abc import Iterable +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from transformers.utils.hub import cached_file +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +from .qwen3_tts_tokenizer import Qwen3TTSTokenizer + +logger = init_logger(__name__) + + +class Qwen3TTSCode2Wav(nn.Module): + """Stage-1 code2wav model for Qwen3-TTS (GenerationModelRunner). + Consumes frame-aligned codec tokens from input_ids and decodes waveform via SpeechTokenizer.""" + + input_modalities = "audio" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.model_path = vllm_config.model_config.model + + self.have_multimodal_outputs = True + self.has_preprocess = False + self.has_postprocess = False + # Generation-only stage (no logits / sampling). + self.requires_raw_input_tokens = True + + self._speech_tokenizer: Qwen3TTSTokenizer | None = None + self._num_quantizers: int | None = None + self._decode_upsample_rate: int | None = None + self._output_sample_rate: int | None = None + + # Default streaming window (must match connector config by convention). + self._stream_chunk_frames = 25 + self._stream_left_context_frames = 25 + self._logged_codec_stats = False + + def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: + if self._speech_tokenizer is not None: + return self._speech_tokenizer + + # Locate speech_tokenizer dir from HF cache (or local path). + cfg_path = cached_file(self.model_path, "speech_tokenizer/config.json") + if cfg_path is None: + raise ValueError(f"{self.model_path}/speech_tokenizer/config.json not found") + speech_tokenizer_dir = os.path.dirname(cfg_path) + + # Stage-1 only needs decode; skip HF feature extractor to avoid heavy optional deps. + # Still require preprocessor_config.json (use cached_file so online runs can fetch it). + prep_cfg = cached_file(self.model_path, "speech_tokenizer/preprocessor_config.json") + if prep_cfg is None: + raise ValueError( + f"{self.model_path}/speech_tokenizer/preprocessor_config.json not found. " + "Please make sure the checkpoint contains the required HF preprocessing files." + ) + + tok = Qwen3TTSTokenizer.from_pretrained( + speech_tokenizer_dir, + torch_dtype=torch.bfloat16, + load_feature_extractor=False, + ) + + # Align device with vLLM worker. + device = getattr(self.vllm_config.device_config, "device", None) + if device is None: + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + try: + if tok.model is not None: + tok.model.to(device=device) + tok.device = device + except Exception as e: + raise RuntimeError(f"Failed to move SpeechTokenizer to device={device}: {e}") from e + + # Derive codec group count and rates from tokenizer config if possible. + num_q = None + try: + dec_cfg = getattr(tok.model.config, "decoder_config", None) + if dec_cfg is not None: + num_q = getattr(dec_cfg, "num_quantizers", None) + except Exception: + num_q = None + if num_q is None: + # Fallback: many code2wav stages use 16 quantizers. + num_q = 16 + num_q = int(num_q) + if num_q <= 0: + raise ValueError(f"Invalid speech_tokenizer num_quantizers={num_q}") + + try: + upsample = int(tok.get_decode_upsample_rate()) + except Exception as e: + raise ValueError(f"Failed to get decode upsample rate: {e}") from e + if upsample <= 0: + raise ValueError(f"Invalid decode upsample rate: {upsample}") + + try: + out_sr = int(tok.get_output_sample_rate()) + except Exception: + out_sr = 24000 + + self._speech_tokenizer = tok + self._num_quantizers = num_q + self._decode_upsample_rate = upsample + self._output_sample_rate = out_sr + return tok + + def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: + # This stage ignores token embeddings. Keep a stable dummy embedding for vLLM runner. + if input_ids.numel() == 0: + return torch.empty((0, 1), device=input_ids.device, dtype=torch.float32) + return torch.zeros((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.float32) + + def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None) -> None: + return None + + @staticmethod + def _reconstruct_window_codes_fq( + *, + chunk_ids: torch.Tensor, + q: int, + chunk_frames: int, + codec_streaming: bool, + ctx_frames: int, + ctx_codes: list[int] | None, + ) -> torch.Tensor: + """Reconstruct [F, Q] codes from codebook-major flattened chunk ids (and optional left-context).""" + if q <= 0: + raise ValueError(f"Invalid q={q} (must be >0).") + if chunk_frames <= 0: + raise ValueError(f"Invalid chunk_frames={chunk_frames} (must be >0).") + + if int(chunk_ids.numel()) != int(q) * int(chunk_frames): + raise ValueError( + "Invalid chunk_ids length for Qwen3TTSCode2Wav: " + f"got={int(chunk_ids.numel())} expected={int(q) * int(chunk_frames)} " + f"(q={q} chunk_frames={chunk_frames})." + ) + + chunk_qf = chunk_ids.reshape(int(q), int(chunk_frames)) + if codec_streaming and ctx_frames > 0: + if ctx_codes is None: + raise ValueError("Missing ctx_codes for streaming decode window reconstruction.") + expected_ctx_tokens = int(q) * int(ctx_frames) + if len(ctx_codes) != expected_ctx_tokens: + raise ValueError( + "Invalid ctx_codes length for streaming decode window reconstruction: " + f"got={len(ctx_codes)} expected={expected_ctx_tokens} (q={q} ctx_frames={ctx_frames})." + ) + ctx_tensor = torch.tensor(ctx_codes, dtype=torch.long, device=chunk_ids.device) + ctx_qf = ctx_tensor.reshape(int(q), int(ctx_frames)) + window_qf = torch.cat([ctx_qf, chunk_qf], dim=1) + else: + window_qf = chunk_qf + + return window_qf.transpose(0, 1).contiguous() # [F, Q] + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: Any = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor]: + # ModelOutput is (audio_tensor, sr_tensor). + tok = self._ensure_speech_tokenizer_loaded() + assert self._num_quantizers is not None + assert self._output_sample_rate is not None + + if input_ids is None: + # Profile run / placeholder schedule: return empty audio. + empty = torch.zeros((0,), dtype=torch.float32) + return empty, torch.tensor(self._output_sample_rate, dtype=torch.int32) + + ids = input_ids.reshape(-1).to(dtype=torch.long) + q = int(self._num_quantizers) + + if ids.numel() == 0 or ids.numel() < q: + empty = torch.zeros((0,), dtype=torch.float32) + return empty, torch.tensor(self._output_sample_rate, dtype=torch.int32) + + # Contract: connector provides codec_streaming + codec_context_frames (left-context frames to trim). + # Assumes max_batch_size=1 for code2wav (vLLM provides a flattened per-step token stream). + ctx_frames: int | None = None + codec_streaming: bool | None = None + ctx_codes: list[int] | None = None + chunk_frames: int | None = None + rt_info = kwargs.get("runtime_additional_information") + if isinstance(rt_info, list) and len(rt_info) == 1 and isinstance(rt_info[0], dict): + v = rt_info[0].get("codec_streaming") + if v is not None: + try: + codec_streaming = bool(v) if not isinstance(v, torch.Tensor) else bool(v.item()) + except Exception: + codec_streaming = None + v = rt_info[0].get("codec_context_frames") + if v is not None: + try: + ctx_frames = int(v) + except Exception as e: + raise ValueError(f"Invalid codec_context_frames={v!r}: {e}") from e + v = rt_info[0].get("codec_context_codes") + if v is not None: + if isinstance(v, list): + ctx_codes = [int(x) for x in v] + elif isinstance(v, torch.Tensor): + ctx_codes = v.detach().to("cpu").reshape(-1).to(dtype=torch.long).tolist() + v = rt_info[0].get("codec_chunk_frames") + if v is not None: + try: + chunk_frames = int(v) + except Exception as e: + raise ValueError(f"Invalid codec_chunk_frames={v!r}: {e}") from e + + if codec_streaming is None: + raise ValueError( + "Missing codec_streaming in runtime_additional_information for Qwen3TTSCode2Wav. " + "This indicates the async_chunk connector/adapter contract was not applied." + ) + + if codec_streaming is False: + ctx_frames = 0 + else: + if ctx_frames is None: + raise ValueError( + "Missing codec_context_frames in runtime_additional_information for streaming Qwen3TTSCode2Wav. " + "This indicates the async_chunk connector/adapter contract was not applied." + ) + if ctx_frames < 0: + raise ValueError(f"Invalid codec_context_frames={ctx_frames} (must be >=0).") + + # input_ids may be padded; use codec_chunk_frames to slice the exact chunk (chunk_frames * q) and ignore padding. + if chunk_frames is None: + raise ValueError( + "Missing codec_chunk_frames in runtime_additional_information for Qwen3TTSCode2Wav. " + "This indicates the async_chunk connector/adapter contract was not applied." + ) + if chunk_frames < 0: + raise ValueError(f"Invalid codec_chunk_frames={chunk_frames} (must be >=0).") + expected_chunk_tokens = int(chunk_frames) * q + if expected_chunk_tokens == 0: + empty = torch.zeros((0,), dtype=torch.float32) + return empty, torch.tensor(self._output_sample_rate, dtype=torch.int32) + if ids.numel() < expected_chunk_tokens: + raise ValueError( + "Code2Wav received fewer tokens than expected for this chunk: " + f"got={int(ids.numel())} expected={expected_chunk_tokens} " + f"(chunk_frames={int(chunk_frames)} q={q}). " + "This indicates vLLM split the chunk across multiple forward calls; " + "the code2wav stage requires per-step frame-aligned chunks." + ) + if ids.numel() > expected_chunk_tokens: + # Extra non-padding tokens beyond expected_chunk_tokens indicate a scheduler/adapter contract violation. + extra = ids[expected_chunk_tokens:] + if extra.numel() > 0 and bool((extra != 0).any().item()): + raise ValueError( + "Code2Wav received extra non-padding tokens beyond the expected chunk length: " + f"got={int(ids.numel())} expected={expected_chunk_tokens} " + f"(chunk_frames={int(chunk_frames)} q={q}). " + "This indicates multiple codec chunks were scheduled in a single forward, " + "which breaks streaming trim/paste semantics." + ) + ids = ids[:expected_chunk_tokens] + + chunk_ids = ids + ctx_frames_i = int(ctx_frames or 0) + frames = int((ctx_frames_i if codec_streaming else 0) + int(chunk_frames)) + codes_fq = self._reconstruct_window_codes_fq( + chunk_ids=chunk_ids, + q=q, + chunk_frames=int(chunk_frames), + codec_streaming=bool(codec_streaming), + ctx_frames=ctx_frames_i, + ctx_codes=ctx_codes, + ) + if not self._logged_codec_stats and frames > 1: + self._logged_codec_stats = True + try: + uniq = int(torch.unique(codes_fq).numel()) + cmin = int(codes_fq.min().item()) + cmax = int(codes_fq.max().item()) + head = codes_fq[: min(2, frames), : min(8, q)].detach().to("cpu").tolist() + logger.info( + "Qwen3TTSCode2Wav received codec codes: frames=%d q=%d uniq=%d range=[%d,%d] head=%s", + frames, + q, + uniq, + cmin, + cmax, + head, + ) + except Exception: + pass + + wavs, sr = tok.decode({"audio_codes": codes_fq}) + if not wavs: + raise ValueError("SpeechTokenizer code2wav produced empty waveform list.") + audio_np = wavs[0].astype(np.float32, copy=False) + + if ctx_frames > 0: + # Trim waveform samples corresponding to left-context frames in the sliding window. + upsample = self._decode_upsample_rate + if upsample is None: + try: + upsample = int(tok.get_decode_upsample_rate()) + except Exception as e: + raise ValueError(f"Failed to get decode upsample rate: {e}") from e + if upsample <= 0: + raise ValueError(f"Invalid decode upsample rate: {upsample}") + self._decode_upsample_rate = upsample + + ctx_frames_i = int(ctx_frames) + if ctx_frames_i > frames: + raise ValueError(f"codec_context_frames={ctx_frames_i} exceeds frames={frames}") + + decoded = int(audio_np.shape[0]) + cut = int(ctx_frames_i) * int(upsample) + if cut > decoded: + raise ValueError( + "Streaming decode context trim exceeds decoded length: " + f"cut={cut} decoded={decoded} ctx_frames={ctx_frames_i} frames={frames}" + ) + audio_np = audio_np[cut:] + + # Return 1D waveform per chunk so the output processor can concatenate along time. + # Returning [1, T] would stack chunks as channels. + audio_tensor = torch.from_numpy(audio_np).to(dtype=torch.float32).reshape(-1) + sr_tensor = torch.tensor(int(sr), dtype=torch.int32) + return audio_tensor, sr_tensor + + def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput: + if isinstance(model_outputs, OmniOutput): + return model_outputs + + if not (isinstance(model_outputs, tuple) and len(model_outputs) == 2): + raise TypeError(f"Qwen3TTSCode2Wav expected (audio_tensor, sr) outputs, got {type(model_outputs)}") + + audio_tensor, sr = model_outputs + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={ + "model_outputs": audio_tensor, + "sr": sr, + }, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # SpeechTokenizer weights live under `speech_tokenizer/` and are loaded + # lazily from that directory. Ignore main checkpoint weights. + return set() diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py new file mode 100644 index 00000000000..044bc5292df --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -0,0 +1,407 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +from vllm.config import VllmConfig +from vllm.config.vllm import set_current_vllm_config +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer +from vllm.model_executor.models.utils import is_pp_missing_parameter +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, KVCacheTensor +from vllm.v1.worker.gpu import attn_utils + +from .configuration_qwen3_tts import Qwen3TTSTalkerCodePredictorConfig, Qwen3TTSTalkerConfig + +logger = init_logger(__name__) + + +class _LocalPredictorKVCache: + """Minimal local KV cache + attention metadata for running code_predictor inside one worker (independent of engine KV).""" + + def __init__( + self, + *, + vllm_config: VllmConfig, + max_seq_len: int, + max_batch_size: int, + device: torch.device, + ) -> None: + self.vllm_config = vllm_config + self.device = device + + # Collect attention layers registered in this vllm_config. + kv_cache_spec_by_layer = attn_utils.get_kv_cache_spec(vllm_config) + if not kv_cache_spec_by_layer: + raise RuntimeError("Local predictor KVCache requires vLLM Attention layers to be registered.") + + # We only need enough blocks for a tiny per-frame sequence (<= max_seq_len). + any_spec = next(iter(kv_cache_spec_by_layer.values())) + block_size = int(any_spec.block_size) + blocks_per_seq = (int(max_seq_len) + block_size - 1) // block_size + num_blocks = max(1, int(max_batch_size) * int(blocks_per_seq)) + + # Allocate per-layer KV caches (small, independent). + kv_cache_tensors: list[KVCacheTensor] = [] + for layer_name, spec in kv_cache_spec_by_layer.items(): + kv_cache_tensors.append(KVCacheTensor(size=int(spec.page_size_bytes) * num_blocks, shared_by=[layer_name])) + + merged_spec: KVCacheSpec = KVCacheSpec.merge(list(kv_cache_spec_by_layer.values())) + self.kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=[ + KVCacheGroupSpec(layer_names=list(kv_cache_spec_by_layer.keys()), kv_cache_spec=merged_spec) + ], + ) + + # Init backend + bind KV cache tensors to attention modules. + self.attn_backends, self.attn_metadata_builders = attn_utils.init_attn_backend( + self.kv_cache_config, vllm_config, device + ) + self.runner_kv_caches: list[torch.Tensor] = [] + attn_utils.init_kv_cache( + self.runner_kv_caches, + vllm_config.compilation_config.static_forward_context, + self.kv_cache_config, + self.attn_backends, + device, + ) + + # Precompute a fixed block table mapping for the maximum batch. + self.block_size = block_size + self.blocks_per_seq = blocks_per_seq + self.max_batch_size = int(max_batch_size) + + bt = torch.full((self.max_batch_size, self.blocks_per_seq), -1, dtype=torch.int32, device=device) + for i in range(self.max_batch_size): + for j in range(self.blocks_per_seq): + bt[i, j] = i * self.blocks_per_seq + j + self._block_table = bt + + def build_attn_metadata( + self, + *, + num_reqs: int, + query_lens: torch.Tensor, # (num_reqs,) int32 on cpu + seq_lens: torch.Tensor, # (num_reqs,) int32 on cpu + ) -> tuple[dict[str, Any], torch.Tensor]: + """Build attention metadata and return (attn_metadata, positions).""" + num_reqs = int(num_reqs) + if num_reqs <= 0: + return {}, torch.empty((0,), dtype=torch.int64, device=self.device) + if num_reqs > self.max_batch_size: + raise ValueError(f"num_reqs={num_reqs} exceeds local predictor max_batch_size={self.max_batch_size}") + + query_lens_i32 = query_lens.to(dtype=torch.int32, device="cpu") + seq_lens_i32 = seq_lens.to(dtype=torch.int32, device="cpu") + + # query_start_loc: prefix sums of query_lens. + qsl = torch.zeros((num_reqs + 1,), dtype=torch.int32, device="cpu") + qsl[1:] = torch.cumsum(query_lens_i32, dim=0) + num_tokens = int(qsl[-1].item()) + if num_tokens <= 0: + return {}, torch.empty((0,), dtype=torch.int64, device=self.device) + + # positions: for each request i, emit positions [seq_len-query_len .. seq_len-1] + pos_list: list[torch.Tensor] = [] + for i in range(num_reqs): + ql = int(query_lens_i32[i].item()) + sl = int(seq_lens_i32[i].item()) + start = sl - ql + pos_list.append(torch.arange(start, sl, dtype=torch.int64)) + positions_cpu = torch.cat(pos_list, dim=0) + + # slot_mapping: map each query token to a physical slot in the paged KV cache. + # We allocate per-request contiguous blocks; slot = base + position. + slot_mapping = torch.empty((num_tokens,), dtype=torch.int64, device="cpu") + cursor = 0 + for i in range(num_reqs): + ql = int(query_lens_i32[i].item()) + sl = int(seq_lens_i32[i].item()) + start = sl - ql + for p in range(start, sl): + block_idx = p // self.block_size + offset = p % self.block_size + block_id = int(self._block_table[i, block_idx].item()) + slot_mapping[cursor] = block_id * self.block_size + offset + cursor += 1 + + max_seq_len = int(seq_lens_i32[:num_reqs].max().item()) + query_start_loc_gpu = qsl.to(device=self.device) + seq_lens_gpu = seq_lens_i32.to(device=self.device) + block_table = self._block_table[:num_reqs].contiguous() + slot_mapping_gpu = slot_mapping.to(device=self.device) + + attn_metadata = attn_utils.build_attn_metadata( + self.attn_metadata_builders, + num_reqs=num_reqs, + num_tokens=num_tokens, + query_start_loc_gpu=query_start_loc_gpu, + query_start_loc_cpu=qsl, + seq_lens=seq_lens_gpu, + max_seq_len=max_seq_len, + block_tables=[block_table], + slot_mappings=[slot_mapping_gpu], + kv_cache_config=self.kv_cache_config, + ) + return attn_metadata, positions_cpu.to(device=self.device) + + +class Qwen3TTSTalkerCodePredictorModelVLLM(nn.Module): + def __init__( + self, + config: Qwen3TTSTalkerCodePredictorConfig, + *, + talker_hidden_size: int | None = None, + cache_config=None, + quant_config=None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer( + config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Official code_predictor uses one embedding table per residual group. + # Some Qwen3-TTS checkpoints store codec embeddings in the talker hidden + # space, even when `code_predictor_config.hidden_size` is smaller. + # We keep the embedding dim aligned with the checkpoint and project down + # via `small_to_mtp_projection` in the wrapper module. + emb_dim = int(talker_hidden_size) if talker_hidden_size is not None else int(config.hidden_size) + self.codec_embedding = nn.ModuleList( + [nn.Embedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)] + ) + + def get_input_embeddings(self) -> nn.ModuleList: + return self.codec_embedding + + def forward(self, positions: torch.Tensor, inputs_embeds: torch.Tensor) -> torch.Tensor: + # Token-major: [num_tokens, hidden] + hidden_states = inputs_embeds + residual = None + for layer in self.layers: + hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Match vLLM Qwen2/Qwen3 packing conventions: q_proj/k_proj/v_proj -> qkv_proj, + # gate_proj/up_proj -> gate_up_proj. + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name)): + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + mapped = name.replace(weight_name, param_name) + if mapped.endswith(".bias") and mapped not in params_dict: + continue + if is_pp_missing_parameter(mapped, self): + continue + if mapped.endswith("scale"): + mapped = maybe_remap_kv_scale_name(mapped, params_dict) + if mapped is None: + continue + param = params_dict.get(mapped) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(mapped) + break + else: + mapped = maybe_remap_kv_scale_name(name, params_dict) + if mapped is None: + continue + if name.endswith(".bias") and mapped not in params_dict: + continue + if is_pp_missing_parameter(mapped, self): + continue + param = params_dict.get(mapped) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(mapped) + return loaded_params + + +class Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM(nn.Module): + """vLLM-native code_predictor used by the AR talker (residual codebooks).""" + + def __init__( + self, + *, + vllm_config: VllmConfig, + config: Qwen3TTSTalkerCodePredictorConfig, + talker_config: Qwen3TTSTalkerConfig, + prefix: str = "code_predictor", + ) -> None: + super().__init__() + self._vllm_config = vllm_config + self.config = config + self.talker_config = talker_config + + # Keep module/weight names aligned with official checkpoint (talker.code_predictor.model.*). + self.model = Qwen3TTSTalkerCodePredictorModelVLLM( + config, + talker_hidden_size=int(talker_config.hidden_size), + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.model", + ) + + # One head per residual group. + self.lm_head = nn.ModuleList( + [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)] + ) + + if config.hidden_size != talker_config.hidden_size: + self.small_to_mtp_projection = nn.Linear(talker_config.hidden_size, config.hidden_size, bias=True) + else: + self.small_to_mtp_projection = nn.Identity() + + self._kv_cache: _LocalPredictorKVCache | None = None + + def get_input_embeddings(self) -> nn.ModuleList: + return self.model.get_input_embeddings() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Ensure all vLLM custom layers consult the predictor vllm_config + # (esp. for Attention static_forward_context). + with set_current_vllm_config(self._vllm_config): + loaded: set[str] = set() + model_weights: list[tuple[str, torch.Tensor]] = [] + other_weights: list[tuple[str, torch.Tensor]] = [] + for name, w in weights: + if name.startswith("model."): + model_weights.append((name[len("model.") :], w)) + else: + other_weights.append((name, w)) + + loaded_model = self.model.load_weights(model_weights) + loaded |= {f"model.{n}" for n in loaded_model} + + params = dict(self.named_parameters(remove_duplicate=False)) + for name, w in other_weights: + if name not in params: + continue + default_weight_loader(params[name], w) + loaded.add(name) + return loaded + + def _maybe_init_kv_cache(self, device: torch.device) -> None: + if self._kv_cache is not None: + return + max_seq_len = int(getattr(self.config, "num_code_groups", 16) or 16) + # Upper bound on batch size: vLLM scheduler max_num_seqs (fallback 8). + max_batch = int(getattr(self._vllm_config.scheduler_config, "max_num_seqs", 8) or 8) + max_batch = max(1, max_batch) + self._kv_cache = _LocalPredictorKVCache( + vllm_config=self._vllm_config, + max_seq_len=max_seq_len, + max_batch_size=max_batch, + device=device, + ) + + @torch.inference_mode() + def reset_cache(self) -> None: + # We reuse a fixed kv cache buffer and overwrite starting at slot 0. + # No action required here (seq_lens controls what is read). + return + + @torch.inference_mode() + def prefill_logits(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + """Prefill with 2 tokens: [past_hidden, layer0_embed]. Returns logits for residual group 0.""" + self._maybe_init_kv_cache(inputs_embeds.device) + assert self._kv_cache is not None + + bsz = int(inputs_embeds.shape[0]) + qlen = 2 + # Flatten to token-major. + hs = inputs_embeds.to(dtype=torch.bfloat16).reshape(bsz * qlen, -1) + hs = self.small_to_mtp_projection(hs) + + query_lens = torch.full((bsz,), qlen, dtype=torch.int32) + seq_lens = query_lens.clone() + attn_metadata, positions = self._kv_cache.build_attn_metadata( + num_reqs=bsz, query_lens=query_lens, seq_lens=seq_lens + ) + + with ( + set_current_vllm_config(self._vllm_config), + set_forward_context(attn_metadata, self._vllm_config, num_tokens=int(hs.shape[0])), + ): + out = self.model(positions=positions, inputs_embeds=hs) + + # Gather last token per request. + last_idx = torch.arange(qlen - 1, bsz * qlen, step=qlen, device=out.device, dtype=torch.long) + last_h = out.index_select(0, last_idx) + logits = self.lm_head[0](last_h) + return logits + + @torch.inference_mode() + def decode_logits(self, input_ids: torch.Tensor, *, generation_step: int, past_seq_len: int) -> torch.Tensor: + """Decode one new token for residual group `generation_step` (1..Q-1).""" + self._maybe_init_kv_cache(input_ids.device) + assert self._kv_cache is not None + bsz = int(input_ids.shape[0]) + if generation_step <= 0: + raise ValueError("generation_step must be >= 1 for decode_logits") + + embed_idx = generation_step - 1 + hs = self.model.get_input_embeddings()[embed_idx](input_ids.to(dtype=torch.long).reshape(bsz, 1)) + hs = self.small_to_mtp_projection(hs.reshape(bsz, -1)) + + query_lens = torch.ones((bsz,), dtype=torch.int32) + seq_lens = torch.full((bsz,), int(past_seq_len) + 1, dtype=torch.int32) + attn_metadata, positions = self._kv_cache.build_attn_metadata( + num_reqs=bsz, query_lens=query_lens, seq_lens=seq_lens + ) + + with ( + set_current_vllm_config(self._vllm_config), + set_forward_context(attn_metadata, self._vllm_config, num_tokens=int(hs.shape[0])), + ): + out = self.model(positions=positions, inputs_embeds=hs) + + logits = self.lm_head[generation_step](out) + return logits diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py new file mode 100644 index 00000000000..25329b8851c --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py @@ -0,0 +1,228 @@ +import os +from collections.abc import Iterable +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from transformers.utils.hub import cached_file +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +from .qwen3_tts import Qwen3TTSModel +from .qwen3_tts_tokenizer import Qwen3TTSTokenizer + +logger = init_logger(__name__) + +_VALID_TASK_TYPES = ("CustomVoice", "VoiceDesign", "Base") +_VALID_STAGES = ("talker", "speech_tokenizer") + + +class Qwen3TTSForConditionalGenerationDisaggregatedVLLM(nn.Module): + """Stage-aware wrapper for disaggregated Qwen3-TTS (selects stage via model_stage). + SpeechTokenizer stage decodes codec->waveform; talker is handled by the AR talker model.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.model_path = vllm_config.model_config.model + self.model_stage = getattr(vllm_config.model_config, "model_stage", None) + self._async_chunk = bool(getattr(vllm_config.model_config, "async_chunk", False)) + + if self.model_stage not in _VALID_STAGES: + raise ValueError(f"Invalid model_stage for Qwen3-TTS disaggregated model: {self.model_stage}") + + if self.model_stage == "talker": + # Avoid accidental fallback to the HF generate() path. + raise ValueError( + "Qwen3-TTS disaggregated wrapper no longer supports model_stage='talker'. " + "Use model_arch=Qwen3TTSTalkerForConditionalGenerationARVLLM for Stage-0." + ) + + self.have_multimodal_outputs = True + # Only speech_tokenizer needs preprocess in async_chunk (treat prompt_token_ids as codec codes). + self.has_preprocess = bool(self.model_stage == "speech_tokenizer" and self._async_chunk) + if self.model_stage == "speech_tokenizer" and not self._async_chunk: + raise ValueError( + "Qwen3-TTS SpeechTokenizer stage no longer supports serial " + "`additional_information['audio_codes']` mode. Use async_chunk " + "stage config so Stage-1 consumes codec codes via prompt_token_ids." + ) + + self._talker: Qwen3TTSModel | None = None + self._speech_tokenizer: Qwen3TTSTokenizer | None = None + # Only required for Stage-1 streaming decode (to reframe flattened codes). + self._num_code_groups = 0 + if self.model_stage == "speech_tokenizer": + try: + self._num_code_groups = int(vllm_config.model_config.hf_config.talker_config.num_code_groups) + except Exception as e: + raise ValueError(f"Failed to read talker_config.num_code_groups from hf_config: {e}") from e + if self._num_code_groups <= 0: + raise ValueError(f"Invalid num_code_groups={self._num_code_groups} for Qwen3-TTS.") + + def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: + if self._speech_tokenizer is not None: + return self._speech_tokenizer + + # Locate speech_tokenizer dir from HF cache (or local path). + speech_tokenizer_path = cached_file(self.model_path, "speech_tokenizer/config.json") + if speech_tokenizer_path is None: + raise ValueError(f"{self.model_path}/speech_tokenizer/config.json not found") + speech_tokenizer_dir = os.path.dirname(speech_tokenizer_path) + self._speech_tokenizer = Qwen3TTSTokenizer.from_pretrained( + speech_tokenizer_dir, + torch_dtype=torch.bfloat16, + load_feature_extractor=False, + ) + # Run decode on the vLLM worker device (fallback to best-effort CUDA/CPU). + device = getattr(self.vllm_config.device_config, "device", None) + if device is None: + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + try: + if hasattr(self._speech_tokenizer, "model") and self._speech_tokenizer.model is not None: + self._speech_tokenizer.model.to(device=device) + self._speech_tokenizer.device = device + except Exception as e: + raise RuntimeError(f"Failed to move SpeechTokenizer to device={device}: {e}") from e + return self._speech_tokenizer + + def preprocess( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor | None, + **info_dict: Any, + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + # Only used in async_chunk speech_tokenizer stage. + if self.model_stage != "speech_tokenizer" or not self._async_chunk: + return input_ids, (input_embeds if input_embeds is not None else self.embed_input_ids(input_ids)), {} + + if self._num_code_groups <= 0: + raise ValueError(f"Invalid talker_config.num_code_groups={self._num_code_groups} for streaming decode.") + + # Optional request id for debugging only (streaming decode keeps no per-request state). + req_id = str(info_dict.get("_omni_request_id") or "") + + q = int(self._num_code_groups) + if input_ids.numel() <= 0: + update = {"model_outputs": None, "sr": None} + return input_ids, (input_embeds if input_embeds is not None else self.embed_input_ids(input_ids)), update + + tokens = input_ids.reshape(-1).to(torch.long) + if int(tokens.numel()) % q != 0: + # Finished requests may still get placeholder tokens; treat as a no-op instead of crashing. + if bool(info_dict.get("finished", False)) or int(tokens.numel()) <= 1: + update = {"model_outputs": None, "sr": None} + return ( + input_ids, + (input_embeds if input_embeds is not None else self.embed_input_ids(input_ids)), + update, + ) + raise ValueError( + f"Streaming codec token length must be divisible by num_code_groups={q}. " + f"got={int(tokens.numel())} request_id={req_id or ''}" + ) + + frames = int(tokens.numel()) // q + if frames <= 0: + update = {"model_outputs": None, "sr": None} + return input_ids, (input_embeds if input_embeds is not None else self.embed_input_ids(input_ids)), update + + # tokens are codebook-major flattened: [Q, F] flattened row-major. + codes_qf = tokens.reshape(q, frames) + codes_fq = codes_qf.transpose(0, 1).contiguous() # [F, Q] + + ctx_frames = int(info_dict.get("codec_context_frames") or 0) + if ctx_frames < 0 or ctx_frames > frames: + raise ValueError( + f"Invalid codec_context_frames={ctx_frames} for frames={frames} request_id={req_id or ''}" + ) + + tok = self._ensure_speech_tokenizer_loaded() + device = getattr(tok, "device", None) or torch.device("cpu") + codes_chunk = codes_fq.to(device=device) + + wavs, sr = tok.decode({"audio_codes": codes_chunk}) + if not wavs: + raise ValueError("SpeechTokenizer streaming decode produced empty waveform list.") + audio_np = wavs[0].astype(np.float32, copy=False) + + if ctx_frames > 0: + try: + upsample = int(tok.get_decode_upsample_rate()) + except Exception as e: + raise ValueError(f"Failed to get decode upsample rate for streaming trim: {e}") from e + if upsample <= 0: + raise ValueError(f"Invalid decode upsample rate: {upsample}") + cut = ctx_frames * upsample + if cut >= audio_np.shape[0]: + raise ValueError( + f"Streaming decode context trim exceeds decoded length: cut={cut} decoded={audio_np.shape[0]}" + ) + audio_np = audio_np[cut:] + + update: dict[str, Any] = { + "model_outputs": torch.from_numpy(audio_np).to(dtype=torch.float32), + "sr": torch.tensor(int(sr), dtype=torch.int), + } + return input_ids, (input_embeds if input_embeds is not None else self.embed_input_ids(input_ids)), update + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: Any = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Any, + ) -> OmniOutput: + runtime_info = kwargs.get("runtime_additional_information", [{}]) + if isinstance(runtime_info, list) and runtime_info: + runtime_info = runtime_info[0] + if not isinstance(runtime_info, dict): + runtime_info = {} + + # speech_tokenizer stage: decode in preprocess(); forward returns a dummy tensor for span slicing. + device = input_ids.device if isinstance(input_ids, torch.Tensor) else torch.device("cpu") + n = int(input_ids.shape[0]) if isinstance(input_ids, torch.Tensor) else 1 + if n <= 0: + n = 1 + return torch.zeros((n, 1), dtype=torch.float32, device=device) + + def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput: + if isinstance(model_outputs, OmniOutput): + return model_outputs + + # async_chunk speech_tokenizer: emit the latest decoded chunk from runtime_additional_information. + if self.model_stage != "speech_tokenizer" or not self._async_chunk: + return OmniOutput(text_hidden_states=model_outputs, multimodal_outputs={}) + + runtime_info = kwargs.get("runtime_additional_information", [{}]) + if isinstance(runtime_info, list) and runtime_info: + runtime_info = runtime_info[0] + if not isinstance(runtime_info, dict): + runtime_info = {} + + mo = runtime_info.get("model_outputs") + sr = runtime_info.get("sr") + if isinstance(mo, torch.Tensor) and isinstance(sr, torch.Tensor): + return OmniOutput(text_hidden_states=model_outputs, multimodal_outputs={"model_outputs": mo, "sr": sr}) + return OmniOutput(text_hidden_states=model_outputs, multimodal_outputs={}) + + def compute_logits( + self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None + ) -> torch.Tensor | None: + return None + + def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: + # SpeechTokenizer ignores token embeddings, but vLLM requires embed_input_ids to select the runner type. + if input_ids.numel() == 0: + return torch.empty((0, 1), device=input_ids.device, dtype=torch.float32) + return torch.zeros((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.float32) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Talker loads weights elsewhere; speech_tokenizer loads `speech_tokenizer/` lazily. + # Return empty set without consuming weights to avoid vLLM re-loading. + return set() diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py new file mode 100644 index 00000000000..65b6c4e154a --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py @@ -0,0 +1,1374 @@ +from __future__ import annotations + +import base64 +import dataclasses +import io +import os +from collections.abc import Callable, Iterable, Mapping +from typing import Any +from urllib.parse import urlparse +from urllib.request import urlopen + +import numpy as np +import soundfile as sf +import torch +import torch.nn as nn +from transformers import AutoTokenizer +from transformers.utils.hub import cached_file +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.qwen3 import Qwen3Model +from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, WeightsMapper, maybe_prefix +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +from .configuration_qwen3_tts import Qwen3TTSConfig, Qwen3TTSTalkerConfig +from .modeling_qwen3_tts import ( + Qwen3TTSSpeakerEncoder, + Qwen3TTSTalkerResizeMLP, + mel_spectrogram, +) +from .qwen3_tts_code_predictor_vllm import Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM +from .qwen3_tts_tokenizer import Qwen3TTSTokenizer + +logger = init_logger(__name__) + + +class Qwen3TTSTalkerForConditionalGenerationARVLLM(nn.Module): + """vLLM-AR talker: step-wise layer-0 codec decoding. + Predicts residual codebooks (1..Q-1) into `audio_codes` and streams text via `tailing_text_hidden`.""" + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Talker backbone (Qwen3 decoder-only). + "talker.model.layers.": "model.layers.", + "talker.model.norm.": "model.norm.", + "talker.model.codec_embedding.": "model.embed_tokens.", + # Heads / side modules. + "talker.codec_head.": "lm_head.", + "talker.model.text_embedding.": "text_embedding.", + "talker.text_projection.": "text_projection.", + "talker.code_predictor.": "code_predictor.", + # Speaker encoder (Base only). + "speaker_encoder.": "speaker_encoder.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.model_path = vllm_config.model_config.model + self.config: Qwen3TTSConfig = vllm_config.model_config.hf_config # type: ignore[assignment] + self.talker_config: Qwen3TTSTalkerConfig = self.config.talker_config + + # Codec ids: only [0, codebook_vocab_size) are real code indices (layer-0 is sampled from talker vocab). + # codec_eos_token_id is a special stop token and must not be decoded by SpeechTokenizer. + self._codebook_vocab_size = int(getattr(self.talker_config.code_predictor_config, "vocab_size", 0) or 0) + if self._codebook_vocab_size <= 0: + raise ValueError( + f"Invalid talker_config.code_predictor_config.vocab_size={self._codebook_vocab_size}; " + "cannot restrict codec logits safely." + ) + self._codec_eos_token_id = int(getattr(self.talker_config, "codec_eos_token_id", -1)) + + self.have_multimodal_outputs = True + self.has_preprocess = True + self.has_postprocess = True + + # Used by OmniGPUModelRunner for the GPU-side MTP fast-path. + self.mtp_hidden_size = int(self.talker_config.hidden_size) + # OmniGPUModelRunner will store talker_mtp output under this key in + # per-request additional_information. + self.talker_mtp_output_key = "audio_codes" + + self.model = Qwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.talker_config.vocab_size, + self.talker_config.hidden_size, + quant_config=vllm_config.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(self.talker_config.vocab_size) + self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors + + # Text embedding is a separate table in the official implementation. + self.text_embedding = nn.Embedding(self.talker_config.text_vocab_size, self.talker_config.text_hidden_size) + self.text_projection = Qwen3TTSTalkerResizeMLP( + self.talker_config.text_hidden_size, + self.talker_config.text_hidden_size, + self.talker_config.hidden_size, + self.talker_config.hidden_act, + bias=True, + ) + + # Speaker encoder is only needed for Base voice cloning and may be missing in some checkpoints. + # Keep it optional to avoid strict weight-loading failures. + self.speaker_encoder: Qwen3TTSSpeakerEncoder | None = None + + # Residual code predictor (1..Q-1) uses a dedicated vLLM config to build its own KV cache. + # This avoids polluting the main engine's static forward context. + predictor_compilation = dataclasses.replace(vllm_config.compilation_config) + self._code_predictor_vllm_config = dataclasses.replace(vllm_config, compilation_config=predictor_compilation) + from vllm.config.vllm import set_current_vllm_config as _set_cfg + + with _set_cfg(self._code_predictor_vllm_config): + self.code_predictor = Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM( + vllm_config=self._code_predictor_vllm_config, + config=self.talker_config.code_predictor_config, + talker_config=self.talker_config, + prefix="code_predictor", + ) + + # Tokenizer for prompt building. + self._tokenizer = None + self._speech_tokenizer: Qwen3TTSTokenizer | None = None + + # -------------------- vLLM required hooks -------------------- + + def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **_: Any, + ) -> torch.Tensor | IntermediateTensors: + return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + + def compute_logits( + self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None + ) -> torch.Tensor | None: + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + if hidden_states is None: + return None + logits = self.logits_processor(self.lm_head, hidden_states) + if logits is None: + return None + + # Allow only real codec ids (1..codebook_vocab_size-1) plus codec EOS; specials can crash SpeechTokenizer. + # Also, id 0 is padding for the 12Hz decoder. + vocab = int(logits.shape[-1]) + allowed = torch.zeros((vocab,), dtype=torch.bool, device=logits.device) + lo = 1 + hi = min(self._codebook_vocab_size, vocab) + if hi > lo: + allowed[lo:hi] = True + if 0 <= self._codec_eos_token_id < vocab: + allowed[self._codec_eos_token_id] = True + logits = logits.masked_fill(~allowed, float("-inf")) + return logits + + # -------------------- Omni multimodal output plumbing -------------------- + + def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput: + if isinstance(model_outputs, OmniOutput): + return model_outputs + + hidden = model_outputs + info_dicts = kwargs.get("runtime_additional_information") or [] + audio_codes_list: list[torch.Tensor] = [] + ref_code_len_list: list[torch.Tensor] = [] + codec_streaming_list: list[torch.Tensor] = [] + for info in info_dicts: + if not isinstance(info, dict): + continue + ac = info.get("audio_codes") + if isinstance(ac, torch.Tensor): + audio_codes_list.append(ac) + cs = info.get("codec_streaming") + if isinstance(cs, bool): + codec_streaming_list.append( + torch.full((int(ac.shape[0]),), int(cs), dtype=torch.int8, device=ac.device) + ) + ref_len = info.get("ref_code_len") + if ref_len is None: + continue + if isinstance(ref_len, torch.Tensor): + if ref_len.numel() == 0: + raise ValueError("ref_code_len is an empty tensor") + ref_len_val = int(ref_len.reshape(-1)[-1].item()) + elif isinstance(ref_len, list): + if len(ref_len) != 1: + raise ValueError(f"ref_code_len must be scalar or 1-element list, got len={len(ref_len)}") + ref_len_val = int(ref_len[0]) + else: + ref_len_val = int(ref_len) + if isinstance(ac, torch.Tensor): + # Emit ref_code_len per-token span for runner slicing (consumer takes the last value). + ref_code_len_list.append( + torch.full((int(ac.shape[0]),), ref_len_val, dtype=torch.int32, device=ac.device) + ) + + if not audio_codes_list: + return OmniOutput(text_hidden_states=hidden, multimodal_outputs={}) + + audio_codes = torch.cat(audio_codes_list, dim=0) + span_len = int(audio_codes.shape[0]) + hidden = hidden[:span_len] + mm: dict[str, torch.Tensor] = {"audio_codes": audio_codes} + if ref_code_len_list: + mm["ref_code_len"] = torch.cat(ref_code_len_list, dim=0)[:span_len] + if codec_streaming_list: + mm["codec_streaming"] = torch.cat(codec_streaming_list, dim=0)[:span_len] + return OmniOutput(text_hidden_states=hidden, multimodal_outputs=mm) + + # -------------------- preprocess / postprocess -------------------- + + def preprocess( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor | None, + **info_dict: Any, + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + # Metadata may be passed flattened or under `additional_information`; normalize to flattened keys. + additional_information = info_dict.get("additional_information") + if isinstance(additional_information, dict): + merged: dict[str, Any] = {k: v for k, v in info_dict.items() if k != "additional_information"} + for k, v in additional_information.items(): + merged.setdefault(k, v) + info_dict = merged + + span_len = int(input_ids.shape[0]) + if span_len <= 0: + return input_ids, input_embeds if input_embeds is not None else self.embed_input_ids(input_ids), {} + + text_list = info_dict.get("text") + if not isinstance(text_list, list) or not text_list or not text_list[0]: + raise ValueError("Missing additional_information.text for Qwen3-TTS AR talker.") + + task_type = (info_dict.get("task_type") or ["CustomVoice"])[0] + non_streaming_mode_val = info_dict.get("non_streaming_mode") + if isinstance(non_streaming_mode_val, list): + non_streaming_mode_raw = non_streaming_mode_val[0] if non_streaming_mode_val else None + else: + non_streaming_mode_raw = non_streaming_mode_val + if isinstance(non_streaming_mode_raw, bool): + non_streaming_mode = non_streaming_mode_raw + else: + non_streaming_mode = task_type in ("CustomVoice", "VoiceDesign") + codec_streaming_val = info_dict.get("codec_streaming") + if isinstance(codec_streaming_val, list): + codec_streaming_raw = codec_streaming_val[0] if codec_streaming_val else None + else: + codec_streaming_raw = codec_streaming_val + if isinstance(codec_streaming_raw, bool): + codec_streaming = codec_streaming_raw + else: + codec_streaming = task_type == "Base" + + if span_len > 1: + # Prefill (prompt embeddings) + prompt_embeds_cpu = info_dict.get("talker_prompt_embeds") + prompt_embeds = None + tts_pad_embed_cpu = info_dict.get("tts_pad_embed") + tts_pad_embed = None + if isinstance(tts_pad_embed_cpu, torch.Tensor) and tts_pad_embed_cpu.numel() > 0: + tts_pad_embed = tts_pad_embed_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + + if prompt_embeds is None: + full_prompt_embeds, tailing_text_hidden, tts_pad_embed, ref_code_len = self._build_prompt_embeds( + task_type=task_type, info_dict=info_dict + ) + # Store full prompt embeddings + trailing queue on CPU for later chunks/steps. + prompt_embeds_cpu = full_prompt_embeds.detach().to("cpu").contiguous() + info_update: dict[str, Any] = { + "talker_prompt_embeds": prompt_embeds_cpu, + "tailing_text_hidden": tailing_text_hidden.detach().to("cpu").contiguous(), + "tts_pad_embed": tts_pad_embed.detach().to("cpu").contiguous(), + "talker_prefill_offset": 0, + "codec_streaming": codec_streaming, + } + if ref_code_len is not None: + info_update["ref_code_len"] = int(ref_code_len) + # Always return a span_len slice; if the scheduled placeholder is longer, pad with tts_pad_embed. + # This preserves placeholder/embedding alignment. + offset = 0 + s = 0 + e = span_len + take = prompt_embeds_cpu[s:e] + if int(take.shape[0]) < span_len: + pad_n = int(span_len - int(take.shape[0])) + pad_rows = tts_pad_embed.detach().to("cpu").contiguous().reshape(1, -1).expand(pad_n, -1) + take = torch.cat([take, pad_rows], dim=0) + prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16) + info_update["talker_prefill_offset"] = int(offset + span_len) + else: + # Subsequent prefill chunk: slice from our own running offset. + if not isinstance(prompt_embeds_cpu, torch.Tensor) or prompt_embeds_cpu.ndim != 2: + raise RuntimeError("Invalid talker_prompt_embeds in additional_information.") + if tts_pad_embed is None: + raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must initialize it.") + offset = int(info_dict.get("talker_prefill_offset", 0) or 0) + if offset < 0: + offset = 0 + s = max(0, min(offset, int(prompt_embeds_cpu.shape[0]))) + e = max(0, min(offset + span_len, int(prompt_embeds_cpu.shape[0]))) + take = prompt_embeds_cpu[s:e] + if int(take.shape[0]) < span_len: + pad_n = int(span_len - int(take.shape[0])) + pad_rows = tts_pad_embed.detach().to("cpu").contiguous().reshape(1, -1).expand(pad_n, -1) + take = torch.cat([take, pad_rows], dim=0) + prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16) + info_update = {"talker_prefill_offset": int(offset + span_len)} + info_update["codec_streaming"] = codec_streaming + + # When inputs_embeds is set, token ids are ignored by the model but must stay in-vocab for vLLM bookkeeping. + input_ids_out = input_ids.clone() + input_ids_out[:] = int(self.talker_config.codec_pad_id) + + zeros = torch.zeros( + (prompt_embeds.shape[0], int(self.talker_config.num_code_groups)), + device=input_ids.device, + dtype=torch.long, + ) + info_update["audio_codes"] = zeros + return input_ids_out, prompt_embeds, info_update + + # Decode: span_len == 1 + # Pop one text-step vector from tailing_text_hidden queue. + tts_pad_embed_cpu = info_dict.get("tts_pad_embed") + if not isinstance(tts_pad_embed_cpu, torch.Tensor): + raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must run first.") + tts_pad_embed = tts_pad_embed_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + + tail_cpu = info_dict.get("tailing_text_hidden") + if isinstance(tail_cpu, torch.Tensor) and tail_cpu.ndim == 2 and tail_cpu.shape[0] > 0: + text_step = tail_cpu[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + new_tail = tail_cpu[1:].detach().to("cpu").contiguous() if tail_cpu.shape[0] > 1 else tail_cpu[:0] + else: + text_step = tts_pad_embed + new_tail = tail_cpu if isinstance(tail_cpu, torch.Tensor) else torch.empty((0, tts_pad_embed.shape[-1])) + + last_hidden_cpu = info_dict.get("last_talker_hidden") + if not isinstance(last_hidden_cpu, torch.Tensor): + raise RuntimeError("Missing `last_talker_hidden` in additional_information; postprocess must run.") + past_hidden = last_hidden_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + + # Use OmniGPUModelRunner talker_mtp fast-path for residual codebooks and per-step inputs_embeds update. + last_id_hidden = self.embed_input_ids(input_ids.reshape(1, 1).to(torch.long)).to( + device=input_ids.device, dtype=torch.bfloat16 + ) + inputs_embeds_out = last_id_hidden.reshape(1, -1) + + info_update = { + "tailing_text_hidden": new_tail, + "mtp_inputs": (past_hidden, text_step), + "codec_streaming": codec_streaming, + } + return input_ids, inputs_embeds_out, info_update + + def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]: + # Keep the last token hidden for the next decode step's code predictor. + if hidden_states.numel() == 0: + return {} + last = hidden_states[-1, :].detach().to("cpu").contiguous() + return {"last_talker_hidden": last} + + # -------------------- prompt construction helpers -------------------- + + def _get_tokenizer(self): + if self._tokenizer is None: + self._tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + trust_remote_code=True, + fix_mistral_regex=True, + use_fast=True, + ) + self._tokenizer.padding_side = "left" + return self._tokenizer + + @staticmethod + def _build_assistant_text(text: str) -> str: + return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + + @staticmethod + def _build_ref_text(text: str) -> str: + return f"<|im_start|>assistant\n{text}<|im_end|>\n" + + @staticmethod + def _build_instruct_text(instruct: str) -> str: + return f"<|im_start|>user\n{instruct}<|im_end|>\n" + + @staticmethod + def estimate_prompt_len_from_additional_information( + additional_information: dict[str, Any] | None, + *, + task_type: str, + tokenize_prompt: Callable[[str], list[int]], + codec_language_id: Mapping[str, int] | None, + spk_is_dialect: Mapping[str, object] | None, + estimate_ref_code_len: Callable[[object], int | None] | None = None, + ) -> int: + """Compute Stage-0 placeholder prompt length (length-only mirror of `_build_prompt_embeds()`). + It must match the model-side `inputs_embeds` length to avoid extra padding and quality drop.""" + + def _first(x: object, default: object) -> object: + if isinstance(x, list): + return x[0] if x else default + return x if x is not None else default + + info: dict[str, Any] = additional_information or {} + text = _first(info.get("text"), "") + language = _first(info.get("language"), "Auto") + speaker = _first(info.get("speaker"), "") + instruct = _first(info.get("instruct"), "") + non_streaming_mode_raw = _first(info.get("non_streaming_mode"), None) + + if isinstance(non_streaming_mode_raw, bool): + non_streaming_mode = non_streaming_mode_raw + else: + # Official defaults: CustomVoice/VoiceDesign -> non_streaming_mode=True; Base -> False. + non_streaming_mode = task_type in ("CustomVoice", "VoiceDesign") + + if not isinstance(text, str): + text = "" + if not isinstance(instruct, str): + instruct = "" + if not isinstance(language, str): + language = "Auto" + + instruct_len = 0 + if instruct.strip(): + instruct_text = Qwen3TTSTalkerForConditionalGenerationARVLLM._build_instruct_text(instruct) + instruct_len = len(tokenize_prompt(instruct_text)) + + # ---- codec prefix portion (matches _build_prompt_embeds) ---- + language_id = None + if language.lower() != "auto" and codec_language_id: + language_id = codec_language_id.get(language.lower()) + if ( + language_id is None + and codec_language_id + and spk_is_dialect + and isinstance(language, str) + and language.lower() in ("chinese", "auto") + and isinstance(speaker, str) + and speaker.strip() + ): + dialect = spk_is_dialect.get(speaker.lower()) + if isinstance(dialect, str) and dialect: + language_id = codec_language_id.get(dialect) + prefill_len = 3 if language_id is None else 4 + + speaker_len = 1 if task_type in ("CustomVoice", "Base") else 0 + codec_input_len = prefill_len + speaker_len + 2 # + [codec_pad, codec_bos] + codec_prefix_len = codec_input_len - 1 # codec_input[:-1] + tts_bos + + # Role header: input_ids[:, :3] in model. + role_len = 3 + prompt_len = instruct_len + role_len + codec_prefix_len + + # ---- text conditioning portion (matches _build_prompt_embeds) ---- + assistant_text = Qwen3TTSTalkerForConditionalGenerationARVLLM._build_assistant_text(text) + assistant_len = len(tokenize_prompt(assistant_text)) + if assistant_len < 8: + raise ValueError(f"Unexpected assistant prompt length: {assistant_len}") + + if task_type in ("CustomVoice", "VoiceDesign"): + if non_streaming_mode: + # model: full text ids (input_ids[:, 3:-5]) + eos + codec_bos step + prompt_len += assistant_len - 6 + else: + # model: only first text token in prefill + prompt_len += 1 + + if task_type == "Base": + xvec_only = bool(_first(info.get("x_vector_only_mode"), False)) + in_context_mode = not xvec_only + + voice_clone_prompt = _first(info.get("voice_clone_prompt"), None) + if isinstance(voice_clone_prompt, dict): + icl_flag = _first(voice_clone_prompt.get("icl_mode"), None) + if isinstance(icl_flag, bool): + in_context_mode = icl_flag + + if in_context_mode: + ref_code = None + if isinstance(voice_clone_prompt, dict): + ref_code = _first(voice_clone_prompt.get("ref_code"), None) + + ref_code_len: int | None = None + if isinstance(ref_code, list): + if ref_code and isinstance(ref_code[0], list): + ref_code_len = len(ref_code) + elif ref_code: + ref_code_len = len(ref_code) + elif hasattr(ref_code, "shape"): + try: + shape = getattr(ref_code, "shape") + if shape and len(shape) >= 1: + ref_code_len = int(shape[0]) + except Exception: + ref_code_len = None + + if ref_code_len is None and estimate_ref_code_len is not None: + ref_code_len = estimate_ref_code_len(info.get("ref_audio")) + + if ref_code_len is None: + raise ValueError( + "Base in-context voice cloning requires either `voice_clone_prompt.ref_code` " + "or a readable `ref_audio` that can be mapped to a codec frame length." + ) + + codec_lens = 1 + int(ref_code_len) # codec_bos + ref_code + if non_streaming_mode: + # _generate_icl_prompt(non_streaming_mode=True): + # text_embed = ref_ids + text_ids + eos. + ref_ids = _first(info.get("ref_ids"), None) + if isinstance(voice_clone_prompt, dict) and ref_ids is None: + ref_ids = _first(voice_clone_prompt.get("ref_ids") or voice_clone_prompt.get("ref_id"), None) + + if ref_ids is None: + ref_text = _first(info.get("ref_text"), "") + if not isinstance(ref_text, str) or not ref_text.strip(): + raise ValueError( + "Base in-context non-streaming requires `ref_text` or tokenized `ref_ids`." + ) + ref_text_ids = tokenize_prompt( + Qwen3TTSTalkerForConditionalGenerationARVLLM._build_ref_text(ref_text) + ) + ref_ids_len = len(ref_text_ids) + elif hasattr(ref_ids, "shape"): + shape = getattr(ref_ids, "shape", None) + ref_ids_len = int(shape[-1]) if shape else 0 + elif isinstance(ref_ids, list): + ref_ids_len = len(ref_ids) + else: + ref_ids_len = 0 + + # model uses ref_ids[:, 3:-2] (strip 5 tokens) and text_id=input_ids[:, 3:-5] (strip 8). + ref_id_len = max(0, int(ref_ids_len) - 5) + text_id_len = max(0, int(assistant_len) - 8) + text_embed_len = ref_id_len + text_id_len + 1 # + eos + prompt_len += text_embed_len + codec_lens + else: + # _generate_icl_prompt(non_streaming_mode=False): aligned to codec_lens. + prompt_len += codec_lens + else: + # Base without ICL behaves like CustomVoice. + if non_streaming_mode: + prompt_len += assistant_len - 6 + else: + prompt_len += 1 + + return max(2, int(prompt_len)) + + def _is_probably_base64(self, s: str) -> bool: + if s.startswith("data:audio"): + return True + if ("/" not in s and "\\" not in s) and len(s) > 256: + return True + return False + + def _is_url(self, s: str) -> bool: + try: + u = urlparse(s) + return u.scheme in ("http", "https") and bool(u.netloc) + except Exception: + return False + + def _decode_base64_to_wav_bytes(self, b64: str) -> bytes: + if "," in b64 and b64.strip().startswith("data:"): + b64 = b64.split(",", 1)[1] + return base64.b64decode(b64) + + def _load_audio_to_np(self, x: str) -> tuple[np.ndarray, int]: + import librosa + + if self._is_url(x): + with urlopen(x) as resp: + audio_bytes = resp.read() + with io.BytesIO(audio_bytes) as f: + audio, sr = sf.read(f, dtype="float32", always_2d=False) + elif self._is_probably_base64(x): + wav_bytes = self._decode_base64_to_wav_bytes(x) + with io.BytesIO(wav_bytes) as f: + audio, sr = sf.read(f, dtype="float32", always_2d=False) + else: + audio, sr = librosa.load(x, sr=None, mono=True) + + if isinstance(audio, np.ndarray) and audio.ndim > 1: + audio = np.mean(audio, axis=-1) + + return np.asarray(audio, dtype=np.float32), int(sr) + + def _normalize_ref_audio(self, ref_audio: object) -> tuple[np.ndarray, int]: + # NOTE: additional_information may serialize (wav, sr) into (nested) lists across processes; be tolerant. + if isinstance(ref_audio, str): + return self._load_audio_to_np(ref_audio) + + def _is_sr(x: object) -> bool: + try: + v = int(x) # type: ignore[arg-type] + except Exception: + return False + return 1_000 <= v <= 200_000 + + def _is_number_sequence(xs: list[object]) -> bool: + if not xs: + return False + for v in xs[:8]: + if not isinstance(v, (int, float, np.number)): + return False + return True + + wav_candidates: list[object] = [] + sr_candidates: list[int] = [] + + def _summarize(obj: object, depth: int = 0) -> str: + if depth > 2: + if isinstance(obj, (int, np.integer)): + return f"int({int(obj)})" + return type(obj).__name__ + if obj is None: + return "None" + if isinstance(obj, str): + if len(obj) <= 16: + return f"str({obj!r})" + return f"str(len={len(obj)})" + if isinstance(obj, (int, float, np.number)): + return f"{type(obj).__name__}({obj})" + if isinstance(obj, np.ndarray): + return f"ndarray(shape={obj.shape}, dtype={obj.dtype})" + if isinstance(obj, torch.Tensor): + return f"Tensor(shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device})" + if isinstance(obj, dict): + keys = list(obj.keys()) + return f"dict(keys={keys[:8]})" + if isinstance(obj, (tuple, list)): + items = list(obj) + head = ", ".join(_summarize(x, depth + 1) for x in items[:3]) + return f"{type(obj).__name__}(len={len(items)}; head=[{head}])" + return f"{type(obj).__name__}" + + def _scan(obj: object, depth: int = 0) -> None: + if depth > 4: + return + if obj is None: + return + if _is_sr(obj): + sr_candidates.append(int(obj)) # type: ignore[arg-type] + return + if isinstance(obj, np.ndarray) and obj.size > 0: + wav_candidates.append(obj) + return + if isinstance(obj, torch.Tensor) and obj.numel() > 0: + wav_candidates.append(obj) + return + if isinstance(obj, dict): + # Inlined ndarray/tensor payloads from OmniInputProcessor. + if obj.get("__ndarray__") and "data" in obj and "dtype" in obj and "shape" in obj: + try: + data = obj["data"] + dtype = obj["dtype"] + shape = obj["shape"] + if isinstance(data, (bytes, bytearray, memoryview)): + arr = np.frombuffer(data, dtype=dtype).reshape(shape) + if arr.size > 0: + wav_candidates.append(arr) + return + except Exception: + pass + if obj.get("__tensor__") and "data" in obj and "dtype" in obj and "shape" in obj: + try: + data = obj["data"] + dtype = obj["dtype"] + shape = obj["shape"] + if isinstance(data, (bytes, bytearray, memoryview)): + # Stored as raw CPU bytes; interpret as numpy for audio. + np_dtype = np.dtype(dtype) + arr = np.frombuffer(data, dtype=np_dtype).reshape(shape) + if arr.size > 0: + wav_candidates.append(arr) + return + except Exception: + pass + wav_obj = obj.get("array") or obj.get("wav") or obj.get("audio") + sr_obj = obj.get("sampling_rate") or obj.get("sr") or obj.get("sample_rate") + if wav_obj is not None: + _scan(wav_obj, depth + 1) + if sr_obj is not None: + _scan(sr_obj, depth + 1) + return + if isinstance(obj, (tuple, list)): + obj_list = list(obj) + # Unwrap singleton nesting ([[wav, sr]]). + while isinstance(obj_list, list) and len(obj_list) == 1: + inner = obj_list[0] + if isinstance(inner, np.ndarray) and inner.size > 0: + wav_candidates.append(inner) + return + if isinstance(inner, torch.Tensor) and inner.numel() > 0: + wav_candidates.append(inner) + return + if isinstance(inner, dict): + _scan(inner, depth + 1) + return + if isinstance(inner, (tuple, list)): + obj_list = list(inner) # type: ignore[list-item] + continue + break + + # If the *unwrapped* list is a long list of numbers, treat it as waveform. + if len(obj_list) >= 512 and _is_number_sequence(obj_list): + wav_candidates.append(obj_list) + return + + # If this is a long list of numbers, treat it as waveform and stop. + if isinstance(obj, list) and len(obj) >= 512 and _is_number_sequence(obj_list): # type: ignore[arg-type] + wav_candidates.append(obj) + return + + # Otherwise, recurse into elements (but avoid descending into huge numeric lists). + for item in obj_list: + if isinstance(item, list) and len(item) >= 512 and _is_number_sequence(item): # type: ignore[arg-type] + wav_candidates.append(item) + continue + _scan(item, depth + 1) + return + + _scan(ref_audio) + if not sr_candidates: + raise TypeError(f"ref_audio missing sample_rate: {_summarize(ref_audio)}") + sr = int(sr_candidates[0]) + + def _wav_len(x: object) -> int: + try: + if isinstance(x, np.ndarray): + return int(x.size) + if isinstance(x, torch.Tensor): + return int(x.numel()) + if isinstance(x, list): + return int(len(x)) + except Exception: + pass + return 0 + + if not wav_candidates: + raise TypeError(f"ref_audio missing waveform: {_summarize(ref_audio)}") + wav_obj = max(wav_candidates, key=_wav_len) + + def _to_np(x: object) -> np.ndarray: + if isinstance(x, np.ndarray): + return x.astype(np.float32).reshape(-1) + if isinstance(x, torch.Tensor): + return x.detach().to("cpu").float().contiguous().numpy().reshape(-1) + if isinstance(x, dict) and x.get("__ndarray__") and "data" in x and "dtype" in x and "shape" in x: + data = x["data"] + dtype = x["dtype"] + shape = x["shape"] + if isinstance(data, (bytes, bytearray, memoryview)): + return np.frombuffer(data, dtype=dtype).reshape(shape).astype(np.float32).reshape(-1) + if isinstance(x, list): + # list of numbers + if len(x) >= 2 and _is_number_sequence(x): # type: ignore[arg-type] + return np.asarray(x, dtype=np.float32).reshape(-1) + # list of chunks + parts: list[np.ndarray] = [] + for part in x: + if isinstance(part, (np.ndarray, torch.Tensor, list)): + parts.append(_to_np(part)) + if parts: + return np.concatenate(parts, axis=0) + raise TypeError(f"Unsupported waveform type: {type(x)}") + + wav_np = _to_np(wav_obj) + if wav_np.size < 1024: + raise ValueError(f"ref_audio waveform too short: {wav_np.size} samples") + return wav_np, sr + raise TypeError(f"Unsupported ref_audio type: {type(ref_audio)}") + + def _extract_speaker_embedding(self, wav: np.ndarray, sr: int) -> torch.Tensor: + if self.speaker_encoder is None: + raise ValueError( + "This checkpoint does not provide `speaker_encoder` weights; " + "cannot compute ref_spk_embedding from ref_audio." + ) + # vLLM workers do not automatically move arbitrary torch.nn.Modules to + # CUDA. Ensure the speaker encoder is on the same device/dtype as the + # main model before running it. + dev = next(self.parameters()).device + try: + spk_param = next(self.speaker_encoder.parameters()) + if spk_param.device != dev or spk_param.dtype != torch.bfloat16: + self.speaker_encoder.to(device=dev, dtype=torch.bfloat16) + except StopIteration: + pass + # Resample to 24kHz for speaker encoder. + target_sr = int(getattr(self.config.speaker_encoder_config, "sample_rate", 24000)) + if sr != target_sr: + import librosa + + wav = librosa.resample(y=wav.astype(np.float32), orig_sr=int(sr), target_sr=target_sr) + sr = target_sr + + # Follow official implementation: mel_spectrogram expects 24kHz. + mels = mel_spectrogram( + torch.from_numpy(wav).unsqueeze(0), + n_fft=1024, + num_mels=128, + sampling_rate=24000, + hop_size=256, + win_size=1024, + fmin=0, + fmax=12000, + ).transpose(1, 2) + spk = self.speaker_encoder(mels.to(dev, dtype=torch.bfloat16))[0] + return spk.to(dtype=torch.bfloat16) + + def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: + if self._speech_tokenizer is not None: + return self._speech_tokenizer + speech_tokenizer_path = cached_file(self.model_path, "speech_tokenizer/config.json") + if speech_tokenizer_path is None: + raise ValueError(f"{self.model_path}/speech_tokenizer/config.json not found") + # Ensure the HF feature extractor config is present. Transformers' + # AutoFeatureExtractor does not proactively fetch this file. + preprocessor_config_path = cached_file(self.model_path, "speech_tokenizer/preprocessor_config.json") + if preprocessor_config_path is None: + raise ValueError(f"{self.model_path}/speech_tokenizer/preprocessor_config.json not found") + speech_tokenizer_dir = os.path.dirname(speech_tokenizer_path) + tok = Qwen3TTSTokenizer.from_pretrained( + speech_tokenizer_dir, + torch_dtype=torch.bfloat16, + ) + # Prefer GPU for encoder if available; otherwise keep CPU. + dev = next(self.parameters()).device + if getattr(dev, "type", None) == "cuda": + try: + tok.model.to(dev) + tok.device = dev + except Exception as e: + raise RuntimeError(f"Failed to move speech tokenizer to {dev}: {e}") from e + else: + tok.device = dev + self._speech_tokenizer = tok + return tok + + def _encode_ref_audio_to_code(self, wav: np.ndarray, sr: int) -> torch.Tensor: + tok = self._ensure_speech_tokenizer_loaded() + enc = tok.encode(wav, sr=int(sr), return_dict=True) + ref_code = getattr(enc, "audio_codes", None) + if isinstance(ref_code, list): + ref_code = ref_code[0] if ref_code else None + if isinstance(ref_code, torch.Tensor): + # 12Hz: likely [T, Q] or [B, T, Q] + if ref_code.ndim == 3: + ref_code = ref_code[0] + return ref_code.to(device=next(self.parameters()).device, dtype=torch.long) + raise ValueError("SpeechTokenizer.encode did not return audio_codes tensor") + + def _generate_icl_prompt( + self, + *, + text_id: torch.Tensor, + ref_id: torch.Tensor, + ref_code: torch.Tensor, + tts_pad_embed: torch.Tensor, + tts_eos_embed: torch.Tensor, + non_streaming_mode: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Ported from official `generate_icl_prompt` in modeling_qwen3_tts.py + text_embed = self.text_projection(self.text_embedding(torch.cat([ref_id, text_id], dim=-1))) + text_embed = torch.cat([text_embed, tts_eos_embed], dim=1) + + # codec embed (codec bos + codec) 1 T2 D + codec_embed: list[torch.Tensor] = [] + for i in range(int(self.talker_config.num_code_groups)): + if i == 0: + codec_embed.append(self.embed_input_ids(ref_code[:, :1])) + else: + codec_embed.append(self.code_predictor.get_input_embeddings()[i - 1](ref_code[:, i : i + 1])) + codec_embed_sum = torch.cat(codec_embed, dim=1).sum(1).unsqueeze(0) # [1,T,H] + codec_embed_sum = torch.cat( + [ + self.embed_input_ids( + torch.tensor([[self.talker_config.codec_bos_id]], device=codec_embed_sum.device, dtype=torch.long) + ), + codec_embed_sum, + ], + dim=1, + ) + + text_lens = int(text_embed.shape[1]) + codec_lens = int(codec_embed_sum.shape[1]) + if non_streaming_mode: + # Official non-streaming mode: append the full text conditioning in + # prefill, and use PAD in decode steps. + icl_input_embed = text_embed + self.embed_input_ids( + torch.tensor( + [[self.talker_config.codec_pad_id] * text_lens], + device=codec_embed_sum.device, + dtype=torch.long, + ) + ) + icl_input_embed = torch.cat([icl_input_embed, codec_embed_sum + tts_pad_embed], dim=1) + return icl_input_embed, tts_pad_embed + if text_lens > codec_lens: + return text_embed[:, :codec_lens] + codec_embed_sum, text_embed[:, codec_lens:] + text_embed = torch.cat([text_embed] + [tts_pad_embed] * (codec_lens - text_lens), dim=1) + return text_embed + codec_embed_sum, tts_pad_embed + + def _build_prompt_embeds( + self, + *, + task_type: str, + info_dict: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int | None]: + text = (info_dict.get("text") or [""])[0] + language = (info_dict.get("language") or ["Auto"])[0] + non_streaming_mode_val = info_dict.get("non_streaming_mode") + if isinstance(non_streaming_mode_val, list): + non_streaming_mode_raw = non_streaming_mode_val[0] if non_streaming_mode_val else None + else: + non_streaming_mode_raw = non_streaming_mode_val + if isinstance(non_streaming_mode_raw, bool): + non_streaming_mode = non_streaming_mode_raw + else: + # Match official inference defaults: + # - CustomVoice/VoiceDesign: non_streaming_mode=True + # - Base: non_streaming_mode=False + non_streaming_mode = task_type in ("CustomVoice", "VoiceDesign") + + # Text ids for assistant template (always). + tok = self._get_tokenizer() + input_ids = tok(self._build_assistant_text(text), return_tensors="pt", padding=False)["input_ids"].to( + device=next(self.parameters()).device + ) + + # Optional instruct prefix. + instruct = (info_dict.get("instruct") or [""])[0] + instruct_embed = None + if isinstance(instruct, str) and instruct.strip(): + instruct_ids = tok(self._build_instruct_text(instruct), return_tensors="pt", padding=False)["input_ids"].to( + device=input_ids.device + ) + instruct_embed = self.text_projection(self.text_embedding(instruct_ids)) + + # tts special token embeds (projected into talker hidden). + tts_tokens = torch.tensor( + [[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]], + device=input_ids.device, + dtype=input_ids.dtype, + ) + tts_bos_embed, tts_eos_embed, tts_pad_embed = self.text_projection(self.text_embedding(tts_tokens)).chunk( + 3, dim=1 + ) + + # Codec prefill tags. + language_id = None + if isinstance(language, str) and language.lower() != "auto": + language_id = self.talker_config.codec_language_id.get(language.lower()) + # Match official dialect override: + # If language is Chinese/Auto and the selected speaker is a dialect voice, + # set language_id to that dialect to improve code generation stability. + if language_id is None and isinstance(language, str) and language.lower() in ("chinese", "auto"): + speaker_for_dialect = None + if task_type == "CustomVoice": + speaker_for_dialect = (info_dict.get("speaker") or [""])[0] + if isinstance(speaker_for_dialect, str) and speaker_for_dialect.strip(): + spk_is_dialect = getattr(self.talker_config, "spk_is_dialect", None) or {} + dialect = spk_is_dialect.get(speaker_for_dialect.lower()) + if isinstance(dialect, str) and dialect: + language_id = self.talker_config.codec_language_id.get(dialect) + if language_id is None: + codec_prefill_list = [ + [ + self.talker_config.codec_nothink_id, + self.talker_config.codec_think_bos_id, + self.talker_config.codec_think_eos_id, + ] + ] + else: + codec_prefill_list = [ + [ + self.talker_config.codec_think_id, + self.talker_config.codec_think_bos_id, + int(language_id), + self.talker_config.codec_think_eos_id, + ] + ] + + codec_input_0 = self.embed_input_ids( + torch.tensor(codec_prefill_list, device=input_ids.device, dtype=torch.long) + ) + codec_input_1 = self.embed_input_ids( + torch.tensor([[self.talker_config.codec_pad_id, self.talker_config.codec_bos_id]], device=input_ids.device) + ) + + # Speaker embedding/token (task-dependent) + speaker_embed = None + ref_code_len: int | None = None + + def _as_singleton(x: object) -> object: + if isinstance(x, list): + return x[0] if x else None + return x + + def _to_long_tensor(x: object, *, device: torch.device) -> torch.Tensor | None: + x = _as_singleton(x) + if x is None: + return None + if isinstance(x, torch.Tensor): + t = x + elif isinstance(x, np.ndarray): + t = torch.from_numpy(x) + elif isinstance(x, list) and x and all(isinstance(v, (int, np.integer)) for v in x): + t = torch.tensor(x, dtype=torch.long) + else: + return None + if t.ndim == 1: + t = t.unsqueeze(0) + return t.to(device=device, dtype=torch.long) + + def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: + raw = _as_singleton(raw) + if raw is None: + return None + if isinstance(raw, dict): + return raw + # Some callers may pass list[dict] directly. + if isinstance(raw, list) and raw and isinstance(raw[0], dict): + return raw[0] + return None + + if task_type == "Base": + # Base supports voice clone prompt with in-context mode. + xvec_only = bool((info_dict.get("x_vector_only_mode") or [False])[0]) + in_context_mode = not xvec_only + voice_clone_prompt = _normalize_voice_clone_prompt(info_dict.get("voice_clone_prompt")) + # Official implementation may pass `voice_clone_prompt.icl_mode`. + if voice_clone_prompt is not None and "icl_mode" in voice_clone_prompt: + icl_flag = _as_singleton(voice_clone_prompt.get("icl_mode")) + if isinstance(icl_flag, bool): + in_context_mode = icl_flag + xvec_only = not in_context_mode + ref_code = None + if voice_clone_prompt is not None: + ref_code = _as_singleton(voice_clone_prompt.get("ref_code")) + ref_code_t = None + if isinstance(ref_code, torch.Tensor): + ref_code_t = ref_code + elif isinstance(ref_code, np.ndarray): + ref_code_t = torch.from_numpy(ref_code) + if isinstance(ref_code_t, torch.Tensor): + if ref_code_t.ndim == 3: + ref_code_t = ref_code_t[0] + ref_code_t = ref_code_t.to(device=input_ids.device, dtype=torch.long) + ref_code_len = int(ref_code_t.shape[0]) + elif in_context_mode: + # Compute ref_code from ref_audio if not provided. + ref_audio_list = info_dict.get("ref_audio") + if not isinstance(ref_audio_list, list) or not ref_audio_list: + raise ValueError("Base requires `ref_audio`.") + wav_np, sr = self._normalize_ref_audio(ref_audio_list[0]) + ref_code_t = self._encode_ref_audio_to_code(wav_np, sr).to(device=input_ids.device) + ref_code_len = int(ref_code_t.shape[0]) + + # Speaker embedding: use prompt embed if provided; otherwise extract from audio. + spk = None + if voice_clone_prompt is not None: + spk = _as_singleton(voice_clone_prompt.get("ref_spk_embedding")) + if isinstance(spk, torch.Tensor): + speaker_embed = spk.to(device=input_ids.device, dtype=torch.bfloat16).view(1, 1, -1) + else: + ref_audio_list = info_dict.get("ref_audio") + if not isinstance(ref_audio_list, list) or not ref_audio_list: + raise ValueError("Base requires `ref_audio`.") + wav_np, sr = self._normalize_ref_audio(ref_audio_list[0]) + speaker_embed = self._extract_speaker_embedding(wav_np, sr).view(1, 1, -1) + + codec_input = torch.cat([codec_input_0, speaker_embed, codec_input_1], dim=1) + + # Role header (<|im_start|>assistant\n) -> projected text embeds. + role_embed = self.text_projection(self.text_embedding(input_ids[:, :3])) + + codec_prefix = torch.cat((tts_pad_embed.expand(-1, codec_input.shape[1] - 2, -1), tts_bos_embed), dim=1) + codec_prefix = codec_prefix + codec_input[:, :-1] + talker_prompt = torch.cat((role_embed, codec_prefix), dim=1) + + if in_context_mode: + # Prefer explicit tokenized `ref_ids` if provided (matches official signature). + ref_ids = _to_long_tensor(info_dict.get("ref_ids"), device=input_ids.device) + if ref_ids is None and voice_clone_prompt is not None: + ref_ids = _to_long_tensor( + voice_clone_prompt.get("ref_ids") or voice_clone_prompt.get("ref_id"), device=input_ids.device + ) + if ref_ids is None: + ref_text = _as_singleton(info_dict.get("ref_text")) + if not isinstance(ref_text, str) or not ref_text.strip(): + raise ValueError("Base in-context voice cloning requires `ref_text` or tokenized `ref_ids`.") + ref_ids = tok(self._build_ref_text(ref_text), return_tensors="pt", padding=False)["input_ids"].to( + device=input_ids.device + ) + icl_input_embed, trailing_text_hidden = self._generate_icl_prompt( + text_id=input_ids[:, 3:-5], + ref_id=ref_ids[:, 3:-2], + ref_code=ref_code_t, # type: ignore[arg-type] + tts_pad_embed=tts_pad_embed, + tts_eos_embed=tts_eos_embed, + non_streaming_mode=non_streaming_mode, + ) + talker_prompt = torch.cat([talker_prompt, icl_input_embed], dim=1) + else: + # First text token (+ codec_bos). + if non_streaming_mode: + # Official non-streaming mode: put the full text into the + # prefill prompt and use PAD for decode steps. + text_all = self.text_projection(self.text_embedding(input_ids[:, 3:-5])) + text_all = torch.cat([text_all, tts_eos_embed], dim=1) + pad_ids = torch.full( + (1, int(text_all.shape[1])), + int(self.talker_config.codec_pad_id), + device=input_ids.device, + dtype=torch.long, + ) + talker_prompt = torch.cat( + [ + talker_prompt, + text_all + self.embed_input_ids(pad_ids), + tts_pad_embed + + self.embed_input_ids( + torch.tensor([[self.talker_config.codec_bos_id]], device=input_ids.device) + ), + ], + dim=1, + ) + trailing_text_hidden = tts_pad_embed + else: + first_text = self.text_projection(self.text_embedding(input_ids[:, 3:4])) + codec_input[:, -1:] + talker_prompt = torch.cat([talker_prompt, first_text], dim=1) + trailing_text_hidden = torch.cat( + ( + self.text_projection(self.text_embedding(input_ids[:, 4:-5])), + tts_eos_embed, + ), + dim=1, + ) + + elif task_type == "CustomVoice": + speaker = (info_dict.get("speaker") or [""])[0] + if not isinstance(speaker, str) or not speaker.strip(): + raise ValueError("CustomVoice requires additional_information.speaker.") + spk_id_map = getattr(self.talker_config, "spk_id", None) or {} + if speaker.lower() not in spk_id_map: + raise ValueError(f"Unsupported speaker: {speaker}") + spk_id = spk_id_map[speaker.lower()] + # Keep it at least 1D; embedding on a 0-d tensor can return 1D. + spk_tensor = torch.tensor([spk_id], device=input_ids.device, dtype=torch.long) + spk_embed = self.embed_input_ids(spk_tensor) + if spk_embed.ndim == 1: + spk_embed = spk_embed.view(1, 1, -1) + elif spk_embed.ndim == 2: + spk_embed = spk_embed.view(1, 1, -1) + speaker_embed = spk_embed + codec_input = torch.cat([codec_input_0, speaker_embed, codec_input_1], dim=1) + + role_embed = self.text_projection(self.text_embedding(input_ids[:, :3])) + codec_prefix = torch.cat((tts_pad_embed.expand(-1, codec_input.shape[1] - 2, -1), tts_bos_embed), dim=1) + codec_prefix = codec_prefix + codec_input[:, :-1] + talker_prompt = torch.cat((role_embed, codec_prefix), dim=1) + + if non_streaming_mode: + text_all = self.text_projection(self.text_embedding(input_ids[:, 3:-5])) + text_all = torch.cat([text_all, tts_eos_embed], dim=1) + pad_ids = torch.full( + (1, int(text_all.shape[1])), + int(self.talker_config.codec_pad_id), + device=input_ids.device, + dtype=torch.long, + ) + talker_prompt = torch.cat( + [ + talker_prompt, + text_all + self.embed_input_ids(pad_ids), + tts_pad_embed + + self.embed_input_ids( + torch.tensor([[self.talker_config.codec_bos_id]], device=input_ids.device) + ), + ], + dim=1, + ) + trailing_text_hidden = tts_pad_embed + else: + first_text = self.text_projection(self.text_embedding(input_ids[:, 3:4])) + codec_input[:, -1:] + talker_prompt = torch.cat([talker_prompt, first_text], dim=1) + trailing_text_hidden = torch.cat( + ( + self.text_projection(self.text_embedding(input_ids[:, 4:-5])), + tts_eos_embed, + ), + dim=1, + ) + + elif task_type == "VoiceDesign": + # No known speaker identity; only codec tags + text. + codec_input = torch.cat([codec_input_0, codec_input_1], dim=1) + + role_embed = self.text_projection(self.text_embedding(input_ids[:, :3])) + codec_prefix = torch.cat((tts_pad_embed.expand(-1, codec_input.shape[1] - 2, -1), tts_bos_embed), dim=1) + codec_prefix = codec_prefix + codec_input[:, :-1] + talker_prompt = torch.cat((role_embed, codec_prefix), dim=1) + + if non_streaming_mode: + text_all = self.text_projection(self.text_embedding(input_ids[:, 3:-5])) + text_all = torch.cat([text_all, tts_eos_embed], dim=1) + pad_ids = torch.full( + (1, int(text_all.shape[1])), + int(self.talker_config.codec_pad_id), + device=input_ids.device, + dtype=torch.long, + ) + talker_prompt = torch.cat( + [ + talker_prompt, + text_all + self.embed_input_ids(pad_ids), + tts_pad_embed + + self.embed_input_ids( + torch.tensor([[self.talker_config.codec_bos_id]], device=input_ids.device) + ), + ], + dim=1, + ) + trailing_text_hidden = tts_pad_embed + else: + first_text = self.text_projection(self.text_embedding(input_ids[:, 3:4])) + codec_input[:, -1:] + talker_prompt = torch.cat([talker_prompt, first_text], dim=1) + trailing_text_hidden = torch.cat( + ( + self.text_projection(self.text_embedding(input_ids[:, 4:-5])), + tts_eos_embed, + ), + dim=1, + ) + else: + raise ValueError(f"Unsupported task_type={task_type}") + + if instruct_embed is not None: + talker_prompt = torch.cat([instruct_embed, talker_prompt], dim=1) + + return ( + talker_prompt.squeeze(0), # [prompt_len, H] + trailing_text_hidden.squeeze(0), # [T, H] + tts_pad_embed.squeeze(0), # [1, H] + ref_code_len, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Consume talker weights, and conditionally consume speaker encoder + # weights only if they are present in the checkpoint. + speaker_weights: list[tuple[str, torch.Tensor]] = [] + + def _talker_and_collect_speaker(ws: Iterable[tuple[str, torch.Tensor]]): + for k, v in ws: + if k.startswith("speaker_encoder."): + speaker_weights.append((k, v)) + continue + if k.startswith("talker."): + yield k, v + + loader = AutoWeightsLoader(self) + loaded = loader.load_weights(_talker_and_collect_speaker(weights), mapper=self.hf_to_vllm_mapper) + + if speaker_weights: + if self.speaker_encoder is None: + self.speaker_encoder = Qwen3TTSSpeakerEncoder(self.config.speaker_encoder_config) + loaded |= loader.load_weights(speaker_weights, mapper=self.hf_to_vllm_mapper) + logger.info("Loaded %d weights for Qwen3TTSTalkerForConditionalGenerationARVLLM", len(loaded)) + return loaded + + # -------------------- GPU-side MTP fast-path -------------------- + + @torch.inference_mode() + def talker_mtp( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor, + last_talker_hidden: torch.Tensor, + text_step: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """GPU fast-path used by OmniGPUModelRunner to predict residual codebooks (1..Q-1). + Returns (inputs_embeds, audio_codes) for the current step.""" + bsz = int(input_ids.shape[0]) + q = int(self.talker_config.num_code_groups) + dev = input_embeds.device + + input_ids = input_ids.reshape(bsz, 1).to(dtype=torch.long, device=dev) + last_id_hidden = input_embeds.reshape(bsz, 1, -1).to(dtype=torch.bfloat16, device=dev) + past_hidden = last_talker_hidden.reshape(bsz, 1, -1).to(dtype=torch.bfloat16, device=dev) + text_step = text_step.reshape(bsz, 1, -1).to(dtype=torch.bfloat16, device=dev) + + # Residual predictor runs fixed-length (Q-1) steps via the vLLM-native code_predictor. + max_steps = q - 1 + if max_steps <= 0: + audio_codes = input_ids.reshape(bsz, 1) + return (last_id_hidden + text_step).reshape(bsz, -1), audio_codes + + # Subtalker sampling defaults (match official defaults). + do_sample = True + top_k = 50 + top_p = 1.0 + temperature = 0.9 + + def _sample_next(logits: torch.Tensor) -> torch.Tensor: + # logits: [B,V] + if temperature and float(temperature) > 0: + logits = logits / float(temperature) + if top_k and int(top_k) > 0 and int(top_k) < logits.shape[-1]: + v, _ = torch.topk(logits, int(top_k), dim=-1) + min_keep = v[:, -1].unsqueeze(-1) + logits = torch.where(logits < min_keep, torch.tensor(float("-inf"), device=logits.device), logits) + if top_p is not None and 0.0 < float(top_p) < 1.0: + sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) + probs = torch.softmax(sorted_logits, dim=-1) + cum = torch.cumsum(probs, dim=-1) + remove = cum > float(top_p) + remove[:, 0] = False + sorted_logits = torch.where(remove, torch.tensor(float("-inf"), device=logits.device), sorted_logits) + logits = torch.empty_like(logits).scatter(-1, sorted_idx, sorted_logits) + if not do_sample: + return torch.argmax(logits, dim=-1, keepdim=True) + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1) + + predictor_inputs = torch.cat([past_hidden, last_id_hidden], dim=1) # [B,2,H] + self.code_predictor.reset_cache() + tok = _sample_next(self.code_predictor.prefill_logits(predictor_inputs)) + residual_ids = [tok] + past_seq_len = 2 + for step in range(1, max_steps): + logits = self.code_predictor.decode_logits(tok, generation_step=step, past_seq_len=past_seq_len) + tok = _sample_next(logits) + residual_ids.append(tok) + past_seq_len += 1 + + residual_ids_t = torch.cat(residual_ids, dim=1).to(dtype=torch.long, device=dev) # [B, Q-1] + audio_codes = torch.cat([input_ids, residual_ids_t], dim=1) # [B,Q] + + # Map invalid layer-0 ids (e.g. EOS) to PAD=0 so SpeechTokenizer sees only real codes. + # vLLM still uses EOS for stopping. + layer0 = audio_codes[:, :1] + invalid0 = (layer0 < 0) | (layer0 >= int(self._codebook_vocab_size)) + if invalid0.any(): + audio_codes = torch.where(invalid0.expand_as(audio_codes), torch.zeros_like(audio_codes), audio_codes) + + # Sum embeddings of all code groups, then add the current text step. + embeds: list[torch.Tensor] = [last_id_hidden] + for i in range(max_steps): + embeds.append(self.code_predictor.get_input_embeddings()[i](residual_ids_t[:, i : i + 1])) + summed = torch.cat(embeds, dim=1).sum(1, keepdim=True) # [B,1,H] + inputs_embeds_out = (summed + text_step).reshape(bsz, -1) + return inputs_embeds_out, audio_codes.to(dtype=torch.long) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py index e6e50211988..30c0d832d59 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py @@ -14,8 +14,11 @@ # limitations under the License. import base64 import io +import json +from pathlib import Path import urllib.request from urllib.parse import urlparse +from typing import Any import librosa import numpy as np @@ -23,17 +26,13 @@ import torch from torch.nn.utils.rnn import pad_sequence from transformers import AutoConfig, AutoFeatureExtractor, AutoModel +from transformers.utils.hub import cached_file from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import ( Qwen3TTSTokenizerV2EncoderOutput, Qwen3TTSTokenizerV2Model, ) -from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config -from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import ( - Qwen3TTSTokenizerV1EncoderOutput, - Qwen3TTSTokenizerV1Model, -) AudioInput = ( str # wav path, or base64 string @@ -62,6 +61,18 @@ def __init__(self): self.config = None self.device = None + @staticmethod + def _resolve_local_config_path(pretrained_model_name_or_path: str) -> str: + p = Path(pretrained_model_name_or_path) + if p.is_dir(): + cfg = p / "config.json" + if cfg.exists(): + return str(cfg) + cfg = cached_file(pretrained_model_name_or_path, "config.json") + if cfg is None: + raise ValueError(f"config.json not found under {pretrained_model_name_or_path!r}") + return str(cfg) + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3TTSTokenizer": """ @@ -80,16 +91,45 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3 """ inst = cls() - AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config) - AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model) + load_feature_extractor = bool(kwargs.pop("load_feature_extractor", True)) + # Register 12Hz tokenizer (no optional deps). AutoConfig.register("qwen3_tts_tokenizer_12hz", Qwen3TTSTokenizerV2Config) AutoModel.register(Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2Model) - inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path) + # Register 25Hz tokenizer only when needed (avoids importing optional 25Hz backends for 12Hz models). + cfg_path = cls._resolve_local_config_path(pretrained_model_name_or_path) + try: + cfg_json = json.loads(Path(cfg_path).read_text()) + except Exception as e: + raise ValueError(f"Failed to parse tokenizer config.json at {cfg_path!r}: {e}") from e + + model_type = str(cfg_json.get("model_type") or "") + if model_type == "qwen3_tts_tokenizer_25hz": + from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config + from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Model + + AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config) + AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model) + elif model_type != "qwen3_tts_tokenizer_12hz": + raise ValueError(f"Unsupported Qwen3-TTS tokenizer model_type={model_type!r} at {cfg_path!r}") + inst.model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs) inst.config = inst.model.config + if load_feature_extractor: + try: + inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path) + except Exception as e: + raise ValueError( + "Failed to load Qwen3-TTS speech tokenizer feature extractor. " + "Please make sure the checkpoint contains the required HF " + "preprocessing files (e.g. preprocessor_config.json) under " + "the speech_tokenizer directory." + ) from e + else: + inst.feature_extractor = None + inst.device = getattr(inst.model, "device", None) if inst.device is None: # fallback: infer from first parameter device @@ -180,6 +220,8 @@ def _normalize_audio_inputs( List[np.ndarray]: List of float32 waveforms resampled to model input SR. """ + if self.feature_extractor is None: + raise ValueError("Speech tokenizer feature extractor is not loaded; audio encode is not available.") target_sr = int(self.feature_extractor.sampling_rate) if isinstance(audios, (str, np.ndarray)): @@ -212,12 +254,7 @@ def encode( audios: AudioInput, sr: int | None = None, return_dict: bool = True, - ) -> ( - Qwen3TTSTokenizerV1EncoderOutput - | Qwen3TTSTokenizerV2EncoderOutput - | tuple[list[torch.Tensor], list[torch.Tensor] | None, list[torch.Tensor] | None] - | tuple[list[torch.Tensor]] - ): + ) -> Any: """ Batch-encode audio into discrete codes (and optional conditioning, depending on 25Hz/12Hz). @@ -234,7 +271,7 @@ def encode( Forwarded to model.encode(...). If True, returns ModelOutput. Returns: - Qwen3TTSTokenizerV1EncoderOutput | Qwen3TTSTokenizerV2EncoderOutput | tuple: + Any: Encoder output or tuple returned by model.encode. If return_dict=True, returns a model-specific encoder output. For 25Hz models, this includes audio_codes/xvectors/ref_mels; for 12Hz models, this includes audio_codes. @@ -242,18 +279,24 @@ def encode( """ wavs = self._normalize_audio_inputs(audios, sr=sr) + if self.feature_extractor is None: + raise ValueError("Speech tokenizer feature extractor is not loaded; audio encode is not available.") inputs = self.feature_extractor( raw_audio=wavs, sampling_rate=int(self.feature_extractor.sampling_rate), return_tensors="pt", ) - inputs = inputs.to(self.device).to(self.model.dtype) + # Normalize to tensors and keep padding_mask integer (tokenizer expects 0/1). + input_values = inputs["input_values"].squeeze(1).to(self.device).to(self.model.dtype) + padding_mask = inputs["padding_mask"].squeeze(1).to(self.device) + if padding_mask.dtype == torch.bool: + padding_mask = padding_mask.to(torch.long) with torch.inference_mode(): # model.encode expects (B, T) and (B, T) enc = self.model.encode( - inputs["input_values"].squeeze(1), - inputs["padding_mask"].squeeze(1), + input_values, + padding_mask, return_dict=return_dict, ) return enc diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index 747ca8f0cdd..b794ad73167 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -48,10 +48,25 @@ "qwen3_omni_code2wav", "Qwen3OmniMoeCode2Wav", ), - "Qwen3TTSForConditionalGeneration": ( + "Qwen3TTSTalkerForConditionalGenerationARVLLM": ( + "qwen3_tts", + "qwen3_tts_talker_ar", + "Qwen3TTSTalkerForConditionalGenerationARVLLM", + ), + "Qwen3TTSCode2Wav": ( "qwen3_tts", + "qwen3_tts_code2wav", + "Qwen3TTSCode2Wav", + ), + "Qwen3TTSForConditionalGenerationDisaggregatedVLLM": ( + "qwen3_tts", + "qwen3_tts_disaggregated", + "Qwen3TTSForConditionalGenerationDisaggregatedVLLM", + ), + "Qwen3TTSForConditionalGeneration": ( "qwen3_tts", - "Qwen3TTSModelForGeneration", + "qwen3_tts_disaggregated", + "Qwen3TTSForConditionalGenerationDisaggregatedVLLM", ), } diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml index d408dbab91e..1db64fda791 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml @@ -1,22 +1,97 @@ +async_chunk: true stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm runtime: devices: "0" max_batch_size: 1 engine_args: - model_stage: qwen3_tts - model_arch: Qwen3TTSForConditionalGeneration + model_stage: talker + model_arch: Qwen3TTSTalkerForConditionalGenerationARVLLM + # Force stage-specific registered architecture. + hf_overrides: + architectures: [Qwen3TTSTalkerForConditionalGenerationARVLLM] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.3 + distributed_executor_backend: "mp" + max_num_batched_tokens: 512 + max_model_len: 4096 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2speech_tokenizer_async_chunk + # Use named connector to apply runtime.connectors.extra. + output_connectors: + to_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3TTSCode2Wav + # Force stage-specific registered architecture. + hf_overrides: + architectures: [Qwen3TTSCode2Wav] worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler enforce_eager: true trust_remote_code: true async_scheduling: false enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.1 + engine_output_type: audio + gpu_memory_utilization: 0.2 distributed_executor_backend: "mp" - max_num_batched_tokens: 1000000 - + # Must be divisible by num_code_groups and cover (left_context + chunk). + max_num_batched_tokens: 8192 + # async_chunk appends windows per step; max_model_len must cover accumulated stream. + max_model_len: 32768 + engine_input_source: [0] final_output: true final_output_type: audio + # Distributed connector configuration + input_connectors: + from_stage_0: connector_of_shared_memory + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + # Frame-aligned codec streaming transport. + codec_streaming: true + # Match official chunked_decode defaults. + codec_chunk_frames: 300 + codec_left_context_frames: 25 + + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py new file mode 100644 index 00000000000..0ebf9abdbba --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -0,0 +1,84 @@ +"""Stage input processor for Qwen3-TTS: Talker → SpeechTokenizer transition.""" + +from typing import Any + +import torch + + +def talker2speech_tokenizer_async_chunk( + pooling_output: dict[str, Any], + request: Any, +) -> dict[str, Any] | None: + """Async-chunk payload extractor for Qwen3-TTS Talker → SpeechTokenizer. + + Stage-0 emits per-step codec codes; they are sent via connector and consumed by Stage-1 as `prompt_token_ids`. + Returns: `code_predictor_codes` (List[int]) / `codec_streaming` (bool) / `finished` (torch.bool). + """ + if not isinstance(pooling_output, dict): + return None + + # `codec_streaming` is the cross-stage streaming toggle (not the official `non_streaming_mode`). + # It can be overridden per request. + info = getattr(request, "additional_information_cpu", None) + if info is None: + info = getattr(request, "additional_information", None) + # vLLM may pass additional information as a list for batched requests; Qwen3-TTS typically uses batch=1. + if isinstance(info, list) and info and isinstance(info[0], dict): + info = info[0] + if not isinstance(info, dict): + info = {} + + def _first(x: object, default: object) -> object: + if isinstance(x, list): + return x[0] if x else default + return x if x is not None else default + + # In async_chunk, Stage-1 consumes only newly scheduled tokens per step; Stage-0 must stream frame-aligned windows. + # Stage-1 trims left-context each step. + codec_streaming_val = _first(info.get("codec_streaming"), True) + codec_streaming = bool(codec_streaming_val) if isinstance(codec_streaming_val, bool) else True + # Do not override from `pooling_output`: this is a pipeline contract. + # Mis-overrides can break Stage-1 trim/paste rules. + + # The stop-token step is not a decodable frame; only notify Stage-1 via `finished`. + finished = False + try: + finished = bool(request.is_finished()) + except Exception: + finished = False + + if finished: + return { + "code_predictor_codes": [], + "codec_streaming": codec_streaming, + "finished": torch.tensor(True, dtype=torch.bool), + } + + # Talker AR stage exposes per-step codes as `audio_codes` (shape [T, Q]). + audio_codes = pooling_output.get("audio_codes") + if not isinstance(audio_codes, torch.Tensor) or audio_codes.numel() == 0: + # Nothing to send for this step. + return None + + # `audio_codes` may include prefill/placeholder frames (shape [T,Q]); take only the last frame and skip if all-zero. + if audio_codes.ndim == 2: + frame = audio_codes[-1] + try: + if frame.numel() == 0 or not bool(frame.any().item()): + return None + except Exception: + # If `.any()` is unreliable, prefer sending the last frame and let Stage-1 fail-fast on misalignment. + pass + elif audio_codes.ndim == 1: + frame = audio_codes + else: + raise ValueError(f"Invalid audio_codes shape for Qwen3-TTS async_chunk: {tuple(audio_codes.shape)}") + + frame = frame.to(torch.long).reshape(-1) + codec_codes = frame.cpu().tolist() + + return { + "code_predictor_codes": codec_codes, + "codec_streaming": codec_streaming, + "finished": torch.tensor(bool(finished), dtype=torch.bool), + } diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 0747db3ea57..580b2738265 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -1,8 +1,4 @@ -"""Code2Wav GPU Model Runner for vLLM-Omni. - -Handles direct conversion from codec codes to audio waveforms for Qwen3 Omni MoE Code2Wav. -This is a non-autoregressive model that doesn't require sampling or logits computation. -""" +"""Code2Wav GPU Model Runner for vLLM-Omni (non-autoregressive codec->waveform).""" from __future__ import annotations @@ -42,12 +38,7 @@ class GPUGenerationModelRunner(OmniGPUModelRunner): - """Generation model runner for vLLM-Omni (non-autoregressive). - - - Reuses GPUModelRunner preparation, multimodal handling, and TP/PP/DP glue. - - Does not compute logits or perform token sampling. - - Executes generation process and returns tensors via `pooler_output`. - """ + """Non-autoregressive generation runner that skips logits/sampling and returns waveforms via pooler_output.""" def _update_request_states(self, scheduler_output: SchedulerOutput): # remove requests @@ -415,15 +406,7 @@ def _run_generation_model( model_kwargs: dict, logits_indices: torch.Tensor, ) -> torch.Tensor | list[torch.Tensor]: - """Run generation from codec codes to waveforms. - - Args: - scheduler_output: Contains codec codes in input_ids or additional info - intermediate_tensors: PP intermediate tensors if applicable - - Returns: - Audio waveforms: [batch, 1, waveform_len] or list of tensors - """ + """Run codec->waveform generation and return waveforms (tensor or list).""" # Keep inputs identical to AR runner kwargs = dict( input_ids=input_ids, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 09a792bc802..ac9930d0ccc 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -70,16 +70,18 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): @instrument(span_name="Loading (GPU)") def load_model(self, *args, **kwargs) -> None: super().load_model(*args, **kwargs) + # TODO move this model specific logic to a separate class - if hasattr(self.model, "talker_mtp") and self.model.talker is not None: - self.talker_mtp = self.model.talker_mtp + talker_mtp = getattr(self.model, "talker_mtp", None) + if talker_mtp is not None: + self.talker_mtp = talker_mtp # type: ignore[assignment] cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None if cudagraph_mode.has_full_cudagraphs(): - self.talker_mtp = CUDAGraphWrapper( - self.model.talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL - ) - hidden_size = self.model_config.hf_config.talker_config.text_config.hidden_size + self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) + hidden_size = int( + getattr(self.model, "mtp_hidden_size", 0) or getattr(self.model_config.hf_text_config, "hidden_size") + ) max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size) self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32) self.talker_mtp_inputs_embeds = self._make_buffer( @@ -306,6 +308,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] + # async_chunk: keep per-step additional_information_cpu in sync (e.g. codec window metadata). + cached_infos = getattr(req_data, "additional_information", None) + if isinstance(cached_infos, dict): + info = cached_infos.get(req_id) + if isinstance(info, dict) and info: + self._merge_additional_information_update(req_id, info) num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_id in req_data.resumed_req_ids @@ -313,19 +321,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_index = self.input_batch.req_id_to_index.get(req_id) if req_state.prev_num_draft_len and self.use_async_scheduling: - # prev_num_draft_len is used in async scheduling mode with - # spec decode. it indicates if need to update num_computed_tokens - # of the request. for example: - # fist step: num_computed_tokens = 0, spec_tokens = [], - # prev_num_draft_len = 0. - # second step: num_computed_tokens = 100(prompt length), - # spec_tokens = [a,b], prev_num_draft_len = 0. - # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], - # prev_num_draft_len = 2. - # num_computed_tokens in first step and second step does't contain - # the spec tokens length, but in third step it contains the - # spec tokens length. we only need to update num_computed_tokens - # when prev_num_draft_len > 0. + # Async scheduling + spec decode: adjust num_computed_tokens only when prev_num_draft_len > 0. + # This accounts for rejected draft tokens from the previous step. if req_index is None: req_state.prev_num_draft_len = 0 else: @@ -767,8 +764,6 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput" This version avoids hard dependency on payload classes by duck-typing.""" try: new_reqs = getattr(scheduler_output, "scheduled_new_reqs", []) - if not new_reqs: - return for nr in new_reqs: req_id = getattr(nr, "req_id", None) or getattr(nr, "request_id", None) if req_id is None: @@ -812,6 +807,55 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput" info_dict[k] = getattr(entry, "list_data", None) if info_dict and req_id in self.requests: setattr(self.requests[req_id], "additional_information_cpu", info_dict) + + # async_chunk: refresh additional_information_cpu for cached/running requests too (metadata can change per step). + cached_reqs = getattr(scheduler_output, "scheduled_cached_reqs", None) + cached_infos = getattr(cached_reqs, "additional_information", None) if cached_reqs is not None else None + if isinstance(cached_infos, dict) and cached_infos: + for req_id, payload_info in cached_infos.items(): + if req_id not in self.requests: + continue + if payload_info is None: + continue + info_dict: dict[str, object] | None = None + if isinstance(payload_info, dict): + info_dict = payload_info + else: + entries = getattr(payload_info, "entries", None) + if isinstance(entries, dict): + decoded: dict[str, object] = {} + for k, entry in entries.items(): + tensor_data = getattr(entry, "tensor_data", None) + if tensor_data is not None: + dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) + arr = np.frombuffer(tensor_data, dtype=dt) + arr = arr.reshape(getattr(entry, "tensor_shape", ())) + decoded[k] = torch.from_numpy(arr.copy()) + else: + decoded[k] = getattr(entry, "list_data", None) + info_dict = decoded + + if not info_dict: + continue + + req_state = self.requests[req_id] + existing = getattr(req_state, "additional_information_cpu", None) + if not isinstance(existing, dict) or not existing: + setattr(req_state, "additional_information_cpu", info_dict) + continue + + merged = dict(existing) + for k, v in info_dict.items(): + if isinstance(v, torch.Tensor): + merged[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, list): + merged[k] = [ + (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) + for item in v + ] + else: + merged[k] = v + setattr(req_state, "additional_information_cpu", merged) except Exception as e: logger.error(f"Error decoding prompt_embeds / additional_information: {e}") @@ -929,6 +973,9 @@ def _preprocess( intermediate_tensors: IntermediateTensors | None = None, ): """Align with v0.14.0 preprocess and omni's additional information handling.""" + # Decode prompt_embeds/additional_information payloads before model.preprocess() uses them. + self._decode_and_store_request_payloads(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens is_first_rank = get_pp_group().is_first_rank is_encoder_decoder = self.model_config.is_encoder_decoder @@ -1047,9 +1094,16 @@ def _preprocess( span_len = int(e) - int(s) # call the custom process function + embed_slice = inputs_embeds[s:e] if inputs_embeds is not None else None req_input_ids, req_embeds, update_dict = self.model.preprocess( - input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos + input_ids=input_ids[s:e], input_embeds=embed_slice, **req_infos ) + if inputs_embeds is None: + inputs_embeds = torch.empty( + (input_ids.shape[0], req_embeds.shape[-1]), + device=req_embeds.device, + dtype=req_embeds.dtype, + ) if hasattr(self.model, "talker_mtp") and span_len == 1: last_talker_hidden, text_step = update_dict.pop("mtp_inputs") @@ -1101,14 +1155,15 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te with set_forward_context( None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc ): - req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) - # update the inputs_embeds and code_predictor_codes - code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + req_embeds, audio_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) + # Store per-step codec codes in additional_information_cpu under talker_mtp_output_key for make_omni_output. + audio_codes_cpu = audio_codes.detach().to("cpu").contiguous() + out_key = getattr(self.model, "talker_mtp_output_key", "audio_codes") for idx, req_id in enumerate(decode_req_ids): req_index = self.input_batch.req_ids.index(req_id) start_offset = int(self.query_start_loc.cpu[req_index]) inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] - update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} + update_dict = {out_key: audio_codes_cpu[idx : idx + 1]} self._merge_additional_information_update(req_id, update_dict) def _model_forward( From b07bdf2265ad3e078c376fe6a19252ae5e03c3a1 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Mon, 9 Feb 2026 04:31:26 -0800 Subject: [PATCH 02/28] [~] Fix: Enhance compatibility in OmniConnector adapter Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/inputs/preprocess.py | 107 ++++++------------ .../worker/gpu_generation_model_runner.py | 23 +++- vllm_omni/worker/gpu_model_runner.py | 58 +++------- 3 files changed, 72 insertions(+), 116 deletions(-) diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py index f6e490567af..b50c7123e51 100644 --- a/vllm_omni/inputs/preprocess.py +++ b/vllm_omni/inputs/preprocess.py @@ -19,72 +19,39 @@ class OmniInputPreprocessor(InputPreprocessor): - """Input preprocessor for omni models (tokens/embeds/multimodal + additional_information).""" + """Input preprocessor for omni models. - def _is_qwen3_tts_talker_ar(self) -> bool: - archs = getattr(self.model_config, "architectures", None) - return bool(archs) and "Qwen3TTSTalkerForConditionalGenerationARVLLM" in archs + Extends the base InputPreprocessor to handle omni-specific input + types including prompt embeddings and additional information payloads. + Supports processing tokens, embeddings, text, and multimodal inputs. + """ - def _get_qwen3_tts_codec_pad_id(self) -> int: - hf_config = getattr(self.model_config, "hf_config", None) - talker_config = getattr(hf_config, "talker_config", None) - pad = getattr(talker_config, "codec_pad_id", None) + @staticmethod + def _get_prompt_placeholder(additional_information: dict[str, Any] | None) -> tuple[int, int] | None: + """Extract generic placeholder length and pad_id from additional_information. + + Returns (prompt_placeholder_len, prompt_placeholder_pad_id) if the + upstream serving layer pre-computed them, else None. + """ + if not isinstance(additional_information, dict): + return None + raw_len = additional_information.get("prompt_placeholder_len") + raw_pad = additional_information.get("prompt_placeholder_pad_id") + if raw_len is None: + return None + # Values are wrapped in lists by the serving layer. + if isinstance(raw_len, list): + raw_len = raw_len[0] if raw_len else None + if isinstance(raw_pad, list): + raw_pad = raw_pad[0] if raw_pad else 0 try: - pad_id = int(pad) - except Exception: - pad_id = 0 - return max(0, pad_id) - - def _get_qwen3_tts_prompt_len_tokenizer(self): - # Qwen3-TTS talker prompt length must match HF AutoTokenizer (fix_mistral_regex). - tok = getattr(self, "_qwen3_tts_prompt_len_tokenizer", None) - if tok is not None: - return tok - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained( - self.model_config.model, - trust_remote_code=True, - fix_mistral_regex=True, - use_fast=True, - ) - tok.padding_side = "left" - self._qwen3_tts_prompt_len_tokenizer = tok - return tok - - def _estimate_qwen3_tts_talker_prompt_len(self, additional_information: dict[str, Any] | None) -> int: - """Estimate Qwen3-TTS talker placeholder prompt length for vLLM scheduling. - Real conditioning is carried in additional_information.""" - info = additional_information if isinstance(additional_information, dict) else {} - - def _first(x: object, default: object = "") -> object: - if isinstance(x, list): - return x[0] if x else default - return x if x is not None else default - - task_type = str(_first(info.get("task_type"), "CustomVoice") or "CustomVoice") - hf_config = getattr(self.model_config, "hf_config", None) - talker_config = getattr(hf_config, "talker_config", None) - codec_language_id = getattr(talker_config, "codec_language_id", None) - spk_is_dialect = getattr(talker_config, "spk_is_dialect", None) - - from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker_ar import ( - Qwen3TTSTalkerForConditionalGenerationARVLLM, - ) - - tok = self._get_qwen3_tts_prompt_len_tokenizer() - - def _hf_tokenize_len(s: str) -> list[int]: - return tok(s, padding=False)["input_ids"] - - return Qwen3TTSTalkerForConditionalGenerationARVLLM.estimate_prompt_len_from_additional_information( - info, - task_type=task_type, - tokenize_prompt=_hf_tokenize_len, - codec_language_id=codec_language_id, - spk_is_dialect=spk_is_dialect, - estimate_ref_code_len=None, - ) + ph_len = int(raw_len) + ph_pad = int(raw_pad) if raw_pad is not None else 0 + except (TypeError, ValueError): + return None + if ph_len <= 0: + return None + return ph_len, max(0, ph_pad) def _process_text( self, @@ -111,13 +78,13 @@ def _process_text( if additional_information is not None: inputs["additional_information"] = additional_information else: - if self._is_qwen3_tts_talker_ar(): - # Qwen3-TTS talker uses a small codec vocab; text token ids are OOV. - # Use in-vocab pad placeholders for scheduling. - additional_information = parsed_content.get("additional_information") - prompt_len = self._estimate_qwen3_tts_talker_prompt_len(additional_information) - pad_id = self._get_qwen3_tts_codec_pad_id() - prompt_token_ids = [pad_id] * prompt_len + additional_information = parsed_content.get("additional_information") + placeholder = self._get_prompt_placeholder(additional_information) + if placeholder is not None: + # Upstream serving layer pre-computed placeholder length/pad_id + # (e.g. TTS models whose text tokens are OOV in the codec vocab). + ph_len, ph_pad = placeholder + prompt_token_ids = [ph_pad] * ph_len else: prompt_token_ids = self._tokenize_prompt( prompt_text, diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 580b2738265..0747db3ea57 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -1,4 +1,8 @@ -"""Code2Wav GPU Model Runner for vLLM-Omni (non-autoregressive codec->waveform).""" +"""Code2Wav GPU Model Runner for vLLM-Omni. + +Handles direct conversion from codec codes to audio waveforms for Qwen3 Omni MoE Code2Wav. +This is a non-autoregressive model that doesn't require sampling or logits computation. +""" from __future__ import annotations @@ -38,7 +42,12 @@ class GPUGenerationModelRunner(OmniGPUModelRunner): - """Non-autoregressive generation runner that skips logits/sampling and returns waveforms via pooler_output.""" + """Generation model runner for vLLM-Omni (non-autoregressive). + + - Reuses GPUModelRunner preparation, multimodal handling, and TP/PP/DP glue. + - Does not compute logits or perform token sampling. + - Executes generation process and returns tensors via `pooler_output`. + """ def _update_request_states(self, scheduler_output: SchedulerOutput): # remove requests @@ -406,7 +415,15 @@ def _run_generation_model( model_kwargs: dict, logits_indices: torch.Tensor, ) -> torch.Tensor | list[torch.Tensor]: - """Run codec->waveform generation and return waveforms (tensor or list).""" + """Run generation from codec codes to waveforms. + + Args: + scheduler_output: Contains codec codes in input_ids or additional info + intermediate_tensors: PP intermediate tensors if applicable + + Returns: + Audio waveforms: [batch, 1, waveform_len] or list of tensors + """ # Keep inputs identical to AR runner kwargs = dict( input_ids=input_ids, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index ac9930d0ccc..3bfa30b889a 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -321,8 +321,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_index = self.input_batch.req_id_to_index.get(req_id) if req_state.prev_num_draft_len and self.use_async_scheduling: - # Async scheduling + spec decode: adjust num_computed_tokens only when prev_num_draft_len > 0. - # This accounts for rejected draft tokens from the previous step. + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt length), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step does't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. if req_index is None: req_state.prev_num_draft_len = 0 else: @@ -815,47 +826,8 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput" for req_id, payload_info in cached_infos.items(): if req_id not in self.requests: continue - if payload_info is None: - continue - info_dict: dict[str, object] | None = None - if isinstance(payload_info, dict): - info_dict = payload_info - else: - entries = getattr(payload_info, "entries", None) - if isinstance(entries, dict): - decoded: dict[str, object] = {} - for k, entry in entries.items(): - tensor_data = getattr(entry, "tensor_data", None) - if tensor_data is not None: - dt = np.dtype(getattr(entry, "tensor_dtype", "float32")) - arr = np.frombuffer(tensor_data, dtype=dt) - arr = arr.reshape(getattr(entry, "tensor_shape", ())) - decoded[k] = torch.from_numpy(arr.copy()) - else: - decoded[k] = getattr(entry, "list_data", None) - info_dict = decoded - - if not info_dict: - continue - - req_state = self.requests[req_id] - existing = getattr(req_state, "additional_information_cpu", None) - if not isinstance(existing, dict) or not existing: - setattr(req_state, "additional_information_cpu", info_dict) - continue - - merged = dict(existing) - for k, v in info_dict.items(): - if isinstance(v, torch.Tensor): - merged[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, list): - merged[k] = [ - (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) - for item in v - ] - else: - merged[k] = v - setattr(req_state, "additional_information_cpu", merged) + if isinstance(payload_info, dict) and payload_info: + self._merge_additional_information_update(req_id, payload_info) except Exception as e: logger.error(f"Error decoding prompt_embeds / additional_information: {e}") From fbd0a92cef6cbaf3affdeb11ee0656630121d31e Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Mon, 9 Feb 2026 06:01:08 -0800 Subject: [PATCH 03/28] [~] Fix: Enhance OmniGenerationScheduler and Qwen3TTSTokenizer functionality Signed-off-by: Sy03 <1370724210@qq.com> --- .../core/sched/omni_generation_scheduler.py | 68 ++++++++++------- .../models/qwen3_tts/qwen3_tts_tokenizer.py | 73 ++++--------------- vllm_omni/worker/gpu_model_runner.py | 2 +- 3 files changed, 60 insertions(+), 83 deletions(-) diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 0b88e1ee521..f0d0527b068 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -31,8 +31,12 @@ def __init__(self, *args, **kwargs): self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) def schedule(self) -> SchedulerOutput: - """Diffusion fast path: schedule all prompt tokens at once (use 1 placeholder if empty). - Fall back to vLLM scheduling if the token budget cannot be satisfied.""" + """Diffusion fast path: + - Feed all input tokens of the request at once + (if 0, allocate 1 placeholder token). + - If the token budget cannot be satisfied at once, fall back to the + default vLLM scheduling. + """ token_budget = self.max_num_scheduled_tokens scheduled_timestamp = time.monotonic() @@ -46,7 +50,7 @@ def schedule(self) -> SchedulerOutput: scheduled_encoder_inputs: dict[str, list[int]] = {} cached_prompt_token_ids: dict[str, list[int]] = {} - # Temporary queue to preserve waiting order for non-diffusion requests. + # Temporary queue: preserve waiting order, do not disturb non-diffusion requests skipped_waiting_requests = create_request_queue(self.policy) req_index = 0 if self.chunk_transfer_adapter: @@ -64,6 +68,7 @@ def schedule(self) -> SchedulerOutput: already_finished_reqs.add(request) req_index += 1 continue + num_computed_tokens = request.num_computed_tokens required_tokens = len(request.prompt_token_ids) - num_computed_tokens # async_chunk: don't schedule placeholder tokens when no new chunk is available. @@ -125,19 +130,19 @@ def schedule(self) -> SchedulerOutput: if self.chunk_transfer_adapter is not None and len(request.prompt_token_ids) == 0: if request.request_id in self.chunk_transfer_adapter.finished_requests: request.status = RequestStatus.FINISHED_STOPPED - _ai = getattr(request, "additional_information", None) or {} - _pad = _ai.get("prompt_placeholder_pad_id", [0])[0] - request.prompt_token_ids.append(_pad) + request.prompt_token_ids.append(0) try: - request._all_token_ids.append(_pad) # type: ignore[attr-defined] + request._all_token_ids.append(0) # type: ignore[attr-defined] except Exception: pass else: break - # Treat all requests as diffusion here (feature flag can be added later). + # Uniformly treat as diffusion. A feature flag can be added later + # via config or request tag. - # Allocate all prompt tokens at once (use 1 placeholder if empty). + # Allocate all input tokens for the request in one shot + # (allocate 1 placeholder if zero) required_tokens = max(len(request.prompt_token_ids), 1) num_new_tokens = min(required_tokens, token_budget) new_blocks = self.kv_cache_manager.allocate_slots( @@ -206,6 +211,13 @@ def schedule(self) -> SchedulerOutput: req_to_new_blocks=req_to_new_blocks, ) + # async_chunk: forward per-step additional_information updates for cached requests. + cached_ai: dict[str, object] = {} + for req in scheduled_running_reqs: + ai = getattr(req, "additional_information", None) + if isinstance(ai, dict) and ai: + cached_ai[req.request_id] = ai + cached_reqs_data = OmniCachedRequestData( req_ids=cached_reqs_data.req_ids, resumed_req_ids=cached_reqs_data.resumed_req_ids, @@ -215,18 +227,8 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens=cached_reqs_data.num_computed_tokens, num_output_tokens=cached_reqs_data.num_output_tokens, prompt_token_ids=cached_prompt_token_ids, + additional_information=cached_ai, ) - # async_chunk: forward per-step additional_information updates for cached requests. - try: - cached_ai: dict[str, object] = {} - for req in scheduled_running_reqs: - ai = getattr(req, "additional_information", None) - if isinstance(ai, dict) and ai: - cached_ai[req.request_id] = ai - if cached_ai: - setattr(cached_reqs_data, "additional_information", cached_ai) - except Exception: - pass total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) scheduler_output = SchedulerOutput( @@ -260,7 +262,8 @@ def schedule(self) -> SchedulerOutput: self._update_after_schedule(scheduler_output) try: - # Wrap base NewRequestData as OmniNewRequestData and attach request-level payloads. + # Rewrap base NewRequestData entries with OmniNewRequestData, + # enriching with request-level payloads new_list = [] for nr in scheduler_output.scheduled_new_reqs: req_id = getattr(nr, "req_id", None) @@ -293,14 +296,25 @@ def schedule(self) -> SchedulerOutput: init_logger(__name__).exception("Failed to wrap scheduled_new_reqs with OmniNewRequestData") return scheduler_output - # Diffusion scheduler: stop requests immediately after one step (AR uses the original vLLM scheduler). + + """ + Scheduler for the diffusion model. + This scheduler is modified to stop the request immediately for the diffusion model. + This is because the diffusion model can generate the final image/audio in one step. + Note: This is just a minimal modification to the original scheduler, + and there should be some further efforts to optimize the scheduler. + The original scheduler is still used for the AR model. + """ def update_from_output( self, scheduler_output: SchedulerOutput, model_runner_output: OmniModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: - """Update scheduler state from model_runner_output (diffusion requests stop immediately).""" + """Update the scheduler state based on the model runner output. + + This method is modified to stop the request immediately for the diffusion model. + """ sampled_token_ids = model_runner_output.sampled_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict @@ -329,7 +343,9 @@ def update_from_output( if kv_connector_output and getattr(kv_connector_output, "invalid_block_ids", None): failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids) - # NOTE: keep loop body cheap (len(num_scheduled_tokens) can be 1K+). + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. stopped_running_reqs: set[Request] = set() stopped_preempted_reqs: set[Request] = set() for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): @@ -409,7 +425,9 @@ def update_from_output( new_logprobs = logprobs.slice_request(req_index, len(new_token_ids)) if new_token_ids and self.structured_output_manager.should_advance(request): - # NOTE: structured_output_request is guaranteed when structured output is enabled (ignore type warning). + # NOTE: structured_output_request should not be None if + # use_structured_output, we have check above, so safe to ignore + # type warning request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] # noqa: E501 req_id, new_token_ids ) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py index 30c0d832d59..20f58f62500 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py @@ -14,8 +14,6 @@ # limitations under the License. import base64 import io -import json -from pathlib import Path import urllib.request from urllib.parse import urlparse from typing import Any @@ -26,13 +24,17 @@ import torch from torch.nn.utils.rnn import pad_sequence from transformers import AutoConfig, AutoFeatureExtractor, AutoModel -from transformers.utils.hub import cached_file from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import ( Qwen3TTSTokenizerV2EncoderOutput, Qwen3TTSTokenizerV2Model, ) +from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config +from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import ( + Qwen3TTSTokenizerV1EncoderOutput, + Qwen3TTSTokenizerV1Model, +) AudioInput = ( str # wav path, or base64 string @@ -61,18 +63,6 @@ def __init__(self): self.config = None self.device = None - @staticmethod - def _resolve_local_config_path(pretrained_model_name_or_path: str) -> str: - p = Path(pretrained_model_name_or_path) - if p.is_dir(): - cfg = p / "config.json" - if cfg.exists(): - return str(cfg) - cfg = cached_file(pretrained_model_name_or_path, "config.json") - if cfg is None: - raise ValueError(f"config.json not found under {pretrained_model_name_or_path!r}") - return str(cfg) - @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3TTSTokenizer": """ @@ -93,42 +83,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3 load_feature_extractor = bool(kwargs.pop("load_feature_extractor", True)) - # Register 12Hz tokenizer (no optional deps). AutoConfig.register("qwen3_tts_tokenizer_12hz", Qwen3TTSTokenizerV2Config) AutoModel.register(Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2Model) - # Register 25Hz tokenizer only when needed (avoids importing optional 25Hz backends for 12Hz models). - cfg_path = cls._resolve_local_config_path(pretrained_model_name_or_path) - try: - cfg_json = json.loads(Path(cfg_path).read_text()) - except Exception as e: - raise ValueError(f"Failed to parse tokenizer config.json at {cfg_path!r}: {e}") from e - - model_type = str(cfg_json.get("model_type") or "") - if model_type == "qwen3_tts_tokenizer_25hz": - from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config - from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Model - - AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config) - AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model) - elif model_type != "qwen3_tts_tokenizer_12hz": - raise ValueError(f"Unsupported Qwen3-TTS tokenizer model_type={model_type!r} at {cfg_path!r}") + AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config) + AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model) inst.model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs) inst.config = inst.model.config - if load_feature_extractor: - try: - inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path) - except Exception as e: - raise ValueError( - "Failed to load Qwen3-TTS speech tokenizer feature extractor. " - "Please make sure the checkpoint contains the required HF " - "preprocessing files (e.g. preprocessor_config.json) under " - "the speech_tokenizer directory." - ) from e - else: - inst.feature_extractor = None + inst.feature_extractor = ( + AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path) + if load_feature_extractor else None + ) inst.device = getattr(inst.model, "device", None) if inst.device is None: @@ -220,8 +187,6 @@ def _normalize_audio_inputs( List[np.ndarray]: List of float32 waveforms resampled to model input SR. """ - if self.feature_extractor is None: - raise ValueError("Speech tokenizer feature extractor is not loaded; audio encode is not available.") target_sr = int(self.feature_extractor.sampling_rate) if isinstance(audios, (str, np.ndarray)): @@ -254,7 +219,7 @@ def encode( audios: AudioInput, sr: int | None = None, return_dict: bool = True, - ) -> Any: + ) -> Qwen3TTSTokenizerV1EncoderOutput | Qwen3TTSTokenizerV2EncoderOutput | tuple: """ Batch-encode audio into discrete codes (and optional conditioning, depending on 25Hz/12Hz). @@ -271,7 +236,7 @@ def encode( Forwarded to model.encode(...). If True, returns ModelOutput. Returns: - Any: + Qwen3TTSTokenizerV1EncoderOutput | Qwen3TTSTokenizerV2EncoderOutput | tuple: Encoder output or tuple returned by model.encode. If return_dict=True, returns a model-specific encoder output. For 25Hz models, this includes audio_codes/xvectors/ref_mels; for 12Hz models, this includes audio_codes. @@ -279,24 +244,18 @@ def encode( """ wavs = self._normalize_audio_inputs(audios, sr=sr) - if self.feature_extractor is None: - raise ValueError("Speech tokenizer feature extractor is not loaded; audio encode is not available.") inputs = self.feature_extractor( raw_audio=wavs, sampling_rate=int(self.feature_extractor.sampling_rate), return_tensors="pt", ) - # Normalize to tensors and keep padding_mask integer (tokenizer expects 0/1). - input_values = inputs["input_values"].squeeze(1).to(self.device).to(self.model.dtype) - padding_mask = inputs["padding_mask"].squeeze(1).to(self.device) - if padding_mask.dtype == torch.bool: - padding_mask = padding_mask.to(torch.long) + inputs = inputs.to(self.device).to(self.model.dtype) with torch.inference_mode(): # model.encode expects (B, T) and (B, T) enc = self.model.encode( - input_values, - padding_mask, + inputs["input_values"].squeeze(1), + inputs["padding_mask"].squeeze(1), return_dict=return_dict, ) return enc diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 3bfa30b889a..831697dd2b1 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1128,7 +1128,7 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc ): req_embeds, audio_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) - # Store per-step codec codes in additional_information_cpu under talker_mtp_output_key for make_omni_output. + # update the inputs_embeds and audio_codes audio_codes_cpu = audio_codes.detach().to("cpu").contiguous() out_key = getattr(self.model, "talker_mtp_output_key", "audio_codes") for idx, req_id in enumerate(decode_req_ids): From fc6afaee9b0bda216b1fd9296d0e55964979e06c Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Mon, 9 Feb 2026 06:22:41 -0800 Subject: [PATCH 04/28] [~] Refactor: codec frame rate handling in OmniModelConfig and Qwen3TTSConfig Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/config/model.py | 1 - .../models/qwen3_tts/configuration_qwen3_tts.py | 11 +++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index 7f915dc56e7..f13a90bb7f0 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -58,7 +58,6 @@ class OmniModelConfig(ModelConfig): } ) omni_kv_config: dict | None = None - # Codec frame rate (frames/sec) for prompt length estimation. codec_frame_rate_hz: float | None = None @property diff --git a/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py index dde69006865..1b441e39ac5 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py +++ b/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py @@ -511,6 +511,17 @@ def __init__( self.vision_config = PretrainedConfig() # dummy vision config self.vision_config.spatial_merge_size = 1 + @property + def codec_frame_rate_hz(self) -> float | None: + pos_per_sec = getattr(self.talker_config, "position_id_per_seconds", None) + if pos_per_sec is None: + return None + try: + fps = float(pos_per_sec) + except (TypeError, ValueError): + return None + return fps if fps > 0 else None + def get_text_config(self, **kwargs): # vLLM expects text config to expose hidden_size/num_attention_heads. # For Qwen3 TTS, the talker config is the text model config. From 05e9cff70464f7d2e6550c07ecd14a22e1dd7408 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Mon, 9 Feb 2026 19:04:00 -0800 Subject: [PATCH 05/28] [~] Refactor: Improve device handling and additional information management in Qwen3 TTS models and OmniGenerationScheduler Signed-off-by: Sy03 <1370724210@qq.com> --- .../core/sched/omni_generation_scheduler.py | 15 ++++--- .../models/qwen3_tts/qwen3_tts_code2wav.py | 41 +++++++++---------- .../qwen3_tts/qwen3_tts_disaggregated.py | 23 ++++++----- 3 files changed, 41 insertions(+), 38 deletions(-) diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index f0d0527b068..ba1b1f9d0c0 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -136,7 +136,9 @@ def schedule(self) -> SchedulerOutput: except Exception: pass else: - break + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue # Uniformly treat as diffusion. A feature flag can be added later # via config or request tag. @@ -212,11 +214,11 @@ def schedule(self) -> SchedulerOutput: ) # async_chunk: forward per-step additional_information updates for cached requests. - cached_ai: dict[str, object] = {} + per_req_additional_info: dict[str, object] = {} for req in scheduled_running_reqs: - ai = getattr(req, "additional_information", None) - if isinstance(ai, dict) and ai: - cached_ai[req.request_id] = ai + req_info = getattr(req, "additional_information", None) + if isinstance(req_info, dict) and req_info: + per_req_additional_info[req.request_id] = req_info cached_reqs_data = OmniCachedRequestData( req_ids=cached_reqs_data.req_ids, @@ -227,8 +229,9 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens=cached_reqs_data.num_computed_tokens, num_output_tokens=cached_reqs_data.num_output_tokens, prompt_token_ids=cached_prompt_token_ids, - additional_information=cached_ai, ) + if per_req_additional_info: + cached_reqs_data.additional_information = per_req_additional_info total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) scheduler_output = SchedulerOutput( diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index cb949d0c0c4..c96c8b7be2d 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -45,6 +45,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._stream_left_context_frames = 25 self._logged_codec_stats = False + @staticmethod + def _module_device(module: nn.Module) -> torch.device: + try: + return next(module.parameters()).device + except StopIteration: + for _, buf in module.named_buffers(recurse=True): + return buf.device + return torch.device("cpu") + def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: if self._speech_tokenizer is not None: return self._speech_tokenizer @@ -70,28 +79,16 @@ def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: load_feature_extractor=False, ) - # Align device with vLLM worker. - device = getattr(self.vllm_config.device_config, "device", None) - if device is None: - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - try: - if tok.model is not None: - tok.model.to(device=device) - tok.device = device - except Exception as e: - raise RuntimeError(f"Failed to move SpeechTokenizer to device={device}: {e}") from e + # Align device with vLLM worker, then read back from module. + if tok.model is not None: + tok.model.to(device=self.vllm_config.device_config.device) + tok.device = self._module_device(tok.model) - # Derive codec group count and rates from tokenizer config if possible. - num_q = None - try: - dec_cfg = getattr(tok.model.config, "decoder_config", None) - if dec_cfg is not None: - num_q = getattr(dec_cfg, "num_quantizers", None) - except Exception: - num_q = None + # Derive codec group count and rates from tokenizer config. + dec_cfg = getattr(tok.model.config, "decoder_config", None) + num_q = getattr(dec_cfg, "num_quantizers", None) if dec_cfg is not None else None if num_q is None: - # Fallback: many code2wav stages use 16 quantizers. - num_q = 16 + raise ValueError("speech_tokenizer decoder_config.num_quantizers not found") num_q = int(num_q) if num_q <= 0: raise ValueError(f"Invalid speech_tokenizer num_quantizers={num_q}") @@ -105,8 +102,8 @@ def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: try: out_sr = int(tok.get_output_sample_rate()) - except Exception: - out_sr = 24000 + except Exception as e: + raise ValueError(f"Failed to get output sample rate: {e}") from e self._speech_tokenizer = tok self._num_quantizers = num_q diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py index 25329b8851c..46e2fdecba8 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py @@ -63,6 +63,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self._num_code_groups <= 0: raise ValueError(f"Invalid num_code_groups={self._num_code_groups} for Qwen3-TTS.") + @staticmethod + def _module_device(module: nn.Module) -> torch.device: + try: + return next(module.parameters()).device + except StopIteration: + for _, buf in module.named_buffers(recurse=True): + return buf.device + return torch.device("cpu") + def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: if self._speech_tokenizer is not None: return self._speech_tokenizer @@ -77,16 +86,10 @@ def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: torch_dtype=torch.bfloat16, load_feature_extractor=False, ) - # Run decode on the vLLM worker device (fallback to best-effort CUDA/CPU). - device = getattr(self.vllm_config.device_config, "device", None) - if device is None: - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - try: - if hasattr(self._speech_tokenizer, "model") and self._speech_tokenizer.model is not None: - self._speech_tokenizer.model.to(device=device) - self._speech_tokenizer.device = device - except Exception as e: - raise RuntimeError(f"Failed to move SpeechTokenizer to device={device}: {e}") from e + # Run decode on the vLLM worker device, then read back from module. + if self._speech_tokenizer.model is not None: + self._speech_tokenizer.model.to(device=self.vllm_config.device_config.device) + self._speech_tokenizer.device = self._module_device(self._speech_tokenizer.model) return self._speech_tokenizer def preprocess( From 1852a23fbe28facc3fb78378a5621d7b213e165a Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Mon, 9 Feb 2026 19:58:13 -0800 Subject: [PATCH 06/28] [~] Refactor: Streamline prompt handling and additional information extraction in OmniInputPreprocessor and Qwen3-TTS Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/inputs/preprocess.py | 43 +---- .../stage_input_processors/qwen3_tts.py | 163 +++++++++++------- vllm_omni/worker/gpu_model_runner.py | 26 +-- 3 files changed, 115 insertions(+), 117 deletions(-) diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py index b50c7123e51..09b215bf98a 100644 --- a/vllm_omni/inputs/preprocess.py +++ b/vllm_omni/inputs/preprocess.py @@ -26,33 +26,6 @@ class OmniInputPreprocessor(InputPreprocessor): Supports processing tokens, embeddings, text, and multimodal inputs. """ - @staticmethod - def _get_prompt_placeholder(additional_information: dict[str, Any] | None) -> tuple[int, int] | None: - """Extract generic placeholder length and pad_id from additional_information. - - Returns (prompt_placeholder_len, prompt_placeholder_pad_id) if the - upstream serving layer pre-computed them, else None. - """ - if not isinstance(additional_information, dict): - return None - raw_len = additional_information.get("prompt_placeholder_len") - raw_pad = additional_information.get("prompt_placeholder_pad_id") - if raw_len is None: - return None - # Values are wrapped in lists by the serving layer. - if isinstance(raw_len, list): - raw_len = raw_len[0] if raw_len else None - if isinstance(raw_pad, list): - raw_pad = raw_pad[0] if raw_pad else 0 - try: - ph_len = int(raw_len) - ph_pad = int(raw_pad) if raw_pad is not None else 0 - except (TypeError, ValueError): - return None - if ph_len <= 0: - return None - return ph_len, max(0, ph_pad) - def _process_text( self, parsed_content: OmniTextPrompt, @@ -78,18 +51,10 @@ def _process_text( if additional_information is not None: inputs["additional_information"] = additional_information else: - additional_information = parsed_content.get("additional_information") - placeholder = self._get_prompt_placeholder(additional_information) - if placeholder is not None: - # Upstream serving layer pre-computed placeholder length/pad_id - # (e.g. TTS models whose text tokens are OOV in the codec vocab). - ph_len, ph_pad = placeholder - prompt_token_ids = [ph_pad] * ph_len - else: - prompt_token_ids = self._tokenize_prompt( - prompt_text, - tokenization_kwargs=tokenization_kwargs, - ) + prompt_token_ids = self._tokenize_prompt( + prompt_text, + tokenization_kwargs=tokenization_kwargs, + ) inputs = token_inputs_omni( prompt_token_ids, prompt_embeds=parsed_content.get("prompt_embeds"), diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 0ebf9abdbba..d8c06ee35d3 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -1,84 +1,131 @@ -"""Stage input processor for Qwen3-TTS: Talker → SpeechTokenizer transition.""" +"""Stage input processor for Qwen3-TTS: Talker -> Code2Wav.""" from typing import Any import torch -def talker2speech_tokenizer_async_chunk( +def _get_request_info(request: Any) -> dict[str, Any]: + info = getattr(request, "additional_information_cpu", None) + if info is None: + info = getattr(request, "additional_information", None) + if isinstance(info, list) and info and isinstance(info[0], dict): + info = info[0] + return info if isinstance(info, dict) else {} + + +def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: + audio_codes = pooling_output.get("audio_codes") + if not isinstance(audio_codes, torch.Tensor) or audio_codes.numel() == 0: + return None + if audio_codes.ndim == 2: + frame = audio_codes[-1] + if frame.numel() == 0 or not bool(frame.any().item()): + return None + return frame.to(torch.long).reshape(-1) + if audio_codes.ndim == 1: + return audio_codes.to(torch.long).reshape(-1) + raise ValueError(f"Invalid audio_codes shape for Qwen3-TTS async_chunk: {tuple(audio_codes.shape)}") + + +def talker2code2wav_async_chunk( + connector: Any, pooling_output: dict[str, Any], request: Any, ) -> dict[str, Any] | None: - """Async-chunk payload extractor for Qwen3-TTS Talker → SpeechTokenizer. - - Stage-0 emits per-step codec codes; they are sent via connector and consumed by Stage-1 as `prompt_token_ids`. - Returns: `code_predictor_codes` (List[int]) / `codec_streaming` (bool) / `finished` (torch.bool). - """ if not isinstance(pooling_output, dict): return None - # `codec_streaming` is the cross-stage streaming toggle (not the official `non_streaming_mode`). - # It can be overridden per request. - info = getattr(request, "additional_information_cpu", None) - if info is None: - info = getattr(request, "additional_information", None) - # vLLM may pass additional information as a list for batched requests; Qwen3-TTS typically uses batch=1. - if isinstance(info, list) and info and isinstance(info[0], dict): - info = info[0] - if not isinstance(info, dict): - info = {} - - def _first(x: object, default: object) -> object: - if isinstance(x, list): - return x[0] if x else default - return x if x is not None else default - - # In async_chunk, Stage-1 consumes only newly scheduled tokens per step; Stage-0 must stream frame-aligned windows. - # Stage-1 trims left-context each step. - codec_streaming_val = _first(info.get("codec_streaming"), True) - codec_streaming = bool(codec_streaming_val) if isinstance(codec_streaming_val, bool) else True - # Do not override from `pooling_output`: this is a pipeline contract. - # Mis-overrides can break Stage-1 trim/paste rules. - - # The stop-token step is not a decodable frame; only notify Stage-1 via `finished`. - finished = False - try: - finished = bool(request.is_finished()) - except Exception: - finished = False - - if finished: + info = _get_request_info(request) + request_id = request.external_req_id + + codec_streaming_raw = info.get("codec_streaming", True) + if isinstance(codec_streaming_raw, list): + codec_streaming_raw = codec_streaming_raw[0] if codec_streaming_raw else True + codec_streaming = codec_streaming_raw if isinstance(codec_streaming_raw, bool) else True + + raw_cfg = getattr(connector, "config", {}) or {} + cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} + chunk_size = int(cfg.get("codec_chunk_frames", 25)) + left_context_size = int(cfg.get("codec_left_context_frames", 25)) + if chunk_size <= 0 or left_context_size < 0: + raise ValueError( + f"Invalid codec chunk config: codec_chunk_frames={chunk_size}, " + f"codec_left_context_frames={left_context_size}" + ) + + finished = bool(request.is_finished()) + + appended_frame = False + if not finished: + frame = _extract_last_frame(pooling_output) + if frame is None: + return None + codec_codes = frame.cpu().tolist() + connector.code_prompt_token_ids[request_id].append(codec_codes) + appended_frame = True + + length = len(connector.code_prompt_token_ids[request_id]) + chunk_length = length % chunk_size + + if chunk_length != 0 and not finished: + return None + + context_length = chunk_length if chunk_length != 0 else chunk_size + + if finished and (not appended_frame) and chunk_length == 0: return { "code_predictor_codes": [], "codec_streaming": codec_streaming, + "codec_context_codes": [], + "codec_context_frames": 0, + "codec_total_frames": 0, + "codec_chunk_frames": 0, + "codec_num_code_groups": 0, + "codec_layout": "codebook_major", "finished": torch.tensor(True, dtype=torch.bool), } - # Talker AR stage exposes per-step codes as `audio_codes` (shape [T, Q]). - audio_codes = pooling_output.get("audio_codes") - if not isinstance(audio_codes, torch.Tensor) or audio_codes.numel() == 0: - # Nothing to send for this step. - return None + if length <= 0: + return { + "code_predictor_codes": [], + "codec_streaming": codec_streaming, + "codec_context_codes": [], + "codec_context_frames": 0, + "codec_total_frames": 0, + "codec_chunk_frames": 0, + "codec_num_code_groups": 0, + "codec_layout": "codebook_major", + "finished": torch.tensor(bool(finished), dtype=torch.bool), + } - # `audio_codes` may include prefill/placeholder frames (shape [T,Q]); take only the last frame and skip if all-zero. - if audio_codes.ndim == 2: - frame = audio_codes[-1] - try: - if frame.numel() == 0 or not bool(frame.any().item()): - return None - except Exception: - # If `.any()` is unreliable, prefer sending the last frame and let Stage-1 fail-fast on misalignment. - pass - elif audio_codes.ndim == 1: - frame = audio_codes + end_index = min(length, left_context_size + context_length) + ctx_frames = max(0, int(end_index - context_length)) + window_frames = connector.code_prompt_token_ids[request_id][-end_index:] + + if ctx_frames > 0: + ctx_part = window_frames[:ctx_frames] + codec_context_codes = torch.tensor(ctx_part).transpose(0, 1).reshape(-1).tolist() else: - raise ValueError(f"Invalid audio_codes shape for Qwen3-TTS async_chunk: {tuple(audio_codes.shape)}") + codec_context_codes = [] + + chunk_part = window_frames[ctx_frames:] + code_predictor_codes = torch.tensor(chunk_part).transpose(0, 1).reshape(-1).tolist() - frame = frame.to(torch.long).reshape(-1) - codec_codes = frame.cpu().tolist() + num_code_groups = int( + len(connector.code_prompt_token_ids[request_id][-1]) + if connector.code_prompt_token_ids[request_id] + else 0 + ) return { - "code_predictor_codes": codec_codes, + "code_predictor_codes": code_predictor_codes, "codec_streaming": codec_streaming, + "codec_context_codes": codec_context_codes, + "codec_context_frames": int(ctx_frames), + "codec_total_frames": int(end_index), + "codec_chunk_frames": int(context_length), + "codec_num_code_groups": num_code_groups, + "codec_layout": "codebook_major", "finished": torch.tensor(bool(finished), dtype=torch.bool), } diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 831697dd2b1..628275e4986 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -72,6 +72,7 @@ def load_model(self, *args, **kwargs) -> None: super().load_model(*args, **kwargs) # TODO move this model specific logic to a separate class + # TTS model IS the talker (no .talker sub-attr); use getattr to support both Omni and TTS. talker_mtp = getattr(self.model, "talker_mtp", None) if talker_mtp is not None: self.talker_mtp = talker_mtp # type: ignore[assignment] @@ -79,6 +80,7 @@ def load_model(self, *args, **kwargs) -> None: assert cudagraph_mode is not None if cudagraph_mode.has_full_cudagraphs(): self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) + # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size. hidden_size = int( getattr(self.model, "mtp_hidden_size", 0) or getattr(self.model_config.hf_text_config, "hidden_size") ) @@ -308,12 +310,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] - # async_chunk: keep per-step additional_information_cpu in sync (e.g. codec window metadata). - cached_infos = getattr(req_data, "additional_information", None) - if isinstance(cached_infos, dict): - info = cached_infos.get(req_id) - if isinstance(info, dict) and info: - self._merge_additional_information_update(req_id, info) num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_id in req_data.resumed_req_ids @@ -818,16 +814,6 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput" info_dict[k] = getattr(entry, "list_data", None) if info_dict and req_id in self.requests: setattr(self.requests[req_id], "additional_information_cpu", info_dict) - - # async_chunk: refresh additional_information_cpu for cached/running requests too (metadata can change per step). - cached_reqs = getattr(scheduler_output, "scheduled_cached_reqs", None) - cached_infos = getattr(cached_reqs, "additional_information", None) if cached_reqs is not None else None - if isinstance(cached_infos, dict) and cached_infos: - for req_id, payload_info in cached_infos.items(): - if req_id not in self.requests: - continue - if isinstance(payload_info, dict) and payload_info: - self._merge_additional_information_update(req_id, payload_info) except Exception as e: logger.error(f"Error decoding prompt_embeds / additional_information: {e}") @@ -945,9 +931,6 @@ def _preprocess( intermediate_tensors: IntermediateTensors | None = None, ): """Align with v0.14.0 preprocess and omni's additional information handling.""" - # Decode prompt_embeds/additional_information payloads before model.preprocess() uses them. - self._decode_and_store_request_payloads(scheduler_output) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens is_first_rank = get_pp_group().is_first_rank is_encoder_decoder = self.model_config.is_encoder_decoder @@ -1169,7 +1152,10 @@ def _model_forward( self._omni_last_model_output = model_output return model_output - def _merge_additional_information_update(self, req_id: str, upd: dict) -> None: + def _merge_additional_information_update(self, req_id: str, upd: dict | None) -> None: + # Guard: _update_additional_information may pass None when additional_information is absent. + if not isinstance(upd, dict): + return req_state = self.requests.get(req_id) if req_state is None: return From 8fb1c617c4241f6b2bd9480c7e0dbe348fc8bdb3 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Mon, 9 Feb 2026 23:03:12 -0800 Subject: [PATCH 07/28] [~] Refactor: Simplify payload handling and enhance metadata management in Qwen3 TTS models and OmniGPUModelRunner Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/entrypoints/async_omni.py | 3 +- .../qwen3_tts/configuration_qwen3_tts.py | 12 ++++- .../models/qwen3_tts/qwen3_tts_code2wav.py | 3 +- .../qwen3_tts_code_predictor_vllm.py | 46 +++++++++++++++---- .../models/qwen3_tts/qwen3_tts_talker_ar.py | 9 ---- .../models/qwen3_tts/qwen3_tts_tokenizer.py | 1 - .../stage_configs/qwen3_tts.yaml | 8 +++- vllm_omni/worker/gpu_model_runner.py | 13 +++--- 8 files changed, 63 insertions(+), 32 deletions(-) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 369713d7b68..d00652a658b 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -390,7 +390,8 @@ async def _process_async_results( submit_flag = False prompt_token_ids = engine_outputs.prompt_token_ids engine_input = copy.deepcopy(prompt) - engine_input["prompt_token_ids"] = [0] * compute_talker_prompt_ids_length(prompt_token_ids) + next_prompt_len = max(1, compute_talker_prompt_ids_length(prompt_token_ids)) + engine_input["prompt_token_ids"] = [0] * next_prompt_len engine_input["multi_modal_data"] = engine_input["mm_processor_kwargs"] = None for i in range(1, len(self.stage_list)): task = { diff --git a/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py index 1b441e39ac5..01dc6bbe45e 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py +++ b/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py @@ -526,8 +526,16 @@ def get_text_config(self, **kwargs): # vLLM expects text config to expose hidden_size/num_attention_heads. # For Qwen3 TTS, the talker config is the text model config. config = self.talker_config - # if hasattr(config, "rope_parameters"): - # delattr(config, "rope_parameters") + # Code2Wav is a pure convolutional waveform decoder; it does NOT use + # rotary position embeddings. When hf_overrides sets architectures + # to [Qwen3TTSCode2Wav], strip rope_parameters so that the model + # runner sees uses_mrope == False and skips mrope position computation + # on codec tokens. Each stage loads its own config instance, so this + # in-place mutation does not affect the Talker stage. + archs = getattr(self, "architectures", []) or [] + if any("Code2Wav" in str(a) for a in archs): + if hasattr(config, "rope_parameters"): + delattr(config, "rope_parameters") return config diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index c96c8b7be2d..0008c40231f 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -237,7 +237,8 @@ def forward( if ctx_frames < 0: raise ValueError(f"Invalid codec_context_frames={ctx_frames} (must be >=0).") - # input_ids may be padded; use codec_chunk_frames to slice the exact chunk (chunk_frames * q) and ignore padding. + # input_ids may be padded; use codec_chunk_frames to slice the + # exact chunk (chunk_frames * q) and ignore padding. if chunk_frames is None: raise ValueError( "Missing codec_chunk_frames in runtime_additional_information for Qwen3TTSCode2Wav. " diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py index 044bc5292df..dbcdbe0be3d 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -25,7 +25,8 @@ class _LocalPredictorKVCache: - """Minimal local KV cache + attention metadata for running code_predictor inside one worker (independent of engine KV).""" + """Minimal local KV cache + attention metadata for running + code_predictor inside one worker (independent of engine KV).""" def __init__( self, @@ -93,11 +94,20 @@ def build_attn_metadata( num_reqs: int, query_lens: torch.Tensor, # (num_reqs,) int32 on cpu seq_lens: torch.Tensor, # (num_reqs,) int32 on cpu - ) -> tuple[dict[str, Any], torch.Tensor]: - """Build attention metadata and return (attn_metadata, positions).""" + ) -> tuple[dict[str, Any], torch.Tensor, dict[str, torch.Tensor]]: + """Build attention metadata, positions, and slot_mapping dict. + + Returns: + (attn_metadata, positions, slot_mappings_by_layer) + - attn_metadata: per-layer attention metadata for attn backends. + - positions: (num_tokens,) position IDs on device. + - slot_mappings_by_layer: {layer_name: slot_mapping_tensor} for + set_forward_context so that unified_kv_cache_update can write + the KV cache correctly. + """ num_reqs = int(num_reqs) if num_reqs <= 0: - return {}, torch.empty((0,), dtype=torch.int64, device=self.device) + return {}, torch.empty((0,), dtype=torch.int64, device=self.device), {} if num_reqs > self.max_batch_size: raise ValueError(f"num_reqs={num_reqs} exceeds local predictor max_batch_size={self.max_batch_size}") @@ -109,7 +119,7 @@ def build_attn_metadata( qsl[1:] = torch.cumsum(query_lens_i32, dim=0) num_tokens = int(qsl[-1].item()) if num_tokens <= 0: - return {}, torch.empty((0,), dtype=torch.int64, device=self.device) + return {}, torch.empty((0,), dtype=torch.int64, device=self.device), {} # positions: for each request i, emit positions [seq_len-query_len .. seq_len-1] pos_list: list[torch.Tensor] = [] @@ -153,7 +163,15 @@ def build_attn_metadata( slot_mappings=[slot_mapping_gpu], kv_cache_config=self.kv_cache_config, ) - return attn_metadata, positions_cpu.to(device=self.device) + + # Build slot_mappings_by_layer for set_forward_context. + # Fix for vllm 0.15.0 + slot_mappings_by_layer: dict[str, torch.Tensor] = {} + for kv_cache_group in self.kv_cache_config.kv_cache_groups: + for layer_name in kv_cache_group.layer_names: + slot_mappings_by_layer[layer_name] = slot_mapping_gpu + + return attn_metadata, positions_cpu.to(device=self.device), slot_mappings_by_layer class Qwen3TTSTalkerCodePredictorModelVLLM(nn.Module): @@ -362,13 +380,17 @@ def prefill_logits(self, inputs_embeds: torch.Tensor) -> torch.Tensor: query_lens = torch.full((bsz,), qlen, dtype=torch.int32) seq_lens = query_lens.clone() - attn_metadata, positions = self._kv_cache.build_attn_metadata( + attn_metadata, positions, slot_mappings = self._kv_cache.build_attn_metadata( num_reqs=bsz, query_lens=query_lens, seq_lens=seq_lens ) with ( set_current_vllm_config(self._vllm_config), - set_forward_context(attn_metadata, self._vllm_config, num_tokens=int(hs.shape[0])), + set_forward_context( + attn_metadata, self._vllm_config, + num_tokens=int(hs.shape[0]), + slot_mapping=slot_mappings, + ), ): out = self.model(positions=positions, inputs_embeds=hs) @@ -393,13 +415,17 @@ def decode_logits(self, input_ids: torch.Tensor, *, generation_step: int, past_s query_lens = torch.ones((bsz,), dtype=torch.int32) seq_lens = torch.full((bsz,), int(past_seq_len) + 1, dtype=torch.int32) - attn_metadata, positions = self._kv_cache.build_attn_metadata( + attn_metadata, positions, slot_mappings = self._kv_cache.build_attn_metadata( num_reqs=bsz, query_lens=query_lens, seq_lens=seq_lens ) with ( set_current_vllm_config(self._vllm_config), - set_forward_context(attn_metadata, self._vllm_config, num_tokens=int(hs.shape[0])), + set_forward_context( + attn_metadata, self._vllm_config, + num_tokens=int(hs.shape[0]), + slot_mapping=slot_mappings, + ), ): out = self.model(positions=positions, inputs_embeds=hs) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py index 65b6c4e154a..daf557ae598 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py @@ -250,15 +250,6 @@ def preprocess( raise ValueError("Missing additional_information.text for Qwen3-TTS AR talker.") task_type = (info_dict.get("task_type") or ["CustomVoice"])[0] - non_streaming_mode_val = info_dict.get("non_streaming_mode") - if isinstance(non_streaming_mode_val, list): - non_streaming_mode_raw = non_streaming_mode_val[0] if non_streaming_mode_val else None - else: - non_streaming_mode_raw = non_streaming_mode_val - if isinstance(non_streaming_mode_raw, bool): - non_streaming_mode = non_streaming_mode_raw - else: - non_streaming_mode = task_type in ("CustomVoice", "VoiceDesign") codec_streaming_val = info_dict.get("codec_streaming") if isinstance(codec_streaming_val, list): codec_streaming_raw = codec_streaming_val[0] if codec_streaming_val else None diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py index 20f58f62500..01d9abd95c3 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py @@ -16,7 +16,6 @@ import io import urllib.request from urllib.parse import urlparse -from typing import Any import librosa import numpy as np diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml index 1db64fda791..f8a0b75a533 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml @@ -6,7 +6,7 @@ stage_args: devices: "0" max_batch_size: 1 engine_args: - model_stage: talker + model_stage: qwen3_tts model_arch: Qwen3TTSTalkerForConditionalGenerationARVLLM # Force stage-specific registered architecture. hf_overrides: @@ -22,7 +22,7 @@ stage_args: distributed_executor_backend: "mp" max_num_batched_tokens: 512 max_model_len: 4096 - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2speech_tokenizer_async_chunk + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk # Use named connector to apply runtime.connectors.extra. output_connectors: to_stage_1: connector_of_shared_memory @@ -87,6 +87,10 @@ runtime: shm_threshold_bytes: 65536 # Frame-aligned codec streaming transport. codec_streaming: true + # Connector polling / timeout (unit: loop count, sleep interval in seconds). + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 # Match official chunked_decode defaults. codec_chunk_frames: 300 codec_left_context_frames: 25 diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 628275e4986..f1513e95031 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -771,6 +771,8 @@ def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput" This version avoids hard dependency on payload classes by duck-typing.""" try: new_reqs = getattr(scheduler_output, "scheduled_new_reqs", []) + if not new_reqs: + return for nr in new_reqs: req_id = getattr(nr, "req_id", None) or getattr(nr, "request_id", None) if req_id is None: @@ -1110,15 +1112,15 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te with set_forward_context( None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc ): - req_embeds, audio_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) - # update the inputs_embeds and audio_codes - audio_codes_cpu = audio_codes.detach().to("cpu").contiguous() - out_key = getattr(self.model, "talker_mtp_output_key", "audio_codes") + req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) + # update the inputs_embeds and code_predictor_codes + code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + out_key = getattr(self.model, "talker_mtp_output_key", "code_predictor_codes") for idx, req_id in enumerate(decode_req_ids): req_index = self.input_batch.req_ids.index(req_id) start_offset = int(self.query_start_loc.cpu[req_index]) inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] - update_dict = {out_key: audio_codes_cpu[idx : idx + 1]} + update_dict = {out_key: code_predictor_codes_cpu[idx : idx + 1]} self._merge_additional_information_update(req_id, update_dict) def _model_forward( @@ -1153,7 +1155,6 @@ def _model_forward( return model_output def _merge_additional_information_update(self, req_id: str, upd: dict | None) -> None: - # Guard: _update_additional_information may pass None when additional_information is absent. if not isinstance(upd, dict): return req_state = self.requests.get(req_id) From 82b2640cbfd38b803194c82ccaaf84169ebdc00b Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Mon, 9 Feb 2026 23:27:49 -0800 Subject: [PATCH 08/28] [~] Refactor: Remove unused logger initialization and streamline code in Qwen3 TTS models Signed-off-by: Sy03 <1370724210@qq.com> --- .../model_executor/models/qwen3_tts/qwen3_tts_code2wav.py | 4 ---- .../models/qwen3_tts/qwen3_tts_code_predictor_vllm.py | 3 --- .../models/qwen3_tts/qwen3_tts_disaggregated.py | 4 ---- .../model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py | 6 ------ 4 files changed, 17 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 0008c40231f..f4d4a4b1037 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -39,10 +39,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._num_quantizers: int | None = None self._decode_upsample_rate: int | None = None self._output_sample_rate: int | None = None - - # Default streaming window (must match connector config by convention). - self._stream_chunk_frames = 25 - self._stream_left_context_frames = 25 self._logged_codec_stats = False @staticmethod diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py index dbcdbe0be3d..f131ae4170c 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -8,7 +8,6 @@ from vllm.config import VllmConfig from vllm.config.vllm import set_current_vllm_config from vllm.forward_context import set_forward_context -from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, @@ -21,8 +20,6 @@ from .configuration_qwen3_tts import Qwen3TTSTalkerCodePredictorConfig, Qwen3TTSTalkerConfig -logger = init_logger(__name__) - class _LocalPredictorKVCache: """Minimal local KV cache + attention metadata for running diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py index 46e2fdecba8..3972d59bb54 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py @@ -7,16 +7,12 @@ import torch.nn as nn from transformers.utils.hub import cached_file from vllm.config import VllmConfig -from vllm.logger import init_logger from vllm_omni.model_executor.models.output_templates import OmniOutput from .qwen3_tts import Qwen3TTSModel from .qwen3_tts_tokenizer import Qwen3TTSTokenizer -logger = init_logger(__name__) - -_VALID_TASK_TYPES = ("CustomVoice", "VoiceDesign", "Base") _VALID_STAGES = ("talker", "speech_tokenizer") diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py index daf557ae598..8bfbddc4b38 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py @@ -718,11 +718,6 @@ def _scan(obj: object, depth: int = 0) -> None: wav_candidates.append(obj_list) return - # If this is a long list of numbers, treat it as waveform and stop. - if isinstance(obj, list) and len(obj) >= 512 and _is_number_sequence(obj_list): # type: ignore[arg-type] - wav_candidates.append(obj) - return - # Otherwise, recurse into elements (but avoid descending into huge numeric lists). for item in obj_list: if isinstance(item, list) and len(item) >= 512 and _is_number_sequence(item): # type: ignore[arg-type] @@ -780,7 +775,6 @@ def _to_np(x: object) -> np.ndarray: if wav_np.size < 1024: raise ValueError(f"ref_audio waveform too short: {wav_np.size} samples") return wav_np, sr - raise TypeError(f"Unsupported ref_audio type: {type(ref_audio)}") def _extract_speaker_embedding(self, wav: np.ndarray, sr: int) -> torch.Tensor: if self.speaker_encoder is None: From 885ee3d06c71779ba131803eaa0146d245c455d5 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Tue, 10 Feb 2026 00:46:49 -0800 Subject: [PATCH 09/28] [~] Refactor: Update TTS prompt handling and introduce new configuration for Qwen3 TTS async chunk processing Signed-off-by: Sy03 <1370724210@qq.com> --- .../entrypoints/openai/serving_speech.py | 11 +-- .../qwen3_tts/configuration_qwen3_tts.py | 9 +- ...s_talker_speech_tokenizer_async_chunk.yaml | 92 +++++++++++++++++++ 3 files changed, 101 insertions(+), 11 deletions(-) create mode 100644 vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index a8bae9e9932..d0c54b6f65a 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -125,10 +125,6 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non return None - def _build_tts_prompt(self, text: str) -> str: - """Build TTS prompt from input text.""" - return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" - def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: """Build TTS parameters from request. @@ -221,11 +217,12 @@ async def create_speech( if validation_error: return self.create_error_response(validation_error) - # Build TTS parameters and prompt + # Must use prompt_token_ids (not text prompt): the AR Talker + # operates on codec tokens; text token IDs exceed codec vocab. + # model.preprocess replaces all embeddings, so value 0 is fine. tts_params = self._build_tts_params(request) - prompt_text = self._build_tts_prompt(request.input) prompt = { - "prompt": prompt_text, + "prompt_token_ids": [1] * 2048, "additional_information": tts_params, } else: diff --git a/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py index 01dc6bbe45e..8e751413767 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py +++ b/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py @@ -504,10 +504,11 @@ def __init__( self.tts_bos_token_id = tts_bos_token_id self.tts_eos_token_id = tts_eos_token_id - # TODO: remove these dummy values after - self.image_token_id = 0 # dummy image token id - self.video_token_id = 0 # dummy video token id - self.vision_start_token_id = 0 # dummy vision start token id + # Dummy vision token IDs that must never collide with real codec tokens. + # mrope scans prompt_token_ids for these; using -1 ensures no false match. + self.image_token_id = -1 + self.video_token_id = -1 + self.vision_start_token_id = -1 self.vision_config = PretrainedConfig() # dummy vision config self.vision_config.spatial_merge_size = 1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml new file mode 100644 index 00000000000..c5e282e4d80 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml @@ -0,0 +1,92 @@ +async_chunk: true +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: qwen3_tts + model_arch: Qwen3TTSTalkerForConditionalGenerationARVLLM + hf_overrides: + architectures: [Qwen3TTSTalkerForConditionalGenerationARVLLM] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: false + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.3 + distributed_executor_backend: "mp" + max_num_batched_tokens: 512 + max_model_len: 4096 + # Stage-0 emits flattened codec codes via async_chunk connector. + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3TTSCode2Wav + hf_overrides: + architectures: [Qwen3TTSCode2Wav] + # Stage-1 has no main checkpoint weights (SpeechTokenizer is loaded from + # `speech_tokenizer/` lazily). Avoid probing for model.safetensors. + load_format: dummy + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.2 + distributed_executor_backend: "mp" + # Must be divisible by num_code_groups (typically 16). + # Must be >= num_code_groups * (codec_left_context_frames + codec_chunk_frames). + max_num_batched_tokens: 1024 + max_model_len: 4096 + engine_input_source: [0] + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + # Qwen3-TTS codec streaming (frame-aligned tokenized transport). + codec_streaming: true + codec_chunk_frames: 25 + codec_left_context_frames: 25 + + edges: + - from: 0 + to: 1 + window_size: -1 From ad6b6764f12c15122521e44ab6f81cf89d78a231 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Tue, 10 Feb 2026 04:26:19 -0800 Subject: [PATCH 10/28] [~] Fix: Enhance TTS processing to fix audio overlap issues Signed-off-by: Sy03 <1370724210@qq.com> --- .../entrypoints/openai/serving_speech.py | 37 ++- .../models/qwen3_tts/qwen3_tts_code2wav.py | 225 ++++-------------- .../stage_configs/qwen3_tts.yaml | 4 +- ...s_talker_speech_tokenizer_async_chunk.yaml | 6 +- .../stage_input_processors/qwen3_tts.py | 49 +--- 5 files changed, 94 insertions(+), 227 deletions(-) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index d0c54b6f65a..63d94ed4528 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -43,6 +43,7 @@ def __init__(self, *args, **kwargs): # Load supported speakers self.supported_speakers = self._load_supported_speakers() logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}") + self._tts_tokenizer = None def _load_supported_speakers(self) -> set[str]: """Load supported speakers (case-insensitive) from the model configuration.""" @@ -62,6 +63,32 @@ def _load_supported_speakers(self) -> set[str]: return set() + def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int: + """Estimate prompt length so the placeholder matches model-side embeddings.""" + try: + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker_ar import ( + Qwen3TTSTalkerForConditionalGenerationARVLLM, + ) + if self._tts_tokenizer is None: + from transformers import AutoTokenizer + model_name = self.engine_client.model_config.model + self._tts_tokenizer = AutoTokenizer.from_pretrained( + model_name, trust_remote_code=True, padding_side="left", + ) + hf_config = self.engine_client.model_config.hf_config + talker_config = hf_config.talker_config + task_type = (tts_params.get("task_type") or ["CustomVoice"])[0] + return Qwen3TTSTalkerForConditionalGenerationARVLLM.estimate_prompt_len_from_additional_information( + additional_information=tts_params, + task_type=task_type, + tokenize_prompt=lambda t: self._tts_tokenizer(t, padding=False)["input_ids"], + codec_language_id=getattr(talker_config, "codec_language_id", None), + spk_is_dialect=getattr(talker_config, "spk_is_dialect", None), + ) + except Exception as e: + logger.warning("Failed to estimate TTS prompt length, using fallback 2048: %s", e) + return 2048 + def _is_tts_model(self) -> bool: """Check if the current model is a supported TTS model.""" stage_list = getattr(self.engine_client, "stage_list", None) @@ -219,10 +246,12 @@ async def create_speech( # Must use prompt_token_ids (not text prompt): the AR Talker # operates on codec tokens; text token IDs exceed codec vocab. - # model.preprocess replaces all embeddings, so value 0 is fine. + # model.preprocess replaces all embeddings, so placeholder value + # is irrelevant -- but length must match to avoid excess padding. tts_params = self._build_tts_params(request) + ph_len = self._estimate_prompt_len(tts_params) prompt = { - "prompt_token_ids": [1] * 2048, + "prompt_token_ids": [1] * ph_len, "additional_information": tts_params, } else: @@ -279,6 +308,10 @@ async def create_speech( if hasattr(sample_rate, "item"): sample_rate = sample_rate.item() + # Streaming accumulates chunks as a list; concat first. + if isinstance(audio_tensor, list): + import torch + audio_tensor = torch.cat(audio_tensor, dim=-1) # Convert tensor to numpy if hasattr(audio_tensor, "float"): audio_tensor = audio_tensor.float().detach().cpu().numpy() diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index f4d4a4b1037..300ce0dff8e 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -116,47 +116,6 @@ def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None) -> None: return None - @staticmethod - def _reconstruct_window_codes_fq( - *, - chunk_ids: torch.Tensor, - q: int, - chunk_frames: int, - codec_streaming: bool, - ctx_frames: int, - ctx_codes: list[int] | None, - ) -> torch.Tensor: - """Reconstruct [F, Q] codes from codebook-major flattened chunk ids (and optional left-context).""" - if q <= 0: - raise ValueError(f"Invalid q={q} (must be >0).") - if chunk_frames <= 0: - raise ValueError(f"Invalid chunk_frames={chunk_frames} (must be >0).") - - if int(chunk_ids.numel()) != int(q) * int(chunk_frames): - raise ValueError( - "Invalid chunk_ids length for Qwen3TTSCode2Wav: " - f"got={int(chunk_ids.numel())} expected={int(q) * int(chunk_frames)} " - f"(q={q} chunk_frames={chunk_frames})." - ) - - chunk_qf = chunk_ids.reshape(int(q), int(chunk_frames)) - if codec_streaming and ctx_frames > 0: - if ctx_codes is None: - raise ValueError("Missing ctx_codes for streaming decode window reconstruction.") - expected_ctx_tokens = int(q) * int(ctx_frames) - if len(ctx_codes) != expected_ctx_tokens: - raise ValueError( - "Invalid ctx_codes length for streaming decode window reconstruction: " - f"got={len(ctx_codes)} expected={expected_ctx_tokens} (q={q} ctx_frames={ctx_frames})." - ) - ctx_tensor = torch.tensor(ctx_codes, dtype=torch.long, device=chunk_ids.device) - ctx_qf = ctx_tensor.reshape(int(q), int(ctx_frames)) - window_qf = torch.cat([ctx_qf, chunk_qf], dim=1) - else: - window_qf = chunk_qf - - return window_qf.transpose(0, 1).contiguous() # [F, Q] - @torch.no_grad() def forward( self, @@ -166,133 +125,63 @@ def forward( inputs_embeds: torch.Tensor | None = None, **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor]: - # ModelOutput is (audio_tensor, sr_tensor). + """Decode codec codes into audio waveform. + + input_ids layout: [codec_context_frames, *flat_codes] + where flat_codes is codebook-major [q*F]. + """ tok = self._ensure_speech_tokenizer_loaded() assert self._num_quantizers is not None assert self._output_sample_rate is not None + sr_val = self._output_sample_rate + empty_ret = ( + torch.zeros((0,), dtype=torch.float32), + torch.tensor(sr_val, dtype=torch.int32), + ) + if input_ids is None: - # Profile run / placeholder schedule: return empty audio. - empty = torch.zeros((0,), dtype=torch.float32) - return empty, torch.tensor(self._output_sample_rate, dtype=torch.int32) + return empty_ret - ids = input_ids.reshape(-1).to(dtype=torch.long) q = int(self._num_quantizers) + ids = input_ids.reshape(-1).to(dtype=torch.long) + n_tokens = ids.numel() - if ids.numel() == 0 or ids.numel() < q: - empty = torch.zeros((0,), dtype=torch.float32) - return empty, torch.tensor(self._output_sample_rate, dtype=torch.int32) - - # Contract: connector provides codec_streaming + codec_context_frames (left-context frames to trim). - # Assumes max_batch_size=1 for code2wav (vLLM provides a flattened per-step token stream). - ctx_frames: int | None = None - codec_streaming: bool | None = None - ctx_codes: list[int] | None = None - chunk_frames: int | None = None - rt_info = kwargs.get("runtime_additional_information") - if isinstance(rt_info, list) and len(rt_info) == 1 and isinstance(rt_info[0], dict): - v = rt_info[0].get("codec_streaming") - if v is not None: - try: - codec_streaming = bool(v) if not isinstance(v, torch.Tensor) else bool(v.item()) - except Exception: - codec_streaming = None - v = rt_info[0].get("codec_context_frames") - if v is not None: - try: - ctx_frames = int(v) - except Exception as e: - raise ValueError(f"Invalid codec_context_frames={v!r}: {e}") from e - v = rt_info[0].get("codec_context_codes") - if v is not None: - if isinstance(v, list): - ctx_codes = [int(x) for x in v] - elif isinstance(v, torch.Tensor): - ctx_codes = v.detach().to("cpu").reshape(-1).to(dtype=torch.long).tolist() - v = rt_info[0].get("codec_chunk_frames") - if v is not None: - try: - chunk_frames = int(v) - except Exception as e: - raise ValueError(f"Invalid codec_chunk_frames={v!r}: {e}") from e - - if codec_streaming is None: - raise ValueError( - "Missing codec_streaming in runtime_additional_information for Qwen3TTSCode2Wav. " - "This indicates the async_chunk connector/adapter contract was not applied." - ) + if n_tokens == 0: + return empty_ret - if codec_streaming is False: - ctx_frames = 0 - else: - if ctx_frames is None: - raise ValueError( - "Missing codec_context_frames in runtime_additional_information for streaming Qwen3TTSCode2Wav. " - "This indicates the async_chunk connector/adapter contract was not applied." - ) - if ctx_frames < 0: - raise ValueError(f"Invalid codec_context_frames={ctx_frames} (must be >=0).") + # input_ids[0] = codec_context_frames (prepended by adapter). + ctx_frames = int(ids[0].item()) + ids = ids[1:] + n_tokens = ids.numel() - # input_ids may be padded; use codec_chunk_frames to slice the - # exact chunk (chunk_frames * q) and ignore padding. - if chunk_frames is None: - raise ValueError( - "Missing codec_chunk_frames in runtime_additional_information for Qwen3TTSCode2Wav. " - "This indicates the async_chunk connector/adapter contract was not applied." - ) - if chunk_frames < 0: - raise ValueError(f"Invalid codec_chunk_frames={chunk_frames} (must be >=0).") - expected_chunk_tokens = int(chunk_frames) * q - if expected_chunk_tokens == 0: - empty = torch.zeros((0,), dtype=torch.float32) - return empty, torch.tensor(self._output_sample_rate, dtype=torch.int32) - if ids.numel() < expected_chunk_tokens: - raise ValueError( - "Code2Wav received fewer tokens than expected for this chunk: " - f"got={int(ids.numel())} expected={expected_chunk_tokens} " - f"(chunk_frames={int(chunk_frames)} q={q}). " - "This indicates vLLM split the chunk across multiple forward calls; " - "the code2wav stage requires per-step frame-aligned chunks." + if n_tokens == 0: + return empty_ret + + # Warmup / dummy_run: not divisible by num_quantizers. + if n_tokens % q != 0: + logger.warning( + "Code2Wav input_ids length %d not divisible by num_quantizers %d, " + "likely a warmup run; returning empty audio.", + n_tokens, q, ) - if ids.numel() > expected_chunk_tokens: - # Extra non-padding tokens beyond expected_chunk_tokens indicate a scheduler/adapter contract violation. - extra = ids[expected_chunk_tokens:] - if extra.numel() > 0 and bool((extra != 0).any().item()): - raise ValueError( - "Code2Wav received extra non-padding tokens beyond the expected chunk length: " - f"got={int(ids.numel())} expected={expected_chunk_tokens} " - f"(chunk_frames={int(chunk_frames)} q={q}). " - "This indicates multiple codec chunks were scheduled in a single forward, " - "which breaks streaming trim/paste semantics." - ) - ids = ids[:expected_chunk_tokens] - - chunk_ids = ids - ctx_frames_i = int(ctx_frames or 0) - frames = int((ctx_frames_i if codec_streaming else 0) + int(chunk_frames)) - codes_fq = self._reconstruct_window_codes_fq( - chunk_ids=chunk_ids, - q=q, - chunk_frames=int(chunk_frames), - codec_streaming=bool(codec_streaming), - ctx_frames=ctx_frames_i, - ctx_codes=ctx_codes, - ) - if not self._logged_codec_stats and frames > 1: + return empty_ret + + total_frames = n_tokens // q + + # Reshape codebook-major flat [q*F] -> [q, F] -> [F, q] for SpeechTokenizer. + codes_fq = ids.reshape(q, total_frames).transpose(0, 1).contiguous() + + if not self._logged_codec_stats and total_frames > 1: self._logged_codec_stats = True try: uniq = int(torch.unique(codes_fq).numel()) cmin = int(codes_fq.min().item()) cmax = int(codes_fq.max().item()) - head = codes_fq[: min(2, frames), : min(8, q)].detach().to("cpu").tolist() + head = codes_fq[:min(2, total_frames), :min(8, q)].cpu().tolist() logger.info( - "Qwen3TTSCode2Wav received codec codes: frames=%d q=%d uniq=%d range=[%d,%d] head=%s", - frames, - q, - uniq, - cmin, - cmax, - head, + "Code2Wav codec: frames=%d q=%d uniq=%d range=[%d,%d] head=%s", + total_frames, q, uniq, cmin, cmax, head, ) except Exception: pass @@ -302,33 +191,21 @@ def forward( raise ValueError("SpeechTokenizer code2wav produced empty waveform list.") audio_np = wavs[0].astype(np.float32, copy=False) + # Trim left-context waveform samples (streaming sliding window). if ctx_frames > 0: - # Trim waveform samples corresponding to left-context frames in the sliding window. upsample = self._decode_upsample_rate - if upsample is None: - try: - upsample = int(tok.get_decode_upsample_rate()) - except Exception as e: - raise ValueError(f"Failed to get decode upsample rate: {e}") from e - if upsample <= 0: - raise ValueError(f"Invalid decode upsample rate: {upsample}") - self._decode_upsample_rate = upsample - - ctx_frames_i = int(ctx_frames) - if ctx_frames_i > frames: - raise ValueError(f"codec_context_frames={ctx_frames_i} exceeds frames={frames}") - - decoded = int(audio_np.shape[0]) - cut = int(ctx_frames_i) * int(upsample) - if cut > decoded: - raise ValueError( - "Streaming decode context trim exceeds decoded length: " - f"cut={cut} decoded={decoded} ctx_frames={ctx_frames_i} frames={frames}" + if upsample is None or upsample <= 0: + raise ValueError(f"Invalid decode upsample rate: {upsample}") + cut = ctx_frames * upsample + if cut < audio_np.shape[0]: + audio_np = audio_np[cut:] + else: + logger.warning( + "Context trim %d >= decoded length %d; returning empty audio.", + cut, audio_np.shape[0], ) - audio_np = audio_np[cut:] + return empty_ret - # Return 1D waveform per chunk so the output processor can concatenate along time. - # Returning [1, T] would stack chunks as channels. audio_tensor = torch.from_numpy(audio_np).to(dtype=torch.float32).reshape(-1) sr_tensor = torch.tensor(int(sr), dtype=torch.int32) return audio_tensor, sr_tensor diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml index f8a0b75a533..538473cf5e6 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml @@ -91,8 +91,8 @@ runtime: connector_get_sleep_s: 0.01 connector_get_max_wait_first_chunk: 3000 connector_get_max_wait: 300 - # Match official chunked_decode defaults. - codec_chunk_frames: 300 + # Align with Omni: small chunks with sufficient context overlap. + codec_chunk_frames: 25 codec_left_context_frames: 25 edges: diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml index c5e282e4d80..93fd1c22e04 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml @@ -54,10 +54,10 @@ stage_args: engine_output_type: audio gpu_memory_utilization: 0.2 distributed_executor_backend: "mp" - # Must be divisible by num_code_groups (typically 16). # Must be >= num_code_groups * (codec_left_context_frames + codec_chunk_frames). - max_num_batched_tokens: 1024 - max_model_len: 4096 + max_num_batched_tokens: 8192 + # async_chunk appends windows per step; max_model_len must cover accumulated stream. + max_model_len: 32768 engine_input_source: [0] final_output: true final_output_type: audio diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index d8c06ee35d3..0180137f394 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -5,15 +5,6 @@ import torch -def _get_request_info(request: Any) -> dict[str, Any]: - info = getattr(request, "additional_information_cpu", None) - if info is None: - info = getattr(request, "additional_information", None) - if isinstance(info, list) and info and isinstance(info[0], dict): - info = info[0] - return info if isinstance(info, dict) else {} - - def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: audio_codes = pooling_output.get("audio_codes") if not isinstance(audio_codes, torch.Tensor) or audio_codes.numel() == 0: @@ -36,14 +27,8 @@ def talker2code2wav_async_chunk( if not isinstance(pooling_output, dict): return None - info = _get_request_info(request) request_id = request.external_req_id - codec_streaming_raw = info.get("codec_streaming", True) - if isinstance(codec_streaming_raw, list): - codec_streaming_raw = codec_streaming_raw[0] if codec_streaming_raw else True - codec_streaming = codec_streaming_raw if isinstance(codec_streaming_raw, bool) else True - raw_cfg = getattr(connector, "config", {}) or {} cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} chunk_size = int(cfg.get("codec_chunk_frames", 25)) @@ -76,26 +61,14 @@ def talker2code2wav_async_chunk( if finished and (not appended_frame) and chunk_length == 0: return { "code_predictor_codes": [], - "codec_streaming": codec_streaming, - "codec_context_codes": [], "codec_context_frames": 0, - "codec_total_frames": 0, - "codec_chunk_frames": 0, - "codec_num_code_groups": 0, - "codec_layout": "codebook_major", "finished": torch.tensor(True, dtype=torch.bool), } if length <= 0: return { "code_predictor_codes": [], - "codec_streaming": codec_streaming, - "codec_context_codes": [], "codec_context_frames": 0, - "codec_total_frames": 0, - "codec_chunk_frames": 0, - "codec_num_code_groups": 0, - "codec_layout": "codebook_major", "finished": torch.tensor(bool(finished), dtype=torch.bool), } @@ -103,29 +76,13 @@ def talker2code2wav_async_chunk( ctx_frames = max(0, int(end_index - context_length)) window_frames = connector.code_prompt_token_ids[request_id][-end_index:] - if ctx_frames > 0: - ctx_part = window_frames[:ctx_frames] - codec_context_codes = torch.tensor(ctx_part).transpose(0, 1).reshape(-1).tolist() - else: - codec_context_codes = [] - - chunk_part = window_frames[ctx_frames:] - code_predictor_codes = torch.tensor(chunk_part).transpose(0, 1).reshape(-1).tolist() - - num_code_groups = int( - len(connector.code_prompt_token_ids[request_id][-1]) - if connector.code_prompt_token_ids[request_id] - else 0 + # Pack context + chunk into codebook-major flat codes for adapter. + code_predictor_codes = ( + torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist() ) return { "code_predictor_codes": code_predictor_codes, - "codec_streaming": codec_streaming, - "codec_context_codes": codec_context_codes, "codec_context_frames": int(ctx_frames), - "codec_total_frames": int(end_index), - "codec_chunk_frames": int(context_length), - "codec_num_code_groups": num_code_groups, - "codec_layout": "codebook_major", "finished": torch.tensor(bool(finished), dtype=torch.bool), } From a304e9e0f7223a22648d41c07d4dbfaf586c7177 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Tue, 10 Feb 2026 04:30:09 -0800 Subject: [PATCH 11/28] [~] Style: Fix code format errors of pre-commit Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/entrypoints/openai/serving_speech.py | 1 + .../models/qwen3_tts/qwen3_tts_code2wav.py | 15 +++++++++++---- .../qwen3_tts/qwen3_tts_code_predictor_vllm.py | 6 ++++-- .../models/qwen3_tts/qwen3_tts_tokenizer.py | 3 +-- .../stage_input_processors/qwen3_tts.py | 4 +--- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 63d94ed4528..97251f8ba98 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -311,6 +311,7 @@ async def create_speech( # Streaming accumulates chunks as a list; concat first. if isinstance(audio_tensor, list): import torch + audio_tensor = torch.cat(audio_tensor, dim=-1) # Convert tensor to numpy if hasattr(audio_tensor, "float"): diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 300ce0dff8e..4e381c20143 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -163,7 +163,8 @@ def forward( logger.warning( "Code2Wav input_ids length %d not divisible by num_quantizers %d, " "likely a warmup run; returning empty audio.", - n_tokens, q, + n_tokens, + q, ) return empty_ret @@ -178,10 +179,15 @@ def forward( uniq = int(torch.unique(codes_fq).numel()) cmin = int(codes_fq.min().item()) cmax = int(codes_fq.max().item()) - head = codes_fq[:min(2, total_frames), :min(8, q)].cpu().tolist() + head = codes_fq[: min(2, total_frames), : min(8, q)].cpu().tolist() logger.info( "Code2Wav codec: frames=%d q=%d uniq=%d range=[%d,%d] head=%s", - total_frames, q, uniq, cmin, cmax, head, + total_frames, + q, + uniq, + cmin, + cmax, + head, ) except Exception: pass @@ -202,7 +208,8 @@ def forward( else: logger.warning( "Context trim %d >= decoded length %d; returning empty audio.", - cut, audio_np.shape[0], + cut, + audio_np.shape[0], ) return empty_ret diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py index f131ae4170c..68d4baf51c5 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -384,7 +384,8 @@ def prefill_logits(self, inputs_embeds: torch.Tensor) -> torch.Tensor: with ( set_current_vllm_config(self._vllm_config), set_forward_context( - attn_metadata, self._vllm_config, + attn_metadata, + self._vllm_config, num_tokens=int(hs.shape[0]), slot_mapping=slot_mappings, ), @@ -419,7 +420,8 @@ def decode_logits(self, input_ids: torch.Tensor, *, generation_step: int, past_s with ( set_current_vllm_config(self._vllm_config), set_forward_context( - attn_metadata, self._vllm_config, + attn_metadata, + self._vllm_config, num_tokens=int(hs.shape[0]), slot_mapping=slot_mappings, ), diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py index 01d9abd95c3..785ddedab50 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py @@ -92,8 +92,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3 inst.config = inst.model.config inst.feature_extractor = ( - AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path) - if load_feature_extractor else None + AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path) if load_feature_extractor else None ) inst.device = getattr(inst.model, "device", None) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 0180137f394..2930b17d068 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -77,9 +77,7 @@ def talker2code2wav_async_chunk( window_frames = connector.code_prompt_token_ids[request_id][-end_index:] # Pack context + chunk into codebook-major flat codes for adapter. - code_predictor_codes = ( - torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist() - ) + code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist() return { "code_predictor_codes": code_predictor_codes, From 6eea932f4dea89c5e96e932a8bc5a22516753400 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Tue, 10 Feb 2026 04:33:21 -0800 Subject: [PATCH 12/28] [~] Style: Re-fix code format errors of pre-commit Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/entrypoints/openai/serving_speech.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 97251f8ba98..0a5470e4b7d 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -69,11 +69,15 @@ def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int: from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker_ar import ( Qwen3TTSTalkerForConditionalGenerationARVLLM, ) + if self._tts_tokenizer is None: from transformers import AutoTokenizer + model_name = self.engine_client.model_config.model self._tts_tokenizer = AutoTokenizer.from_pretrained( - model_name, trust_remote_code=True, padding_side="left", + model_name, + trust_remote_code=True, + padding_side="left", ) hf_config = self.engine_client.model_config.hf_config talker_config = hf_config.talker_config From 3a061199acd4dbe3d6de45b5cb49360323586d84 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Tue, 10 Feb 2026 09:56:25 -0800 Subject: [PATCH 13/28] [~] Refactor: Remove Qwen3 TTS model files and update registry to reflect changes Signed-off-by: Sy03 <1370724210@qq.com> --- .../entrypoints/openai/serving_speech.py | 6 +- .../models/qwen3_tts/modeling_qwen3_tts.py | 2326 ----------------- .../models/qwen3_tts/processing_qwen3_tts.py | 102 - .../qwen3_tts/qwen3_tts_disaggregated.py | 227 -- ...3_tts_talker_ar.py => qwen3_tts_talker.py} | 262 +- vllm_omni/model_executor/models/registry.py | 16 +- .../stage_configs/qwen3_tts.yaml | 4 +- ...s_talker_speech_tokenizer_async_chunk.yaml | 4 +- 8 files changed, 260 insertions(+), 2687 deletions(-) delete mode 100644 vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py delete mode 100644 vllm_omni/model_executor/models/qwen3_tts/processing_qwen3_tts.py delete mode 100644 vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py rename vllm_omni/model_executor/models/qwen3_tts/{qwen3_tts_talker_ar.py => qwen3_tts_talker.py} (85%) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 0a5470e4b7d..f6d642bd49f 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -66,8 +66,8 @@ def _load_supported_speakers(self) -> set[str]: def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int: """Estimate prompt length so the placeholder matches model-side embeddings.""" try: - from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker_ar import ( - Qwen3TTSTalkerForConditionalGenerationARVLLM, + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( + Qwen3TTSTalkerForConditionalGeneration, ) if self._tts_tokenizer is None: @@ -82,7 +82,7 @@ def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int: hf_config = self.engine_client.model_config.hf_config talker_config = hf_config.talker_config task_type = (tts_params.get("task_type") or ["CustomVoice"])[0] - return Qwen3TTSTalkerForConditionalGenerationARVLLM.estimate_prompt_len_from_additional_information( + return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information( additional_information=tts_params, task_type=task_type, tokenize_prompt=lambda t: self._tts_tokenizer(t, padding=False)["input_ids"], diff --git a/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py deleted file mode 100644 index 1e759a8d2b4..00000000000 --- a/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py +++ /dev/null @@ -1,2326 +0,0 @@ -# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Qwen3TTS model.""" - -import json -import os -from collections.abc import Callable -from dataclasses import dataclass - -import torch -from librosa.filters import mel as librosa_mel_fn -from torch import nn -from torch.nn import functional as F -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.generation import GenerationMixin -from transformers.integrations import use_kernel_forward_from_hub -from transformers.masking_utils import ( - create_causal_mask, - create_sliding_window_causal_mask, -) -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_layers import GradientCheckpointingLayer -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - ModelOutput, -) -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from transformers.processing_utils import Unpack -from transformers.utils import can_return_tuple, logging -from transformers.utils.hub import cached_file - -from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific - -from .configuration_qwen3_tts import ( - Qwen3TTSConfig, - Qwen3TTSSpeakerEncoderConfig, - Qwen3TTSTalkerCodePredictorConfig, - Qwen3TTSTalkerConfig, -) -from .qwen3_tts_tokenizer import Qwen3TTSTokenizer - -logger = logging.get_logger(__name__) - - -class Res2NetBlock(torch.nn.Module): - def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): - super().__init__() - - in_channel = in_channels // scale - hidden_channel = out_channels // scale - - self.blocks = nn.ModuleList( - [ - TimeDelayNetBlock( - in_channel, - hidden_channel, - kernel_size=kernel_size, - dilation=dilation, - ) - for i in range(scale - 1) - ] - ) - self.scale = scale - - def forward(self, hidden_states): - outputs = [] - for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): - if i == 0: - output_part = hidden_part - elif i == 1: - output_part = self.blocks[i - 1](hidden_part) - else: - output_part = self.blocks[i - 1](hidden_part + output_part) - outputs.append(output_part) - output = torch.cat(outputs, dim=1) - return output - - -class SqueezeExcitationBlock(nn.Module): - def __init__(self, in_channels, se_channels, out_channels): - super().__init__() - - self.conv1 = nn.Conv1d( - in_channels=in_channels, - out_channels=se_channels, - kernel_size=1, - padding="same", - padding_mode="reflect", - ) - self.relu = nn.ReLU(inplace=True) - self.conv2 = nn.Conv1d( - in_channels=se_channels, - out_channels=out_channels, - kernel_size=1, - padding="same", - padding_mode="reflect", - ) - self.sigmoid = nn.Sigmoid() - - def forward(self, hidden_states): - hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) - - hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) - hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) - - return hidden_states * hidden_states_mean - - -class AttentiveStatisticsPooling(nn.Module): - """This class implements an attentive statistic pooling layer for each channel. - It returns the concatenated mean and std of the input tensor. - """ - - def __init__(self, channels, attention_channels=128): - super().__init__() - - self.eps = 1e-12 - self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) - self.tanh = nn.Tanh() - self.conv = nn.Conv1d( - in_channels=attention_channels, - out_channels=channels, - kernel_size=1, - padding="same", - padding_mode="reflect", - ) - - def _length_to_mask(self, length, max_len=None, dtype=None, device=None): - """Creates a binary mask for each sequence. - - Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 - - Arguments - --------- - length : torch.LongTensor - Containing the length of each sequence in the batch. Must be 1D. - max_len : int - Max length for the mask, also the size of the second dimension. - dtype : torch.dtype, default: None - The dtype of the generated mask. - device: torch.device, default: None - The device to put the mask variable. - - Returns - ------- - mask : tensor - The binary mask. - """ - - if max_len is None: - max_len = length.max().long().item() # using arange to generate mask - mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( - len(length), max_len - ) < length.unsqueeze(1) - - mask = torch.as_tensor(mask, dtype=dtype, device=device) - return mask - - def _compute_statistics(self, x, m, dim=2): - mean = (m * x).sum(dim) - std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) - return mean, std - - def forward(self, hidden_states): - seq_length = hidden_states.shape[-1] - lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) - - # Make binary mask of shape [N, 1, L] - mask = self._length_to_mask( - lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device - ) - mask = mask.unsqueeze(1) - - # Expand the temporal context of the pooling layer by allowing the - # self-attention to look at global properties of the utterance. - total = mask.sum(dim=2, keepdim=True) - - mean, std = self._compute_statistics(hidden_states, mask / total) - mean = mean.unsqueeze(2).repeat(1, 1, seq_length) - std = std.unsqueeze(2).repeat(1, 1, seq_length) - attention = torch.cat([hidden_states, mean, std], dim=1) - - # Apply layers - attention = self.conv(self.tanh(self.tdnn(attention))) - - # Filter out zero-paddings - attention = attention.masked_fill(mask == 0, float("-inf")) - - attention = F.softmax(attention, dim=2) - mean, std = self._compute_statistics(hidden_states, attention) - # Append mean and std of the batch - pooled_stats = torch.cat((mean, std), dim=1) - pooled_stats = pooled_stats.unsqueeze(2) - - return pooled_stats - - -class TimeDelayNetBlock(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - dilation, - ): - super().__init__() - self.conv = nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dilation=dilation, - padding="same", - padding_mode="reflect", - ) - self.activation = nn.ReLU() - - def forward(self, hidden_states: torch.Tensor): - return self.activation(self.conv(hidden_states)) - - -class SqueezeExcitationRes2NetBlock(nn.Module): - """An implementation of building block in ECAPA-TDNN, i.e., - TDNN-Res2Net-TDNN-SqueezeExcitationBlock. - """ - - def __init__( - self, - in_channels, - out_channels, - res2net_scale=8, - se_channels=128, - kernel_size=1, - dilation=1, - ): - super().__init__() - self.out_channels = out_channels - self.tdnn1 = TimeDelayNetBlock( - in_channels, - out_channels, - kernel_size=1, - dilation=1, - ) - self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) - self.tdnn2 = TimeDelayNetBlock( - out_channels, - out_channels, - kernel_size=1, - dilation=1, - ) - self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) - - def forward(self, hidden_state): - residual = hidden_state - - hidden_state = self.tdnn1(hidden_state) - hidden_state = self.res2net_block(hidden_state) - hidden_state = self.tdnn2(hidden_state) - hidden_state = self.se_block(hidden_state) - - return hidden_state + residual - - -class Qwen3TTSSpeakerEncoder(torch.nn.Module): - """An implementation of the speaker embedding model in a paper. - "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in - TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143). - Use for Qwen3TTS extract speaker embedding. - """ - - def __init__(self, config: Qwen3TTSSpeakerEncoderConfig): - super().__init__() - if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( - config.enc_dilations - ): - raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") - self.channels = config.enc_channels - self.blocks = nn.ModuleList() - - # The initial TDNN layer - self.blocks.append( - TimeDelayNetBlock( - config.mel_dim, - config.enc_channels[0], - config.enc_kernel_sizes[0], - config.enc_dilations[0], - ) - ) - - # SE-Res2Net layers - for i in range(1, len(config.enc_channels) - 1): - self.blocks.append( - SqueezeExcitationRes2NetBlock( - config.enc_channels[i - 1], - config.enc_channels[i], - res2net_scale=config.enc_res2net_scale, - se_channels=config.enc_se_channels, - kernel_size=config.enc_kernel_sizes[i], - dilation=config.enc_dilations[i], - ) - ) - - # Multi-layer feature aggregation - self.mfa = TimeDelayNetBlock( - config.enc_channels[-1], - config.enc_channels[-1], - config.enc_kernel_sizes[-1], - config.enc_dilations[-1], - ) - - # Attentive Statistical Pooling - self.asp = AttentiveStatisticsPooling( - config.enc_channels[-1], - attention_channels=config.enc_attention_channels, - ) - - # Final linear transformation - self.fc = nn.Conv1d( - in_channels=config.enc_channels[-1] * 2, - out_channels=config.enc_dim, - kernel_size=1, - padding="same", - padding_mode="reflect", - ) - - def forward(self, hidden_states): - # Minimize transpose for efficiency - hidden_states = hidden_states.transpose(1, 2) - - hidden_states_list = [] - for layer in self.blocks: - hidden_states = layer(hidden_states) - hidden_states_list.append(hidden_states) - - # Multi-layer feature aggregation - hidden_states = torch.cat(hidden_states_list[1:], dim=1) - hidden_states = self.mfa(hidden_states) - - # Attentive Statistical Pooling - hidden_states = self.asp(hidden_states) - - # Final linear transformation - hidden_states = self.fc(hidden_states) - - hidden_states = hidden_states.squeeze(-1) - return hidden_states - - -def dynamic_range_compression_torch(x, c=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * c) - - -def mel_spectrogram( - y: torch.Tensor, - n_fft: int, - num_mels: int, - sampling_rate: int, - hop_size: int, - win_size: int, - fmin: int, - fmax: int = None, - center: bool = False, -) -> torch.Tensor: - """ - Calculate the mel spectrogram of an input signal. - This function uses slaney norm for the librosa mel filterbank - (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). - - Args: - y (torch.Tensor): Input signal. - n_fft (int): FFT size. - num_mels (int): Number of mel bins. - sampling_rate (int): Sampling rate of the input signal. - hop_size (int): Hop size for STFT. - win_size (int): Window size for STFT. - fmin (int): Minimum frequency for mel filterbank. - fmax (int): Maximum frequency for mel filterbank. - If None, defaults to half the sampling rate (fmax = sr / 2.0) - inside librosa_mel_fn - center (bool): Whether to pad the input to center the frames. Default is False. - - Returns: - torch.Tensor: Mel spectrogram. - """ - if torch.min(y) < -1.0: - print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") - if torch.max(y) > 1.0: - print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") - - device = y.device - - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - - mel_basis = torch.from_numpy(mel).float().to(device) - hann_window = torch.hann_window(win_size).to(device) - - padding = (n_fft - hop_size) // 2 - y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) - - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window, - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) - - mel_spec = torch.matmul(mel_basis, spec) - mel_spec = dynamic_range_compression_torch(mel_spec) - - return mel_spec - - -def _compute_default_rope_parameters( - config, - device, -): - base = config.rope_theta - partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) - return inv_freq, attention_factor - - -class Qwen3TTSPreTrainedModel(PreTrainedModel): - config_class = Qwen3TTSConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen3TTSDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_static_cache = False - _supports_attention_backend = True - - def _init_weights(self, module): - # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 - - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if module.weight is not None: - module.weight.data.fill_(1.0) - if module.bias is not None: - module.bias.data.zero_() - - -class Qwen3TTSTalkerTextPreTrainedModel(PreTrainedModel): - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = [] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_3 = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = False - _supports_attention_backend = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Qwen3TTSRMSNorm): - module.weight.data.fill_(1.0) - - -class Qwen3TTSTalkerRotaryEmbedding(nn.Module): - def __init__(self, config: Qwen3TTSTalkerConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn: Callable = _compute_default_rope_parameters - if self.rope_type != "default": - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - # In contrast to other models, Qwen3TTSThinkerText has different position ids for the grids - # So we expand the inv_freq to shape (3, ...) - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) - position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Qwen3TTSRotaryEmbedding(nn.Module): - def __init__(self, config: Qwen3TTSConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn: Callable = _compute_default_rope_parameters - if self.rope_type != "default": - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -@use_kernel_forward_from_hub("RMSNorm") -class Qwen3TTSRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen3TTSRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, mrope_interleaved=False, unsqueeze_dim=1): - """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). - - Explanation: - Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding - sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. - Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. - For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, - height and width) of text embedding is always the same, so the text embedding rotary position embedding has no - difference with modern LLMs. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - if mrope_interleaved: - - def apply_interleaved_rope(x, modality_num): - x_t = x[0].clone() - index_ranges = [] - for i, n in enumerate(mrope_section[1:], 1): - beg_idx = i - end_idx = n * modality_num - index_ranges.append((beg_idx, end_idx)) - for beg_idx, end_idx in index_ranges: - x_t[..., beg_idx:end_idx:modality_num] = x[beg_idx, ..., beg_idx:end_idx:modality_num] - return x_t - - dim = cos.shape[-1] - modality_num = len(mrope_section) - cos = torch.cat([apply_interleaved_rope(cos[..., : dim // 2], modality_num)] * 2, dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([apply_interleaved_rope(sin[..., : dim // 2], modality_num)] * 2, dim=-1).unsqueeze( - unsqueeze_dim - ) - else: - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class Qwen3TTSTalkerAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config, layer_idx): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - self.q_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! - self.k_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = getattr(config, "sliding_window", None) - self.rope_scaling = config.rope_scaling - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None, - past_key_values: Cache | None = None, - cache_position: torch.LongTensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"], self.rope_scaling["interleaved"] - ) - - if past_key_values is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=self.sliding_window, # diff with Llama - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class Qwen3TTSTalkerResizeMLP(nn.Module): - def __init__(self, input_size: int, intermediate_size: int, output_size: int, act: str, bias=False): - super().__init__() - self.linear_fc1 = nn.Linear(input_size, intermediate_size, bias=bias) - self.linear_fc2 = nn.Linear(intermediate_size, output_size, bias=bias) - self.act_fn = ACT2FN[act] - - def forward(self, hidden_state): - return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) - - -@dataclass -class Qwen3TTSTalkerCodePredictorOutputWithPast(ModelOutput): - r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head - (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, - returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - """ - - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor = None - past_key_values: list[torch.FloatTensor] | None = None - hidden_states: tuple[torch.FloatTensor] | None = None - attentions: tuple[torch.FloatTensor] | None = None - generation_steps: int | None = None - - -class Qwen3TTSTalkerTextMLP(nn.Module): - def __init__(self, config, intermediate_size=None): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class Qwen3TTSAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Qwen3TTSConfig, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - self.q_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! - self.k_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None, - past_key_values: Cache | None = None, - cache_position: torch.LongTensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_values is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=self.sliding_window, # diff with Llama - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class Qwen3TTSDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: Qwen3TTSConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = Qwen3TTSAttention(config=config, layer_idx=layer_idx) - - self.mlp = Qwen3TTSTalkerTextMLP(config) - self.input_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -class Qwen3TTSTalkerCodePredictorModel(Qwen3TTSPreTrainedModel): - config_class = Qwen3TTSTalkerCodePredictorConfig - base_model_prefix = "talker.code_predictor.model" - - def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig, embedding_dim: int): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.layers = nn.ModuleList( - [Qwen3TTSDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen3TTSRotaryEmbedding(config=config) - self.gradient_checkpointing = False - self.has_sliding_layers = "sliding_attention" in self.config.layer_types - self.codec_embedding = nn.ModuleList( - [nn.Embedding(config.vocab_size, embedding_dim) for _ in range(config.num_code_groups - 1)] - ) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.codec_embedding - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @can_return_tuple - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - generation_steps=None, - **flash_attn_kwargs, - ) -> BaseModelOutputWithPast: - if input_ids is not None: - raise ValueError("`input_ids` is expected to be `None`") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache - if not isinstance(past_key_values, (type(None), Cache)): - raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - # It may already have been prepared by e.g. `generate` - if not isinstance(causal_mask_mapping := attention_mask, dict): - # Prepare mask arguments - mask_kwargs = { - "config": self.config, - "input_embeds": inputs_embeds, - "attention_mask": attention_mask, - "cache_position": cache_position, - "past_key_values": past_key_values, - } - # Create the masks - causal_mask_mapping = { - "full_attention": create_causal_mask(**mask_kwargs), - } - # The sliding window alternating layers are not always activated depending on the config - if self.has_sliding_layers: - causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen3TTSTalkerCodePredictorModelForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - config_class = Qwen3TTSTalkerCodePredictorConfig - base_model_prefix = "talker.code_predictor" - - def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig, talker_config: Qwen3TTSTalkerConfig): - super().__init__(config) - self.model = Qwen3TTSTalkerCodePredictorModel(config, talker_config.hidden_size) - self.vocab_size = config.vocab_size - self.lm_head = nn.ModuleList( - [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)] - ) - - if config.hidden_size != talker_config.hidden_size: - self.small_to_mtp_projection = torch.nn.Linear(talker_config.hidden_size, config.hidden_size, bias=True) - else: - self.small_to_mtp_projection = torch.nn.Identity() - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward_finetune( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - generation_steps=None, - **kwargs, - ) -> CausalLMOutputWithPast: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=None, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - - logits = [] - for i in range(1, self.config.num_code_groups): - logits.append(self.lm_head[i - 1](hidden_states[:, i])) - logits = torch.stack(logits, dim=1) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return Qwen3TTSTalkerCodePredictorOutputWithPast(loss=loss, logits=logits) - - @can_return_tuple - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - generation_steps=None, - **kwargs, - ) -> CausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # Prefill stage - if inputs_embeds is not None and inputs_embeds.shape[1] > 1: - generation_steps = inputs_embeds.shape[1] - 2 # hidden & layer 0 - # Generation stage - else: - inputs_embeds = self.model.get_input_embeddings()[generation_steps - 1](input_ids) - inputs_embeds = self.small_to_mtp_projection(inputs_embeds) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=None, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - logits = self.lm_head[generation_steps](hidden_states) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return Qwen3TTSTalkerCodePredictorOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - generation_steps=generation_steps + 1, - ) - - def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1): - model_kwargs = super()._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder, num_new_tokens - ) - model_kwargs["generation_steps"] = outputs.generation_steps - return model_kwargs - - -@dataclass -class Qwen3TTSTalkerOutputWithPast(ModelOutput): - r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head - (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, - returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - """ - - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - past_key_values: list[torch.FloatTensor] | None = None - hidden_states: tuple[torch.FloatTensor] | None = None - attentions: tuple[torch.FloatTensor] | None = None - past_hidden: torch.FloatTensor | None = None - generation_step: int | None = None - trailing_text_hidden: torch.FloatTensor | None = None - tts_pad_embed: torch.FloatTensor | None = None - - -class Qwen3TTSTalkerDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config, layer_idx): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = Qwen3TTSTalkerAttention(config, layer_idx) - - self.mlp = Qwen3TTSTalkerTextMLP(config, intermediate_size=config.intermediate_size) - - self.input_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: tuple[torch.Tensor] | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - hidden_states = self.mlp(hidden_states) - - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -class Qwen3TTSTalkerModel(Qwen3TTSTalkerTextPreTrainedModel): - config_class = Qwen3TTSTalkerConfig - base_model_prefix = "talker.model" - - def __init__(self, config): - super().__init__(config) - self.vocab_size = config.vocab_size - self.layers = nn.ModuleList( - [Qwen3TTSTalkerDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen3TTSTalkerRotaryEmbedding(config) - self.gradient_checkpointing = False - self.codec_embedding = nn.Embedding(config.vocab_size, config.hidden_size) - self.text_embedding = nn.Embedding(config.text_vocab_size, config.text_hidden_size) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.codec_embedding - - def get_text_embeddings(self): - return self.text_embedding - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @can_return_tuple - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - cache_position: torch.LongTensor | None = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> BaseModelOutputWithPast: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - # the hard coded `3` is for temporal, height and width. - if position_ids is None: - position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) - elif position_ids.ndim == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - - if position_ids.ndim == 3 and position_ids.shape[0] == 4: - text_position_ids = position_ids[0] - position_ids = position_ids[1:] - else: - text_position_ids = position_ids[0] - - mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask - causal_mask = mask_function( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=text_position_ids, - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=text_position_ids, - past_key_values=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen3TTSTalkerForConditionalGeneration(Qwen3TTSTalkerTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - config_class = Qwen3TTSTalkerConfig - base_model_prefix = "talker" - - def __init__(self, config: Qwen3TTSTalkerConfig): - super().__init__(config) - self.model = Qwen3TTSTalkerModel(config) - self.vocab_size = config.vocab_size - self.text_projection = Qwen3TTSTalkerResizeMLP( - config.text_hidden_size, config.text_hidden_size, config.hidden_size, config.hidden_act, bias=True - ) - - self.codec_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.code_predictor = Qwen3TTSTalkerCodePredictorModelForConditionalGeneration( - config=config.code_predictor_config, talker_config=config - ) - self.rope_deltas = None - - # Initialize weights and apply final processing - self.post_init() - - # TODO: hack, modular cannot inherit multiple classes - - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def get_text_embeddings(self): - return self.model.get_text_embeddings() - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward_sub_talker_finetune(self, codec_ids, talker_hidden_states): - assert len(codec_ids.shape) == 2 - assert len(talker_hidden_states.shape) == 2 - assert codec_ids.shape[0] == talker_hidden_states.shape[0] - assert talker_hidden_states.shape[1] == self.config.hidden_size - assert codec_ids.shape[1] == self.config.num_code_groups - - sub_talker_inputs_embeds = [talker_hidden_states.unsqueeze(1)] - - for i in range(self.config.num_code_groups - 1): - if i == 0: - sub_talker_inputs_embeds.append(self.get_input_embeddings()(codec_ids[:, :1])) - else: - sub_talker_inputs_embeds.append( - self.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, i : i + 1]) - ) - sub_talker_inputs_embeds = torch.cat(sub_talker_inputs_embeds, dim=1) - - sub_talker_outputs = self.code_predictor.forward_finetune( - inputs_embeds=sub_talker_inputs_embeds, labels=codec_ids[:, 1:] - ) - - sub_talker_logits = sub_talker_outputs.logits - sub_talker_loss = sub_talker_outputs.loss - return sub_talker_logits, sub_talker_loss - - @can_return_tuple - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - past_hidden=None, - trailing_text_hidden=None, - tts_pad_embed=None, - generation_step=None, - subtalker_dosample=None, - subtalker_top_p=None, - subtalker_top_k=None, - subtalker_temperature=None, - **kwargs, - ) -> CausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - ```""" - # Prefill - if inputs_embeds is not None and inputs_embeds.shape[1] > 1: - generation_step = -1 - codec_ids = None - # Generate - else: - last_id_hidden = self.get_input_embeddings()(input_ids) - predictor_result = self.code_predictor.generate( - inputs_embeds=torch.cat((past_hidden, last_id_hidden), dim=1), - max_new_tokens=self.config.num_code_groups - 1, - do_sample=subtalker_dosample, - top_p=subtalker_top_p, - top_k=subtalker_top_k, - temperature=subtalker_temperature, - output_hidden_states=True, - return_dict_in_generate=True, - ) - codec_ids = torch.cat((input_ids, predictor_result.sequences), dim=-1) - codec_hiddens = torch.cat( - [last_id_hidden] - + [ - self.code_predictor.get_input_embeddings()[i](predictor_result.sequences[..., i : i + 1]) - for i in range(self.config.num_code_groups - 1) - ], - dim=1, - ) - inputs_embeds = codec_hiddens.sum(1, keepdim=True) - - if generation_step < trailing_text_hidden.shape[1]: - inputs_embeds = inputs_embeds + trailing_text_hidden[:, generation_step].unsqueeze(1) - else: - inputs_embeds = inputs_embeds + tts_pad_embed - if attention_mask is not None: - if ( - cache_position is None - or (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - ): - delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) - position_ids, rope_deltas = self.get_rope_index( - attention_mask, - ) - rope_deltas = rope_deltas - delta0 - self.rope_deltas = rope_deltas - else: - batch_size, seq_length = input_ids.shape - delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 - position_ids = torch.arange(seq_length, device=input_ids.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - - outputs: BaseModelOutputWithPast = self.model( - input_ids=None, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - logits = self.codec_head(hidden_states) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return Qwen3TTSTalkerOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=(outputs.hidden_states, codec_ids), - attentions=outputs.attentions, - past_hidden=hidden_states[:, -1:, :], - generation_step=generation_step + 1, - trailing_text_hidden=trailing_text_hidden, - tts_pad_embed=tts_pad_embed, - ) - - def get_rope_index( - self, - attention_mask: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Calculate the 3D rope index based on image and video's temporal, height and width in LLM. - - Explanation: - Each embedding sequence contains vision embedding and text embedding or just contains text embedding. - - For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. - Examples: - input_ids: [T T T T T], here T is for text. - temporal position_ids: [0, 1, 2, 3, 4] - height position_ids: [0, 1, 2, 3, 4] - width position_ids: [0, 1, 2, 3, 4] - - For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part - and 1D rotary position embedding for text part. - Examples: - Temporal (Time): 3 patches, representing different segments of the video in time. - Height: 2 patches, dividing each frame vertically. - Width: 2 patches, dividing each frame horizontally. - We also have some important parameters: - fps (Frames Per Second): The video's frame rate, set to 1. - This means one frame is processed each second. - interval: The step size for the temporal position IDs, - calculated as tokens_per_second * temporal_patch_size / fps. - In this case, 25 * 2 / 1 = 50. This means that each temporal - patch will be have a difference of 50 in the temporal position IDs. - input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. - text temporal position_ids: [101, 102, 103, 104, 105] - text height position_ids: [101, 102, 103, 104, 105] - text width position_ids: [101, 102, 103, 104, 105] - Here we calculate the text start position_ids as the max vision position_ids plus 1. - - Args: - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - Returns: - position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) - mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) - """ - mrope_position_deltas = [] - - position_ids = attention_mask.float().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) - - return position_ids, mrope_position_deltas - - def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1): - model_kwargs = super()._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder, num_new_tokens - ) - model_kwargs["past_hidden"] = outputs.past_hidden - model_kwargs["generation_step"] = outputs.generation_step - model_kwargs["trailing_text_hidden"] = outputs.trailing_text_hidden - model_kwargs["tts_pad_embed"] = outputs.tts_pad_embed - return model_kwargs - - -class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin): - config_class = Qwen3TTSConfig - - def __init__(self, config: Qwen3TTSConfig): - super().__init__(config) - self.config = config - - self.talker = Qwen3TTSTalkerForConditionalGeneration(self.config.talker_config) - - if config.tts_model_type == "base": - self.speaker_encoder = Qwen3TTSSpeakerEncoder(self.config.speaker_encoder_config) - else: - self.speaker_encoder = None - - self.speech_tokenizer = None - self.generate_config = None - - self.supported_speakers = self.config.talker_config.spk_id.keys() - self.supported_languages = ["auto"] - for language_id in self.config.talker_config.codec_language_id.keys(): - if "dialect" not in language_id: - self.supported_languages.append(language_id) - - self.speaker_encoder_sample_rate = self.config.speaker_encoder_config.sample_rate - self.tokenizer_type = self.config.tokenizer_type - self.tts_model_size = self.config.tts_model_size - self.tts_model_type = self.config.tts_model_type - - self.post_init() - - def load_speech_tokenizer(self, speech_tokenizer): - self.speech_tokenizer = speech_tokenizer - - def load_generate_config(self, generate_config): - self.generate_config = generate_config - - def get_supported_speakers(self): - return self.supported_speakers - - def get_supported_languages(self): - return self.supported_languages - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path, - *model_args, - config=None, - cache_dir=None, - ignore_mismatched_sizes=False, - force_download=False, - local_files_only=False, - token=None, - revision="main", - use_safetensors=None, - weights_only=True, - **kwargs, - ): - model = super().from_pretrained( - pretrained_model_name_or_path, - *model_args, - config=config, - cache_dir=cache_dir, - ignore_mismatched_sizes=ignore_mismatched_sizes, - force_download=force_download, - local_files_only=local_files_only, - token=token, - revision=revision, - use_safetensors=use_safetensors, - weights_only=weights_only, - **kwargs, - ) - if not local_files_only and not os.path.isdir(pretrained_model_name_or_path): - download_cache_dir = kwargs.get("cache_dir", cache_dir) - download_revision = kwargs.get("revision", revision) - download_weights_from_hf_specific( - pretrained_model_name_or_path, - cache_dir=download_cache_dir, - allow_patterns=["speech_tokenizer/*"], - revision=download_revision, - ) - speech_tokenizer_path = cached_file( - pretrained_model_name_or_path, - "speech_tokenizer/config.json", - subfolder=kwargs.pop("subfolder", None), - cache_dir=kwargs.pop("cache_dir", None), - force_download=kwargs.pop("force_download", False), - proxies=kwargs.pop("proxies", None), - resume_download=kwargs.pop("resume_download", None), - local_files_only=kwargs.pop("local_files_only", False), - token=kwargs.pop("use_auth_token", None), - revision=kwargs.pop("revision", None), - ) - if speech_tokenizer_path is None: - raise ValueError(f"""{pretrained_model_name_or_path}/{speech_tokenizer_path} not exists""") - speech_tokenizer_dir = os.path.dirname(speech_tokenizer_path) - speech_tokenizer = Qwen3TTSTokenizer.from_pretrained( - speech_tokenizer_dir, - *model_args, - **kwargs, - ) - model.load_speech_tokenizer(speech_tokenizer) - - generate_config_path = cached_file( - pretrained_model_name_or_path, - "generation_config.json", - subfolder=kwargs.pop("subfolder", None), - cache_dir=kwargs.pop("cache_dir", None), - force_download=kwargs.pop("force_download", False), - proxies=kwargs.pop("proxies", None), - resume_download=kwargs.pop("resume_download", None), - local_files_only=kwargs.pop("local_files_only", False), - token=kwargs.pop("use_auth_token", None), - revision=kwargs.pop("revision", None), - ) - with open(generate_config_path, encoding="utf-8") as f: - generate_config = json.load(f) - model.load_generate_config(generate_config) - - return model - - @torch.inference_mode() - def extract_speaker_embedding(self, audio, sr): - assert sr == 24000, "Only support 24kHz audio" - mels = mel_spectrogram( - torch.from_numpy(audio).unsqueeze(0), - n_fft=1024, - num_mels=128, - sampling_rate=24000, - hop_size=256, - win_size=1024, - fmin=0, - fmax=12000, - ).transpose(1, 2) - speaker_embedding = self.speaker_encoder(mels.to(self.device).to(self.dtype))[0] - return speaker_embedding - - @torch.inference_mode() - def generate_speaker_prompt(self, voice_clone_prompt: list[dict]): - voice_clone_spk_embeds = [] - for index in range(len(voice_clone_prompt["ref_spk_embedding"])): - ref_spk_embedding = ( - voice_clone_prompt["ref_spk_embedding"][index].to(self.talker.device).to(self.talker.dtype) - ) - voice_clone_spk_embeds.append(ref_spk_embedding) - - return voice_clone_spk_embeds - - def generate_icl_prompt( - self, - text_id: torch.Tensor, - ref_id: torch.Tensor, - ref_code: torch.Tensor, - tts_pad_embed: torch.Tensor, - tts_eos_embed: torch.Tensor, - non_streaming_mode: bool, - ): - # text embed (ref id + text id + eos) 1 T1 D - text_embed = self.talker.text_projection( - self.talker.get_text_embeddings()(torch.cat([ref_id, text_id], dim=-1)) - ) - text_embed = torch.cat([text_embed, tts_eos_embed], dim=1) - # codec embed (codec bos + codec) 1 T2 D - codec_embed = [] - for i in range(self.talker.config.num_code_groups): - if i == 0: - codec_embed.append(self.talker.get_input_embeddings()(ref_code[:, :1])) - else: - codec_embed.append(self.talker.code_predictor.get_input_embeddings()[i - 1](ref_code[:, i : i + 1])) - codec_embed = torch.cat(codec_embed, dim=1).sum(1).unsqueeze(0) - codec_embed = torch.cat( - [ - self.talker.get_input_embeddings()( - torch.tensor( - [ - [ - self.config.talker_config.codec_bos_id, - ] - ], - device=self.talker.device, - dtype=text_id.dtype, - ) - ), - codec_embed, - ], - dim=1, - ) - # compute lens - text_lens = text_embed.shape[1] - codec_lens = codec_embed.shape[1] - if non_streaming_mode: - icl_input_embed = text_embed + self.talker.get_input_embeddings()( - torch.tensor( - [ - [ - self.config.talker_config.codec_pad_id, - ] - * text_lens - ], - device=self.talker.device, - dtype=text_id.dtype, - ) - ) - icl_input_embed = torch.cat([icl_input_embed, codec_embed + tts_pad_embed], dim=1) - return icl_input_embed, tts_pad_embed - else: - if text_lens > codec_lens: - return text_embed[:, :codec_lens] + codec_embed, text_embed[:, codec_lens:] - else: - text_embed = torch.cat([text_embed] + [tts_pad_embed] * (codec_lens - text_lens), dim=1) - return text_embed + codec_embed, tts_pad_embed - - @torch.no_grad() - def generate( - self, - input_ids: list[torch.Tensor] | None = None, - instruct_ids: list[torch.Tensor] | None = None, - ref_ids: list[torch.Tensor] | None = None, - voice_clone_prompt: list[dict] = None, - languages: list[str] = None, - speakers: list[str] = None, - non_streaming_mode=False, - max_new_tokens: int = 4096, - do_sample: bool = True, - top_k: int = 50, - top_p: float = 1.0, - temperature: float = 0.9, - subtalker_dosample: bool = True, - subtalker_top_k: int = 50, - subtalker_top_p: float = 1.0, - subtalker_temperature: float = 0.9, - eos_token_id: int | None = None, - repetition_penalty: float = 1.05, - **kwargs, - ): - talker_kwargs = { - "max_new_tokens": max_new_tokens, - "min_new_tokens": 2, - "do_sample": do_sample, - "top_k": top_k, - "top_p": top_p, - "temperature": temperature, - "subtalker_dosample": subtalker_dosample, - "subtalker_top_k": subtalker_top_k, - "subtalker_top_p": subtalker_top_p, - "subtalker_temperature": subtalker_temperature, - "eos_token_id": eos_token_id if eos_token_id is not None else self.config.talker_config.codec_eos_token_id, - "repetition_penalty": repetition_penalty, - "suppress_tokens": [ - i - for i in range(self.config.talker_config.vocab_size - 1024, self.config.talker_config.vocab_size) - if i not in (self.config.talker_config.codec_eos_token_id,) - ], - "output_hidden_states": getattr(kwargs, "output_hidden_states", True), - "return_dict_in_generate": getattr(kwargs, "return_dict_in_generate", True), - } - - talker_input_embeds = [[] for _ in range(len(input_ids))] - - voice_clone_spk_embeds = None - # voice clone speaker prompt generate - if voice_clone_prompt is not None: - voice_clone_spk_embeds = self.generate_speaker_prompt(voice_clone_prompt) - - # instruct text prompt generate - if instruct_ids is not None: - for index, instruct_id in enumerate(instruct_ids): - if instruct_id is not None: - talker_input_embeds[index].append( - self.talker.text_projection(self.talker.get_text_embeddings()(instruct_id)) - ) - - # tts text prompt generate - trailing_text_hiddens = [] - if speakers is None: - speakers = [None] * len(input_ids) - for index, (input_id, language, speaker) in enumerate(zip(input_ids, languages, speakers)): - if voice_clone_spk_embeds is None: - if speaker == "" or speaker is None: # Instruct create speaker - speaker_embed = None - else: - if speaker.lower() not in self.config.talker_config.spk_id: - raise NotImplementedError(f"Speaker {speaker} not implemented") - else: - spk_id = self.config.talker_config.spk_id[speaker.lower()] - speaker_embed = self.talker.get_input_embeddings()( - torch.tensor( - spk_id, - device=self.talker.device, - dtype=input_id.dtype, - ) - ) - else: - if voice_clone_prompt["x_vector_only_mode"][index] or voice_clone_prompt["icl_mode"][index]: - speaker_embed = voice_clone_spk_embeds[index] - else: - speaker_embed = None - - assert language is not None - - if language.lower() == "auto": - language_id = None - else: - if language.lower() not in self.config.talker_config.codec_language_id: - raise NotImplementedError(f"Language {language} not implemented") - else: - language_id = self.config.talker_config.codec_language_id[language.lower()] - - if ( - language.lower() in ["chinese", "auto"] - and speaker != "" - and speaker is not None - and self.config.talker_config.spk_is_dialect[speaker.lower()] is not False - ): - dialect = self.config.talker_config.spk_is_dialect[speaker.lower()] - language_id = self.config.talker_config.codec_language_id[dialect] - - tts_bos_embed, tts_eos_embed, tts_pad_embed = self.talker.text_projection( - self.talker.get_text_embeddings()( - torch.tensor( - [[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]], - device=self.talker.device, - dtype=input_id.dtype, - ) - ) - ).chunk(3, dim=1) # 3 * [1 1 d] - - # codec: tag and speaker - if language_id is None: - codec_prefill_list = [ - [ - self.config.talker_config.codec_nothink_id, - self.config.talker_config.codec_think_bos_id, - self.config.talker_config.codec_think_eos_id, - ] - ] - else: - codec_prefill_list = [ - [ - self.config.talker_config.codec_think_id, - self.config.talker_config.codec_think_bos_id, - language_id, - self.config.talker_config.codec_think_eos_id, - ] - ] - - codec_input_emebdding_0 = self.talker.get_input_embeddings()( - torch.tensor( - codec_prefill_list, - device=self.talker.device, - dtype=input_id.dtype, - ) - ) - codec_input_emebdding_1 = self.talker.get_input_embeddings()( - torch.tensor( - [ - [ - self.config.talker_config.codec_pad_id, - self.config.talker_config.codec_bos_id, - ] - ], - device=self.talker.device, - dtype=input_id.dtype, - ) - ) - if speaker_embed is None: - codec_input_emebdding = torch.cat([codec_input_emebdding_0, codec_input_emebdding_1], dim=1) - else: - codec_input_emebdding = torch.cat( - [codec_input_emebdding_0, speaker_embed.view(1, 1, -1), codec_input_emebdding_1], dim=1 - ) - - # '<|im_start|>assistant\n我叫通义千问,是阿里云的开源大模型。<|im_end|>\n<|im_start|>assistant\n' - - # <|im_start|>assistant\n - _talker_input_embed_role = self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, :3])) - - # tts_pad * 4 + tts_bos - _talker_input_embed = ( - torch.cat( - ( - tts_pad_embed.expand(-1, codec_input_emebdding.shape[1] - 2, -1), - tts_bos_embed, - ), - dim=1, - ) - + codec_input_emebdding[:, :-1] - ) - - talker_input_embed = torch.cat((_talker_input_embed_role, _talker_input_embed), dim=1) - - if ( - voice_clone_prompt is not None - and voice_clone_prompt["ref_code"] is not None - and voice_clone_prompt["icl_mode"][index] - ): - icl_input_embed, trailing_text_hidden = self.generate_icl_prompt( - text_id=input_id[:, 3:-5], - ref_id=ref_ids[index][:, 3:-2], - ref_code=voice_clone_prompt["ref_code"][index].to(self.talker.device), - tts_pad_embed=tts_pad_embed, - tts_eos_embed=tts_eos_embed, - non_streaming_mode=non_streaming_mode, - ) - talker_input_embed = torch.cat([talker_input_embed, icl_input_embed], dim=1) - else: - # tts_text_first_token - talker_input_embed = torch.cat( - [ - talker_input_embed, - self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 3:4])) - + codec_input_emebdding[:, -1:], - ], - dim=1, - ) - if non_streaming_mode: - talker_input_embed = talker_input_embed[:, :-1] # 去掉原本放进去的text - talker_input_embed = torch.cat( - [ - talker_input_embed, - torch.cat( - ( - self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 3:-5])), - tts_eos_embed, - ), - dim=1, - ) - + self.talker.get_input_embeddings()( - torch.tensor( - [ - [ - self.config.talker_config.codec_pad_id, - ] - * (input_id[:, 3:-5].shape[1] + 1) - ], - device=self.talker.device, - dtype=input_id.dtype, - ) - ), - tts_pad_embed - + self.talker.get_input_embeddings()( - torch.tensor( - [ - [ - self.config.talker_config.codec_bos_id, - ] - ], - device=self.talker.device, - dtype=input_id.dtype, - ) - ), - ], - dim=1, - ) - trailing_text_hidden = tts_pad_embed - else: - # 叫通义千问,是阿里云的开源大模型。 - trailing_text_hidden = torch.cat( - ( - self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 4:-5])), - tts_eos_embed, - ), - dim=1, - ) - talker_input_embeds[index].append(talker_input_embed) - trailing_text_hiddens.append(trailing_text_hidden) - - for index, talker_input_embed in enumerate(talker_input_embeds): - talker_input_embeds[index] = torch.cat([item for item in talker_input_embed if item is not None], dim=1) - - # for batch inferquence - original_lengths = torch.tensor([t.shape[1] for t in talker_input_embeds]) - # left padding for talker input embeds - sequences = [t.squeeze(0) for t in talker_input_embeds] - sequences_reversed = [t.flip(dims=[0]) for t in sequences] - padded_reversed = torch.nn.utils.rnn.pad_sequence(sequences_reversed, batch_first=True, padding_value=0.0) - talker_input_embeds = padded_reversed.flip(dims=[1]) - # generate mask - batch_size, max_len = talker_input_embeds.shape[0], talker_input_embeds.shape[1] - indices = torch.arange(max_len).expand(batch_size, -1) - num_pads = max_len - original_lengths - talker_attention_mask = (indices >= num_pads.unsqueeze(1)).long().to(talker_input_embeds.device) - # padding trailing text hiddens - pad_embedding_vector = tts_pad_embed.squeeze() - sequences_to_pad = [t.squeeze(0) for t in trailing_text_hiddens] - trailing_text_original_lengths = [s.shape[0] for s in sequences_to_pad] - padded_hiddens = torch.nn.utils.rnn.pad_sequence(sequences_to_pad, batch_first=True, padding_value=0.0) - arange_tensor = torch.arange(max(trailing_text_original_lengths), device=padded_hiddens.device).expand( - len(trailing_text_original_lengths), -1 - ) - lengths_tensor = torch.tensor(trailing_text_original_lengths, device=padded_hiddens.device).unsqueeze(1) - padding_mask = arange_tensor >= lengths_tensor - padded_hiddens[padding_mask] = pad_embedding_vector - trailing_text_hiddens = padded_hiddens - - # forward - talker_result = self.talker.generate( - inputs_embeds=talker_input_embeds, - attention_mask=talker_attention_mask, - trailing_text_hidden=trailing_text_hiddens, - tts_pad_embed=tts_pad_embed, - **talker_kwargs, - ) - - talker_codes = torch.stack([hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None], dim=1) - talker_hidden_states = torch.cat([hid[0][-1][:, -1:] for hid in talker_result.hidden_states], dim=1)[:, :-1] - - first_codebook = talker_codes[:, :, 0] - is_stop_token = first_codebook == self.config.talker_config.codec_eos_token_id - stop_indices = torch.argmax(is_stop_token.int(), dim=1) - has_stop_token = is_stop_token.any(dim=1) - effective_lengths = torch.where(has_stop_token, stop_indices, talker_codes.shape[1]) - - talker_codes_list = [ - talker_codes[ - i, - :length, - ] - for i, length in enumerate(effective_lengths) - ] - talker_hidden_states_list = [talker_hidden_states[i, :length, :] for i, length in enumerate(effective_lengths)] - - return talker_codes_list, talker_hidden_states_list - - -__all__ = [ - "Qwen3TTSForConditionalGeneration", - "Qwen3TTSTalkerForConditionalGeneration", - "Qwen3TTSPreTrainedModel", - "Qwen3TTSTalkerModel", -] diff --git a/vllm_omni/model_executor/models/qwen3_tts/processing_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/processing_qwen3_tts.py deleted file mode 100644 index 5643a857cdb..00000000000 --- a/vllm_omni/model_executor/models/qwen3_tts/processing_qwen3_tts.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from transformers.feature_extraction_utils import BatchFeature -from transformers.processing_utils import ProcessingKwargs, ProcessorMixin - - -class Qwen3TTSProcessorKwargs(ProcessingKwargs, total=False): - _defaults = { - "text_kwargs": { - "padding": False, - "padding_side": "left", - } - } - - -class Qwen3TTSProcessor(ProcessorMixin): - r""" - Constructs a Qwen3TTS processor. - - Args: - tokenizer ([`Qwen2TokenizerFast`], *optional*): - The text tokenizer. - chat_template (`Optional[str]`, *optional*): - The Jinja template to use for formatting the conversation. - If not provided, the default chat template is used. - """ - - attributes = ["tokenizer"] - tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") - - def __init__(self, tokenizer=None, chat_template=None): - super().__init__(tokenizer, chat_template=chat_template) - - def __call__(self, text=None, **kwargs) -> BatchFeature: - """ - Main method to prepare for the model one or several sequences(s) and audio(s). - This method forwards the `text` and `kwargs` arguments to - Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` - to encode the text. - - Args: - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - """ - - if text is None: - raise ValueError("You need to specify either a `text` input to process.") - - output_kwargs = self._merge_kwargs( - Qwen3TTSProcessorKwargs, - tokenizer_init_kwargs=self.tokenizer.init_kwargs, - **kwargs, - ) - if not isinstance(text, list): - text = [text] - - texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) - - return BatchFeature( - data={**texts_inputs}, - tensor_type=kwargs.get("return_tensors"), - ) - - def batch_decode(self, *args, **kwargs): - """ - This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please - refer to the docstring of this method for more information. - """ - return self.tokenizer.batch_decode(*args, **kwargs) - - def decode(self, *args, **kwargs): - """ - This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to - the docstring of this method for more information. - """ - return self.tokenizer.decode(*args, **kwargs) - - def apply_chat_template(self, conversations, chat_template=None, **kwargs): - if isinstance(conversations[0], dict): - conversations = [conversations] - return super().apply_chat_template(conversations, chat_template, **kwargs) - - @property - def model_input_names(self): - tokenizer_input_names = self.tokenizer.model_input_names - return list(dict.fromkeys(tokenizer_input_names)) - - -__all__ = ["Qwen3TTSProcessor"] diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py deleted file mode 100644 index 3972d59bb54..00000000000 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_disaggregated.py +++ /dev/null @@ -1,227 +0,0 @@ -import os -from collections.abc import Iterable -from typing import Any - -import numpy as np -import torch -import torch.nn as nn -from transformers.utils.hub import cached_file -from vllm.config import VllmConfig - -from vllm_omni.model_executor.models.output_templates import OmniOutput - -from .qwen3_tts import Qwen3TTSModel -from .qwen3_tts_tokenizer import Qwen3TTSTokenizer - -_VALID_STAGES = ("talker", "speech_tokenizer") - - -class Qwen3TTSForConditionalGenerationDisaggregatedVLLM(nn.Module): - """Stage-aware wrapper for disaggregated Qwen3-TTS (selects stage via model_stage). - SpeechTokenizer stage decodes codec->waveform; talker is handled by the AR talker model.""" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - self.vllm_config = vllm_config - self.model_path = vllm_config.model_config.model - self.model_stage = getattr(vllm_config.model_config, "model_stage", None) - self._async_chunk = bool(getattr(vllm_config.model_config, "async_chunk", False)) - - if self.model_stage not in _VALID_STAGES: - raise ValueError(f"Invalid model_stage for Qwen3-TTS disaggregated model: {self.model_stage}") - - if self.model_stage == "talker": - # Avoid accidental fallback to the HF generate() path. - raise ValueError( - "Qwen3-TTS disaggregated wrapper no longer supports model_stage='talker'. " - "Use model_arch=Qwen3TTSTalkerForConditionalGenerationARVLLM for Stage-0." - ) - - self.have_multimodal_outputs = True - # Only speech_tokenizer needs preprocess in async_chunk (treat prompt_token_ids as codec codes). - self.has_preprocess = bool(self.model_stage == "speech_tokenizer" and self._async_chunk) - if self.model_stage == "speech_tokenizer" and not self._async_chunk: - raise ValueError( - "Qwen3-TTS SpeechTokenizer stage no longer supports serial " - "`additional_information['audio_codes']` mode. Use async_chunk " - "stage config so Stage-1 consumes codec codes via prompt_token_ids." - ) - - self._talker: Qwen3TTSModel | None = None - self._speech_tokenizer: Qwen3TTSTokenizer | None = None - # Only required for Stage-1 streaming decode (to reframe flattened codes). - self._num_code_groups = 0 - if self.model_stage == "speech_tokenizer": - try: - self._num_code_groups = int(vllm_config.model_config.hf_config.talker_config.num_code_groups) - except Exception as e: - raise ValueError(f"Failed to read talker_config.num_code_groups from hf_config: {e}") from e - if self._num_code_groups <= 0: - raise ValueError(f"Invalid num_code_groups={self._num_code_groups} for Qwen3-TTS.") - - @staticmethod - def _module_device(module: nn.Module) -> torch.device: - try: - return next(module.parameters()).device - except StopIteration: - for _, buf in module.named_buffers(recurse=True): - return buf.device - return torch.device("cpu") - - def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: - if self._speech_tokenizer is not None: - return self._speech_tokenizer - - # Locate speech_tokenizer dir from HF cache (or local path). - speech_tokenizer_path = cached_file(self.model_path, "speech_tokenizer/config.json") - if speech_tokenizer_path is None: - raise ValueError(f"{self.model_path}/speech_tokenizer/config.json not found") - speech_tokenizer_dir = os.path.dirname(speech_tokenizer_path) - self._speech_tokenizer = Qwen3TTSTokenizer.from_pretrained( - speech_tokenizer_dir, - torch_dtype=torch.bfloat16, - load_feature_extractor=False, - ) - # Run decode on the vLLM worker device, then read back from module. - if self._speech_tokenizer.model is not None: - self._speech_tokenizer.model.to(device=self.vllm_config.device_config.device) - self._speech_tokenizer.device = self._module_device(self._speech_tokenizer.model) - return self._speech_tokenizer - - def preprocess( - self, - input_ids: torch.Tensor, - input_embeds: torch.Tensor | None, - **info_dict: Any, - ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: - # Only used in async_chunk speech_tokenizer stage. - if self.model_stage != "speech_tokenizer" or not self._async_chunk: - return input_ids, (input_embeds if input_embeds is not None else self.embed_input_ids(input_ids)), {} - - if self._num_code_groups <= 0: - raise ValueError(f"Invalid talker_config.num_code_groups={self._num_code_groups} for streaming decode.") - - # Optional request id for debugging only (streaming decode keeps no per-request state). - req_id = str(info_dict.get("_omni_request_id") or "") - - q = int(self._num_code_groups) - if input_ids.numel() <= 0: - update = {"model_outputs": None, "sr": None} - return input_ids, (input_embeds if input_embeds is not None else self.embed_input_ids(input_ids)), update - - tokens = input_ids.reshape(-1).to(torch.long) - if int(tokens.numel()) % q != 0: - # Finished requests may still get placeholder tokens; treat as a no-op instead of crashing. - if bool(info_dict.get("finished", False)) or int(tokens.numel()) <= 1: - update = {"model_outputs": None, "sr": None} - return ( - input_ids, - (input_embeds if input_embeds is not None else self.embed_input_ids(input_ids)), - update, - ) - raise ValueError( - f"Streaming codec token length must be divisible by num_code_groups={q}. " - f"got={int(tokens.numel())} request_id={req_id or ''}" - ) - - frames = int(tokens.numel()) // q - if frames <= 0: - update = {"model_outputs": None, "sr": None} - return input_ids, (input_embeds if input_embeds is not None else self.embed_input_ids(input_ids)), update - - # tokens are codebook-major flattened: [Q, F] flattened row-major. - codes_qf = tokens.reshape(q, frames) - codes_fq = codes_qf.transpose(0, 1).contiguous() # [F, Q] - - ctx_frames = int(info_dict.get("codec_context_frames") or 0) - if ctx_frames < 0 or ctx_frames > frames: - raise ValueError( - f"Invalid codec_context_frames={ctx_frames} for frames={frames} request_id={req_id or ''}" - ) - - tok = self._ensure_speech_tokenizer_loaded() - device = getattr(tok, "device", None) or torch.device("cpu") - codes_chunk = codes_fq.to(device=device) - - wavs, sr = tok.decode({"audio_codes": codes_chunk}) - if not wavs: - raise ValueError("SpeechTokenizer streaming decode produced empty waveform list.") - audio_np = wavs[0].astype(np.float32, copy=False) - - if ctx_frames > 0: - try: - upsample = int(tok.get_decode_upsample_rate()) - except Exception as e: - raise ValueError(f"Failed to get decode upsample rate for streaming trim: {e}") from e - if upsample <= 0: - raise ValueError(f"Invalid decode upsample rate: {upsample}") - cut = ctx_frames * upsample - if cut >= audio_np.shape[0]: - raise ValueError( - f"Streaming decode context trim exceeds decoded length: cut={cut} decoded={audio_np.shape[0]}" - ) - audio_np = audio_np[cut:] - - update: dict[str, Any] = { - "model_outputs": torch.from_numpy(audio_np).to(dtype=torch.float32), - "sr": torch.tensor(int(sr), dtype=torch.int), - } - return input_ids, (input_embeds if input_embeds is not None else self.embed_input_ids(input_ids)), update - - @torch.no_grad() - def forward( - self, - input_ids: torch.Tensor | None = None, - positions: torch.Tensor | None = None, - intermediate_tensors: Any = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: Any, - ) -> OmniOutput: - runtime_info = kwargs.get("runtime_additional_information", [{}]) - if isinstance(runtime_info, list) and runtime_info: - runtime_info = runtime_info[0] - if not isinstance(runtime_info, dict): - runtime_info = {} - - # speech_tokenizer stage: decode in preprocess(); forward returns a dummy tensor for span slicing. - device = input_ids.device if isinstance(input_ids, torch.Tensor) else torch.device("cpu") - n = int(input_ids.shape[0]) if isinstance(input_ids, torch.Tensor) else 1 - if n <= 0: - n = 1 - return torch.zeros((n, 1), dtype=torch.float32, device=device) - - def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput: - if isinstance(model_outputs, OmniOutput): - return model_outputs - - # async_chunk speech_tokenizer: emit the latest decoded chunk from runtime_additional_information. - if self.model_stage != "speech_tokenizer" or not self._async_chunk: - return OmniOutput(text_hidden_states=model_outputs, multimodal_outputs={}) - - runtime_info = kwargs.get("runtime_additional_information", [{}]) - if isinstance(runtime_info, list) and runtime_info: - runtime_info = runtime_info[0] - if not isinstance(runtime_info, dict): - runtime_info = {} - - mo = runtime_info.get("model_outputs") - sr = runtime_info.get("sr") - if isinstance(mo, torch.Tensor) and isinstance(sr, torch.Tensor): - return OmniOutput(text_hidden_states=model_outputs, multimodal_outputs={"model_outputs": mo, "sr": sr}) - return OmniOutput(text_hidden_states=model_outputs, multimodal_outputs={}) - - def compute_logits( - self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None - ) -> torch.Tensor | None: - return None - - def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: - # SpeechTokenizer ignores token embeddings, but vLLM requires embed_input_ids to select the runner type. - if input_ids.numel() == 0: - return torch.empty((0, 1), device=input_ids.device, dtype=torch.float32) - return torch.zeros((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.float32) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # Talker loads weights elsewhere; speech_tokenizer loads `speech_tokenizer/` lazily. - # Return empty set without consuming weights to avoid vLLM re-loading. - return set() diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py similarity index 85% rename from vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py rename to vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index 8bfbddc4b38..e8e557100a7 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker_ar.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -13,7 +13,10 @@ import soundfile as sf import torch import torch.nn as nn +import torch.nn.functional as F +from librosa.filters import mel as librosa_mel_fn from transformers import AutoTokenizer +from transformers.activations import ACT2FN from transformers.utils.hub import cached_file from vllm.config import VllmConfig from vllm.distributed import get_pp_group @@ -26,19 +29,254 @@ from vllm_omni.model_executor.models.output_templates import OmniOutput -from .configuration_qwen3_tts import Qwen3TTSConfig, Qwen3TTSTalkerConfig -from .modeling_qwen3_tts import ( - Qwen3TTSSpeakerEncoder, - Qwen3TTSTalkerResizeMLP, - mel_spectrogram, -) +from .configuration_qwen3_tts import Qwen3TTSConfig, Qwen3TTSSpeakerEncoderConfig, Qwen3TTSTalkerConfig from .qwen3_tts_code_predictor_vllm import Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM from .qwen3_tts_tokenizer import Qwen3TTSTokenizer logger = init_logger(__name__) -class Qwen3TTSTalkerForConditionalGenerationARVLLM(nn.Module): +# --------------------------------------------------------------------------- +# Components ported from the HuggingFace Qwen3-TTS reference implementation. +# Only the classes actually needed by the vLLM AR Talker are kept here. +# --------------------------------------------------------------------------- + + +class Qwen3TTSTalkerResizeMLP(nn.Module): + """Two-layer MLP that maps between hidden sizes with an activation in between.""" + + def __init__(self, input_size: int, intermediate_size: int, output_size: int, act: str, bias=False): + super().__init__() + self.linear_fc1 = nn.Linear(input_size, intermediate_size, bias=bias) + self.linear_fc2 = nn.Linear(intermediate_size, output_size, bias=bias) + self.act_fn = ACT2FN[act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +# ---- Speaker encoder (ECAPA-TDNN) and helpers ---- + + +class TimeDelayNetBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + in_channel = in_channels // scale + hidden_channel = out_channels // scale + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock(in_channel, hidden_channel, kernel_size=kernel_size, dilation=dilation) + for _ in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + return torch.cat(outputs, dim=1) + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + self.conv1 = nn.Conv1d(in_channels, se_channels, kernel_size=1, padding="same", padding_mode="reflect") + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d(se_channels, out_channels, kernel_size=1, padding="same", padding_mode="reflect") + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + return hidden_states * hidden_states_mean + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """TDNN-Res2Net-TDNN-SE building block used in ECAPA-TDNN.""" + + def __init__(self, in_channels, out_channels, res2net_scale=8, se_channels=128, kernel_size=1, dilation=1): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock(in_channels, out_channels, kernel_size=1, dilation=1) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock(out_channels, out_channels, kernel_size=1, dilation=1) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + return hidden_state + residual + + +class AttentiveStatisticsPooling(nn.Module): + """Attentive statistic pooling layer: returns concatenated mean and std.""" + + def __init__(self, channels, attention_channels=128): + super().__init__() + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d(attention_channels, channels, kernel_size=1, padding="same", padding_mode="reflect") + + @staticmethod + def _length_to_mask(length, max_len=None, dtype=None, device=None): + if max_len is None: + max_len = length.max().long().item() + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + return torch.as_tensor(mask, dtype=dtype, device=device) + + @staticmethod + def _compute_statistics(x, m, dim=2, eps=1e-12): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + mask = self._length_to_mask( + lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device + ) + mask = mask.unsqueeze(1) + total = mask.sum(dim=2, keepdim=True) + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + attention = self.conv(self.tanh(self.tdnn(attention))) + attention = attention.masked_fill(mask == 0, float("-inf")) + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + pooled_stats = torch.cat((mean, std), dim=1) + return pooled_stats.unsqueeze(2) + + +class Qwen3TTSSpeakerEncoder(torch.nn.Module): + """ECAPA-TDNN speaker encoder. + + Reference: "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143). + """ + + def __init__(self, config: Qwen3TTSSpeakerEncoderConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, config.enc_channels[0], config.enc_kernel_sizes[0], config.enc_dilations[0], + ) + ) + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], config.enc_channels[-1], config.enc_kernel_sizes[-1], config.enc_dilations[-1] + ) + self.asp = AttentiveStatisticsPooling(config.enc_channels[-1], attention_channels=config.enc_attention_channels) + self.fc = nn.Conv1d( + config.enc_channels[-1] * 2, config.enc_dim, kernel_size=1, padding="same", padding_mode="reflect", + ) + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + hidden_states = self.asp(hidden_states) + hidden_states = self.fc(hidden_states) + return hidden_states.squeeze(-1) + + +# ---- Audio utilities ---- + + +def _dynamic_range_compression(x, c=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * c) + + +def mel_spectrogram( + y: torch.Tensor, + n_fft: int, + num_mels: int, + sampling_rate: int, + hop_size: int, + win_size: int, + fmin: int, + fmax: int | None = None, + center: bool = False, +) -> torch.Tensor: + """Calculate mel spectrogram of an input signal using librosa mel filterbank and torch STFT.""" + if torch.min(y) < -1.0: + logger.warning("Min value of input waveform signal is %s", torch.min(y)) + if torch.max(y) > 1.0: + logger.warning("Max value of input waveform signal is %s", torch.max(y)) + device = y.device + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis = torch.from_numpy(mel).float().to(device) + hann_window = torch.hann_window(win_size).to(device) + padding = (n_fft - hop_size) // 2 + y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + spec = torch.stft( + y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, + center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + mel_spec = torch.matmul(mel_basis, spec) + return _dynamic_range_compression(mel_spec) + + +# --------------------------------------------------------------------------- +# Main AR Talker model +# --------------------------------------------------------------------------- + + +class Qwen3TTSTalkerForConditionalGeneration(nn.Module): """vLLM-AR talker: step-wise layer-0 codec decoding. Predicts residual codebooks (1..Q-1) into `audio_codes` and streams text via `tailing_text_hidden`.""" @@ -433,7 +671,7 @@ def _first(x: object, default: object) -> object: instruct_len = 0 if instruct.strip(): - instruct_text = Qwen3TTSTalkerForConditionalGenerationARVLLM._build_instruct_text(instruct) + instruct_text = Qwen3TTSTalkerForConditionalGeneration._build_instruct_text(instruct) instruct_len = len(tokenize_prompt(instruct_text)) # ---- codec prefix portion (matches _build_prompt_embeds) ---- @@ -463,7 +701,7 @@ def _first(x: object, default: object) -> object: prompt_len = instruct_len + role_len + codec_prefix_len # ---- text conditioning portion (matches _build_prompt_embeds) ---- - assistant_text = Qwen3TTSTalkerForConditionalGenerationARVLLM._build_assistant_text(text) + assistant_text = Qwen3TTSTalkerForConditionalGeneration._build_assistant_text(text) assistant_len = len(tokenize_prompt(assistant_text)) if assistant_len < 8: raise ValueError(f"Unexpected assistant prompt length: {assistant_len}") @@ -529,7 +767,7 @@ def _first(x: object, default: object) -> object: "Base in-context non-streaming requires `ref_text` or tokenized `ref_ids`." ) ref_text_ids = tokenize_prompt( - Qwen3TTSTalkerForConditionalGenerationARVLLM._build_ref_text(ref_text) + Qwen3TTSTalkerForConditionalGeneration._build_ref_text(ref_text) ) ref_ids_len = len(ref_text_ids) elif hasattr(ref_ids, "shape"): @@ -866,7 +1104,7 @@ def _generate_icl_prompt( tts_eos_embed: torch.Tensor, non_streaming_mode: bool, ) -> tuple[torch.Tensor, torch.Tensor]: - # Ported from official `generate_icl_prompt` in modeling_qwen3_tts.py + # Ported from official Qwen3TTSForConditionalGeneration.generate_icl_prompt text_embed = self.text_projection(self.text_embedding(torch.cat([ref_id, text_id], dim=-1))) text_embed = torch.cat([text_embed, tts_eos_embed], dim=1) @@ -1272,7 +1510,7 @@ def _talker_and_collect_speaker(ws: Iterable[tuple[str, torch.Tensor]]): if self.speaker_encoder is None: self.speaker_encoder = Qwen3TTSSpeakerEncoder(self.config.speaker_encoder_config) loaded |= loader.load_weights(speaker_weights, mapper=self.hf_to_vllm_mapper) - logger.info("Loaded %d weights for Qwen3TTSTalkerForConditionalGenerationARVLLM", len(loaded)) + logger.info("Loaded %d weights for Qwen3TTSTalkerForConditionalGeneration", len(loaded)) return loaded # -------------------- GPU-side MTP fast-path -------------------- diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index b794ad73167..4cf1d22349d 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -48,26 +48,16 @@ "qwen3_omni_code2wav", "Qwen3OmniMoeCode2Wav", ), - "Qwen3TTSTalkerForConditionalGenerationARVLLM": ( + "Qwen3TTSTalkerForConditionalGeneration": ( "qwen3_tts", - "qwen3_tts_talker_ar", - "Qwen3TTSTalkerForConditionalGenerationARVLLM", + "qwen3_tts_talker", + "Qwen3TTSTalkerForConditionalGeneration", ), "Qwen3TTSCode2Wav": ( "qwen3_tts", "qwen3_tts_code2wav", "Qwen3TTSCode2Wav", ), - "Qwen3TTSForConditionalGenerationDisaggregatedVLLM": ( - "qwen3_tts", - "qwen3_tts_disaggregated", - "Qwen3TTSForConditionalGenerationDisaggregatedVLLM", - ), - "Qwen3TTSForConditionalGeneration": ( - "qwen3_tts", - "qwen3_tts_disaggregated", - "Qwen3TTSForConditionalGenerationDisaggregatedVLLM", - ), } diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml index 538473cf5e6..0306a04269f 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml @@ -7,10 +7,10 @@ stage_args: max_batch_size: 1 engine_args: model_stage: qwen3_tts - model_arch: Qwen3TTSTalkerForConditionalGenerationARVLLM + model_arch: Qwen3TTSTalkerForConditionalGeneration # Force stage-specific registered architecture. hf_overrides: - architectures: [Qwen3TTSTalkerForConditionalGenerationARVLLM] + architectures: [Qwen3TTSTalkerForConditionalGeneration] worker_type: ar scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler enforce_eager: true diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml index 93fd1c22e04..8f3a2dfb72c 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml @@ -7,9 +7,9 @@ stage_args: max_batch_size: 1 engine_args: model_stage: qwen3_tts - model_arch: Qwen3TTSTalkerForConditionalGenerationARVLLM + model_arch: Qwen3TTSTalkerForConditionalGeneration hf_overrides: - architectures: [Qwen3TTSTalkerForConditionalGenerationARVLLM] + architectures: [Qwen3TTSTalkerForConditionalGeneration] worker_type: ar scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler enforce_eager: false From 14a9ddcce4806a5dcffc1311765db25144cc9dd5 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Tue, 10 Feb 2026 19:58:29 -0800 Subject: [PATCH 14/28] [~] Refactor: Optimize TTS model initialization and enhance configuration management for improved performance Signed-off-by: Sy03 <1370724210@qq.com> --- .../model_executor/models/qwen3_tts/qwen3_tts_code2wav.py | 2 +- .../model_executor/stage_input_processors/qwen3_tts.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index 4e381c20143..339268f34f0 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -150,7 +150,7 @@ def forward( if n_tokens == 0: return empty_ret - # input_ids[0] = codec_context_frames (prepended by adapter). + # input_ids[0] = codec_context_frames (prepended by stage_input_processor). ctx_frames = int(ids[0].item()) ids = ids[1:] n_tokens = ids.numel() diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 2930b17d068..08c89ef3acd 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -61,14 +61,12 @@ def talker2code2wav_async_chunk( if finished and (not appended_frame) and chunk_length == 0: return { "code_predictor_codes": [], - "codec_context_frames": 0, "finished": torch.tensor(True, dtype=torch.bool), } if length <= 0: return { "code_predictor_codes": [], - "codec_context_frames": 0, "finished": torch.tensor(bool(finished), dtype=torch.bool), } @@ -79,8 +77,9 @@ def talker2code2wav_async_chunk( # Pack context + chunk into codebook-major flat codes for adapter. code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist() + # Build final prompt_token_ids with ctx_frames header for Qwen3-TTS Code2Wav. + # The model expects input_ids layout: [ctx_frames, *flat_codes]. return { - "code_predictor_codes": code_predictor_codes, - "codec_context_frames": int(ctx_frames), + "code_predictor_codes": [int(ctx_frames)] + code_predictor_codes, "finished": torch.tensor(bool(finished), dtype=torch.bool), } From 4965b8490ff18617c14edae97d165427ace4ecbd Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Wed, 11 Feb 2026 00:25:08 -0800 Subject: [PATCH 15/28] [~] Feat: Implement ref_audio resolution for TTS processing to solve SSRF issue and enhance model registry for Qwen3 TTS - Added functionality to resolve reference audio from URLs or base64 strings in serving_speech.py - Updated model registry to include Qwen3TTSForConditionalGeneration. - Enhanced configuration for Qwen3 TTS async chunk processing in stage_configs. Signed-off-by: Sy03 <1370724210@qq.com> --- .../entrypoints/openai/serving_speech.py | 46 ++++++++++- .../models/qwen3_tts/qwen3_tts_talker.py | 36 +++++--- vllm_omni/model_executor/models/registry.py | 5 ++ .../npu/stage_configs/qwen3_tts.yaml | 82 +++++++++++++++++-- 4 files changed, 147 insertions(+), 22 deletions(-) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index f6d642bd49f..1467b16d65e 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -1,6 +1,12 @@ import asyncio +import base64 +import io from typing import Any +from urllib.parse import urlparse +from urllib.request import urlopen +import numpy as np +import soundfile as sf from fastapi import Request from fastapi.responses import Response from vllm.entrypoints.openai.engine.serving import OpenAIServing @@ -17,6 +23,9 @@ logger = init_logger(__name__) +_REF_AUDIO_TIMEOUT_S = 15 +_REF_AUDIO_MAX_BYTES = 50 * 1024 * 1024 # 50 MB + # TTS Configuration (currently supports Qwen3-TTS) _TTS_MODEL_STAGES: set[str] = {"qwen3_tts"} _TTS_LANGUAGES: set[str] = { @@ -156,6 +165,34 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non return None + @staticmethod + async def _resolve_ref_audio(ref_audio_str: str) -> tuple[list[float], int]: + """Resolve ref_audio URL/base64 to (wav_samples, sample_rate).""" + parsed = urlparse(ref_audio_str) + + def _fetch_sync() -> tuple[np.ndarray, int]: + if parsed.scheme in ("http", "https"): + with urlopen(ref_audio_str, timeout=_REF_AUDIO_TIMEOUT_S) as resp: + data = resp.read(_REF_AUDIO_MAX_BYTES + 1) + if len(data) > _REF_AUDIO_MAX_BYTES: + raise ValueError(f"ref_audio URL exceeds {_REF_AUDIO_MAX_BYTES} bytes") + buf = io.BytesIO(data) + elif ref_audio_str.startswith("data:"): + b64 = ref_audio_str + if "," in b64: + b64 = b64.split(",", 1)[1] + buf = io.BytesIO(base64.b64decode(b64)) + else: + raise ValueError("ref_audio must be an http(s) URL or data: base64 URI") + audio, sr = sf.read(buf, dtype="float32", always_2d=False) + if isinstance(audio, np.ndarray) and audio.ndim > 1: + audio = np.mean(audio, axis=-1) + return np.asarray(audio, dtype=np.float32), int(sr) + + loop = asyncio.get_running_loop() + wav_np, sr = await loop.run_in_executor(None, _fetch_sync) + return wav_np.tolist(), sr + def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: """Build TTS parameters from request. @@ -191,9 +228,7 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any else: params["instruct"] = [""] - # Voice clone parameters (used with Base task) - if request.ref_audio is not None: - params["ref_audio"] = [request.ref_audio] + # Voice clone: ref_audio resolved in create_speech(), not here. if request.ref_text is not None: params["ref_text"] = [request.ref_text] if request.x_vector_only_mode is not None: @@ -253,6 +288,11 @@ async def create_speech( # model.preprocess replaces all embeddings, so placeholder value # is irrelevant -- but length must match to avoid excess padding. tts_params = self._build_tts_params(request) + + if request.ref_audio is not None: + wav_list, sr = await self._resolve_ref_audio(request.ref_audio) + tts_params["ref_audio"] = [[wav_list, sr]] + ph_len = self._estimate_prompt_len(tts_params) prompt = { "prompt_token_ids": [1] * ph_len, diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index e8e557100a7..bc9c6fd59da 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -7,7 +7,6 @@ from collections.abc import Callable, Iterable, Mapping from typing import Any from urllib.parse import urlparse -from urllib.request import urlopen import numpy as np import soundfile as sf @@ -198,7 +197,10 @@ def __init__(self, config: Qwen3TTSSpeakerEncoderConfig): self.blocks = nn.ModuleList() self.blocks.append( TimeDelayNetBlock( - config.mel_dim, config.enc_channels[0], config.enc_kernel_sizes[0], config.enc_dilations[0], + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], ) ) for i in range(1, len(config.enc_channels) - 1): @@ -217,7 +219,11 @@ def __init__(self, config: Qwen3TTSSpeakerEncoderConfig): ) self.asp = AttentiveStatisticsPooling(config.enc_channels[-1], attention_channels=config.enc_attention_channels) self.fc = nn.Conv1d( - config.enc_channels[-1] * 2, config.enc_dim, kernel_size=1, padding="same", padding_mode="reflect", + config.enc_channels[-1] * 2, + config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", ) def forward(self, hidden_states): @@ -263,8 +269,16 @@ def mel_spectrogram( padding = (n_fft - hop_size) // 2 y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) spec = torch.stft( - y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, - center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True, + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, ) spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) mel_spec = torch.matmul(mel_basis, spec) @@ -766,9 +780,7 @@ def _first(x: object, default: object) -> object: raise ValueError( "Base in-context non-streaming requires `ref_text` or tokenized `ref_ids`." ) - ref_text_ids = tokenize_prompt( - Qwen3TTSTalkerForConditionalGeneration._build_ref_text(ref_text) - ) + ref_text_ids = tokenize_prompt(Qwen3TTSTalkerForConditionalGeneration._build_ref_text(ref_text)) ref_ids_len = len(ref_text_ids) elif hasattr(ref_ids, "shape"): shape = getattr(ref_ids, "shape", None) @@ -815,14 +827,12 @@ def _decode_base64_to_wav_bytes(self, b64: str) -> bytes: return base64.b64decode(b64) def _load_audio_to_np(self, x: str) -> tuple[np.ndarray, int]: + """Load audio from local path or base64 data URI (no network I/O).""" import librosa if self._is_url(x): - with urlopen(x) as resp: - audio_bytes = resp.read() - with io.BytesIO(audio_bytes) as f: - audio, sr = sf.read(f, dtype="float32", always_2d=False) - elif self._is_probably_base64(x): + raise ValueError("ref_audio URLs must be resolved by the serving layer before reaching the model worker.") + if self._is_probably_base64(x): wav_bytes = self._decode_base64_to_wav_bytes(x) with io.BytesIO(wav_bytes) as f: audio, sr = sf.read(f, dtype="float32", always_2d=False) diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index 4cf1d22349d..2a66632e796 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -48,6 +48,11 @@ "qwen3_omni_code2wav", "Qwen3OmniMoeCode2Wav", ), + "Qwen3TTSForConditionalGeneration": ( + "qwen3_tts", + "qwen3_tts_talker", + "Qwen3TTSTalkerForConditionalGeneration", + ), "Qwen3TTSTalkerForConditionalGeneration": ( "qwen3_tts", "qwen3_tts_talker", diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml index d408dbab91e..71ca44ace22 100644 --- a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml @@ -1,22 +1,92 @@ +async_chunk: true stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm runtime: devices: "0" max_batch_size: 1 engine_args: model_stage: qwen3_tts - model_arch: Qwen3TTSForConditionalGeneration + model_arch: Qwen3TTSTalkerForConditionalGeneration + hf_overrides: + architectures: [Qwen3TTSTalkerForConditionalGeneration] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.3 + distributed_executor_backend: "mp" + max_num_batched_tokens: 512 + max_model_len: 4096 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk + output_connectors: + to_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3TTSCode2Wav + hf_overrides: + architectures: [Qwen3TTSCode2Wav] worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler enforce_eager: true trust_remote_code: true async_scheduling: false enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.1 + engine_output_type: audio + gpu_memory_utilization: 0.2 distributed_executor_backend: "mp" - max_num_batched_tokens: 1000000 - + max_num_batched_tokens: 8192 + max_model_len: 32768 + engine_input_source: [0] final_output: true final_output_type: audio + input_connectors: + from_stage_0: connector_of_shared_memory + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + codec_streaming: true + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + codec_chunk_frames: 25 + codec_left_context_frames: 25 + + edges: + - from: 0 + to: 1 + window_size: -1 \ No newline at end of file From 6fa6afc96a94a25eb9fe49b23fbb883165db741d Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Wed, 11 Feb 2026 02:51:45 -0800 Subject: [PATCH 16/28] [~] Feat: Improve codec mask handling in Qwen3 TTS Signed-off-by: Sy03 <1370724210@qq.com> --- .../openai_api/test_serving_speech.py | 10 ++- .../models/qwen3_tts/qwen3_tts_talker.py | 85 ++++++------------- 2 files changed, 34 insertions(+), 61 deletions(-) diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py index 2db98c06869..e55ff6812df 100644 --- a/tests/entrypoints/openai_api/test_serving_speech.py +++ b/tests/entrypoints/openai_api/test_serving_speech.py @@ -310,10 +310,12 @@ def test_is_tts_model(self, speech_server): speech_server.engine_client.stage_list = [mock_stage] assert speech_server._is_tts_model() is True - def test_build_tts_prompt(self, speech_server): - """Test TTS prompt format.""" - prompt = speech_server._build_tts_prompt("Hello") - assert prompt == "<|im_start|>assistant\nHello<|im_end|>\n<|im_start|>assistant\n" + def test_estimate_prompt_len_fallback(self, speech_server): + """Test prompt length estimation falls back to 2048 when model is unavailable.""" + tts_params = {"text": ["Hello"], "task_type": ["CustomVoice"]} + result = speech_server._estimate_prompt_len(tts_params) + # Without a real model, it should fall back to 2048. + assert result == 2048 def test_validate_tts_request_basic(self, speech_server): """Test basic validation cases.""" diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index bc9c6fd59da..d3937bfe8e0 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -380,6 +380,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix="code_predictor", ) + # Constant logit mask: allow only codec ids [1, codebook_vocab_size) plus codec EOS. + vocab = int(self.talker_config.vocab_size) + codec_mask = torch.zeros((vocab,), dtype=torch.bool) + lo, hi = 1, min(self._codebook_vocab_size, vocab) + if hi > lo: + codec_mask[lo:hi] = True + if 0 <= self._codec_eos_token_id < vocab: + codec_mask[self._codec_eos_token_id] = True + self.register_buffer("_codec_allowed_mask", codec_mask, persistent=False) + # Tokenizer for prompt building. self._tokenizer = None self._speech_tokenizer: Qwen3TTSTokenizer | None = None @@ -410,17 +420,8 @@ def compute_logits( if logits is None: return None - # Allow only real codec ids (1..codebook_vocab_size-1) plus codec EOS; specials can crash SpeechTokenizer. - # Also, id 0 is padding for the 12Hz decoder. - vocab = int(logits.shape[-1]) - allowed = torch.zeros((vocab,), dtype=torch.bool, device=logits.device) - lo = 1 - hi = min(self._codebook_vocab_size, vocab) - if hi > lo: - allowed[lo:hi] = True - if 0 <= self._codec_eos_token_id < vocab: - allowed[self._codec_eos_token_id] = True - logits = logits.masked_fill(~allowed, float("-inf")) + # Mask out invalid codec ids using the pre-built constant buffer. + logits = logits.masked_fill(~self._codec_allowed_mask, float("-inf")) return logits # -------------------- Omni multimodal output plumbing -------------------- @@ -515,13 +516,15 @@ def preprocess( if span_len > 1: # Prefill (prompt embeddings) prompt_embeds_cpu = info_dict.get("talker_prompt_embeds") - prompt_embeds = None tts_pad_embed_cpu = info_dict.get("tts_pad_embed") tts_pad_embed = None if isinstance(tts_pad_embed_cpu, torch.Tensor) and tts_pad_embed_cpu.numel() > 0: tts_pad_embed = tts_pad_embed_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) - if prompt_embeds is None: + # First prefill round: prompt_embeds_cpu is not yet populated. + # Subsequent prefill rounds (multi-chunk): prompt_embeds_cpu is a Tensor stored by the first round. + is_first_prefill = not isinstance(prompt_embeds_cpu, torch.Tensor) or prompt_embeds_cpu.ndim != 2 + if is_first_prefill: full_prompt_embeds, tailing_text_hidden, tts_pad_embed, ref_code_len = self._build_prompt_embeds( task_type=task_type, info_dict=info_dict ) @@ -549,9 +552,7 @@ def preprocess( prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16) info_update["talker_prefill_offset"] = int(offset + span_len) else: - # Subsequent prefill chunk: slice from our own running offset. - if not isinstance(prompt_embeds_cpu, torch.Tensor) or prompt_embeds_cpu.ndim != 2: - raise RuntimeError("Invalid talker_prompt_embeds in additional_information.") + # Subsequent prefill chunk: slice from stored embeddings at running offset. if tts_pad_embed is None: raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must initialize it.") offset = int(info_dict.get("talker_prefill_offset", 0) or 0) @@ -1550,55 +1551,25 @@ def talker_mtp( audio_codes = input_ids.reshape(bsz, 1) return (last_id_hidden + text_step).reshape(bsz, -1), audio_codes - # Subtalker sampling defaults (match official defaults). - do_sample = True - top_k = 50 - top_p = 1.0 - temperature = 0.9 - - def _sample_next(logits: torch.Tensor) -> torch.Tensor: - # logits: [B,V] - if temperature and float(temperature) > 0: - logits = logits / float(temperature) - if top_k and int(top_k) > 0 and int(top_k) < logits.shape[-1]: - v, _ = torch.topk(logits, int(top_k), dim=-1) - min_keep = v[:, -1].unsqueeze(-1) - logits = torch.where(logits < min_keep, torch.tensor(float("-inf"), device=logits.device), logits) - if top_p is not None and 0.0 < float(top_p) < 1.0: - sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) - probs = torch.softmax(sorted_logits, dim=-1) - cum = torch.cumsum(probs, dim=-1) - remove = cum > float(top_p) - remove[:, 0] = False - sorted_logits = torch.where(remove, torch.tensor(float("-inf"), device=logits.device), sorted_logits) - logits = torch.empty_like(logits).scatter(-1, sorted_idx, sorted_logits) - if not do_sample: - return torch.argmax(logits, dim=-1, keepdim=True) - probs = torch.softmax(logits, dim=-1) - return torch.multinomial(probs, num_samples=1) - - predictor_inputs = torch.cat([past_hidden, last_id_hidden], dim=1) # [B,2,H] - self.code_predictor.reset_cache() - tok = _sample_next(self.code_predictor.prefill_logits(predictor_inputs)) - residual_ids = [tok] - past_seq_len = 2 - for step in range(1, max_steps): - logits = self.code_predictor.decode_logits(tok, generation_step=step, past_seq_len=past_seq_len) - tok = _sample_next(logits) - residual_ids.append(tok) - past_seq_len += 1 - - residual_ids_t = torch.cat(residual_ids, dim=1).to(dtype=torch.long, device=dev) # [B, Q-1] - audio_codes = torch.cat([input_ids, residual_ids_t], dim=1) # [B,Q] + # Single forward call: predicts all residual codes (1..Q-1) autoregressively. + audio_codes = self.code_predictor( + layer0_code=input_ids.reshape(bsz, 1), + layer0_embed=last_id_hidden, + last_talker_hidden=past_hidden, + do_sample=True, + temperature=0.9, + top_k=50, + top_p=1.0, + ) # [B, Q] # Map invalid layer-0 ids (e.g. EOS) to PAD=0 so SpeechTokenizer sees only real codes. - # vLLM still uses EOS for stopping. layer0 = audio_codes[:, :1] invalid0 = (layer0 < 0) | (layer0 >= int(self._codebook_vocab_size)) if invalid0.any(): audio_codes = torch.where(invalid0.expand_as(audio_codes), torch.zeros_like(audio_codes), audio_codes) # Sum embeddings of all code groups, then add the current text step. + residual_ids_t = audio_codes[:, 1:] embeds: list[torch.Tensor] = [last_id_hidden] for i in range(max_steps): embeds.append(self.code_predictor.get_input_embeddings()[i](residual_ids_t[:, i : i + 1])) From 44693a05d68a40cf8ca816f76935e073ead0c1a1 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Thu, 12 Feb 2026 05:33:14 -0800 Subject: [PATCH 17/28] [~] Refactor: Clean up additional information handling in OmniGenerationScheduler and update compatibility comments in OmniConnectors - Removed unused per-request additional information handling in OmniGenerationScheduler. - Updated comments in adapter.py for clarity on compatibility with entry points. - Cleaned up additional information assignment in chunk_transfer_adapter.py. Signed-off-by: Sy03 <1370724210@qq.com> --- .../core/sched/omni_generation_scheduler.py | 9 --- .../distributed/omni_connectors/adapter.py | 29 +++++++-- .../chunk_transfer_adapter.py | 1 - .../qwen3_tts_code_predictor_vllm.py | 63 +++++++++++++++++++ 4 files changed, 87 insertions(+), 15 deletions(-) diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index ba1b1f9d0c0..9faf74df52c 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -213,13 +213,6 @@ def schedule(self) -> SchedulerOutput: req_to_new_blocks=req_to_new_blocks, ) - # async_chunk: forward per-step additional_information updates for cached requests. - per_req_additional_info: dict[str, object] = {} - for req in scheduled_running_reqs: - req_info = getattr(req, "additional_information", None) - if isinstance(req_info, dict) and req_info: - per_req_additional_info[req.request_id] = req_info - cached_reqs_data = OmniCachedRequestData( req_ids=cached_reqs_data.req_ids, resumed_req_ids=cached_reqs_data.resumed_req_ids, @@ -230,8 +223,6 @@ def schedule(self) -> SchedulerOutput: num_output_tokens=cached_reqs_data.num_output_tokens, prompt_token_ids=cached_prompt_token_ids, ) - if per_req_additional_info: - cached_reqs_data.additional_information = per_req_additional_info total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) scheduler_output = SchedulerOutput( diff --git a/vllm_omni/distributed/omni_connectors/adapter.py b/vllm_omni/distributed/omni_connectors/adapter.py index 3bc05eeb3bc..1d1dd6f0f27 100644 --- a/vllm_omni/distributed/omni_connectors/adapter.py +++ b/vllm_omni/distributed/omni_connectors/adapter.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Temporary compatibility shim for vllm_omni.entrypoints.omni_stage.py / omni_llm.py. +# temporary for compatibility with vllm_omni.entrypoints.omni_stage.py +# and vllm_omni.entrypoints.omni_llm.py import time from collections.abc import Callable @@ -25,7 +26,12 @@ def try_send_via_connector( next_stage_queue_submit_fn: Callable[[dict[str, Any]], None], metrics: OrchestratorAggregator, ) -> bool: - """Send payload via OmniConnector and enqueue notification/metrics; return True on success.""" + """ + Attempts to send data via OmniConnector. + Returns True if successful, False otherwise. + Encapsulates the logic of preparing payload, sending via connector, + sending notification, and recording metrics. + """ try: t0 = time.time() @@ -90,7 +96,10 @@ def try_recv_via_connector( connectors: dict[Any, Any], stage_id: int, ) -> tuple[Any, dict[str, Any] | None]: - """Resolve engine_inputs from connector/IPC payload; returns (engine_inputs, rx_metrics) or (None, None).""" + """ + Attempts to resolve input data from either connector or IPC. + Returns (engine_inputs, rx_metrics) or (None, None) if failed/skipped. + """ rid = task["request_id"] if task.get("from_connector"): @@ -145,7 +154,10 @@ def try_recv_via_connector( ) return None, None else: - # Queue path (e.g. Stage-0 seed): task should carry direct inputs, but still decode SHM/IPC if present. + # Data comes from queue as usual (e.g. seed request for Stage-0) + # Since fallback logic is deprecated, we assume this is a direct inputs payload. + # We still need to decode it if it used SHM (via legacy stage_utils logic, or new shm_connector format) + # For Stage-0 specifically, 'engine_inputs' is often directly in the task dict. # Try to use the new stage_utils which uses OmniSerializer from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc_with_metrics @@ -162,7 +174,14 @@ def try_recv_via_connector( def compute_talker_prompt_ids_length(prompt_ids: list[int]) -> int: - """Compute talker prompt length for chat-style prompt ids (system/user/assistant).""" + """Compute the length of the talker prompt ids. + + Args: + prompt_ids: The prompt ids tensor. + + Returns: + The length of the talker prompt ids. + """ im_start_token_id = 151644 system_token_id = 8948 user_token_id = 872 diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 2cb4e20e59d..c7d72a6ba89 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -182,7 +182,6 @@ def _poll_single_request(self, req_id: str): req.prompt_token_ids = payload_data.get("code_predictor_codes", []) req.num_computed_tokens = 0 - req.additional_information = payload_data # Mark as finished for consumption with self.lock: diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py index 68d4baf51c5..b8b8f6bed49 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -430,3 +430,66 @@ def decode_logits(self, input_ids: torch.Tensor, *, generation_step: int, past_s logits = self.lm_head[generation_step](out) return logits + + @torch.inference_mode() + def forward( + self, + layer0_code: torch.Tensor, + layer0_embed: torch.Tensor, + last_talker_hidden: torch.Tensor, + do_sample: bool = True, + temperature: float = 1.0, + top_k: int = 50, + top_p: float = 1.0, + ) -> torch.Tensor: + """Full autoregressive prediction of residual codebooks 1..Q-1. + + Args: + layer0_code: [B, 1] first-layer codec token ids. + layer0_embed: [B, 1, H] embedding of layer0_code. + last_talker_hidden: [B, 1, H] hidden state from the talker. + do_sample: whether to sample or take argmax. + temperature: sampling temperature. + top_k: top-k filtering. + top_p: top-p (nucleus) filtering. + + Returns: + audio_codes: [B, Q] all codebook tokens (layer0 + residuals). + """ + bsz = int(layer0_code.shape[0]) + num_groups = int(self.config.num_code_groups) + max_steps = num_groups - 1 + + # Reset KV cache for a fresh sequence. + self.reset_cache() + + # Prefill: feed [last_talker_hidden, layer0_embed] → logits for group 1. + prefill_input = torch.cat([last_talker_hidden, layer0_embed], dim=1) # [B, 2, H] + logits = self.prefill_logits(prefill_input) # [B, vocab] + + all_codes = [layer0_code.reshape(bsz, 1)] + past_seq_len = 2 + + for step in range(1, num_groups): + # Sample or argmax from logits. + if do_sample and temperature > 0: + scaled = logits / temperature + if top_k > 0: + topk_vals, _ = scaled.topk(top_k, dim=-1) + scaled = scaled.masked_fill(scaled < topk_vals[:, -1:], float("-inf")) + probs = torch.softmax(scaled, dim=-1) + next_ids = torch.multinomial(probs, num_samples=1) # [B, 1] + else: + next_ids = logits.argmax(dim=-1, keepdim=True) # [B, 1] + all_codes.append(next_ids) + + # If not the last step, decode one more token. + if step < max_steps: + logits = self.decode_logits( + next_ids.reshape(bsz), + generation_step=step, + past_seq_len=past_seq_len, + ) + past_seq_len += 1 + + return torch.cat(all_codes, dim=1) # [B, Q] From 03f66bd9c5649a9c0e4959618a1ce1745c99609e Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Thu, 12 Feb 2026 05:43:54 -0800 Subject: [PATCH 18/28] [~] Style: Format error fixed Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml index 71ca44ace22..fbfbf10a49e 100644 --- a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml @@ -89,4 +89,4 @@ runtime: edges: - from: 0 to: 1 - window_size: -1 \ No newline at end of file + window_size: -1 From 3316f49e4ff59d3db119d888545f0ed91dd4f653 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Fri, 13 Feb 2026 14:03:17 -0800 Subject: [PATCH 19/28] [+] Feat: Enhance SSRF protection and improve TTS processing for cuda_graph Signed-off-by: Sy03 <1370724210@qq.com> --- .../entrypoints/openai/serving_speech.py | 23 +++++++++++++++++++ .../models/qwen3_tts/qwen3_tts_talker.py | 19 +++++++++++---- .../stage_configs/qwen3_tts.yaml | 2 +- .../stage_input_processors/qwen3_tts.py | 23 ++++++------------- vllm_omni/worker/gpu_model_runner.py | 19 ++++++++++++++- 5 files changed, 63 insertions(+), 23 deletions(-) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 1467b16d65e..201be69dae0 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -1,6 +1,8 @@ import asyncio import base64 import io +import ipaddress +import socket from typing import Any from urllib.parse import urlparse from urllib.request import urlopen @@ -25,6 +27,16 @@ _REF_AUDIO_TIMEOUT_S = 15 _REF_AUDIO_MAX_BYTES = 50 * 1024 * 1024 # 50 MB +_REF_AUDIO_BLOCKED_NETWORKS = [ + ipaddress.ip_network("127.0.0.0/8"), + ipaddress.ip_network("10.0.0.0/8"), + ipaddress.ip_network("172.16.0.0/12"), + ipaddress.ip_network("192.168.0.0/16"), + ipaddress.ip_network("169.254.0.0/16"), + ipaddress.ip_network("::1/128"), + ipaddress.ip_network("fc00::/7"), + ipaddress.ip_network("fe80::/10"), +] # TTS Configuration (currently supports Qwen3-TTS) _TTS_MODEL_STAGES: set[str] = {"qwen3_tts"} @@ -170,8 +182,19 @@ async def _resolve_ref_audio(ref_audio_str: str) -> tuple[list[float], int]: """Resolve ref_audio URL/base64 to (wav_samples, sample_rate).""" parsed = urlparse(ref_audio_str) + def _check_ssrf(url: str) -> None: + host = urlparse(url).hostname + if not host: + raise ValueError("ref_audio URL must include a hostname") + for info in socket.getaddrinfo(host, None): + ip_str = str(info[4][0]).split("%", 1)[0] + addr = ipaddress.ip_address(ip_str) + if any(addr in net for net in _REF_AUDIO_BLOCKED_NETWORKS): + raise ValueError(f"ref_audio URL resolves to blocked address: {addr}") + def _fetch_sync() -> tuple[np.ndarray, int]: if parsed.scheme in ("http", "https"): + _check_ssrf(ref_audio_str) with urlopen(ref_audio_str, timeout=_REF_AUDIO_TIMEOUT_S) as resp: data = resp.read(_REF_AUDIO_MAX_BYTES + 1) if len(data) > _REF_AUDIO_MAX_BYTES: diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index d3937bfe8e0..a39eded3aa6 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -327,6 +327,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self._codec_eos_token_id = int(getattr(self.talker_config, "codec_eos_token_id", -1)) + self._eos_logit_bias: float = 0.0 + self.have_multimodal_outputs = True self.has_preprocess = True self.has_postprocess = True @@ -366,9 +368,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Keep it optional to avoid strict weight-loading failures. self.speaker_encoder: Qwen3TTSSpeakerEncoder | None = None - # Residual code predictor (1..Q-1) uses a dedicated vLLM config to build its own KV cache. - # This avoids polluting the main engine's static forward context. + # Code predictor uses an isolated vLLM config so its KV cache doesn't + # pollute the main engine's static_forward_context (shallow-copy shares + # the dict by reference — must assign a fresh one). predictor_compilation = dataclasses.replace(vllm_config.compilation_config) + predictor_compilation.static_forward_context = {} self._code_predictor_vllm_config = dataclasses.replace(vllm_config, compilation_config=predictor_compilation) from vllm.config.vllm import set_current_vllm_config as _set_cfg @@ -422,6 +426,12 @@ def compute_logits( # Mask out invalid codec ids using the pre-built constant buffer. logits = logits.masked_fill(~self._codec_allowed_mask, float("-inf")) + + if self._eos_logit_bias != 0.0: + eos_id = self._codec_eos_token_id + if 0 <= eos_id < logits.shape[-1]: + logits[:, eos_id] = logits[:, eos_id] + self._eos_logit_bias + return logits # -------------------- Omni multimodal output plumbing -------------------- @@ -1551,7 +1561,7 @@ def talker_mtp( audio_codes = input_ids.reshape(bsz, 1) return (last_id_hidden + text_step).reshape(bsz, -1), audio_codes - # Single forward call: predicts all residual codes (1..Q-1) autoregressively. + # Predict residual codes (1..Q-1) with HF reference sampling params. audio_codes = self.code_predictor( layer0_code=input_ids.reshape(bsz, 1), layer0_embed=last_id_hidden, @@ -1565,8 +1575,7 @@ def talker_mtp( # Map invalid layer-0 ids (e.g. EOS) to PAD=0 so SpeechTokenizer sees only real codes. layer0 = audio_codes[:, :1] invalid0 = (layer0 < 0) | (layer0 >= int(self._codebook_vocab_size)) - if invalid0.any(): - audio_codes = torch.where(invalid0.expand_as(audio_codes), torch.zeros_like(audio_codes), audio_codes) + audio_codes = torch.where(invalid0.expand_as(audio_codes), torch.zeros_like(audio_codes), audio_codes) # Sum embeddings of all code groups, then add the current text step. residual_ids_t = audio_codes[:, 1:] diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml index 0306a04269f..1f29f0796ed 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml @@ -13,7 +13,7 @@ stage_args: architectures: [Qwen3TTSTalkerForConditionalGeneration] worker_type: ar scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - enforce_eager: true + enforce_eager: false trust_remote_code: true async_scheduling: false enable_prefix_caching: false diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 08c89ef3acd..8599ea2e3e8 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -20,7 +20,7 @@ def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: def talker2code2wav_async_chunk( - connector: Any, + transfer_manager: Any, pooling_output: dict[str, Any], request: Any, ) -> dict[str, Any] | None: @@ -29,6 +29,7 @@ def talker2code2wav_async_chunk( request_id = request.external_req_id + connector = getattr(transfer_manager, "connector", None) raw_cfg = getattr(connector, "config", {}) or {} cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} chunk_size = int(cfg.get("codec_chunk_frames", 25)) @@ -41,16 +42,12 @@ def talker2code2wav_async_chunk( finished = bool(request.is_finished()) - appended_frame = False - if not finished: - frame = _extract_last_frame(pooling_output) - if frame is None: - return None + frame = _extract_last_frame(pooling_output) + if frame is not None: codec_codes = frame.cpu().tolist() - connector.code_prompt_token_ids[request_id].append(codec_codes) - appended_frame = True + transfer_manager.code_prompt_token_ids[request_id].append(codec_codes) - length = len(connector.code_prompt_token_ids[request_id]) + length = len(transfer_manager.code_prompt_token_ids[request_id]) chunk_length = length % chunk_size if chunk_length != 0 and not finished: @@ -58,12 +55,6 @@ def talker2code2wav_async_chunk( context_length = chunk_length if chunk_length != 0 else chunk_size - if finished and (not appended_frame) and chunk_length == 0: - return { - "code_predictor_codes": [], - "finished": torch.tensor(True, dtype=torch.bool), - } - if length <= 0: return { "code_predictor_codes": [], @@ -72,7 +63,7 @@ def talker2code2wav_async_chunk( end_index = min(length, left_context_size + context_length) ctx_frames = max(0, int(end_index - context_length)) - window_frames = connector.code_prompt_token_ids[request_id][-end_index:] + window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:] # Pack context + chunk into codebook-major flat codes for adapter. code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist() diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index f1513e95031..a01ad113aeb 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -78,7 +78,11 @@ def load_model(self, *args, **kwargs) -> None: self.talker_mtp = talker_mtp # type: ignore[assignment] cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - if cudagraph_mode.has_full_cudagraphs(): + # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that + # have a separate .talker sub-module. TTS models' code predictor + # has internal AR loops / torch.multinomial — not graph-safe. + has_separate_talker = getattr(self.model, "talker", None) is not None + if cudagraph_mode.has_full_cudagraphs() and has_separate_talker: self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size. hidden_size = int( @@ -648,6 +652,11 @@ def _dummy_run( input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] model_kwargs = self._init_model_kwargs() + elif getattr(getattr(self, "model", None), "has_preprocess", False): + # Capture CUDA graph with inputs_embeds path so replay reads + # from the same buffer that _preprocess writes into. + input_ids = self.input_ids.gpu[:num_tokens_padded] + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] else: input_ids = self.input_ids.gpu[:num_tokens_padded] inputs_embeds = None @@ -990,6 +999,11 @@ def _preprocess( inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] model_kwargs = self._init_model_kwargs() input_ids = self.input_ids.gpu[:num_input_tokens] + elif getattr(self.model, "has_preprocess", False): + # Use pre-allocated buffer for CUDA graph compatibility. + input_ids = self.input_ids.gpu[:num_input_tokens] + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + model_kwargs = self._init_model_kwargs() else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the @@ -1104,6 +1118,9 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te max_num_scheduled_tokens=1, use_cascade_attn=False, ) + # Force eager for unwrapped code predictors (AR loops / multinomial). + if not isinstance(self.talker_mtp, CUDAGraphWrapper): + _cudagraph_mode = CUDAGraphMode.NONE num_tokens_padded = batch_desc.num_tokens req_input_ids = self.talker_mtp_input_ids.gpu[:num_tokens_padded] req_embeds = self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded] From 1cad15db4d37efbf05b5aad57fcaa66e2834db1b Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Fri, 13 Feb 2026 23:34:36 -0800 Subject: [PATCH 20/28] [~] Style: Fix pre-commit issue Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/core/sched/omni_generation_scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 9faf74df52c..bef340d9684 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -349,7 +349,8 @@ def update_from_output( continue request = self.requests.get(req_id) if request is None or request.is_finished(): - # Request may already be finished (e.g., aborted during execution / pipeline parallelism / async scheduling). + # Request may already be finished (e.g., aborted during + # execution / pipeline parallelism / async scheduling). continue req_index = model_runner_output.req_id_to_index[req_id] From b6c1928c4f198a04debbfdd3ecbf4a406cc4a65a Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Sun, 15 Feb 2026 07:17:08 -0800 Subject: [PATCH 21/28] [+] Fix: Enhance scheduling logic and chunk processing in OmniGenerationScheduler and ChunkTransferAdapter to solve v0.16.0 errors Signed-off-by: Sy03 <1370724210@qq.com> --- .../core/sched/omni_generation_scheduler.py | 20 ++++++++++--------- .../chunk_transfer_adapter.py | 9 ++++++++- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index bef340d9684..10d52dfd4e1 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -176,11 +176,13 @@ def schedule(self) -> SchedulerOutput: # If fast path scheduled none, fall back to the original scheduling if not num_scheduled_tokens: - res = super().schedule() if self.chunk_transfer_adapter: + # Don't fall back: base scheduler doesn't handle async_chunk + # requests with empty prompt_token_ids. self.chunk_transfer_adapter.restore_queues(self.waiting, self.running) - self.chunk_transfer_adapter.postprocess_scheduler_output(res) - return res + else: + res = super().schedule() + return res # Compute common prefix blocks (aligned with v1) num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) @@ -392,6 +394,10 @@ def update_from_output( # Diffusion request: completes in one step; mark finished and free resources if request.status == RequestStatus.FINISHED_STOPPED or ( self.chunk_transfer_adapter is None and request.num_computed_tokens >= request.num_prompt_tokens + ) or ( + self.chunk_transfer_adapter is not None + and request.request_id in self.chunk_transfer_adapter.finished_requests + and request.num_computed_tokens >= len(request.prompt_token_ids) ): request.status = RequestStatus.FINISHED_STOPPED # Optional: set a stop_reason for front-end clarity @@ -405,15 +411,11 @@ def update_from_output( finished = self._handle_stopped_request(request) if finished: kv_transfer_params = self._free_request(request) - if status_before_stop == RequestStatus.RUNNING: - stopped_running_reqs.add(request) - elif status_before_stop == RequestStatus.WAITING_FOR_CHUNK: - # In async chunk mode, request may be in either queue. - # Remove from both to avoid stale queue entries. + if status_before_stop == RequestStatus.WAITING_FOR_CHUNK: stopped_running_reqs.add(request) stopped_preempted_reqs.add(request) else: - stopped_preempted_reqs.add(request) + stopped_running_reqs.add(request) # Extract sample logprobs if needed. if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs: diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index c7d72a6ba89..a6afb97bd4c 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -180,9 +180,16 @@ def _poll_single_request(self, req_id: str): if payload_data.get("finished"): self.finished_requests.add(req_id) - req.prompt_token_ids = payload_data.get("code_predictor_codes", []) + # req.prompt_token_ids = payload_data.get("code_predictor_codes", []) + # req.num_computed_tokens = 0 + new_ids = payload_data.get("code_predictor_codes", []) + req.prompt_token_ids = new_ids req.num_computed_tokens = 0 + # Empty chunk with more data expected: keep polling. + if not new_ids and not payload_data.get("finished"): + return + # Mark as finished for consumption with self.lock: self._finished_load_reqs.add(req_id) From 8fbf458d856547e3b4c387c85e4425a2534f13fe Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Sun, 15 Feb 2026 07:24:30 -0800 Subject: [PATCH 22/28] [~] Style: Fix ruff format error Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/core/sched/omni_generation_scheduler.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 10d52dfd4e1..ef1c4c7c901 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -392,12 +392,14 @@ def update_from_output( routed_experts = None # Diffusion request: completes in one step; mark finished and free resources - if request.status == RequestStatus.FINISHED_STOPPED or ( - self.chunk_transfer_adapter is None and request.num_computed_tokens >= request.num_prompt_tokens - ) or ( - self.chunk_transfer_adapter is not None - and request.request_id in self.chunk_transfer_adapter.finished_requests - and request.num_computed_tokens >= len(request.prompt_token_ids) + if ( + request.status == RequestStatus.FINISHED_STOPPED + or (self.chunk_transfer_adapter is None and request.num_computed_tokens >= request.num_prompt_tokens) + or ( + self.chunk_transfer_adapter is not None + and request.request_id in self.chunk_transfer_adapter.finished_requests + and request.num_computed_tokens >= len(request.prompt_token_ids) + ) ): request.status = RequestStatus.FINISHED_STOPPED # Optional: set a stop_reason for front-end clarity From 5530aaa325cd93f2a64cfe6346a49be2d8a50277 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Mon, 16 Feb 2026 00:58:50 -0800 Subject: [PATCH 23/28] [+] CI: Add prompt length estimation for Talker stage and refactor input handling in offline test Signed-off-by: Sy03 <1370724210@qq.com> --- .../offline_inference/qwen3_tts/end2end.py | 204 ++++++++++-------- 1 file changed, 120 insertions(+), 84 deletions(-) diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index 93aeba3ca5f..12e5e193542 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -4,18 +4,21 @@ tasks, then runs Omni generation and saves output wav files. """ +import logging import os -from typing import NamedTuple +from typing import Any, NamedTuple import soundfile as sf +import torch os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" -from vllm import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm_omni import Omni +logger = logging.getLogger(__name__) + class QueryResult(NamedTuple): """Container for a prepared Omni request.""" @@ -24,6 +27,44 @@ class QueryResult(NamedTuple): model_name: str +def _estimate_prompt_len( + additional_information: dict[str, Any], + model_name: str, + _cache: dict[str, Any] = {}, +) -> int: + """Estimate prompt_token_ids placeholder length for the Talker stage. + + The AR Talker replaces all input embeddings via ``preprocess``, so the + placeholder values are irrelevant but the **length** must match the + embeddings that ``preprocess`` will produce. + """ + try: + from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( + Qwen3TTSTalkerForConditionalGeneration, + ) + + if model_name not in _cache: + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left") + cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True) + _cache[model_name] = (tok, getattr(cfg, "talker_config", None)) + + tok, tcfg = _cache[model_name] + task_type = (additional_information.get("task_type") or ["CustomVoice"])[0] + return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information( + additional_information=additional_information, + task_type=task_type, + tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"], + codec_language_id=getattr(tcfg, "codec_language_id", None), + spk_is_dialect=getattr(tcfg, "spk_is_dialect", None), + ) + except Exception as exc: + logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc) + return 2048 + + def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult: """Build CustomVoice sample inputs. @@ -34,6 +75,7 @@ def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult: QueryResult with Omni inputs and the CustomVoice model path. """ task_type = "CustomVoice" + model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" if use_batch_sample: texts = ["其实我真的有发现,我是一个特别善于观察别人情绪的人。", "She said she would be here by noon."] instructs = ["", "Very happy."] @@ -41,18 +83,18 @@ def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult: speakers = ["Vivian", "Ryan"] inputs = [] for text, instruct, language, speaker in zip(texts, instructs, languages, speakers): - prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "text": [text], + "instruct": [instruct], + "language": [language], + "speaker": [speaker], + "max_new_tokens": [2048], + } inputs.append( { - "prompt": prompt, - "additional_information": { - "task_type": [task_type], - "text": [text], - "instruct": [instruct], - "language": [language], - "speaker": [speaker], - "max_new_tokens": [2048], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } ) else: @@ -60,21 +102,21 @@ def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult: language = "Chinese" speaker = "Vivian" instruct = "用特别愤怒的语气说" - prompts = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "text": [text], + "language": [language], + "speaker": [speaker], + "instruct": [instruct], + "max_new_tokens": [2048], + } inputs = { - "prompt": prompts, - "additional_information": { - "task_type": [task_type], - "text": [text], - "language": [language], - "speaker": [speaker], - "instruct": [instruct], - "max_new_tokens": [2048], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } return QueryResult( inputs=inputs, - model_name="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", + model_name=model_name, ) @@ -88,6 +130,7 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult: QueryResult with Omni inputs and the VoiceDesign model path. """ task_type = "VoiceDesign" + model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign" if use_batch_sample: texts = [ "哥哥,你回来啦,人家等了你好久好久了,要抱抱!", @@ -100,39 +143,39 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult: languages = ["Chinese", "English"] inputs = [] for text, instruct, language in zip(texts, instructs, languages): - prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "text": [text], + "language": [language], + "instruct": [instruct], + "max_new_tokens": [2048], + "non_streaming_mode": [True], + } inputs.append( { - "prompt": prompt, - "additional_information": { - "task_type": [task_type], - "text": [text], - "language": [language], - "instruct": [instruct], - "max_new_tokens": [2048], - "non_streaming_mode": [True], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } ) else: text = "哥哥,你回来啦,人家等了你好久好久了,要抱抱!" instruct = "体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。" language = "Chinese" - prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "text": [text], + "language": [language], + "instruct": [instruct], + "max_new_tokens": [2048], + "non_streaming_mode": [True], + } inputs = { - "prompt": prompt, - "additional_information": { - "task_type": [task_type], - "text": [text], - "language": [language], - "instruct": [instruct], - "max_new_tokens": [2048], - "non_streaming_mode": [True], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } return QueryResult( inputs=inputs, - model_name="Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", + model_name=model_name, ) @@ -147,6 +190,7 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que QueryResult with Omni inputs and the Base model path. """ task_type = "Base" + model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" ref_audio_path_1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav" ref_audio_single = ref_audio_path_1 ref_text_single = ( @@ -163,38 +207,38 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que syn_lang_batch = ["Chinese", "English"] inputs = [] for text, language in zip(syn_text_batch, syn_lang_batch): - prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "ref_audio": [ref_audio_single], + "ref_text": [ref_text_single], + "text": [text], + "language": [language], + "x_vector_only_mode": [x_vector_only_mode], + "max_new_tokens": [2048], + } inputs.append( { - "prompt": prompt, - "additional_information": { - "task_type": [task_type], - "ref_audio": [ref_audio_single], - "ref_text": [ref_text_single], - "text": [text], - "language": [language], - "x_vector_only_mode": [x_vector_only_mode], - "max_new_tokens": [2048], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } ) else: - prompt = f"<|im_start|>assistant\n{syn_text_single}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "ref_audio": [ref_audio_single], + "ref_text": [ref_text_single], + "text": [syn_text_single], + "language": [syn_lang_single], + "x_vector_only_mode": [x_vector_only_mode], + "max_new_tokens": [2048], + } inputs = { - "prompt": prompt, - "additional_information": { - "task_type": [task_type], - "ref_audio": [ref_audio_single], - "ref_text": [ref_text_single], - "text": [syn_text_single], - "language": [syn_lang_single], - "x_vector_only_mode": [x_vector_only_mode], - "max_new_tokens": [2048], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } return QueryResult( inputs=inputs, - model_name="Qwen/Qwen3-TTS-12Hz-1.7B-Base", + model_name=model_name, ) @@ -223,30 +267,22 @@ def main(args): stage_init_timeout=args.stage_init_timeout, ) - sampling_params = SamplingParams( - temperature=0.9, - top_p=1.0, - top_k=50, - max_tokens=2048, - seed=42, - detokenize=False, - repetition_penalty=1.05, - ) - - sampling_params_list = [ - sampling_params, - ] - output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav os.makedirs(output_dir, exist_ok=True) - omni_generator = omni.generate(query_result.inputs, sampling_params_list) + omni_generator = omni.generate(query_result.inputs, sampling_params_list=None) for stage_outputs in omni_generator: for output in stage_outputs.request_output: request_id = output.request_id - audio_tensor = output.outputs[0].multimodal_output["audio"] + audio_data = output.outputs[0].multimodal_output["audio"] + # async_chunk mode returns a list of chunks; concatenate them. + if isinstance(audio_data, list): + audio_tensor = torch.cat(audio_data, dim=-1) + else: + audio_tensor = audio_data output_wav = os.path.join(output_dir, f"output_{request_id}.wav") - audio_samplerate = output.outputs[0].multimodal_output["sr"].item() + sr_val = output.outputs[0].multimodal_output["sr"] + audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1]) # Convert to numpy array and ensure correct format audio_numpy = audio_tensor.float().detach().cpu().numpy() From a108e624ca0949df2ed2c5a6f5670ecc31847780 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Thu, 19 Feb 2026 12:16:00 -0800 Subject: [PATCH 24/28] [~] CI: Fix unit-test for "talker_mtp_output_key" caller update for qwen3_tts Signed-off-by: Sy03 <1370724210@qq.com> --- tests/worker/test_omni_gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index c7836123a64..9b5052b464a 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -69,6 +69,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4): runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) runner.talker_mtp = DummyTalkerMTP() + runner.model = SimpleNamespace(talker_mtp_output_key="code_predictor_codes") runner.vllm_config = object() # Provide a minimal implementation that returns the expected 4-tuple. From 60b66036239a05a7b3908b50a9c4698a775fe60a Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Fri, 20 Feb 2026 03:55:43 -0800 Subject: [PATCH 25/28] [~] CI: Increase timeout for Omni Model Test step from 15 to 20 minutes Signed-off-by: Sy03 <1370724210@qq.com> --- .buildkite/pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index d0566e18b00..c50bbcfb7d3 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -176,7 +176,7 @@ steps: # type: DirectoryOrCreate - label: "Omni Model Test" - timeout_in_minutes: 15 + timeout_in_minutes: 20 depends_on: image-build commands: - export VLLM_LOGGING_LEVEL=DEBUG From b6e69729545bdbf4e51c942b16761bebaca31a74 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Fri, 20 Feb 2026 07:55:37 -0800 Subject: [PATCH 26/28] [~] Fix: Copy input batch request IDs and indices in NPU and GPU model runners to prevent mutation Signed-off-by: Sy03 <1370724210@qq.com> --- .../platforms/npu/worker/npu_generation_model_runner.py | 6 ++++-- vllm_omni/worker/gpu_generation_model_runner.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py index e8559bb463c..abc3a92e691 100644 --- a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py @@ -238,9 +238,11 @@ def sample_tokens( pooler_output.append(mm_payload) else: raise RuntimeError("Unsupported diffusion output type") + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() output = OmniModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 0747db3ea57..da3112b61f3 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -380,9 +380,11 @@ def sample_tokens( pooler_output.append(mm_payload) else: raise RuntimeError("Unsupported diffusion output type") + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() output = OmniModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, From f9d656f6effbb1191ebe95ca32a2a5e4704faf15 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Fri, 20 Feb 2026 10:07:51 -0800 Subject: [PATCH 27/28] [+] Style: Add comments to clarify copying operation Signed-off-by: Sy03 <1370724210@qq.com> --- vllm_omni/platforms/npu/worker/npu_generation_model_runner.py | 1 + vllm_omni/worker/gpu_generation_model_runner.py | 1 + 2 files changed, 2 insertions(+) diff --git a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py index abc3a92e691..d263fb0d386 100644 --- a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py @@ -238,6 +238,7 @@ def sample_tokens( pooler_output.append(mm_payload) else: raise RuntimeError("Unsupported diffusion output type") + # [Omni] Copy req_id mappings to avoid async scheduling mutation. req_ids_output_copy = self.input_batch.req_ids.copy() req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() output = OmniModelRunnerOutput( diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index da3112b61f3..aa75d201ccb 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -380,6 +380,7 @@ def sample_tokens( pooler_output.append(mm_payload) else: raise RuntimeError("Unsupported diffusion output type") + # [Omni] Copy req_id mappings to avoid async scheduling mutation. req_ids_output_copy = self.input_batch.req_ids.copy() req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() output = OmniModelRunnerOutput( From 33fbc3b7d196fcba21d0f25d7c28be6c0a2e53e7 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Fri, 20 Feb 2026 11:33:58 -0800 Subject: [PATCH 28/28] [-] Build: Remove deprecated configuration file for Qwen3 TTS talker speech tokenizer async chunk. Signed-off-by: Sy03 <1370724210@qq.com> --- ...s_talker_speech_tokenizer_async_chunk.yaml | 92 ------------------- 1 file changed, 92 deletions(-) delete mode 100644 vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml deleted file mode 100644 index 8f3a2dfb72c..00000000000 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts_talker_speech_tokenizer_async_chunk.yaml +++ /dev/null @@ -1,92 +0,0 @@ -async_chunk: true -stage_args: - - stage_id: 0 - stage_type: llm - runtime: - devices: "0" - max_batch_size: 1 - engine_args: - model_stage: qwen3_tts - model_arch: Qwen3TTSTalkerForConditionalGeneration - hf_overrides: - architectures: [Qwen3TTSTalkerForConditionalGeneration] - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - enforce_eager: false - trust_remote_code: true - async_scheduling: false - enable_prefix_caching: false - engine_output_type: latent - gpu_memory_utilization: 0.3 - distributed_executor_backend: "mp" - max_num_batched_tokens: 512 - max_model_len: 4096 - # Stage-0 emits flattened codec codes via async_chunk connector. - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: false - repetition_penalty: 1.05 - stop_token_ids: [2150] - - - stage_id: 1 - stage_type: llm - runtime: - devices: "0" - max_batch_size: 1 - engine_args: - model_stage: code2wav - model_arch: Qwen3TTSCode2Wav - hf_overrides: - architectures: [Qwen3TTSCode2Wav] - # Stage-1 has no main checkpoint weights (SpeechTokenizer is loaded from - # `speech_tokenizer/` lazily). Avoid probing for model.safetensors. - load_format: dummy - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - async_scheduling: false - enable_prefix_caching: false - engine_output_type: audio - gpu_memory_utilization: 0.2 - distributed_executor_backend: "mp" - # Must be >= num_code_groups * (codec_left_context_frames + codec_chunk_frames). - max_num_batched_tokens: 8192 - # async_chunk appends windows per step; max_model_len must cover accumulated stream. - max_model_len: 32768 - engine_input_source: [0] - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: true - repetition_penalty: 1.0 - -runtime: - enabled: true - defaults: - window_size: -1 - max_inflight: 1 - - connectors: - connector_of_shared_memory: - name: SharedMemoryConnector - extra: - shm_threshold_bytes: 65536 - # Qwen3-TTS codec streaming (frame-aligned tokenized transport). - codec_streaming: true - codec_chunk_frames: 25 - codec_left_context_frames: 25 - - edges: - - from: 0 - to: 1 - window_size: -1