Skip to content

Fix Base voice clone streaming quality and stop-token crash#1945

Merged
hsliuustc0106 merged 3 commits into
vllm-project:mainfrom
linyueqian:fix/voice-clone-streaming-quality
Mar 17, 2026
Merged

Fix Base voice clone streaming quality and stop-token crash#1945
hsliuustc0106 merged 3 commits into
vllm-project:mainfrom
linyueqian:fix/voice-clone-streaming-quality

Conversation

@linyueqian
Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian commented Mar 17, 2026

Summary

  • Fix voice clone (Base ICL) audio distortion in async_chunk streaming mode
  • Fix CUDA assert crash when using async_chunk=false with voice clone
  • Align context trimming with official HF qwen_tts implementation

Changes

1. Streaming quality (stage_input_processors/qwen3_tts.py)

Prepend ref_code as decoder context on every chunk, not just the first. The HF reference decodes ref_code + all_codes in one pass; without ref_code on later chunks the vocoder loses speaker identity and produces distorted audio. Changed .pop() to .get() to retain ref_code across chunks.

2. Stop-token crash (stage_input_processors/qwen3_tts.py)

Filter frames with codec values >= codebook_size (2048). The talker's stop_token_id (2150) exceeds the codebook and causes CUDA assert in the decoder's embedding lookup when using the non-streaming path.

3. Proportional context trim (qwen3_tts_code2wav.py)

