Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import pytest
import torch

from vllm_omni.model_executor.stage_input_processors.qwen3_tts import talker2code2wav_async_chunk
from vllm_omni.model_executor.stage_input_processors.qwen3_tts import (
talker2code2wav,
talker2code2wav_async_chunk,
)

_FRAME = [1, 2, 3, 4] # 4-codebook frame
_Q = len(_FRAME) # num quantizers
Expand All @@ -29,6 +32,7 @@ def _tm(*, chunk_frames=25, left_context=25, initial_chunk=0):
return SimpleNamespace(
code_prompt_token_ids=defaultdict(list),
put_req_chunk=defaultdict(int),
request_payload={},
connector=SimpleNamespace(
config={
"extra": {
Expand Down Expand Up @@ -152,3 +156,113 @@ def test_per_request_override_wins_over_stage_config():
tm = _tm(initial_chunk=5)
payload = _call(tm, "r-override2", n_frames=10, put_req=0, req_ic=15)
assert payload is None


def test_first_streaming_chunk_prepends_ref_code_context():
tm = _tm(initial_chunk=10)
rid = "r-ref"
tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(10)]
ref_code = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long)

payload = talker2code2wav_async_chunk(
transfer_manager=tm,
pooling_output={"audio_codes": torch.zeros((0,)), "ref_code": ref_code},
request=_req(rid, finished=False, initial_codec_chunk_frames=10),
is_finished=False,
)

assert payload is not None
assert payload["left_context_size"] == 2
assert len(payload["code_predictor_codes"]) == _Q * 12


def test_ref_code_context_only_applies_to_first_streaming_chunk():
tm = _tm(initial_chunk=10)
rid = "r-ref2"
tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(20)]
tm.put_req_chunk[rid] = 1
ref_code = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long)

payload = talker2code2wav_async_chunk(
transfer_manager=tm,
pooling_output={"audio_codes": torch.zeros((0,)), "ref_code": ref_code},
request=_req(rid, finished=False, initial_codec_chunk_frames=10),
is_finished=False,
)

assert payload is not None
assert payload["left_context_size"] == 10
assert len(payload["code_predictor_codes"]) == _Q * 20


def test_ref_code_context_can_be_buffered_before_first_emit():
tm = _tm(initial_chunk=10)
rid = "r-ref-buffered"
ref_code = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long)

first_payload = talker2code2wav_async_chunk(
transfer_manager=tm,
pooling_output={"audio_codes": torch.tensor([[1, 2, 3, 4]]), "ref_code": ref_code},
request=_req(rid, finished=False, initial_codec_chunk_frames=10),
is_finished=False,
)
assert first_payload is None
assert rid in tm.request_payload

for _ in range(8):
talker2code2wav_async_chunk(
transfer_manager=tm,
pooling_output={"audio_codes": torch.tensor([[1, 2, 3, 4]])},
request=_req(rid, finished=False, initial_codec_chunk_frames=10),
is_finished=False,
)

payload = talker2code2wav_async_chunk(
transfer_manager=tm,
pooling_output={"audio_codes": torch.tensor([[1, 2, 3, 4]])},
request=_req(rid, finished=False, initial_codec_chunk_frames=10),
is_finished=False,
)

assert payload is not None
assert payload["left_context_size"] == 2
assert len(payload["code_predictor_codes"]) == _Q * 12
assert rid not in tm.request_payload


def test_non_async_processor_prepends_ref_code_and_sets_trim_context():
ref_code = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long)
audio_codes = torch.tensor(
[
[0, 0, 0, 0],
[1, 2, 3, 4],
[5, 6, 7, 8],
],
dtype=torch.long,
)
output = SimpleNamespace(multimodal_output={"audio_codes": audio_codes, "ref_code": ref_code})
stage = SimpleNamespace(engine_outputs=[SimpleNamespace(outputs=[output])])

prompts = talker2code2wav(stage_list=[stage], engine_input_source=[0])

