Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Parity test for the scalar / batched decode-preprocess paths.

The talker exposes a batched ``preprocess_decode_batch`` plus a scalar
fast-path that loops to the existing single-request ``preprocess()`` when
the decode batch is small or has no ``task_type=Base`` requests. This test
asserts the two paths produce identical outputs so the fast-path is a true
byte-equivalent shortcut, not an approximation.
"""

from types import SimpleNamespace

import pytest
import torch

from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import (
_DEFAULT_SCALAR_DECODE_PREPROCESS_THRESHOLD,
Qwen3TTSTalkerForConditionalGeneration,
)


def _make_minimal_talker(*, threshold: int | None = None, compact_min: int = 256):
model = Qwen3TTSTalkerForConditionalGeneration.__new__(Qwen3TTSTalkerForConditionalGeneration)
model.talker_config = SimpleNamespace(codec_pad_id=7, num_code_groups=16)
model._scalar_decode_preprocess_threshold = (
threshold if threshold is not None else _DEFAULT_SCALAR_DECODE_PREPROCESS_THRESHOLD
)
model._trailing_text_compact_min_frames = compact_min

def fake_embed_input_ids(input_ids):
return input_ids.to(torch.float32).reshape(-1, 1, 1).expand(-1, 1, 4)

model.embed_input_ids = fake_embed_input_ids
return model


def _build_req_info(*, task_type: str, text_offset: int, seed: int):
"""Build one request payload with a predictable trailing-text tensor."""
trailing = torch.arange(seed, seed + 12, dtype=torch.float32).reshape(3, 4)
last_hidden = torch.full((4,), float(seed % 7), dtype=torch.float32)
tts_pad = torch.full((1, 4), float(-seed), dtype=torch.float32)
return {
"text": ["hello"],
"task_type": [task_type],
"hidden_states": {"trailing_text": trailing, "last": last_hidden},
"embed": {"tts_pad": tts_pad},
"meta": {"talker_text_offset": text_offset},
}


@pytest.mark.parametrize("batch_size", [1, 2, 4, 8])
@pytest.mark.parametrize("task_type", ["Base", "CustomVoice"])
def test_scalar_and_batched_paths_agree(batch_size: int, task_type: str) -> None:
"""Same inputs → identical (out_ids, out_embeds, past_hidden, text_step, updates)."""
req_infos = [_build_req_info(task_type=task_type, text_offset=i % 3, seed=10 + i) for i in range(batch_size)]
input_ids = torch.arange(100, 100 + batch_size, dtype=torch.long)

scalar_model = _make_minimal_talker(threshold=batch_size + 1)
batched_model = _make_minimal_talker(threshold=0)

scalar_out = scalar_model.preprocess_decode_batch(
input_ids=input_ids,
req_infos=[dict(info) for info in req_infos],
)
batched_out = batched_model.preprocess_decode_batch(
input_ids=input_ids,
req_infos=[dict(info) for info in req_infos],
)

s_ids, s_embeds, s_past, s_step, s_updates = scalar_out
b_ids, b_embeds, b_past, b_step, b_updates = batched_out

assert s_ids.tolist() == b_ids.tolist()
assert torch.equal(s_embeds, b_embeds)
assert torch.equal(s_past, b_past)
assert torch.equal(s_step, b_step)
assert len(s_updates) == len(b_updates)
for s_u, b_u in zip(s_updates, b_updates):
assert s_u["meta"]["talker_text_offset"] == b_u["meta"]["talker_text_offset"]
assert s_u["meta"]["codec_streaming"] == b_u["meta"]["codec_streaming"]
s_has_hs = "hidden_states" in s_u
b_has_hs = "hidden_states" in b_u
assert s_has_hs == b_has_hs
if s_has_hs:
assert torch.equal(
s_u["hidden_states"]["trailing_text"],
b_u["hidden_states"]["trailing_text"],
)


def test_routing_uses_scalar_for_small_batch() -> None:
model = _make_minimal_talker(threshold=4)
req_infos = [_build_req_info(task_type="Base", text_offset=0, seed=1) for _ in range(4)]
assert model._should_use_scalar_decode_preprocess(req_infos) is True


def test_routing_uses_batched_for_large_base_batch() -> None:
model = _make_minimal_talker(threshold=4)
req_infos = [_build_req_info(task_type="Base", text_offset=0, seed=1) for _ in range(8)]
assert model._should_use_scalar_decode_preprocess(req_infos) is False


def test_routing_uses_scalar_when_no_base_request() -> None:
model = _make_minimal_talker(threshold=4)
req_infos = [_build_req_info(task_type="CustomVoice", text_offset=0, seed=i) for i in range(8)]
assert model._should_use_scalar_decode_preprocess(req_infos) is True


def test_routing_threshold_zero_means_size_check_disabled() -> None:
model = _make_minimal_talker(threshold=0)
base_batch = [_build_req_info(task_type="Base", text_offset=0, seed=i) for i in range(2)]
custom_batch = [_build_req_info(task_type="CustomVoice", text_offset=0, seed=i) for i in range(2)]
assert model._should_use_scalar_decode_preprocess(base_batch) is False
assert model._should_use_scalar_decode_preprocess(custom_batch) is True
12 changes: 8 additions & 4 deletions vllm_omni/deploy/qwen3_tts_high_concurrency.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ connectors:
connector_get_max_wait: 300
codec_chunk_frames: 25
codec_left_context_frames: 72
# Stage0 code-predictor prefix CUDA graphs for the c64 hot path.
# These keys are consumed by Qwen3-TTS talker and ignored by Code2Wav.
code_predictor_prefix_graphs: true
# Stage0 code-predictor prefix CUDA graphs. Off by default; the path
# currently regresses default_voice c=64 TTFP. Voice_clone deployments
# that want the captured prefix graphs can flip this back to true in a
# downstream yaml. Keys are consumed by the talker and ignored by Code2Wav.
code_predictor_prefix_graphs: false
code_predictor_prefix_graph_buckets: [64]
code_predictor_prefix_graph_seq_lens: [2, 3, 4, 5, 6, 7, 8]
# Keep voice-clone reference context bounded so Stage1 chunk lengths are
Expand All @@ -36,7 +38,9 @@ connectors:
# no-ref first/steady chunks: 25 / 97 frames
# Base ref-context first/steady chunks: 73 / 169 frames
# decoder internal non-streaming chunks: 325 frames
decode_cudagraph_capture_sizes: [25, 73, 97, 169, 325]
# 49, 145 cover the default_voice shapes that v021 hit but PR #3662 left
# uncaptured. The full set is 7 shapes - well under Stage1's 12-shape cap.
decode_cudagraph_capture_sizes: [25, 49, 73, 97, 145, 169, 325]
# Keep B>1 captures opt-in; c64 e2e validation did not show a stable win.
decode_cudagraph_batch_sizes: [1]
decode_compile_shapes: []
Expand Down
112 changes: 109 additions & 3 deletions vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@

logger = init_logger(__name__)

_TRAILING_TEXT_COMPACT_MIN_FRAMES = 64
_TRAILING_TEXT_COMPACT_MIN_FRAMES = 64 # legacy default (overridable via connector_extra)
_DEFAULT_SCALAR_DECODE_PREPROCESS_THRESHOLD = 8
_DEFAULT_TRAILING_TEXT_COMPACT_MIN_FRAMES = 256
_PRECOMPUTED_REF_CODE_KEY = "precomputed_ref"
_NORMALIZED_REF_AUDIO_KEY = "_qwen3_tts_normalized_ref_audio"
_PRECOMPUTED_TEXT_IDS_KEY = "_qwen3_tts_text_ids"
Expand Down Expand Up @@ -448,6 +450,36 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
dict(raw_subtalker_sampling) if isinstance(raw_subtalker_sampling, Mapping) else {}
)

extra_cfg = self._stage_connector_extra_config(vllm_config)
self._scalar_decode_preprocess_threshold = self._parse_non_negative_int(
extra_cfg.get("scalar_decode_preprocess_threshold"),
_DEFAULT_SCALAR_DECODE_PREPROCESS_THRESHOLD,
)
self._trailing_text_compact_min_frames = self._parse_non_negative_int(
extra_cfg.get("trailing_text_compact_min_frames"),
_DEFAULT_TRAILING_TEXT_COMPACT_MIN_FRAMES,
)

@staticmethod
def _stage_connector_extra_config(vllm_config: VllmConfig) -> dict[str, Any]:
model_cfg = getattr(vllm_config, "model_config", None)
connector_cfg = getattr(model_cfg, "stage_connector_config", None)
if isinstance(connector_cfg, dict):
extra_cfg = connector_cfg.get("extra", connector_cfg)
else:
extra_cfg = getattr(connector_cfg, "extra", None)
return extra_cfg if isinstance(extra_cfg, dict) else {}

@staticmethod
def _parse_non_negative_int(value: object, default: int) -> int:
if value is None:
return default
try:
parsed = int(value)
except (TypeError, ValueError):
return default
return parsed if parsed >= 0 else default

# -------------------- vLLM required hooks --------------------

def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor:
Expand Down Expand Up @@ -702,7 +734,7 @@ def preprocess(
)
next_text_offset = text_offset + 1
should_compact_tail = next_text_offset >= tail_len or (
next_text_offset >= _TRAILING_TEXT_COMPACT_MIN_FRAMES and next_text_offset * 2 >= tail_len
next_text_offset >= self._trailing_text_compact_min_frames and next_text_offset * 2 >= tail_len
)
if should_compact_tail:
if next_text_offset >= tail_len:
Expand Down Expand Up @@ -746,6 +778,80 @@ def preprocess_decode_batch(
*,
input_ids: torch.Tensor,
req_infos: list[dict[str, Any]],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[dict[str, Any]]]:
if self._should_use_scalar_decode_preprocess(req_infos):
return self._preprocess_decode_batch_scalar(input_ids=input_ids, req_infos=req_infos)
return self._preprocess_decode_batch_impl(input_ids=input_ids, req_infos=req_infos)

def _should_use_scalar_decode_preprocess(self, req_infos: list[dict[str, Any]]) -> bool:
threshold = self._scalar_decode_preprocess_threshold
if threshold > 0 and len(req_infos) <= threshold:
return True
# No task_type=Base request -> batched path saves nothing.
for info in req_infos:
extra = info.get("additional_information")
if isinstance(extra, dict):
task_field = extra.get("task_type")
else:
task_field = info.get("task_type")
if isinstance(task_field, list):
task_type = task_field[0] if task_field else None
else:
task_type = task_field
if task_type == "Base":
return False
return True

def _preprocess_decode_batch_scalar(
self,
*,
input_ids: torch.Tensor,
req_infos: list[dict[str, Any]],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[dict[str, Any]]]:
"""Loop ``preprocess`` per request and stack the outputs."""
input_ids_flat = input_ids.reshape(-1)
if int(input_ids_flat.numel()) != len(req_infos):
raise ValueError(
f"preprocess_decode_batch expected {len(req_infos)} input ids, got {int(input_ids_flat.numel())}"
)

inputs_embeds_list: list[torch.Tensor] = []
past_hidden_list: list[torch.Tensor] = []
text_step_list: list[torch.Tensor] = []
updates: list[dict[str, Any]] = []

for i, info_dict in enumerate(req_infos):
single_input_ids = input_ids_flat[i : i + 1]
_, single_inputs_embeds, single_update = self.preprocess(
single_input_ids,
None,
**info_dict,
)
mtp_inputs = single_update.pop("mtp_inputs", None)
if mtp_inputs is None:
raise RuntimeError("scalar decode preprocess: missing mtp_inputs in update")
past_hidden, text_step = mtp_inputs
inputs_embeds_list.append(single_inputs_embeds.reshape(1, -1))
past_hidden_list.append(past_hidden.reshape(1, -1))
text_step_list.append(text_step.reshape(1, -1))
updates.append(single_update)

inputs_embeds_out = torch.cat(inputs_embeds_list, dim=0)
past_hidden_out = torch.cat(past_hidden_list, dim=0)
text_step_out = torch.cat(text_step_list, dim=0)
return (
input_ids_flat,
inputs_embeds_out,
past_hidden_out,
text_step_out,
updates,
)

def _preprocess_decode_batch_impl(
self,
*,
input_ids: torch.Tensor,
req_infos: list[dict[str, Any]],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[dict[str, Any]]]:
"""Batch the decode-only preprocess path for Qwen3-TTS.

Expand Down Expand Up @@ -806,7 +912,7 @@ def preprocess_decode_batch(
text_step = tail[text_offset : text_offset + 1].to(device=device, dtype=dtype).reshape(1, -1)
next_text_offset = text_offset + 1
should_compact_tail = next_text_offset >= tail_len or (
next_text_offset >= _TRAILING_TEXT_COMPACT_MIN_FRAMES and next_text_offset * 2 >= tail_len
next_text_offset >= self._trailing_text_compact_min_frames and next_text_offset * 2 >= tail_len
)
if should_compact_tail:
if next_text_offset >= tail_len:
Expand Down
Loading