Use ctx_frames / actual_frames * wav_len ratio (matching HF's qwen_tts implementation) instead of fixed ctx_frames * upsample to handle decoder output length variations.

Test plan

  • Base (voice clone) with async_chunk=true produces clean audio matching HF quality
  • Whisper transcription confirms correct text content
  • CustomVoice and VoiceDesign still work (no regression)
  • Base with async_chunk=false no longer crashes

Fixes #1944

CC @Sy0307

Two fixes for Qwen3-TTS Base (voice clone) mode:

1. Streaming quality: prepend ref_code as decoder context on every chunk
   (not just the first). The HF reference decodes ref_code + all codes in
   one pass; without ref_code on later chunks the vocoder loses speaker
   identity, producing distorted audio.

2. Non-streaming crash: filter frames with codec values >= codebook_size
   (2048). The talker's stop_token_id (2150) exceeds the codebook and
   causes CUDA assert in the decoder's embedding lookup.

3. Proportional context trim matching HF: use ctx_frames/actual_frames
   ratio instead of fixed ctx_frames * upsample to handle decoder output
   length variations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: linyueqian <linyueqian@outlook.com>
@linyueqian
Copy link
Copy Markdown
Collaborator Author

linyueqian commented Mar 17, 2026

Test Results

Text: "Good one. Okay, fine, I'm just gonna leave this sock monkey here. Goodbye."
Model: Qwen/Qwen3-TTS-12Hz-1.7B-Base (voice clone)
Ref audio: clone_2.wav from official Qwen3-TTS repo

A/B Comparison

Version Duration Whisper Transcription Audio Quality
HF qwen_tts (reference) 4.80s Good. Okay, fine. I'm just gonna leave this sock monkey here. Goodbye. Clean, no distortion
vLLM before fix 5.04s Good one. Okay, fine. I'm just gonna leave this sock monkey here. Goodbye. Distorted after first few seconds
vLLM after fix 5.04s Good one. Okay, fine. I'm just gonna leave this sock monkey here. Goodbye. Clean, matches HF quality

Root Cause

Before fix: Only the first streaming chunk received ref_code decoder context (101 frames). Subsequent chunks lost speaker identity → distorted audio.

After fix: Every chunk receives ref_code as left context, maintaining speaker identity throughout the stream — matching HF's one-pass decode behavior.

Environment

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

can we add this to the acc test for qwen-tts model

@linyueqian
Copy link
Copy Markdown
Collaborator Author

can we add this to the acc test for qwen-tts model

voice clone is hard to test given the timbre is easy to be identified from human, but ASR cannot tell any difference. @yenuo26 any good suggestion? i notice you add gender test which seems great.

@linyueqian
Copy link
Copy Markdown
Collaborator Author

wespeaker could be one option

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Mar 17, 2026
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

wespeaker could be one option

can we have a minimal assertion to avoid "Distorted after first few seconds"?

@linyueqian
Copy link
Copy Markdown
Collaborator Author

Audio Quality Assertion

Added a minimal HNR (Harmonic-to-Noise Ratio) check that reliably detects the voice clone distortion:

def assert_voice_clone_quality(wav_path, min_hnr_db=1.2):
    """Distorted voice clone: HNR < 1.0 dB. Clean: HNR > 1.2 dB."""
    data, sr = sf.read(wav_path)
    frame_len = int(0.03 * sr)
    hop = frame_len // 2
    hnr_values = []
    for start in range(0, len(data) - frame_len, hop):
        frame = data[start:start + frame_len]
        if np.max(np.abs(frame)) < 0.01:
            continue
        ac = np.correlate(frame, frame, mode='full')[len(frame)-1:]
        ac = ac / (ac[0] + 1e-10)
        min_lag, max_lag = int(sr / 400), min(int(sr / 80), len(ac))
        if min_lag >= max_lag:
            continue
        peak = np.max(ac[min_lag:max_lag])
        if 0 < peak < 1:
            hnr_values.append(10 * np.log10(peak / (1 - peak + 1e-10)))
    return np.mean(hnr_values) if hnr_values else 0

Results

Sample HNR Status
Before fix (distorted) 0.88 dB FAIL
After fix 1.62 dB PASS
HF reference 2.17 dB PASS
After fix sample 1 4.38 dB PASS
After fix sample 2 4.88 dB PASS
After fix sample 3 1.65 dB PASS
After fix sample 4 1.45 dB PASS
After fix sample 5 2.85 dB PASS

Threshold of 1.2 dB cleanly separates distorted from clean voice clone audio across all test cases.

- Update ref_code test: verify ref_code context is prepended on all chunks
  (not just the first) to maintain speaker identity throughout streaming
- Update buffered ref_code test: verify ref_code is retained (.get not .pop)
  for subsequent chunks
- Add stop-token filtering test: verify frames with values >= codebook_size
  (e.g. stop_token_id=2150) are excluded from Code2Wav input

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: linyueqian <linyueqian@outlook.com>
Adds test_base_voice_clone_no_distortion which computes the
Harmonic-to-Noise Ratio (HNR) of the generated audio and asserts
it exceeds 1.2 dB. Distorted voice clone (from lost ref_code decoder
context) drops below 1.0 dB; clean output exceeds 1.2 dB.

Validated on before/after fix outputs and HF reference:
  Before fix: 0.88 dB (FAIL)
  After fix:  1.62 dB (PASS)
  HF ref:     2.17 dB (PASS)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: linyueqian <linyueqian@outlook.com>
@univa-HARRY
Copy link
Copy Markdown

I think it would resolve my issue #1707

@yenuo26
Copy link
Copy Markdown
Collaborator

yenuo26 commented Mar 17, 2026

can we add this to the acc test for qwen-tts model

voice clone is hard to test given the timbre is easy to be identified from human, but ASR cannot tell any difference. @yenuo26 any good suggestion? i notice you add gender test which seems great.

I have added audio content generation and voice gender validation in the use cases. Currently, I'm using F0 for simple calculation, but I feel this approach is less stable. I'm currently considering switching to a small gender recognition model.
#1911

@Sy0307
Copy link
Copy Markdown
Contributor

Sy0307 commented Mar 17, 2026

LGTM. Tests go well. I will save ref_code in metadata instead of prompt_token_ids in a new PR.

@hsliuustc0106 hsliuustc0106 merged commit fe786f1 into vllm-project:main Mar 17, 2026
7 checks passed
tangbinh pushed a commit to tangbinh/vllm-omni that referenced this pull request Mar 18, 2026
yiliu30 pushed a commit to yiliu30/vllm-omni-fork that referenced this pull request Mar 20, 2026
…ject#1945)

Signed-off-by: linyueqian <linyueqian@outlook.com>

Signed-off-by: yiliu30 <yi4.liu@intel.com>
hsliuustc0106 added a commit to hsliuustc0106/vllm-omni-skills that referenced this pull request Mar 22, 2026
### vllm-omni-audio-tts
- Source: [PR #2059](vllm-project/vllm-omni#2059) - [BugFix][Qwen3TTS] CodePredictor CudaGraph Pool
- Changes:
  - Bug fix: [BugFix][Qwen3TTS] CodePredictor CudaGraph Pool

### vllm-omni-perf
- Source: [PR #2059](vllm-project/vllm-omni#2059) - [BugFix][Qwen3TTS] CodePredictor CudaGraph Pool
- Changes:
  - Bug fix: [BugFix][Qwen3TTS] CodePredictor CudaGraph Pool

### vllm-omni-api
- Source: [PR #2058](vllm-project/vllm-omni#2058) - [Bugfix] Fix Fish Speech and CosyVoice3 online serving - missing is_comprehension and broken model detection
- Changes:
  - Bug fix: [Bugfix] Fix Fish Speech and CosyVoice3 online serving - missing is_comprehension and broken model detection

### vllm-omni-contrib
- Source: [PR #2045](vllm-project/vllm-omni#2045) - [Voxtral] Improve example

### vllm-omni-cicd
- Source: [PR #2045](vllm-project/vllm-omni#2045) - [Voxtral] Improve example

### vllm-omni-api
- Source: [PR #2042](vllm-project/vllm-omni#2042) - [bugfix] /chat/completion doesn't read extra_body for diffusion model
- Changes:
  - Bug fix: [bugfix] /chat/completion doesn't read extra_body for diffusion model

### vllm-omni-perf
- Source: [PR #2042](vllm-project/vllm-omni#2042) - [bugfix] /chat/completion doesn't read extra_body for diffusion model
- Changes:
  - Bug fix: [bugfix] /chat/completion doesn't read extra_body for diffusion model

### vllm-omni-contrib
- Source: [PR #2038](vllm-project/vllm-omni#2038) - [Doc] Update docs and dockerfiles for rebase of vllm v0.18.0

### vllm-omni-serving
- Source: [PR #2037](vllm-project/vllm-omni#2037) - [Rebase] Rebase to vllm v0.18.0

### vllm-omni-contrib
- Source: [PR #2037](vllm-project/vllm-omni#2037) - [Rebase] Rebase to vllm v0.18.0

### vllm-omni-api
- Source: [PR #2037](vllm-project/vllm-omni#2037) - [Rebase] Rebase to vllm v0.18.0

### vllm-omni-cicd
- Source: [PR #2037](vllm-project/vllm-omni#2037) - [Rebase] Rebase to vllm v0.18.0

### vllm-omni-cicd
- Source: [PR #2032](vllm-project/vllm-omni#2032) - [CI] Change Bagel online test environment variable `VLLM_TEST_CLEAN_GPU_MEMORY` to `0`

### vllm-omni-cicd
- Source: [PR #2031](vllm-project/vllm-omni#2031) - [CI] Fix test.
- Changes:
  - Bug fix: [CI] Fix test.

### vllm-omni-cicd
- Source: [PR #2017](vllm-project/vllm-omni#2017) - [CI] [ROCm] Setup `test-ready.yml` and `test-merge.yml`

### vllm-omni-cicd
- Source: [PR #2014](vllm-project/vllm-omni#2014) - [Test] Implement mock HTTP request handling in benchmark CLI tests

### vllm-omni-perf
- Source: [PR #2014](vllm-project/vllm-omni#2014) - [Test] Implement mock HTTP request handling in benchmark CLI tests

### vllm-omni-serving
- Source: [PR #2012](vllm-project/vllm-omni#2012) - [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips
- Changes:
  - Bug fix: [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips

### vllm-omni-image-gen
- Source: [PR #2012](vllm-project/vllm-omni#2012) - [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips
- Changes:
  - Bug fix: [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips

### vllm-omni-perf
- Source: [PR #2012](vllm-project/vllm-omni#2012) - [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips
- Changes:
  - Bug fix: [Fixbug][Perf] Qwen3-omni: code predictor with re-prefill + SDPA and eliminate decode hot-path CPU round-trips

### vllm-omni-serving
- Source: [PR #2009](vllm-project/vllm-omni#2009) - [Bugfix] revert PR#1758 which introduced the accuracy problem of qwen3-omni
- Changes:
  - Bug fix: [Bugfix] revert PR#1758 which introduced the accuracy problem of qwen3-omni

### vllm-omni-image-gen
- Source: [PR #2007](vllm-project/vllm-omni#2007) - [Bugfix]Fix bug of online server can not return mutli images
- Changes:
  - Bug fix: [Bugfix]Fix bug of online server can not return mutli images
- Additions:
  - Qwen-Image-Layered
  - Qwen-Image-Layered
  - Qwen-Image-Layered

### vllm-omni-api
- Source: [PR #2007](vllm-project/vllm-omni#2007) - [Bugfix]Fix bug of online server can not return mutli images
- Changes:
  - Bug fix: [Bugfix]Fix bug of online server can not return mutli images

### vllm-omni-cicd
- Source: [PR #1998](vllm-project/vllm-omni#1998) - [CI] Split BAGEL tests into dummy/real weight tiers (L2/L3)

### vllm-omni-serving
- Source: [PR #1985](vllm-project/vllm-omni#1985) - [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls
- Changes:
  - Performance improvement: [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls

### vllm-omni-audio-tts
- Source: [PR #1985](vllm-project/vllm-omni#1985) - [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls
- Changes:
  - Performance improvement: [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls

### vllm-omni-perf
- Source: [PR #1985](vllm-project/vllm-omni#1985) - [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls
- Changes:
  - Performance improvement: [Perf] [Qwen3-TTS] Keep audio_codes and last_talker_hidden on GPU to eliminate per-step sync stalls

### vllm-omni-serving
- Source: [PR #1984](vllm-project/vllm-omni#1984) - [CI] [ROCm] Bugfix device environment issue
- Changes:
  - Bug fix: [CI] [ROCm] Bugfix device environment issue

### vllm-omni-api
- Source: [PR #1984](vllm-project/vllm-omni#1984) - [CI] [ROCm] Bugfix device environment issue
- Changes:
  - Bug fix: [CI] [ROCm] Bugfix device environment issue

### vllm-omni-serving
- Source: [PR #1982](vllm-project/vllm-omni#1982) - [Fix] Fix slow hasattr in CUDAGraphWrapper.__getattr__
- Changes:
  - Bug fix: [Fix] Fix slow hasattr in CUDAGraphWrapper.__getattr__

### vllm-omni-cicd
- Source: [PR #1982](vllm-project/vllm-omni#1982) - [Fix] Fix slow hasattr in CUDAGraphWrapper.__getattr__
- Changes:
  - Bug fix: [Fix] Fix slow hasattr in CUDAGraphWrapper.__getattr__

### vllm-omni-api
- Source: [PR #1979](vllm-project/vllm-omni#1979) - [Bugfix] Fix config misalignment between offline and online diffusion inference (Wan2.2, Qwen-Image series)
- Changes:
  - Bug fix: [Bugfix] Fix config misalignment between offline and online diffusion inference (Wan2.2, Qwen-Image series)
- Additions:
  - `/v1/chat/completions`

### vllm-omni-perf
- Source: [PR #1979](vllm-project/vllm-omni#1979) - [Bugfix] Fix config misalignment between offline and online diffusion inference (Wan2.2, Qwen-Image series)
- Changes:
  - Bug fix: [Bugfix] Fix config misalignment between offline and online diffusion inference (Wan2.2, Qwen-Image series)

### vllm-omni-contrib
- Source: [PR #1976](vllm-project/vllm-omni#1976) - [skip ci][Docs] Update WeChat QR code (fix filename case)
- Changes:
  - Bug fix: [skip ci][Docs] Update WeChat QR code (fix filename case)

### vllm-omni-contrib
- Source: [PR #1974](vllm-project/vllm-omni#1974) - [Docs] Update WeChat QR code for community support

### vllm-omni-cicd
- Source: [PR #1945](vllm-project/vllm-omni#1945) - Fix Base voice clone streaming quality and stop-token crash
- Changes:
  - Bug fix: Fix Base voice clone streaming quality and stop-token crash

### vllm-omni-cicd
- Source: [PR #1938](vllm-project/vllm-omni#1938) - [Test] L4 complete diffusion feature test for Bagel models
- Changes:
  - New feature: [Test] L4 complete diffusion feature test for Bagel models

### vllm-omni-perf
- Source: [PR #1938](vllm-project/vllm-omni#1938) - [Test] L4 complete diffusion feature test for Bagel models
- Changes:
  - New feature: [Test] L4 complete diffusion feature test for Bagel models

### vllm-omni-perf
- Source: [PR #1934](vllm-project/vllm-omni#1934) - Fix OmniGen2 transformer config loading for HF models
- Changes:
  - Bug fix: Fix OmniGen2 transformer config loading for HF models

### vllm-omni-audio-tts
- Source: [PR #1930](vllm-project/vllm-omni#1930) - [Bug][Qwen3TTS][Streaming] remove dynamic initial chunk and only compute on initial request

### vllm-omni-perf
- Source: [PR #1930](vllm-project/vllm-omni#1930) - [Bug][Qwen3TTS][Streaming] remove dynamic initial chunk and only compute on initial request

### vllm-omni-audio-tts
- Source: [PR #1926](vllm-project/vllm-omni#1926) - [Misc] removed qwen3_tts.py as it is out-dated

### vllm-omni-contrib
- Source: [PR #1920](vllm-project/vllm-omni#1920) - [Docs] Add Wan2.1-T2V as supported video generation models
- Changes:
  - New feature: [Docs] Add Wan2.1-T2V as supported video generation models

### vllm-omni-video-gen
- Source: [PR #1915](vllm-project/vllm-omni#1915) - [Bugfix] fix helios video generate use cpu device
- Changes:
  - Bug fix: [Bugfix] fix helios video generate use cpu device

### vllm-omni-perf
- Source: [PR #1915](vllm-project/vllm-omni#1915) - [Bugfix] fix helios video generate use cpu device
- Changes:
  - Bug fix: [Bugfix] fix helios video generate use cpu device

### vllm-omni-audio-tts
- Source: [PR #1913](vllm-project/vllm-omni#1913) - [Optim][Qwen3TTS][CodePredictor] support torch.compile with reduce-overhead and dynamic False

### vllm-omni-perf
- Source: [PR #1913](vllm-project/vllm-omni#1913) - [Optim][Qwen3TTS][CodePredictor] support torch.compile with reduce-overhead and dynamic False

### vllm-omni-api
- Source: [PR #1908](vllm-project/vllm-omni#1908) - [Entrypoint][Refactor] vLLM-Omni Entrypoint Refactoring

### vllm-omni-perf
- Source: [PR #1908](vllm-project/vllm-omni#1908) - [Entrypoint][Refactor] vLLM-Omni Entrypoint Refactoring

### vllm-omni-contrib
- Source: [PR #1908](vllm-project/vllm-omni#1908) - [Entrypoint][Refactor] vLLM-Omni Entrypoint Refactoring

### vllm-omni-serving
- Source: [PR #1908](vllm-project/vllm-omni#1908) - [Entrypoint][Refactor] vLLM-Omni Entrypoint Refactoring

### vllm-omni-cicd
- Source: [PR #1908](vllm-project/vllm-omni#1908) - [Entrypoint][Refactor] vLLM-Omni Entrypoint Refactoring

### vllm-omni-image-gen
- Source: [PR #1900](vllm-project/vllm-omni#1900) - [Feat] support HSDP for Flux family
- Changes:
  - New feature: [Feat] support HSDP for Flux family

### vllm-omni-contrib
- Source: [PR #1900](vllm-project/vllm-omni#1900) - [Feat] support HSDP for Flux family
- Changes:
  - New feature: [Feat] support HSDP for Flux family

### vllm-omni-distributed
- Source: [PR #1898](vllm-project/vllm-omni#1898) - [Feature]: Remove some useless `hf_overrides` in yaml
- Changes:
  - New feature: [Feature]: Remove some useless `hf_overrides` in yaml

### vllm-omni-quantization
- Source: [PR #1898](vllm-project/vllm-omni#1898) - [Feature]: Remove some useless `hf_overrides` in yaml
- Changes:
  - New feature: [Feature]: Remove some useless `hf_overrides` in yaml

### vllm-omni-cicd
- Source: [PR #1898](vllm-project/vllm-omni#1898) - [Feature]: Remove some useless `hf_overrides` in yaml
- Changes:
  - New feature: [Feature]: Remove some useless `hf_overrides` in yaml

### vllm-omni-perf
- Source: [PR #1898](vllm-project/vllm-omni#1898) - [Feature]: Remove some useless `hf_overrides` in yaml
- Changes:
  - New feature: [Feature]: Remove some useless `hf_overrides` in yaml

### vllm-omni-contrib
- Source: [PR #1890](vllm-project/vllm-omni#1890) - [NPU] Upgrade to v0.17.0

### vllm-omni-contrib
- Source: [PR #1889](vllm-project/vllm-omni#1889) - Add `Governance` section
- Changes:
  - New feature: Add `Governance` section

### vllm-omni-distributed
- Source: [PR #1881](vllm-project/vllm-omni#1881) - [Feat] Support T5 Tensor Parallelism
- Changes:
  - New feature: [Feat] Support T5 Tensor Parallelism

### vllm-omni-cicd
- Source: [PR #1881](vllm-project/vllm-omni#1881) - [Feat] Support T5 Tensor Parallelism
- Changes:
  - New feature: [Feat] Support T5 Tensor Parallelism
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Mar 27, 2026
The async_chunk stage input processor (_extract_last_frame) did not
check for codec values >= codebook_size (2048). When the Talker emits
stop_token_id=2150 as part of a frame, the raw value passes through
to Code2Wav's nn.Embedding(2048, ...) lookup, causing a CUDA
index-out-of-bounds crash or infinite generation.

The non-async path (talker2code2wav) already has this filtering via
valid_mask since PR vllm-project#1945. This commit adds the same guard to
_extract_last_frame, which is used by talker2code2wav_async_chunk.

Fixes Base/clone voice cloning in async_chunk streaming mode.
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Apr 1, 2026
The async_chunk stage input processor (_extract_last_frame) did not
check for codec values >= codebook_size (2048). When the Talker emits
stop_token_id=2150 as part of a frame, the raw value passes through
to Code2Wav's nn.Embedding(2048, ...) lookup, causing a CUDA
index-out-of-bounds crash or infinite generation.

The non-async path (talker2code2wav) already has this filtering via
valid_mask since PR vllm-project#1945. This commit adds the same guard to
_extract_last_frame, which is used by talker2code2wav_async_chunk.

Fixes Base/clone voice cloning in async_chunk streaming mode.
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Apr 5, 2026
…runcation

The talker2code2wav processor aggressively filtered any frame where any
of the 16 quantizer values was >= codebook_size (2048), and also trimmed
audio_codes to the last 'len(token_ids) - 1' rows. Both defenses are
unnecessary and were producing audible artifacts:

1. The <_CODEBOOK_SIZE filter (PR vllm-project#1945): added to prevent a CUDA assert
   when stop_token_id=2150 reached Code2Wav's embedding lookup. But
   talker.compute_logits() already masks all non-codec values except EOS
   (_codec_allowed_mask at qwen3_tts_talker.py:387-395), and talker_mtp()
   zeroes the entire row when EOS is sampled (lines 1655-1658). So this
   filter is redundant in normal operation, but it can drop legitimate
   middle frames if any codebook value accidentally lands in the special
   range, producing audible gaps in the decoded audio.

2. The seq_len=len(token_ids)-1 trim (PR vllm-project#2104): added for a repeated-
   history aggregation bug. In normal operation, however, this trim can
   silently drop frames from the beginning of the audio (off-by-one
   between token_ids and audio_codes row counts).

The combined effect was that short clone requests were often reduced to
shape [1, 1] or similar, producing codec_codes lists of length 1 that
Code2Wav then rejected as 'malformed request' (qwen3_tts_code2wav.py
line 238), resulting in silent or corrupted output for specific spans.

Replace both with HF's exact behavior from modeling_qwen3_tts.py lines
2283-2290: truncate at the first EOS frame. Since talker_mtp zeroes
rows on EOS, this is equivalent to truncating at the first all-zero
row. No other filtering is needed because compute_logits already
prevents any special token except EOS from being sampled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Apr 5, 2026
…runcation

The talker2code2wav processor aggressively filtered frames with any
quantizer value >= codebook_size (2048) and trimmed audio_codes to the
last (len(token_ids) - 1) rows. Both defenses are unnecessary and
produce audible artifacts:

1. The <_CODEBOOK_SIZE filter (PR vllm-project#1945) was added to prevent a CUDA
   assert when stop_token_id=2150 reached Code2Wav's embedding lookup.
   But talker.compute_logits() already masks all non-codec values
   except EOS (_codec_allowed_mask at qwen3_tts_talker.py:387-395),
   and talker_mtp() zeroes the entire row when EOS is sampled (lines
   1655-1658). The filter is redundant but can also drop legitimate
   middle frames, producing audible gaps in the decoded audio.

2. The seq_len=len(token_ids)-1 trim (PR vllm-project#2104) was added for a
   repeated-history aggregation bug. In normal operation, however,
   this trim can silently drop frames from the beginning of the
   audio (off-by-one between token_ids and audio_codes row counts).

The combined effect: short clone requests were often reduced to
pathological shapes, producing codec_codes of length 1 that Code2Wav
rejected as 'malformed request' (qwen3_tts_code2wav.py line 238) -
the root cause of silent/corrupted audio for specific spans.

Replace both with HF's exact behavior from modeling_qwen3_tts.py lines
2283-2290: truncate at the first EOS frame. Since talker_mtp zeroes
rows on EOS, this is equivalent to truncating at the first all-zero
row. No other filtering is needed because compute_logits prevents any
special token except EOS from being sampled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Apr 5, 2026
…runcation

The talker2code2wav processor aggressively filtered frames with any
quantizer value >= codebook_size (2048) and trimmed audio_codes to the
last (len(token_ids) - 1) rows. Both defenses are unnecessary and
produce audible artifacts:

1. The <_CODEBOOK_SIZE filter (PR vllm-project#1945) was added to prevent a CUDA
   assert when stop_token_id=2150 reached Code2Wav's embedding lookup.
   But talker.compute_logits() already masks all non-codec values
   except EOS (_codec_allowed_mask at qwen3_tts_talker.py:387-395),
   and talker_mtp() zeroes the entire row when EOS is sampled (lines
   1655-1658). The filter is redundant but can also drop legitimate
   middle frames, producing audible gaps in the decoded audio.

2. The seq_len=len(token_ids)-1 trim (PR vllm-project#2104) was added for a
   repeated-history aggregation bug. In normal operation, however,
   this trim can silently drop frames from the beginning of the
   audio (off-by-one between token_ids and audio_codes row counts).

The combined effect: short clone requests were often reduced to
pathological shapes, producing codec_codes of length 1 that Code2Wav
rejected as 'malformed request' (qwen3_tts_code2wav.py line 238) -
the root cause of silent/corrupted audio for specific spans.

Replace both with HF's exact behavior from modeling_qwen3_tts.py lines
2283-2290: truncate at the first EOS frame. Since talker_mtp zeroes
rows on EOS, this is equivalent to truncating at the first all-zero
row. No other filtering is needed because compute_logits prevents any
special token except EOS from being sampled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Apr 5, 2026
…runcation

The talker2code2wav processor aggressively filtered frames with any
quantizer value >= codebook_size (2048) and trimmed audio_codes to the
last (len(token_ids) - 1) rows. Both defenses are unnecessary and
produce audible artifacts:

1. The <_CODEBOOK_SIZE filter (PR vllm-project#1945) was added to prevent a CUDA
   assert when stop_token_id=2150 reached Code2Wav's embedding lookup.
   But talker.compute_logits() already masks all non-codec values
   except EOS (_codec_allowed_mask at qwen3_tts_talker.py:387-395),
   and talker_mtp() zeroes the entire row when EOS is sampled (lines
   1655-1658). The filter is redundant but can also drop legitimate
   middle frames, producing audible gaps in the decoded audio.

2. The seq_len=len(token_ids)-1 trim (PR vllm-project#2104) was added for a
   repeated-history aggregation bug. In normal operation, however,
   this trim can silently drop frames from the beginning of the
   audio (off-by-one between token_ids and audio_codes row counts).

The combined effect: short clone requests were often reduced to
pathological shapes, producing codec_codes of length 1 that Code2Wav
rejected as 'malformed request' (qwen3_tts_code2wav.py line 238) -
the root cause of silent/corrupted audio for specific spans.

Replace both with HF's exact behavior from modeling_qwen3_tts.py lines
2283-2290: truncate at the first EOS frame. Since talker_mtp zeroes
rows on EOS, this is equivalent to truncating at the first all-zero
row. No other filtering is needed because compute_logits prevents any
special token except EOS from being sampled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Apr 5, 2026
…runcation

The talker2code2wav processor aggressively filtered frames with any
quantizer value >= codebook_size (2048) and trimmed audio_codes to the
last (len(token_ids) - 1) rows. Both defenses are unnecessary and
produce audible artifacts:

1. The <_CODEBOOK_SIZE filter (PR vllm-project#1945) was added to prevent a CUDA
   assert when stop_token_id=2150 reached Code2Wav's embedding lookup.
   But talker.compute_logits() already masks all non-codec values
   except EOS (_codec_allowed_mask at qwen3_tts_talker.py:387-395),
   and talker_mtp() zeroes the entire row when EOS is sampled (lines
   1655-1658). The filter is redundant but can also drop legitimate
   middle frames, producing audible gaps in the decoded audio.

2. The seq_len=len(token_ids)-1 trim (PR vllm-project#2104) was added for a
   repeated-history aggregation bug. In normal operation, however,
   this trim can silently drop frames from the beginning of the
   audio (off-by-one between token_ids and audio_codes row counts).

The combined effect: short clone requests were often reduced to
pathological shapes, producing codec_codes of length 1 that Code2Wav
rejected as 'malformed request' (qwen3_tts_code2wav.py line 238) -
the root cause of silent/corrupted audio for specific spans.

Replace both with HF's exact behavior from modeling_qwen3_tts.py lines
2283-2290: truncate at the first EOS frame. Since talker_mtp zeroes
rows on EOS, this is equivalent to truncating at the first all-zero
row. No other filtering is needed because compute_logits prevents any
special token except EOS from being sampled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Apr 5, 2026
…runcation

The talker2code2wav processor aggressively filtered frames with any
quantizer value >= codebook_size (2048) and trimmed audio_codes to the
last (len(token_ids) - 1) rows. Both defenses are unnecessary and
produce audible artifacts:

1. The <_CODEBOOK_SIZE filter (PR vllm-project#1945) was added to prevent a CUDA
   assert when stop_token_id=2150 reached Code2Wav's embedding lookup.
   But talker.compute_logits() already masks all non-codec values
   except EOS (_codec_allowed_mask at qwen3_tts_talker.py:387-395),
   and talker_mtp() zeroes the entire row when EOS is sampled (lines
   1655-1658). The filter is redundant but can also drop legitimate
   middle frames, producing audible gaps in the decoded audio.

2. The seq_len=len(token_ids)-1 trim (PR vllm-project#2104) was added for a
   repeated-history aggregation bug. In normal operation, however,
   this trim can silently drop frames from the beginning of the
   audio (off-by-one between token_ids and audio_codes row counts).

The combined effect: short clone requests were often reduced to
pathological shapes, producing codec_codes of length 1 that Code2Wav
rejected as 'malformed request' (qwen3_tts_code2wav.py line 238) -
the root cause of silent/corrupted audio for specific spans.

Replace both with HF's exact behavior from modeling_qwen3_tts.py lines
2283-2290: truncate at the first EOS frame. Since talker_mtp zeroes
rows on EOS, this is equivalent to truncating at the first all-zero
row. No other filtering is needed because compute_logits prevents any
special token except EOS from being sampled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Apr 5, 2026
…runcation

The talker2code2wav processor aggressively filtered frames with any
quantizer value >= codebook_size (2048) and trimmed audio_codes to the
last (len(token_ids) - 1) rows. Both defenses are unnecessary and
produce audible artifacts:

1. The <_CODEBOOK_SIZE filter (PR vllm-project#1945) was added to prevent a CUDA
   assert when stop_token_id=2150 reached Code2Wav's embedding lookup.
   But talker.compute_logits() already masks all non-codec values
   except EOS (_codec_allowed_mask at qwen3_tts_talker.py:387-395),
   and talker_mtp() zeroes the entire row when EOS is sampled (lines
   1655-1658). The filter is redundant but can also drop legitimate
   middle frames, producing audible gaps in the decoded audio.

2. The seq_len=len(token_ids)-1 trim (PR vllm-project#2104) was added for a
   repeated-history aggregation bug. In normal operation, however,
   this trim can silently drop frames from the beginning of the
   audio (off-by-one between token_ids and audio_codes row counts).

The combined effect: short clone requests were often reduced to
pathological shapes, producing codec_codes of length 1 that Code2Wav
rejected as 'malformed request' (qwen3_tts_code2wav.py line 238) -
the root cause of silent/corrupted audio for specific spans.

Replace both with HF's exact behavior from modeling_qwen3_tts.py lines
2283-2290: truncate at the first EOS frame. Since talker_mtp zeroes
rows on EOS, this is equivalent to truncating at the first all-zero
row. No other filtering is needed because compute_logits prevents any
special token except EOS from being sampled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Apr 6, 2026
…runcation

The talker2code2wav processor aggressively filtered frames with any
quantizer value >= codebook_size (2048) and trimmed audio_codes to the
last (len(token_ids) - 1) rows. Both defenses are unnecessary and
produce audible artifacts:

1. The <_CODEBOOK_SIZE filter (PR vllm-project#1945) was added to prevent a CUDA
   assert when stop_token_id=2150 reached Code2Wav's embedding lookup.
   But talker.compute_logits() already masks all non-codec values
   except EOS (_codec_allowed_mask at qwen3_tts_talker.py:387-395),
   and talker_mtp() zeroes the entire row when EOS is sampled (lines
   1655-1658). The filter is redundant but can also drop legitimate
   middle frames, producing audible gaps in the decoded audio.

2. The seq_len=len(token_ids)-1 trim (PR vllm-project#2104) was added for a
   repeated-history aggregation bug. In normal operation, however,
   this trim can silently drop frames from the beginning of the
   audio (off-by-one between token_ids and audio_codes row counts).

The combined effect: short clone requests were often reduced to
pathological shapes, producing codec_codes of length 1 that Code2Wav
rejected as 'malformed request' (qwen3_tts_code2wav.py line 238) -
the root cause of silent/corrupted audio for specific spans.

Replace both with HF's exact behavior from modeling_qwen3_tts.py lines
2283-2290: truncate at the first EOS frame. Since talker_mtp zeroes
rows on EOS, this is equivalent to truncating at the first all-zero
row. No other filtering is needed because compute_logits prevents any
special token except EOS from being sampled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MihailoMilenkovic added a commit to MihailoMilenkovic/vllm-omni that referenced this pull request Apr 6, 2026
…runcation

The talker2code2wav processor aggressively filtered frames with any
quantizer value >= codebook_size (2048) and trimmed audio_codes to the
last (len(token_ids) - 1) rows. Both defenses are unnecessary and
produce audible artifacts:

1. The <_CODEBOOK_SIZE filter (PR vllm-project#1945) was added to prevent a CUDA
   assert when stop_token_id=2150 reached Code2Wav's embedding lookup.
   But talker.compute_logits() already masks all non-codec values
   except EOS (_codec_allowed_mask at qwen3_tts_talker.py:387-395),
   and talker_mtp() zeroes the entire row when EOS is sampled (lines
   1655-1658). The filter is redundant but can also drop legitimate
   middle frames, producing audible gaps in the decoded audio.

2. The seq_len=len(token_ids)-1 trim (PR vllm-project#2104) was added for a
   repeated-history aggregation bug. In normal operation, however,
   this trim can silently drop frames from the beginning of the
   audio (off-by-one between token_ids and audio_codes row counts).

The combined effect: short clone requests were often reduced to
pathological shapes, producing codec_codes of length 1 that Code2Wav
rejected as 'malformed request' (qwen3_tts_code2wav.py line 238) -
the root cause of silent/corrupted audio for specific spans.

Replace both with HF's exact behavior from modeling_qwen3_tts.py lines
2283-2290: truncate at the first EOS frame. Since talker_mtp zeroes
rows on EOS, this is equivalent to truncating at the first all-zero
row. No other filtering is needed because compute_logits prevents any
special token except EOS from being sampled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Qwen3-TTS Base voice clone: audio distortion in async_chunk streaming mode

5 participants