diff --git a/tests/model_executor/models/qwen3_tts/test_decode_preprocess_parity.py b/tests/model_executor/models/qwen3_tts/test_decode_preprocess_parity.py new file mode 100644 index 0000000000..e30600c2b0 --- /dev/null +++ b/tests/model_executor/models/qwen3_tts/test_decode_preprocess_parity.py @@ -0,0 +1,113 @@ +"""Parity test for the scalar / batched decode-preprocess paths. + +The talker exposes a batched ``preprocess_decode_batch`` plus a scalar +fast-path that loops to the existing single-request ``preprocess()`` when +the decode batch is small or has no ``task_type=Base`` requests. This test +asserts the two paths produce identical outputs so the fast-path is a true +byte-equivalent shortcut, not an approximation. +""" + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( + _DEFAULT_SCALAR_DECODE_PREPROCESS_THRESHOLD, + Qwen3TTSTalkerForConditionalGeneration, +) + + +def _make_minimal_talker(*, threshold: int | None = None, compact_min: int = 256): + model = Qwen3TTSTalkerForConditionalGeneration.__new__(Qwen3TTSTalkerForConditionalGeneration) + model.talker_config = SimpleNamespace(codec_pad_id=7, num_code_groups=16) + model._scalar_decode_preprocess_threshold = ( + threshold if threshold is not None else _DEFAULT_SCALAR_DECODE_PREPROCESS_THRESHOLD + ) + model._trailing_text_compact_min_frames = compact_min + + def fake_embed_input_ids(input_ids): + return input_ids.to(torch.float32).reshape(-1, 1, 1).expand(-1, 1, 4) + + model.embed_input_ids = fake_embed_input_ids + return model + + +def _build_req_info(*, task_type: str, text_offset: int, seed: int): + """Build one request payload with a predictable trailing-text tensor.""" + trailing = torch.arange(seed, seed + 12, dtype=torch.float32).reshape(3, 4) + last_hidden = torch.full((4,), float(seed % 7), dtype=torch.float32) + tts_pad = torch.full((1, 4), float(-seed), dtype=torch.float32) + return { + "text": ["hello"], + "task_type": [task_type], + "hidden_states": {"trailing_text": trailing, "last": last_hidden}, + "embed": {"tts_pad": tts_pad}, + "meta": {"talker_text_offset": text_offset}, + } + + +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("task_type", ["Base", "CustomVoice"]) +def test_scalar_and_batched_paths_agree(batch_size: int, task_type: str) -> None: + """Same inputs → identical (out_ids, out_embeds, past_hidden, text_step, updates).""" + req_infos = [_build_req_info(task_type=task_type, text_offset=i % 3, seed=10 + i) for i in range(batch_size)] + input_ids = torch.arange(100, 100 + batch_size, dtype=torch.long) + + scalar_model = _make_minimal_talker(threshold=batch_size + 1) + batched_model = _make_minimal_talker(threshold=0) + + scalar_out = scalar_model.preprocess_decode_batch( + input_ids=input_ids, + req_infos=[dict(info) for info in req_infos], + ) + batched_out = batched_model.preprocess_decode_batch( + input_ids=input_ids, + req_infos=[dict(info) for info in req_infos], + ) + + s_ids, s_embeds, s_past, s_step, s_updates = scalar_out + b_ids, b_embeds, b_past, b_step, b_updates = batched_out + + assert s_ids.tolist() == b_ids.tolist() + assert torch.equal(s_embeds, b_embeds) + assert torch.equal(s_past, b_past) + assert torch.equal(s_step, b_step) + assert len(s_updates) == len(b_updates) + for s_u, b_u in zip(s_updates, b_updates): + assert s_u["meta"]["talker_text_offset"] == b_u["meta"]["talker_text_offset"] + assert s_u["meta"]["codec_streaming"] == b_u["meta"]["codec_streaming"] + s_has_hs = "hidden_states" in s_u + b_has_hs = "hidden_states" in b_u + assert s_has_hs == b_has_hs + if s_has_hs: + assert torch.equal( + s_u["hidden_states"]["trailing_text"], + b_u["hidden_states"]["trailing_text"], + ) + + +def test_routing_uses_scalar_for_small_batch() -> None: + model = _make_minimal_talker(threshold=4) + req_infos = [_build_req_info(task_type="Base", text_offset=0, seed=1) for _ in range(4)] + assert model._should_use_scalar_decode_preprocess(req_infos) is True + + +def test_routing_uses_batched_for_large_base_batch() -> None: + model = _make_minimal_talker(threshold=4) + req_infos = [_build_req_info(task_type="Base", text_offset=0, seed=1) for _ in range(8)] + assert model._should_use_scalar_decode_preprocess(req_infos) is False + + +def test_routing_uses_scalar_when_no_base_request() -> None: + model = _make_minimal_talker(threshold=4) + req_infos = [_build_req_info(task_type="CustomVoice", text_offset=0, seed=i) for i in range(8)] + assert model._should_use_scalar_decode_preprocess(req_infos) is True + + +def test_routing_threshold_zero_means_size_check_disabled() -> None: + model = _make_minimal_talker(threshold=0) + base_batch = [_build_req_info(task_type="Base", text_offset=0, seed=i) for i in range(2)] + custom_batch = [_build_req_info(task_type="CustomVoice", text_offset=0, seed=i) for i in range(2)] + assert model._should_use_scalar_decode_preprocess(base_batch) is False + assert model._should_use_scalar_decode_preprocess(custom_batch) is True diff --git a/vllm_omni/deploy/qwen3_tts_high_concurrency.yaml b/vllm_omni/deploy/qwen3_tts_high_concurrency.yaml index 9c6ec19fe0..d20d96f17e 100644 --- a/vllm_omni/deploy/qwen3_tts_high_concurrency.yaml +++ b/vllm_omni/deploy/qwen3_tts_high_concurrency.yaml @@ -22,9 +22,11 @@ connectors: connector_get_max_wait: 300 codec_chunk_frames: 25 codec_left_context_frames: 72 - # Stage0 code-predictor prefix CUDA graphs for the c64 hot path. - # These keys are consumed by Qwen3-TTS talker and ignored by Code2Wav. - code_predictor_prefix_graphs: true + # Stage0 code-predictor prefix CUDA graphs. Off by default; the path + # currently regresses default_voice c=64 TTFP. Voice_clone deployments + # that want the captured prefix graphs can flip this back to true in a + # downstream yaml. Keys are consumed by the talker and ignored by Code2Wav. + code_predictor_prefix_graphs: false code_predictor_prefix_graph_buckets: [64] code_predictor_prefix_graph_seq_lens: [2, 3, 4, 5, 6, 7, 8] # Keep voice-clone reference context bounded so Stage1 chunk lengths are @@ -36,7 +38,9 @@ connectors: # no-ref first/steady chunks: 25 / 97 frames # Base ref-context first/steady chunks: 73 / 169 frames # decoder internal non-streaming chunks: 325 frames - decode_cudagraph_capture_sizes: [25, 73, 97, 169, 325] + # 49, 145 cover the default_voice shapes that v021 hit but PR #3662 left + # uncaptured. The full set is 7 shapes - well under Stage1's 12-shape cap. + decode_cudagraph_capture_sizes: [25, 49, 73, 97, 145, 169, 325] # Keep B>1 captures opt-in; c64 e2e validation did not show a stable win. decode_cudagraph_batch_sizes: [1] decode_compile_shapes: [] diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index a705a1cd48..ce915c1de0 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -37,7 +37,9 @@ logger = init_logger(__name__) -_TRAILING_TEXT_COMPACT_MIN_FRAMES = 64 +_TRAILING_TEXT_COMPACT_MIN_FRAMES = 64 # legacy default (overridable via connector_extra) +_DEFAULT_SCALAR_DECODE_PREPROCESS_THRESHOLD = 8 +_DEFAULT_TRAILING_TEXT_COMPACT_MIN_FRAMES = 256 _PRECOMPUTED_REF_CODE_KEY = "precomputed_ref" _NORMALIZED_REF_AUDIO_KEY = "_qwen3_tts_normalized_ref_audio" _PRECOMPUTED_TEXT_IDS_KEY = "_qwen3_tts_text_ids" @@ -448,6 +450,36 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): dict(raw_subtalker_sampling) if isinstance(raw_subtalker_sampling, Mapping) else {} ) + extra_cfg = self._stage_connector_extra_config(vllm_config) + self._scalar_decode_preprocess_threshold = self._parse_non_negative_int( + extra_cfg.get("scalar_decode_preprocess_threshold"), + _DEFAULT_SCALAR_DECODE_PREPROCESS_THRESHOLD, + ) + self._trailing_text_compact_min_frames = self._parse_non_negative_int( + extra_cfg.get("trailing_text_compact_min_frames"), + _DEFAULT_TRAILING_TEXT_COMPACT_MIN_FRAMES, + ) + + @staticmethod + def _stage_connector_extra_config(vllm_config: VllmConfig) -> dict[str, Any]: + model_cfg = getattr(vllm_config, "model_config", None) + connector_cfg = getattr(model_cfg, "stage_connector_config", None) + if isinstance(connector_cfg, dict): + extra_cfg = connector_cfg.get("extra", connector_cfg) + else: + extra_cfg = getattr(connector_cfg, "extra", None) + return extra_cfg if isinstance(extra_cfg, dict) else {} + + @staticmethod + def _parse_non_negative_int(value: object, default: int) -> int: + if value is None: + return default + try: + parsed = int(value) + except (TypeError, ValueError): + return default + return parsed if parsed >= 0 else default + # -------------------- vLLM required hooks -------------------- def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: @@ -702,7 +734,7 @@ def preprocess( ) next_text_offset = text_offset + 1 should_compact_tail = next_text_offset >= tail_len or ( - next_text_offset >= _TRAILING_TEXT_COMPACT_MIN_FRAMES and next_text_offset * 2 >= tail_len + next_text_offset >= self._trailing_text_compact_min_frames and next_text_offset * 2 >= tail_len ) if should_compact_tail: if next_text_offset >= tail_len: @@ -746,6 +778,80 @@ def preprocess_decode_batch( *, input_ids: torch.Tensor, req_infos: list[dict[str, Any]], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[dict[str, Any]]]: + if self._should_use_scalar_decode_preprocess(req_infos): + return self._preprocess_decode_batch_scalar(input_ids=input_ids, req_infos=req_infos) + return self._preprocess_decode_batch_impl(input_ids=input_ids, req_infos=req_infos) + + def _should_use_scalar_decode_preprocess(self, req_infos: list[dict[str, Any]]) -> bool: + threshold = self._scalar_decode_preprocess_threshold + if threshold > 0 and len(req_infos) <= threshold: + return True + # No task_type=Base request -> batched path saves nothing. + for info in req_infos: + extra = info.get("additional_information") + if isinstance(extra, dict): + task_field = extra.get("task_type") + else: + task_field = info.get("task_type") + if isinstance(task_field, list): + task_type = task_field[0] if task_field else None + else: + task_type = task_field + if task_type == "Base": + return False + return True + + def _preprocess_decode_batch_scalar( + self, + *, + input_ids: torch.Tensor, + req_infos: list[dict[str, Any]], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[dict[str, Any]]]: + """Loop ``preprocess`` per request and stack the outputs.""" + input_ids_flat = input_ids.reshape(-1) + if int(input_ids_flat.numel()) != len(req_infos): + raise ValueError( + f"preprocess_decode_batch expected {len(req_infos)} input ids, got {int(input_ids_flat.numel())}" + ) + + inputs_embeds_list: list[torch.Tensor] = [] + past_hidden_list: list[torch.Tensor] = [] + text_step_list: list[torch.Tensor] = [] + updates: list[dict[str, Any]] = [] + + for i, info_dict in enumerate(req_infos): + single_input_ids = input_ids_flat[i : i + 1] + _, single_inputs_embeds, single_update = self.preprocess( + single_input_ids, + None, + **info_dict, + ) + mtp_inputs = single_update.pop("mtp_inputs", None) + if mtp_inputs is None: + raise RuntimeError("scalar decode preprocess: missing mtp_inputs in update") + past_hidden, text_step = mtp_inputs + inputs_embeds_list.append(single_inputs_embeds.reshape(1, -1)) + past_hidden_list.append(past_hidden.reshape(1, -1)) + text_step_list.append(text_step.reshape(1, -1)) + updates.append(single_update) + + inputs_embeds_out = torch.cat(inputs_embeds_list, dim=0) + past_hidden_out = torch.cat(past_hidden_list, dim=0) + text_step_out = torch.cat(text_step_list, dim=0) + return ( + input_ids_flat, + inputs_embeds_out, + past_hidden_out, + text_step_out, + updates, + ) + + def _preprocess_decode_batch_impl( + self, + *, + input_ids: torch.Tensor, + req_infos: list[dict[str, Any]], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[dict[str, Any]]]: """Batch the decode-only preprocess path for Qwen3-TTS. @@ -806,7 +912,7 @@ def preprocess_decode_batch( text_step = tail[text_offset : text_offset + 1].to(device=device, dtype=dtype).reshape(1, -1) next_text_offset = text_offset + 1 should_compact_tail = next_text_offset >= tail_len or ( - next_text_offset >= _TRAILING_TEXT_COMPACT_MIN_FRAMES and next_text_offset * 2 >= tail_len + next_text_offset >= self._trailing_text_compact_min_frames and next_text_offset * 2 >= tail_len ) if should_compact_tail: if next_text_offset >= tail_len: