diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index d0566e18b00..c50bbcfb7d3 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -176,7 +176,7 @@ steps: # type: DirectoryOrCreate - label: "Omni Model Test" - timeout_in_minutes: 15 + timeout_in_minutes: 20 depends_on: image-build commands: - export VLLM_LOGGING_LEVEL=DEBUG diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index 93aeba3ca5f..12e5e193542 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -4,18 +4,21 @@ tasks, then runs Omni generation and saves output wav files. """ +import logging import os -from typing import NamedTuple +from typing import Any, NamedTuple import soundfile as sf +import torch os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" -from vllm import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm_omni import Omni +logger = logging.getLogger(__name__) + class QueryResult(NamedTuple): """Container for a prepared Omni request.""" @@ -24,6 +27,44 @@ class QueryResult(NamedTuple): model_name: str +def _estimate_prompt_len( + additional_information: dict[str, Any], + model_name: str, + _cache: dict[str, Any] = {}, +) -> int: + """Estimate prompt_token_ids placeholder length for the Talker stage. + + The AR Talker replaces all input embeddings via ``preprocess``, so the + placeholder values are irrelevant but the **length** must match the + embeddings that ``preprocess`` will produce. + """ + try: + from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( + Qwen3TTSTalkerForConditionalGeneration, + ) + + if model_name not in _cache: + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left") + cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True) + _cache[model_name] = (tok, getattr(cfg, "talker_config", None)) + + tok, tcfg = _cache[model_name] + task_type = (additional_information.get("task_type") or ["CustomVoice"])[0] + return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information( + additional_information=additional_information, + task_type=task_type, + tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"], + codec_language_id=getattr(tcfg, "codec_language_id", None), + spk_is_dialect=getattr(tcfg, "spk_is_dialect", None), + ) + except Exception as exc: + logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc) + return 2048 + + def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult: """Build CustomVoice sample inputs. @@ -34,6 +75,7 @@ def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult: QueryResult with Omni inputs and the CustomVoice model path. """ task_type = "CustomVoice" + model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" if use_batch_sample: texts = ["其实我真的有发现,我是一个特别善于观察别人情绪的人。", "She said she would be here by noon."] instructs = ["", "Very happy."] @@ -41,18 +83,18 @@ def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult: speakers = ["Vivian", "Ryan"] inputs = [] for text, instruct, language, speaker in zip(texts, instructs, languages, speakers): - prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "text": [text], + "instruct": [instruct], + "language": [language], + "speaker": [speaker], + "max_new_tokens": [2048], + } inputs.append( { - "prompt": prompt, - "additional_information": { - "task_type": [task_type], - "text": [text], - "instruct": [instruct], - "language": [language], - "speaker": [speaker], - "max_new_tokens": [2048], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } ) else: @@ -60,21 +102,21 @@ def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult: language = "Chinese" speaker = "Vivian" instruct = "用特别愤怒的语气说" - prompts = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "text": [text], + "language": [language], + "speaker": [speaker], + "instruct": [instruct], + "max_new_tokens": [2048], + } inputs = { - "prompt": prompts, - "additional_information": { - "task_type": [task_type], - "text": [text], - "language": [language], - "speaker": [speaker], - "instruct": [instruct], - "max_new_tokens": [2048], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } return QueryResult( inputs=inputs, - model_name="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", + model_name=model_name, ) @@ -88,6 +130,7 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult: QueryResult with Omni inputs and the VoiceDesign model path. """ task_type = "VoiceDesign" + model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign" if use_batch_sample: texts = [ "哥哥,你回来啦,人家等了你好久好久了,要抱抱!", @@ -100,39 +143,39 @@ def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult: languages = ["Chinese", "English"] inputs = [] for text, instruct, language in zip(texts, instructs, languages): - prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "text": [text], + "language": [language], + "instruct": [instruct], + "max_new_tokens": [2048], + "non_streaming_mode": [True], + } inputs.append( { - "prompt": prompt, - "additional_information": { - "task_type": [task_type], - "text": [text], - "language": [language], - "instruct": [instruct], - "max_new_tokens": [2048], - "non_streaming_mode": [True], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } ) else: text = "哥哥,你回来啦,人家等了你好久好久了,要抱抱!" instruct = "体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。" language = "Chinese" - prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "text": [text], + "language": [language], + "instruct": [instruct], + "max_new_tokens": [2048], + "non_streaming_mode": [True], + } inputs = { - "prompt": prompt, - "additional_information": { - "task_type": [task_type], - "text": [text], - "language": [language], - "instruct": [instruct], - "max_new_tokens": [2048], - "non_streaming_mode": [True], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } return QueryResult( inputs=inputs, - model_name="Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", + model_name=model_name, ) @@ -147,6 +190,7 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que QueryResult with Omni inputs and the Base model path. """ task_type = "Base" + model_name = "Qwen/Qwen3-TTS-12Hz-1.7B-Base" ref_audio_path_1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav" ref_audio_single = ref_audio_path_1 ref_text_single = ( @@ -163,38 +207,38 @@ def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> Que syn_lang_batch = ["Chinese", "English"] inputs = [] for text, language in zip(syn_text_batch, syn_lang_batch): - prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "ref_audio": [ref_audio_single], + "ref_text": [ref_text_single], + "text": [text], + "language": [language], + "x_vector_only_mode": [x_vector_only_mode], + "max_new_tokens": [2048], + } inputs.append( { - "prompt": prompt, - "additional_information": { - "task_type": [task_type], - "ref_audio": [ref_audio_single], - "ref_text": [ref_text_single], - "text": [text], - "language": [language], - "x_vector_only_mode": [x_vector_only_mode], - "max_new_tokens": [2048], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } ) else: - prompt = f"<|im_start|>assistant\n{syn_text_single}<|im_end|>\n<|im_start|>assistant\n" + additional_information = { + "task_type": [task_type], + "ref_audio": [ref_audio_single], + "ref_text": [ref_text_single], + "text": [syn_text_single], + "language": [syn_lang_single], + "x_vector_only_mode": [x_vector_only_mode], + "max_new_tokens": [2048], + } inputs = { - "prompt": prompt, - "additional_information": { - "task_type": [task_type], - "ref_audio": [ref_audio_single], - "ref_text": [ref_text_single], - "text": [syn_text_single], - "language": [syn_lang_single], - "x_vector_only_mode": [x_vector_only_mode], - "max_new_tokens": [2048], - }, + "prompt_token_ids": [0] * _estimate_prompt_len(additional_information, model_name), + "additional_information": additional_information, } return QueryResult( inputs=inputs, - model_name="Qwen/Qwen3-TTS-12Hz-1.7B-Base", + model_name=model_name, ) @@ -223,30 +267,22 @@ def main(args): stage_init_timeout=args.stage_init_timeout, ) - sampling_params = SamplingParams( - temperature=0.9, - top_p=1.0, - top_k=50, - max_tokens=2048, - seed=42, - detokenize=False, - repetition_penalty=1.05, - ) - - sampling_params_list = [ - sampling_params, - ] - output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav os.makedirs(output_dir, exist_ok=True) - omni_generator = omni.generate(query_result.inputs, sampling_params_list) + omni_generator = omni.generate(query_result.inputs, sampling_params_list=None) for stage_outputs in omni_generator: for output in stage_outputs.request_output: request_id = output.request_id - audio_tensor = output.outputs[0].multimodal_output["audio"] + audio_data = output.outputs[0].multimodal_output["audio"] + # async_chunk mode returns a list of chunks; concatenate them. + if isinstance(audio_data, list): + audio_tensor = torch.cat(audio_data, dim=-1) + else: + audio_tensor = audio_data output_wav = os.path.join(output_dir, f"output_{request_id}.wav") - audio_samplerate = output.outputs[0].multimodal_output["sr"].item() + sr_val = output.outputs[0].multimodal_output["sr"] + audio_samplerate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val[-1]) # Convert to numpy array and ensure correct format audio_numpy = audio_tensor.float().detach().cpu().numpy() diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py index 2db98c06869..e55ff6812df 100644 --- a/tests/entrypoints/openai_api/test_serving_speech.py +++ b/tests/entrypoints/openai_api/test_serving_speech.py @@ -310,10 +310,12 @@ def test_is_tts_model(self, speech_server): speech_server.engine_client.stage_list = [mock_stage] assert speech_server._is_tts_model() is True - def test_build_tts_prompt(self, speech_server): - """Test TTS prompt format.""" - prompt = speech_server._build_tts_prompt("Hello") - assert prompt == "<|im_start|>assistant\nHello<|im_end|>\n<|im_start|>assistant\n" + def test_estimate_prompt_len_fallback(self, speech_server): + """Test prompt length estimation falls back to 2048 when model is unavailable.""" + tts_params = {"text": ["Hello"], "task_type": ["CustomVoice"]} + result = speech_server._estimate_prompt_len(tts_params) + # Without a real model, it should fall back to 2048. + assert result == 2048 def test_validate_tts_request_basic(self, speech_server): """Test basic validation cases.""" diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index c7836123a64..9b5052b464a 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -69,6 +69,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4): runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) runner.talker_mtp = DummyTalkerMTP() + runner.model = SimpleNamespace(talker_mtp_output_key="code_predictor_codes") runner.vllm_config = object() # Provide a minimal implementation that returns the expected 4-tuple. diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index a9ffa015fe1..f13a90bb7f0 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -58,6 +58,7 @@ class OmniModelConfig(ModelConfig): } ) omni_kv_config: dict | None = None + codec_frame_rate_hz: float | None = None @property def registry(self): @@ -128,6 +129,21 @@ def __post_init__( video_pruning_rate=video_pruning_rate, ) + # Qwen3-TTS: infer codec frame rate from the model config for online serving. + if self.codec_frame_rate_hz is None and self.model_arch == "Qwen3TTSTalkerForConditionalGenerationARVLLM": + talker_cfg = getattr(self.hf_config, "talker_config", None) + if isinstance(talker_cfg, dict): + pos_per_sec = talker_cfg.get("position_id_per_seconds") + else: + pos_per_sec = getattr(talker_cfg, "position_id_per_seconds", None) + if pos_per_sec is not None: + try: + fps = float(pos_per_sec) + except Exception: + fps = None + if fps is not None and fps > 0: + self.codec_frame_rate_hz = fps + # Override hf_text_config with omni-specific logic for multi-stage models # (e.g., thinker_config, talker_config) new_hf_text_config = self.draw_hf_text_config() diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 684aab9ce20..ef1c4c7c901 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -62,7 +62,6 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] # OMNI: Skip requests that are not in self.requests - # This can happen when connector marks request as finished and it's removed from requests if request.request_id not in self.requests or ( self.chunk_transfer_adapter is None and request.status == RequestStatus.FINISHED_STOPPED ): @@ -71,7 +70,25 @@ def schedule(self) -> SchedulerOutput: continue num_computed_tokens = request.num_computed_tokens - required_tokens = max(len(request.prompt_token_ids) - num_computed_tokens, 1) + required_tokens = len(request.prompt_token_ids) - num_computed_tokens + # async_chunk: don't schedule placeholder tokens when no new chunk is available. + if required_tokens <= 0: + if ( + self.chunk_transfer_adapter is not None + and request.request_id in self.chunk_transfer_adapter.finished_requests + ): + request.status = RequestStatus.FINISHED_STOPPED + # Upstream may finish with no terminal tokens; append one pad token so we can emit FINISHED. + if len(request.prompt_token_ids) <= num_computed_tokens: + request.prompt_token_ids.append(0) + try: + request._all_token_ids.append(0) # type: ignore[attr-defined] + except Exception: + pass + required_tokens = len(request.prompt_token_ids) - num_computed_tokens + else: + req_index += 1 + continue num_new_tokens = min(required_tokens, token_budget) new_blocks = self.kv_cache_manager.allocate_slots( request, @@ -109,6 +126,20 @@ def schedule(self) -> SchedulerOutput: self.waiting.pop_request() continue + # async_chunk: wait for the first upstream chunk (don't start with placeholders). + if self.chunk_transfer_adapter is not None and len(request.prompt_token_ids) == 0: + if request.request_id in self.chunk_transfer_adapter.finished_requests: + request.status = RequestStatus.FINISHED_STOPPED + request.prompt_token_ids.append(0) + try: + request._all_token_ids.append(0) # type: ignore[attr-defined] + except Exception: + pass + else: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + # Uniformly treat as diffusion. A feature flag can be added later # via config or request tag. @@ -145,11 +176,13 @@ def schedule(self) -> SchedulerOutput: # If fast path scheduled none, fall back to the original scheduling if not num_scheduled_tokens: - res = super().schedule() if self.chunk_transfer_adapter: + # Don't fall back: base scheduler doesn't handle async_chunk + # requests with empty prompt_token_ids. self.chunk_transfer_adapter.restore_queues(self.waiting, self.running) - self.chunk_transfer_adapter.postprocess_scheduler_output(res) - return res + else: + res = super().schedule() + return res # Compute common prefix blocks (aligned with v1) num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) @@ -318,9 +351,8 @@ def update_from_output( continue request = self.requests.get(req_id) if request is None or request.is_finished(): - # The request is already finished. This can happen if the - # request is aborted while the model is executing it (e.g., - # in pipeline parallelism or async scheduling). + # Request may already be finished (e.g., aborted during + # execution / pipeline parallelism / async scheduling). continue req_index = model_runner_output.req_id_to_index[req_id] @@ -360,8 +392,14 @@ def update_from_output( routed_experts = None # Diffusion request: completes in one step; mark finished and free resources - if request.status == RequestStatus.FINISHED_STOPPED or ( - self.chunk_transfer_adapter is None and request.num_computed_tokens >= request.num_prompt_tokens + if ( + request.status == RequestStatus.FINISHED_STOPPED + or (self.chunk_transfer_adapter is None and request.num_computed_tokens >= request.num_prompt_tokens) + or ( + self.chunk_transfer_adapter is not None + and request.request_id in self.chunk_transfer_adapter.finished_requests + and request.num_computed_tokens >= len(request.prompt_token_ids) + ) ): request.status = RequestStatus.FINISHED_STOPPED # Optional: set a stop_reason for front-end clarity @@ -375,15 +413,11 @@ def update_from_output( finished = self._handle_stopped_request(request) if finished: kv_transfer_params = self._free_request(request) - if status_before_stop == RequestStatus.RUNNING: - stopped_running_reqs.add(request) - elif status_before_stop == RequestStatus.WAITING_FOR_CHUNK: - # In async chunk mode, request may be in either queue. - # Remove from both to avoid stale queue entries. + if status_before_stop == RequestStatus.WAITING_FOR_CHUNK: stopped_running_reqs.add(request) stopped_preempted_reqs.add(request) else: - stopped_preempted_reqs.add(request) + stopped_running_reqs.add(request) # Extract sample logprobs if needed. if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs: diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 70d38a9d687..a6afb97bd4c 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -179,11 +179,17 @@ def _poll_single_request(self, req_id: str): else: if payload_data.get("finished"): self.finished_requests.add(req_id) - req.status = RequestStatus.FINISHED_STOPPED - req.prompt_token_ids = payload_data.get("code_predictor_codes", []) + # req.prompt_token_ids = payload_data.get("code_predictor_codes", []) + # req.num_computed_tokens = 0 + new_ids = payload_data.get("code_predictor_codes", []) + req.prompt_token_ids = new_ids req.num_computed_tokens = 0 + # Empty chunk with more data expected: keep polling. + if not new_ids and not payload_data.get("finished"): + return + # Mark as finished for consumption with self.lock: self._finished_load_reqs.add(req_id) @@ -308,7 +314,6 @@ def _process_chunk_queue( # of schedule, but have not scheduled continue if request.request_id in self.finished_requests: - request.additional_information = {} continue # Requests that waiting for chunk self.load_async(request) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 369713d7b68..d00652a658b 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -390,7 +390,8 @@ async def _process_async_results( submit_flag = False prompt_token_ids = engine_outputs.prompt_token_ids engine_input = copy.deepcopy(prompt) - engine_input["prompt_token_ids"] = [0] * compute_talker_prompt_ids_length(prompt_token_ids) + next_prompt_len = max(1, compute_talker_prompt_ids_length(prompt_token_ids)) + engine_input["prompt_token_ids"] = [0] * next_prompt_len engine_input["multi_modal_data"] = engine_input["mm_processor_kwargs"] = None for i in range(1, len(self.stage_list)): task = { diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index a8bae9e9932..201be69dae0 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -1,6 +1,14 @@ import asyncio +import base64 +import io +import ipaddress +import socket from typing import Any +from urllib.parse import urlparse +from urllib.request import urlopen +import numpy as np +import soundfile as sf from fastapi import Request from fastapi.responses import Response from vllm.entrypoints.openai.engine.serving import OpenAIServing @@ -17,6 +25,19 @@ logger = init_logger(__name__) +_REF_AUDIO_TIMEOUT_S = 15 +_REF_AUDIO_MAX_BYTES = 50 * 1024 * 1024 # 50 MB +_REF_AUDIO_BLOCKED_NETWORKS = [ + ipaddress.ip_network("127.0.0.0/8"), + ipaddress.ip_network("10.0.0.0/8"), + ipaddress.ip_network("172.16.0.0/12"), + ipaddress.ip_network("192.168.0.0/16"), + ipaddress.ip_network("169.254.0.0/16"), + ipaddress.ip_network("::1/128"), + ipaddress.ip_network("fc00::/7"), + ipaddress.ip_network("fe80::/10"), +] + # TTS Configuration (currently supports Qwen3-TTS) _TTS_MODEL_STAGES: set[str] = {"qwen3_tts"} _TTS_LANGUAGES: set[str] = { @@ -43,6 +64,7 @@ def __init__(self, *args, **kwargs): # Load supported speakers self.supported_speakers = self._load_supported_speakers() logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}") + self._tts_tokenizer = None def _load_supported_speakers(self) -> set[str]: """Load supported speakers (case-insensitive) from the model configuration.""" @@ -62,6 +84,36 @@ def _load_supported_speakers(self) -> set[str]: return set() + def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int: + """Estimate prompt length so the placeholder matches model-side embeddings.""" + try: + from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( + Qwen3TTSTalkerForConditionalGeneration, + ) + + if self._tts_tokenizer is None: + from transformers import AutoTokenizer + + model_name = self.engine_client.model_config.model + self._tts_tokenizer = AutoTokenizer.from_pretrained( + model_name, + trust_remote_code=True, + padding_side="left", + ) + hf_config = self.engine_client.model_config.hf_config + talker_config = hf_config.talker_config + task_type = (tts_params.get("task_type") or ["CustomVoice"])[0] + return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information( + additional_information=tts_params, + task_type=task_type, + tokenize_prompt=lambda t: self._tts_tokenizer(t, padding=False)["input_ids"], + codec_language_id=getattr(talker_config, "codec_language_id", None), + spk_is_dialect=getattr(talker_config, "spk_is_dialect", None), + ) + except Exception as e: + logger.warning("Failed to estimate TTS prompt length, using fallback 2048: %s", e) + return 2048 + def _is_tts_model(self) -> bool: """Check if the current model is a supported TTS model.""" stage_list = getattr(self.engine_client, "stage_list", None) @@ -125,9 +177,44 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non return None - def _build_tts_prompt(self, text: str) -> str: - """Build TTS prompt from input text.""" - return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + @staticmethod + async def _resolve_ref_audio(ref_audio_str: str) -> tuple[list[float], int]: + """Resolve ref_audio URL/base64 to (wav_samples, sample_rate).""" + parsed = urlparse(ref_audio_str) + + def _check_ssrf(url: str) -> None: + host = urlparse(url).hostname + if not host: + raise ValueError("ref_audio URL must include a hostname") + for info in socket.getaddrinfo(host, None): + ip_str = str(info[4][0]).split("%", 1)[0] + addr = ipaddress.ip_address(ip_str) + if any(addr in net for net in _REF_AUDIO_BLOCKED_NETWORKS): + raise ValueError(f"ref_audio URL resolves to blocked address: {addr}") + + def _fetch_sync() -> tuple[np.ndarray, int]: + if parsed.scheme in ("http", "https"): + _check_ssrf(ref_audio_str) + with urlopen(ref_audio_str, timeout=_REF_AUDIO_TIMEOUT_S) as resp: + data = resp.read(_REF_AUDIO_MAX_BYTES + 1) + if len(data) > _REF_AUDIO_MAX_BYTES: + raise ValueError(f"ref_audio URL exceeds {_REF_AUDIO_MAX_BYTES} bytes") + buf = io.BytesIO(data) + elif ref_audio_str.startswith("data:"): + b64 = ref_audio_str + if "," in b64: + b64 = b64.split(",", 1)[1] + buf = io.BytesIO(base64.b64decode(b64)) + else: + raise ValueError("ref_audio must be an http(s) URL or data: base64 URI") + audio, sr = sf.read(buf, dtype="float32", always_2d=False) + if isinstance(audio, np.ndarray) and audio.ndim > 1: + audio = np.mean(audio, axis=-1) + return np.asarray(audio, dtype=np.float32), int(sr) + + loop = asyncio.get_running_loop() + wav_np, sr = await loop.run_in_executor(None, _fetch_sync) + return wav_np.tolist(), sr def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: """Build TTS parameters from request. @@ -164,9 +251,7 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any else: params["instruct"] = [""] - # Voice clone parameters (used with Base task) - if request.ref_audio is not None: - params["ref_audio"] = [request.ref_audio] + # Voice clone: ref_audio resolved in create_speech(), not here. if request.ref_text is not None: params["ref_text"] = [request.ref_text] if request.x_vector_only_mode is not None: @@ -221,11 +306,19 @@ async def create_speech( if validation_error: return self.create_error_response(validation_error) - # Build TTS parameters and prompt + # Must use prompt_token_ids (not text prompt): the AR Talker + # operates on codec tokens; text token IDs exceed codec vocab. + # model.preprocess replaces all embeddings, so placeholder value + # is irrelevant -- but length must match to avoid excess padding. tts_params = self._build_tts_params(request) - prompt_text = self._build_tts_prompt(request.input) + + if request.ref_audio is not None: + wav_list, sr = await self._resolve_ref_audio(request.ref_audio) + tts_params["ref_audio"] = [[wav_list, sr]] + + ph_len = self._estimate_prompt_len(tts_params) prompt = { - "prompt": prompt_text, + "prompt_token_ids": [1] * ph_len, "additional_information": tts_params, } else: @@ -282,6 +375,11 @@ async def create_speech( if hasattr(sample_rate, "item"): sample_rate = sample_rate.item() + # Streaming accumulates chunks as a list; concat first. + if isinstance(audio_tensor, list): + import torch + + audio_tensor = torch.cat(audio_tensor, dim=-1) # Convert tensor to numpy if hasattr(audio_tensor, "float"): audio_tensor = audio_tensor.float().detach().cpu().numpy() diff --git a/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py index dde69006865..8e751413767 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py +++ b/vllm_omni/model_executor/models/qwen3_tts/configuration_qwen3_tts.py @@ -504,19 +504,39 @@ def __init__( self.tts_bos_token_id = tts_bos_token_id self.tts_eos_token_id = tts_eos_token_id - # TODO: remove these dummy values after - self.image_token_id = 0 # dummy image token id - self.video_token_id = 0 # dummy video token id - self.vision_start_token_id = 0 # dummy vision start token id + # Dummy vision token IDs that must never collide with real codec tokens. + # mrope scans prompt_token_ids for these; using -1 ensures no false match. + self.image_token_id = -1 + self.video_token_id = -1 + self.vision_start_token_id = -1 self.vision_config = PretrainedConfig() # dummy vision config self.vision_config.spatial_merge_size = 1 + @property + def codec_frame_rate_hz(self) -> float | None: + pos_per_sec = getattr(self.talker_config, "position_id_per_seconds", None) + if pos_per_sec is None: + return None + try: + fps = float(pos_per_sec) + except (TypeError, ValueError): + return None + return fps if fps > 0 else None + def get_text_config(self, **kwargs): # vLLM expects text config to expose hidden_size/num_attention_heads. # For Qwen3 TTS, the talker config is the text model config. config = self.talker_config - # if hasattr(config, "rope_parameters"): - # delattr(config, "rope_parameters") + # Code2Wav is a pure convolutional waveform decoder; it does NOT use + # rotary position embeddings. When hf_overrides sets architectures + # to [Qwen3TTSCode2Wav], strip rope_parameters so that the model + # runner sees uses_mrope == False and skips mrope position computation + # on codec tokens. Each stage loads its own config instance, so this + # in-place mutation does not affect the Talker stage. + archs = getattr(self, "architectures", []) or [] + if any("Code2Wav" in str(a) for a in archs): + if hasattr(config, "rope_parameters"): + delattr(config, "rope_parameters") return config diff --git a/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py deleted file mode 100644 index 1e759a8d2b4..00000000000 --- a/vllm_omni/model_executor/models/qwen3_tts/modeling_qwen3_tts.py +++ /dev/null @@ -1,2326 +0,0 @@ -# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Qwen3TTS model.""" - -import json -import os -from collections.abc import Callable -from dataclasses import dataclass - -import torch -from librosa.filters import mel as librosa_mel_fn -from torch import nn -from torch.nn import functional as F -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.generation import GenerationMixin -from transformers.integrations import use_kernel_forward_from_hub -from transformers.masking_utils import ( - create_causal_mask, - create_sliding_window_causal_mask, -) -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_layers import GradientCheckpointingLayer -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - ModelOutput, -) -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from transformers.processing_utils import Unpack -from transformers.utils import can_return_tuple, logging -from transformers.utils.hub import cached_file - -from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific - -from .configuration_qwen3_tts import ( - Qwen3TTSConfig, - Qwen3TTSSpeakerEncoderConfig, - Qwen3TTSTalkerCodePredictorConfig, - Qwen3TTSTalkerConfig, -) -from .qwen3_tts_tokenizer import Qwen3TTSTokenizer - -logger = logging.get_logger(__name__) - - -class Res2NetBlock(torch.nn.Module): - def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): - super().__init__() - - in_channel = in_channels // scale - hidden_channel = out_channels // scale - - self.blocks = nn.ModuleList( - [ - TimeDelayNetBlock( - in_channel, - hidden_channel, - kernel_size=kernel_size, - dilation=dilation, - ) - for i in range(scale - 1) - ] - ) - self.scale = scale - - def forward(self, hidden_states): - outputs = [] - for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): - if i == 0: - output_part = hidden_part - elif i == 1: - output_part = self.blocks[i - 1](hidden_part) - else: - output_part = self.blocks[i - 1](hidden_part + output_part) - outputs.append(output_part) - output = torch.cat(outputs, dim=1) - return output - - -class SqueezeExcitationBlock(nn.Module): - def __init__(self, in_channels, se_channels, out_channels): - super().__init__() - - self.conv1 = nn.Conv1d( - in_channels=in_channels, - out_channels=se_channels, - kernel_size=1, - padding="same", - padding_mode="reflect", - ) - self.relu = nn.ReLU(inplace=True) - self.conv2 = nn.Conv1d( - in_channels=se_channels, - out_channels=out_channels, - kernel_size=1, - padding="same", - padding_mode="reflect", - ) - self.sigmoid = nn.Sigmoid() - - def forward(self, hidden_states): - hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) - - hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) - hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) - - return hidden_states * hidden_states_mean - - -class AttentiveStatisticsPooling(nn.Module): - """This class implements an attentive statistic pooling layer for each channel. - It returns the concatenated mean and std of the input tensor. - """ - - def __init__(self, channels, attention_channels=128): - super().__init__() - - self.eps = 1e-12 - self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) - self.tanh = nn.Tanh() - self.conv = nn.Conv1d( - in_channels=attention_channels, - out_channels=channels, - kernel_size=1, - padding="same", - padding_mode="reflect", - ) - - def _length_to_mask(self, length, max_len=None, dtype=None, device=None): - """Creates a binary mask for each sequence. - - Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 - - Arguments - --------- - length : torch.LongTensor - Containing the length of each sequence in the batch. Must be 1D. - max_len : int - Max length for the mask, also the size of the second dimension. - dtype : torch.dtype, default: None - The dtype of the generated mask. - device: torch.device, default: None - The device to put the mask variable. - - Returns - ------- - mask : tensor - The binary mask. - """ - - if max_len is None: - max_len = length.max().long().item() # using arange to generate mask - mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( - len(length), max_len - ) < length.unsqueeze(1) - - mask = torch.as_tensor(mask, dtype=dtype, device=device) - return mask - - def _compute_statistics(self, x, m, dim=2): - mean = (m * x).sum(dim) - std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) - return mean, std - - def forward(self, hidden_states): - seq_length = hidden_states.shape[-1] - lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) - - # Make binary mask of shape [N, 1, L] - mask = self._length_to_mask( - lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device - ) - mask = mask.unsqueeze(1) - - # Expand the temporal context of the pooling layer by allowing the - # self-attention to look at global properties of the utterance. - total = mask.sum(dim=2, keepdim=True) - - mean, std = self._compute_statistics(hidden_states, mask / total) - mean = mean.unsqueeze(2).repeat(1, 1, seq_length) - std = std.unsqueeze(2).repeat(1, 1, seq_length) - attention = torch.cat([hidden_states, mean, std], dim=1) - - # Apply layers - attention = self.conv(self.tanh(self.tdnn(attention))) - - # Filter out zero-paddings - attention = attention.masked_fill(mask == 0, float("-inf")) - - attention = F.softmax(attention, dim=2) - mean, std = self._compute_statistics(hidden_states, attention) - # Append mean and std of the batch - pooled_stats = torch.cat((mean, std), dim=1) - pooled_stats = pooled_stats.unsqueeze(2) - - return pooled_stats - - -class TimeDelayNetBlock(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - dilation, - ): - super().__init__() - self.conv = nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dilation=dilation, - padding="same", - padding_mode="reflect", - ) - self.activation = nn.ReLU() - - def forward(self, hidden_states: torch.Tensor): - return self.activation(self.conv(hidden_states)) - - -class SqueezeExcitationRes2NetBlock(nn.Module): - """An implementation of building block in ECAPA-TDNN, i.e., - TDNN-Res2Net-TDNN-SqueezeExcitationBlock. - """ - - def __init__( - self, - in_channels, - out_channels, - res2net_scale=8, - se_channels=128, - kernel_size=1, - dilation=1, - ): - super().__init__() - self.out_channels = out_channels - self.tdnn1 = TimeDelayNetBlock( - in_channels, - out_channels, - kernel_size=1, - dilation=1, - ) - self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) - self.tdnn2 = TimeDelayNetBlock( - out_channels, - out_channels, - kernel_size=1, - dilation=1, - ) - self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) - - def forward(self, hidden_state): - residual = hidden_state - - hidden_state = self.tdnn1(hidden_state) - hidden_state = self.res2net_block(hidden_state) - hidden_state = self.tdnn2(hidden_state) - hidden_state = self.se_block(hidden_state) - - return hidden_state + residual - - -class Qwen3TTSSpeakerEncoder(torch.nn.Module): - """An implementation of the speaker embedding model in a paper. - "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in - TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143). - Use for Qwen3TTS extract speaker embedding. - """ - - def __init__(self, config: Qwen3TTSSpeakerEncoderConfig): - super().__init__() - if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( - config.enc_dilations - ): - raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") - self.channels = config.enc_channels - self.blocks = nn.ModuleList() - - # The initial TDNN layer - self.blocks.append( - TimeDelayNetBlock( - config.mel_dim, - config.enc_channels[0], - config.enc_kernel_sizes[0], - config.enc_dilations[0], - ) - ) - - # SE-Res2Net layers - for i in range(1, len(config.enc_channels) - 1): - self.blocks.append( - SqueezeExcitationRes2NetBlock( - config.enc_channels[i - 1], - config.enc_channels[i], - res2net_scale=config.enc_res2net_scale, - se_channels=config.enc_se_channels, - kernel_size=config.enc_kernel_sizes[i], - dilation=config.enc_dilations[i], - ) - ) - - # Multi-layer feature aggregation - self.mfa = TimeDelayNetBlock( - config.enc_channels[-1], - config.enc_channels[-1], - config.enc_kernel_sizes[-1], - config.enc_dilations[-1], - ) - - # Attentive Statistical Pooling - self.asp = AttentiveStatisticsPooling( - config.enc_channels[-1], - attention_channels=config.enc_attention_channels, - ) - - # Final linear transformation - self.fc = nn.Conv1d( - in_channels=config.enc_channels[-1] * 2, - out_channels=config.enc_dim, - kernel_size=1, - padding="same", - padding_mode="reflect", - ) - - def forward(self, hidden_states): - # Minimize transpose for efficiency - hidden_states = hidden_states.transpose(1, 2) - - hidden_states_list = [] - for layer in self.blocks: - hidden_states = layer(hidden_states) - hidden_states_list.append(hidden_states) - - # Multi-layer feature aggregation - hidden_states = torch.cat(hidden_states_list[1:], dim=1) - hidden_states = self.mfa(hidden_states) - - # Attentive Statistical Pooling - hidden_states = self.asp(hidden_states) - - # Final linear transformation - hidden_states = self.fc(hidden_states) - - hidden_states = hidden_states.squeeze(-1) - return hidden_states - - -def dynamic_range_compression_torch(x, c=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * c) - - -def mel_spectrogram( - y: torch.Tensor, - n_fft: int, - num_mels: int, - sampling_rate: int, - hop_size: int, - win_size: int, - fmin: int, - fmax: int = None, - center: bool = False, -) -> torch.Tensor: - """ - Calculate the mel spectrogram of an input signal. - This function uses slaney norm for the librosa mel filterbank - (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). - - Args: - y (torch.Tensor): Input signal. - n_fft (int): FFT size. - num_mels (int): Number of mel bins. - sampling_rate (int): Sampling rate of the input signal. - hop_size (int): Hop size for STFT. - win_size (int): Window size for STFT. - fmin (int): Minimum frequency for mel filterbank. - fmax (int): Maximum frequency for mel filterbank. - If None, defaults to half the sampling rate (fmax = sr / 2.0) - inside librosa_mel_fn - center (bool): Whether to pad the input to center the frames. Default is False. - - Returns: - torch.Tensor: Mel spectrogram. - """ - if torch.min(y) < -1.0: - print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") - if torch.max(y) > 1.0: - print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") - - device = y.device - - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - - mel_basis = torch.from_numpy(mel).float().to(device) - hann_window = torch.hann_window(win_size).to(device) - - padding = (n_fft - hop_size) // 2 - y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) - - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window, - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) - - mel_spec = torch.matmul(mel_basis, spec) - mel_spec = dynamic_range_compression_torch(mel_spec) - - return mel_spec - - -def _compute_default_rope_parameters( - config, - device, -): - base = config.rope_theta - partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) - return inv_freq, attention_factor - - -class Qwen3TTSPreTrainedModel(PreTrainedModel): - config_class = Qwen3TTSConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen3TTSDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_static_cache = False - _supports_attention_backend = True - - def _init_weights(self, module): - # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only - # inference and fine-tuning - so the proper init weights code has been removed - std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 - - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if module.weight is not None: - module.weight.data.fill_(1.0) - if module.bias is not None: - module.bias.data.zero_() - - -class Qwen3TTSTalkerTextPreTrainedModel(PreTrainedModel): - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = [] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_3 = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = False - _supports_attention_backend = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Qwen3TTSRMSNorm): - module.weight.data.fill_(1.0) - - -class Qwen3TTSTalkerRotaryEmbedding(nn.Module): - def __init__(self, config: Qwen3TTSTalkerConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn: Callable = _compute_default_rope_parameters - if self.rope_type != "default": - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - # In contrast to other models, Qwen3TTSThinkerText has different position ids for the grids - # So we expand the inv_freq to shape (3, ...) - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) - position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Qwen3TTSRotaryEmbedding(nn.Module): - def __init__(self, config: Qwen3TTSConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn: Callable = _compute_default_rope_parameters - if self.rope_type != "default": - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -@use_kernel_forward_from_hub("RMSNorm") -class Qwen3TTSRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen3TTSRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, mrope_interleaved=False, unsqueeze_dim=1): - """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). - - Explanation: - Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding - sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. - Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. - For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, - height and width) of text embedding is always the same, so the text embedding rotary position embedding has no - difference with modern LLMs. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - if mrope_interleaved: - - def apply_interleaved_rope(x, modality_num): - x_t = x[0].clone() - index_ranges = [] - for i, n in enumerate(mrope_section[1:], 1): - beg_idx = i - end_idx = n * modality_num - index_ranges.append((beg_idx, end_idx)) - for beg_idx, end_idx in index_ranges: - x_t[..., beg_idx:end_idx:modality_num] = x[beg_idx, ..., beg_idx:end_idx:modality_num] - return x_t - - dim = cos.shape[-1] - modality_num = len(mrope_section) - cos = torch.cat([apply_interleaved_rope(cos[..., : dim // 2], modality_num)] * 2, dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([apply_interleaved_rope(sin[..., : dim // 2], modality_num)] * 2, dim=-1).unsqueeze( - unsqueeze_dim - ) - else: - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class Qwen3TTSTalkerAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config, layer_idx): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - self.q_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! - self.k_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = getattr(config, "sliding_window", None) - self.rope_scaling = config.rope_scaling - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None, - past_key_values: Cache | None = None, - cache_position: torch.LongTensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"], self.rope_scaling["interleaved"] - ) - - if past_key_values is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=self.sliding_window, # diff with Llama - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class Qwen3TTSTalkerResizeMLP(nn.Module): - def __init__(self, input_size: int, intermediate_size: int, output_size: int, act: str, bias=False): - super().__init__() - self.linear_fc1 = nn.Linear(input_size, intermediate_size, bias=bias) - self.linear_fc2 = nn.Linear(intermediate_size, output_size, bias=bias) - self.act_fn = ACT2FN[act] - - def forward(self, hidden_state): - return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) - - -@dataclass -class Qwen3TTSTalkerCodePredictorOutputWithPast(ModelOutput): - r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head - (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, - returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - """ - - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor = None - past_key_values: list[torch.FloatTensor] | None = None - hidden_states: tuple[torch.FloatTensor] | None = None - attentions: tuple[torch.FloatTensor] | None = None - generation_steps: int | None = None - - -class Qwen3TTSTalkerTextMLP(nn.Module): - def __init__(self, config, intermediate_size=None): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class Qwen3TTSAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Qwen3TTSConfig, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - self.q_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! - self.k_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape - self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None, - past_key_values: Cache | None = None, - cache_position: torch.LongTensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_values is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=self.sliding_window, # diff with Llama - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class Qwen3TTSDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: Qwen3TTSConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = Qwen3TTSAttention(config=config, layer_idx=layer_idx) - - self.mlp = Qwen3TTSTalkerTextMLP(config) - self.input_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -class Qwen3TTSTalkerCodePredictorModel(Qwen3TTSPreTrainedModel): - config_class = Qwen3TTSTalkerCodePredictorConfig - base_model_prefix = "talker.code_predictor.model" - - def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig, embedding_dim: int): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.layers = nn.ModuleList( - [Qwen3TTSDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen3TTSRotaryEmbedding(config=config) - self.gradient_checkpointing = False - self.has_sliding_layers = "sliding_attention" in self.config.layer_types - self.codec_embedding = nn.ModuleList( - [nn.Embedding(config.vocab_size, embedding_dim) for _ in range(config.num_code_groups - 1)] - ) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.codec_embedding - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @can_return_tuple - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - generation_steps=None, - **flash_attn_kwargs, - ) -> BaseModelOutputWithPast: - if input_ids is not None: - raise ValueError("`input_ids` is expected to be `None`") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache - if not isinstance(past_key_values, (type(None), Cache)): - raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - # It may already have been prepared by e.g. `generate` - if not isinstance(causal_mask_mapping := attention_mask, dict): - # Prepare mask arguments - mask_kwargs = { - "config": self.config, - "input_embeds": inputs_embeds, - "attention_mask": attention_mask, - "cache_position": cache_position, - "past_key_values": past_key_values, - } - # Create the masks - causal_mask_mapping = { - "full_attention": create_causal_mask(**mask_kwargs), - } - # The sliding window alternating layers are not always activated depending on the config - if self.has_sliding_layers: - causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen3TTSTalkerCodePredictorModelForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - config_class = Qwen3TTSTalkerCodePredictorConfig - base_model_prefix = "talker.code_predictor" - - def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig, talker_config: Qwen3TTSTalkerConfig): - super().__init__(config) - self.model = Qwen3TTSTalkerCodePredictorModel(config, talker_config.hidden_size) - self.vocab_size = config.vocab_size - self.lm_head = nn.ModuleList( - [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)] - ) - - if config.hidden_size != talker_config.hidden_size: - self.small_to_mtp_projection = torch.nn.Linear(talker_config.hidden_size, config.hidden_size, bias=True) - else: - self.small_to_mtp_projection = torch.nn.Identity() - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward_finetune( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - generation_steps=None, - **kwargs, - ) -> CausalLMOutputWithPast: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=None, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - - logits = [] - for i in range(1, self.config.num_code_groups): - logits.append(self.lm_head[i - 1](hidden_states[:, i])) - logits = torch.stack(logits, dim=1) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return Qwen3TTSTalkerCodePredictorOutputWithPast(loss=loss, logits=logits) - - @can_return_tuple - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - generation_steps=None, - **kwargs, - ) -> CausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # Prefill stage - if inputs_embeds is not None and inputs_embeds.shape[1] > 1: - generation_steps = inputs_embeds.shape[1] - 2 # hidden & layer 0 - # Generation stage - else: - inputs_embeds = self.model.get_input_embeddings()[generation_steps - 1](input_ids) - inputs_embeds = self.small_to_mtp_projection(inputs_embeds) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=None, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - logits = self.lm_head[generation_steps](hidden_states) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return Qwen3TTSTalkerCodePredictorOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - generation_steps=generation_steps + 1, - ) - - def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1): - model_kwargs = super()._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder, num_new_tokens - ) - model_kwargs["generation_steps"] = outputs.generation_steps - return model_kwargs - - -@dataclass -class Qwen3TTSTalkerOutputWithPast(ModelOutput): - r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head - (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, - returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - """ - - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - past_key_values: list[torch.FloatTensor] | None = None - hidden_states: tuple[torch.FloatTensor] | None = None - attentions: tuple[torch.FloatTensor] | None = None - past_hidden: torch.FloatTensor | None = None - generation_step: int | None = None - trailing_text_hidden: torch.FloatTensor | None = None - tts_pad_embed: torch.FloatTensor | None = None - - -class Qwen3TTSTalkerDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config, layer_idx): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = Qwen3TTSTalkerAttention(config, layer_idx) - - self.mlp = Qwen3TTSTalkerTextMLP(config, intermediate_size=config.intermediate_size) - - self.input_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: tuple[torch.Tensor] | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - hidden_states = self.mlp(hidden_states) - - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -class Qwen3TTSTalkerModel(Qwen3TTSTalkerTextPreTrainedModel): - config_class = Qwen3TTSTalkerConfig - base_model_prefix = "talker.model" - - def __init__(self, config): - super().__init__(config) - self.vocab_size = config.vocab_size - self.layers = nn.ModuleList( - [Qwen3TTSTalkerDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen3TTSTalkerRotaryEmbedding(config) - self.gradient_checkpointing = False - self.codec_embedding = nn.Embedding(config.vocab_size, config.hidden_size) - self.text_embedding = nn.Embedding(config.text_vocab_size, config.text_hidden_size) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.codec_embedding - - def get_text_embeddings(self): - return self.text_embedding - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @can_return_tuple - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - cache_position: torch.LongTensor | None = None, - **flash_attn_kwargs: Unpack[FlashAttentionKwargs], - ) -> BaseModelOutputWithPast: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - # the hard coded `3` is for temporal, height and width. - if position_ids is None: - position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) - elif position_ids.ndim == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - - if position_ids.ndim == 3 and position_ids.shape[0] == 4: - text_position_ids = position_ids[0] - position_ids = position_ids[1:] - else: - text_position_ids = position_ids[0] - - mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask - causal_mask = mask_function( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=text_position_ids, - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=text_position_ids, - past_key_values=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen3TTSTalkerForConditionalGeneration(Qwen3TTSTalkerTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - config_class = Qwen3TTSTalkerConfig - base_model_prefix = "talker" - - def __init__(self, config: Qwen3TTSTalkerConfig): - super().__init__(config) - self.model = Qwen3TTSTalkerModel(config) - self.vocab_size = config.vocab_size - self.text_projection = Qwen3TTSTalkerResizeMLP( - config.text_hidden_size, config.text_hidden_size, config.hidden_size, config.hidden_act, bias=True - ) - - self.codec_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.code_predictor = Qwen3TTSTalkerCodePredictorModelForConditionalGeneration( - config=config.code_predictor_config, talker_config=config - ) - self.rope_deltas = None - - # Initialize weights and apply final processing - self.post_init() - - # TODO: hack, modular cannot inherit multiple classes - - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def get_text_embeddings(self): - return self.model.get_text_embeddings() - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward_sub_talker_finetune(self, codec_ids, talker_hidden_states): - assert len(codec_ids.shape) == 2 - assert len(talker_hidden_states.shape) == 2 - assert codec_ids.shape[0] == talker_hidden_states.shape[0] - assert talker_hidden_states.shape[1] == self.config.hidden_size - assert codec_ids.shape[1] == self.config.num_code_groups - - sub_talker_inputs_embeds = [talker_hidden_states.unsqueeze(1)] - - for i in range(self.config.num_code_groups - 1): - if i == 0: - sub_talker_inputs_embeds.append(self.get_input_embeddings()(codec_ids[:, :1])) - else: - sub_talker_inputs_embeds.append( - self.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, i : i + 1]) - ) - sub_talker_inputs_embeds = torch.cat(sub_talker_inputs_embeds, dim=1) - - sub_talker_outputs = self.code_predictor.forward_finetune( - inputs_embeds=sub_talker_inputs_embeds, labels=codec_ids[:, 1:] - ) - - sub_talker_logits = sub_talker_outputs.logits - sub_talker_loss = sub_talker_outputs.loss - return sub_talker_logits, sub_talker_loss - - @can_return_tuple - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - past_hidden=None, - trailing_text_hidden=None, - tts_pad_embed=None, - generation_step=None, - subtalker_dosample=None, - subtalker_top_p=None, - subtalker_top_k=None, - subtalker_temperature=None, - **kwargs, - ) -> CausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - ```""" - # Prefill - if inputs_embeds is not None and inputs_embeds.shape[1] > 1: - generation_step = -1 - codec_ids = None - # Generate - else: - last_id_hidden = self.get_input_embeddings()(input_ids) - predictor_result = self.code_predictor.generate( - inputs_embeds=torch.cat((past_hidden, last_id_hidden), dim=1), - max_new_tokens=self.config.num_code_groups - 1, - do_sample=subtalker_dosample, - top_p=subtalker_top_p, - top_k=subtalker_top_k, - temperature=subtalker_temperature, - output_hidden_states=True, - return_dict_in_generate=True, - ) - codec_ids = torch.cat((input_ids, predictor_result.sequences), dim=-1) - codec_hiddens = torch.cat( - [last_id_hidden] - + [ - self.code_predictor.get_input_embeddings()[i](predictor_result.sequences[..., i : i + 1]) - for i in range(self.config.num_code_groups - 1) - ], - dim=1, - ) - inputs_embeds = codec_hiddens.sum(1, keepdim=True) - - if generation_step < trailing_text_hidden.shape[1]: - inputs_embeds = inputs_embeds + trailing_text_hidden[:, generation_step].unsqueeze(1) - else: - inputs_embeds = inputs_embeds + tts_pad_embed - if attention_mask is not None: - if ( - cache_position is None - or (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - ): - delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) - position_ids, rope_deltas = self.get_rope_index( - attention_mask, - ) - rope_deltas = rope_deltas - delta0 - self.rope_deltas = rope_deltas - else: - batch_size, seq_length = input_ids.shape - delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 - position_ids = torch.arange(seq_length, device=input_ids.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - - outputs: BaseModelOutputWithPast = self.model( - input_ids=None, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - logits = self.codec_head(hidden_states) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return Qwen3TTSTalkerOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=(outputs.hidden_states, codec_ids), - attentions=outputs.attentions, - past_hidden=hidden_states[:, -1:, :], - generation_step=generation_step + 1, - trailing_text_hidden=trailing_text_hidden, - tts_pad_embed=tts_pad_embed, - ) - - def get_rope_index( - self, - attention_mask: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Calculate the 3D rope index based on image and video's temporal, height and width in LLM. - - Explanation: - Each embedding sequence contains vision embedding and text embedding or just contains text embedding. - - For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. - Examples: - input_ids: [T T T T T], here T is for text. - temporal position_ids: [0, 1, 2, 3, 4] - height position_ids: [0, 1, 2, 3, 4] - width position_ids: [0, 1, 2, 3, 4] - - For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part - and 1D rotary position embedding for text part. - Examples: - Temporal (Time): 3 patches, representing different segments of the video in time. - Height: 2 patches, dividing each frame vertically. - Width: 2 patches, dividing each frame horizontally. - We also have some important parameters: - fps (Frames Per Second): The video's frame rate, set to 1. - This means one frame is processed each second. - interval: The step size for the temporal position IDs, - calculated as tokens_per_second * temporal_patch_size / fps. - In this case, 25 * 2 / 1 = 50. This means that each temporal - patch will be have a difference of 50 in the temporal position IDs. - input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. - text temporal position_ids: [101, 102, 103, 104, 105] - text height position_ids: [101, 102, 103, 104, 105] - text width position_ids: [101, 102, 103, 104, 105] - Here we calculate the text start position_ids as the max vision position_ids plus 1. - - Args: - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - Returns: - position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) - mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) - """ - mrope_position_deltas = [] - - position_ids = attention_mask.float().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) - - return position_ids, mrope_position_deltas - - def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1): - model_kwargs = super()._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder, num_new_tokens - ) - model_kwargs["past_hidden"] = outputs.past_hidden - model_kwargs["generation_step"] = outputs.generation_step - model_kwargs["trailing_text_hidden"] = outputs.trailing_text_hidden - model_kwargs["tts_pad_embed"] = outputs.tts_pad_embed - return model_kwargs - - -class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin): - config_class = Qwen3TTSConfig - - def __init__(self, config: Qwen3TTSConfig): - super().__init__(config) - self.config = config - - self.talker = Qwen3TTSTalkerForConditionalGeneration(self.config.talker_config) - - if config.tts_model_type == "base": - self.speaker_encoder = Qwen3TTSSpeakerEncoder(self.config.speaker_encoder_config) - else: - self.speaker_encoder = None - - self.speech_tokenizer = None - self.generate_config = None - - self.supported_speakers = self.config.talker_config.spk_id.keys() - self.supported_languages = ["auto"] - for language_id in self.config.talker_config.codec_language_id.keys(): - if "dialect" not in language_id: - self.supported_languages.append(language_id) - - self.speaker_encoder_sample_rate = self.config.speaker_encoder_config.sample_rate - self.tokenizer_type = self.config.tokenizer_type - self.tts_model_size = self.config.tts_model_size - self.tts_model_type = self.config.tts_model_type - - self.post_init() - - def load_speech_tokenizer(self, speech_tokenizer): - self.speech_tokenizer = speech_tokenizer - - def load_generate_config(self, generate_config): - self.generate_config = generate_config - - def get_supported_speakers(self): - return self.supported_speakers - - def get_supported_languages(self): - return self.supported_languages - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path, - *model_args, - config=None, - cache_dir=None, - ignore_mismatched_sizes=False, - force_download=False, - local_files_only=False, - token=None, - revision="main", - use_safetensors=None, - weights_only=True, - **kwargs, - ): - model = super().from_pretrained( - pretrained_model_name_or_path, - *model_args, - config=config, - cache_dir=cache_dir, - ignore_mismatched_sizes=ignore_mismatched_sizes, - force_download=force_download, - local_files_only=local_files_only, - token=token, - revision=revision, - use_safetensors=use_safetensors, - weights_only=weights_only, - **kwargs, - ) - if not local_files_only and not os.path.isdir(pretrained_model_name_or_path): - download_cache_dir = kwargs.get("cache_dir", cache_dir) - download_revision = kwargs.get("revision", revision) - download_weights_from_hf_specific( - pretrained_model_name_or_path, - cache_dir=download_cache_dir, - allow_patterns=["speech_tokenizer/*"], - revision=download_revision, - ) - speech_tokenizer_path = cached_file( - pretrained_model_name_or_path, - "speech_tokenizer/config.json", - subfolder=kwargs.pop("subfolder", None), - cache_dir=kwargs.pop("cache_dir", None), - force_download=kwargs.pop("force_download", False), - proxies=kwargs.pop("proxies", None), - resume_download=kwargs.pop("resume_download", None), - local_files_only=kwargs.pop("local_files_only", False), - token=kwargs.pop("use_auth_token", None), - revision=kwargs.pop("revision", None), - ) - if speech_tokenizer_path is None: - raise ValueError(f"""{pretrained_model_name_or_path}/{speech_tokenizer_path} not exists""") - speech_tokenizer_dir = os.path.dirname(speech_tokenizer_path) - speech_tokenizer = Qwen3TTSTokenizer.from_pretrained( - speech_tokenizer_dir, - *model_args, - **kwargs, - ) - model.load_speech_tokenizer(speech_tokenizer) - - generate_config_path = cached_file( - pretrained_model_name_or_path, - "generation_config.json", - subfolder=kwargs.pop("subfolder", None), - cache_dir=kwargs.pop("cache_dir", None), - force_download=kwargs.pop("force_download", False), - proxies=kwargs.pop("proxies", None), - resume_download=kwargs.pop("resume_download", None), - local_files_only=kwargs.pop("local_files_only", False), - token=kwargs.pop("use_auth_token", None), - revision=kwargs.pop("revision", None), - ) - with open(generate_config_path, encoding="utf-8") as f: - generate_config = json.load(f) - model.load_generate_config(generate_config) - - return model - - @torch.inference_mode() - def extract_speaker_embedding(self, audio, sr): - assert sr == 24000, "Only support 24kHz audio" - mels = mel_spectrogram( - torch.from_numpy(audio).unsqueeze(0), - n_fft=1024, - num_mels=128, - sampling_rate=24000, - hop_size=256, - win_size=1024, - fmin=0, - fmax=12000, - ).transpose(1, 2) - speaker_embedding = self.speaker_encoder(mels.to(self.device).to(self.dtype))[0] - return speaker_embedding - - @torch.inference_mode() - def generate_speaker_prompt(self, voice_clone_prompt: list[dict]): - voice_clone_spk_embeds = [] - for index in range(len(voice_clone_prompt["ref_spk_embedding"])): - ref_spk_embedding = ( - voice_clone_prompt["ref_spk_embedding"][index].to(self.talker.device).to(self.talker.dtype) - ) - voice_clone_spk_embeds.append(ref_spk_embedding) - - return voice_clone_spk_embeds - - def generate_icl_prompt( - self, - text_id: torch.Tensor, - ref_id: torch.Tensor, - ref_code: torch.Tensor, - tts_pad_embed: torch.Tensor, - tts_eos_embed: torch.Tensor, - non_streaming_mode: bool, - ): - # text embed (ref id + text id + eos) 1 T1 D - text_embed = self.talker.text_projection( - self.talker.get_text_embeddings()(torch.cat([ref_id, text_id], dim=-1)) - ) - text_embed = torch.cat([text_embed, tts_eos_embed], dim=1) - # codec embed (codec bos + codec) 1 T2 D - codec_embed = [] - for i in range(self.talker.config.num_code_groups): - if i == 0: - codec_embed.append(self.talker.get_input_embeddings()(ref_code[:, :1])) - else: - codec_embed.append(self.talker.code_predictor.get_input_embeddings()[i - 1](ref_code[:, i : i + 1])) - codec_embed = torch.cat(codec_embed, dim=1).sum(1).unsqueeze(0) - codec_embed = torch.cat( - [ - self.talker.get_input_embeddings()( - torch.tensor( - [ - [ - self.config.talker_config.codec_bos_id, - ] - ], - device=self.talker.device, - dtype=text_id.dtype, - ) - ), - codec_embed, - ], - dim=1, - ) - # compute lens - text_lens = text_embed.shape[1] - codec_lens = codec_embed.shape[1] - if non_streaming_mode: - icl_input_embed = text_embed + self.talker.get_input_embeddings()( - torch.tensor( - [ - [ - self.config.talker_config.codec_pad_id, - ] - * text_lens - ], - device=self.talker.device, - dtype=text_id.dtype, - ) - ) - icl_input_embed = torch.cat([icl_input_embed, codec_embed + tts_pad_embed], dim=1) - return icl_input_embed, tts_pad_embed - else: - if text_lens > codec_lens: - return text_embed[:, :codec_lens] + codec_embed, text_embed[:, codec_lens:] - else: - text_embed = torch.cat([text_embed] + [tts_pad_embed] * (codec_lens - text_lens), dim=1) - return text_embed + codec_embed, tts_pad_embed - - @torch.no_grad() - def generate( - self, - input_ids: list[torch.Tensor] | None = None, - instruct_ids: list[torch.Tensor] | None = None, - ref_ids: list[torch.Tensor] | None = None, - voice_clone_prompt: list[dict] = None, - languages: list[str] = None, - speakers: list[str] = None, - non_streaming_mode=False, - max_new_tokens: int = 4096, - do_sample: bool = True, - top_k: int = 50, - top_p: float = 1.0, - temperature: float = 0.9, - subtalker_dosample: bool = True, - subtalker_top_k: int = 50, - subtalker_top_p: float = 1.0, - subtalker_temperature: float = 0.9, - eos_token_id: int | None = None, - repetition_penalty: float = 1.05, - **kwargs, - ): - talker_kwargs = { - "max_new_tokens": max_new_tokens, - "min_new_tokens": 2, - "do_sample": do_sample, - "top_k": top_k, - "top_p": top_p, - "temperature": temperature, - "subtalker_dosample": subtalker_dosample, - "subtalker_top_k": subtalker_top_k, - "subtalker_top_p": subtalker_top_p, - "subtalker_temperature": subtalker_temperature, - "eos_token_id": eos_token_id if eos_token_id is not None else self.config.talker_config.codec_eos_token_id, - "repetition_penalty": repetition_penalty, - "suppress_tokens": [ - i - for i in range(self.config.talker_config.vocab_size - 1024, self.config.talker_config.vocab_size) - if i not in (self.config.talker_config.codec_eos_token_id,) - ], - "output_hidden_states": getattr(kwargs, "output_hidden_states", True), - "return_dict_in_generate": getattr(kwargs, "return_dict_in_generate", True), - } - - talker_input_embeds = [[] for _ in range(len(input_ids))] - - voice_clone_spk_embeds = None - # voice clone speaker prompt generate - if voice_clone_prompt is not None: - voice_clone_spk_embeds = self.generate_speaker_prompt(voice_clone_prompt) - - # instruct text prompt generate - if instruct_ids is not None: - for index, instruct_id in enumerate(instruct_ids): - if instruct_id is not None: - talker_input_embeds[index].append( - self.talker.text_projection(self.talker.get_text_embeddings()(instruct_id)) - ) - - # tts text prompt generate - trailing_text_hiddens = [] - if speakers is None: - speakers = [None] * len(input_ids) - for index, (input_id, language, speaker) in enumerate(zip(input_ids, languages, speakers)): - if voice_clone_spk_embeds is None: - if speaker == "" or speaker is None: # Instruct create speaker - speaker_embed = None - else: - if speaker.lower() not in self.config.talker_config.spk_id: - raise NotImplementedError(f"Speaker {speaker} not implemented") - else: - spk_id = self.config.talker_config.spk_id[speaker.lower()] - speaker_embed = self.talker.get_input_embeddings()( - torch.tensor( - spk_id, - device=self.talker.device, - dtype=input_id.dtype, - ) - ) - else: - if voice_clone_prompt["x_vector_only_mode"][index] or voice_clone_prompt["icl_mode"][index]: - speaker_embed = voice_clone_spk_embeds[index] - else: - speaker_embed = None - - assert language is not None - - if language.lower() == "auto": - language_id = None - else: - if language.lower() not in self.config.talker_config.codec_language_id: - raise NotImplementedError(f"Language {language} not implemented") - else: - language_id = self.config.talker_config.codec_language_id[language.lower()] - - if ( - language.lower() in ["chinese", "auto"] - and speaker != "" - and speaker is not None - and self.config.talker_config.spk_is_dialect[speaker.lower()] is not False - ): - dialect = self.config.talker_config.spk_is_dialect[speaker.lower()] - language_id = self.config.talker_config.codec_language_id[dialect] - - tts_bos_embed, tts_eos_embed, tts_pad_embed = self.talker.text_projection( - self.talker.get_text_embeddings()( - torch.tensor( - [[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]], - device=self.talker.device, - dtype=input_id.dtype, - ) - ) - ).chunk(3, dim=1) # 3 * [1 1 d] - - # codec: tag and speaker - if language_id is None: - codec_prefill_list = [ - [ - self.config.talker_config.codec_nothink_id, - self.config.talker_config.codec_think_bos_id, - self.config.talker_config.codec_think_eos_id, - ] - ] - else: - codec_prefill_list = [ - [ - self.config.talker_config.codec_think_id, - self.config.talker_config.codec_think_bos_id, - language_id, - self.config.talker_config.codec_think_eos_id, - ] - ] - - codec_input_emebdding_0 = self.talker.get_input_embeddings()( - torch.tensor( - codec_prefill_list, - device=self.talker.device, - dtype=input_id.dtype, - ) - ) - codec_input_emebdding_1 = self.talker.get_input_embeddings()( - torch.tensor( - [ - [ - self.config.talker_config.codec_pad_id, - self.config.talker_config.codec_bos_id, - ] - ], - device=self.talker.device, - dtype=input_id.dtype, - ) - ) - if speaker_embed is None: - codec_input_emebdding = torch.cat([codec_input_emebdding_0, codec_input_emebdding_1], dim=1) - else: - codec_input_emebdding = torch.cat( - [codec_input_emebdding_0, speaker_embed.view(1, 1, -1), codec_input_emebdding_1], dim=1 - ) - - # '<|im_start|>assistant\n我叫通义千问,是阿里云的开源大模型。<|im_end|>\n<|im_start|>assistant\n' - - # <|im_start|>assistant\n - _talker_input_embed_role = self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, :3])) - - # tts_pad * 4 + tts_bos - _talker_input_embed = ( - torch.cat( - ( - tts_pad_embed.expand(-1, codec_input_emebdding.shape[1] - 2, -1), - tts_bos_embed, - ), - dim=1, - ) - + codec_input_emebdding[:, :-1] - ) - - talker_input_embed = torch.cat((_talker_input_embed_role, _talker_input_embed), dim=1) - - if ( - voice_clone_prompt is not None - and voice_clone_prompt["ref_code"] is not None - and voice_clone_prompt["icl_mode"][index] - ): - icl_input_embed, trailing_text_hidden = self.generate_icl_prompt( - text_id=input_id[:, 3:-5], - ref_id=ref_ids[index][:, 3:-2], - ref_code=voice_clone_prompt["ref_code"][index].to(self.talker.device), - tts_pad_embed=tts_pad_embed, - tts_eos_embed=tts_eos_embed, - non_streaming_mode=non_streaming_mode, - ) - talker_input_embed = torch.cat([talker_input_embed, icl_input_embed], dim=1) - else: - # tts_text_first_token - talker_input_embed = torch.cat( - [ - talker_input_embed, - self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 3:4])) - + codec_input_emebdding[:, -1:], - ], - dim=1, - ) - if non_streaming_mode: - talker_input_embed = talker_input_embed[:, :-1] # 去掉原本放进去的text - talker_input_embed = torch.cat( - [ - talker_input_embed, - torch.cat( - ( - self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 3:-5])), - tts_eos_embed, - ), - dim=1, - ) - + self.talker.get_input_embeddings()( - torch.tensor( - [ - [ - self.config.talker_config.codec_pad_id, - ] - * (input_id[:, 3:-5].shape[1] + 1) - ], - device=self.talker.device, - dtype=input_id.dtype, - ) - ), - tts_pad_embed - + self.talker.get_input_embeddings()( - torch.tensor( - [ - [ - self.config.talker_config.codec_bos_id, - ] - ], - device=self.talker.device, - dtype=input_id.dtype, - ) - ), - ], - dim=1, - ) - trailing_text_hidden = tts_pad_embed - else: - # 叫通义千问,是阿里云的开源大模型。 - trailing_text_hidden = torch.cat( - ( - self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 4:-5])), - tts_eos_embed, - ), - dim=1, - ) - talker_input_embeds[index].append(talker_input_embed) - trailing_text_hiddens.append(trailing_text_hidden) - - for index, talker_input_embed in enumerate(talker_input_embeds): - talker_input_embeds[index] = torch.cat([item for item in talker_input_embed if item is not None], dim=1) - - # for batch inferquence - original_lengths = torch.tensor([t.shape[1] for t in talker_input_embeds]) - # left padding for talker input embeds - sequences = [t.squeeze(0) for t in talker_input_embeds] - sequences_reversed = [t.flip(dims=[0]) for t in sequences] - padded_reversed = torch.nn.utils.rnn.pad_sequence(sequences_reversed, batch_first=True, padding_value=0.0) - talker_input_embeds = padded_reversed.flip(dims=[1]) - # generate mask - batch_size, max_len = talker_input_embeds.shape[0], talker_input_embeds.shape[1] - indices = torch.arange(max_len).expand(batch_size, -1) - num_pads = max_len - original_lengths - talker_attention_mask = (indices >= num_pads.unsqueeze(1)).long().to(talker_input_embeds.device) - # padding trailing text hiddens - pad_embedding_vector = tts_pad_embed.squeeze() - sequences_to_pad = [t.squeeze(0) for t in trailing_text_hiddens] - trailing_text_original_lengths = [s.shape[0] for s in sequences_to_pad] - padded_hiddens = torch.nn.utils.rnn.pad_sequence(sequences_to_pad, batch_first=True, padding_value=0.0) - arange_tensor = torch.arange(max(trailing_text_original_lengths), device=padded_hiddens.device).expand( - len(trailing_text_original_lengths), -1 - ) - lengths_tensor = torch.tensor(trailing_text_original_lengths, device=padded_hiddens.device).unsqueeze(1) - padding_mask = arange_tensor >= lengths_tensor - padded_hiddens[padding_mask] = pad_embedding_vector - trailing_text_hiddens = padded_hiddens - - # forward - talker_result = self.talker.generate( - inputs_embeds=talker_input_embeds, - attention_mask=talker_attention_mask, - trailing_text_hidden=trailing_text_hiddens, - tts_pad_embed=tts_pad_embed, - **talker_kwargs, - ) - - talker_codes = torch.stack([hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None], dim=1) - talker_hidden_states = torch.cat([hid[0][-1][:, -1:] for hid in talker_result.hidden_states], dim=1)[:, :-1] - - first_codebook = talker_codes[:, :, 0] - is_stop_token = first_codebook == self.config.talker_config.codec_eos_token_id - stop_indices = torch.argmax(is_stop_token.int(), dim=1) - has_stop_token = is_stop_token.any(dim=1) - effective_lengths = torch.where(has_stop_token, stop_indices, talker_codes.shape[1]) - - talker_codes_list = [ - talker_codes[ - i, - :length, - ] - for i, length in enumerate(effective_lengths) - ] - talker_hidden_states_list = [talker_hidden_states[i, :length, :] for i, length in enumerate(effective_lengths)] - - return talker_codes_list, talker_hidden_states_list - - -__all__ = [ - "Qwen3TTSForConditionalGeneration", - "Qwen3TTSTalkerForConditionalGeneration", - "Qwen3TTSPreTrainedModel", - "Qwen3TTSTalkerModel", -] diff --git a/vllm_omni/model_executor/models/qwen3_tts/processing_qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/processing_qwen3_tts.py deleted file mode 100644 index 5643a857cdb..00000000000 --- a/vllm_omni/model_executor/models/qwen3_tts/processing_qwen3_tts.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from transformers.feature_extraction_utils import BatchFeature -from transformers.processing_utils import ProcessingKwargs, ProcessorMixin - - -class Qwen3TTSProcessorKwargs(ProcessingKwargs, total=False): - _defaults = { - "text_kwargs": { - "padding": False, - "padding_side": "left", - } - } - - -class Qwen3TTSProcessor(ProcessorMixin): - r""" - Constructs a Qwen3TTS processor. - - Args: - tokenizer ([`Qwen2TokenizerFast`], *optional*): - The text tokenizer. - chat_template (`Optional[str]`, *optional*): - The Jinja template to use for formatting the conversation. - If not provided, the default chat template is used. - """ - - attributes = ["tokenizer"] - tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") - - def __init__(self, tokenizer=None, chat_template=None): - super().__init__(tokenizer, chat_template=chat_template) - - def __call__(self, text=None, **kwargs) -> BatchFeature: - """ - Main method to prepare for the model one or several sequences(s) and audio(s). - This method forwards the `text` and `kwargs` arguments to - Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` - to encode the text. - - Args: - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - """ - - if text is None: - raise ValueError("You need to specify either a `text` input to process.") - - output_kwargs = self._merge_kwargs( - Qwen3TTSProcessorKwargs, - tokenizer_init_kwargs=self.tokenizer.init_kwargs, - **kwargs, - ) - if not isinstance(text, list): - text = [text] - - texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) - - return BatchFeature( - data={**texts_inputs}, - tensor_type=kwargs.get("return_tensors"), - ) - - def batch_decode(self, *args, **kwargs): - """ - This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please - refer to the docstring of this method for more information. - """ - return self.tokenizer.batch_decode(*args, **kwargs) - - def decode(self, *args, **kwargs): - """ - This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to - the docstring of this method for more information. - """ - return self.tokenizer.decode(*args, **kwargs) - - def apply_chat_template(self, conversations, chat_template=None, **kwargs): - if isinstance(conversations[0], dict): - conversations = [conversations] - return super().apply_chat_template(conversations, chat_template, **kwargs) - - @property - def model_input_names(self): - tokenizer_input_names = self.tokenizer.model_input_names - return list(dict.fromkeys(tokenizer_input_names)) - - -__all__ = ["Qwen3TTSProcessor"] diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py new file mode 100644 index 00000000000..339268f34f0 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import os +from collections.abc import Iterable +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from transformers.utils.hub import cached_file +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +from .qwen3_tts_tokenizer import Qwen3TTSTokenizer + +logger = init_logger(__name__) + + +class Qwen3TTSCode2Wav(nn.Module): + """Stage-1 code2wav model for Qwen3-TTS (GenerationModelRunner). + Consumes frame-aligned codec tokens from input_ids and decodes waveform via SpeechTokenizer.""" + + input_modalities = "audio" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.model_path = vllm_config.model_config.model + + self.have_multimodal_outputs = True + self.has_preprocess = False + self.has_postprocess = False + # Generation-only stage (no logits / sampling). + self.requires_raw_input_tokens = True + + self._speech_tokenizer: Qwen3TTSTokenizer | None = None + self._num_quantizers: int | None = None + self._decode_upsample_rate: int | None = None + self._output_sample_rate: int | None = None + self._logged_codec_stats = False + + @staticmethod + def _module_device(module: nn.Module) -> torch.device: + try: + return next(module.parameters()).device + except StopIteration: + for _, buf in module.named_buffers(recurse=True): + return buf.device + return torch.device("cpu") + + def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: + if self._speech_tokenizer is not None: + return self._speech_tokenizer + + # Locate speech_tokenizer dir from HF cache (or local path). + cfg_path = cached_file(self.model_path, "speech_tokenizer/config.json") + if cfg_path is None: + raise ValueError(f"{self.model_path}/speech_tokenizer/config.json not found") + speech_tokenizer_dir = os.path.dirname(cfg_path) + + # Stage-1 only needs decode; skip HF feature extractor to avoid heavy optional deps. + # Still require preprocessor_config.json (use cached_file so online runs can fetch it). + prep_cfg = cached_file(self.model_path, "speech_tokenizer/preprocessor_config.json") + if prep_cfg is None: + raise ValueError( + f"{self.model_path}/speech_tokenizer/preprocessor_config.json not found. " + "Please make sure the checkpoint contains the required HF preprocessing files." + ) + + tok = Qwen3TTSTokenizer.from_pretrained( + speech_tokenizer_dir, + torch_dtype=torch.bfloat16, + load_feature_extractor=False, + ) + + # Align device with vLLM worker, then read back from module. + if tok.model is not None: + tok.model.to(device=self.vllm_config.device_config.device) + tok.device = self._module_device(tok.model) + + # Derive codec group count and rates from tokenizer config. + dec_cfg = getattr(tok.model.config, "decoder_config", None) + num_q = getattr(dec_cfg, "num_quantizers", None) if dec_cfg is not None else None + if num_q is None: + raise ValueError("speech_tokenizer decoder_config.num_quantizers not found") + num_q = int(num_q) + if num_q <= 0: + raise ValueError(f"Invalid speech_tokenizer num_quantizers={num_q}") + + try: + upsample = int(tok.get_decode_upsample_rate()) + except Exception as e: + raise ValueError(f"Failed to get decode upsample rate: {e}") from e + if upsample <= 0: + raise ValueError(f"Invalid decode upsample rate: {upsample}") + + try: + out_sr = int(tok.get_output_sample_rate()) + except Exception as e: + raise ValueError(f"Failed to get output sample rate: {e}") from e + + self._speech_tokenizer = tok + self._num_quantizers = num_q + self._decode_upsample_rate = upsample + self._output_sample_rate = out_sr + return tok + + def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: + # This stage ignores token embeddings. Keep a stable dummy embedding for vLLM runner. + if input_ids.numel() == 0: + return torch.empty((0, 1), device=input_ids.device, dtype=torch.float32) + return torch.zeros((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.float32) + + def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None) -> None: + return None + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: Any = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Decode codec codes into audio waveform. + + input_ids layout: [codec_context_frames, *flat_codes] + where flat_codes is codebook-major [q*F]. + """ + tok = self._ensure_speech_tokenizer_loaded() + assert self._num_quantizers is not None + assert self._output_sample_rate is not None + + sr_val = self._output_sample_rate + empty_ret = ( + torch.zeros((0,), dtype=torch.float32), + torch.tensor(sr_val, dtype=torch.int32), + ) + + if input_ids is None: + return empty_ret + + q = int(self._num_quantizers) + ids = input_ids.reshape(-1).to(dtype=torch.long) + n_tokens = ids.numel() + + if n_tokens == 0: + return empty_ret + + # input_ids[0] = codec_context_frames (prepended by stage_input_processor). + ctx_frames = int(ids[0].item()) + ids = ids[1:] + n_tokens = ids.numel() + + if n_tokens == 0: + return empty_ret + + # Warmup / dummy_run: not divisible by num_quantizers. + if n_tokens % q != 0: + logger.warning( + "Code2Wav input_ids length %d not divisible by num_quantizers %d, " + "likely a warmup run; returning empty audio.", + n_tokens, + q, + ) + return empty_ret + + total_frames = n_tokens // q + + # Reshape codebook-major flat [q*F] -> [q, F] -> [F, q] for SpeechTokenizer. + codes_fq = ids.reshape(q, total_frames).transpose(0, 1).contiguous() + + if not self._logged_codec_stats and total_frames > 1: + self._logged_codec_stats = True + try: + uniq = int(torch.unique(codes_fq).numel()) + cmin = int(codes_fq.min().item()) + cmax = int(codes_fq.max().item()) + head = codes_fq[: min(2, total_frames), : min(8, q)].cpu().tolist() + logger.info( + "Code2Wav codec: frames=%d q=%d uniq=%d range=[%d,%d] head=%s", + total_frames, + q, + uniq, + cmin, + cmax, + head, + ) + except Exception: + pass + + wavs, sr = tok.decode({"audio_codes": codes_fq}) + if not wavs: + raise ValueError("SpeechTokenizer code2wav produced empty waveform list.") + audio_np = wavs[0].astype(np.float32, copy=False) + + # Trim left-context waveform samples (streaming sliding window). + if ctx_frames > 0: + upsample = self._decode_upsample_rate + if upsample is None or upsample <= 0: + raise ValueError(f"Invalid decode upsample rate: {upsample}") + cut = ctx_frames * upsample + if cut < audio_np.shape[0]: + audio_np = audio_np[cut:] + else: + logger.warning( + "Context trim %d >= decoded length %d; returning empty audio.", + cut, + audio_np.shape[0], + ) + return empty_ret + + audio_tensor = torch.from_numpy(audio_np).to(dtype=torch.float32).reshape(-1) + sr_tensor = torch.tensor(int(sr), dtype=torch.int32) + return audio_tensor, sr_tensor + + def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput: + if isinstance(model_outputs, OmniOutput): + return model_outputs + + if not (isinstance(model_outputs, tuple) and len(model_outputs) == 2): + raise TypeError(f"Qwen3TTSCode2Wav expected (audio_tensor, sr) outputs, got {type(model_outputs)}") + + audio_tensor, sr = model_outputs + return OmniOutput( + text_hidden_states=None, + multimodal_outputs={ + "model_outputs": audio_tensor, + "sr": sr, + }, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # SpeechTokenizer weights live under `speech_tokenizer/` and are loaded + # lazily from that directory. Ignore main checkpoint weights. + return set() diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py new file mode 100644 index 00000000000..b8b8f6bed49 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +from vllm.config import VllmConfig +from vllm.config.vllm import set_current_vllm_config +from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer +from vllm.model_executor.models.utils import is_pp_missing_parameter +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, KVCacheTensor +from vllm.v1.worker.gpu import attn_utils + +from .configuration_qwen3_tts import Qwen3TTSTalkerCodePredictorConfig, Qwen3TTSTalkerConfig + + +class _LocalPredictorKVCache: + """Minimal local KV cache + attention metadata for running + code_predictor inside one worker (independent of engine KV).""" + + def __init__( + self, + *, + vllm_config: VllmConfig, + max_seq_len: int, + max_batch_size: int, + device: torch.device, + ) -> None: + self.vllm_config = vllm_config + self.device = device + + # Collect attention layers registered in this vllm_config. + kv_cache_spec_by_layer = attn_utils.get_kv_cache_spec(vllm_config) + if not kv_cache_spec_by_layer: + raise RuntimeError("Local predictor KVCache requires vLLM Attention layers to be registered.") + + # We only need enough blocks for a tiny per-frame sequence (<= max_seq_len). + any_spec = next(iter(kv_cache_spec_by_layer.values())) + block_size = int(any_spec.block_size) + blocks_per_seq = (int(max_seq_len) + block_size - 1) // block_size + num_blocks = max(1, int(max_batch_size) * int(blocks_per_seq)) + + # Allocate per-layer KV caches (small, independent). + kv_cache_tensors: list[KVCacheTensor] = [] + for layer_name, spec in kv_cache_spec_by_layer.items(): + kv_cache_tensors.append(KVCacheTensor(size=int(spec.page_size_bytes) * num_blocks, shared_by=[layer_name])) + + merged_spec: KVCacheSpec = KVCacheSpec.merge(list(kv_cache_spec_by_layer.values())) + self.kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=[ + KVCacheGroupSpec(layer_names=list(kv_cache_spec_by_layer.keys()), kv_cache_spec=merged_spec) + ], + ) + + # Init backend + bind KV cache tensors to attention modules. + self.attn_backends, self.attn_metadata_builders = attn_utils.init_attn_backend( + self.kv_cache_config, vllm_config, device + ) + self.runner_kv_caches: list[torch.Tensor] = [] + attn_utils.init_kv_cache( + self.runner_kv_caches, + vllm_config.compilation_config.static_forward_context, + self.kv_cache_config, + self.attn_backends, + device, + ) + + # Precompute a fixed block table mapping for the maximum batch. + self.block_size = block_size + self.blocks_per_seq = blocks_per_seq + self.max_batch_size = int(max_batch_size) + + bt = torch.full((self.max_batch_size, self.blocks_per_seq), -1, dtype=torch.int32, device=device) + for i in range(self.max_batch_size): + for j in range(self.blocks_per_seq): + bt[i, j] = i * self.blocks_per_seq + j + self._block_table = bt + + def build_attn_metadata( + self, + *, + num_reqs: int, + query_lens: torch.Tensor, # (num_reqs,) int32 on cpu + seq_lens: torch.Tensor, # (num_reqs,) int32 on cpu + ) -> tuple[dict[str, Any], torch.Tensor, dict[str, torch.Tensor]]: + """Build attention metadata, positions, and slot_mapping dict. + + Returns: + (attn_metadata, positions, slot_mappings_by_layer) + - attn_metadata: per-layer attention metadata for attn backends. + - positions: (num_tokens,) position IDs on device. + - slot_mappings_by_layer: {layer_name: slot_mapping_tensor} for + set_forward_context so that unified_kv_cache_update can write + the KV cache correctly. + """ + num_reqs = int(num_reqs) + if num_reqs <= 0: + return {}, torch.empty((0,), dtype=torch.int64, device=self.device), {} + if num_reqs > self.max_batch_size: + raise ValueError(f"num_reqs={num_reqs} exceeds local predictor max_batch_size={self.max_batch_size}") + + query_lens_i32 = query_lens.to(dtype=torch.int32, device="cpu") + seq_lens_i32 = seq_lens.to(dtype=torch.int32, device="cpu") + + # query_start_loc: prefix sums of query_lens. + qsl = torch.zeros((num_reqs + 1,), dtype=torch.int32, device="cpu") + qsl[1:] = torch.cumsum(query_lens_i32, dim=0) + num_tokens = int(qsl[-1].item()) + if num_tokens <= 0: + return {}, torch.empty((0,), dtype=torch.int64, device=self.device), {} + + # positions: for each request i, emit positions [seq_len-query_len .. seq_len-1] + pos_list: list[torch.Tensor] = [] + for i in range(num_reqs): + ql = int(query_lens_i32[i].item()) + sl = int(seq_lens_i32[i].item()) + start = sl - ql + pos_list.append(torch.arange(start, sl, dtype=torch.int64)) + positions_cpu = torch.cat(pos_list, dim=0) + + # slot_mapping: map each query token to a physical slot in the paged KV cache. + # We allocate per-request contiguous blocks; slot = base + position. + slot_mapping = torch.empty((num_tokens,), dtype=torch.int64, device="cpu") + cursor = 0 + for i in range(num_reqs): + ql = int(query_lens_i32[i].item()) + sl = int(seq_lens_i32[i].item()) + start = sl - ql + for p in range(start, sl): + block_idx = p // self.block_size + offset = p % self.block_size + block_id = int(self._block_table[i, block_idx].item()) + slot_mapping[cursor] = block_id * self.block_size + offset + cursor += 1 + + max_seq_len = int(seq_lens_i32[:num_reqs].max().item()) + query_start_loc_gpu = qsl.to(device=self.device) + seq_lens_gpu = seq_lens_i32.to(device=self.device) + block_table = self._block_table[:num_reqs].contiguous() + slot_mapping_gpu = slot_mapping.to(device=self.device) + + attn_metadata = attn_utils.build_attn_metadata( + self.attn_metadata_builders, + num_reqs=num_reqs, + num_tokens=num_tokens, + query_start_loc_gpu=query_start_loc_gpu, + query_start_loc_cpu=qsl, + seq_lens=seq_lens_gpu, + max_seq_len=max_seq_len, + block_tables=[block_table], + slot_mappings=[slot_mapping_gpu], + kv_cache_config=self.kv_cache_config, + ) + + # Build slot_mappings_by_layer for set_forward_context. + # Fix for vllm 0.15.0 + slot_mappings_by_layer: dict[str, torch.Tensor] = {} + for kv_cache_group in self.kv_cache_config.kv_cache_groups: + for layer_name in kv_cache_group.layer_names: + slot_mappings_by_layer[layer_name] = slot_mapping_gpu + + return attn_metadata, positions_cpu.to(device=self.device), slot_mappings_by_layer + + +class Qwen3TTSTalkerCodePredictorModelVLLM(nn.Module): + def __init__( + self, + config: Qwen3TTSTalkerCodePredictorConfig, + *, + talker_hidden_size: int | None = None, + cache_config=None, + quant_config=None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer( + config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Official code_predictor uses one embedding table per residual group. + # Some Qwen3-TTS checkpoints store codec embeddings in the talker hidden + # space, even when `code_predictor_config.hidden_size` is smaller. + # We keep the embedding dim aligned with the checkpoint and project down + # via `small_to_mtp_projection` in the wrapper module. + emb_dim = int(talker_hidden_size) if talker_hidden_size is not None else int(config.hidden_size) + self.codec_embedding = nn.ModuleList( + [nn.Embedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)] + ) + + def get_input_embeddings(self) -> nn.ModuleList: + return self.codec_embedding + + def forward(self, positions: torch.Tensor, inputs_embeds: torch.Tensor) -> torch.Tensor: + # Token-major: [num_tokens, hidden] + hidden_states = inputs_embeds + residual = None + for layer in self.layers: + hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Match vLLM Qwen2/Qwen3 packing conventions: q_proj/k_proj/v_proj -> qkv_proj, + # gate_proj/up_proj -> gate_up_proj. + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name)): + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + mapped = name.replace(weight_name, param_name) + if mapped.endswith(".bias") and mapped not in params_dict: + continue + if is_pp_missing_parameter(mapped, self): + continue + if mapped.endswith("scale"): + mapped = maybe_remap_kv_scale_name(mapped, params_dict) + if mapped is None: + continue + param = params_dict.get(mapped) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(mapped) + break + else: + mapped = maybe_remap_kv_scale_name(name, params_dict) + if mapped is None: + continue + if name.endswith(".bias") and mapped not in params_dict: + continue + if is_pp_missing_parameter(mapped, self): + continue + param = params_dict.get(mapped) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(mapped) + return loaded_params + + +class Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM(nn.Module): + """vLLM-native code_predictor used by the AR talker (residual codebooks).""" + + def __init__( + self, + *, + vllm_config: VllmConfig, + config: Qwen3TTSTalkerCodePredictorConfig, + talker_config: Qwen3TTSTalkerConfig, + prefix: str = "code_predictor", + ) -> None: + super().__init__() + self._vllm_config = vllm_config + self.config = config + self.talker_config = talker_config + + # Keep module/weight names aligned with official checkpoint (talker.code_predictor.model.*). + self.model = Qwen3TTSTalkerCodePredictorModelVLLM( + config, + talker_hidden_size=int(talker_config.hidden_size), + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.model", + ) + + # One head per residual group. + self.lm_head = nn.ModuleList( + [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)] + ) + + if config.hidden_size != talker_config.hidden_size: + self.small_to_mtp_projection = nn.Linear(talker_config.hidden_size, config.hidden_size, bias=True) + else: + self.small_to_mtp_projection = nn.Identity() + + self._kv_cache: _LocalPredictorKVCache | None = None + + def get_input_embeddings(self) -> nn.ModuleList: + return self.model.get_input_embeddings() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Ensure all vLLM custom layers consult the predictor vllm_config + # (esp. for Attention static_forward_context). + with set_current_vllm_config(self._vllm_config): + loaded: set[str] = set() + model_weights: list[tuple[str, torch.Tensor]] = [] + other_weights: list[tuple[str, torch.Tensor]] = [] + for name, w in weights: + if name.startswith("model."): + model_weights.append((name[len("model.") :], w)) + else: + other_weights.append((name, w)) + + loaded_model = self.model.load_weights(model_weights) + loaded |= {f"model.{n}" for n in loaded_model} + + params = dict(self.named_parameters(remove_duplicate=False)) + for name, w in other_weights: + if name not in params: + continue + default_weight_loader(params[name], w) + loaded.add(name) + return loaded + + def _maybe_init_kv_cache(self, device: torch.device) -> None: + if self._kv_cache is not None: + return + max_seq_len = int(getattr(self.config, "num_code_groups", 16) or 16) + # Upper bound on batch size: vLLM scheduler max_num_seqs (fallback 8). + max_batch = int(getattr(self._vllm_config.scheduler_config, "max_num_seqs", 8) or 8) + max_batch = max(1, max_batch) + self._kv_cache = _LocalPredictorKVCache( + vllm_config=self._vllm_config, + max_seq_len=max_seq_len, + max_batch_size=max_batch, + device=device, + ) + + @torch.inference_mode() + def reset_cache(self) -> None: + # We reuse a fixed kv cache buffer and overwrite starting at slot 0. + # No action required here (seq_lens controls what is read). + return + + @torch.inference_mode() + def prefill_logits(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + """Prefill with 2 tokens: [past_hidden, layer0_embed]. Returns logits for residual group 0.""" + self._maybe_init_kv_cache(inputs_embeds.device) + assert self._kv_cache is not None + + bsz = int(inputs_embeds.shape[0]) + qlen = 2 + # Flatten to token-major. + hs = inputs_embeds.to(dtype=torch.bfloat16).reshape(bsz * qlen, -1) + hs = self.small_to_mtp_projection(hs) + + query_lens = torch.full((bsz,), qlen, dtype=torch.int32) + seq_lens = query_lens.clone() + attn_metadata, positions, slot_mappings = self._kv_cache.build_attn_metadata( + num_reqs=bsz, query_lens=query_lens, seq_lens=seq_lens + ) + + with ( + set_current_vllm_config(self._vllm_config), + set_forward_context( + attn_metadata, + self._vllm_config, + num_tokens=int(hs.shape[0]), + slot_mapping=slot_mappings, + ), + ): + out = self.model(positions=positions, inputs_embeds=hs) + + # Gather last token per request. + last_idx = torch.arange(qlen - 1, bsz * qlen, step=qlen, device=out.device, dtype=torch.long) + last_h = out.index_select(0, last_idx) + logits = self.lm_head[0](last_h) + return logits + + @torch.inference_mode() + def decode_logits(self, input_ids: torch.Tensor, *, generation_step: int, past_seq_len: int) -> torch.Tensor: + """Decode one new token for residual group `generation_step` (1..Q-1).""" + self._maybe_init_kv_cache(input_ids.device) + assert self._kv_cache is not None + bsz = int(input_ids.shape[0]) + if generation_step <= 0: + raise ValueError("generation_step must be >= 1 for decode_logits") + + embed_idx = generation_step - 1 + hs = self.model.get_input_embeddings()[embed_idx](input_ids.to(dtype=torch.long).reshape(bsz, 1)) + hs = self.small_to_mtp_projection(hs.reshape(bsz, -1)) + + query_lens = torch.ones((bsz,), dtype=torch.int32) + seq_lens = torch.full((bsz,), int(past_seq_len) + 1, dtype=torch.int32) + attn_metadata, positions, slot_mappings = self._kv_cache.build_attn_metadata( + num_reqs=bsz, query_lens=query_lens, seq_lens=seq_lens + ) + + with ( + set_current_vllm_config(self._vllm_config), + set_forward_context( + attn_metadata, + self._vllm_config, + num_tokens=int(hs.shape[0]), + slot_mapping=slot_mappings, + ), + ): + out = self.model(positions=positions, inputs_embeds=hs) + + logits = self.lm_head[generation_step](out) + return logits + + @torch.inference_mode() + def forward( + self, + layer0_code: torch.Tensor, + layer0_embed: torch.Tensor, + last_talker_hidden: torch.Tensor, + do_sample: bool = True, + temperature: float = 1.0, + top_k: int = 50, + top_p: float = 1.0, + ) -> torch.Tensor: + """Full autoregressive prediction of residual codebooks 1..Q-1. + + Args: + layer0_code: [B, 1] first-layer codec token ids. + layer0_embed: [B, 1, H] embedding of layer0_code. + last_talker_hidden: [B, 1, H] hidden state from the talker. + do_sample: whether to sample or take argmax. + temperature: sampling temperature. + top_k: top-k filtering. + top_p: top-p (nucleus) filtering. + + Returns: + audio_codes: [B, Q] all codebook tokens (layer0 + residuals). + """ + bsz = int(layer0_code.shape[0]) + num_groups = int(self.config.num_code_groups) + max_steps = num_groups - 1 + + # Reset KV cache for a fresh sequence. + self.reset_cache() + + # Prefill: feed [last_talker_hidden, layer0_embed] → logits for group 1. + prefill_input = torch.cat([last_talker_hidden, layer0_embed], dim=1) # [B, 2, H] + logits = self.prefill_logits(prefill_input) # [B, vocab] + + all_codes = [layer0_code.reshape(bsz, 1)] + past_seq_len = 2 + + for step in range(1, num_groups): + # Sample or argmax from logits. + if do_sample and temperature > 0: + scaled = logits / temperature + if top_k > 0: + topk_vals, _ = scaled.topk(top_k, dim=-1) + scaled = scaled.masked_fill(scaled < topk_vals[:, -1:], float("-inf")) + probs = torch.softmax(scaled, dim=-1) + next_ids = torch.multinomial(probs, num_samples=1) # [B, 1] + else: + next_ids = logits.argmax(dim=-1, keepdim=True) # [B, 1] + all_codes.append(next_ids) + + # If not the last step, decode one more token. + if step < max_steps: + logits = self.decode_logits( + next_ids.reshape(bsz), + generation_step=step, + past_seq_len=past_seq_len, + ) + past_seq_len += 1 + + return torch.cat(all_codes, dim=1) # [B, Q] diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py new file mode 100644 index 00000000000..a39eded3aa6 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -0,0 +1,1587 @@ +from __future__ import annotations + +import base64 +import dataclasses +import io +import os +from collections.abc import Callable, Iterable, Mapping +from typing import Any +from urllib.parse import urlparse + +import numpy as np +import soundfile as sf +import torch +import torch.nn as nn +import torch.nn.functional as F +from librosa.filters import mel as librosa_mel_fn +from transformers import AutoTokenizer +from transformers.activations import ACT2FN +from transformers.utils.hub import cached_file +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.qwen3 import Qwen3Model +from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, WeightsMapper, maybe_prefix +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +from .configuration_qwen3_tts import Qwen3TTSConfig, Qwen3TTSSpeakerEncoderConfig, Qwen3TTSTalkerConfig +from .qwen3_tts_code_predictor_vllm import Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM +from .qwen3_tts_tokenizer import Qwen3TTSTokenizer + +logger = init_logger(__name__) + + +# --------------------------------------------------------------------------- +# Components ported from the HuggingFace Qwen3-TTS reference implementation. +# Only the classes actually needed by the vLLM AR Talker are kept here. +# --------------------------------------------------------------------------- + + +class Qwen3TTSTalkerResizeMLP(nn.Module): + """Two-layer MLP that maps between hidden sizes with an activation in between.""" + + def __init__(self, input_size: int, intermediate_size: int, output_size: int, act: str, bias=False): + super().__init__() + self.linear_fc1 = nn.Linear(input_size, intermediate_size, bias=bias) + self.linear_fc2 = nn.Linear(intermediate_size, output_size, bias=bias) + self.act_fn = ACT2FN[act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +# ---- Speaker encoder (ECAPA-TDNN) and helpers ---- + + +class TimeDelayNetBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + in_channel = in_channels // scale + hidden_channel = out_channels // scale + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock(in_channel, hidden_channel, kernel_size=kernel_size, dilation=dilation) + for _ in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + return torch.cat(outputs, dim=1) + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + self.conv1 = nn.Conv1d(in_channels, se_channels, kernel_size=1, padding="same", padding_mode="reflect") + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d(se_channels, out_channels, kernel_size=1, padding="same", padding_mode="reflect") + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + return hidden_states * hidden_states_mean + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """TDNN-Res2Net-TDNN-SE building block used in ECAPA-TDNN.""" + + def __init__(self, in_channels, out_channels, res2net_scale=8, se_channels=128, kernel_size=1, dilation=1): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock(in_channels, out_channels, kernel_size=1, dilation=1) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock(out_channels, out_channels, kernel_size=1, dilation=1) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + return hidden_state + residual + + +class AttentiveStatisticsPooling(nn.Module): + """Attentive statistic pooling layer: returns concatenated mean and std.""" + + def __init__(self, channels, attention_channels=128): + super().__init__() + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d(attention_channels, channels, kernel_size=1, padding="same", padding_mode="reflect") + + @staticmethod + def _length_to_mask(length, max_len=None, dtype=None, device=None): + if max_len is None: + max_len = length.max().long().item() + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + return torch.as_tensor(mask, dtype=dtype, device=device) + + @staticmethod + def _compute_statistics(x, m, dim=2, eps=1e-12): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + mask = self._length_to_mask( + lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device + ) + mask = mask.unsqueeze(1) + total = mask.sum(dim=2, keepdim=True) + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + attention = self.conv(self.tanh(self.tdnn(attention))) + attention = attention.masked_fill(mask == 0, float("-inf")) + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + pooled_stats = torch.cat((mean, std), dim=1) + return pooled_stats.unsqueeze(2) + + +class Qwen3TTSSpeakerEncoder(torch.nn.Module): + """ECAPA-TDNN speaker encoder. + + Reference: "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143). + """ + + def __init__(self, config: Qwen3TTSSpeakerEncoderConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], + ) + ) + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], config.enc_channels[-1], config.enc_kernel_sizes[-1], config.enc_dilations[-1] + ) + self.asp = AttentiveStatisticsPooling(config.enc_channels[-1], attention_channels=config.enc_attention_channels) + self.fc = nn.Conv1d( + config.enc_channels[-1] * 2, + config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + hidden_states = self.asp(hidden_states) + hidden_states = self.fc(hidden_states) + return hidden_states.squeeze(-1) + + +# ---- Audio utilities ---- + + +def _dynamic_range_compression(x, c=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * c) + + +def mel_spectrogram( + y: torch.Tensor, + n_fft: int, + num_mels: int, + sampling_rate: int, + hop_size: int, + win_size: int, + fmin: int, + fmax: int | None = None, + center: bool = False, +) -> torch.Tensor: + """Calculate mel spectrogram of an input signal using librosa mel filterbank and torch STFT.""" + if torch.min(y) < -1.0: + logger.warning("Min value of input waveform signal is %s", torch.min(y)) + if torch.max(y) > 1.0: + logger.warning("Max value of input waveform signal is %s", torch.max(y)) + device = y.device + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis = torch.from_numpy(mel).float().to(device) + hann_window = torch.hann_window(win_size).to(device) + padding = (n_fft - hop_size) // 2 + y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + mel_spec = torch.matmul(mel_basis, spec) + return _dynamic_range_compression(mel_spec) + + +# --------------------------------------------------------------------------- +# Main AR Talker model +# --------------------------------------------------------------------------- + + +class Qwen3TTSTalkerForConditionalGeneration(nn.Module): + """vLLM-AR talker: step-wise layer-0 codec decoding. + Predicts residual codebooks (1..Q-1) into `audio_codes` and streams text via `tailing_text_hidden`.""" + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Talker backbone (Qwen3 decoder-only). + "talker.model.layers.": "model.layers.", + "talker.model.norm.": "model.norm.", + "talker.model.codec_embedding.": "model.embed_tokens.", + # Heads / side modules. + "talker.codec_head.": "lm_head.", + "talker.model.text_embedding.": "text_embedding.", + "talker.text_projection.": "text_projection.", + "talker.code_predictor.": "code_predictor.", + # Speaker encoder (Base only). + "speaker_encoder.": "speaker_encoder.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.model_path = vllm_config.model_config.model + self.config: Qwen3TTSConfig = vllm_config.model_config.hf_config # type: ignore[assignment] + self.talker_config: Qwen3TTSTalkerConfig = self.config.talker_config + + # Codec ids: only [0, codebook_vocab_size) are real code indices (layer-0 is sampled from talker vocab). + # codec_eos_token_id is a special stop token and must not be decoded by SpeechTokenizer. + self._codebook_vocab_size = int(getattr(self.talker_config.code_predictor_config, "vocab_size", 0) or 0) + if self._codebook_vocab_size <= 0: + raise ValueError( + f"Invalid talker_config.code_predictor_config.vocab_size={self._codebook_vocab_size}; " + "cannot restrict codec logits safely." + ) + self._codec_eos_token_id = int(getattr(self.talker_config, "codec_eos_token_id", -1)) + + self._eos_logit_bias: float = 0.0 + + self.have_multimodal_outputs = True + self.has_preprocess = True + self.has_postprocess = True + + # Used by OmniGPUModelRunner for the GPU-side MTP fast-path. + self.mtp_hidden_size = int(self.talker_config.hidden_size) + # OmniGPUModelRunner will store talker_mtp output under this key in + # per-request additional_information. + self.talker_mtp_output_key = "audio_codes" + + self.model = Qwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.talker_config.vocab_size, + self.talker_config.hidden_size, + quant_config=vllm_config.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(self.talker_config.vocab_size) + self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors + + # Text embedding is a separate table in the official implementation. + self.text_embedding = nn.Embedding(self.talker_config.text_vocab_size, self.talker_config.text_hidden_size) + self.text_projection = Qwen3TTSTalkerResizeMLP( + self.talker_config.text_hidden_size, + self.talker_config.text_hidden_size, + self.talker_config.hidden_size, + self.talker_config.hidden_act, + bias=True, + ) + + # Speaker encoder is only needed for Base voice cloning and may be missing in some checkpoints. + # Keep it optional to avoid strict weight-loading failures. + self.speaker_encoder: Qwen3TTSSpeakerEncoder | None = None + + # Code predictor uses an isolated vLLM config so its KV cache doesn't + # pollute the main engine's static_forward_context (shallow-copy shares + # the dict by reference — must assign a fresh one). + predictor_compilation = dataclasses.replace(vllm_config.compilation_config) + predictor_compilation.static_forward_context = {} + self._code_predictor_vllm_config = dataclasses.replace(vllm_config, compilation_config=predictor_compilation) + from vllm.config.vllm import set_current_vllm_config as _set_cfg + + with _set_cfg(self._code_predictor_vllm_config): + self.code_predictor = Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM( + vllm_config=self._code_predictor_vllm_config, + config=self.talker_config.code_predictor_config, + talker_config=self.talker_config, + prefix="code_predictor", + ) + + # Constant logit mask: allow only codec ids [1, codebook_vocab_size) plus codec EOS. + vocab = int(self.talker_config.vocab_size) + codec_mask = torch.zeros((vocab,), dtype=torch.bool) + lo, hi = 1, min(self._codebook_vocab_size, vocab) + if hi > lo: + codec_mask[lo:hi] = True + if 0 <= self._codec_eos_token_id < vocab: + codec_mask[self._codec_eos_token_id] = True + self.register_buffer("_codec_allowed_mask", codec_mask, persistent=False) + + # Tokenizer for prompt building. + self._tokenizer = None + self._speech_tokenizer: Qwen3TTSTokenizer | None = None + + # -------------------- vLLM required hooks -------------------- + + def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **_: Any, + ) -> torch.Tensor | IntermediateTensors: + return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + + def compute_logits( + self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None + ) -> torch.Tensor | None: + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + if hidden_states is None: + return None + logits = self.logits_processor(self.lm_head, hidden_states) + if logits is None: + return None + + # Mask out invalid codec ids using the pre-built constant buffer. + logits = logits.masked_fill(~self._codec_allowed_mask, float("-inf")) + + if self._eos_logit_bias != 0.0: + eos_id = self._codec_eos_token_id + if 0 <= eos_id < logits.shape[-1]: + logits[:, eos_id] = logits[:, eos_id] + self._eos_logit_bias + + return logits + + # -------------------- Omni multimodal output plumbing -------------------- + + def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput: + if isinstance(model_outputs, OmniOutput): + return model_outputs + + hidden = model_outputs + info_dicts = kwargs.get("runtime_additional_information") or [] + audio_codes_list: list[torch.Tensor] = [] + ref_code_len_list: list[torch.Tensor] = [] + codec_streaming_list: list[torch.Tensor] = [] + for info in info_dicts: + if not isinstance(info, dict): + continue + ac = info.get("audio_codes") + if isinstance(ac, torch.Tensor): + audio_codes_list.append(ac) + cs = info.get("codec_streaming") + if isinstance(cs, bool): + codec_streaming_list.append( + torch.full((int(ac.shape[0]),), int(cs), dtype=torch.int8, device=ac.device) + ) + ref_len = info.get("ref_code_len") + if ref_len is None: + continue + if isinstance(ref_len, torch.Tensor): + if ref_len.numel() == 0: + raise ValueError("ref_code_len is an empty tensor") + ref_len_val = int(ref_len.reshape(-1)[-1].item()) + elif isinstance(ref_len, list): + if len(ref_len) != 1: + raise ValueError(f"ref_code_len must be scalar or 1-element list, got len={len(ref_len)}") + ref_len_val = int(ref_len[0]) + else: + ref_len_val = int(ref_len) + if isinstance(ac, torch.Tensor): + # Emit ref_code_len per-token span for runner slicing (consumer takes the last value). + ref_code_len_list.append( + torch.full((int(ac.shape[0]),), ref_len_val, dtype=torch.int32, device=ac.device) + ) + + if not audio_codes_list: + return OmniOutput(text_hidden_states=hidden, multimodal_outputs={}) + + audio_codes = torch.cat(audio_codes_list, dim=0) + span_len = int(audio_codes.shape[0]) + hidden = hidden[:span_len] + mm: dict[str, torch.Tensor] = {"audio_codes": audio_codes} + if ref_code_len_list: + mm["ref_code_len"] = torch.cat(ref_code_len_list, dim=0)[:span_len] + if codec_streaming_list: + mm["codec_streaming"] = torch.cat(codec_streaming_list, dim=0)[:span_len] + return OmniOutput(text_hidden_states=hidden, multimodal_outputs=mm) + + # -------------------- preprocess / postprocess -------------------- + + def preprocess( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor | None, + **info_dict: Any, + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + # Metadata may be passed flattened or under `additional_information`; normalize to flattened keys. + additional_information = info_dict.get("additional_information") + if isinstance(additional_information, dict): + merged: dict[str, Any] = {k: v for k, v in info_dict.items() if k != "additional_information"} + for k, v in additional_information.items(): + merged.setdefault(k, v) + info_dict = merged + + span_len = int(input_ids.shape[0]) + if span_len <= 0: + return input_ids, input_embeds if input_embeds is not None else self.embed_input_ids(input_ids), {} + + text_list = info_dict.get("text") + if not isinstance(text_list, list) or not text_list or not text_list[0]: + raise ValueError("Missing additional_information.text for Qwen3-TTS AR talker.") + + task_type = (info_dict.get("task_type") or ["CustomVoice"])[0] + codec_streaming_val = info_dict.get("codec_streaming") + if isinstance(codec_streaming_val, list): + codec_streaming_raw = codec_streaming_val[0] if codec_streaming_val else None + else: + codec_streaming_raw = codec_streaming_val + if isinstance(codec_streaming_raw, bool): + codec_streaming = codec_streaming_raw + else: + codec_streaming = task_type == "Base" + + if span_len > 1: + # Prefill (prompt embeddings) + prompt_embeds_cpu = info_dict.get("talker_prompt_embeds") + tts_pad_embed_cpu = info_dict.get("tts_pad_embed") + tts_pad_embed = None + if isinstance(tts_pad_embed_cpu, torch.Tensor) and tts_pad_embed_cpu.numel() > 0: + tts_pad_embed = tts_pad_embed_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + + # First prefill round: prompt_embeds_cpu is not yet populated. + # Subsequent prefill rounds (multi-chunk): prompt_embeds_cpu is a Tensor stored by the first round. + is_first_prefill = not isinstance(prompt_embeds_cpu, torch.Tensor) or prompt_embeds_cpu.ndim != 2 + if is_first_prefill: + full_prompt_embeds, tailing_text_hidden, tts_pad_embed, ref_code_len = self._build_prompt_embeds( + task_type=task_type, info_dict=info_dict + ) + # Store full prompt embeddings + trailing queue on CPU for later chunks/steps. + prompt_embeds_cpu = full_prompt_embeds.detach().to("cpu").contiguous() + info_update: dict[str, Any] = { + "talker_prompt_embeds": prompt_embeds_cpu, + "tailing_text_hidden": tailing_text_hidden.detach().to("cpu").contiguous(), + "tts_pad_embed": tts_pad_embed.detach().to("cpu").contiguous(), + "talker_prefill_offset": 0, + "codec_streaming": codec_streaming, + } + if ref_code_len is not None: + info_update["ref_code_len"] = int(ref_code_len) + # Always return a span_len slice; if the scheduled placeholder is longer, pad with tts_pad_embed. + # This preserves placeholder/embedding alignment. + offset = 0 + s = 0 + e = span_len + take = prompt_embeds_cpu[s:e] + if int(take.shape[0]) < span_len: + pad_n = int(span_len - int(take.shape[0])) + pad_rows = tts_pad_embed.detach().to("cpu").contiguous().reshape(1, -1).expand(pad_n, -1) + take = torch.cat([take, pad_rows], dim=0) + prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16) + info_update["talker_prefill_offset"] = int(offset + span_len) + else: + # Subsequent prefill chunk: slice from stored embeddings at running offset. + if tts_pad_embed is None: + raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must initialize it.") + offset = int(info_dict.get("talker_prefill_offset", 0) or 0) + if offset < 0: + offset = 0 + s = max(0, min(offset, int(prompt_embeds_cpu.shape[0]))) + e = max(0, min(offset + span_len, int(prompt_embeds_cpu.shape[0]))) + take = prompt_embeds_cpu[s:e] + if int(take.shape[0]) < span_len: + pad_n = int(span_len - int(take.shape[0])) + pad_rows = tts_pad_embed.detach().to("cpu").contiguous().reshape(1, -1).expand(pad_n, -1) + take = torch.cat([take, pad_rows], dim=0) + prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16) + info_update = {"talker_prefill_offset": int(offset + span_len)} + info_update["codec_streaming"] = codec_streaming + + # When inputs_embeds is set, token ids are ignored by the model but must stay in-vocab for vLLM bookkeeping. + input_ids_out = input_ids.clone() + input_ids_out[:] = int(self.talker_config.codec_pad_id) + + zeros = torch.zeros( + (prompt_embeds.shape[0], int(self.talker_config.num_code_groups)), + device=input_ids.device, + dtype=torch.long, + ) + info_update["audio_codes"] = zeros + return input_ids_out, prompt_embeds, info_update + + # Decode: span_len == 1 + # Pop one text-step vector from tailing_text_hidden queue. + tts_pad_embed_cpu = info_dict.get("tts_pad_embed") + if not isinstance(tts_pad_embed_cpu, torch.Tensor): + raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must run first.") + tts_pad_embed = tts_pad_embed_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + + tail_cpu = info_dict.get("tailing_text_hidden") + if isinstance(tail_cpu, torch.Tensor) and tail_cpu.ndim == 2 and tail_cpu.shape[0] > 0: + text_step = tail_cpu[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + new_tail = tail_cpu[1:].detach().to("cpu").contiguous() if tail_cpu.shape[0] > 1 else tail_cpu[:0] + else: + text_step = tts_pad_embed + new_tail = tail_cpu if isinstance(tail_cpu, torch.Tensor) else torch.empty((0, tts_pad_embed.shape[-1])) + + last_hidden_cpu = info_dict.get("last_talker_hidden") + if not isinstance(last_hidden_cpu, torch.Tensor): + raise RuntimeError("Missing `last_talker_hidden` in additional_information; postprocess must run.") + past_hidden = last_hidden_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + + # Use OmniGPUModelRunner talker_mtp fast-path for residual codebooks and per-step inputs_embeds update. + last_id_hidden = self.embed_input_ids(input_ids.reshape(1, 1).to(torch.long)).to( + device=input_ids.device, dtype=torch.bfloat16 + ) + inputs_embeds_out = last_id_hidden.reshape(1, -1) + + info_update = { + "tailing_text_hidden": new_tail, + "mtp_inputs": (past_hidden, text_step), + "codec_streaming": codec_streaming, + } + return input_ids, inputs_embeds_out, info_update + + def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]: + # Keep the last token hidden for the next decode step's code predictor. + if hidden_states.numel() == 0: + return {} + last = hidden_states[-1, :].detach().to("cpu").contiguous() + return {"last_talker_hidden": last} + + # -------------------- prompt construction helpers -------------------- + + def _get_tokenizer(self): + if self._tokenizer is None: + self._tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + trust_remote_code=True, + fix_mistral_regex=True, + use_fast=True, + ) + self._tokenizer.padding_side = "left" + return self._tokenizer + + @staticmethod + def _build_assistant_text(text: str) -> str: + return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + + @staticmethod + def _build_ref_text(text: str) -> str: + return f"<|im_start|>assistant\n{text}<|im_end|>\n" + + @staticmethod + def _build_instruct_text(instruct: str) -> str: + return f"<|im_start|>user\n{instruct}<|im_end|>\n" + + @staticmethod + def estimate_prompt_len_from_additional_information( + additional_information: dict[str, Any] | None, + *, + task_type: str, + tokenize_prompt: Callable[[str], list[int]], + codec_language_id: Mapping[str, int] | None, + spk_is_dialect: Mapping[str, object] | None, + estimate_ref_code_len: Callable[[object], int | None] | None = None, + ) -> int: + """Compute Stage-0 placeholder prompt length (length-only mirror of `_build_prompt_embeds()`). + It must match the model-side `inputs_embeds` length to avoid extra padding and quality drop.""" + + def _first(x: object, default: object) -> object: + if isinstance(x, list): + return x[0] if x else default + return x if x is not None else default + + info: dict[str, Any] = additional_information or {} + text = _first(info.get("text"), "") + language = _first(info.get("language"), "Auto") + speaker = _first(info.get("speaker"), "") + instruct = _first(info.get("instruct"), "") + non_streaming_mode_raw = _first(info.get("non_streaming_mode"), None) + + if isinstance(non_streaming_mode_raw, bool): + non_streaming_mode = non_streaming_mode_raw + else: + # Official defaults: CustomVoice/VoiceDesign -> non_streaming_mode=True; Base -> False. + non_streaming_mode = task_type in ("CustomVoice", "VoiceDesign") + + if not isinstance(text, str): + text = "" + if not isinstance(instruct, str): + instruct = "" + if not isinstance(language, str): + language = "Auto" + + instruct_len = 0 + if instruct.strip(): + instruct_text = Qwen3TTSTalkerForConditionalGeneration._build_instruct_text(instruct) + instruct_len = len(tokenize_prompt(instruct_text)) + + # ---- codec prefix portion (matches _build_prompt_embeds) ---- + language_id = None + if language.lower() != "auto" and codec_language_id: + language_id = codec_language_id.get(language.lower()) + if ( + language_id is None + and codec_language_id + and spk_is_dialect + and isinstance(language, str) + and language.lower() in ("chinese", "auto") + and isinstance(speaker, str) + and speaker.strip() + ): + dialect = spk_is_dialect.get(speaker.lower()) + if isinstance(dialect, str) and dialect: + language_id = codec_language_id.get(dialect) + prefill_len = 3 if language_id is None else 4 + + speaker_len = 1 if task_type in ("CustomVoice", "Base") else 0 + codec_input_len = prefill_len + speaker_len + 2 # + [codec_pad, codec_bos] + codec_prefix_len = codec_input_len - 1 # codec_input[:-1] + tts_bos + + # Role header: input_ids[:, :3] in model. + role_len = 3 + prompt_len = instruct_len + role_len + codec_prefix_len + + # ---- text conditioning portion (matches _build_prompt_embeds) ---- + assistant_text = Qwen3TTSTalkerForConditionalGeneration._build_assistant_text(text) + assistant_len = len(tokenize_prompt(assistant_text)) + if assistant_len < 8: + raise ValueError(f"Unexpected assistant prompt length: {assistant_len}") + + if task_type in ("CustomVoice", "VoiceDesign"): + if non_streaming_mode: + # model: full text ids (input_ids[:, 3:-5]) + eos + codec_bos step + prompt_len += assistant_len - 6 + else: + # model: only first text token in prefill + prompt_len += 1 + + if task_type == "Base": + xvec_only = bool(_first(info.get("x_vector_only_mode"), False)) + in_context_mode = not xvec_only + + voice_clone_prompt = _first(info.get("voice_clone_prompt"), None) + if isinstance(voice_clone_prompt, dict): + icl_flag = _first(voice_clone_prompt.get("icl_mode"), None) + if isinstance(icl_flag, bool): + in_context_mode = icl_flag + + if in_context_mode: + ref_code = None + if isinstance(voice_clone_prompt, dict): + ref_code = _first(voice_clone_prompt.get("ref_code"), None) + + ref_code_len: int | None = None + if isinstance(ref_code, list): + if ref_code and isinstance(ref_code[0], list): + ref_code_len = len(ref_code) + elif ref_code: + ref_code_len = len(ref_code) + elif hasattr(ref_code, "shape"): + try: + shape = getattr(ref_code, "shape") + if shape and len(shape) >= 1: + ref_code_len = int(shape[0]) + except Exception: + ref_code_len = None + + if ref_code_len is None and estimate_ref_code_len is not None: + ref_code_len = estimate_ref_code_len(info.get("ref_audio")) + + if ref_code_len is None: + raise ValueError( + "Base in-context voice cloning requires either `voice_clone_prompt.ref_code` " + "or a readable `ref_audio` that can be mapped to a codec frame length." + ) + + codec_lens = 1 + int(ref_code_len) # codec_bos + ref_code + if non_streaming_mode: + # _generate_icl_prompt(non_streaming_mode=True): + # text_embed = ref_ids + text_ids + eos. + ref_ids = _first(info.get("ref_ids"), None) + if isinstance(voice_clone_prompt, dict) and ref_ids is None: + ref_ids = _first(voice_clone_prompt.get("ref_ids") or voice_clone_prompt.get("ref_id"), None) + + if ref_ids is None: + ref_text = _first(info.get("ref_text"), "") + if not isinstance(ref_text, str) or not ref_text.strip(): + raise ValueError( + "Base in-context non-streaming requires `ref_text` or tokenized `ref_ids`." + ) + ref_text_ids = tokenize_prompt(Qwen3TTSTalkerForConditionalGeneration._build_ref_text(ref_text)) + ref_ids_len = len(ref_text_ids) + elif hasattr(ref_ids, "shape"): + shape = getattr(ref_ids, "shape", None) + ref_ids_len = int(shape[-1]) if shape else 0 + elif isinstance(ref_ids, list): + ref_ids_len = len(ref_ids) + else: + ref_ids_len = 0 + + # model uses ref_ids[:, 3:-2] (strip 5 tokens) and text_id=input_ids[:, 3:-5] (strip 8). + ref_id_len = max(0, int(ref_ids_len) - 5) + text_id_len = max(0, int(assistant_len) - 8) + text_embed_len = ref_id_len + text_id_len + 1 # + eos + prompt_len += text_embed_len + codec_lens + else: + # _generate_icl_prompt(non_streaming_mode=False): aligned to codec_lens. + prompt_len += codec_lens + else: + # Base without ICL behaves like CustomVoice. + if non_streaming_mode: + prompt_len += assistant_len - 6 + else: + prompt_len += 1 + + return max(2, int(prompt_len)) + + def _is_probably_base64(self, s: str) -> bool: + if s.startswith("data:audio"): + return True + if ("/" not in s and "\\" not in s) and len(s) > 256: + return True + return False + + def _is_url(self, s: str) -> bool: + try: + u = urlparse(s) + return u.scheme in ("http", "https") and bool(u.netloc) + except Exception: + return False + + def _decode_base64_to_wav_bytes(self, b64: str) -> bytes: + if "," in b64 and b64.strip().startswith("data:"): + b64 = b64.split(",", 1)[1] + return base64.b64decode(b64) + + def _load_audio_to_np(self, x: str) -> tuple[np.ndarray, int]: + """Load audio from local path or base64 data URI (no network I/O).""" + import librosa + + if self._is_url(x): + raise ValueError("ref_audio URLs must be resolved by the serving layer before reaching the model worker.") + if self._is_probably_base64(x): + wav_bytes = self._decode_base64_to_wav_bytes(x) + with io.BytesIO(wav_bytes) as f: + audio, sr = sf.read(f, dtype="float32", always_2d=False) + else: + audio, sr = librosa.load(x, sr=None, mono=True) + + if isinstance(audio, np.ndarray) and audio.ndim > 1: + audio = np.mean(audio, axis=-1) + + return np.asarray(audio, dtype=np.float32), int(sr) + + def _normalize_ref_audio(self, ref_audio: object) -> tuple[np.ndarray, int]: + # NOTE: additional_information may serialize (wav, sr) into (nested) lists across processes; be tolerant. + if isinstance(ref_audio, str): + return self._load_audio_to_np(ref_audio) + + def _is_sr(x: object) -> bool: + try: + v = int(x) # type: ignore[arg-type] + except Exception: + return False + return 1_000 <= v <= 200_000 + + def _is_number_sequence(xs: list[object]) -> bool: + if not xs: + return False + for v in xs[:8]: + if not isinstance(v, (int, float, np.number)): + return False + return True + + wav_candidates: list[object] = [] + sr_candidates: list[int] = [] + + def _summarize(obj: object, depth: int = 0) -> str: + if depth > 2: + if isinstance(obj, (int, np.integer)): + return f"int({int(obj)})" + return type(obj).__name__ + if obj is None: + return "None" + if isinstance(obj, str): + if len(obj) <= 16: + return f"str({obj!r})" + return f"str(len={len(obj)})" + if isinstance(obj, (int, float, np.number)): + return f"{type(obj).__name__}({obj})" + if isinstance(obj, np.ndarray): + return f"ndarray(shape={obj.shape}, dtype={obj.dtype})" + if isinstance(obj, torch.Tensor): + return f"Tensor(shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device})" + if isinstance(obj, dict): + keys = list(obj.keys()) + return f"dict(keys={keys[:8]})" + if isinstance(obj, (tuple, list)): + items = list(obj) + head = ", ".join(_summarize(x, depth + 1) for x in items[:3]) + return f"{type(obj).__name__}(len={len(items)}; head=[{head}])" + return f"{type(obj).__name__}" + + def _scan(obj: object, depth: int = 0) -> None: + if depth > 4: + return + if obj is None: + return + if _is_sr(obj): + sr_candidates.append(int(obj)) # type: ignore[arg-type] + return + if isinstance(obj, np.ndarray) and obj.size > 0: + wav_candidates.append(obj) + return + if isinstance(obj, torch.Tensor) and obj.numel() > 0: + wav_candidates.append(obj) + return + if isinstance(obj, dict): + # Inlined ndarray/tensor payloads from OmniInputProcessor. + if obj.get("__ndarray__") and "data" in obj and "dtype" in obj and "shape" in obj: + try: + data = obj["data"] + dtype = obj["dtype"] + shape = obj["shape"] + if isinstance(data, (bytes, bytearray, memoryview)): + arr = np.frombuffer(data, dtype=dtype).reshape(shape) + if arr.size > 0: + wav_candidates.append(arr) + return + except Exception: + pass + if obj.get("__tensor__") and "data" in obj and "dtype" in obj and "shape" in obj: + try: + data = obj["data"] + dtype = obj["dtype"] + shape = obj["shape"] + if isinstance(data, (bytes, bytearray, memoryview)): + # Stored as raw CPU bytes; interpret as numpy for audio. + np_dtype = np.dtype(dtype) + arr = np.frombuffer(data, dtype=np_dtype).reshape(shape) + if arr.size > 0: + wav_candidates.append(arr) + return + except Exception: + pass + wav_obj = obj.get("array") or obj.get("wav") or obj.get("audio") + sr_obj = obj.get("sampling_rate") or obj.get("sr") or obj.get("sample_rate") + if wav_obj is not None: + _scan(wav_obj, depth + 1) + if sr_obj is not None: + _scan(sr_obj, depth + 1) + return + if isinstance(obj, (tuple, list)): + obj_list = list(obj) + # Unwrap singleton nesting ([[wav, sr]]). + while isinstance(obj_list, list) and len(obj_list) == 1: + inner = obj_list[0] + if isinstance(inner, np.ndarray) and inner.size > 0: + wav_candidates.append(inner) + return + if isinstance(inner, torch.Tensor) and inner.numel() > 0: + wav_candidates.append(inner) + return + if isinstance(inner, dict): + _scan(inner, depth + 1) + return + if isinstance(inner, (tuple, list)): + obj_list = list(inner) # type: ignore[list-item] + continue + break + + # If the *unwrapped* list is a long list of numbers, treat it as waveform. + if len(obj_list) >= 512 and _is_number_sequence(obj_list): + wav_candidates.append(obj_list) + return + + # Otherwise, recurse into elements (but avoid descending into huge numeric lists). + for item in obj_list: + if isinstance(item, list) and len(item) >= 512 and _is_number_sequence(item): # type: ignore[arg-type] + wav_candidates.append(item) + continue + _scan(item, depth + 1) + return + + _scan(ref_audio) + if not sr_candidates: + raise TypeError(f"ref_audio missing sample_rate: {_summarize(ref_audio)}") + sr = int(sr_candidates[0]) + + def _wav_len(x: object) -> int: + try: + if isinstance(x, np.ndarray): + return int(x.size) + if isinstance(x, torch.Tensor): + return int(x.numel()) + if isinstance(x, list): + return int(len(x)) + except Exception: + pass + return 0 + + if not wav_candidates: + raise TypeError(f"ref_audio missing waveform: {_summarize(ref_audio)}") + wav_obj = max(wav_candidates, key=_wav_len) + + def _to_np(x: object) -> np.ndarray: + if isinstance(x, np.ndarray): + return x.astype(np.float32).reshape(-1) + if isinstance(x, torch.Tensor): + return x.detach().to("cpu").float().contiguous().numpy().reshape(-1) + if isinstance(x, dict) and x.get("__ndarray__") and "data" in x and "dtype" in x and "shape" in x: + data = x["data"] + dtype = x["dtype"] + shape = x["shape"] + if isinstance(data, (bytes, bytearray, memoryview)): + return np.frombuffer(data, dtype=dtype).reshape(shape).astype(np.float32).reshape(-1) + if isinstance(x, list): + # list of numbers + if len(x) >= 2 and _is_number_sequence(x): # type: ignore[arg-type] + return np.asarray(x, dtype=np.float32).reshape(-1) + # list of chunks + parts: list[np.ndarray] = [] + for part in x: + if isinstance(part, (np.ndarray, torch.Tensor, list)): + parts.append(_to_np(part)) + if parts: + return np.concatenate(parts, axis=0) + raise TypeError(f"Unsupported waveform type: {type(x)}") + + wav_np = _to_np(wav_obj) + if wav_np.size < 1024: + raise ValueError(f"ref_audio waveform too short: {wav_np.size} samples") + return wav_np, sr + + def _extract_speaker_embedding(self, wav: np.ndarray, sr: int) -> torch.Tensor: + if self.speaker_encoder is None: + raise ValueError( + "This checkpoint does not provide `speaker_encoder` weights; " + "cannot compute ref_spk_embedding from ref_audio." + ) + # vLLM workers do not automatically move arbitrary torch.nn.Modules to + # CUDA. Ensure the speaker encoder is on the same device/dtype as the + # main model before running it. + dev = next(self.parameters()).device + try: + spk_param = next(self.speaker_encoder.parameters()) + if spk_param.device != dev or spk_param.dtype != torch.bfloat16: + self.speaker_encoder.to(device=dev, dtype=torch.bfloat16) + except StopIteration: + pass + # Resample to 24kHz for speaker encoder. + target_sr = int(getattr(self.config.speaker_encoder_config, "sample_rate", 24000)) + if sr != target_sr: + import librosa + + wav = librosa.resample(y=wav.astype(np.float32), orig_sr=int(sr), target_sr=target_sr) + sr = target_sr + + # Follow official implementation: mel_spectrogram expects 24kHz. + mels = mel_spectrogram( + torch.from_numpy(wav).unsqueeze(0), + n_fft=1024, + num_mels=128, + sampling_rate=24000, + hop_size=256, + win_size=1024, + fmin=0, + fmax=12000, + ).transpose(1, 2) + spk = self.speaker_encoder(mels.to(dev, dtype=torch.bfloat16))[0] + return spk.to(dtype=torch.bfloat16) + + def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer: + if self._speech_tokenizer is not None: + return self._speech_tokenizer + speech_tokenizer_path = cached_file(self.model_path, "speech_tokenizer/config.json") + if speech_tokenizer_path is None: + raise ValueError(f"{self.model_path}/speech_tokenizer/config.json not found") + # Ensure the HF feature extractor config is present. Transformers' + # AutoFeatureExtractor does not proactively fetch this file. + preprocessor_config_path = cached_file(self.model_path, "speech_tokenizer/preprocessor_config.json") + if preprocessor_config_path is None: + raise ValueError(f"{self.model_path}/speech_tokenizer/preprocessor_config.json not found") + speech_tokenizer_dir = os.path.dirname(speech_tokenizer_path) + tok = Qwen3TTSTokenizer.from_pretrained( + speech_tokenizer_dir, + torch_dtype=torch.bfloat16, + ) + # Prefer GPU for encoder if available; otherwise keep CPU. + dev = next(self.parameters()).device + if getattr(dev, "type", None) == "cuda": + try: + tok.model.to(dev) + tok.device = dev + except Exception as e: + raise RuntimeError(f"Failed to move speech tokenizer to {dev}: {e}") from e + else: + tok.device = dev + self._speech_tokenizer = tok + return tok + + def _encode_ref_audio_to_code(self, wav: np.ndarray, sr: int) -> torch.Tensor: + tok = self._ensure_speech_tokenizer_loaded() + enc = tok.encode(wav, sr=int(sr), return_dict=True) + ref_code = getattr(enc, "audio_codes", None) + if isinstance(ref_code, list): + ref_code = ref_code[0] if ref_code else None + if isinstance(ref_code, torch.Tensor): + # 12Hz: likely [T, Q] or [B, T, Q] + if ref_code.ndim == 3: + ref_code = ref_code[0] + return ref_code.to(device=next(self.parameters()).device, dtype=torch.long) + raise ValueError("SpeechTokenizer.encode did not return audio_codes tensor") + + def _generate_icl_prompt( + self, + *, + text_id: torch.Tensor, + ref_id: torch.Tensor, + ref_code: torch.Tensor, + tts_pad_embed: torch.Tensor, + tts_eos_embed: torch.Tensor, + non_streaming_mode: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Ported from official Qwen3TTSForConditionalGeneration.generate_icl_prompt + text_embed = self.text_projection(self.text_embedding(torch.cat([ref_id, text_id], dim=-1))) + text_embed = torch.cat([text_embed, tts_eos_embed], dim=1) + + # codec embed (codec bos + codec) 1 T2 D + codec_embed: list[torch.Tensor] = [] + for i in range(int(self.talker_config.num_code_groups)): + if i == 0: + codec_embed.append(self.embed_input_ids(ref_code[:, :1])) + else: + codec_embed.append(self.code_predictor.get_input_embeddings()[i - 1](ref_code[:, i : i + 1])) + codec_embed_sum = torch.cat(codec_embed, dim=1).sum(1).unsqueeze(0) # [1,T,H] + codec_embed_sum = torch.cat( + [ + self.embed_input_ids( + torch.tensor([[self.talker_config.codec_bos_id]], device=codec_embed_sum.device, dtype=torch.long) + ), + codec_embed_sum, + ], + dim=1, + ) + + text_lens = int(text_embed.shape[1]) + codec_lens = int(codec_embed_sum.shape[1]) + if non_streaming_mode: + # Official non-streaming mode: append the full text conditioning in + # prefill, and use PAD in decode steps. + icl_input_embed = text_embed + self.embed_input_ids( + torch.tensor( + [[self.talker_config.codec_pad_id] * text_lens], + device=codec_embed_sum.device, + dtype=torch.long, + ) + ) + icl_input_embed = torch.cat([icl_input_embed, codec_embed_sum + tts_pad_embed], dim=1) + return icl_input_embed, tts_pad_embed + if text_lens > codec_lens: + return text_embed[:, :codec_lens] + codec_embed_sum, text_embed[:, codec_lens:] + text_embed = torch.cat([text_embed] + [tts_pad_embed] * (codec_lens - text_lens), dim=1) + return text_embed + codec_embed_sum, tts_pad_embed + + def _build_prompt_embeds( + self, + *, + task_type: str, + info_dict: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int | None]: + text = (info_dict.get("text") or [""])[0] + language = (info_dict.get("language") or ["Auto"])[0] + non_streaming_mode_val = info_dict.get("non_streaming_mode") + if isinstance(non_streaming_mode_val, list): + non_streaming_mode_raw = non_streaming_mode_val[0] if non_streaming_mode_val else None + else: + non_streaming_mode_raw = non_streaming_mode_val + if isinstance(non_streaming_mode_raw, bool): + non_streaming_mode = non_streaming_mode_raw + else: + # Match official inference defaults: + # - CustomVoice/VoiceDesign: non_streaming_mode=True + # - Base: non_streaming_mode=False + non_streaming_mode = task_type in ("CustomVoice", "VoiceDesign") + + # Text ids for assistant template (always). + tok = self._get_tokenizer() + input_ids = tok(self._build_assistant_text(text), return_tensors="pt", padding=False)["input_ids"].to( + device=next(self.parameters()).device + ) + + # Optional instruct prefix. + instruct = (info_dict.get("instruct") or [""])[0] + instruct_embed = None + if isinstance(instruct, str) and instruct.strip(): + instruct_ids = tok(self._build_instruct_text(instruct), return_tensors="pt", padding=False)["input_ids"].to( + device=input_ids.device + ) + instruct_embed = self.text_projection(self.text_embedding(instruct_ids)) + + # tts special token embeds (projected into talker hidden). + tts_tokens = torch.tensor( + [[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]], + device=input_ids.device, + dtype=input_ids.dtype, + ) + tts_bos_embed, tts_eos_embed, tts_pad_embed = self.text_projection(self.text_embedding(tts_tokens)).chunk( + 3, dim=1 + ) + + # Codec prefill tags. + language_id = None + if isinstance(language, str) and language.lower() != "auto": + language_id = self.talker_config.codec_language_id.get(language.lower()) + # Match official dialect override: + # If language is Chinese/Auto and the selected speaker is a dialect voice, + # set language_id to that dialect to improve code generation stability. + if language_id is None and isinstance(language, str) and language.lower() in ("chinese", "auto"): + speaker_for_dialect = None + if task_type == "CustomVoice": + speaker_for_dialect = (info_dict.get("speaker") or [""])[0] + if isinstance(speaker_for_dialect, str) and speaker_for_dialect.strip(): + spk_is_dialect = getattr(self.talker_config, "spk_is_dialect", None) or {} + dialect = spk_is_dialect.get(speaker_for_dialect.lower()) + if isinstance(dialect, str) and dialect: + language_id = self.talker_config.codec_language_id.get(dialect) + if language_id is None: + codec_prefill_list = [ + [ + self.talker_config.codec_nothink_id, + self.talker_config.codec_think_bos_id, + self.talker_config.codec_think_eos_id, + ] + ] + else: + codec_prefill_list = [ + [ + self.talker_config.codec_think_id, + self.talker_config.codec_think_bos_id, + int(language_id), + self.talker_config.codec_think_eos_id, + ] + ] + + codec_input_0 = self.embed_input_ids( + torch.tensor(codec_prefill_list, device=input_ids.device, dtype=torch.long) + ) + codec_input_1 = self.embed_input_ids( + torch.tensor([[self.talker_config.codec_pad_id, self.talker_config.codec_bos_id]], device=input_ids.device) + ) + + # Speaker embedding/token (task-dependent) + speaker_embed = None + ref_code_len: int | None = None + + def _as_singleton(x: object) -> object: + if isinstance(x, list): + return x[0] if x else None + return x + + def _to_long_tensor(x: object, *, device: torch.device) -> torch.Tensor | None: + x = _as_singleton(x) + if x is None: + return None + if isinstance(x, torch.Tensor): + t = x + elif isinstance(x, np.ndarray): + t = torch.from_numpy(x) + elif isinstance(x, list) and x and all(isinstance(v, (int, np.integer)) for v in x): + t = torch.tensor(x, dtype=torch.long) + else: + return None + if t.ndim == 1: + t = t.unsqueeze(0) + return t.to(device=device, dtype=torch.long) + + def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: + raw = _as_singleton(raw) + if raw is None: + return None + if isinstance(raw, dict): + return raw + # Some callers may pass list[dict] directly. + if isinstance(raw, list) and raw and isinstance(raw[0], dict): + return raw[0] + return None + + if task_type == "Base": + # Base supports voice clone prompt with in-context mode. + xvec_only = bool((info_dict.get("x_vector_only_mode") or [False])[0]) + in_context_mode = not xvec_only + voice_clone_prompt = _normalize_voice_clone_prompt(info_dict.get("voice_clone_prompt")) + # Official implementation may pass `voice_clone_prompt.icl_mode`. + if voice_clone_prompt is not None and "icl_mode" in voice_clone_prompt: + icl_flag = _as_singleton(voice_clone_prompt.get("icl_mode")) + if isinstance(icl_flag, bool): + in_context_mode = icl_flag + xvec_only = not in_context_mode + ref_code = None + if voice_clone_prompt is not None: + ref_code = _as_singleton(voice_clone_prompt.get("ref_code")) + ref_code_t = None + if isinstance(ref_code, torch.Tensor): + ref_code_t = ref_code + elif isinstance(ref_code, np.ndarray): + ref_code_t = torch.from_numpy(ref_code) + if isinstance(ref_code_t, torch.Tensor): + if ref_code_t.ndim == 3: + ref_code_t = ref_code_t[0] + ref_code_t = ref_code_t.to(device=input_ids.device, dtype=torch.long) + ref_code_len = int(ref_code_t.shape[0]) + elif in_context_mode: + # Compute ref_code from ref_audio if not provided. + ref_audio_list = info_dict.get("ref_audio") + if not isinstance(ref_audio_list, list) or not ref_audio_list: + raise ValueError("Base requires `ref_audio`.") + wav_np, sr = self._normalize_ref_audio(ref_audio_list[0]) + ref_code_t = self._encode_ref_audio_to_code(wav_np, sr).to(device=input_ids.device) + ref_code_len = int(ref_code_t.shape[0]) + + # Speaker embedding: use prompt embed if provided; otherwise extract from audio. + spk = None + if voice_clone_prompt is not None: + spk = _as_singleton(voice_clone_prompt.get("ref_spk_embedding")) + if isinstance(spk, torch.Tensor): + speaker_embed = spk.to(device=input_ids.device, dtype=torch.bfloat16).view(1, 1, -1) + else: + ref_audio_list = info_dict.get("ref_audio") + if not isinstance(ref_audio_list, list) or not ref_audio_list: + raise ValueError("Base requires `ref_audio`.") + wav_np, sr = self._normalize_ref_audio(ref_audio_list[0]) + speaker_embed = self._extract_speaker_embedding(wav_np, sr).view(1, 1, -1) + + codec_input = torch.cat([codec_input_0, speaker_embed, codec_input_1], dim=1) + + # Role header (<|im_start|>assistant\n) -> projected text embeds. + role_embed = self.text_projection(self.text_embedding(input_ids[:, :3])) + + codec_prefix = torch.cat((tts_pad_embed.expand(-1, codec_input.shape[1] - 2, -1), tts_bos_embed), dim=1) + codec_prefix = codec_prefix + codec_input[:, :-1] + talker_prompt = torch.cat((role_embed, codec_prefix), dim=1) + + if in_context_mode: + # Prefer explicit tokenized `ref_ids` if provided (matches official signature). + ref_ids = _to_long_tensor(info_dict.get("ref_ids"), device=input_ids.device) + if ref_ids is None and voice_clone_prompt is not None: + ref_ids = _to_long_tensor( + voice_clone_prompt.get("ref_ids") or voice_clone_prompt.get("ref_id"), device=input_ids.device + ) + if ref_ids is None: + ref_text = _as_singleton(info_dict.get("ref_text")) + if not isinstance(ref_text, str) or not ref_text.strip(): + raise ValueError("Base in-context voice cloning requires `ref_text` or tokenized `ref_ids`.") + ref_ids = tok(self._build_ref_text(ref_text), return_tensors="pt", padding=False)["input_ids"].to( + device=input_ids.device + ) + icl_input_embed, trailing_text_hidden = self._generate_icl_prompt( + text_id=input_ids[:, 3:-5], + ref_id=ref_ids[:, 3:-2], + ref_code=ref_code_t, # type: ignore[arg-type] + tts_pad_embed=tts_pad_embed, + tts_eos_embed=tts_eos_embed, + non_streaming_mode=non_streaming_mode, + ) + talker_prompt = torch.cat([talker_prompt, icl_input_embed], dim=1) + else: + # First text token (+ codec_bos). + if non_streaming_mode: + # Official non-streaming mode: put the full text into the + # prefill prompt and use PAD for decode steps. + text_all = self.text_projection(self.text_embedding(input_ids[:, 3:-5])) + text_all = torch.cat([text_all, tts_eos_embed], dim=1) + pad_ids = torch.full( + (1, int(text_all.shape[1])), + int(self.talker_config.codec_pad_id), + device=input_ids.device, + dtype=torch.long, + ) + talker_prompt = torch.cat( + [ + talker_prompt, + text_all + self.embed_input_ids(pad_ids), + tts_pad_embed + + self.embed_input_ids( + torch.tensor([[self.talker_config.codec_bos_id]], device=input_ids.device) + ), + ], + dim=1, + ) + trailing_text_hidden = tts_pad_embed + else: + first_text = self.text_projection(self.text_embedding(input_ids[:, 3:4])) + codec_input[:, -1:] + talker_prompt = torch.cat([talker_prompt, first_text], dim=1) + trailing_text_hidden = torch.cat( + ( + self.text_projection(self.text_embedding(input_ids[:, 4:-5])), + tts_eos_embed, + ), + dim=1, + ) + + elif task_type == "CustomVoice": + speaker = (info_dict.get("speaker") or [""])[0] + if not isinstance(speaker, str) or not speaker.strip(): + raise ValueError("CustomVoice requires additional_information.speaker.") + spk_id_map = getattr(self.talker_config, "spk_id", None) or {} + if speaker.lower() not in spk_id_map: + raise ValueError(f"Unsupported speaker: {speaker}") + spk_id = spk_id_map[speaker.lower()] + # Keep it at least 1D; embedding on a 0-d tensor can return 1D. + spk_tensor = torch.tensor([spk_id], device=input_ids.device, dtype=torch.long) + spk_embed = self.embed_input_ids(spk_tensor) + if spk_embed.ndim == 1: + spk_embed = spk_embed.view(1, 1, -1) + elif spk_embed.ndim == 2: + spk_embed = spk_embed.view(1, 1, -1) + speaker_embed = spk_embed + codec_input = torch.cat([codec_input_0, speaker_embed, codec_input_1], dim=1) + + role_embed = self.text_projection(self.text_embedding(input_ids[:, :3])) + codec_prefix = torch.cat((tts_pad_embed.expand(-1, codec_input.shape[1] - 2, -1), tts_bos_embed), dim=1) + codec_prefix = codec_prefix + codec_input[:, :-1] + talker_prompt = torch.cat((role_embed, codec_prefix), dim=1) + + if non_streaming_mode: + text_all = self.text_projection(self.text_embedding(input_ids[:, 3:-5])) + text_all = torch.cat([text_all, tts_eos_embed], dim=1) + pad_ids = torch.full( + (1, int(text_all.shape[1])), + int(self.talker_config.codec_pad_id), + device=input_ids.device, + dtype=torch.long, + ) + talker_prompt = torch.cat( + [ + talker_prompt, + text_all + self.embed_input_ids(pad_ids), + tts_pad_embed + + self.embed_input_ids( + torch.tensor([[self.talker_config.codec_bos_id]], device=input_ids.device) + ), + ], + dim=1, + ) + trailing_text_hidden = tts_pad_embed + else: + first_text = self.text_projection(self.text_embedding(input_ids[:, 3:4])) + codec_input[:, -1:] + talker_prompt = torch.cat([talker_prompt, first_text], dim=1) + trailing_text_hidden = torch.cat( + ( + self.text_projection(self.text_embedding(input_ids[:, 4:-5])), + tts_eos_embed, + ), + dim=1, + ) + + elif task_type == "VoiceDesign": + # No known speaker identity; only codec tags + text. + codec_input = torch.cat([codec_input_0, codec_input_1], dim=1) + + role_embed = self.text_projection(self.text_embedding(input_ids[:, :3])) + codec_prefix = torch.cat((tts_pad_embed.expand(-1, codec_input.shape[1] - 2, -1), tts_bos_embed), dim=1) + codec_prefix = codec_prefix + codec_input[:, :-1] + talker_prompt = torch.cat((role_embed, codec_prefix), dim=1) + + if non_streaming_mode: + text_all = self.text_projection(self.text_embedding(input_ids[:, 3:-5])) + text_all = torch.cat([text_all, tts_eos_embed], dim=1) + pad_ids = torch.full( + (1, int(text_all.shape[1])), + int(self.talker_config.codec_pad_id), + device=input_ids.device, + dtype=torch.long, + ) + talker_prompt = torch.cat( + [ + talker_prompt, + text_all + self.embed_input_ids(pad_ids), + tts_pad_embed + + self.embed_input_ids( + torch.tensor([[self.talker_config.codec_bos_id]], device=input_ids.device) + ), + ], + dim=1, + ) + trailing_text_hidden = tts_pad_embed + else: + first_text = self.text_projection(self.text_embedding(input_ids[:, 3:4])) + codec_input[:, -1:] + talker_prompt = torch.cat([talker_prompt, first_text], dim=1) + trailing_text_hidden = torch.cat( + ( + self.text_projection(self.text_embedding(input_ids[:, 4:-5])), + tts_eos_embed, + ), + dim=1, + ) + else: + raise ValueError(f"Unsupported task_type={task_type}") + + if instruct_embed is not None: + talker_prompt = torch.cat([instruct_embed, talker_prompt], dim=1) + + return ( + talker_prompt.squeeze(0), # [prompt_len, H] + trailing_text_hidden.squeeze(0), # [T, H] + tts_pad_embed.squeeze(0), # [1, H] + ref_code_len, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Consume talker weights, and conditionally consume speaker encoder + # weights only if they are present in the checkpoint. + speaker_weights: list[tuple[str, torch.Tensor]] = [] + + def _talker_and_collect_speaker(ws: Iterable[tuple[str, torch.Tensor]]): + for k, v in ws: + if k.startswith("speaker_encoder."): + speaker_weights.append((k, v)) + continue + if k.startswith("talker."): + yield k, v + + loader = AutoWeightsLoader(self) + loaded = loader.load_weights(_talker_and_collect_speaker(weights), mapper=self.hf_to_vllm_mapper) + + if speaker_weights: + if self.speaker_encoder is None: + self.speaker_encoder = Qwen3TTSSpeakerEncoder(self.config.speaker_encoder_config) + loaded |= loader.load_weights(speaker_weights, mapper=self.hf_to_vllm_mapper) + logger.info("Loaded %d weights for Qwen3TTSTalkerForConditionalGeneration", len(loaded)) + return loaded + + # -------------------- GPU-side MTP fast-path -------------------- + + @torch.inference_mode() + def talker_mtp( + self, + input_ids: torch.Tensor, + input_embeds: torch.Tensor, + last_talker_hidden: torch.Tensor, + text_step: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """GPU fast-path used by OmniGPUModelRunner to predict residual codebooks (1..Q-1). + Returns (inputs_embeds, audio_codes) for the current step.""" + bsz = int(input_ids.shape[0]) + q = int(self.talker_config.num_code_groups) + dev = input_embeds.device + + input_ids = input_ids.reshape(bsz, 1).to(dtype=torch.long, device=dev) + last_id_hidden = input_embeds.reshape(bsz, 1, -1).to(dtype=torch.bfloat16, device=dev) + past_hidden = last_talker_hidden.reshape(bsz, 1, -1).to(dtype=torch.bfloat16, device=dev) + text_step = text_step.reshape(bsz, 1, -1).to(dtype=torch.bfloat16, device=dev) + + # Residual predictor runs fixed-length (Q-1) steps via the vLLM-native code_predictor. + max_steps = q - 1 + if max_steps <= 0: + audio_codes = input_ids.reshape(bsz, 1) + return (last_id_hidden + text_step).reshape(bsz, -1), audio_codes + + # Predict residual codes (1..Q-1) with HF reference sampling params. + audio_codes = self.code_predictor( + layer0_code=input_ids.reshape(bsz, 1), + layer0_embed=last_id_hidden, + last_talker_hidden=past_hidden, + do_sample=True, + temperature=0.9, + top_k=50, + top_p=1.0, + ) # [B, Q] + + # Map invalid layer-0 ids (e.g. EOS) to PAD=0 so SpeechTokenizer sees only real codes. + layer0 = audio_codes[:, :1] + invalid0 = (layer0 < 0) | (layer0 >= int(self._codebook_vocab_size)) + audio_codes = torch.where(invalid0.expand_as(audio_codes), torch.zeros_like(audio_codes), audio_codes) + + # Sum embeddings of all code groups, then add the current text step. + residual_ids_t = audio_codes[:, 1:] + embeds: list[torch.Tensor] = [last_id_hidden] + for i in range(max_steps): + embeds.append(self.code_predictor.get_input_embeddings()[i](residual_ids_t[:, i : i + 1])) + summed = torch.cat(embeds, dim=1).sum(1, keepdim=True) # [B,1,H] + inputs_embeds_out = (summed + text_step).reshape(bsz, -1) + return inputs_embeds_out, audio_codes.to(dtype=torch.long) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py index e6e50211988..785ddedab50 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py @@ -80,16 +80,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3 """ inst = cls() - AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config) - AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model) + load_feature_extractor = bool(kwargs.pop("load_feature_extractor", True)) AutoConfig.register("qwen3_tts_tokenizer_12hz", Qwen3TTSTokenizerV2Config) AutoModel.register(Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2Model) - inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path) + AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config) + AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model) + inst.model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs) inst.config = inst.model.config + inst.feature_extractor = ( + AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path) if load_feature_extractor else None + ) + inst.device = getattr(inst.model, "device", None) if inst.device is None: # fallback: infer from first parameter device @@ -212,12 +217,7 @@ def encode( audios: AudioInput, sr: int | None = None, return_dict: bool = True, - ) -> ( - Qwen3TTSTokenizerV1EncoderOutput - | Qwen3TTSTokenizerV2EncoderOutput - | tuple[list[torch.Tensor], list[torch.Tensor] | None, list[torch.Tensor] | None] - | tuple[list[torch.Tensor]] - ): + ) -> Qwen3TTSTokenizerV1EncoderOutput | Qwen3TTSTokenizerV2EncoderOutput | tuple: """ Batch-encode audio into discrete codes (and optional conditioning, depending on 25Hz/12Hz). diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py index 747ca8f0cdd..2a66632e796 100644 --- a/vllm_omni/model_executor/models/registry.py +++ b/vllm_omni/model_executor/models/registry.py @@ -50,8 +50,18 @@ ), "Qwen3TTSForConditionalGeneration": ( "qwen3_tts", + "qwen3_tts_talker", + "Qwen3TTSTalkerForConditionalGeneration", + ), + "Qwen3TTSTalkerForConditionalGeneration": ( + "qwen3_tts", + "qwen3_tts_talker", + "Qwen3TTSTalkerForConditionalGeneration", + ), + "Qwen3TTSCode2Wav": ( "qwen3_tts", - "Qwen3TTSModelForGeneration", + "qwen3_tts_code2wav", + "Qwen3TTSCode2Wav", ), } diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml index d408dbab91e..1f29f0796ed 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml @@ -1,22 +1,101 @@ +async_chunk: true stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm runtime: devices: "0" max_batch_size: 1 engine_args: model_stage: qwen3_tts - model_arch: Qwen3TTSForConditionalGeneration + model_arch: Qwen3TTSTalkerForConditionalGeneration + # Force stage-specific registered architecture. + hf_overrides: + architectures: [Qwen3TTSTalkerForConditionalGeneration] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: false + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.3 + distributed_executor_backend: "mp" + max_num_batched_tokens: 512 + max_model_len: 4096 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk + # Use named connector to apply runtime.connectors.extra. + output_connectors: + to_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3TTSCode2Wav + # Force stage-specific registered architecture. + hf_overrides: + architectures: [Qwen3TTSCode2Wav] worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler enforce_eager: true trust_remote_code: true async_scheduling: false enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.1 + engine_output_type: audio + gpu_memory_utilization: 0.2 distributed_executor_backend: "mp" - max_num_batched_tokens: 1000000 - + # Must be divisible by num_code_groups and cover (left_context + chunk). + max_num_batched_tokens: 8192 + # async_chunk appends windows per step; max_model_len must cover accumulated stream. + max_model_len: 32768 + engine_input_source: [0] final_output: true final_output_type: audio + # Distributed connector configuration + input_connectors: + from_stage_0: connector_of_shared_memory + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + # Frame-aligned codec streaming transport. + codec_streaming: true + # Connector polling / timeout (unit: loop count, sleep interval in seconds). + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + # Align with Omni: small chunks with sufficient context overlap. + codec_chunk_frames: 25 + codec_left_context_frames: 25 + + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py new file mode 100644 index 00000000000..8599ea2e3e8 --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -0,0 +1,76 @@ +"""Stage input processor for Qwen3-TTS: Talker -> Code2Wav.""" + +from typing import Any + +import torch + + +def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None: + audio_codes = pooling_output.get("audio_codes") + if not isinstance(audio_codes, torch.Tensor) or audio_codes.numel() == 0: + return None + if audio_codes.ndim == 2: + frame = audio_codes[-1] + if frame.numel() == 0 or not bool(frame.any().item()): + return None + return frame.to(torch.long).reshape(-1) + if audio_codes.ndim == 1: + return audio_codes.to(torch.long).reshape(-1) + raise ValueError(f"Invalid audio_codes shape for Qwen3-TTS async_chunk: {tuple(audio_codes.shape)}") + + +def talker2code2wav_async_chunk( + transfer_manager: Any, + pooling_output: dict[str, Any], + request: Any, +) -> dict[str, Any] | None: + if not isinstance(pooling_output, dict): + return None + + request_id = request.external_req_id + + connector = getattr(transfer_manager, "connector", None) + raw_cfg = getattr(connector, "config", {}) or {} + cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} + chunk_size = int(cfg.get("codec_chunk_frames", 25)) + left_context_size = int(cfg.get("codec_left_context_frames", 25)) + if chunk_size <= 0 or left_context_size < 0: + raise ValueError( + f"Invalid codec chunk config: codec_chunk_frames={chunk_size}, " + f"codec_left_context_frames={left_context_size}" + ) + + finished = bool(request.is_finished()) + + frame = _extract_last_frame(pooling_output) + if frame is not None: + codec_codes = frame.cpu().tolist() + transfer_manager.code_prompt_token_ids[request_id].append(codec_codes) + + length = len(transfer_manager.code_prompt_token_ids[request_id]) + chunk_length = length % chunk_size + + if chunk_length != 0 and not finished: + return None + + context_length = chunk_length if chunk_length != 0 else chunk_size + + if length <= 0: + return { + "code_predictor_codes": [], + "finished": torch.tensor(bool(finished), dtype=torch.bool), + } + + end_index = min(length, left_context_size + context_length) + ctx_frames = max(0, int(end_index - context_length)) + window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:] + + # Pack context + chunk into codebook-major flat codes for adapter. + code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist() + + # Build final prompt_token_ids with ctx_frames header for Qwen3-TTS Code2Wav. + # The model expects input_ids layout: [ctx_frames, *flat_codes]. + return { + "code_predictor_codes": [int(ctx_frames)] + code_predictor_codes, + "finished": torch.tensor(bool(finished), dtype=torch.bool), + } diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml index d408dbab91e..fbfbf10a49e 100644 --- a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml @@ -1,22 +1,92 @@ +async_chunk: true stage_args: - stage_id: 0 - stage_type: llm # Use llm stage type to launch OmniLLM + stage_type: llm runtime: devices: "0" max_batch_size: 1 engine_args: model_stage: qwen3_tts - model_arch: Qwen3TTSForConditionalGeneration + model_arch: Qwen3TTSTalkerForConditionalGeneration + hf_overrides: + architectures: [Qwen3TTSTalkerForConditionalGeneration] + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.3 + distributed_executor_backend: "mp" + max_num_batched_tokens: 512 + max_model_len: 4096 + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk + output_connectors: + to_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.05 + stop_token_ids: [2150] + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: code2wav + model_arch: Qwen3TTSCode2Wav + hf_overrides: + architectures: [Qwen3TTSCode2Wav] worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler enforce_eager: true trust_remote_code: true async_scheduling: false enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.1 + engine_output_type: audio + gpu_memory_utilization: 0.2 distributed_executor_backend: "mp" - max_num_batched_tokens: 1000000 - + max_num_batched_tokens: 8192 + max_model_len: 32768 + engine_input_source: [0] final_output: true final_output_type: audio + input_connectors: + from_stage_0: connector_of_shared_memory + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + detokenize: true + repetition_penalty: 1.0 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + codec_streaming: true + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + codec_chunk_frames: 25 + codec_left_context_frames: 25 + + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py index e8559bb463c..d263fb0d386 100644 --- a/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_generation_model_runner.py @@ -238,9 +238,12 @@ def sample_tokens( pooler_output.append(mm_payload) else: raise RuntimeError("Unsupported diffusion output type") + # [Omni] Copy req_id mappings to avoid async scheduling mutation. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() output = OmniModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py index 0747db3ea57..aa75d201ccb 100644 --- a/vllm_omni/worker/gpu_generation_model_runner.py +++ b/vllm_omni/worker/gpu_generation_model_runner.py @@ -380,9 +380,12 @@ def sample_tokens( pooler_output.append(mm_payload) else: raise RuntimeError("Unsupported diffusion output type") + # [Omni] Copy req_id mappings to avoid async scheduling mutation. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() output = OmniModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 09a792bc802..a01ad113aeb 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -70,16 +70,24 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): @instrument(span_name="Loading (GPU)") def load_model(self, *args, **kwargs) -> None: super().load_model(*args, **kwargs) + # TODO move this model specific logic to a separate class - if hasattr(self.model, "talker_mtp") and self.model.talker is not None: - self.talker_mtp = self.model.talker_mtp + # TTS model IS the talker (no .talker sub-attr); use getattr to support both Omni and TTS. + talker_mtp = getattr(self.model, "talker_mtp", None) + if talker_mtp is not None: + self.talker_mtp = talker_mtp # type: ignore[assignment] cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - if cudagraph_mode.has_full_cudagraphs(): - self.talker_mtp = CUDAGraphWrapper( - self.model.talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL - ) - hidden_size = self.model_config.hf_config.talker_config.text_config.hidden_size + # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that + # have a separate .talker sub-module. TTS models' code predictor + # has internal AR loops / torch.multinomial — not graph-safe. + has_separate_talker = getattr(self.model, "talker", None) is not None + if cudagraph_mode.has_full_cudagraphs() and has_separate_talker: + self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) + # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size. + hidden_size = int( + getattr(self.model, "mtp_hidden_size", 0) or getattr(self.model_config.hf_text_config, "hidden_size") + ) max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size) self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32) self.talker_mtp_inputs_embeds = self._make_buffer( @@ -644,6 +652,11 @@ def _dummy_run( input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] model_kwargs = self._init_model_kwargs() + elif getattr(getattr(self, "model", None), "has_preprocess", False): + # Capture CUDA graph with inputs_embeds path so replay reads + # from the same buffer that _preprocess writes into. + input_ids = self.input_ids.gpu[:num_tokens_padded] + inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] else: input_ids = self.input_ids.gpu[:num_tokens_padded] inputs_embeds = None @@ -986,6 +999,11 @@ def _preprocess( inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] model_kwargs = self._init_model_kwargs() input_ids = self.input_ids.gpu[:num_input_tokens] + elif getattr(self.model, "has_preprocess", False): + # Use pre-allocated buffer for CUDA graph compatibility. + input_ids = self.input_ids.gpu[:num_input_tokens] + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + model_kwargs = self._init_model_kwargs() else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the @@ -1047,9 +1065,16 @@ def _preprocess( span_len = int(e) - int(s) # call the custom process function + embed_slice = inputs_embeds[s:e] if inputs_embeds is not None else None req_input_ids, req_embeds, update_dict = self.model.preprocess( - input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos + input_ids=input_ids[s:e], input_embeds=embed_slice, **req_infos ) + if inputs_embeds is None: + inputs_embeds = torch.empty( + (input_ids.shape[0], req_embeds.shape[-1]), + device=req_embeds.device, + dtype=req_embeds.dtype, + ) if hasattr(self.model, "talker_mtp") and span_len == 1: last_talker_hidden, text_step = update_dict.pop("mtp_inputs") @@ -1093,6 +1118,9 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te max_num_scheduled_tokens=1, use_cascade_attn=False, ) + # Force eager for unwrapped code predictors (AR loops / multinomial). + if not isinstance(self.talker_mtp, CUDAGraphWrapper): + _cudagraph_mode = CUDAGraphMode.NONE num_tokens_padded = batch_desc.num_tokens req_input_ids = self.talker_mtp_input_ids.gpu[:num_tokens_padded] req_embeds = self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded] @@ -1104,11 +1132,12 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) # update the inputs_embeds and code_predictor_codes code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + out_key = getattr(self.model, "talker_mtp_output_key", "code_predictor_codes") for idx, req_id in enumerate(decode_req_ids): req_index = self.input_batch.req_ids.index(req_id) start_offset = int(self.query_start_loc.cpu[req_index]) inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] - update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} + update_dict = {out_key: code_predictor_codes_cpu[idx : idx + 1]} self._merge_additional_information_update(req_id, update_dict) def _model_forward( @@ -1142,7 +1171,9 @@ def _model_forward( self._omni_last_model_output = model_output return model_output - def _merge_additional_information_update(self, req_id: str, upd: dict) -> None: + def _merge_additional_information_update(self, req_id: str, upd: dict | None) -> None: + if not isinstance(upd, dict): + return req_state = self.requests.get(req_id) if req_state is None: return