diff --git a/examples/online_serving/text_to_speech/qwen3_tts_nv/Dockerfile b/examples/online_serving/text_to_speech/qwen3_tts_nv/Dockerfile new file mode 100644 index 00000000000..96189dd3bc2 --- /dev/null +++ b/examples/online_serving/text_to_speech/qwen3_tts_nv/Dockerfile @@ -0,0 +1,26 @@ +FROM nvcr.io/nvidia/tritonserver:26.02-py3 + +# 1. System dependency for git-based installs +RUN apt-get update && \ + apt-get install -y git sox libsox-fmt-all + +# 2. Install upstream vLLM first so it pulls in torch and the core runtime stack +RUN pip install --no-cache-dir "vllm==0.21.0" + +# 3. Install vLLM-Omni from the fork/branch on top of upstream vLLM +RUN git clone --single-branch --branch vklimkov/qwen3tts_nv \ + https://github.com/vklimkov-nvidia/vllm-omni.git /tmp/vllm-omni && \ + cd /tmp/vllm-omni && \ + VLLM_OMNI_TARGET_DEVICE=cuda pip install --no-cache-dir . && \ + cd / && rm -rf /tmp/vllm-omni + +# 5. Extra python requirements +RUN pip install --no-cache-dir \ + onnxscript \ + librosa \ + sox \ + onnx-graphsurgeon \ + "tritonclient[grpc]" +RUN pip install --no-cache-dir --force-reinstall --no-deps "numpy==2.3.5" + +WORKDIR /workspace diff --git a/examples/online_serving/text_to_speech/qwen3_tts_nv/README.md b/examples/online_serving/text_to_speech/qwen3_tts_nv/README.md new file mode 100644 index 00000000000..7685a5ff89d --- /dev/null +++ b/examples/online_serving/text_to_speech/qwen3_tts_nv/README.md @@ -0,0 +1,145 @@ +# Qwen3-TTS Triton serving example + +End-to-end recipe for serving [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) +with NVIDIA Triton Inference Server. + +## Motivation + +Qwen3-TTS has two stages with very different runtime characteristics: + +- **Talker** — autoregressive Transformer that emits discrete audio codes. + Token-by-token decoding benefits from continuous batching and paged + KV-cache, so we serve it with **vLLM-Omni** as a Python Triton backend. +- **Codec decoder** — non-autoregressive convolutional model that turns a + chunk of codes into a waveform. We export it to **TensorRT** with a + dynamic batch profile and serve it via Triton's native `tensorrt_plan` + backend with dynamic batching enabled. + +The talker streams chunks of codes into the codec via Triton's +[BLS](https://github.com/triton-inference-server/python_backend#business-logic-scripting) +API, and waveform chunks are streamed back to the client over a decoupled +gRPC stream. + +## 1. One-time setup + +Steps 1.1 and 1.2 only need to be done once per machine. + +### 1.1 Export the codec decoder to ONNX (host) + +The ONNX export must run in an environment that matches the original +Qwen3-TTS repo — see its +[Quickstart](https://github.com/QwenLM/Qwen3-TTS#quickstart). Create a +clean Python 3.12 env on the host: + +```bash +conda create -n qwen3-tts python=3.12 -y +conda activate qwen3-tts +pip install -U qwen-tts onnx onnxruntime + +cd examples/online_serving/text_to_speech/qwen3_tts_nv +python3 scripts/export_codec_onnx.py \ + --tokenizer-path Qwen/Qwen3-TTS-Tokenizer-12Hz \ + --onnx-path codec.onnx +``` + +> We plan to release a pre-exported `codec.onnx` so this step can be +> skipped. + +### 1.2 Build the Triton container and TRT engine + +```bash +cd examples/online_serving/text_to_speech/qwen3_tts_nv +docker build --network=host -t qwen3tts_nv . +docker run --rm -it --gpus all \ + --shm-size=8g \ + --network=host \ + -v "$(pwd):/workspace" \ + -v "${HOME}/.cache/huggingface:/root/.cache/huggingface" \ + -e HF_HOME=/root/.cache/huggingface \ + qwen3tts_nv \ + /bin/bash +``` + +All subsequent commands run inside the container, from `/workspace`. + +Build the TRT engine from the ONNX produced in step 1.1: + +```bash +python3 scripts/export_codec_trt.py \ + --onnx-path codec.onnx \ + --trt-path model_repository/codec_decoder/1/model.plan +``` + +The default Triton config (`model_repository/codec_decoder/config.pbtxt`) +uses dynamic batching with `max_batch_size: 32`, so the same engine +handles batches up to 32. Codec is exported for `codec_chunk_size==30`. + +## 2. Start the server + +Run from inside the container, in `/workspace`: + +```bash +tritonserver --model-repository=model_repository +``` + +This loads two models: + +- `qwen3_tts` — Python backend running the vLLM-Omni talker (decoupled, + streaming). +- `codec_decoder` — TensorRT backend running the exported engine with + dynamic batching. + +## 3. Usage & benchmarking + +Two scripts are provided: + +- `scripts/benchmark_model.py` — benchmarks the **talker only**, and + doubles as an example of how to drive the vLLM-Omni Qwen3-TTS model + definition directly. Spins up a single-stage `AsyncOmni` engine and + measures throughput, TTFT and ITL on raw codec tokens. +- `scripts/benchmark_service.py` — benchmarks the **full Triton service + end to end** over gRPC, and doubles as an example of how to issue + requests against a running `tritonserver`. Text in, streamed waveform + chunks out (talker + codec + BLS plumbing). Measures throughput, + real-time factor (RTF) and time-to-first-audio (TTFA). + +Both read prompts from a `\t` file and accept a +concurrency / `--num-workers` argument. + +```bash +# Talker-only (model) benchmark +python3 scripts/benchmark_model.py \ + --model Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ + --text-file vctk_subset.txt \ + --num-requests 16 \ + --concurrency 1 + +# End-to-end service benchmark (Triton must be running) +python3 scripts/benchmark_service.py \ + --text-file vctk_subset.txt \ + --num-requests 16 \ + --num-workers 1 +``` + +### Reference results (RTX A6000) + +Single RTX A6000, default `max_num_seqs` / engine config. Latencies are +`mean / p95` in milliseconds. + +**Talker only** (`scripts/benchmark_model.py`, codec tokens only): + +| Concurrency | Throughput (req/s) | TTFT mean / p95 (ms) | ITL mean / p95 (ms) | +| ----------: | -----------------: | -------------------: | ------------------: | +| 1 | 0.99 | 26.81 / 28.19 | 15.04 / 16.54 | +| 4 | 3.41 | 51.88 / 54.31 | 16.72 / 18.84 | +| 8 | 5.60 | 61.03 / 66.51 | 19.45 / 23.05 | +| 32 | 12.79 | 105.49 / 114.59 | 33.45 / 39.13 | + +**End-to-end service** (`scripts/benchmark_service.py`, talker + codec): + +| Concurrency | Throughput (req/s) | RTF | TTFA mean / p95 (ms) | +| ----------: | -----------------: | -----: | -------------------: | +| 1 | 0.81 | 4.79x | 73.5 / 75.5 | +| 4 | 2.60 | 13.75x | 116.7 / 130.2 | +| 8 | 3.74 | 19.80x | 158.8 / 184.7 | +| 32 | 7.78 | 36.62x | 342.2 / 399.2 | diff --git a/examples/online_serving/text_to_speech/qwen3_tts_nv/model_repository/codec_decoder/config.pbtxt b/examples/online_serving/text_to_speech/qwen3_tts_nv/model_repository/codec_decoder/config.pbtxt new file mode 100644 index 00000000000..2b8f51c1caf --- /dev/null +++ b/examples/online_serving/text_to_speech/qwen3_tts_nv/model_repository/codec_decoder/config.pbtxt @@ -0,0 +1,31 @@ +name: "codec_decoder" +platform: "tensorrt_plan" +max_batch_size: 32 + +input [ + { + name: "audio_codes" + data_type: TYPE_INT64 + dims: [ -1, 16 ] + } +] + +output [ + { + name: "audio_values" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +dynamic_batching { + max_queue_delay_microseconds: 1000 + preferred_batch_size: [ 32 ] +} + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] diff --git a/examples/online_serving/text_to_speech/qwen3_tts_nv/model_repository/qwen3_tts/1/model.py b/examples/online_serving/text_to_speech/qwen3_tts_nv/model_repository/qwen3_tts/1/model.py new file mode 100644 index 00000000000..8942ee6678d --- /dev/null +++ b/examples/online_serving/text_to_speech/qwen3_tts_nv/model_repository/qwen3_tts/1/model.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved +"""Triton Python backend for Qwen3-TTS-NV driven by vllm-omni's AsyncOmni engine. + +Wraps :class:`Qwen3TTSTalkerForConditionalGenerationNv` (the NV-flavoured AR +talker, same model used by ``benchmark_qwen3_tts_talker.py``) — the talker +emits codec tokens directly and we stream them out through the +``codec_decoder`` BLS model. + +Pipeline: + 1. Build an ``additional_information`` dict from ``{task_type, text, + language, speaker}`` and a placeholder ``prompt_token_ids`` of length + ``prompt_len`` (estimated from the same talker class). + 2. Submit one request to ``AsyncOmni.generate()``; stream codec frames out + as they arrive, chunk-decoding them through the ``codec_decoder`` BLS. + 3. Client receives a sequence of audio chunks @ 24 kHz, the last marked + final. + +Notes: + * The NV talker exposes per-step codec rows under the ``"audio_codes"`` + multimodal key. The engine's accumulator passes through three shapes + across one request: (a) the first yield is a single tensor — the + prefill prefix; (b) the middle yields share a growing Python list + (one append per AR step); (c) the final yield is a tensor again + (``_consolidate_multimodal_tensors`` cats the list). + We **skip the first and last (tensor-typed) yields** entirely — the + first is the unused prefill, the last is the already-streamed cat + of everything — and only consume from the in-between list yields + (always skipping list index 0, which is the same prefill tensor). + * ``max_num_batched_tokens`` must be at least the longest expected + ``prompt_len`` (otherwise prefill is chunked across yields, which + breaks the simple "first tensor == prefill" assumption and also + hurts TTFT). It is plumbed straight through to the engine args. +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import json +import logging +import os +import queue +import tempfile +import threading +import time +import uuid + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils +import yaml + +logging.basicConfig( + format="%(asctime)s [%(levelname)s]: %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger("qwen3_tts_triton") + + +def _require_param(parameters: dict, key: str) -> str: + val = parameters.get(key) + if val is None: + raise KeyError(f"Missing required model parameter: {key!r}") + if isinstance(val, dict): + val = val.get("string_value") + if val is None: + raise KeyError(f"Missing required model parameter: {key!r}") + return str(val) + + +class TritonPythonModel: + def initialize(self, args): + os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + + self.model_config = json.loads(args["model_config"]) + params = self.model_config.get("parameters", {}) + + self.vllm_model_path = _require_param(params, "vllm_model_path") + self.default_speaker = _require_param(params, "default_speaker").lower() + self.default_language = _require_param(params, "default_language") + self.task_type = _require_param(params, "task_type") + if self.task_type != "CustomVoice": + raise ValueError(f"Qwen3-TTS NV talker only supports task_type='CustomVoice', got {self.task_type!r}.") + + self.max_model_len = int(_require_param(params, "max_model_len")) + self.max_num_seqs = int(_require_param(params, "max_num_seqs")) + self.max_num_batched_tokens = int(_require_param(params, "max_num_batched_tokens")) + self.max_new_tokens = int(_require_param(params, "max_new_tokens")) + self.gpu_memory_utilization = float(_require_param(params, "gpu_memory_utilization")) + + self.codec_chunk_size = int(_require_param(params, "codec_chunk_size")) + self.codec_left_context = int(_require_param(params, "codec_left_context")) + self.first_chunk_frames = int(_require_param(params, "first_chunk_frames")) + self.codec_codebook_size = int(_require_param(params, "codec_codebook_size")) + + self.sampling_temperature = float(_require_param(params, "sampling_temperature")) + self.sampling_top_k = int(_require_param(params, "sampling_top_k")) + self.sampling_repetition_penalty = float(_require_param(params, "sampling_repetition_penalty")) + self.sampling_seed = int(_require_param(params, "sampling_seed")) + self.sampling_stop_token_ids = [ + int(x) for x in _require_param(params, "sampling_stop_token_ids").split(",") if x.strip() + ] + + self._samples_per_frame = int(24000 / 12.5) # 12.5 fps codec + self._loop = asyncio.new_event_loop() + self._loop_thread = threading.Thread(target=self._loop.run_forever, daemon=True) + self._loop_thread.start() + + # Dedicated pool for per-request codec workers. Each in-flight request + # holds one thread that serializes its own codec decode + response_sender.send + # calls; default size matches max_num_seqs so every request can run a + # codec call concurrently and let Triton dynamic batching kick in on the + # codec_decoder model. + self._codec_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=max(1, self.max_num_seqs), + thread_name_prefix="qwen3_tts_codec", + ) + + self._load_prompt_builders() + self._start_omni_engine() + + logger.info("Qwen3-TTS initialized (default_speaker=%s)", self.default_speaker) + + def _load_prompt_builders(self): + from transformers import AutoTokenizer + + from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import ( + Qwen3TTSConfig, + ) + from vllm_omni.model_executor.models.qwen3_tts_nv.qwen3_tts_talker_nv import ( + Qwen3TTSTalkerForConditionalGenerationNv, + ) + + self.tokenizer = AutoTokenizer.from_pretrained( + self.vllm_model_path, + trust_remote_code=True, + padding_side="left", + ) + hf_cfg = Qwen3TTSConfig.from_pretrained(self.vllm_model_path, trust_remote_code=True) + talker_cfg = getattr(hf_cfg, "talker_config", None) + self._codec_language_id = getattr(talker_cfg, "codec_language_id", None) + self._spk_is_dialect = getattr(talker_cfg, "spk_is_dialect", None) + self._estimate_prompt_len = ( + Qwen3TTSTalkerForConditionalGenerationNv.estimate_prompt_len_from_additional_information + ) + + def _build_stage_config_file(self) -> str: + stage_cfg = { + "stage_args": [ + { + "stage_id": 0, + "stage_type": "llm", + "is_comprehension": True, + "final_output": True, + "final_output_type": "audio", + "runtime": {"devices": "0"}, + "engine_args": { + "model_stage": "qwen3_tts", + "max_num_seqs": self.max_num_seqs, + "model_arch": "Qwen3TTSTalkerForConditionalGenerationNv", + "worker_type": "ar", + "scheduler_cls": "vllm_omni.core.sched.omni_ar_scheduler.OmniARAsyncScheduler", + "enforce_eager": False, + "trust_remote_code": True, + "async_scheduling": True, + "enable_prefix_caching": False, + "engine_output_type": "audio", + "gpu_memory_utilization": self.gpu_memory_utilization, + # uni runs the worker in-process for TP=PP=1; avoids the + # shm_broadcast IPC round-trip that the mp executor pays + # on every execute_model / sample_tokens call. + "distributed_executor_backend": "uni", + "max_num_batched_tokens": self.max_num_batched_tokens, + "max_model_len": self.max_model_len, + }, + "default_sampling_params": { + "temperature": self.sampling_temperature, + "top_k": self.sampling_top_k, + "max_tokens": self.max_new_tokens, + "seed": self.sampling_seed, + "detokenize": False, + "repetition_penalty": self.sampling_repetition_penalty, + "stop_token_ids": self.sampling_stop_token_ids, + }, + } + ], + } + tmp = tempfile.NamedTemporaryFile( + mode="w", + suffix=".yaml", + prefix="qwen3_tts_triton_", + delete=False, + ) + yaml.dump(stage_cfg, tmp, sort_keys=False) + tmp.close() + return tmp.name + + def _start_omni_engine(self): + from vllm_omni import AsyncOmni + + self._stage_cfg_path = self._build_stage_config_file() + self.omni = AsyncOmni( + model=self.vllm_model_path, + stage_configs_path=self._stage_cfg_path, + log_stats=False, + stage_init_timeout=300, + ) + + def _build_prompt(self, text: str, language: str, speaker: str) -> dict: + """Build the engine input. + + The NV talker only takes ``{task_type, text, language, speaker}`` + in ``additional_information``. ``prompt_token_ids`` is a + placeholder of ``prompt_len`` zeros (the actual prefill embeds are + synthesized inside the talker's :meth:`preprocess`); the only + reason we still compute ``prompt_len`` here is to warn when it + exceeds ``max_num_batched_tokens`` (which would chunk prefill and + hurt TTFT). + """ + additional_information = { + "task_type": [self.task_type], + "text": [text], + "language": [language], + "speaker": [speaker], + } + prompt_len = self._estimate_prompt_len( + additional_information=additional_information, + task_type=self.task_type, + tokenize_prompt=lambda t: self.tokenizer(t, padding=False)["input_ids"], + codec_language_id=self._codec_language_id, + spk_is_dialect=self._spk_is_dialect, + ) + if prompt_len > self.max_num_batched_tokens: + logger.warning( + "prompt_len=%d exceeds max_num_batched_tokens=%d; prefill will be chunked which hurts TTFT.", + prompt_len, + self.max_num_batched_tokens, + ) + return { + "prompt_token_ids": [0] * prompt_len, + "additional_information": additional_information, + } + + def _decode_codec(self, codes: torch.Tensor, left_context_frames: int) -> np.ndarray: + codes_np = codes.numpy().astype(np.int64) + pad = self.codec_chunk_size - codes_np.shape[0] + if pad > 0: + codes_np = np.pad(codes_np, ((0, pad), (0, 0))) + + response = pb_utils.InferenceRequest( + model_name="codec_decoder", + requested_output_names=["audio_values"], + inputs=[pb_utils.Tensor("audio_codes", codes_np[np.newaxis])], + ).exec() + if response.has_error(): + raise RuntimeError(f"Codec decode failed: {response.error().message()}") + + audio_tensor = pb_utils.get_output_tensor_by_name(response, "audio_values") + audio = ( + audio_tensor.as_numpy() + if audio_tensor.is_cpu() + else torch.from_dlpack(audio_tensor.to_dlpack()).cpu().numpy() + ) + if audio.ndim > 1: + audio = audio[0] + + left = left_context_frames * self._samples_per_frame + right = pad * self._samples_per_frame + return audio[left:-right] if right > 0 else audio[left:] + + def _send_audio(self, response_sender, audio: np.ndarray, final: bool): + response_sender.send( + pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("audio", audio.astype(np.float32))]), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL if final else 0, + ) + + def _send_error(self, response_sender, err: Exception): + try: + response_sender.send( + pb_utils.InferenceResponse(output_tensors=[], error=pb_utils.TritonError(str(err))), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) + except Exception: + pass + + def _codec_worker(self, codec_q: queue.Queue, response_sender, state: dict) -> None: + """Per-request worker. Pops (chunk, ctx, is_final) tuples; ``None`` is a + sentinel meaning "send empty final response and exit". Runs on a thread + from ``self._codec_pool`` so codec decode + sender.send don't block the + asyncio loop, and so it can overlap with vLLM token generation for the + same request.""" + finalized = False + try: + while True: + item = codec_q.get() + if item is None: + self._send_audio( + response_sender, + np.array([], dtype=np.float32), + final=True, + ) + finalized = True + return + chunk, ctx, is_final = item + audio = self._decode_codec(chunk, ctx) + self._send_audio(response_sender, audio, final=is_final) + if state["t_first_audio"] is None: + state["t_first_audio"] = time.perf_counter() + if is_final: + finalized = True + return + except Exception as e: + state["error"] = e + if not finalized: + self._send_error(response_sender, e) + + async def _synthesize(self, text: str, language: str, speaker: str, response_sender): + t_start = time.perf_counter() + request_id = f"tts-{uuid.uuid4().hex[:8]}" + prompt = self._build_prompt(text, language, speaker) + + codec_q: queue.Queue = queue.Queue() + state: dict = {"t_first_audio": None, "error": None} + codec_future = self._codec_pool.submit( + self._codec_worker, + codec_q, + response_sender, + state, + ) + + # ``mm_codes`` holds the latest list-typed ``audio_codes`` payload + # we've seen from the engine. The engine's accumulator passes + # through three shapes across one request: the first yield is a + # single tensor (the prefill prefix), the middle yields share the + # same growing Python list (one .append() per AR step), and the + # final yield is the consolidated cat tensor. We skip both tensor + # yields. Because the engine's consolidation reassigns + # ``mm_accumulated["audio_codes"] = torch.cat(list)`` rather than + # mutating the list, our held list reference still sees the EOS + # step's append — so the post-loop ``mm_codes`` has every decode + # row (index 0 is the prefill, indices 1..N are the N decode rows). + sent_frames = 0 + mm_codes: list | None = None + + try: + async for out in self.omni.generate(prompt, request_id=request_id): + if state["error"] is not None: + break + payload = out.multimodal_output.get("audio_codes") + if not isinstance(payload, list): + continue + mm_codes = payload + + decode_count = len(mm_codes) - 1 + threshold = ( + self.first_chunk_frames if sent_frames == 0 else self.codec_chunk_size - self.codec_left_context + ) + while decode_count - sent_frames >= threshold: + ctx = min(sent_frames, self.codec_left_context) + start = 1 + sent_frames - ctx + end = 1 + sent_frames + threshold + chunk = torch.cat(mm_codes[start:end], dim=0) + codec_q.put((chunk, ctx, False)) + sent_frames += threshold + threshold = self.codec_chunk_size - self.codec_left_context + + # Final trailing chunk (or empty-final sentinel). ``mm_codes`` + # here is the latest list reference; the engine has appended + # every decode row up to and including the EOS-sampling step. + if state["error"] is None: + if mm_codes is not None and len(mm_codes) - 1 > sent_frames: + ctx = min(sent_frames, self.codec_left_context) + start = 1 + sent_frames - ctx + end = len(mm_codes) + chunk = torch.cat(mm_codes[start:end], dim=0) + codec_q.put((chunk, ctx, True)) + else: + codec_q.put(None) + + # Wait for the worker to drain and send the final response without + # blocking the asyncio loop thread. + await asyncio.wrap_future(codec_future) + + if state["error"] is not None: + raise state["error"] + + t_end = time.perf_counter() + ttfa_ms = ((state["t_first_audio"] or t_end) - t_start) * 1000 + logger.info( + "rid=%s ttfa=%.1fms total=%.1fms frames=%d speaker=%s text=%r", + request_id, + ttfa_ms, + (t_end - t_start) * 1000, + sent_frames, + speaker, + text[:120], + ) + except Exception as e: + logger.error("rid=%s failed: %s", request_id, e, exc_info=True) + try: + await self.omni.abort(request_id) + except Exception: + pass + # Make sure the worker exits if it's still alive. + if not codec_future.done(): + codec_q.put(None) + try: + await asyncio.wrap_future(codec_future) + except Exception: + pass + self._send_error(response_sender, e) + + def execute(self, requests): + for request in requests: + response_sender = request.get_response_sender() + try: + text = pb_utils.get_input_tensor_by_name(request, "text").as_numpy().flatten()[0].decode("utf-8") + lang_tensor = pb_utils.get_input_tensor_by_name(request, "language") + language = ( + lang_tensor.as_numpy().flatten()[0].decode("utf-8") + if lang_tensor is not None + else self.default_language + ) + spk_tensor = pb_utils.get_input_tensor_by_name(request, "speaker") + speaker = ( + spk_tensor.as_numpy().flatten()[0].decode("utf-8") + if spk_tensor is not None + else self.default_speaker + ).lower() + asyncio.run_coroutine_threadsafe( + self._synthesize(text, language, speaker, response_sender), + self._loop, + ) + except Exception as e: + logger.error("Request parse failed: %s", e, exc_info=True) + self._send_error(response_sender, e) + return None + + def finalize(self): + if hasattr(self, "omni"): + try: + self.omni.shutdown() + except Exception: + pass + if hasattr(self, "_loop") and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + if hasattr(self, "_loop_thread"): + self._loop_thread.join(timeout=10) + if hasattr(self, "_codec_pool"): + self._codec_pool.shutdown(wait=False) + if getattr(self, "_stage_cfg_path", None): + try: + os.unlink(self._stage_cfg_path) + except OSError: + pass diff --git a/examples/online_serving/text_to_speech/qwen3_tts_nv/model_repository/qwen3_tts/config.pbtxt b/examples/online_serving/text_to_speech/qwen3_tts_nv/model_repository/qwen3_tts/config.pbtxt new file mode 100644 index 00000000000..5e8dd84c622 --- /dev/null +++ b/examples/online_serving/text_to_speech/qwen3_tts_nv/model_repository/qwen3_tts/config.pbtxt @@ -0,0 +1,116 @@ +name: "qwen3_tts" +backend: "python" +max_batch_size: 32 + +input [ + { + name: "text" + data_type: TYPE_STRING + dims: [ 1 ] + }, + { + name: "language" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + }, + { + name: "speaker" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + } +] + +output [ + { + name: "audio" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +model_transaction_policy { + decoupled: true +} + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] +parameters { + key: "vllm_model_path" + value: { string_value: "/workspace/Qwen3-TTS-12Hz-1.7B-CustomVoice" } +} +parameters { + key: "default_speaker" + value: { string_value: "aiden" } +} +parameters { + key: "default_language" + value: { string_value: "auto" } +} +parameters { + key: "task_type" + value: { string_value: "CustomVoice" } +} + +parameters { + key: "max_model_len" + value: { string_value: "768" } +} +parameters { + key: "max_num_seqs" + value: { string_value: "32" } +} +parameters { + key: "max_num_batched_tokens" + value: { string_value: "4096" } +} +parameters { + key: "max_new_tokens" + value: { string_value: "768" } +} +parameters { + key: "gpu_memory_utilization" + value: { string_value: "0.5" } +} + +parameters { + key: "codec_chunk_size" + value: { string_value: "30" } +} +parameters { + key: "codec_left_context" + value: { string_value: "25" } +} +parameters { + key: "first_chunk_frames" + value: { string_value: "2" } +} +parameters { + key: "codec_codebook_size" + value: { string_value: "2048" } +} +parameters { + key: "sampling_temperature" + value: { string_value: "0.9" } +} +parameters { + key: "sampling_top_k" + value: { string_value: "50" } +} +parameters { + key: "sampling_repetition_penalty" + value: { string_value: "1.05" } +} +parameters { + key: "sampling_seed" + value: { string_value: "42" } +} +parameters { + key: "sampling_stop_token_ids" + value: { string_value: "2150" } +} diff --git a/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/benchmark_model.py b/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/benchmark_model.py new file mode 100644 index 00000000000..c73fc523fd9 --- /dev/null +++ b/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/benchmark_model.py @@ -0,0 +1,789 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved +"""Benchmark Qwen3-TTS NV *talker only* via a single-stage AsyncOmni engine. + +Runs only the NV AR talker (Qwen3TTSTalkerForConditionalGenerationNv) +without code2wav, producing codec tokens as output. Measures TTFT, +per-token inter-token latency (ITL), end-to-end latency, and throughput +under configurable concurrency. + +Reads texts from a file (one utterance per line, optionally tab-separated +with text in the second column) and runs concurrent requests through the +AsyncOmni engine. + +Usage: + # Basic benchmark with default prompts + python benchmark_qwen3_tts_talker.py \\ + --model Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \\ + --num-requests 50 + + # From a text file with concurrency sweep + python benchmark_qwen3_tts_talker.py \\ + --model /path/to/checkpoint \\ + --text-file texts.txt \\ + --num-requests 100 \\ + --concurrency 1 4 8 + + # With torch profiler on the final run + python benchmark_qwen3_tts_talker.py \\ + --model /path/to/checkpoint \\ + --num-requests 20 --concurrency 1 --profile + + # Save JSON results + python benchmark_qwen3_tts_talker.py \\ + --model /path/to/checkpoint \\ + --text-file texts.txt \\ + --num-requests 100 --concurrency 1 4 \\ + --result-dir results/ +""" + +import os + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +import argparse +import asyncio +import json +import logging +import tempfile +import time +import uuid +from dataclasses import asdict, dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any + +import numpy as np +import yaml + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) +logger = logging.getLogger(__name__) + +DEFAULT_PROMPTS = [ + "Hello, welcome to the voice synthesis benchmark test.", + "She said she would be here by noon, but nobody showed up.", + "The quick brown fox jumps over the lazy dog near the riverbank.", + "I can't believe how beautiful the sunset looks from up here on the mountain.", + "Please remember to bring your identification documents to the appointment tomorrow morning.", + "Have you ever wondered what it would be like to travel through time and visit ancient civilizations?", + "The restaurant on the corner serves the best pasta I have ever tasted in my entire life.", + "After the meeting, we should discuss the quarterly results and plan for the next phase.", + "Learning a new language takes patience, practice, and a genuine curiosity about other cultures.", + "The train leaves at half past seven, so we need to arrive at the station before then.", + "Could you please turn down the music a little bit, I'm trying to concentrate on my work.", + "It was a dark and stormy night when the old lighthouse keeper heard a knock at the door.", +] + + +# --------------------------------------------------------------------------- +# Stage config generation +# --------------------------------------------------------------------------- + + +def _build_talker_only_stage_config( + max_num_seqs: int = 1, + profile: bool = False, + torch_profiler_dir: str = "./profiler_traces", + with_stack: bool = False, + record_shapes: bool = False, + gpu_memory_utilization: float = 0.5, + max_model_len: int = 4096, + max_num_batched_tokens: int = 512, + enforce_eager: bool = False, + max_new_tokens: int = 2048, + distributed_executor_backend: str = "mp", +) -> dict: + """Build a single-stage YAML dict containing only the NV AR talker.""" + engine_args: dict[str, Any] = { + "model_stage": "qwen3_tts", + "max_num_seqs": max_num_seqs, + "model_arch": "Qwen3TTSTalkerForConditionalGenerationNv", + "worker_type": "ar", + "scheduler_cls": "vllm_omni.core.sched.omni_ar_scheduler.OmniARAsyncScheduler", + "enforce_eager": enforce_eager, + "trust_remote_code": True, + "async_scheduling": True, + "enable_prefix_caching": False, + "engine_output_type": "audio", + "gpu_memory_utilization": gpu_memory_utilization, + # "uni" runs the worker in-process (no shm_broadcast IPC); use "mp" + # only when TP/PP > 1 or you actually need a separate worker process. + "distributed_executor_backend": distributed_executor_backend, + "max_num_batched_tokens": max_num_batched_tokens, + "max_model_len": max_model_len, + } + + if profile: + engine_args["profiler_config"] = { + "profiler": "torch", + "torch_profiler_dir": os.path.abspath(torch_profiler_dir), + "torch_profiler_with_stack": with_stack, + "torch_profiler_record_shapes": record_shapes, + } + + cfg = { + "stage_args": [ + { + "stage_id": 0, + "stage_type": "llm", + "is_comprehension": True, + "final_output": True, + "final_output_type": "audio", + "runtime": {"devices": "0"}, + "engine_args": engine_args, + "default_sampling_params": { + "temperature": 0.9, + "top_k": 50, + "max_tokens": max_new_tokens, + "seed": 42, + "detokenize": False, + "repetition_penalty": 1.05, + "stop_token_ids": [2150], + }, + } + ], + } + return cfg + + +def _write_temp_stage_config(cfg: dict) -> str: + """Write stage config dict to a temp YAML file, return its path.""" + tmp = tempfile.NamedTemporaryFile( + mode="w", + suffix=".yaml", + prefix="talker_nv_bench_", + delete=False, + ) + yaml.dump(cfg, tmp, default_flow_style=False, sort_keys=False) + tmp.close() + logger.info("Wrote single-stage config to %s", tmp.name) + return tmp.name + + +# --------------------------------------------------------------------------- +# Prompt construction +# --------------------------------------------------------------------------- + + +def _estimate_prompt_len( + additional_information: dict[str, Any], + model_name: str, + _cache: dict[str, Any] = {}, +) -> int: + """Estimate prompt_token_ids placeholder length for the NV talker.""" + try: + from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import ( + Qwen3TTSConfig, + ) + from vllm_omni.model_executor.models.qwen3_tts_nv.qwen3_tts_talker_nv import ( + Qwen3TTSTalkerForConditionalGenerationNv, + ) + + if model_name not in _cache: + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained( + model_name, + trust_remote_code=True, + padding_side="left", + ) + hf_cfg = Qwen3TTSConfig.from_pretrained( + model_name, + trust_remote_code=True, + ) + _cache[model_name] = (tok, getattr(hf_cfg, "talker_config", None)) + + tok, tcfg = _cache[model_name] + task_type = (additional_information.get("task_type") or ["CustomVoice"])[0] + + return Qwen3TTSTalkerForConditionalGenerationNv.estimate_prompt_len_from_additional_information( + additional_information=additional_information, + task_type=task_type, + tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"], + codec_language_id=getattr(tcfg, "codec_language_id", None), + spk_is_dialect=getattr(tcfg, "spk_is_dialect", None), + ) + except Exception as exc: + logger.warning("Prompt length estimation failed, using 2048: %s", exc) + return 2048 + + +def build_input( + text: str, + speaker: str, + language: str, + model_name: str, +) -> dict: + """Build an engine input dict from text + speaker + language.""" + additional_information = { + "task_type": ["CustomVoice"], + "text": [text], + "language": [language], + "speaker": [speaker], + } + ph_len = _estimate_prompt_len(additional_information, model_name) + return { + "prompt_token_ids": [0] * ph_len, + "additional_information": additional_information, + } + + +# --------------------------------------------------------------------------- +# Result dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class RequestResult: + success: bool = False + text: str = "" + prompt_len: int = 0 + num_generated: int = 0 + steps: int = 0 + ttft_s: float = 0.0 + e2e_s: float = 0.0 + inter_token_latencies: list = field(default_factory=list) + error: str = "" + + +@dataclass +class BenchmarkResult: + config_name: str = "" + concurrency: int = 0 + num_requests: int = 0 + completed: int = 0 + failed: int = 0 + duration_s: float = 0.0 + # TTFT + mean_ttft_ms: float = 0.0 + median_ttft_ms: float = 0.0 + p95_ttft_ms: float = 0.0 + p99_ttft_ms: float = 0.0 + # E2E + mean_e2e_ms: float = 0.0 + median_e2e_ms: float = 0.0 + p95_e2e_ms: float = 0.0 + p99_e2e_ms: float = 0.0 + # ITL (inter-token latency, excluding first token) + mean_itl_ms: float = 0.0 + median_itl_ms: float = 0.0 + p95_itl_ms: float = 0.0 + p99_itl_ms: float = 0.0 + # Throughput + total_tokens: int = 0 + mean_tokens_per_request: float = 0.0 + token_throughput: float = 0.0 + request_throughput: float = 0.0 + per_request: list = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Inference +# --------------------------------------------------------------------------- + + +async def run_one_request(omni, prompt: dict, request_id: str) -> RequestResult: + """Submit one TTS request and collect outputs with per-token timing. + + AsyncOmni coerces sampling params to ``RequestOutputKind.DELTA`` when no + explicit ``sampling_params`` are passed (since #2911). In DELTA mode, + ``CompletionOutput.token_ids`` only holds the *new* tokens for the + current step, so ``len(token_ids)`` cannot be used as a cumulative + counter. The omni output processor always stores the cumulative list on + ``cumulative_token_ids``; we use that to detect new tokens and to time + inter-token latencies. + """ + result = RequestResult() + t_start = time.perf_counter() + t_last_token = None + prev_num_tokens = 0 + + try: + async for stage_output in omni.generate(prompt, request_id=request_id): + now = time.perf_counter() + ro = stage_output.request_output + result.steps += 1 + + cur_num_tokens = prev_num_tokens + if hasattr(ro, "outputs") and ro.outputs: + out0 = ro.outputs[0] + cum_ids = getattr(out0, "cumulative_token_ids", None) + if cum_ids is not None: + cur_num_tokens = len(cum_ids) + else: + cur_num_tokens = len(getattr(out0, "token_ids", []) or []) + + if cur_num_tokens > prev_num_tokens: + if t_last_token is None: + result.ttft_s = now - t_start + else: + result.inter_token_latencies.append(now - t_last_token) + t_last_token = now + prev_num_tokens = cur_num_tokens + + t_end = time.perf_counter() + result.e2e_s = t_end - t_start + result.num_generated = prev_num_tokens + result.success = True + + if result.ttft_s == 0.0 and result.steps > 0: + result.ttft_s = t_end - t_start + + except Exception as exc: + result.e2e_s = time.perf_counter() - t_start + result.error = str(exc) + logger.error("Request %s failed: %s", request_id, exc) + + return result + + +# --------------------------------------------------------------------------- +# Worker / concurrency +# --------------------------------------------------------------------------- + + +async def worker( + worker_id: int, + omni, + texts: list[str], + model_name: str, + speaker: str, + language: str, + results: list[RequestResult], + counter: dict, + lock: asyncio.Lock, +): + """Persistent async worker that picks texts until the quota is exhausted.""" + while True: + async with lock: + if counter["remaining"] <= 0: + break + counter["remaining"] -= 1 + idx = counter["issued"] + counter["issued"] += 1 + + text = texts[idx % len(texts)] + request_id = f"bench-nv-w{worker_id}-{uuid.uuid4().hex[:8]}" + + prompt = build_input( + text=text, + speaker=speaker, + language=language, + model_name=model_name, + ) + + result = await run_one_request(omni, prompt, request_id) + result.text = text + result.prompt_len = len(prompt["prompt_token_ids"]) + + async with lock: + results.append(result) + done = len(results) + + if done % 10 == 0 or done == counter["total"]: + logger.info(" progress: %d / %d", done, counter["total"]) + + +# --------------------------------------------------------------------------- +# Metrics +# --------------------------------------------------------------------------- + + +def _pct(arr, p): + return float(np.percentile(arr, p)) if len(arr) > 0 else 0.0 + + +def compute_and_print_metrics( + results: list[RequestResult], + duration: float, + concurrency: int, + num_requests: int, +) -> BenchmarkResult: + successful = [r for r in results if r.success] + failed = [r for r in results if not r.success] + + bench = BenchmarkResult( + concurrency=concurrency, + num_requests=num_requests, + completed=len(successful), + failed=len(failed), + duration_s=duration, + ) + + if not successful: + print("ERROR: No requests completed successfully.") + return bench + + ttfts = [r.ttft_s * 1000 for r in successful] + e2es = [r.e2e_s * 1000 for r in successful] + all_itls = [] + for r in successful: + all_itls.extend([t * 1000 for t in r.inter_token_latencies]) + gen_tokens = [r.num_generated for r in successful] + + bench.mean_ttft_ms = float(np.mean(ttfts)) + bench.median_ttft_ms = float(np.median(ttfts)) + bench.p95_ttft_ms = _pct(ttfts, 95) + bench.p99_ttft_ms = _pct(ttfts, 99) + + bench.mean_e2e_ms = float(np.mean(e2es)) + bench.median_e2e_ms = float(np.median(e2es)) + bench.p95_e2e_ms = _pct(e2es, 95) + bench.p99_e2e_ms = _pct(e2es, 99) + + if all_itls: + bench.mean_itl_ms = float(np.mean(all_itls)) + bench.median_itl_ms = float(np.median(all_itls)) + bench.p95_itl_ms = _pct(all_itls, 95) + bench.p99_itl_ms = _pct(all_itls, 99) + + bench.total_tokens = sum(gen_tokens) + bench.mean_tokens_per_request = float(np.mean(gen_tokens)) + bench.token_throughput = bench.total_tokens / duration if duration > 0 else 0.0 + bench.request_throughput = len(successful) / duration if duration > 0 else 0.0 + + bench.per_request = [ + { + "ttft_ms": r.ttft_s * 1000, + "e2e_ms": r.e2e_s * 1000, + "num_generated": r.num_generated, + "steps": r.steps, + "prompt_len": r.prompt_len, + "mean_itl_ms": float(np.mean([t * 1000 for t in r.inter_token_latencies])) + if r.inter_token_latencies + else 0.0, + "text": r.text, + } + for r in successful + ] + + W = 56 + print(f"\n{'=' * W}") + print(f"{'Benchmark Result':^{W}}") + print(f"{'=' * W}") + print(f"{'Successful requests:':<42}{bench.completed}") + print(f"{'Failed requests:':<42}{bench.failed}") + print(f"{'Concurrency:':<42}{concurrency}") + print(f"{'Wall-clock duration (s):':<42}{duration:.2f}") + print(f"{'Request throughput (req/s):':<42}{bench.request_throughput:.2f}") + + print(f"\n{'-' * W}") + print(f"{'Time to First Token (TTFT)':^{W}}") + print(f"{'-' * W}") + print(f"{'Mean (ms):':<42}{bench.mean_ttft_ms:.2f}") + print(f"{'Median (ms):':<42}{bench.median_ttft_ms:.2f}") + print(f"{'P95 (ms):':<42}{bench.p95_ttft_ms:.2f}") + print(f"{'P99 (ms):':<42}{bench.p99_ttft_ms:.2f}") + + print(f"\n{'-' * W}") + print(f"{'End-to-End Latency (E2E)':^{W}}") + print(f"{'-' * W}") + print(f"{'Mean (ms):':<42}{bench.mean_e2e_ms:.2f}") + print(f"{'Median (ms):':<42}{bench.median_e2e_ms:.2f}") + print(f"{'P95 (ms):':<42}{bench.p95_e2e_ms:.2f}") + print(f"{'P99 (ms):':<42}{bench.p99_e2e_ms:.2f}") + + print(f"\n{'-' * W}") + print(f"{'Inter-Token Latency (ITL)':^{W}}") + print(f"{'-' * W}") + if all_itls: + print(f"{'Mean (ms):':<42}{bench.mean_itl_ms:.2f}") + print(f"{'Median (ms):':<42}{bench.median_itl_ms:.2f}") + print(f"{'P95 (ms):':<42}{bench.p95_itl_ms:.2f}") + print(f"{'P99 (ms):':<42}{bench.p99_itl_ms:.2f}") + else: + print(f"{'(no inter-token data)':^{W}}") + + print(f"\n{'-' * W}") + print(f"{'Token Throughput':^{W}}") + print(f"{'-' * W}") + print(f"{'Total tokens generated:':<42}{bench.total_tokens}") + print(f"{'Mean tokens / request:':<42}{bench.mean_tokens_per_request:.1f}") + print(f"{'Token throughput (tok/s):':<42}{bench.token_throughput:.2f}") + print(f"{'=' * W}\n") + + if failed: + print(f" First {min(3, len(failed))} errors:") + for r in failed[:3]: + print(f" {r.error[:200]}") + + return bench + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main(args): + from vllm_omni import AsyncOmni + + model_name = args.model + + # ── Load texts ──────────────────────────────────────────────────────── + if args.text_file: + path = Path(args.text_file) + if not path.exists(): + print(f"ERROR: text file not found: {path}") + return + raw_lines = [line.strip() for line in path.read_text().splitlines() if line.strip()] + texts = [] + for line in raw_lines: + if "\t" in line: + texts.append(line.split("\t", 1)[1].strip()) + else: + texts.append(line) + texts = [t for t in texts if t] + logger.info("Loaded %d texts from %s", len(texts), path) + else: + texts = DEFAULT_PROMPTS + logger.info("Using %d default prompts", len(texts)) + + if not texts: + print("ERROR: no texts available.") + return + + max_concurrency = max(args.concurrency) + + # ── Build stage config ──────────────────────────────────────────────── + stage_cfg = _build_talker_only_stage_config( + max_num_seqs=max_concurrency, + profile=args.profile, + torch_profiler_dir=args.torch_profiler_dir, + with_stack=args.with_stack, + record_shapes=args.record_shapes, + gpu_memory_utilization=args.gpu_memory_utilization, + max_model_len=args.max_model_len, + max_num_batched_tokens=args.max_num_batched_tokens, + enforce_eager=args.enforce_eager, + max_new_tokens=args.max_new_tokens, + distributed_executor_backend=args.distributed_executor_backend, + ) + tmp_config_path = _write_temp_stage_config(stage_cfg) + + try: + logger.info("Creating AsyncOmni engine (talker only) for %s ...", model_name) + omni = AsyncOmni( + model=model_name, + stage_configs_path=tmp_config_path, + log_stats=args.log_stats, + stage_init_timeout=args.stage_init_timeout, + ) + logger.info("Engine ready (single stage: talker).") + + all_bench_results = [] + + for concurrency in args.concurrency: + logger.info( + "═══ concurrency=%d requests=%d ═══", + concurrency, + args.num_requests, + ) + + # ── Warmup ──────────────────────────────────────────────────── + warmup_count = 0 if args.no_warmup else args.num_warmups * concurrency + if warmup_count > 0: + logger.info("Warming up with %d requests (concurrency=%d)...", warmup_count, concurrency) + warmup_results: list[RequestResult] = [] + warmup_counter = { + "remaining": warmup_count, + "issued": 0, + "total": warmup_count, + } + warmup_lock = asyncio.Lock() + warmup_tasks = [ + asyncio.create_task( + worker( + worker_id=i, + omni=omni, + texts=texts, + model_name=model_name, + speaker=args.speaker, + language=args.language, + results=warmup_results, + counter=warmup_counter, + lock=warmup_lock, + ) + ) + for i in range(concurrency) + ] + await asyncio.gather(*warmup_tasks) + warmup_ok = sum(1 for r in warmup_results if r.success) + logger.info("Warmup done: %d / %d succeeded.", warmup_ok, warmup_count) + + # ── Benchmark run ───────────────────────────────────────────── + logger.info("Starting benchmark run (%d requests, concurrency=%d)...", args.num_requests, concurrency) + + bench_results: list[RequestResult] = [] + counter = { + "remaining": args.num_requests, + "issued": 0, + "total": args.num_requests, + } + lock = asyncio.Lock() + + if args.profile: + logger.info("Starting profiler ...") + await omni.start_profile( + profile_prefix=args.profile_prefix, + stages=[0], + ) + + start_time = time.perf_counter() + try: + tasks = [ + asyncio.create_task( + worker( + worker_id=i, + omni=omni, + texts=texts, + model_name=model_name, + speaker=args.speaker, + language=args.language, + results=bench_results, + counter=counter, + lock=lock, + ) + ) + for i in range(concurrency) + ] + await asyncio.gather(*tasks) + finally: + if args.profile: + logger.info("Stopping profiler ...") + await omni.stop_profile(stages=[0]) + + duration = time.perf_counter() - start_time + + bench = compute_and_print_metrics( + bench_results, + duration, + concurrency, + args.num_requests, + ) + bench.config_name = args.config_name + all_bench_results.append(asdict(bench)) + + # ── Save results ────────────────────────────────────────────────── + if args.result_dir: + result_dir = Path(args.result_dir) + result_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + result_file = result_dir / f"bench_talker_nv_{args.config_name}_{timestamp}.json" + with open(result_file, "w") as f: + json.dump(all_bench_results, f, indent=2) + logger.info("Results saved to %s", result_file) + + omni.shutdown() + finally: + os.unlink(tmp_config_path) + + logger.info("Done.") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark Qwen3-TTS NV talker (AR stage only) via AsyncOmni", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + model = parser.add_argument_group("model / input") + model.add_argument( + "--model", + type=str, + default="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", + help="Model name or path", + ) + model.add_argument( + "--text-file", + type=str, + default=None, + help="Path to text file (one utterance per line, optionally tab-separated with text in 2nd column)", + ) + model.add_argument("--speaker", type=str, default="aiden") + model.add_argument("--language", type=str, default="English") + model.add_argument( + "--max-new-tokens", + type=int, + default=2048, + help="Max sampling tokens per request (passed via default_sampling_params.max_tokens)", + ) + + bench = parser.add_argument_group("benchmark") + bench.add_argument( + "-c", + "--concurrency", + type=int, + nargs="+", + default=[1], + help="Concurrency levels to test (space-separated, default: 1)", + ) + bench.add_argument( + "-n", + "--num-requests", + type=int, + default=50, + help="Total number of requests per concurrency level (default: 50)", + ) + bench.add_argument( + "--num-warmups", + type=int, + default=3, + help="Warmup rounds per concurrency level (total warmup = concurrency * this, default: 3)", + ) + bench.add_argument("--no-warmup", action="store_true", help="Skip warmup") + bench.add_argument( + "--config-name", + type=str, + default="talker_nv", + help="Label for this run (used in result filenames)", + ) + bench.add_argument( + "--result-dir", + type=str, + default=None, + help="Directory to save JSON results", + ) + + engine = parser.add_argument_group("engine") + engine.add_argument("--gpu-memory-utilization", type=float, default=0.5) + engine.add_argument("--max-model-len", type=int, default=4096) + engine.add_argument("--max-num-batched-tokens", type=int, default=4096) + engine.add_argument("--enforce-eager", action="store_true") + engine.add_argument("--stage-init-timeout", type=int, default=300) + engine.add_argument("--log-stats", action="store_true", default=False) + engine.add_argument( + "--distributed-executor-backend", + type=str, + default="uni", + choices=["uni", "mp", "ray"], + help="vLLM executor backend. 'uni' runs the worker in-process and " + "avoids the shm_broadcast IPC round-trips on every " + "execute_model/sample_tokens call (recommended for TP=1, " + "single GPU). Default: uni.", + ) + + prof = parser.add_argument_group("profiling") + prof.add_argument( + "--profile", + action="store_true", + help="Enable torch profiler during the benchmark run", + ) + prof.add_argument("--profile-prefix", type=str, default=None, help="Prefix for profiler trace filenames") + prof.add_argument( + "--torch-profiler-dir", type=str, default="./profiler_traces", help="Directory for torch profiler traces" + ) + prof.add_argument("--with-stack", action="store_true", help="Record Python call stacks in profiler") + prof.add_argument("--record-shapes", action="store_true", help="Record tensor shapes in profiler") + + return parser.parse_args() + + +if __name__ == "__main__": + asyncio.run(main(parse_args())) diff --git a/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/benchmark_service.py b/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/benchmark_service.py new file mode 100644 index 00000000000..495af7a118b --- /dev/null +++ b/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/benchmark_service.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved +""" +Benchmark script for Qwen3-TTS Triton server (decoupled mode, gRPC). + +Spawns N concurrent workers that send TTS requests in parallel. +Each line of the text file is parsed as ``\\t``. +Texts are randomly sampled for each request. + +Usage: + python benchmark.py --text-file texts.txt --num-requests 100 --num-workers 8 + python benchmark.py --text-file texts.txt --num-requests 50 \ + --output-dir out_wavs +""" + +import argparse +import queue +import random +import threading +import time +import wave +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +import tritonclient.grpc as grpcclient + +SAMPLE_RATE = 24_000 +MODEL_NAME = "qwen3_tts" + + +@dataclass +class RequestResult: + uttid: str + num_samples: int + duration_s: float + ttfa_s: float = 0.0 + error: str | None = None + + +@dataclass +class BenchmarkStats: + lock: threading.Lock = field(default_factory=threading.Lock) + results: list[RequestResult] = field(default_factory=list) + + def add(self, result: RequestResult): + with self.lock: + self.results.append(result) + + +def _save_wav(path: Path, audio: np.ndarray, sample_rate: int = SAMPLE_RATE): + audio = np.clip(audio, -1.0, 1.0) + pcm = (audio * 32767.0).astype(np.int16) + with wave.open(str(path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(pcm.tobytes()) + + +def synthesize( + client: grpcclient.InferenceServerClient, + result_q: queue.Queue, + text: str, + chunk_timeout: float, +): + """Send one TTS request and collect streamed chunks. + + Returns ``(audio, ttfa_s, elapsed_s, error)``. + """ + text_input = grpcclient.InferInput("text", [1, 1], "BYTES") + text_input.set_data_from_numpy(np.array([[text]], dtype=object)) + + t0 = time.perf_counter() + t_first: float | None = None + chunks: list[np.ndarray] = [] + + client.async_stream_infer( + model_name=MODEL_NAME, + inputs=[text_input], + outputs=[grpcclient.InferRequestedOutput("audio")], + ) + + while True: + try: + result, error = result_q.get(timeout=chunk_timeout) + except queue.Empty: + elapsed = time.perf_counter() - t0 + return None, elapsed, elapsed, "no chunk within chunk_timeout" + + if error: + elapsed = time.perf_counter() - t0 + return None, elapsed, elapsed, str(error) + + audio = result.as_numpy("audio").squeeze() + if audio.size > 0: + if t_first is None: + t_first = time.perf_counter() + chunks.append(audio) + + response = result.get_response() + final_param = response.parameters.get("triton_final_response") + if final_param and getattr(final_param, "bool_param", False): + break + + elapsed = time.perf_counter() - t0 + ttfa = (t_first - t0) if t_first else elapsed + audio = np.concatenate(chunks) if chunks else np.array([], dtype=np.float32) + return audio, ttfa, elapsed, None + + +def worker( + worker_id: int, + triton_url: str, + items: list[tuple[str, str]], + task_queue: list[int], + queue_lock: threading.Lock, + stats: BenchmarkStats, + chunk_timeout: float, + output_dir: Path | None, +): + result_q: queue.Queue = queue.Queue() + client = grpcclient.InferenceServerClient(url=triton_url) + client.start_stream(callback=lambda result, error: result_q.put((result, error))) + + try: + while True: + with queue_lock: + if not task_queue: + return + task_idx = task_queue.pop() + + uttid, text = random.choice(items) + audio, ttfa, elapsed, error = synthesize(client, result_q, text, chunk_timeout) + + if error is not None: + # Reset the stream so late chunks don't bleed into the next + # request. + client.stop_stream() + client.start_stream(callback=lambda result, error: result_q.put((result, error))) + stats.add( + RequestResult( + uttid=uttid, + num_samples=0, + duration_s=elapsed, + ttfa_s=ttfa, + error=error, + ) + ) + print(f"[worker {worker_id:02d}] request {task_idx} ({uttid}) FAILED ({elapsed:.1f}s) — {error}") + continue + + num_samples = len(audio) + if output_dir is not None and num_samples > 0: + _save_wav(output_dir / f"{uttid}.wav", audio) + + stats.add( + RequestResult( + uttid=uttid, + num_samples=num_samples, + duration_s=elapsed, + ttfa_s=ttfa, + ) + ) + print( + f"[worker {worker_id:02d}] request {task_idx} ({uttid}) done — " + f"{num_samples / SAMPLE_RATE:.2f}s audio in {elapsed:.2f}s " + f"(TTFA: {ttfa:.3f}s)" + ) + finally: + client.stop_stream() + + +def _load_items(text_file: str) -> list[tuple[str, str]]: + items: list[tuple[str, str]] = [] + with open(text_file) as f: + for line in f: + line = line.rstrip("\n") + if not line.strip(): + continue + parts = line.split("\t", 1) + if len(parts) != 2: + raise ValueError(f"Expected '\\t' per line, got: {line!r}") + uttid, text = parts[0].strip(), parts[1].strip() + if not uttid or not text: + raise ValueError(f"Empty uttid or text in line: {line!r}") + items.append((uttid, text)) + return items + + +def _run_workers( + num_workers: int, + triton_url: str, + items: list[tuple[str, str]], + num_tasks: int, + chunk_timeout: float, + output_dir: Path | None, +) -> tuple[BenchmarkStats, float]: + task_queue = list(range(num_tasks)) + queue_lock = threading.Lock() + stats = BenchmarkStats() + + threads = [ + threading.Thread( + target=worker, + args=(i, triton_url, items, task_queue, queue_lock, stats, chunk_timeout, output_dir), + ) + for i in range(num_workers) + ] + wall_start = time.perf_counter() + for t in threads: + t.start() + for t in threads: + t.join() + return stats, time.perf_counter() - wall_start + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark Qwen3-TTS Triton server") + parser.add_argument("--text-file", required=True, help="Path to file with '\\t' per line") + parser.add_argument("--num-requests", type=int, required=True, help="Total number of requests to send") + parser.add_argument("--num-workers", type=int, default=4, help="Number of concurrent workers (default: 4)") + parser.add_argument("--triton-url", default="localhost:8001", help="Triton gRPC endpoint (default: localhost:8001)") + parser.add_argument("--no-warmup", action="store_true", help="Skip warmup phase (3 requests per worker)") + parser.add_argument( + "--chunk-timeout", type=float, default=60, help="Per-chunk receive timeout in seconds (default: 60)" + ) + parser.add_argument( + "--output-dir", default=None, help="If set, write each generated waveform to /.wav" + ) + args = parser.parse_args() + + items = _load_items(args.text_file) + if not items: + print(f"ERROR: no usable lines found in {args.text_file}") + return + + output_dir: Path | None = None + if args.output_dir is not None: + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Loaded {len(items)} utterances from {args.text_file}") + print(f"Sending {args.num_requests} requests with {args.num_workers} workers to {args.triton_url}") + if output_dir is not None: + print(f"Writing WAVs to {output_dir.resolve()}") + print("-" * 70) + + if not args.no_warmup: + total_warmup = args.num_workers * 3 + print(f"Warmup: {total_warmup} requests (3 per worker) ...") + _run_workers( + args.num_workers, + args.triton_url, + items, + total_warmup, + args.chunk_timeout, + output_dir=None, + ) + print("Warmup complete.") + print("-" * 70) + + stats, wall_elapsed = _run_workers( + args.num_workers, + args.triton_url, + items, + args.num_requests, + args.chunk_timeout, + output_dir, + ) + + successes = [r for r in stats.results if r.error is None] + failures = [r for r in stats.results if r.error is not None] + total_audio_seconds = sum(r.num_samples for r in successes) / SAMPLE_RATE + + print() + print("=" * 70) + print("BENCHMARK RESULTS") + print("=" * 70) + print(f" Total requests sent: {args.num_requests}") + print(f" Successful: {len(successes)}") + print(f" Failed: {len(failures)}") + print(f" Concurrent workers: {args.num_workers}") + print() + print(f" Wall-clock time: {wall_elapsed:.2f} s") + print(f" Total audio synthesized: {total_audio_seconds:.2f} s") + print(f" Real-time factor (RTF): {total_audio_seconds / wall_elapsed:.2f}x") + print(f" Throughput: {len(successes) / wall_elapsed:.2f} requests/s") + + if successes: + ttfas_ms = sorted(r.ttfa_s * 1000 for r in successes) + mean_ttfa = sum(ttfas_ms) / len(ttfas_ms) + print() + print(" Time to first audio (TTFA):") + print(f" mean: {mean_ttfa:.1f} ms") + print(f" p95: {ttfas_ms[int(len(ttfas_ms) * 0.95)]:.1f} ms") + + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/export_codec_onnx.py b/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/export_codec_onnx.py new file mode 100644 index 00000000000..681e51c5d6c --- /dev/null +++ b/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/export_codec_onnx.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved +"""Export the Qwen3-TTS 12Hz codec decoder to ONNX.""" + +import argparse +from pathlib import Path + +import numpy as np +import torch + +# Match ORT's full-FP32 matmul; PyTorch on Ampere+ uses TF32 by default. +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +# Bypass torch.vmap-based mask builders (untraceable by ONNX export). +try: + import transformers.masking_utils as _mu + + for _name in ("create_causal_mask", "create_sliding_window_causal_mask"): + if hasattr(_mu, _name): + setattr(_mu, _name, lambda *a, **kw: None) +except ImportError: + pass + +# enable_gqa=True is unsupported by the TorchScript ONNX exporter and is a +# no-op here (num_heads == num_kv_heads). +_orig_sdpa = torch.nn.functional.scaled_dot_product_attention + + +def _sdpa_no_gqa(*args, **kwargs): + kwargs.pop("enable_gqa", None) + return _orig_sdpa(*args, **kwargs) + + +torch.nn.functional.scaled_dot_product_attention = _sdpa_no_gqa + +try: + import onnx +except ImportError as exc: + raise ImportError( + "`onnx` is required on top of the Qwen3-TTS environment. Install with: pip install onnx onnxruntime" + ) from exc + +try: + from qwen_tts import Qwen3TTSTokenizer +except ImportError as exc: + raise ImportError( + "`qwen_tts` not importable; install Qwen3-TTS per https://github.com/QwenLM/Qwen3-TTS#quickstart." + ) from exc + + +class CodecDecoderWrapper(torch.nn.Module): + def __init__(self, decoder: torch.nn.Module): + super().__init__() + self.decoder = decoder + + def forward(self, audio_codes: torch.Tensor) -> torch.Tensor: + return self.decoder(audio_codes.transpose(1, 2)).squeeze(1) + + +def check_onnx_parity(wrapper, onnx_path, audio_codes, device, atol=1e-3): + try: + import onnxruntime as ort + except ImportError: + print("onnxruntime not installed – skipping parity check") + return True + + if device.type == "cuda": + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + sess = ort.InferenceSession(str(onnx_path), providers=providers) + + with torch.inference_mode(): + ref = wrapper(audio_codes).detach().cpu().float().numpy() + ort_out = sess.run(None, {"audio_codes": audio_codes.cpu().numpy()})[0] + max_diff = float(np.abs(ref - ort_out).max()) + ok = max_diff <= atol + print( + f"ONNX parity ({sess.get_providers()[0]}): " + f"max_abs_diff={max_diff:.6f} atol={atol} " + f"{'PASSED' if ok else 'FAILED'}" + ) + return ok + + +def parse_args(): + p = argparse.ArgumentParser(description="Export Qwen3-TTS 12Hz codec decoder to ONNX") + p.add_argument("--tokenizer-path", default="Qwen/Qwen3-TTS-Tokenizer-12Hz") + p.add_argument("--onnx-path", default="codec.onnx") + p.add_argument("--frames", type=int, default=30) + p.add_argument("--batch-size", type=int, default=1) + p.add_argument("--opset", type=int, default=18) + p.add_argument("--device", default="cpu", choices=["cpu", "cuda"]) + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + tokenizer = Qwen3TTSTokenizer.from_pretrained( + args.tokenizer_path, + device_map=args.device, + dtype=torch.float32, + attn_implementation="eager", + ) + decoder = tokenizer.model.decoder + wrapper = CodecDecoderWrapper(decoder).to(device).eval() + + nq = int(decoder.config.num_quantizers) + dummy = torch.randint( + 0, + int(decoder.config.codebook_size), + (args.batch_size, args.frames, nq), + dtype=torch.long, + device=device, + ) + + onnx_path = Path(args.onnx_path) + onnx_path.parent.mkdir(parents=True, exist_ok=True) + + with torch.inference_mode(): + torch.onnx.export( + wrapper, + (dummy,), + str(onnx_path), + dynamo=False, + export_params=True, + opset_version=args.opset, + do_constant_folding=True, + input_names=["audio_codes"], + output_names=["audio_values"], + dynamic_axes={ + "audio_codes": {0: "batch"}, + "audio_values": {0: "batch"}, + }, + ) + print(f"ONNX exported to {onnx_path}") + + onnx.checker.check_model(str(onnx_path)) + + ok = check_onnx_parity(wrapper, onnx_path, dummy, device) + if not ok: + raise RuntimeError("ONNX vs PyTorch parity failed — export is broken.") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/export_codec_trt.py b/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/export_codec_trt.py new file mode 100644 index 00000000000..2a71a185515 --- /dev/null +++ b/examples/online_serving/text_to_speech/qwen3_tts_nv/scripts/export_codec_trt.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved +"""Build a TensorRT engine from the Qwen3-TTS codec decoder ONNX.""" + +import argparse +import shutil +import subprocess +from pathlib import Path + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +from onnx import shape_inference + + +def _make_runtime_zero(np_type, base_name): + """0-D `zero = seed - seed` (seed=1). Hidden from the constant-folder.""" + seed = gs.Constant(name=f"{base_name}_seed", values=np.array(1, dtype=np_type)) + zero = gs.Variable(name=f"{base_name}_zero", dtype=np_type, shape=()) + sub = gs.Node(op="Sub", inputs=[seed, seed], outputs=[zero], name=f"{base_name}_zero_sub") + return sub, zero + + +def _make_add_zero_barrier(tensor, zero_var, name): + new_tensor = gs.Variable( + name=f"{tensor.name}__{name}", + dtype=tensor.dtype, + shape=tensor.shape, + ) + add = gs.Node(op="Add", inputs=[tensor, zero_var], outputs=[new_tensor], name=name) + return add, new_tensor + + +def apply_trt_fusion_barrier( + onnx_path, + target_tensor_name="/decoder/Transpose_19_output_0", +): + """Wrap the post-transformer permute with `Add(x, runtime_zero)` barriers. + + Works around a TRT 10.15 fused-tactic bug that produces silently wrong + audio at dynamic batch > 1. Patches the ONNX file in place. + """ + model = onnx.load(str(onnx_path)) + try: + model = shape_inference.infer_shapes(model) + except Exception as exc: # noqa: BLE001 + print(f" [warn] onnx shape_inference failed ({exc})") + graph = gs.import_onnx(model) + + tp_node = None + for node in graph.nodes: + for out in node.outputs: + if out.name == target_tensor_name: + tp_node = node + break + if tp_node is not None: + break + if tp_node is None: + raise RuntimeError(f"target tensor {target_tensor_name!r} not found in graph") + if tp_node.op != "Transpose": + print(f" [warn] producer of {target_tensor_name!r} is {tp_node.op!r}, expected Transpose; proceeding anyway") + + in_tensor = tp_node.inputs[0] + out_tensor = tp_node.outputs[0] + if in_tensor.dtype is None: + raise RuntimeError(f"cannot insert barrier on {in_tensor.name!r}: dtype unknown") + np_type = np.dtype(in_tensor.dtype).type + safe_name = tp_node.name.lstrip("/").replace("/", "_") or "Transpose_19" + + zero_sub, zero_var = _make_runtime_zero(np_type, base_name=f"FusionBarrier_{safe_name}") + + pre_add, pre_out = _make_add_zero_barrier(in_tensor, zero_var, name=f"FusionBarrier_pre_{safe_name}") + tp_node.inputs[0] = pre_out + + post_add, post_out = _make_add_zero_barrier(out_tensor, zero_var, name=f"FusionBarrier_post_{safe_name}") + for node in graph.nodes: + if node is post_add: + continue + for i, inp in enumerate(node.inputs): + if inp is out_tensor: + node.inputs[i] = post_out + for i, outp in enumerate(graph.outputs): + if outp is out_tensor: + graph.outputs[i] = post_out + + graph.nodes.extend([zero_sub, pre_add, post_add]) + graph.cleanup().toposort() + + onnx.save(gs.export_onnx(graph), str(onnx_path)) + print(f" wrapped {tp_node.name!r} with Add(x, runtime_zero) barriers") + + +def _infer_num_quantizers(onnx_path): + model = onnx.load(str(onnx_path)) + for inp in model.graph.input: + if inp.name != "audio_codes": + continue + dims = inp.type.tensor_type.shape.dim + if len(dims) >= 3 and dims[2].dim_value > 0: + return int(dims[2].dim_value) + raise RuntimeError( + f"could not infer num_quantizers from {onnx_path} (audio_codes dim 2 is not a static positive integer)" + ) + + +def convert_to_trt(onnx_path, trt_path, trtexec_bin, nq, batch_prof, frames_prof, fp32): + exe = shutil.which(trtexec_bin) if "/" not in trtexec_bin else trtexec_bin + if exe is None: + raise FileNotFoundError(f"trtexec not found: {trtexec_bin}") + trt_path.parent.mkdir(parents=True, exist_ok=True) + + def s(b, f): + return f"{b}x{f}x{nq}" + + cmd = [ + exe, + f"--onnx={onnx_path}", + f"--saveEngine={trt_path}", + f"--minShapes=audio_codes:{s(batch_prof[0], frames_prof[0])}", + f"--optShapes=audio_codes:{s(batch_prof[1], frames_prof[1])}", + f"--maxShapes=audio_codes:{s(batch_prof[2], frames_prof[2])}", + ] + if not fp32: + cmd.append("--fp16") + print("Running:", " ".join(cmd)) + subprocess.run(cmd, check=True) + print(f"TensorRT engine saved to {trt_path}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Build a TensorRT engine from the Qwen3-TTS codec ONNX") + p.add_argument("--onnx-path", required=True) + p.add_argument("--trt-path", required=True) + p.add_argument("--trtexec-bin", default="/usr/src/tensorrt/bin/trtexec") + p.add_argument("--batch-profile", nargs=3, type=int, default=[1, 8, 32], metavar=("MIN", "OPT", "MAX")) + p.add_argument("--frames-profile", nargs=3, type=int, default=[30, 30, 30], metavar=("MIN", "OPT", "MAX")) + p.add_argument("--fp32", action="store_true", help="Build pure FP32 engine (default: FP16).") + p.add_argument( + "--no-fusion-barrier", + action="store_true", + help="Skip the TRT-10.15 fusion-barrier ONNX patch (without it the engine is wrong at dynamic batch > 1).", + ) + return p.parse_args() + + +def main(): + args = parse_args() + onnx_path = Path(args.onnx_path) + trt_path = Path(args.trt_path) + + if not onnx_path.is_file(): + raise FileNotFoundError(f"ONNX not found: {onnx_path}") + + if not args.no_fusion_barrier: + apply_trt_fusion_barrier(onnx_path) + + nq = _infer_num_quantizers(onnx_path) + print(f"num_quantizers={nq} (from {onnx_path})") + + convert_to_trt( + onnx_path, + trt_path, + args.trtexec_bin, + nq=nq, + batch_prof=tuple(args.batch_profile), + frames_prof=tuple(args.frames_profile), + fp32=args.fp32, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/model_executor/models/qwen3_tts_nv/__init__.py b/tests/model_executor/models/qwen3_tts_nv/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/model_executor/models/qwen3_tts_nv/test_qwen3_tts_talker_nv.py b/tests/model_executor/models/qwen3_tts_nv/test_qwen3_tts_talker_nv.py new file mode 100644 index 00000000000..d8af118854d --- /dev/null +++ b/tests/model_executor/models/qwen3_tts_nv/test_qwen3_tts_talker_nv.py @@ -0,0 +1,586 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved +"""High-level unit tests for ``Qwen3TTSTalkerForConditionalGenerationNv``. + +The full model has heavy dependencies (``Qwen3Model``, ``ParallelLMHead``, +``VocabParallelEmbedding``, etc.) that require a distributed init, so these +tests construct the instance via ``object.__new__`` and inject only the +attributes that ``forward`` / ``compute_logits`` / ``make_omni_output`` / +``postprocess`` actually read. + +The interesting behavior under test is the per-step dispatch in +:meth:`Qwen3TTSTalkerForConditionalGenerationNv.forward`: + +* **Decode-only batch** — ``_get_decode_idxs`` returns ``(None, 0)`` and the + code predictor runs on every token. +* **Mixed prefill + decode batch** — only decode positions are routed + through the code predictor; prefill positions keep the prefill embedding + produced by ``preprocess``. +* **All-prefill batch** — code predictor is skipped entirely. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.model_executor.models.qwen3_tts_nv import qwen3_tts_talker_nv as nv +from vllm_omni.model_executor.models.qwen3_tts_nv.qwen3_tts_talker_nv import ( + Qwen3TTSTalkerForConditionalGenerationNv, + _dict_to_namespace, + _get_talker_config, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +HIDDEN = 8 +NUM_CODE_GROUPS = 4 +VOCAB_SIZE = 16 +MAX_NUM_TOKENS = 16 + + +# ────────────────────────────────────────────────────────────────────── +# Helpers: build a minimal ``Qwen3TTSTalkerForConditionalGenerationNv`` +# without running the real ``__init__`` (avoids distributed init). +# ────────────────────────────────────────────────────────────────────── + + +class _FakeCodePredictor(nn.Module): + """Stand-in for ``self.code_predictor`` exposing the surface used by forward.""" + + def __init__(self) -> None: + super().__init__() + self.num_code_groups = NUM_CODE_GROUPS + # Per-group embedding tables for groups 1..N-1. + self._group_embeddings = nn.ModuleList([nn.Embedding(VOCAB_SIZE, HIDDEN) for _ in range(NUM_CODE_GROUPS - 1)]) + # Group-0 codec embedding. + self.codec_embedding = nn.Embedding(VOCAB_SIZE, HIDDEN) + self.generate_calls: list[dict[str, torch.Tensor]] = [] + + def get_group_embeddings(self) -> nn.ModuleList: + return self._group_embeddings + + def generate_groups_1_15( + self, + prev_hidden: torch.Tensor, + group0_tokens: torch.Tensor, + ) -> torch.Tensor: + """Record inputs and return a deterministic [seq_len, N-1] code tensor.""" + self.generate_calls.append( + { + "prev_hidden": prev_hidden.detach().clone(), + "group0_tokens": group0_tokens.detach().clone(), + } + ) + seq_len = group0_tokens.shape[0] + # Deterministic codes derived from group0 so we can assert later. + codes = group0_tokens.view(-1, 1).expand(seq_len, NUM_CODE_GROUPS - 1) % VOCAB_SIZE + return codes.contiguous() + + +class _FakeBackbone(nn.Module): + """Stand-in for ``self.model``: returns the input embeds directly.""" + + def __init__(self) -> None: + super().__init__() + self.last_call: dict[str, torch.Tensor] | None = None + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + self.last_call = { + "input_ids": input_ids.detach().clone(), + "inputs_embeds": inputs_embeds.detach().clone(), + } + # Returning the embeds preserves whatever forward built into the buffer. + return inputs_embeds.clone() + + +def _make_fake_attn_metadata(query_lens: list[int], device: torch.device = torch.device("cpu")) -> SimpleNamespace: + """Build a fake attn_metadata mimicking the runner's contract.""" + start_loc = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(query_lens), 0).tolist()), + dtype=torch.long, + device=device, + ) + return SimpleNamespace( + max_query_len=int(max(query_lens)), + query_start_loc=start_loc, + ) + + +def _make_fake_forward_context(attn_metadata) -> SimpleNamespace: + return SimpleNamespace( + attn_metadata=attn_metadata, + batch_descriptor=None, + ) + + +def _make_talker_instance() -> Qwen3TTSTalkerForConditionalGenerationNv: + """Construct a Talker without running ``__init__`` and inject the + attributes that the methods under test read.""" + model = object.__new__(Qwen3TTSTalkerForConditionalGenerationNv) + nn.Module.__init__(model) + + # Persistent scratch buffers. + model._combined_embeddings = torch.zeros(MAX_NUM_TOKENS, HIDDEN) + model._out_codes = torch.zeros(MAX_NUM_TOKENS, NUM_CODE_GROUPS, dtype=torch.long) + model._prev_hidden_buffer = torch.zeros(MAX_NUM_TOKENS, HIDDEN) + # tts_pad text embedding (a fixed constant, populated from weights at + # load time; here we set a recognisable bias so we can verify it lands + # in the assembled decode embedding). + model._tts_pad_text_embed = torch.full((1, HIDDEN), 0.5) + + # Submodules. + model.code_predictor = _FakeCodePredictor() + model.model = _FakeBackbone() + + # vllm_config: only ``compilation_config.cudagraph_mode`` is read by + # ``_get_decode_idxs``. We set NONE so no padding kicks in; tests for + # padding cover that branch separately. + from vllm.config import CUDAGraphMode + + model.vllm_config = SimpleNamespace( + compilation_config=SimpleNamespace( + cudagraph_mode=CUDAGraphMode.NONE, + cudagraph_capture_sizes=[], + ) + ) + + # ``compute_logits`` reads these. + model.codec_head = nn.Linear(HIDDEN, VOCAB_SIZE, bias=False) + model.suppress_mask = nn.Parameter(torch.zeros(VOCAB_SIZE, dtype=torch.bool), requires_grad=False) + model.logits_processor = lambda head, hs: head(hs) + + return model + + +# ────────────────────────────────────────────────────────────────────── +# Forward dispatch: decode-only vs mixed vs all-prefill +# ────────────────────────────────────────────────────────────────────── + + +def test_get_decode_idxs_returns_none_when_no_attn_metadata(monkeypatch): + """Profile / dummy run: code predictor must run on every position.""" + model = _make_talker_instance() + monkeypatch.setattr(nv, "get_forward_context", lambda: _make_fake_forward_context(None)) + + decode_idx, num_req = model._get_decode_idxs() + + assert decode_idx is None + assert num_req == 0 + + +def test_get_decode_idxs_returns_none_for_decode_only_batch(monkeypatch): + """Decode-only batch (``max_query_len == 1``): apply everywhere.""" + model = _make_talker_instance() + attn_md = _make_fake_attn_metadata([1, 1, 1, 1]) + monkeypatch.setattr(nv, "get_forward_context", lambda: _make_fake_forward_context(attn_md)) + + decode_idx, num_req = model._get_decode_idxs() + + assert decode_idx is None + assert num_req == 0 + + +def test_get_decode_idxs_picks_decode_indices_in_mixed_batch(monkeypatch): + """Mixed batch: decode tokens are at positions 0 (req#0=1 tok) and 5 + (req#2=1 tok). Req#1 is prefill (4 tokens at positions 1..4).""" + model = _make_talker_instance() + attn_md = _make_fake_attn_metadata([1, 4, 1]) + monkeypatch.setattr(nv, "get_forward_context", lambda: _make_fake_forward_context(attn_md)) + + decode_idx, num_req = model._get_decode_idxs() + + assert num_req == 2 + assert decode_idx.tolist() == [0, 5] + + +def test_get_decode_idxs_returns_empty_for_all_prefill_batch(monkeypatch): + """All-prefill batch (no req with query_len == 1).""" + model = _make_talker_instance() + attn_md = _make_fake_attn_metadata([3, 4]) + monkeypatch.setattr(nv, "get_forward_context", lambda: _make_fake_forward_context(attn_md)) + + decode_idx, num_req = model._get_decode_idxs() + + assert num_req == 0 + assert decode_idx.numel() == 0 + + +def test_forward_decode_only_runs_code_predictor_everywhere(monkeypatch): + """When ``decode_idx`` is None, the code predictor must be called once + on the full batch and the assembled decode embedding (codec_emb + + tts_pad + sum(group_embs)) must be written to every position. + """ + model = _make_talker_instance() + monkeypatch.setattr(nv, "get_forward_context", lambda: _make_fake_forward_context(None)) + + num_tokens = 3 + input_ids = torch.tensor([1, 2, 3], dtype=torch.long) + positions = torch.arange(num_tokens, dtype=torch.long) + inputs_embeds = torch.zeros(num_tokens, HIDDEN) + prev_hidden_slot = torch.randn(num_tokens, HIDDEN) + model._prev_hidden_buffer[:num_tokens].copy_(prev_hidden_slot) + + out = model.forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=inputs_embeds, + ) + + # Code predictor was called on the full batch. + assert len(model.code_predictor.generate_calls) == 1 + call = model.code_predictor.generate_calls[0] + torch.testing.assert_close(call["group0_tokens"], input_ids) + torch.testing.assert_close(call["prev_hidden"], prev_hidden_slot) + + # Output codes: column 0 == input_ids, columns 1..N-1 == fake codes. + expected_codes_1_15 = input_ids.view(-1, 1).expand(num_tokens, NUM_CODE_GROUPS - 1) % VOCAB_SIZE + torch.testing.assert_close(model._out_codes[:num_tokens, 0], input_ids) + torch.testing.assert_close(model._out_codes[:num_tokens, 1:], expected_codes_1_15) + + # Backbone was fed an embedding equal to the analytical decode assembly. + cp = model.code_predictor + expected_emb = cp.codec_embedding(input_ids) + model._tts_pad_text_embed + for i, emb in enumerate(cp.get_group_embeddings()): + expected_emb = expected_emb + emb(expected_codes_1_15[:, i]) + + assert model.model.last_call is not None + torch.testing.assert_close(model.model.last_call["inputs_embeds"], expected_emb) + # And forward returned that hidden states (FakeBackbone is identity). + torch.testing.assert_close(out, expected_emb) + + +def test_forward_mixed_batch_only_runs_code_predictor_on_decode(monkeypatch): + """Mixed batch: code predictor runs on decode positions only and the + prefill positions in ``inputs_embeds`` flow through unchanged.""" + model = _make_talker_instance() + + # 3 reqs: decode (1 tok), prefill (3 tok), decode (1 tok). Decode positions + # in the flat batch are 0 and 4. + query_lens = [1, 3, 1] + num_tokens = sum(query_lens) + attn_md = _make_fake_attn_metadata(query_lens) + monkeypatch.setattr(nv, "get_forward_context", lambda: _make_fake_forward_context(attn_md)) + + input_ids = torch.tensor([7, 0, 0, 0, 5], dtype=torch.long) + positions = torch.arange(num_tokens, dtype=torch.long) + # Distinct prefill marker so we can assert it survives at prefill positions. + prefill_marker = torch.full((HIDDEN,), -3.0) + inputs_embeds = torch.zeros(num_tokens, HIDDEN) + inputs_embeds[1:4] = prefill_marker + # prev_hidden values for the decode slots. + prev_hidden = torch.zeros(num_tokens, HIDDEN) + prev_hidden[0] = torch.full((HIDDEN,), 0.7) + prev_hidden[4] = torch.full((HIDDEN,), 0.9) + model._prev_hidden_buffer[:num_tokens].copy_(prev_hidden) + + model.forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=inputs_embeds, + ) + + # Code predictor called exactly once on the decode slice [0, 4]. + assert len(model.code_predictor.generate_calls) == 1 + call = model.code_predictor.generate_calls[0] + torch.testing.assert_close(call["group0_tokens"], torch.tensor([7, 5], dtype=torch.long)) + torch.testing.assert_close(call["prev_hidden"], prev_hidden[[0, 4]]) + + # ``_out_codes`` only has groups 1..N-1 written at the decode rows. + decode_rows = model._out_codes[[0, 4], 1:] + expected_decode_codes = torch.tensor([[7, 7, 7], [5, 5, 5]], dtype=torch.long) + torch.testing.assert_close(decode_rows, expected_decode_codes) + + # Prefill rows for groups 1..N-1 must remain untouched (zero). + torch.testing.assert_close( + model._out_codes[1:4, 1:], + torch.zeros((3, NUM_CODE_GROUPS - 1), dtype=torch.long), + ) + + # input_ids fully written into column 0. + torch.testing.assert_close(model._out_codes[:num_tokens, 0], input_ids) + + # Backbone embeddings: prefill rows preserved, decode rows replaced by + # the assembled decode embedding. + fed = model.model.last_call["inputs_embeds"] + torch.testing.assert_close(fed[1], prefill_marker) + torch.testing.assert_close(fed[2], prefill_marker) + torch.testing.assert_close(fed[3], prefill_marker) + + cp = model.code_predictor + decode_ids = torch.tensor([7, 5], dtype=torch.long) + expected_decode_emb = cp.codec_embedding(decode_ids) + model._tts_pad_text_embed + for i, emb in enumerate(cp.get_group_embeddings()): + expected_decode_emb = expected_decode_emb + emb(expected_decode_codes[:, i]) + torch.testing.assert_close(fed[[0, 4]], expected_decode_emb) + + +def test_forward_all_prefill_skips_code_predictor(monkeypatch): + """All-prefill batch: code predictor must not be called at all.""" + model = _make_talker_instance() + query_lens = [3, 4] + num_tokens = sum(query_lens) + attn_md = _make_fake_attn_metadata(query_lens) + monkeypatch.setattr(nv, "get_forward_context", lambda: _make_fake_forward_context(attn_md)) + + input_ids = torch.zeros(num_tokens, dtype=torch.long) + inputs_embeds = torch.randn(num_tokens, HIDDEN) + expected_passthrough = inputs_embeds.clone() + + model.forward( + input_ids=input_ids, + positions=torch.arange(num_tokens, dtype=torch.long), + intermediate_tensors=None, + inputs_embeds=inputs_embeds, + ) + + # No code predictor call. + assert model.code_predictor.generate_calls == [] + # Backbone was fed exactly the prefill embeddings. + torch.testing.assert_close(model.model.last_call["inputs_embeds"], expected_passthrough) + # Groups 1..N-1 must remain zero (they're never produced for prefill). + torch.testing.assert_close( + model._out_codes[:num_tokens, 1:], + torch.zeros((num_tokens, NUM_CODE_GROUPS - 1), dtype=torch.long), + ) + + +# ────────────────────────────────────────────────────────────────────── +# make_omni_output / postprocess / compute_logits +# ────────────────────────────────────────────────────────────────────── + + +def test_make_omni_output_wraps_hidden_and_codes(): + model = _make_talker_instance() + + num_tokens = 5 + hidden = torch.randn(num_tokens, HIDDEN) + # Pre-populate _out_codes with a recognisable pattern. + model._out_codes[:num_tokens] = torch.arange(num_tokens * NUM_CODE_GROUPS, dtype=torch.long).view( + num_tokens, NUM_CODE_GROUPS + ) + + out = model.make_omni_output(hidden) + + assert isinstance(out, OmniOutput) + torch.testing.assert_close(out.text_hidden_states, hidden) + assert out.multimodal_outputs is not None + audio_codes = out.multimodal_outputs["audio_codes"] + assert audio_codes.shape == (num_tokens, NUM_CODE_GROUPS) + torch.testing.assert_close(audio_codes, model._out_codes[:num_tokens]) + + +def test_make_omni_output_passes_through_existing_omni_output(): + model = _make_talker_instance() + existing = OmniOutput( + text_hidden_states=torch.zeros(1, HIDDEN), + multimodal_outputs={"audio_codes": torch.zeros(1, NUM_CODE_GROUPS)}, + ) + assert model.make_omni_output(existing) is existing + + +def test_postprocess_returns_last_hidden(): + model = _make_talker_instance() + hidden = torch.arange(3 * HIDDEN, dtype=torch.float32).view(3, HIDDEN) + out = model.postprocess(hidden) + assert "last_talker_hidden" in out + torch.testing.assert_close(out["last_talker_hidden"], hidden[-1]) + + +def test_postprocess_empty_hidden_returns_empty_dict(): + model = _make_talker_instance() + assert model.postprocess(torch.empty(0, HIDDEN)) == {} + + +def test_compute_logits_applies_suppress_mask(): + model = _make_talker_instance() + # Set deterministic codec_head weights so we know what logits to expect. + with torch.no_grad(): + model.codec_head.weight.copy_(torch.eye(VOCAB_SIZE, HIDDEN)[:VOCAB_SIZE, :HIDDEN]) + # Suppress two tokens. + mask = torch.zeros(VOCAB_SIZE, dtype=torch.bool) + mask[3] = True + mask[7] = True + model.suppress_mask.data.copy_(mask) + + hidden = torch.randn(2, HIDDEN) + logits = model.compute_logits(hidden) + + assert logits.shape == (2, VOCAB_SIZE) + assert torch.isinf(logits[:, 3]).all() and (logits[:, 3] < 0).all() + assert torch.isinf(logits[:, 7]).all() and (logits[:, 7] < 0).all() + # Other entries are finite. + finite_cols = [i for i in range(VOCAB_SIZE) if i not in (3, 7)] + assert torch.isfinite(logits[:, finite_cols]).all() + + +def test_compute_logits_unwraps_omni_output(): + model = _make_talker_instance() + hidden = torch.randn(1, HIDDEN) + wrapped = OmniOutput(text_hidden_states=hidden) + direct = model.compute_logits(hidden) + via_omni = model.compute_logits(wrapped) + torch.testing.assert_close(via_omni, direct) + + +# ────────────────────────────────────────────────────────────────────── +# Static helpers +# ────────────────────────────────────────────────────────────────────── + + +def test_first_str_handles_lists_scalars_and_none(): + f = Qwen3TTSTalkerForConditionalGenerationNv._first_str + assert f(["hello", "ignored"]) == "hello" + assert f([]) == "" + assert f("plain") == "plain" + assert f(None) == "" + assert f(42) == "42" + + +def test_build_assistant_text_layout(): + text = Qwen3TTSTalkerForConditionalGenerationNv._build_assistant_text("hi") + assert text == "<|im_start|>assistant\nhi<|im_end|>\n<|im_start|>assistant\n" + + +def _make_tokenizer(token_count: int): + """Return a fake tokenizer that returns ``token_count`` ints regardless of input.""" + return lambda s: [0] * token_count + + +def test_estimate_prompt_len_rejects_non_customvoice(): + with pytest.raises(ValueError, match="CustomVoice"): + Qwen3TTSTalkerForConditionalGenerationNv.estimate_prompt_len_from_additional_information( + {"text": "hi", "speaker": "alice"}, + task_type="OTHER", + tokenize_prompt=_make_tokenizer(10), + codec_language_id=None, + ) + + +def test_estimate_prompt_len_requires_text_and_speaker(): + with pytest.raises(ValueError, match="text"): + Qwen3TTSTalkerForConditionalGenerationNv.estimate_prompt_len_from_additional_information( + {"speaker": "alice"}, + task_type="CustomVoice", + tokenize_prompt=_make_tokenizer(10), + codec_language_id=None, + ) + with pytest.raises(ValueError, match="speaker"): + Qwen3TTSTalkerForConditionalGenerationNv.estimate_prompt_len_from_additional_information( + {"text": "hello"}, + task_type="CustomVoice", + tokenize_prompt=_make_tokenizer(10), + codec_language_id=None, + ) + + +def test_estimate_prompt_len_no_language_id_uses_prefill_3(): + """No language_id -> prefill_len=3, total = 3 + assistant_len - 1.""" + assistant_len = 12 + out = Qwen3TTSTalkerForConditionalGenerationNv.estimate_prompt_len_from_additional_information( + {"text": "hi", "speaker": "alice", "language": "Auto"}, + task_type="CustomVoice", + tokenize_prompt=_make_tokenizer(assistant_len), + codec_language_id={"english": 1}, + spk_is_dialect=None, + ) + assert out == 3 + assistant_len - 1 + + +def test_estimate_prompt_len_with_language_id_uses_prefill_4(): + """Resolved language_id -> prefill_len=4.""" + assistant_len = 12 + out = Qwen3TTSTalkerForConditionalGenerationNv.estimate_prompt_len_from_additional_information( + {"text": "hi", "speaker": "alice", "language": "English"}, + task_type="CustomVoice", + tokenize_prompt=_make_tokenizer(assistant_len), + codec_language_id={"english": 1}, + ) + assert out == 4 + assistant_len - 1 + + +def test_estimate_prompt_len_dialect_fallback_promotes_to_4(): + """Auto language + speaker registered as a dialect resolves a language_id.""" + assistant_len = 12 + out = Qwen3TTSTalkerForConditionalGenerationNv.estimate_prompt_len_from_additional_information( + {"text": "hi", "speaker": "shanghainese_voice", "language": "Auto"}, + task_type="CustomVoice", + tokenize_prompt=_make_tokenizer(assistant_len), + codec_language_id={"shanghainese": 7}, + spk_is_dialect={"shanghainese_voice": "shanghainese"}, + ) + assert out == 4 + assistant_len - 1 + + +def test_estimate_prompt_len_unwraps_list_values(): + assistant_len = 12 + out = Qwen3TTSTalkerForConditionalGenerationNv.estimate_prompt_len_from_additional_information( + { + "text": ["hi"], + "speaker": ["alice"], + "language": ["Auto"], + }, + task_type="CustomVoice", + tokenize_prompt=_make_tokenizer(assistant_len), + codec_language_id={"english": 1}, + ) + assert out == 3 + assistant_len - 1 + + +def test_estimate_prompt_len_short_assistant_raises(): + with pytest.raises(ValueError, match="assistant prompt length"): + Qwen3TTSTalkerForConditionalGenerationNv.estimate_prompt_len_from_additional_information( + {"text": "hi", "speaker": "alice"}, + task_type="CustomVoice", + tokenize_prompt=_make_tokenizer(5), + codec_language_id=None, + ) + + +# ────────────────────────────────────────────────────────────────────── +# Internal helpers _dict_to_namespace / _get_talker_config +# ────────────────────────────────────────────────────────────────────── + + +def test_dict_to_namespace_recurses_but_keeps_rope_scaling_as_dict(): + src = { + "hidden_size": 32, + "rope_scaling": {"rope_type": "yarn", "factor": 4.0}, + "nested": {"a": 1}, + } + ns = _dict_to_namespace(src) + assert ns.hidden_size == 32 + # rope_scaling is preserved as a dict (downstream expects dict-like). + assert isinstance(ns.rope_scaling, dict) + assert ns.rope_scaling == {"rope_type": "yarn", "factor": 4.0} + # Other nested dicts get converted. + assert ns.nested.a == 1 + + +def test_get_talker_config_with_full_config_returns_talker_field(): + talker_dict = {"hidden_size": 32, "vocab_size": 16} + full = SimpleNamespace(talker_config=talker_dict) + out = _get_talker_config(full) + assert out.hidden_size == 32 + assert out.vocab_size == 16 + + +def test_get_talker_config_with_already_talker_config_returns_unchanged(): + talker_cfg = SimpleNamespace(hidden_size=32) + out = _get_talker_config(talker_cfg) + assert out is talker_cfg diff --git a/vllm_omni/model_executor/models/qwen3_tts_nv/__init__.py b/vllm_omni/model_executor/models/qwen3_tts_nv/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/model_executor/models/qwen3_tts_nv/qwen3_tts_talker_nv.py b/vllm_omni/model_executor/models/qwen3_tts_nv/qwen3_tts_talker_nv.py new file mode 100644 index 00000000000..b6de04417a1 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts_nv/qwen3_tts_talker_nv.py @@ -0,0 +1,1520 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# Copyright 2026 The Qwen team, Alibaba Group. +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3TTS Talker model compatible with HuggingFace weights.""" + +import bisect +from collections.abc import Callable, Iterable, Mapping +from types import SimpleNamespace +from typing import Any + +import torch +from torch import nn +from transformers import AutoTokenizer, PretrainedConfig +from vllm.compilation.backends import set_model_tag +from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.models.qwen3 import Qwen3Model +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + maybe_prefix, +) + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +logger = init_logger(__name__) + + +# ── RoPE helpers for the native code predictor ────────────────────── + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dims of the input (standard RoPE helper).""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply standard 1-D rotary position embeddings to Q and K. + + Args: + q, k: [batch, num_heads, seq_len, head_dim] + cos, sin: [1, 1, seq_len, head_dim] (broadcastable) + """ + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3TTSNativeRotaryEmbedding(nn.Module): + """Simple 1-D rotary position embedding for the native code predictor. + + Matches the ``Qwen3TTSRotaryEmbedding`` in the original HF code, but + simplified: no dynamic-rope, no MRoPE – just standard RoPE with a + configurable ``rope_theta``. + """ + + def __init__(self, head_dim: int, rope_theta: float = 1_000_000.0) -> None: + super().__init__() + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) + # Use nn.Parameter so vLLM natively handles device/dtype casting. + # requires_grad=False because this is deterministic and not trained. + # The weight-loader already skips "rotary_emb.inv_freq". + self.inv_freq = nn.Parameter(inv_freq, requires_grad=False) + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: + """Return ``(cos, sin)`` tensors for positions ``[0 .. seq_len)``. + + Returns: + cos: [1, 1, seq_len, head_dim] + sin: [1, 1, seq_len, head_dim] + """ + positions = torch.arange(seq_len, device=device, dtype=torch.float32) + # [seq_len] x [head_dim/2] → [seq_len, head_dim/2] + freqs = torch.outer(positions, self.inv_freq.to(device)) + emb = torch.cat([freqs, freqs], dim=-1) # [seq_len, head_dim] + cos = emb.cos().unsqueeze(0).unsqueeze(0).to(dtype) + sin = emb.sin().unsqueeze(0).unsqueeze(0).to(dtype) + return cos, sin + + +def _gumbel_sample(logits: torch.Tensor) -> torch.Tensor: + """Gumbel-max trick: equivalent to categorical sampling. + + Uses only uniform RNG + log + argmax — all CUDA-graph safe. + Unlike ``torch.multinomial``, this degrades gracefully on degenerate + inputs (all-zero probs / all-``-inf`` logits) instead of triggering + a device-side assert that poisons the CUDA context. Also ~2.5x + faster than multinomial in graph replay benchmarks. + """ + u = torch.empty_like(logits).uniform_(1e-20, 1.0 - 1e-20) + return (logits - torch.log(-torch.log(u))).argmax(dim=-1) + + +def _multinomial_sample(logits: torch.Tensor) -> torch.Tensor: + """Standard softmax + multinomial sampling. + + CUDA-graph capturable on PyTorch >= 2.8, but will crash with a + device-side assert if any row has all-zero probabilities (e.g. + during graph warmup with uninitialised buffers). + """ + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, 1).squeeze(-1) + + +def _sample_from_logits( + logits: torch.Tensor, + do_sample: bool = True, + temperature: float = 1.0, + top_k: int | None = None, + top_p: float | None = None, + repetition_penalty: float = 1.0, + previous_tokens: torch.Tensor | None = None, + use_gumbel: bool = True, +) -> torch.Tensor: + """Sample tokens from logits (CUDA-graph safe). + + All operations are legal inside ``torch.cuda.graph()`` capture on + PyTorch >= 2.8 (``topk``, ``sort``, ``multinomial``, ``uniform_``, + ``argmax``, ``gather``, ``scatter_``, ``masked_fill``). + + The only patterns that remain **unsafe** during capture are + host-to-device copies such as ``torch.tensor(scalar, device=cuda)`` + and ``torch.full_like(t, val)`` for some values — use + ``masked_fill`` or pre-allocated buffers instead. + + Args: + use_gumbel: If ``True`` (default), use the Gumbel-max trick for + the final categorical draw. Gumbel-max is ~2.5x faster + than ``multinomial`` and robust to degenerate warmup data. + Set ``False`` to use ``softmax → multinomial`` instead. + """ + if repetition_penalty != 1.0 and previous_tokens is not None: + score = torch.gather(logits, -1, previous_tokens) + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits.scatter_(-1, previous_tokens, score) + + if not do_sample: + return logits.argmax(dim=-1) + + logits = logits / max(temperature, 1e-6) + + # ── Top-k filtering ───────────────────────────────────────────── + if top_k is not None and top_k > 0: + vals, idxs = torch.topk(logits, k=min(top_k, logits.size(-1)), dim=-1) + + # ── Top-p (nucleus) within the top-k slice ────────────────── + if top_p is not None and 0.0 < top_p < 1.0: + sorted_vals, sort_idx = torch.sort(vals, dim=-1, descending=True) + probs = torch.softmax(sorted_vals, dim=-1) + cum_probs = torch.cumsum(probs, dim=-1) + remove = (cum_probs - probs) > top_p + sorted_vals = sorted_vals.masked_fill(remove, -1e10) + # Unsort back to topk order + unsort_idx = sort_idx.argsort(dim=-1) + vals = sorted_vals.gather(-1, unsort_idx) + + sampled_in_k = _gumbel_sample(vals) if use_gumbel else _multinomial_sample(vals) + return idxs.gather(-1, sampled_in_k.unsqueeze(-1)).squeeze(-1) + + # ── Top-p only (no top-k) ─────────────────────────────────────── + if top_p is not None and 0.0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) + probs = torch.softmax(sorted_logits, dim=-1) + cum_probs = torch.cumsum(probs, dim=-1) + remove = (cum_probs - probs) > top_p + sorted_logits = sorted_logits.masked_fill(remove, -1e10) + + sampled_sorted = _gumbel_sample(sorted_logits) if use_gumbel else _multinomial_sample(sorted_logits) + return sorted_indices.gather(-1, sampled_sorted.unsqueeze(-1)).squeeze(-1) + + # ── No filtering — sample from full distribution ──────────────── + if use_gumbel: + return _gumbel_sample(logits) + return _multinomial_sample(logits) + + +class Qwen3TTSTalkerResizeMLP(nn.Module): + """Resize MLP for text projection in Qwen3TTS Talker. + + Maps from text_hidden_size to hidden_size with an intermediate layer. + """ + + def __init__( + self, + input_size: int, + intermediate_size: int, + output_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.linear_fc1 = ColumnParallelLinear( + input_size, + intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc1", + ) + self.linear_fc2 = RowParallelLinear( + intermediate_size, + output_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc2", + ) + if hidden_act == "silu": + self.act_fn = nn.SiLU() + elif hidden_act == "gelu": + self.act_fn = nn.GELU() + else: + raise ValueError(f"Unsupported activation: {hidden_act}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.linear_fc1(x) + x = self.act_fn(x) + x, _ = self.linear_fc2(x) + return x + + +class Qwen3TTSNativeAttention(nn.Module): + """Native attention for Qwen3TTS using torch SDPA. + + Used for the code predictor which has deterministic shapes and doesn't + benefit from KV caching. Can be captured in CUDA graphs. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int | None = None, + rms_norm_eps: float = 1e-6, + qkv_bias: bool = False, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim if head_dim else hidden_size // num_heads + self.num_kv_groups = num_heads // num_kv_heads + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=qkv_bias) + self.k_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=qkv_bias) + self.v_proj = nn.Linear(hidden_size, num_kv_heads * self.head_dim, bias=qkv_bias) + self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=qkv_bias) + + # QK normalization + self.q_norm = nn.RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = nn.RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + """Forward pass using torch SDPA. + + Args: + hidden_states: [batch_size, seq_len, hidden_size] + attention_mask: Optional attention mask + position_embeddings: Optional (cos, sin) tuple from rotary + embedding, each [1, 1, seq_len, head_dim]. + """ + batch_size, seq_len, _ = hidden_states.shape + + # Project Q, K, V + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to [batch, seq, num_heads, head_dim] + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + # Apply QK normalization + q = self.q_norm(q) + k = self.k_norm(k) + + # Transpose to [batch, num_heads, seq, head_dim] for SDPA + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Apply rotary position embeddings (standard 1-D RoPE) + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = _apply_rotary_pos_emb(q, k, cos, sin) + + # Expand KV heads if using GQA + if self.num_kv_groups > 1: + k = k.repeat_interleave(self.num_kv_groups, dim=1) + v = v.repeat_interleave(self.num_kv_groups, dim=1) + + # Apply scaled dot product attention + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + is_causal=attention_mask is None, # Use causal if no mask provided + scale=self.scaling, + ) + + # Reshape back to [batch, seq, hidden] + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_len, -1) + + output = self.o_proj(attn_output) + return output + + +class Qwen3TTSNativeMLP(nn.Module): + """Native MLP for Qwen3TTS Code Predictor using standard PyTorch layers.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ) -> None: + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class Qwen3TTSCodePredictorDecoderLayer(nn.Module): + """Native decoder layer for Qwen3TTS Code Predictor. + + Uses native PyTorch attention (SDPA) instead of vLLM attention. + This is more efficient for the code predictor since: + - Shapes are deterministic (fixed 15 steps) + - No KV cache benefit + - Can be captured in CUDA graphs + """ + + def __init__(self, config: PretrainedConfig) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3TTSNativeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=getattr(config, "head_dim", None), + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, "attention_bias", False), + ) + + self.mlp = Qwen3TTSNativeMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # Self Attention with pre-norm + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, attention_mask, position_embeddings) + hidden_states = residual + hidden_states + + # MLP with pre-norm + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +# Keys whose values must stay as plain dicts (expected by downstream code) +_KEEP_AS_DICT_KEYS = {"rope_scaling"} + + +def _dict_to_namespace(d, _key: str | None = None): + """Recursively convert a dict to SimpleNamespace for attribute access. + + Certain keys (e.g. ``rope_scaling``) are kept as plain dicts because + downstream code (``get_rope``, ``"mrope_section" in rope_scaling``, etc.) + expects dict-like objects. + """ + if isinstance(d, dict): + if _key in _KEEP_AS_DICT_KEYS: + return d # keep as plain dict + return SimpleNamespace(**{k: _dict_to_namespace(v, _key=k) for k, v in d.items()}) + return d + + +def _get_talker_config(hf_config: PretrainedConfig): + """Get the talker config from either full TTS config or talker config directly. + + If talker_config is stored as a plain dict (from Qwen3TTSConfig), + convert it to a namespace so attribute access (config.hidden_size etc.) works. + """ + if hasattr(hf_config, "talker_config"): + tc = hf_config.talker_config + if isinstance(tc, dict): + return _dict_to_namespace(tc) + return tc + # Otherwise assume this is already the talker config + return hf_config + + +class Qwen3TTSTalkerCodePredictorModel(nn.Module): + """Native PyTorch code predictor model for Qwen3TTS Talker. + + Uses native attention (SDPA) instead of vLLM attention since: + - Runs for fixed 15 steps per global time step + - Shapes are deterministic + - No benefit from KV caching + - Can be captured in CUDA graphs for efficiency + """ + + def __init__(self, config: PretrainedConfig, embedding_dim: int) -> None: + super().__init__() + + self.config = config + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.num_code_groups = config.num_code_groups + + # Codec embeddings for groups 1 to N-1 (group 0 uses main model embedding) + self.codec_embedding = nn.ModuleList( + [nn.Embedding(config.vocab_size, embedding_dim) for _ in range(config.num_code_groups - 1)] + ) + + # Decoder layers using native attention + self.layers = nn.ModuleList( + [Qwen3TTSCodePredictorDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + # Final layer norm + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Standard 1-D rotary position embeddings (matches HF code predictor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.rotary_emb = Qwen3TTSNativeRotaryEmbedding( + head_dim=head_dim, + rope_theta=getattr(config, "rope_theta", 1_000_000.0), + ) + + def get_input_embeddings(self) -> nn.ModuleList: + """Get codec embedding layers for all groups.""" + return self.codec_embedding + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass. + + Args: + inputs_embeds: [batch_size, seq_len, hidden_size] + attention_mask: Optional causal mask + + Returns: + hidden_states: [batch_size, seq_len, hidden_size] + """ + hidden_states = inputs_embeds + + # Compute position embeddings shared across all decoder layers. + # Positions are simply [0, 1, ..., seq_len-1] since we + # recompute from scratch each call (no KV cache). + seq_len = hidden_states.shape[1] + position_embeddings = self.rotary_emb(seq_len, hidden_states.device, hidden_states.dtype) + + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask, position_embeddings) + + hidden_states = self.norm(hidden_states) + return hidden_states + + +@support_torch_compile +class Qwen3TTSTalkerCodePredictor(nn.Module): + """Code predictor for Qwen3TTS Talker — groups 1..N-1 only. + + Given the previous step's backbone hidden state and the group-0 token + (sampled by vLLM), autoregressively predicts codec groups 1 through + N-1 using a small native-attention transformer. + + Group-0 prediction (``codec_head``, ``suppress_mask``) is handled by + the outer model's ``compute_logits()`` + vLLM sampler. + + Also owns ``codec_embedding`` (group-0 codebook), shared with the + outer model for input-embedding lookups. + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + + hf_config = vllm_config.model_config.hf_config + talker_config = _get_talker_config(hf_config) + config = talker_config.code_predictor_config + if isinstance(config, dict): + config = _dict_to_namespace(config) + quant_config = vllm_config.quant_config + + self.config = config + self.num_code_groups = config.num_code_groups + self.hidden_size = config.hidden_size + self.talker_hidden_size = talker_config.hidden_size + + # Group-0 codec embedding (shared with outer model) + self.codec_embedding = VocabParallelEmbedding( + talker_config.vocab_size, + talker_config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.codec_embedding", + ) + + # Code-predictor transformer backbone + self.model = Qwen3TTSTalkerCodePredictorModel(config, self.talker_hidden_size) + + # Projection from talker hidden size to code predictor hidden size + if config.hidden_size != self.talker_hidden_size: + self.small_to_mtp_projection = nn.Linear(self.talker_hidden_size, config.hidden_size, bias=True) + else: + self.small_to_mtp_projection = nn.Identity() + + # LM heads for each code group (1 to N-1) + self.lm_head = nn.ModuleList( + [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)] + ) + + # Sampling parameters for the internal groups-1..N-1 loop, + # read from code_predictor_config. Fallback defaults match the + # original HF implementation's subtalker_* arguments. + self.do_sample = getattr(config, "do_sample", True) + self.temperature = getattr(config, "temperature", 0.9) + self.top_k = getattr(config, "top_k", 50) + self.top_p = getattr(config, "top_p", 1.0) + self.repetition_penalty = getattr(config, "repetition_penalty", 1.0) + self.use_gumbel = getattr(config, "use_gumbel", True) + + # ── Persistent scratch buffers ────────────────────────────────── + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + N = config.num_code_groups + hidden = talker_config.hidden_size + cp_hidden = config.hidden_size + dtype = vllm_config.model_config.dtype + self._max_cp_len = 1 + N # prev_hidden ctx + group0 + groups 1..N-1 + + self._cp_inputs_embeds = torch.zeros(max_num_tokens, self._max_cp_len, hidden, dtype=dtype) + self._cp_hidden_states = torch.empty(max_num_tokens, self._max_cp_len, cp_hidden, dtype=dtype) + # Only groups 1..N-1 (N-1 columns) + self._cp_all_codecs = torch.empty(max_num_tokens, N - 1, dtype=torch.long) + + def get_group0_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Look up group-0 codec embeddings.""" + return self.codec_embedding(input_ids) + + def get_group_embeddings(self) -> nn.ModuleList: + """Get codec embedding layers for groups 1..N-1.""" + return self.model.get_input_embeddings() + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + """Forward pass through the code predictor transformer.""" + inputs_embeds = self.small_to_mtp_projection(inputs_embeds) + hidden_states = self.model(inputs_embeds) + return hidden_states + + def _compute_inner_logits( + self, + hidden_states: torch.Tensor, + generation_step: int, + ) -> torch.Tensor: + """Compute logits for a specific inner code group (1..N-1).""" + if generation_step >= len(self.lm_head): + raise ValueError(f"generation_step {generation_step} exceeds number of code groups {len(self.lm_head)}") + return self.lm_head[generation_step](hidden_states) + + def generate_groups_1_15( + self, + prev_hidden: torch.Tensor, + group0_tokens: torch.Tensor, + ) -> torch.Tensor: + """Generate codec groups 1..N-1 given previous hidden state and group0. + + Args: + prev_hidden: [seq_len, hidden_size] backbone output from previous step + group0_tokens: [seq_len] group-0 tokens (from vLLM sampling) + + Returns: + codes_1_15: [seq_len, num_code_groups - 1] + """ + seq_len = prev_hidden.shape[0] + N = self.num_code_groups + + inputs_embeds = self._cp_inputs_embeds[:seq_len] # Batch x Books x Dim + all_codecs = self._cp_all_codecs[:seq_len] + + inputs_embeds.zero_() + + # Position 0: previous backbone hidden state + inputs_embeds[:, 0, :] = prev_hidden + + # Position 1: group-0 codec embedding + inputs_embeds[:, 1, :] = self.codec_embedding(group0_tokens) + + for step in range(N - 1): + # some how it is more efficient to re-run same graph for + # bx16xdim input instead of capturing 15 graphs for different + # input lengths + hidden_states = self(inputs_embeds) + + current_len = step + 2 + logits = self._compute_inner_logits(hidden_states[:, current_len - 1, :], step) + + if self.repetition_penalty != 1.0 and step > 0: + current_context = all_codecs[:, :step] + else: + current_context = None + + next_token = _sample_from_logits( + logits, + do_sample=self.do_sample, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, + repetition_penalty=self.repetition_penalty, + previous_tokens=current_context, + use_gumbel=self.use_gumbel, + ) + all_codecs[:, step] = next_token + + next_embed = self.get_group_embeddings()[step](next_token) + inputs_embeds[:, current_len, :] = next_embed + + return all_codecs + + +@ignore_torch_compile +@support_torch_compile +class Qwen3TTSTalkerForConditionalGenerationNv(nn.Module): + """Qwen3TTS Talker for conditional generation. + + Per-step flow: + + 1. **Code predictor** (conditional): given the previous step's backbone + hidden state (``prev_hidden``, custom input) and the group-0 token + (``input_ids``, sampled by vLLM at the previous step), predict codec + groups 1..N-1. Skipped when ``prev_hidden`` is all-zero (prefill). + 2. **Embedding**: text_projection(text_embed) + codec_embed(group0) + + sum of groups-1..N-1 embeddings from the code predictor. + 3. **Backbone**: transformer with vLLM paged attention and KV cache. + 4. **Logits**: ``compute_logits()`` projects backbone output through + ``codec_head`` and applies ``suppress_mask``. vLLM's standard + sampler then samples the next group-0 token. + + Custom I/O: + Inputs: ``text_ids`` (int64), ``prev_hidden`` (float, dim=hidden_size) + Outputs: ``codes`` (int64, dim=N-1), ``hidden`` (float, dim=hidden_size) + """ + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # Maps HuggingFace checkpoint names (raw, unconverted) to the vLLM + # module layout used in this file. Applied to weights with the + # ``talker.`` prefix; ``speaker_encoder.*`` and other top-level + # checkpoint sections are filtered out (the NV variant doesn't use + # them). Order matters: more-specific prefixes first so that e.g. + # ``talker.model.codec_embedding.`` is rerouted before the generic + # ``talker.model.layers.`` rule could match. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Group-0 codec embedding lives inside the code predictor. + "talker.model.codec_embedding.": "code_predictor.codec_embedding.", + # Text embedding lifted to the outer model (vLLM's Qwen3Model + # owns ``embed_tokens`` for codec ids; text tokens use a + # separate top-level table). + "talker.model.text_embedding.": "text_embedding.", + # Talker backbone (transformer + final norm) — uses vLLM's + # ``Qwen3Model`` directly, matching layer/norm names 1:1. + "talker.model.layers.": "model.layers.", + "talker.model.norm.": "model.norm.", + # Side modules. + "talker.codec_head.": "codec_head.", + "talker.text_projection.": "text_projection.", + # Code predictor (groups 1..N-1, native attention). + "talker.code_predictor.": "code_predictor.", + } + ) + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + + hf_config = vllm_config.model_config.hf_config + config = _get_talker_config(hf_config) + quant_config = vllm_config.quant_config + + self.hf_config = hf_config + self.config = config + self.quant_config = quant_config + self.vllm_config = vllm_config + self.model_path = vllm_config.model_config.model + + # Omni preprocess/postprocess hooks (consumed by OmniGPUModelRunner). + self.has_preprocess = True + self.has_postprocess = True + # Required so the runner unpacks ``multimodal_outputs`` (audio_codes) + # from the ``OmniOutput`` returned by :meth:`make_omni_output`. + # Without this, ``extract_multimodal_outputs`` discards the codes. + self.have_multimodal_outputs = True + # Keep small per-step buffers GPU-resident (avoids CPU round-trips). + self.gpu_resident_buffer_keys: set[str] = { + "last_talker_hidden", + } + + # HF AutoTokenizer, loaded eagerly so prefill preprocess has no + # first-call latency spike. + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True, padding_side="left") + + # Transformer backbone — vLLM's reusable Qwen3Model. The talker + # has Qwen3-style decoder layers, so we delegate the entire + # backbone (decoder layers, final norm, and a per-rank + # ``embed_tokens`` table that we do not actually consume — every + # forward goes through ``inputs_embeds``). + with set_model_tag("talker"): + self.model = Qwen3Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + # Text-token embedding lives outside the backbone (Qwen3Model only + # owns the codec-vocab ``embed_tokens``). + self.text_embedding = VocabParallelEmbedding( + config.text_vocab_size, + config.text_hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "text_embedding"), + ) + + # Text projection MLP + self.text_projection = Qwen3TTSTalkerResizeMLP( + input_size=config.text_hidden_size, + intermediate_size=config.text_hidden_size, + output_size=config.hidden_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "text_projection"), + ) + + # Compiled code predictor (groups 1..N-1 only) + with set_model_tag("code_predictor"): + self.code_predictor = Qwen3TTSTalkerCodePredictor( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "code_predictor"), + ) + + # Group-0 prediction head + suppress mask (used by compute_logits + # so vLLM's standard sampler can sample group-0). + self.codec_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "codec_head"), + ) + + self.suppress_mask = nn.Parameter( + torch.zeros(config.vocab_size, dtype=torch.bool), + requires_grad=False, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + # Persistent buffers — addresses must be stable across CUDA graph + # replays. The piecewise CUDAGraphWrapper does NOT copy inputs on + # replay; it expects the same ``data_ptr()`` that was recorded during + # capture. Any tensor created transiently in ``forward()`` (like + # ``text_embed + codec_embed``) would have a new address each call, + # causing the replayed graph to read stale memory. + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + dtype = vllm_config.model_config.dtype + self._out_codes = torch.zeros(max_num_tokens, self.code_predictor.num_code_groups, dtype=torch.long) + self._combined_embeddings = torch.zeros(max_num_tokens, config.hidden_size, dtype=dtype) + # Per-token slot for the previous-step backbone hidden state fed to the + # code predictor. Written by ``preprocess`` at the request's offset, + # read by ``forward`` at decode positions. + self._prev_hidden_buffer = torch.zeros(max_num_tokens, config.hidden_size, dtype=dtype) + # ``text_proj(text_emb(tts_pad_token_id))`` — request-independent + # constant added on top of ``codec_emb(group0)`` at every decode + # step. Declared here so the address is stable across CUDA graph + # replays; actual value is populated from weights in ``load_weights``. + self._tts_pad_text_embed = torch.zeros(1, config.hidden_size, dtype=dtype) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Get group-0 codec embeddings for input ids.""" + return self.code_predictor.get_group0_embeddings(input_ids) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.get_input_embeddings(input_ids) + + def _get_decode_idxs(self): + """ + helper function that returns indices of decoding tokens, + that's where exactly the local transformer should be + applied. + + Returns: + decode_idx: indices of decoder requests, if None returned, + local transformer should be applied everywhere + num_requests: number of decoding requests, before padding + """ + ctx = get_forward_context() + attn_metadata = ctx.attn_metadata + if attn_metadata is None: + # when attention metadata is not provided (capturing, dummy run) + # then we should apply the local transformer everywhere + return None, 0 + + if isinstance(attn_metadata, dict): + any_layer_meta = next(iter(attn_metadata.values())) + else: + any_layer_meta = attn_metadata + + if any_layer_meta.max_query_len == 1: + # all requests in the batch a decode-only, + # apply local transformer everywhere + return None, 0 + + start_loc = any_layer_meta.query_start_loc + tokens_per_req = start_loc[1:] - start_loc[:-1] + is_decode = tokens_per_req == 1 # shape: (num_reqs,) + decode_token_indices = start_loc[:-1][is_decode] + + num_requests = decode_token_indices.shape[0] + padded_num_requests = num_requests + if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + idx = bisect.bisect_left(sizes, num_requests) + if idx < len(sizes): + padded_num_requests = sizes[idx] + if padded_num_requests != num_requests: + decode_token_indices = torch.nn.functional.pad( + decode_token_indices, (0, padded_num_requests - num_requests) + ) + return decode_token_indices, num_requests + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Any | None = None, + inputs_embeds: torch.Tensor | None = None, + **_: Any, + ) -> torch.Tensor: + """Forward pass: code predictor -> embedding -> backbone. + + ``inputs_embeds`` is produced by :meth:`preprocess`: + + * **Prefill**: full prefill embedding sequence for the span. + * **Decode**: zeros — the actual decode embedding + (``codec_emb(group0) + sum(group_emb(group1..N-1)) + + text_proj(text_emb(tts_pad))``) is assembled here on decode + positions only. + + ``prev_hidden`` (backbone output of the previous step) is read from + :attr:`_prev_hidden_buffer`, which is populated by :meth:`preprocess` + at each request's token offset. + + Three regimes: + + * **Profile / dummy run** (``attn_metadata is None``): treat every + token as a decode token so the code predictor and decode-side + embedding assembly get captured in the compiled CUDA graph. + * **Decode-only batch**: every token is a decode token — the + compiled / CUDA-graphed path replays directly. + * **Mixed prefill + decode**: only decode-token positions go through + the code predictor (eager); the assembled decode embeddings are + scattered back into the combined-embedding buffer at those + positions, leaving prefill positions as the prefill embeds. + + Returns: + Backbone ``hidden_states`` tensor. The codec groups 1..N-1 + produced inside this forward live in :attr:`_out_codes` and + are exposed to the runner via :meth:`make_omni_output` (key + ``"audio_codes"``). + """ + num_tokens = input_ids.shape[0] + combined_embeddings = self._combined_embeddings[:num_tokens] + combined_embeddings.copy_(inputs_embeds) + + decode_idx, num_req = self._get_decode_idxs() + group_embeddings = self.code_predictor.get_group_embeddings() + if decode_idx is None: + codes_1_15 = self.code_predictor.generate_groups_1_15( + prev_hidden=self._prev_hidden_buffer[:num_tokens], + group0_tokens=input_ids, + ) + self._out_codes[: codes_1_15.shape[0], 1:] = codes_1_15 + # Assemble decode embedding in-place on top of the (zero) + # ``inputs_embeds`` produced by ``preprocess``: group-0 codec + # embedding + tts_pad text embedding + sum of groups 1..N-1. + combined_embeddings.add_(self.code_predictor.codec_embedding(input_ids)) + combined_embeddings.add_(self._tts_pad_text_embed) + for i in range(len(group_embeddings)): + combined_embeddings.add_(group_embeddings[i](codes_1_15[:, i])) + elif num_req > 0: + # need to overwrite the batch descriptor since we are slicing the inputs + ctx = get_forward_context() + orig_batch_descriptor = ctx.batch_descriptor + ctx.batch_descriptor = BatchDescriptor( + # padded number of requests + num_tokens=decode_idx.shape[0], + ) + codes_1_15 = self.code_predictor.generate_groups_1_15( + prev_hidden=self._prev_hidden_buffer[decode_idx], + group0_tokens=input_ids[decode_idx], + ) + # restore original batch descriptor + ctx.batch_descriptor = orig_batch_descriptor + valid_dec_idx = decode_idx[:num_req] + self._out_codes[valid_dec_idx, 1:] = codes_1_15[:num_req] + # Assemble decode embedding only at decode positions; prefill + # positions keep the full prefill embedding produced by + # ``preprocess``. + decode_group0_ids = input_ids[valid_dec_idx] + decode_embed = self.code_predictor.codec_embedding(decode_group0_ids) + self._tts_pad_text_embed + for i in range(len(group_embeddings)): + decode_embed = decode_embed + group_embeddings[i](codes_1_15[:num_req, i]) + combined_embeddings[valid_dec_idx] = decode_embed + + # Qwen3Model.forward(input_ids, positions, intermediate_tensors, + # inputs_embeds): when ``inputs_embeds`` is provided, ``input_ids`` + # is ignored and the embedded sequence is fed directly into the + # decoder layers. + hidden_states = self.model( + input_ids, + positions, + intermediate_tensors, + combined_embeddings, + ) + + # save input ids to the output codes + self._out_codes[: input_ids.shape[0], 0] = input_ids + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + """Compute group-0 logits for vLLM sampling. + + Projects backbone hidden states through ``codec_head``, applies the + ``suppress_mask`` to block reserved token IDs, and returns logits + of shape ``[batch, vocab_size]``. + """ + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + logits = self.logits_processor(self.codec_head, hidden_states) + logits = logits.masked_fill(self.suppress_mask.bool(), float("-inf")) + return logits + + def make_omni_output( + self, + model_outputs: torch.Tensor | OmniOutput, + **_: Any, + ) -> OmniOutput: + """Wrap backbone hidden states with the codec groups 1..N-1. + + The codes produced inside :meth:`forward` live in :attr:`_out_codes`; + we slice the first ``num_tokens`` rows here and expose them under + the conventional ``"audio_codes"`` multimodal key consumed by + :class:`OmniGPUModelRunner`. + + ``last_talker_hidden`` (state needed by the *next* step's code + predictor) is **not** part of the omni output — it is stashed into + ``model_intermediate_buffer`` by :meth:`postprocess` and read back + by :meth:`preprocess` on the next decode step. + """ + if isinstance(model_outputs, OmniOutput): + return model_outputs + + hidden = model_outputs + num_tokens = int(hidden.shape[0]) + audio_codes = self._out_codes[:num_tokens] + return OmniOutput( + text_hidden_states=hidden, + multimodal_outputs={"audio_codes": audio_codes}, + ) + + # ------------------------------------------------------------------ + # Preprocess / postprocess (CustomVoice, non-streaming text only) + # ------------------------------------------------------------------ + + @staticmethod + def _first_str(value: Any) -> str: + """Return the first element of a list-wrapped scalar, or the scalar itself.""" + if isinstance(value, list): + return str(value[0]) if value else "" + if value is None: + return "" + return str(value) + + @staticmethod + def _build_assistant_text(text: str) -> str: + return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + + @staticmethod + def estimate_prompt_len_from_additional_information( + additional_information: dict[str, Any] | None, + *, + task_type: str, + tokenize_prompt: Callable[[str], list[int]], + codec_language_id: Mapping[str, int] | None, + spk_is_dialect: Mapping[str, object] | None = None, + ) -> int: + """Compute the Stage-0 placeholder ``prompt_token_ids`` length. + + This is a length-only mirror of :meth:`_build_prompt_embeds` — it + must match the embedding sequence that will actually be produced by + :meth:`preprocess` during prefill, otherwise the vLLM scheduler will + either truncate or over-schedule the prefill span. + + The NV talker only supports ``task_type="CustomVoice"`` with the + non-streaming text layout, so the formula is just: + + * ``role_len`` = 3 (``<|im_start|>assistant\\n``) + * ``codec_input_len`` = ``prefill_len + 1 (speaker) + 2 (pad, bos)`` + * ``codec_prefix_len`` = ``codec_input_len - 1`` + * ``text_body_len`` = ``assistant_len - 8 + 1`` (full text + ``tts_eos``) + * ``tail_len`` = 1 (``(tts_pad, codec_bos)``) + + where ``prefill_len`` is 4 if the language (or speaker's dialect + fallback) resolves in ``codec_language_id``, else 3. + + Args: + additional_information: Same dict that will be passed through the + request. Only ``text``, ``language`` and ``speaker`` are + inspected (all accepted as either scalars or ``[value]``). + task_type: Must be ``"CustomVoice"`` — anything else raises. + tokenize_prompt: Callable that returns the HF token IDs for a + given string (e.g. ``lambda s: tokenizer(s)["input_ids"]``). + codec_language_id: ``talker_config.codec_language_id`` mapping. + spk_is_dialect: Optional ``talker_config.spk_is_dialect`` mapping + for the Auto/Chinese + dialect-voice fallback. + + Returns: + Exact prefill length (in codec-time positions) that + :meth:`_build_prompt_embeds` will produce for this request. + """ + if task_type != "CustomVoice": + raise ValueError(f"Qwen3-TTS NV talker only supports task_type='CustomVoice', got {task_type!r}.") + + def _first(x: object, default: object) -> object: + if isinstance(x, list): + return x[0] if x else default + return x if x is not None else default + + info: dict[str, Any] = additional_information or {} + text = _first(info.get("text"), "") + language = _first(info.get("language"), "Auto") + speaker = _first(info.get("speaker"), "") + + if not isinstance(text, str) or not text: + raise ValueError( + "estimate_prompt_len_from_additional_information requires non-empty additional_information['text']." + ) + if not isinstance(speaker, str) or not speaker: + raise ValueError( + "estimate_prompt_len_from_additional_information requires " + "additional_information['speaker'] (CustomVoice only)." + ) + if not isinstance(language, str): + language = "Auto" + + lang_map: Mapping[str, int] = codec_language_id or {} + dialect_map: Mapping[str, object] = spk_is_dialect or {} + language_id: int | None = None + if language.lower() != "auto": + language_id = lang_map.get(language.lower()) + if language_id is None and language.lower() in ("auto", "chinese"): + dialect = dialect_map.get(speaker.lower().strip()) + if isinstance(dialect, str) and dialect: + language_id = lang_map.get(dialect) + prefill_len = 4 if language_id is not None else 3 + + # Mirrors _build_prompt_embeds: role (3) + codec_prefix + # (codec_input_len - 1) + (text_body_len = assistant_len - 8 + 1) + # + tail (1) = prefill_len + assistant_len - 1. + assistant_len = len(tokenize_prompt(Qwen3TTSTalkerForConditionalGenerationNv._build_assistant_text(text))) + if assistant_len < 8: + raise ValueError(f"Unexpected assistant prompt length: {assistant_len}") + return prefill_len + assistant_len - 1 + + def _build_prompt_embeds( + self, + *, + text: str, + speaker: str, + language: str | None, + ) -> torch.Tensor: + """Build the full prefill embedding sequence for CustomVoice + non-streaming. + + Mirrors the HuggingFace ``Qwen3TTSForConditionalGeneration.generate`` + layout (``task_type='CustomVoice'``, ``non_streaming_mode=True``): + + 1. Role header ``<|im_start|>assistant\\n`` -> projected text embeds. + 2. Codec control prefix (``codec_think[_eos/_bos/_nothink]`` + optional + language_id + speaker_id + ``codec_pad`` + ``codec_bos``), mixed + with ``tts_pad``/``tts_bos`` on the text side. + 3. Full synthesis text + ``tts_eos`` on the text side, ``codec_pad`` + on the codec side. + 4. Final ``(tts_pad, codec_bos)`` tail position. + + Returns: + ``[prompt_len, hidden]`` bfloat16 tensor on the model's device. + """ + tc = self.config # talker config + hf = self.hf_config # parent Qwen3TTSConfig + device = next(self.parameters()).device + speaker_key = speaker.lower().strip() + + input_ids = self.tokenizer( + self._build_assistant_text(text), + return_tensors="pt", + padding=False, + )["input_ids"].to(device=device) + + # tts special-token projected embeddings (bos / eos / pad). + tts_tokens = torch.tensor( + [[hf.tts_bos_token_id, hf.tts_eos_token_id, hf.tts_pad_token_id]], + device=device, + dtype=torch.long, + ) + tts_bos_embed, tts_eos_embed, tts_pad_embed = self.text_projection(self.text_embedding(tts_tokens)).chunk( + 3, dim=1 + ) + + # Codec control prefix: choose with/without language_id. + language_id: int | None = None + lang_map = getattr(tc, "codec_language_id", None) or {} + if isinstance(language, str) and language.lower() != "auto": + language_id = lang_map.get(language.lower()) + # Dialect fallback (official behavior): if Chinese/Auto + speaker is a + # known dialect voice, promote language_id to that dialect. + if language_id is None and isinstance(language, str) and language.lower() in ("auto", "chinese"): + spk_is_dialect = getattr(tc, "spk_is_dialect", None) or {} + dialect = spk_is_dialect.get(speaker_key) + if isinstance(dialect, str) and dialect: + language_id = lang_map.get(dialect) + + if language_id is None: + codec_prefill = [ + tc.codec_nothink_id, + tc.codec_think_bos_id, + tc.codec_think_eos_id, + ] + else: + codec_prefill = [ + tc.codec_think_id, + tc.codec_think_bos_id, + int(language_id), + tc.codec_think_eos_id, + ] + + codec_input_0 = self.code_predictor.codec_embedding( + torch.tensor([codec_prefill], device=device, dtype=torch.long) + ) + codec_input_1 = self.code_predictor.codec_embedding( + torch.tensor( + [[tc.codec_pad_id, tc.codec_bos_id]], + device=device, + dtype=torch.long, + ) + ) + + spk_map = {k.lower(): v for k, v in (getattr(tc, "spk_id", None) or {}).items()} + if speaker_key not in spk_map: + raise ValueError(f"Unsupported CustomVoice speaker: {speaker!r} (known: {sorted(spk_map) or ''})") + speaker_embed = self.code_predictor.codec_embedding( + torch.tensor([[spk_map[speaker_key]]], device=device, dtype=torch.long) + ) + + codec_input = torch.cat([codec_input_0, speaker_embed, codec_input_1], dim=1) + + role_embed = self.text_projection(self.text_embedding(input_ids[:, :3])) + codec_prefix = torch.cat( + ( + tts_pad_embed.expand(-1, codec_input.shape[1] - 2, -1), + tts_bos_embed, + ), + dim=1, + ) + codec_prefix = codec_prefix + codec_input[:, :-1] + talker_prompt = torch.cat((role_embed, codec_prefix), dim=1) + + # Non-streaming: full synth text in prefill + (tts_pad, codec_bos) tail. + text_all = self.text_projection(self.text_embedding(input_ids[:, 3:-5])) + text_all = torch.cat([text_all, tts_eos_embed], dim=1) + pad_ids = torch.full( + (1, int(text_all.shape[1])), + int(tc.codec_pad_id), + device=device, + dtype=torch.long, + ) + tail_codec_bos = torch.tensor([[tc.codec_bos_id]], device=device, dtype=torch.long) + talker_prompt = torch.cat( + [ + talker_prompt, + text_all + self.code_predictor.codec_embedding(pad_ids), + tts_pad_embed + self.code_predictor.codec_embedding(tail_codec_bos), + ], + dim=1, + ) + + return talker_prompt.squeeze(0).to(dtype=torch.bfloat16).contiguous() + + def preprocess( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor | None, + *, + start: int = 0, + end: int = 0, + **info_dict: Any, + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + """Build per-request ``(input_ids, inputs_embeds)`` for this step. + + CustomVoice + non-streaming only. + + Args: + input_ids: This request's slice of the flat batch's token ids. + input_embeds: Corresponding slice of the flat-batch + ``inputs_embeds`` if the runner already populated one. + start: This request's start position in the flat batch + (``query_start_loc[req_index]``). Provided by the runner; + used to index :attr:`_prev_hidden_buffer` at decode. + end: This request's end position in the flat batch + (``start + sched_tokens``). Provided by the runner; not + used here directly (``input_ids.shape[0] == end - start``) + but accepted for runner-contract symmetry. + **info_dict: The request's ``additional_information`` plus + runner-provided ``request_id`` and any state previously + stashed by this method (e.g. ``talker_prompt_embeds``, + ``talker_prefill_offset``, ``last_talker_hidden``). + + Prefill (``span_len > 1``): + On the first prefill call, builds the full prompt embedding once + (see :meth:`_build_prompt_embeds`) and stashes it under + ``talker_prompt_embeds`` (CPU). On subsequent chunks, slices from + that buffer using ``talker_prefill_offset``. ``input_ids`` are + filled with ``codec_pad`` placeholders since the code predictor + doesn't run during prefill. + + Decode (``span_len == 1``): + Returns ``inputs_embeds`` of zeros — the actual decode + embedding (``codec_emb(group0) + sum(group_emb(group1..N-1)) + + text_proj(text_emb(tts_pad))``) is assembled inside + :meth:`forward` at decode positions only. ``input_ids`` (the + group-0 token sampled by vLLM) is passed through unchanged. + The previous-step backbone hidden (``last_talker_hidden``, + produced by :meth:`postprocess`) is written into + :attr:`_prev_hidden_buffer` at ``start`` for the code predictor + to read. + """ + # Normalize: some runner paths still pass per-request state nested + # under ``additional_information`` instead of flattened. + nested = info_dict.get("additional_information") + if isinstance(nested, dict): + merged = {k: v for k, v in info_dict.items() if k != "additional_information"} + for k, v in nested.items(): + merged.setdefault(k, v) + info_dict = merged + + span_len = int(input_ids.shape[0]) + if span_len <= 0: + base = input_embeds if input_embeds is not None else self.embed_input_ids(input_ids) + return input_ids, base, {} + + tc = self.config + device = input_ids.device + + # ----- Prefill ------------------------------------------------- + if span_len > 1: + text = self._first_str(info_dict.get("text")) + if not text: + raise ValueError("Qwen3-TTS NV talker.preprocess requires additional_information.text for prefill.") + speaker = self._first_str(info_dict.get("speaker")) + if not speaker: + raise ValueError( + "Qwen3-TTS NV talker.preprocess requires additional_information.speaker (CustomVoice only)." + ) + language = self._first_str(info_dict.get("language")) or "Auto" + + prompt_embeds_cpu = info_dict.get("talker_prompt_embeds") + is_first = not isinstance(prompt_embeds_cpu, torch.Tensor) or prompt_embeds_cpu.ndim != 2 + if is_first: + full = self._build_prompt_embeds(text=text, speaker=speaker, language=language) + prompt_embeds_cpu = full.detach().to("cpu").contiguous() + offset = 0 + info_update: dict[str, Any] = { + "talker_prompt_embeds": prompt_embeds_cpu, + "talker_prefill_offset": 0, + } + else: + offset = int(info_dict.get("talker_prefill_offset", 0) or 0) + info_update = {} + + # Slice the span out of the stored prefill buffer; pad with the + # last row if the scheduled chunk overshoots (shouldn't happen + # when the placeholder length matches the true prefill length). + s = max(0, min(offset, int(prompt_embeds_cpu.shape[0]))) + e = max(0, min(offset + span_len, int(prompt_embeds_cpu.shape[0]))) + take = prompt_embeds_cpu[s:e] + if int(take.shape[0]) < span_len: + pad_n = span_len - int(take.shape[0]) + if take.shape[0] > 0: + pad_rows = take[-1:].expand(pad_n, -1) + else: + pad_rows = torch.zeros( + (pad_n, prompt_embeds_cpu.shape[-1]), + dtype=prompt_embeds_cpu.dtype, + ) + take = torch.cat([take, pad_rows], dim=0) + prompt_embeds = take.to(device=device, dtype=torch.bfloat16) + info_update["talker_prefill_offset"] = offset + span_len + + # input_ids for prefill: codec_pad placeholder (code predictor + # is skipped for prefill positions, so the exact value doesn't + # matter as long as it's a valid codec token). + input_ids_out = torch.full_like(input_ids, int(tc.codec_pad_id)) + return input_ids_out, prompt_embeds, info_update + + # ----- Decode (span_len == 1) --------------------------------- + # The decode embedding is assembled inside :meth:`forward` (where + # we have visibility of decode-vs-prefill positions and the codes + # produced by the code predictor). Here we just return zeros that + # ``forward`` will accumulate the real embedding into. + inputs_embeds_out = torch.zeros( + (1, self.config.hidden_size), + device=device, + dtype=self._combined_embeddings.dtype, + ) + + # prev_hidden for the code predictor: ``last_talker_hidden`` stashed + # by postprocess. When missing, we leave the slot untouched (first + # decode step after prefill will have it available since postprocess + # runs after every forward, including prefill). + last_hidden = info_dict.get("last_talker_hidden") + if isinstance(last_hidden, torch.Tensor) and last_hidden.numel() > 0: + prev_h = last_hidden.to(device=device, dtype=self._prev_hidden_buffer.dtype).reshape(1, -1) + self._prev_hidden_buffer[start : start + 1].copy_(prev_h) + + return input_ids, inputs_embeds_out, {} + + def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]: + """Stash the last backbone hidden as ``last_talker_hidden`` for the next step.""" + if hidden_states.numel() == 0: + return {} + last = hidden_states[-1, :].detach() + return {"last_talker_hidden": last} + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights directly from a raw HuggingFace Qwen3-TTS checkpoint. + + No offline conversion is required: this method renames the + ``talker.*`` weights to the vLLM module layout (see + :attr:`hf_to_vllm_mapper`), drops unrelated checkpoint sections + (``speaker_encoder.*``, etc.), and then computes the two + derived buffers that the runtime needs: + + * :attr:`suppress_mask`: a boolean ``[vocab_size]`` mask that + blocks the top-1024 reserved token IDs (except + ``codec_eos_token_id``) when sampling group-0. + * :attr:`_tts_pad_text_embed`: the projected text embedding of + ``tts_pad_token_id`` — a request-independent constant added + on top of ``codec_emb(group0)`` at every decode step. + """ + # Filter to talker weights only (skip speaker_encoder.* etc). + talker_weights = ((name, w) for name, w in weights if name.startswith("talker.")) + + # ``suppress_mask`` is a Parameter we initialise ourselves below; + # if a converted checkpoint happens to carry one, ignore it. + loader = AutoWeightsLoader(self, skip_prefixes=["suppress_mask"]) + loaded = loader.load_weights(talker_weights, mapper=self.hf_to_vllm_mapper) + + self._init_runtime_buffers() + + # Mark the parameters we initialise without a checkpoint weight + # as "loaded" so the strict-loading check in + # ``DefaultModelLoader.load_weights`` doesn't flag them. These + # are populated either in ``__init__`` (rotary inv_freq) or in + # ``_init_runtime_buffers`` (suppress_mask). + # + # ``model.embed_tokens`` is created by ``Qwen3Model`` but is never + # invoked in this model — every prefill / decode step feeds the + # backbone via ``inputs_embeds`` assembled from + # ``code_predictor.codec_embedding`` and ``text_embedding``. + # Skip it from the strict-load check. + loaded.add("suppress_mask") + loaded.add("model.embed_tokens.weight") + for name, _ in self.named_parameters(): + if name.endswith("rotary_emb.inv_freq"): + loaded.add(name) + + logger.info( + "Loaded %d weights for Qwen3TTSTalkerForConditionalGenerationNv", + len(loaded), + ) + return loaded + + @torch.no_grad() + def _init_runtime_buffers(self) -> None: + """Populate :attr:`suppress_mask` and :attr:`_tts_pad_text_embed`. + + Called from :meth:`load_weights` once the underlying parameters + have been filled, so :meth:`text_projection` and + :meth:`model.get_text_embeddings` can be evaluated to derive the + constant ``tts_pad`` embedding used on every decode step. + """ + tc = self.config + hf = self.hf_config + + # Top-1024 token IDs are reserved/invalid; suppress them at + # group-0 sampling time, except for ``codec_eos`` which must + # remain reachable as an end-of-stream signal. + vocab_size = int(tc.vocab_size) + codec_eos = int(getattr(tc, "codec_eos_token_id", -1)) + mask = torch.zeros(vocab_size, dtype=torch.bool, device=self.suppress_mask.device) + suppress_start = vocab_size - 1024 + if suppress_start > 0: + mask[suppress_start:] = True + if suppress_start <= codec_eos < vocab_size: + mask[codec_eos] = False + self.suppress_mask.copy_(mask) + + # Precompute ``text_proj(text_emb(tts_pad_token_id))`` — added + # to ``codec_emb(group0)`` at every decode step; depends only on + # frozen weights so we evaluate it once here. + device = next(self.parameters()).device + pad_id = int(hf.tts_pad_token_id) + pad_ids = torch.tensor([[pad_id]], device=device, dtype=torch.long) + text_emb = self.text_embedding(pad_ids) + pad_proj = self.text_projection(text_emb).reshape(1, -1) + self._tts_pad_text_embed.copy_( + pad_proj.to( + device=self._tts_pad_text_embed.device, + dtype=self._tts_pad_text_embed.dtype, + ) + ) diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index cc900d471b0..c66a4e2557b 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -97,6 +97,11 @@ "qwen3_tts_talker", "Qwen3TTSTalkerForConditionalGeneration", ), + "Qwen3TTSTalkerForConditionalGenerationNv": ( + "qwen3_tts_nv", + "qwen3_tts_talker_nv", + "Qwen3TTSTalkerForConditionalGenerationNv", + ), "Qwen3TTSCode2Wav": ( "qwen3_tts", "qwen3_tts_code2wav", diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index eae7c0a11a3..8d249a16afb 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1483,7 +1483,11 @@ def flush_decode_batch() -> None: embed_slice = inputs_embeds[s:e] if inputs_embeds is not None else None req_input_ids, req_embeds, update_dict = self.model.preprocess( - input_ids=input_ids[s:e], input_embeds=embed_slice, **req_infos + input_ids=input_ids[s:e], + input_embeds=embed_slice, + start=s, + end=e, + **req_infos, ) if inputs_embeds is None: inputs_embeds = torch.empty(