assert len(prompts) == 1
prompt = prompts[0]
assert prompt["additional_information"] == {"left_context_size": [2]}
assert prompt["prompt_token_ids"] == [
9,
8,
1,
5,
9,
8,
2,
6,
9,
8,
3,
7,
9,
8,
4,
8,
]
5 changes: 4 additions & 1 deletion vllm_omni/entrypoints/openai/serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
from fastapi import Request, UploadFile
from fastapi.responses import Response, StreamingResponse
from transformers.utils.hub import cached_file
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.logger import init_logger
from vllm.multimodal.media import MediaConnector
Expand Down Expand Up @@ -120,7 +121,9 @@ def _load_codec_frame_rate(self) -> float | None:
try:
model_path = self.engine_client.model_config.model
st_config_path = os.path.join(model_path, "speech_tokenizer", "config.json")
if os.path.exists(st_config_path):
if not os.path.exists(st_config_path):
st_config_path = cached_file(model_path, "speech_tokenizer/config.json")
if st_config_path is not None and os.path.exists(st_config_path):
with open(st_config_path) as f:
st_config = json.load(f)
output_sr = st_config.get("output_sample_rate")
Expand Down
19 changes: 15 additions & 4 deletions vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A
logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer")
audio_codes_list: list[torch.Tensor] = []
ref_code_len_list: list[torch.Tensor] = []
ref_code_tensor: torch.Tensor | None = None
codec_streaming_list: list[torch.Tensor] = []
for info in info_dicts:
if not isinstance(info, dict):
Expand All @@ -459,6 +460,9 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A
codec_streaming_list.append(
torch.full((int(ac.shape[0]),), int(cs), dtype=torch.int8, device=ac.device)
)
ref_code = info.get("ref_code")
if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0:
ref_code_tensor = ref_code
ref_len = info.get("ref_code_len")
if ref_len is None:
continue
Expand Down Expand Up @@ -487,6 +491,8 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A
mm: dict[str, torch.Tensor] = {"audio_codes": audio_codes}
if ref_code_len_list:
mm["ref_code_len"] = torch.cat(ref_code_len_list, dim=0)[:span_len]
if ref_code_tensor is not None:
mm["ref_code"] = [ref_code_tensor]
if codec_streaming_list:
mm["codec_streaming"] = torch.cat(codec_streaming_list, dim=0)[:span_len]
return OmniOutput(text_hidden_states=hidden, multimodal_outputs=mm)
Expand Down Expand Up @@ -538,8 +544,8 @@ def preprocess(
# Subsequent prefill rounds (multi-chunk): prompt_embeds_cpu is a Tensor stored by the first round.
is_first_prefill = not isinstance(prompt_embeds_cpu, torch.Tensor) or prompt_embeds_cpu.ndim != 2
if is_first_prefill:
full_prompt_embeds, tailing_text_hidden, tts_pad_embed, ref_code_len = self._build_prompt_embeds(
task_type=task_type, info_dict=info_dict
full_prompt_embeds, tailing_text_hidden, tts_pad_embed, ref_code_len, ref_code = (
self._build_prompt_embeds(task_type=task_type, info_dict=info_dict)
)
# Store full prompt embeddings + trailing queue on CPU for later chunks/steps.
prompt_embeds_cpu = full_prompt_embeds.detach().to("cpu").contiguous()
Expand All @@ -550,6 +556,8 @@ def preprocess(
"talker_prefill_offset": 0,
"codec_streaming": codec_streaming,
}
if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0:
info_update["ref_code"] = ref_code.detach().to("cpu").contiguous()
if ref_code_len is not None:
info_update["ref_code_len"] = int(ref_code_len)
# Always return a span_len slice; if the scheduled placeholder is longer, pad with tts_pad_embed.
Expand Down Expand Up @@ -773,7 +781,6 @@ def _first(x: object, default: object) -> object:

if ref_code_len is None and estimate_ref_code_len is not None:
ref_code_len = estimate_ref_code_len(info.get("ref_audio"))

if ref_code_len is None:
raise ValueError(
"Base in-context voice cloning requires either `voice_clone_prompt.ref_code` "
Expand Down Expand Up @@ -1183,7 +1190,7 @@ def _build_prompt_embeds(
*,
task_type: str,
info_dict: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int | None]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int | None, torch.Tensor | None]:
text = (info_dict.get("text") or [""])[0]
language = (info_dict.get("language") or ["Auto"])[0]
non_streaming_mode_val = info_dict.get("non_streaming_mode")
Expand Down Expand Up @@ -1268,6 +1275,7 @@ def _build_prompt_embeds(
# Speaker embedding/token (task-dependent)
speaker_embed = None
ref_code_len: int | None = None
ref_code_prompt: torch.Tensor | None = None

def _as_singleton(x: object) -> object:
if isinstance(x, list):
Expand Down Expand Up @@ -1333,6 +1341,8 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None:
wav_np, sr = self._normalize_ref_audio(ref_audio_list[0])
ref_code_t = self._encode_ref_audio_to_code(wav_np, sr).to(device=input_ids.device)
ref_code_len = int(ref_code_t.shape[0])
if isinstance(ref_code_t, torch.Tensor):
ref_code_prompt = ref_code_t

# Speaker embedding: use prompt embed if provided; otherwise extract from audio.
spk = None
Expand Down Expand Up @@ -1521,6 +1531,7 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None:
trailing_text_hidden.squeeze(0), # [T, H]
tts_pad_embed.squeeze(0), # [1, H]
ref_code_len,
ref_code_prompt.contiguous() if isinstance(ref_code_prompt, torch.Tensor) else None,
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Expand Down
26 changes: 26 additions & 0 deletions vllm_omni/model_executor/stage_input_processors/qwen3_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,24 @@ def talker2code2wav(
# Filter zero-padded frames (EOS/invalid steps), matching _extract_last_frame behavior
valid_mask = audio_codes.any(dim=1)
audio_codes = audio_codes[valid_mask]
ref_code = output.multimodal_output.get("ref_code")
if isinstance(ref_code, list):
ref_code = ref_code[0] if ref_code else None
if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0:
ref_code = ref_code.to(torch.long).cpu().contiguous()
ref_code_len = int(ref_code.shape[0])
audio_codes = torch.cat([ref_code.to(audio_codes.device), audio_codes], dim=0)
else:
ref_code_len = 0
# Code2Wav expects codebook-major flat: [Q*num_frames]
codec_codes = audio_codes.transpose(0, 1).cpu().reshape(-1).tolist()
additional_information = {"left_context_size": [ref_code_len]} if ref_code_len > 0 else None
code2wav_inputs.append(
OmniTokensPrompt(
prompt_token_ids=codec_codes,
multi_modal_data=None,
mm_processor_kwargs=None,
additional_information=additional_information,
)
)
return code2wav_inputs
Expand Down Expand Up @@ -61,12 +72,19 @@ def talker2code2wav_async_chunk(
) -> dict[str, Any] | None:
request_id = request.external_req_id
finished = bool(is_finished or request.is_finished())
request_payload = getattr(transfer_manager, "request_payload", None)
if request_payload is None:
request_payload = {}
transfer_manager.request_payload = request_payload

if isinstance(pooling_output, dict):
frame = _extract_last_frame(pooling_output)
if frame is not None:
codec_codes = frame.cpu().tolist()
transfer_manager.code_prompt_token_ids[request_id].append(codec_codes)
ref_code = pooling_output.get("ref_code")
if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0 and request_payload.get(request_id) is None:
request_payload[request_id] = ref_code.to(torch.long).cpu().contiguous()
elif not finished:
# Some steps may not produce pooling_output. Only flush on finish.
return None
Expand Down Expand Up @@ -139,6 +157,14 @@ def talker2code2wav_async_chunk(
left_context_size = max(0, int(end_index - context_length))
window_frames = transfer_manager.code_prompt_token_ids[request_id][-end_index:]

# Match offline Base ICL decoding by prepending ref_code only once to the first decoder window.
if transfer_manager.put_req_chunk[request_id] == 0:
ref_code = request_payload.pop(request_id, None)
if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0:
ref_frames = ref_code.tolist()
window_frames = ref_frames + window_frames
left_context_size += len(ref_frames)

# Pack context + chunk into codebook-major flat codes for adapter.
code_predictor_codes = torch.tensor(window_frames).transpose(0, 1).reshape(-1).tolist()

Expand Down