diff --git a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py index 15464032b37..1b6fe2e4bcb 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py @@ -7,7 +7,10 @@ import pytest import torch -from vllm_omni.model_executor.stage_input_processors.qwen3_tts import talker2code2wav_async_chunk +from vllm_omni.model_executor.stage_input_processors.qwen3_tts import ( + talker2code2wav, + talker2code2wav_async_chunk, +) _FRAME = [1, 2, 3, 4] # 4-codebook frame _Q = len(_FRAME) # num quantizers @@ -29,6 +32,7 @@ def _tm(*, chunk_frames=25, left_context=25, initial_chunk=0): return SimpleNamespace( code_prompt_token_ids=defaultdict(list), put_req_chunk=defaultdict(int), + request_payload={}, connector=SimpleNamespace( config={ "extra": { @@ -152,3 +156,113 @@ def test_per_request_override_wins_over_stage_config(): tm = _tm(initial_chunk=5) payload = _call(tm, "r-override2", n_frames=10, put_req=0, req_ic=15) assert payload is None + + +def test_first_streaming_chunk_prepends_ref_code_context(): + tm = _tm(initial_chunk=10) + rid = "r-ref" + tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(10)] + ref_code = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long) + + payload = talker2code2wav_async_chunk( + transfer_manager=tm, + pooling_output={"audio_codes": torch.zeros((0,)), "ref_code": ref_code}, + request=_req(rid, finished=False, initial_codec_chunk_frames=10), + is_finished=False, + ) + + assert payload is not None + assert payload["left_context_size"] == 2 + assert len(payload["code_predictor_codes"]) == _Q * 12 + + +def test_ref_code_context_only_applies_to_first_streaming_chunk(): + tm = _tm(initial_chunk=10) + rid = "r-ref2" + tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(20)] + tm.put_req_chunk[rid] = 1 + ref_code = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long) + + payload = talker2code2wav_async_chunk( + transfer_manager=tm, + pooling_output={"audio_codes": torch.zeros((0,)), "ref_code": ref_code}, + request=_req(rid, finished=False, initial_codec_chunk_frames=10), + is_finished=False, + ) + + assert payload is not None + assert payload["left_context_size"] == 10 + assert len(payload["code_predictor_codes"]) == _Q * 20 + + +def test_ref_code_context_can_be_buffered_before_first_emit(): + tm = _tm(initial_chunk=10) + rid = "r-ref-buffered" + ref_code = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long) + + first_payload = talker2code2wav_async_chunk( + transfer_manager=tm, + pooling_output={"audio_codes": torch.tensor([[1, 2, 3, 4]]), "ref_code": ref_code}, + request=_req(rid, finished=False, initial_codec_chunk_frames=10), + is_finished=False, + ) + assert first_payload is None + assert rid in tm.request_payload + + for _ in range(8): + talker2code2wav_async_chunk( + transfer_manager=tm, + pooling_output={"audio_codes": torch.tensor([[1, 2, 3, 4]])}, + request=_req(rid, finished=False, initial_codec_chunk_frames=10), + is_finished=False, + ) + + payload = talker2code2wav_async_chunk( + transfer_manager=tm, + pooling_output={"audio_codes": torch.tensor([[1, 2, 3, 4]])}, + request=_req(rid, finished=False, initial_codec_chunk_frames=10), + is_finished=False, + ) + + assert payload is not None + assert payload["left_context_size"] == 2 + assert len(payload["code_predictor_codes"]) == _Q * 12 + assert rid not in tm.request_payload + + +def test_non_async_processor_prepends_ref_code_and_sets_trim_context(): + ref_code = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long) + audio_codes = torch.tensor( + [ + [0, 0, 0, 0], + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + dtype=torch.long, + ) + output = SimpleNamespace(multimodal_output={"audio_codes": audio_codes, "ref_code": ref_code}) + stage = SimpleNamespace(engine_outputs=[SimpleNamespace(outputs=[output])]) + + prompts = talker2code2wav(stage_list=[stage], engine_input_source=[0]) + + assert len(prompts) == 1 + prompt = prompts[0] + assert prompt["additional_information"] == {"left_context_size": [2]} + assert prompt["prompt_token_ids"] == [ + 9, + 8, + 1, + 5, + 9, + 8, + 2, + 6, + 9, + 8, + 3, + 7, + 9, + 8, + 4, + 8, + ] diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index de9c559ddfb..7bcf75ace9d 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -11,6 +11,7 @@ import numpy as np from fastapi import Request, UploadFile from fastapi.responses import Response, StreamingResponse +from transformers.utils.hub import cached_file from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.logger import init_logger from vllm.multimodal.media import MediaConnector @@ -120,7 +121,9 @@ def _load_codec_frame_rate(self) -> float | None: try: model_path = self.engine_client.model_config.model st_config_path = os.path.join(model_path, "speech_tokenizer", "config.json") - if os.path.exists(st_config_path): + if not os.path.exists(st_config_path): + st_config_path = cached_file(model_path, "speech_tokenizer/config.json") + if st_config_path is not None and os.path.exists(st_config_path): with open(st_config_path) as f: st_config = json.load(f) output_sr = st_config.get("output_sample_rate") diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index a1f21a6a68a..72597249cdf 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -447,6 +447,7 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer") audio_codes_list: list[torch.Tensor] = [] ref_code_len_list: list[torch.Tensor] = [] + ref_code_tensor: torch.Tensor | None = None codec_streaming_list: list[torch.Tensor] = [] for info in info_dicts: if not isinstance(info, dict): @@ -459,6 +460,9 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A codec_streaming_list.append( torch.full((int(ac.shape[0]),), int(cs), dtype=torch.int8, device=ac.device) ) + ref_code = info.get("ref_code") + if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0: + ref_code_tensor = ref_code ref_len = info.get("ref_code_len") if ref_len is None: continue @@ -487,6 +491,8 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A mm: dict[str, torch.Tensor] = {"audio_codes": audio_codes} if ref_code_len_list: mm["ref_code_len"] = torch.cat(ref_code_len_list, dim=0)[:span_len] + if ref_code_tensor is not None: + mm["ref_code"] = [ref_code_tensor] if codec_streaming_list: mm["codec_streaming"] = torch.cat(codec_streaming_list, dim=0)[:span_len] return OmniOutput(text_hidden_states=hidden, multimodal_outputs=mm) @@ -538,8 +544,8 @@ def preprocess( # Subsequent prefill rounds (multi-chunk): prompt_embeds_cpu is a Tensor stored by the first round. is_first_prefill = not isinstance(prompt_embeds_cpu, torch.Tensor) or prompt_embeds_cpu.ndim != 2 if is_first_prefill: - full_prompt_embeds, tailing_text_hidden, tts_pad_embed, ref_code_len = self._build_prompt_embeds( - task_type=task_type, info_dict=info_dict + full_prompt_embeds, tailing_text_hidden, tts_pad_embed, ref_code_len, ref_code = ( + self._build_prompt_embeds(task_type=task_type, info_dict=info_dict) ) # Store full prompt embeddings + trailing queue on CPU for later chunks/steps. prompt_embeds_cpu = full_prompt_embeds.detach().to("cpu").contiguous() @@ -550,6 +556,8 @@ def preprocess( "talker_prefill_offset": 0, "codec_streaming": codec_streaming, } + if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0: + info_update["ref_code"] = ref_code.detach().to("cpu").contiguous() if ref_code_len is not None: info_update["ref_code_len"] = int(ref_code_len) # Always return a span_len slice; if the scheduled placeholder is longer, pad with tts_pad_embed. @@ -773,7 +781,6 @@ def _first(x: object, default: object) -> object: if ref_code_len is None and estimate_ref_code_len is not None: ref_code_len = estimate_ref_code_len(info.get("ref_audio")) - if ref_code_len is None: raise ValueError( "Base in-context voice cloning requires either `voice_clone_prompt.ref_code` " @@ -1183,7 +1190,7 @@ def _build_prompt_embeds( *, task_type: str, info_dict: dict[str, Any], - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int | None]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int | None, torch.Tensor | None]: text = (info_dict.get("text") or [""])[0] language = (info_dict.get("language") or ["Auto"])[0] non_streaming_mode_val = info_dict.get("non_streaming_mode") @@ -1268,6 +1275,7 @@ def _build_prompt_embeds( # Speaker embedding/token (task-dependent) speaker_embed = None ref_code_len: int | None = None + ref_code_prompt: torch.Tensor | None = None def _as_singleton(x: object) -> object: if isinstance(x, list): @@ -1333,6 +1341,8 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: wav_np, sr = self._normalize_ref_audio(ref_audio_list[0]) ref_code_t = self._encode_ref_audio_to_code(wav_np, sr).to(device=input_ids.device) ref_code_len = int(ref_code_t.shape[0]) + if isinstance(ref_code_t, torch.Tensor): + ref_code_prompt = ref_code_t # Speaker embedding: use prompt embed if provided; otherwise extract from audio. spk = None @@ -1521,6 +1531,7 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: trailing_text_hidden.squeeze(0), # [T, H] tts_pad_embed.squeeze(0), # [1, H] ref_code_len, + ref_code_prompt.contiguous() if isinstance(ref_code_prompt, torch.Tensor) else None, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index f91d246c695..eb975326272 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -27,13 +27,24 @@ def talker2code2wav( # Filter zero-padded frames (EOS/invalid steps), matching _extract_last_frame behavior valid_mask = audio_codes.any(dim=1) audio_codes = audio_codes[valid_mask] + ref_code = output.multimodal_output.get("ref_code") + if isinstance(ref_code, list): + ref_code = ref_code[0] if ref_code else None + if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0: + ref_code = ref_code.to(torch.long).cpu().contiguous() + ref_code_len = int(ref_code.shape[0]) + audio_codes = torch.cat([ref_code.to(audio_codes.device), audio_codes], dim=0) + else: + ref_code_len = 0 # Code2Wav expects codebook-major flat: [Q*num_frames] codec_codes = audio_codes.transpose(0, 1).cpu().reshape(-1).tolist() + additional_information = {"left_context_size": [ref_code_len]} if ref_code_len > 0 else None code2wav_inputs.append( OmniTokensPrompt( prompt_token_ids=codec_codes, multi_modal_data=None, mm_processor_kwargs=None, + additional_information=additional_information, ) ) return code2wav_inputs @@ -61,12 +72,19 @@ def talker2code2wav_async_chunk( ) -> dict[str, Any] | None: request_id = request.external_req_id finished = bool(is_finished or request.is_finished()) + request_payload = getattr(transfer_manager, "request_payload", None) + if request_payload is None: + request_payload = {} + transfer_manager.request_payload = request_payload if isinstance(pooling_output, dict): frame = _extract_last_frame(pooling_output) if frame is not None: codec_codes = frame.cpu().tolist() transfer_manager.code_prompt_token_ids[request_id].append(codec_codes) + ref_code = pooling_output.get("ref_code") + if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0 and request_payload.get(request_id) is None: + request_payload[request_id] = ref_code.to(torch.long).cpu().contiguous() elif not finished: # Some steps may not produce pooling_output. Only flush on finish. return None @@ -139,6 +157,14 @@ def talker2code2wav_async_chunk( left_context_size = max(0, int(end_index - context_length)) window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:] + # Match offline Base ICL decoding by prepending ref_code only once to the first decoder window. + if transfer_manager.put_req_chunk[request_id] == 0: + ref_code = request_payload.pop(request_id, None) + if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0: + ref_frames = ref_code.tolist() + window_frames = ref_frames + window_frames + left_context_size += len(ref_frames) + # Pack context + chunk into codebook-major flat codes for adapter. code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist()