Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/dfx/perf/stage_configs/qwen3_tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ runtime:
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
codec_chunk_frames: 25
codec_left_context_frames: 25
codec_left_context_frames: 72

edges:
- from: 0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import pytest
import torch
import torch.nn as nn

from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav import Qwen3TTSCode2Wav

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


class _FakeDecoder(nn.Module):
def __init__(self, total_upsample: int = 4):
super().__init__()
self.total_upsample = total_upsample

def chunked_decode(self, codes: torch.Tensor) -> torch.Tensor:
frames = codes.shape[-1]
wav_len = frames * self.total_upsample + 6
wav = torch.arange(wav_len, dtype=torch.float32)
return wav.view(1, 1, -1)


def _make_model() -> Qwen3TTSCode2Wav:
model = Qwen3TTSCode2Wav(
vllm_config=SimpleNamespace(
model_config=SimpleNamespace(model="unused"),
device_config=SimpleNamespace(device=torch.device("cpu")),
)
)
model._decoder = _FakeDecoder()
model._num_quantizers = 2
model._output_sample_rate = 24000
model._total_upsample = 4
model._ensure_speech_tokenizer_loaded = lambda: None
return model


def test_forward_trims_context_on_exact_frame_boundaries():
model = _make_model()

out = model.forward(
input_ids=torch.arange(12, dtype=torch.long),
runtime_additional_information=[{"left_context_size": 2}],
)

audio = out.multimodal_outputs["model_outputs"][0]
expected = torch.arange(8, 24, dtype=torch.float32)
torch.testing.assert_close(audio, expected)


def test_forward_trims_trailing_padding_without_context():
model = _make_model()

out = model.forward(
input_ids=torch.arange(12, dtype=torch.long),
runtime_additional_information=[{"left_context_size": 0}],
)

audio = out.multimodal_outputs["model_outputs"][0]
expected = torch.arange(24, dtype=torch.float32)
torch.testing.assert_close(audio, expected)
3 changes: 2 additions & 1 deletion vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ connectors:
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
codec_chunk_frames: 25
codec_left_context_frames: 25
# Match the decoder sliding attention window to avoid chunk-boundary noise.
codec_left_context_frames: 72

edges:
- from: 0
Expand Down
42 changes: 27 additions & 15 deletions vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._num_quantizers: int | None = None
self._output_sample_rate: int | None = None
self._total_upsample: int | None = None
self._decoder_sliding_window: int | None = None
self._logged_codec_stats = False

@staticmethod
Expand Down Expand Up @@ -106,6 +107,7 @@ def _ensure_speech_tokenizer_loaded(self) -> None:
self._num_quantizers = num_q
self._output_sample_rate = out_sr
self._total_upsample = int(decoder.total_upsample)
self._decoder_sliding_window = int(getattr(dec_cfg, "sliding_window", 0) or 0)

# Precompute SnakeBeta exp caches (benefits both Triton and eager paths)
if hasattr(decoder, "precompute_snake_caches"):
Expand All @@ -128,6 +130,20 @@ def _ensure_speech_tokenizer_loaded(self) -> None:
if isinstance(extra_cfg, dict):
chunk_frames = int(extra_cfg.get("codec_chunk_frames") or 0)
left_frames = int(extra_cfg.get("codec_left_context_frames") or 0)
if (
chunk_frames > 0
and left_frames > 0
and self._decoder_sliding_window
and left_frames < self._decoder_sliding_window
):
logger.warning(
"Qwen3-TTS streaming codec_left_context_frames=%d is smaller than "
"decoder sliding_window=%d; chunk-boundary distortion may occur. "
"Increase codec_left_context_frames to at least %d for streaming.",
left_frames,
self._decoder_sliding_window,
self._decoder_sliding_window,
)

decoder.enable_cudagraph(
device=device,
Expand Down Expand Up @@ -289,21 +305,17 @@ def forward(
for j, idx in enumerate(valid_indices):
ctx_frames, actual_frames = parsed[idx]
wav = wav_tensors[j]
# Drop the ref_code prefix from the decoded waveform, keeping only newly generated audio.
if ctx_frames <= 0:
expected_len = actual_frames * upsample
if wav.shape[0] > expected_len:
wav = wav[:expected_len]
else:
cut = int(ctx_frames / max(actual_frames, 1) * wav.shape[0])
if cut >= wav.shape[0]:
logger.warning(
"Context trim %d >= decoded length %d; returning empty audio.",
cut,
wav.shape[0],
)
continue
wav = wav[cut:]
# Slice on exact codec-frame boundaries instead of proportionally.
start = max(0, ctx_frames * upsample)
end = max(start, actual_frames * upsample)
if start >= wav.shape[0]:
logger.warning(
"Context trim start %d >= decoded length %d; returning empty audio.",
start,
wav.shape[0],
)
continue
wav = wav[start : min(end, wav.shape[0])]
if wav.shape[0] > 0:
audios[idx] = wav.to(dtype=torch.float32).reshape(-1)

Expand Down
4 changes: 2 additions & 2 deletions vllm_omni/model_executor/stage_configs/qwen3_tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ runtime:
connector_get_sleep_s: 0.01
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
# Align with Omni: small chunks with sufficient context overlap.
# Match the decoder sliding attention window to avoid chunk-boundary noise.
codec_chunk_frames: 25
codec_left_context_frames: 25
codec_left_context_frames: 72

edges:
- from: 0
Expand Down
4 changes: 2 additions & 2 deletions vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ runtime:
connector_get_sleep_s: 0.01
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
# Align with Omni: small chunks with sufficient context overlap.
# Match the decoder sliding attention window to avoid chunk-boundary noise.
codec_chunk_frames: 25
codec_left_context_frames: 25
codec_left_context_frames: 72
Comment thread
Sy0307 marked this conversation as resolved.

edges:
- from: 0
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ runtime:
connector_get_max_wait: 300
# Align with Omni: small chunks with sufficient context overlap.
codec_chunk_frames: 25
codec_left_context_frames: 25
codec_left_context_frames: 72

edges:
- from: 0
Expand Down
Loading