diff --git a/docs/serving/speech_api.md b/docs/serving/speech_api.md index 11813ae35c3..60d6df71895 100644 --- a/docs/serving/speech_api.md +++ b/docs/serving/speech_api.md @@ -531,16 +531,18 @@ for result in response.json()["results"]: All items are fanned out to `generate()` concurrently. The engine's stage worker automatically batches them up to the configured `max_batch_size` and queues the rest — no client-side throttling needed. -For best throughput, set both stages' `max_num_seqs` to ≥4 via `--stage-overrides`: +For best throughput, set both stages' `max_num_seqs` above 1 via `--stage-overrides`. On the current Qwen3-TTS CustomVoice benchmark, stage 1 performed best at `max_num_seqs: 10`: ```bash vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ --omni --port 8091 --trust-remote-code --enforce-eager \ - --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2}, - "1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}' + --stage-overrides '{"0":{"max_num_seqs":10,"gpu_memory_utilization":0.2}, + "1":{"max_num_seqs":10,"gpu_memory_utilization":0.2}}' ``` -The bundled `qwen3_tts.yaml` uses `max_num_seqs: 1` (single request) on both stages. Bumping to 4 yields roughly 4× throughput on the talker and lets stage 1 batch chunks across in-flight requests. +The bundled `qwen3_tts.yaml` uses a multi-request default and lets stage 1 batch chunks across in-flight requests. For latency-sensitive deployments, avoid forcing stage 1 back to `max_num_seqs: 1`; benchmark before reducing it below `10`. + +The bundled config also sets `initial_codec_chunk_frames: 1`. This emits only the first audio chunk early for lower TTFA, then returns to the normal `codec_chunk_frames` window so Code2Wav does not repeatedly decode tiny overlapping chunks. ## Supported Models diff --git a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py index 77568a665e9..33b6187e54d 100644 --- a/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py +++ b/tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py @@ -291,8 +291,18 @@ def test_deterministic_across_calls(decoder, wrapper): [2, 4, 8, 16, 25, 32, 50, 64, 128, 256, 325], [512], ), + ( + { + "codec_chunk_frames": 25, + "codec_left_context_frames": 72, + "decode_chunk_size": 400, + "decode_left_context": 17, + }, + [2, 4, 8, 16, 25, 32, 64, 97, 128, 256, 417], + [325, 512], + ), ], - ids=["default", "streaming_c33", "streaming_c25"], + ids=["default", "streaming_c33", "streaming_c25", "custom_decode_chunk"], ) def test_compute_capture_sizes(kwargs, expected_in, not_expected): """compute_capture_sizes produces expected sizes capped by max useful size.""" diff --git a/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py index 40e85da5b57..88ab16c0d1c 100644 --- a/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py +++ b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py @@ -23,13 +23,33 @@ class _FakeDecoder(nn.Module): def __init__(self, total_upsample: int = _TOTAL_UPSAMPLE): super().__init__() self.total_upsample = total_upsample - - def chunked_decode(self, codes: torch.Tensor) -> torch.Tensor: + self.decode_calls: list[dict[str, int]] = [] + self.cudagraph_calls: list[dict[str, int | torch.device]] = [] + + def to(self, *args, **kwargs): + return self + + def chunked_decode( + self, + codes: torch.Tensor, + *, + chunk_size: int = 300, + left_context_size: int = 25, + ) -> torch.Tensor: + self.decode_calls.append( + { + "chunk_size": chunk_size, + "left_context_size": left_context_size, + } + ) frames = codes.shape[-1] wav_len = frames * self.total_upsample + 6 wav = torch.arange(wav_len, dtype=torch.float32) return wav.view(1, 1, -1) + def enable_cudagraph(self, **kwargs): + self.cudagraph_calls.append(kwargs) + def _fake_dec_config(): return SimpleNamespace( @@ -38,7 +58,12 @@ def _fake_dec_config(): ) -def _make_model() -> Qwen3TTSCode2Wav: +def _make_model( + *, + stage_connector_config=None, + async_chunk: bool = False, + device: torch.device | None = None, +) -> Qwen3TTSCode2Wav: dec_config = _fake_dec_config() tok_config = SimpleNamespace( decoder_config=dec_config, @@ -56,13 +81,51 @@ def _make_model() -> Qwen3TTSCode2Wav: ): model = Qwen3TTSCode2Wav( vllm_config=SimpleNamespace( - model_config=SimpleNamespace(model="unused"), - device_config=SimpleNamespace(device=torch.device("cpu")), + load_config=SimpleNamespace(), + model_config=SimpleNamespace( + model="unused", + revision=None, + stage_connector_config=stage_connector_config, + async_chunk=async_chunk, + ), + device_config=SimpleNamespace(device=device or torch.device("cpu")), ) ) return model +def _load_weights_noop(model: Qwen3TTSCode2Wav) -> set[str]: + class _FakeModelLoader: + class Source: + def __init__(self, **_: object): + pass + + def __init__(self, _load_config: object): + pass + + def _get_weights_iterator(self, _source: object): + return iter(()) + + class _FakeAutoWeightsLoader: + def __init__(self, *_: object, **__: object): + pass + + def load_weights(self, _weights: object) -> set[str]: + return {"decoder.fake_weight"} + + with ( + patch( + "vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav.DefaultModelLoader", + _FakeModelLoader, + ), + patch( + "vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav.AutoWeightsLoader", + _FakeAutoWeightsLoader, + ), + ): + return model.load_weights(iter(())) + + def test_forward_trims_context_on_exact_frame_boundaries(): model = _make_model() @@ -87,3 +150,89 @@ def test_forward_trims_trailing_padding_without_context(): audio = out.multimodal_outputs["model_outputs"][0] expected = torch.arange(24, dtype=torch.float32) torch.testing.assert_close(audio, expected) + + +def test_connector_codec_chunking_does_not_override_decode_chunking(): + model = _make_model( + async_chunk=True, + stage_connector_config={ + "extra": { + "codec_chunk_frames": 25, + "codec_left_context_frames": 72, + } + }, + ) + + loaded = _load_weights_noop(model) + + assert loaded == {"decoder.fake_weight"} + assert model._decode_chunk_frames == 300 + assert model._decode_left_context_frames == 25 + + model.forward( + input_ids=torch.arange(12, dtype=torch.long), + runtime_additional_information=[{"meta": {"left_context_size": 0}}], + ) + + assert model.decoder.decode_calls[-1] == { + "chunk_size": 300, + "left_context_size": 25, + } + + +def test_decode_chunking_can_be_overridden_separately(): + model = _make_model( + async_chunk=True, + stage_connector_config={ + "extra": { + "codec_chunk_frames": 25, + "codec_left_context_frames": 72, + "decode_chunk_frames": 400, + "decode_left_context_frames": 17, + } + }, + ) + + _load_weights_noop(model) + + assert model._decode_chunk_frames == 400 + assert model._decode_left_context_frames == 17 + + +def test_decode_chunking_override_is_passed_to_cudagraph(): + model = _make_model( + async_chunk=True, + device=torch.device("cuda"), + stage_connector_config={ + "extra": { + "codec_chunk_frames": 25, + "codec_left_context_frames": 72, + "decode_chunk_frames": 400, + "decode_left_context_frames": 17, + } + }, + ) + + _load_weights_noop(model) + + assert model.decoder.cudagraph_calls[-1] == { + "device": torch.device("cuda"), + "codec_chunk_frames": 25, + "codec_left_context_frames": 72, + "decode_chunk_size": 400, + "decode_left_context": 17, + } + + +def test_invalid_decode_chunking_is_rejected(): + model = _make_model( + async_chunk=True, + stage_connector_config={ + "extra": { + "decode_chunk_frames": 0, + } + }, + ) + + with pytest.raises(ValueError, match="decode_chunk_frames=0"): + _load_weights_noop(model) diff --git a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py index f5cb7797215..e4045f37fb3 100644 --- a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py +++ b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py @@ -34,7 +34,7 @@ def _req(rid, *, finished, initial_codec_chunk_frames=None): ) -def _tm(*, chunk_frames=25, left_context=25, max_num_seqs=1): +def _tm(*, chunk_frames=25, left_context=25, max_num_seqs=1, initial_chunk_frames=0): return SimpleNamespace( code_prompt_token_ids=defaultdict(list), scheduler_max_num_seqs=max_num_seqs, @@ -45,6 +45,7 @@ def _tm(*, chunk_frames=25, left_context=25, max_num_seqs=1): "extra": { "codec_chunk_frames": chunk_frames, "codec_left_context_frames": left_context, + "initial_codec_chunk_frames": initial_chunk_frames, } } ), @@ -99,48 +100,37 @@ def test_flush_on_finish(): _CASES = [ # ── IC boundary rule ────────────────────────────────────────────── - # IC phase: length <= chunk_size (uses <=, consistent with fish_speech) - # IC emits fill the entire first chunk_size worth of frames, so the - # normal phase always starts at a clean chunk boundary. - # initial_coverage = (chunk_size // initial_chunk_size) * initial_chunk_size + # initial_codec_chunk_frames only controls the first emitted chunk. + # After that, the processor returns to codec_chunk_frames-sized windows + # to avoid flooding Code2Wav with repeated tiny overlapping decodes. # # Dynamic IC=16, cs=25, initial_coverage=16 - # IC does NOT evenly divide cs, so initial_coverage < cs. - # IC emits at 16; frames 17-25 remain in IC phase but 25%16!=0 -> hold. # Normal phase: adjusted = length - 16, emit when adjusted % 25 == 0. - ((25, 25, 0), 24, False, None), # IC: 24<=25, 24%16!=0 -> hold - ((25, 25, 0), 25, False, None), # IC: 25<=25, 25%16!=0 -> hold + ((25, 25, 0), 24, False, None), + ((25, 25, 0), 25, False, None), ((25, 25, 0), 41, False, (16, 41)), # normal: adjusted=25, 25%25==0 -> emit, lc=16 # - # Per-request IC=10, cs=25, initial_coverage=20 - # IC does NOT evenly divide cs; IC emits at 10, 20. - # Frames 21-25 are still IC phase but 21..25 % 10 != 0 -> hold. - ((25, 25, 10), 9, False, None), # IC: 9%10!=0 -> hold - ((25, 25, 10), 10, False, (0, 10)), # IC: 10%10==0 -> emit, lc=0 - ((25, 25, 10), 25, False, None), # IC: 25<=25, 25%10!=0 -> hold - ((25, 25, 10), 45, False, (20, 45)), # normal: adjusted=25, 25%25==0 -> emit, lc=20 + # Per-request IC=10, cs=25: first emit at 10, then 35, 60... + ((25, 25, 10), 9, False, None), + ((25, 25, 10), 10, False, (0, 10)), + ((25, 25, 10), 25, False, None), + ((25, 25, 10), 35, False, (10, 35)), + ((25, 25, 10), 45, False, None), ((25, 25, 10), 5, True, (0, 5)), # finished flushes IC tail - ((25, 25, 10), 33, True, (20, 33)), # finished flushes normal tail + ((25, 25, 10), 33, True, (10, 33)), # finished flushes normal tail # - # IC=8, cs=16: IC evenly divides chunk_size (edge case) - # initial_coverage = (16//8)*8 = 16 == chunk_size. - # IC fills the entire first chunk: emits at 8 and 16. - # Normal phase starts at frame 17; first normal emit at 16+16=32. - ((16, 25, 8), 8, False, (0, 8)), # IC: 8%8==0 -> emit, lc=0 - ((16, 25, 8), 16, False, (8, 16)), # IC: 16<=16, 16%8==0 -> emit, lc=8 - ((16, 25, 8), 24, False, None), # normal: adjusted=8, 8%16!=0 -> hold - ((16, 25, 8), 32, False, (16, 32)), # normal: adjusted=16, 16%16==0 -> first emit, lc=16 + # IC=8, cs=16: first emit at 8, then 24, 40... + ((16, 25, 8), 8, False, (0, 8)), + ((16, 25, 8), 16, False, None), + ((16, 25, 8), 24, False, (8, 24)), + ((16, 25, 8), 32, False, None), # - # IC=5, cs=25: IC evenly divides chunk_size - # initial_coverage = (25//5)*5 = 25 == chunk_size. - # IC fills the entire first chunk: emits at 5, 10, 15, 20, 25. - # Normal phase starts at frame 26; first normal emit at 25+25=50. - # Emit intervals: 5,5,5,5,5,25,25,... — smooth transition, no gap. - ((25, 25, 5), 5, False, (0, 5)), # IC: 5%5==0 -> emit, lc=0 - ((25, 25, 5), 12, False, None), # IC: 12%5!=0 -> hold - ((25, 25, 5), 25, False, (20, 25)), # IC: 25<=25, 25%5==0 -> emit, lc=20 - ((25, 25, 5), 30, False, None), # normal: adjusted=5, 5%25!=0 -> hold - ((25, 25, 5), 50, False, (25, 50)), # normal: adjusted=25, 25%25==0 -> first emit, lc=25 + # IC=5, cs=25: first emit at 5, then 30, 55... + ((25, 25, 5), 5, False, (0, 5)), + ((25, 25, 5), 12, False, None), + ((25, 25, 5), 25, False, None), + ((25, 25, 5), 30, False, (5, 30)), + ((25, 25, 5), 50, False, None), # # Per-request override: IC=15 at n_frames=10 -> 10%15!=0 -> hold ((25, 25, 15), 10, False, None), @@ -172,10 +162,10 @@ def test_dynamic_ic_adapts_to_load(): assert p1 is not None assert len(p1["codes"]["audio"]) == _Q * 2 - # High load: add 4 others -> active=5/8 -> IC=8 -> emit at 8 + # High load on a new request: active=6/8 -> IC=8 -> emit at 8 for i in range(4): tm.code_prompt_token_ids[f"other-{i}"] = [[0]] - p2 = _call(tm, "r", n_frames=8) + p2 = _call(tm, "new-high-load", n_frames=8) assert p2 is not None assert len(p2["codes"]["audio"]) == _Q * 8 @@ -201,13 +191,11 @@ def test_ic_load_change_mid_request(): for i in range(6): tm.code_prompt_token_ids[f"other-{i}"] = [[0]] * 10 - # IC for "r" is still cached as 2. - # initial_coverage = ((25-1)//2)*2 = 24, first normal emit at 24+25=49 + # IC for "r" is still cached as 2. The first normal emit is at 2+25=27. assert _call(tm, "r", n_frames=25) is None - assert _call(tm, "r", n_frames=27) is None - p3 = _call(tm, "r", n_frames=49) + p3 = _call(tm, "r", n_frames=27) assert p3 is not None - assert p3["meta"]["left_context_size"] == 24 + assert p3["meta"]["left_context_size"] == 2 # A *new* request under high load gets IC=16 (not IC=2). # Frame 2 would emit under IC=2 but must hold under IC=16. @@ -216,6 +204,25 @@ def test_ic_load_change_mid_request(): assert p4 is not None +def test_connector_initial_chunk_config_overrides_dynamic_ic(): + tm = _tm(initial_chunk_frames=4, max_num_seqs=8) + + # Under high load dynamic IC would be 16, but connector config pins the + # first chunk to 4 frames. + for i in range(7): + tm.code_prompt_token_ids[f"other-{i}"] = [[0]] + + p1 = _call(tm, "r", n_frames=4) + assert p1 is not None + assert len(p1["codes"]["audio"]) == _Q * 4 + + # Only the first chunk uses the small size; the next emit is 4+25. + assert _call(tm, "r", n_frames=25) is None + p2 = _call(tm, "r", n_frames=29) + assert p2 is not None + assert p2["meta"]["left_context_size"] == 4 + + @pytest.mark.parametrize( "active,max_bs,max_ic,expected", [ @@ -269,7 +276,7 @@ def test_ref_code_context_applies_to_all_streaming_chunks(): """ref_code is prepended as decoder context on every chunk, not just the first.""" tm = _tm() rid = "r-ref2" - tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(20)] + tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(35)] tm.put_req_chunk[rid] = 1 ref_code = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long) tm.request_payload[rid] = ref_code @@ -284,7 +291,7 @@ def test_ref_code_context_applies_to_all_streaming_chunks(): assert payload is not None # ref_code (2 frames) prepended as left context on second chunk too assert payload["meta"]["left_context_size"] == 10 + 2 - assert len(payload["codes"]["audio"]) == _Q * (20 + 2) + assert len(payload["codes"]["audio"]) == _Q * (35 + 2) def test_ref_code_context_can_be_buffered_before_first_emit(): diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py index 3187031bf36..3575f62b18c 100644 --- a/tests/worker/test_omni_connector_mixin.py +++ b/tests/worker/test_omni_connector_mixin.py @@ -607,6 +607,7 @@ def test_cleanup_removes_all_state(self): host._get_req_chunk[req_id] = 3 host._send_side_request_payload[ext_id] = {"some": "data"} host._code_prompt_token_ids[ext_id] = [[1, 2, 3]] + host._cached_ic[ext_id] = 16 host._chunk_stream_completed.add(req_id) host._stage_recv_req_ids.add(req_id) host._local_stage_payload_cache[req_id] = {"engine_inputs": {}} @@ -621,6 +622,7 @@ def test_cleanup_removes_all_state(self): self.assertNotIn(req_id, host._get_req_chunk) self.assertNotIn(ext_id, host._send_side_request_payload) self.assertNotIn(ext_id, host._code_prompt_token_ids) + self.assertNotIn(ext_id, host._cached_ic) self.assertNotIn(req_id, host._chunk_stream_completed) self.assertNotIn(req_id, host._stage_recv_req_ids) self.assertNotIn(req_id, host._local_stage_payload_cache) @@ -656,11 +658,34 @@ def test_cleanup_without_mapping(self): # Stage-0 uses req_id directly (no ext_id mapping) host._put_req_chunk[req_id] = 3 host._get_req_chunk[req_id] = 0 + host._cached_ic[req_id] = 4 host.cleanup_finished_request(req_id) self.assertNotIn(req_id, host._put_req_chunk) self.assertNotIn(req_id, host._get_req_chunk) + self.assertNotIn(req_id, host._cached_ic) + + host.shutdown_omni_connectors() + + def test_deferred_cleanup_removes_cached_ic(self): + host = self._make_host(stage_id=1) + req_id = "req-1" + ext_id = "ext-req-1" + + host._request_ids_mapping[req_id] = ext_id + host._pending_save_counts[ext_id] = 1 + host._cached_ic[ext_id] = 8 + + host.cleanup_finished_request(req_id) + + self.assertIn(ext_id, host._deferred_send_cleanup) + self.assertIn(ext_id, host._cached_ic) + + host._decrement_pending_save_count(ext_id) + + self.assertNotIn(ext_id, host._deferred_send_cleanup) + self.assertNotIn(ext_id, host._cached_ic) host.shutdown_omni_connectors() diff --git a/vllm_omni/deploy/qwen3_tts.yaml b/vllm_omni/deploy/qwen3_tts.yaml index c2f9735026b..80c3168497d 100644 --- a/vllm_omni/deploy/qwen3_tts.yaml +++ b/vllm_omni/deploy/qwen3_tts.yaml @@ -26,6 +26,8 @@ connectors: # Must match the decoder sliding attention window. codec_chunk_frames: 25 codec_left_context_frames: 72 + # Emit only the first audio chunk early, then return to codec_chunk_frames. + initial_codec_chunk_frames: 1 stages: - stage_id: 0 diff --git a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py index fd4b6252098..9993784431b 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py +++ b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py @@ -89,6 +89,8 @@ def warmup( dtype: torch.dtype = torch.long, codec_chunk_frames: int = 0, codec_left_context_frames: int = 0, + decode_chunk_size: int = 300, + decode_left_context: int = 25, ): if device.type != "cuda" or not self.enabled or self._warmed_up: return @@ -100,6 +102,8 @@ def warmup( self.capture_sizes = self.compute_capture_sizes( codec_chunk_frames=codec_chunk_frames, codec_left_context_frames=codec_left_context_frames, + decode_chunk_size=decode_chunk_size, + decode_left_context=decode_left_context, ) logger.info("Starting CUDA Graph warmup for %d sizes: %s", len(self.capture_sizes), self.capture_sizes) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index aec4f8eecca..31cdecebe85 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -286,9 +286,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if hasattr(self.decoder, "precompute_snake_caches"): self.decoder.precompute_snake_caches() - # Read chunk config from stage connector and update decode params - chunk_frames = 0 - left_frames = 0 + # The connector codec chunk settings control inter-stage streaming + # windows. Keep decoder-internal chunking separate; using the small + # streaming window here causes repeated overlap decode in Code2Wav. + codec_chunk_frames = 0 + codec_left_context_frames = 0 model_cfg = getattr(self.vllm_config, "model_config", None) connector_cfg = getattr(model_cfg, "stage_connector_config", None) extra_cfg = ( @@ -296,20 +298,40 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if isinstance(connector_cfg, dict) else getattr(connector_cfg, "extra", None) ) + + def _get_int_config(name: str, default: int) -> int: + value = extra_cfg.get(name, default) + if value is None: + return default + try: + return int(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid Qwen3-TTS Code2Wav config {name}={value!r}") from exc + if isinstance(extra_cfg, dict): - chunk_frames = int(extra_cfg.get("codec_chunk_frames") or 0) - left_frames = int(extra_cfg.get("codec_left_context_frames") or 0) - if getattr(model_cfg, "async_chunk", False) and chunk_frames > 0: - self._decode_chunk_frames = chunk_frames - self._decode_left_context_frames = left_frames if left_frames > 0 else 0 + codec_chunk_frames = int(extra_cfg.get("codec_chunk_frames") or 0) + codec_left_context_frames = int(extra_cfg.get("codec_left_context_frames") or 0) + decode_chunk_frames = _get_int_config("decode_chunk_frames", self._decode_chunk_frames) + decode_left_context_frames = _get_int_config( + "decode_left_context_frames", + self._decode_left_context_frames, + ) + if decode_chunk_frames <= 0 or decode_left_context_frames < 0: + raise ValueError( + "Invalid Qwen3-TTS Code2Wav decode chunk config: " + f"decode_chunk_frames={decode_chunk_frames}, " + f"decode_left_context_frames={decode_left_context_frames}" + ) + self._decode_chunk_frames = decode_chunk_frames + self._decode_left_context_frames = decode_left_context_frames if hasattr(self.decoder, "enable_cudagraph") and device.type == "cuda": try: if ( - chunk_frames > 0 - and left_frames > 0 + codec_chunk_frames > 0 + and codec_left_context_frames > 0 and self._decoder_sliding_window - and left_frames < self._decoder_sliding_window + and codec_left_context_frames < self._decoder_sliding_window ): logger.warning( "Qwen3-TTS streaming codec_left_context_frames=%d " @@ -317,15 +339,17 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: "chunk-boundary distortion may occur. " "Increase codec_left_context_frames to at least " "%d for streaming.", - left_frames, + codec_left_context_frames, self._decoder_sliding_window, self._decoder_sliding_window, ) self.decoder.enable_cudagraph( device=device, - codec_chunk_frames=chunk_frames, - codec_left_context_frames=left_frames, + codec_chunk_frames=codec_chunk_frames, + codec_left_context_frames=codec_left_context_frames, + decode_chunk_size=self._decode_chunk_frames, + decode_left_context=self._decode_left_context_frames, ) logger.info("Code2Wav decoder CUDA Graph enabled") except Exception: diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py index d96505b8db1..e71dbc091e8 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py @@ -851,6 +851,8 @@ def enable_cudagraph( device: torch.device | None = None, codec_chunk_frames: int = 0, codec_left_context_frames: int = 0, + decode_chunk_size: int = 300, + decode_left_context: int = 25, ): from ..cuda_graph_decoder_wrapper import CUDAGraphDecoderWrapper @@ -871,6 +873,8 @@ def enable_cudagraph( dtype=torch.long, codec_chunk_frames=codec_chunk_frames, codec_left_context_frames=codec_left_context_frames, + decode_chunk_size=decode_chunk_size, + decode_left_context=decode_left_context, ) self._cudagraph_enabled = True logger.info( diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml index 4ca8d11ad77..c370646a1ad 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml @@ -38,7 +38,7 @@ stage_args: devices: "0" engine_args: model_stage: code2wav - max_num_seqs: 1 + max_num_seqs: 10 model_arch: Qwen3TTSCode2Wav worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler @@ -87,6 +87,8 @@ runtime: # Match the decoder sliding attention window to avoid chunk-boundary noise. codec_chunk_frames: 25 codec_left_context_frames: 72 + # Emit only the first audio chunk early, then return to codec_chunk_frames. + initial_codec_chunk_frames: 1 edges: - from: 0 diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py index 3488eb5b00f..8badab79f45 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py @@ -161,10 +161,11 @@ def talker2code2wav_async_chunk( cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {} chunk_size = int(cfg.get("codec_chunk_frames", 25)) left_context_size_config = int(cfg.get("codec_left_context_frames", 25)) + configured_initial_chunk_size = int(cfg.get("initial_codec_chunk_frames") or 0) # Per-request override takes priority over dynamic IC. - per_request_override = False - initial_chunk_size = 0 + fixed_initial_chunk_size = configured_initial_chunk_size > 0 + initial_chunk_size = configured_initial_chunk_size additional_information = getattr(request, "additional_information", None) if ( @@ -175,10 +176,10 @@ def talker2code2wav_async_chunk( entry = additional_information.entries["initial_codec_chunk_frames"] if entry.list_data is not None and len(entry.list_data) == 1: initial_chunk_size = int(entry.list_data[0]) - per_request_override = True + fixed_initial_chunk_size = True # Dynamic IC: cache per request so boundaries stay stable for its lifetime. - if not per_request_override: + if not fixed_initial_chunk_size: _ic_cache = getattr(transfer_manager, "_cached_ic", None) if _ic_cache is None: _ic_cache = {} @@ -190,7 +191,7 @@ def talker2code2wav_async_chunk( _ic_cache[request_id] = compute_dynamic_initial_chunk_size(active, capacity, max_ic) initial_chunk_size = _ic_cache[request_id] - if chunk_size <= 0 or left_context_size_config < 0 or initial_chunk_size < 0: + if chunk_size <= 0 or left_context_size_config < 0 or configured_initial_chunk_size < 0 or initial_chunk_size < 0: raise ValueError( f"Invalid codec chunk config: codec_chunk_frames={chunk_size}, " f"codec_left_context_frames={left_context_size_config}, " @@ -214,21 +215,16 @@ def talker2code2wav_async_chunk( } return None - in_initial_phase = initial_chunk_size > 0 and initial_chunk_size < chunk_size and length <= chunk_size + use_first_chunk = initial_chunk_size > 0 and initial_chunk_size < chunk_size - if in_initial_phase: - # IC phase: emit every initial_chunk_size frames with growing left context. - if not finished and length % initial_chunk_size != 0: + if use_first_chunk and length <= initial_chunk_size: + if not finished and length < initial_chunk_size: return None - context_length = ( - length % initial_chunk_size if (finished and length % initial_chunk_size != 0) else initial_chunk_size - ) + context_length = length if finished and length < initial_chunk_size else initial_chunk_size else: - # Normal phase: offset so the first normal emit picks up after IC phase. - # IC is stateless (may change with load); any mismatch is absorbed by left_context. - initial_coverage = ( - (chunk_size // initial_chunk_size) * initial_chunk_size if 0 < initial_chunk_size < chunk_size else 0 - ) + # The initial chunk is only for TTFA. After that, return to the normal + # codec chunk size so Code2Wav is not flooded by repeated tiny windows. + initial_coverage = initial_chunk_size if use_first_chunk else 0 adjusted = length - initial_coverage if not finished and adjusted % chunk_size != 0: return None diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py index 8e8f5741fa6..25f6c040cb5 100644 --- a/vllm_omni/worker/omni_connector_model_runner_mixin.py +++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py @@ -120,6 +120,7 @@ def init_omni_connectors( # ownership lives in ``_local_stage_payload_cache``. self._send_side_request_payload: dict[str, dict[str, Any]] = {} self._code_prompt_token_ids: dict[str, list[list[int]]] = defaultdict(list) + self._cached_ic: dict[str, int] = {} self._request_ids_mapping: dict[str, str] = {} # -- async I/O state (shared by chunk + full_payload_mode) -- @@ -220,6 +221,7 @@ def cleanup_finished_request(self, req_id: str) -> None: self._put_req_chunk.pop(send_req_id, None) self._send_side_request_payload.pop(send_req_id, None) self._code_prompt_token_ids.pop(send_req_id, None) + self._cached_ic.pop(send_req_id, None) self._kv_pending_transfers.pop(req_id, None) self._kv_active_transfers.discard(req_id) self._kv_completed_transfers.discard(req_id) @@ -239,7 +241,9 @@ def drop_inactive_request_delivery_state(self, req_id: str) -> None: def _drop_send_side_payload_state(self, req_id: str, ext_id: str | None) -> None: if ext_id is not None: self._send_side_request_payload.pop(ext_id, None) + self._cached_ic.pop(ext_id, None) self._send_side_request_payload.pop(req_id, None) + self._cached_ic.pop(req_id, None) def _cleanup_recv_delivery_state(self, req_id: str) -> None: """Clear recv-side delivery-cycle state.""" @@ -1786,6 +1790,7 @@ def _decrement_pending_save_count(self, request_id: str) -> None: self._put_req_chunk.pop(cleanup_req_id, None) self._send_side_request_payload.pop(cleanup_req_id, None) self._code_prompt_token_ids.pop(cleanup_req_id, None) + self._cached_ic.pop(cleanup_req_id, None) # ------------------------------------------------------------------ # # Payload accumulation (ported from OmniChunkTransferAdapter)