From 03b2f7f1f7de0cdab21eddf1804a5e043f3f3d37 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Sat, 4 Apr 2026 02:01:33 -0700 Subject: [PATCH] Fix Qwen3-TTS streaming chunk-boundary artifacts Signed-off-by: Sy03 <1370724210@qq.com> --- tests/dfx/perf/stage_configs/qwen3_tts.yaml | 2 +- .../qwen3_tts/test_qwen3_tts_code2wav.py | 65 +++++++++++++++++++ .../models/qwen3_tts/pipeline.yaml | 3 +- .../models/qwen3_tts/qwen3_tts_code2wav.py | 42 +++++++----- .../stage_configs/qwen3_tts.yaml | 4 +- .../stage_configs/qwen3_tts_batch.yaml | 4 +- .../npu/stage_configs/qwen3_tts.yaml | 2 +- 7 files changed, 100 insertions(+), 22 deletions(-) create mode 100644 tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py diff --git a/tests/dfx/perf/stage_configs/qwen3_tts.yaml b/tests/dfx/perf/stage_configs/qwen3_tts.yaml index dd69b248d1a..97b30905603 100644 --- a/tests/dfx/perf/stage_configs/qwen3_tts.yaml +++ b/tests/dfx/perf/stage_configs/qwen3_tts.yaml @@ -88,7 +88,7 @@ runtime: connector_get_max_wait_first_chunk: 3000 connector_get_max_wait: 300 codec_chunk_frames: 25 - codec_left_context_frames: 25 + codec_left_context_frames: 72 edges: - from: 0 diff --git a/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py new file mode 100644 index 00000000000..3f4e5b3ada4 --- /dev/null +++ b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav import Qwen3TTSCode2Wav + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class _FakeDecoder(nn.Module): + def __init__(self, total_upsample: int = 4): + super().__init__() + self.total_upsample = total_upsample + + def chunked_decode(self, codes: torch.Tensor) -> torch.Tensor: + frames = codes.shape[-1] + wav_len = frames * self.total_upsample + 6 + wav = torch.arange(wav_len, dtype=torch.float32) + return wav.view(1, 1, -1) + + +def _make_model() -> Qwen3TTSCode2Wav: + model = Qwen3TTSCode2Wav( + vllm_config=SimpleNamespace( + model_config=SimpleNamespace(model="unused"), + device_config=SimpleNamespace(device=torch.device("cpu")), + ) + ) + model._decoder = _FakeDecoder() + model._num_quantizers = 2 + model._output_sample_rate = 24000 + model._total_upsample = 4 + model._ensure_speech_tokenizer_loaded = lambda: None + return model + + +def test_forward_trims_context_on_exact_frame_boundaries(): + model = _make_model() + + out = model.forward( + input_ids=torch.arange(12, dtype=torch.long), + runtime_additional_information=[{"left_context_size": 2}], + ) + + audio = out.multimodal_outputs["model_outputs"][0] + expected = torch.arange(8, 24, dtype=torch.float32) + torch.testing.assert_close(audio, expected) + + +def test_forward_trims_trailing_padding_without_context(): + model = _make_model() + + out = model.forward( + input_ids=torch.arange(12, dtype=torch.long), + runtime_additional_information=[{"left_context_size": 0}], + ) + + audio = out.multimodal_outputs["model_outputs"][0] + expected = torch.arange(24, dtype=torch.float32) + torch.testing.assert_close(audio, expected) diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml b/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml index 6e3c78ff934..fd8ea3a3f4e 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml +++ b/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml @@ -84,7 +84,8 @@ connectors: connector_get_max_wait_first_chunk: 3000 connector_get_max_wait: 300 codec_chunk_frames: 25 - codec_left_context_frames: 25 + # Match the decoder sliding attention window to avoid chunk-boundary noise. + codec_left_context_frames: 72 edges: - from: 0 diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py index f6ac91a994f..79f0f4a8def 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py @@ -41,6 +41,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._num_quantizers: int | None = None self._output_sample_rate: int | None = None self._total_upsample: int | None = None + self._decoder_sliding_window: int | None = None self._logged_codec_stats = False @staticmethod @@ -106,6 +107,7 @@ def _ensure_speech_tokenizer_loaded(self) -> None: self._num_quantizers = num_q self._output_sample_rate = out_sr self._total_upsample = int(decoder.total_upsample) + self._decoder_sliding_window = int(getattr(dec_cfg, "sliding_window", 0) or 0) # Precompute SnakeBeta exp caches (benefits both Triton and eager paths) if hasattr(decoder, "precompute_snake_caches"): @@ -128,6 +130,20 @@ def _ensure_speech_tokenizer_loaded(self) -> None: if isinstance(extra_cfg, dict): chunk_frames = int(extra_cfg.get("codec_chunk_frames") or 0) left_frames = int(extra_cfg.get("codec_left_context_frames") or 0) + if ( + chunk_frames > 0 + and left_frames > 0 + and self._decoder_sliding_window + and left_frames < self._decoder_sliding_window + ): + logger.warning( + "Qwen3-TTS streaming codec_left_context_frames=%d is smaller than " + "decoder sliding_window=%d; chunk-boundary distortion may occur. " + "Increase codec_left_context_frames to at least %d for streaming.", + left_frames, + self._decoder_sliding_window, + self._decoder_sliding_window, + ) decoder.enable_cudagraph( device=device, @@ -289,21 +305,17 @@ def forward( for j, idx in enumerate(valid_indices): ctx_frames, actual_frames = parsed[idx] wav = wav_tensors[j] - # Drop the ref_code prefix from the decoded waveform, keeping only newly generated audio. - if ctx_frames <= 0: - expected_len = actual_frames * upsample - if wav.shape[0] > expected_len: - wav = wav[:expected_len] - else: - cut = int(ctx_frames / max(actual_frames, 1) * wav.shape[0]) - if cut >= wav.shape[0]: - logger.warning( - "Context trim %d >= decoded length %d; returning empty audio.", - cut, - wav.shape[0], - ) - continue - wav = wav[cut:] + # Slice on exact codec-frame boundaries instead of proportionally. + start = max(0, ctx_frames * upsample) + end = max(start, actual_frames * upsample) + if start >= wav.shape[0]: + logger.warning( + "Context trim start %d >= decoded length %d; returning empty audio.", + start, + wav.shape[0], + ) + continue + wav = wav[start : min(end, wav.shape[0])] if wav.shape[0] > 0: audios[idx] = wav.to(dtype=torch.float32).reshape(-1) diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml index 2c5f0a54744..a0d38eb4b9f 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml @@ -89,9 +89,9 @@ runtime: connector_get_sleep_s: 0.01 connector_get_max_wait_first_chunk: 3000 connector_get_max_wait: 300 - # Align with Omni: small chunks with sufficient context overlap. + # Match the decoder sliding attention window to avoid chunk-boundary noise. codec_chunk_frames: 25 - codec_left_context_frames: 25 + codec_left_context_frames: 72 edges: - from: 0 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml index a3509bb3305..75b2bab3a27 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml @@ -90,9 +90,9 @@ runtime: connector_get_sleep_s: 0.01 connector_get_max_wait_first_chunk: 3000 connector_get_max_wait: 300 - # Align with Omni: small chunks with sufficient context overlap. + # Match the decoder sliding attention window to avoid chunk-boundary noise. codec_chunk_frames: 25 - codec_left_context_frames: 25 + codec_left_context_frames: 72 edges: - from: 0 diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml index a741f819a2b..cd82d91b715 100644 --- a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml @@ -88,7 +88,7 @@ runtime: connector_get_max_wait: 300 # Align with Omni: small chunks with sufficient context overlap. codec_chunk_frames: 25 - codec_left_context_frames: 25 + codec_left_context_frames: 72 edges: - from: 0