Skip to content

[Bugfix] Fix MiMo-Audio voice instability: stochastic local_sampler + codec streaming context#3686

Merged
linyueqian merged 6 commits into
vllm-project:mainfrom
Galleons2029:fix/mimo-audio-voice-instability
May 21, 2026
Merged

[Bugfix] Fix MiMo-Audio voice instability: stochastic local_sampler + codec streaming context#3686
linyueqian merged 6 commits into
vllm-project:mainfrom
Galleons2029:fix/mimo-audio-voice-instability

Conversation

@Galleons2029
Copy link
Copy Markdown
Contributor

@Galleons2029 Galleons2029 commented May 18, 2026

Root Causes Fixed

Supersedes #3548

1. local_sampler forced argmax — flat, monotone speech (mimo_audio_llm.py) local_sampler was initialized with do_sample=False, forcing all audio RVQ code generation to argmax instead of stochastic sampling. This also activated the CUDA-Graph path (MiMoLocalSamplerTensor) which silently enforces argmax even when temperature > 0.

Fix: local_sampler = MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95) in __init__, base_local_forward fallback, and local_forward fallback. global_sampler stays do_sample=False to match vLLM's external greedy sampler and avoid KV-cache state divergence.

2. Codec streaming context too small → multiple voices (yaml + stage_input_processors) codec_left_context_frames=3 was far below the vocoder attention window vocoder_attn_window_size=[40, 10]. At each chunk boundary the vocoder had no meaningful acoustic history, resetting its internal state and producing a new random voice timbre — audible as multi-speaker artifacts.

  • codec_left_context_frames: 3 → 40 (covers the full vocoder attention window)
  • codec_chunk_frames: 3 → 30 (fewer boundaries; context overhead 14× → ~2.3×)
  • Applied to both vllm_omni/deploy/mimo_audio.yaml and the newly tracked vllm_omni/model_executor/stage_configs/mimo_audio.yaml

3. Dead sum(code_list) == 0 guard removed (stage_input_processors/mimo_audio.py) The check included pad tokens (TALKER_CODEC_PAD_TOKEN_ID=151667), making the sum always ≥ 2.4 M — the condition was structurally unreachable. Removed. Dropping mid-stream frames would also break vocoder temporal alignment regardless.

Guardrail: parameter floor validation

Added _MIN_CODEC_CHUNK_FRAMES=3 / _MIN_CODEC_LEFT_CONTEXT_FRAMES=40 / _DEFAULT_* constants in both stage_input_processors/mimo_audio.py and mimo_audio_code2wav.py __init__. Values below the floor emit logger.warning and are clamped to safe defaults, preventing silent regressions on config changes.

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Test Plan

The test case mentioned in #3452

Test Result

English:
weather.wav

Chinese:
30frame.wav


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 0dcd4593af

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm_omni/deploy/mimo_audio.yaml Outdated
@Galleons2029 Galleons2029 force-pushed the fix/mimo-audio-voice-instability branch from 0dcd459 to bfd9826 Compare May 18, 2026 07:16
… codec streaming context

`local_sampler` was initialized with `do_sample=False`, forcing all audio RVQ
code generation to argmax instead of stochastic sampling. This also activated
the CUDA-Graph path (`MiMoLocalSamplerTensor`) which silently enforces argmax
even when temperature > 0.

Fix: `local_sampler = MiMoSampler(do_sample=True, temperature=0.9, top_p=0.95)`
in `__init__`, `base_local_forward` fallback, and `local_forward` fallback.
`global_sampler` stays `do_sample=False` to match vLLM's external greedy sampler
and avoid KV-cache state divergence.

`codec_left_context_frames=3` was far below the vocoder attention window
`vocoder_attn_window_size=[40, 10]`. At each chunk boundary the vocoder had no
meaningful acoustic history, resetting its internal state and producing a new
random voice timbre — audible as multi-speaker artifacts.

- `codec_left_context_frames: 3 → 40` (covers the full vocoder attention window)
- `codec_chunk_frames: 3 → 30` (fewer boundaries; context overhead 14× → ~2.3×)
- Applied to both `vllm_omni/deploy/mimo_audio.yaml` and the newly tracked
  `vllm_omni/model_executor/stage_configs/mimo_audio.yaml`

The check included pad tokens (`TALKER_CODEC_PAD_TOKEN_ID=151667`), making the
sum always ≥ 2.4 M — the condition was structurally unreachable. Removed.
Dropping mid-stream frames would also break vocoder temporal alignment regardless.

Added `_MIN_CODEC_CHUNK_FRAMES=3` / `_MIN_CODEC_LEFT_CONTEXT_FRAMES=40` /
`_DEFAULT_*` constants in both `stage_input_processors/mimo_audio.py` and
`mimo_audio_code2wav.py` `__init__`. Values below the floor emit `logger.warning`
and are clamped to safe defaults, preventing silent regressions on config changes.

Signed-off-by: Galleons2029 <Galleons777@gmail.com>
@Galleons2029 Galleons2029 force-pushed the fix/mimo-audio-voice-instability branch from bfd9826 to 85a6fdf Compare May 18, 2026 07:21
Follow the review advice from gpt bot:
"This change makes the bundled MiMo deploy YAML require specific logical GPUs ("1,2" for stage 0 and "3" for stage 1, plus tensor_parallel_size: 2), which breaks common setups where users only expose one device (or fewer than four logical IDs). In that case, stage initialization fails when device mapping cannot resolve those IDs, so vllm serve ... --omni no longer starts with the default config for many environments."

Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
@linyueqian
Copy link
Copy Markdown
Collaborator

@qibaoyuan ptal

@linyueqian
Copy link
Copy Markdown
Collaborator

long_MAIN.wav
long_PR.wav

@linyueqian linyueqian added ready label to trigger buildkite CI tts-test label to trigger buildkite tts models test in nightly CI labels May 20, 2026
Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve. Verified on an H20 server with MiMo-Audio-7B-Instruct, per the issue #3452 repro, using a single-GPU deploy config override (so both deploy/mimo_audio.yaml and the new stage_configs/mimo_audio.yaml differ from this run only in codec_chunk_frames/codec_left_context_frames).

Short prompt: "The weather is so nice today." (Whisper-medium ASR, 3 runs each)

Run main @ a3d4ed8 PR #3686
1 6.88s, "...so nice today. The weather is so nice today." (repeats) 3.20s, "...so nice today."
2 4.96s, correct text 5.44s, correct text
3 6.56s, correct text (abnormally long) 2.24s, correct text

Long prompt with --instruct "Speak naturally and clearly in a calm female voice"

  • main: 4.00s, "The kaput brown fox jumps over the lazy dog..." (word substitution)
  • PR: 6.24s, "The quick brown fox jumps over the lazy dog near the riverbank at noon."

PR consistently shorter, no repetition, no word corruption. Both proposed root causes (local_sampler forced argmax + codec_left_context_frames too small for vocoder_attn_window_size=[40,10]) match the symptoms.

Follow-up items (non-blocking, fine as a fixup commit on this PR)

🟡 [important] Two existing tests fail under the new _flush_remaining_codes semantics:

  • tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py::test_flush_remaining_codes_when_length_is_exact_multiple_of_chunk_size
  • ...::test_flush_remaining_codes_context_window_end_index[6-3-3-6]

Both assert the old behavior where length % chunk_size == 0 returned the full last chunk. The new early-return is the right fix (those frames were already emitted as a complete chunk upstream); the tests just need to expect the finished sentinel instead.

🟢 [nit] See inline on stage_configs/mimo_audio.yaml about the hardcoded GPU IDs in the newly added file.

Performance note (FYI, not blocking)

Setting local_sampler.do_sample=True disables the CUDA-graph path through the use_cg gate in local_forward (use_cg = (do_sample is None or do_sample is False) and ...). Stage-0 per-prompt latency went from ~2.28s/it to ~2.70s/it on this hardware (about 18% slower). A reasonable trade-off for correctness, but worth calling out in the PR description so users know to expect it.

Thanks for tracking this one down.

Comment thread vllm_omni/model_executor/stage_configs/mimo_audio.yaml Outdated
@qibaoyuan
Copy link
Copy Markdown
Contributor

LGTM

@linyueqian
Copy link
Copy Markdown
Collaborator

great! @Galleons2029 please fix as suggested about and i think it is ready to merge

Follow the review advice from Yueqian Lin: "Suggest either deleting this file or aligning it with the single-GPU defaults so it does not contradict the revert and mislead future readers."

Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
…boundary

When length % chunk_size == 0, _flush_remaining_codes previously returned an
empty finished sentinel, dropping the tail audio frames. The vocoder needs the
final chunk plus left context to produce a stable tail; otherwise voice cuts
off at chunk boundaries. Fall back to chunk_size as the context length in this
case, matching the behavior pinned by the new unit tests in
tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py.

Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
@Galleons2029 Galleons2029 force-pushed the fix/mimo-audio-voice-instability branch 3 times, most recently from b123c6b to e4ca981 Compare May 21, 2026 06:16
@Galleons2029
Copy link
Copy Markdown
Contributor Author

Galleons2029 commented May 21, 2026

great! @Galleons2029 please fix as suggested about and i think it is ready to merge

@linyueqian The suggestion have been resolved. One thing left: the remaining buildkite/vllm-omni-npu-ci failure is an upstream vLLM × NPU image mismatch (ImportError: cannot import name 'split_routed_experts' from vllm.model_executor.layers.fused_moe.routed_experts_capturer) — same import error has been failing on the NPU pipeline across multiple unrelated runs (#1588, #1600, #1619, #1625) and is not introduced by this PR. The PR only touches vllm_omni/model_executor/{models,stage_input_processors,stage_configs}/mimo_audio and vllm_omni/deploy/mimo_audio.yaml (5 files, all under the MiMo-Audio code path). Could you help retrigger / mark this NPU job as non-blocking?

@linyueqian linyueqian merged commit e949ccf into vllm-project:main May 21, 2026
13 of 15 checks passed
@linyueqian
Copy link
Copy Markdown
Collaborator

thanks! i have merged this pr

Nightwing-77 pushed a commit to Nightwing-77/vllm-omni that referenced this pull request May 21, 2026
… codec streaming context (vllm-project#3686)

Signed-off-by: Galleons2029 <Galleons777@gmail.com>
Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
Co-authored-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Signed-off-by: Advik <scince5678@gmail.com>
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

https://buildkite.com/vllm/vllm-omni/builds/10167/canvas?sid=019e4af0-54d6-4fb3-9d82-c771a02c5dde&tab=output CI failed when merging to main

Galleons2029 added a commit to Galleons2029/vllm-omni-ljl that referenced this pull request May 22, 2026
…batching

PR vllm-project#3686 set `local_sampler.do_sample=True` (temperature=0.9, top_p=0.95)
intending to fix MiMo-Audio voice instability by avoiding the silent argmax
that `MiMoLocalSamplerTensor` enforces on the CUDA-graph path.  The change
unintentionally destabilises text continuations under concurrent batched
requests, surfaced by buildkite/vllm-omni #10167 on the merge-to-main run
of `tests/e2e/online_serving/test_mimo_audio.py::test_text_to_text_001`:

  AssertionError: The output does not contain any of the keywords.
  response.text_content = 'The capital of China is.'   # missing "Beijing"

Root cause is an audio-embedding feedback loop into the text decode path:

  local_forward (stochastic local_sampler)
    -> next_speech_tokens   # random under temp=0.9 top_p=0.95
    -> new_audio_emb = sum_k speech_embeddings[k](next_speech_tokens[..,k,..])
    -> _cached_new_audio_emb_by_req[req_id] = new_audio_emb
    -> next decode step: _prepare_multimodal_embeddings_with_cache adds
       prev_new_audio_emb back into inputs_embeds
    -> self.model(input_ids, positions, inputs_embeds=inputs_embeds)
    -> compute_logits -> global_sampler.sample (greedy)

`global_sampler` is greedy but its *logits* depend on the random audio
embedding from the previous step, so the greedy argmax flips for some
batch members.  At batch=5 with identical prompt
"What is the capital of China? Answer in 20 words." we observed 5
different continuations, one of which dropped "Beijing" entirely and
emitted "The capital of China is.<eot>" instead.

Reproduction (with --run-level=advanced_model):

  before revert: 4/5 contain "beijing", 1/5 truncates -> FAIL
  after revert : 5/5 contain "beijing"                -> PASS

Setting do_sample=False also restores the CUDA-graph path
(`use_cg = (do_sample is None or do_sample is False) and ...`), undoing
the ~18% stage-0 per-prompt latency regression Codex flagged on vllm-project#3686.

The voice-instability symptoms PR vllm-project#3686 set out to fix are actually
resolved by its other change -- `codec_left_context_frames: 3 -> 40` in
the stage-1 vocoder config, which covers `vocoder_attn_window_size=[40, 10]`
and prevents acoustic-state resets at chunk boundaries.  That change
lives in stage_configs / stage_input_processors and is preserved here.
Voice diversity, if needed, should be reintroduced in the codec/vocoder
path (stage-1) with a per-request seed rather than by randomising the
shared local_sampler whose outputs feed back into stage-0 text logits.

Three sites touched, all in mimo_audio_llm.py:
- __init__ : self.local_sampler do_sample True -> False
- base_local_forward fallback: same
- local_forward fallback: same

The unrelated `pooling_output is None` guard in
stage_input_processors/mimo_audio.py landed earlier on this branch is
retained.  That guard fixes a separate AttributeError in
chunk_transfer_adapter when stage-0 emits None pooling_output on
text-only paths.  It is independent of the truncation bug.

Fixes vllm-project#3815
Follow-up to vllm-project#3686

Signed-off-by: Galleons2029 <Galleons777@gmail.com>
Galleons2029 added a commit to Galleons2029/vllm-omni-ljl that referenced this pull request May 22, 2026
…batching

PR vllm-project#3686 set `local_sampler.do_sample=True` (temperature=0.9, top_p=0.95)
intending to fix MiMo-Audio voice instability by avoiding the silent argmax
that `MiMoLocalSamplerTensor` enforces on the CUDA-graph path.  The change
unintentionally destabilises text continuations under concurrent batched
requests, surfaced by buildkite/vllm-omni #10167 on the merge-to-main run
of `tests/e2e/online_serving/test_mimo_audio.py::test_text_to_text_001`:

  AssertionError: The output does not contain any of the keywords.
  response.text_content = 'The capital of China is.'   # missing "Beijing"

Root cause is an audio-embedding feedback loop into the text decode path:

  local_forward (stochastic local_sampler)
    -> next_speech_tokens   # random under temp=0.9 top_p=0.95
    -> new_audio_emb = sum_k speech_embeddings[k](next_speech_tokens[..,k,..])
    -> _cached_new_audio_emb_by_req[req_id] = new_audio_emb
    -> next decode step: _prepare_multimodal_embeddings_with_cache adds
       prev_new_audio_emb back into inputs_embeds
    -> self.model(input_ids, positions, inputs_embeds=inputs_embeds)
    -> compute_logits -> global_sampler.sample (greedy)

`global_sampler` is greedy but its *logits* depend on the random audio
embedding from the previous step, so the greedy argmax flips for some
batch members.  At batch=5 with identical prompt
"What is the capital of China? Answer in 20 words." we observed 5
different continuations, one of which dropped "Beijing" entirely and
emitted "The capital of China is.<eot>" instead.

Reproduction (with --run-level=advanced_model):

  before revert: 4/5 contain "beijing", 1/5 truncates -> FAIL
  after revert : 5/5 contain "beijing"                -> PASS

Setting do_sample=False also restores the CUDA-graph path
(`use_cg = (do_sample is None or do_sample is False) and ...`), undoing
the ~18% stage-0 per-prompt latency regression Codex flagged on vllm-project#3686.

The voice-instability symptoms PR vllm-project#3686 set out to fix are actually
resolved by its other change -- `codec_left_context_frames: 3 -> 40` in
the stage-1 vocoder config, which covers `vocoder_attn_window_size=[40, 10]`
and prevents acoustic-state resets at chunk boundaries.  That change
lives in stage_configs / stage_input_processors and is preserved here.
Voice diversity, if needed, should be reintroduced in the codec/vocoder
path (stage-1) with a per-request seed rather than by randomising the
shared local_sampler whose outputs feed back into stage-0 text logits.

Three sites touched, all in mimo_audio_llm.py:
- __init__ : self.local_sampler do_sample True -> False
- base_local_forward fallback: same
- local_forward fallback: same

The unrelated `pooling_output is None` guard in
stage_input_processors/mimo_audio.py landed earlier on this branch is
retained.  That guard fixes a separate AttributeError in
chunk_transfer_adapter when stage-0 emits None pooling_output on
text-only paths.  It is independent of the truncation bug.

Fixes vllm-project#3815
Follow-up to vllm-project#3686

Signed-off-by: Galleons2029 <Galleons777@gmail.com>
Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
hsliuustc0106 added a commit that referenced this pull request May 22, 2026
…tion under concurrent batching (followup to #3686) (#3817)

Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
Signed-off-by: Galleons2029 <Galleons777@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
zengchuang-hw pushed a commit to zengchuang-hw/vllm-omni that referenced this pull request Jun 1, 2026
… codec streaming context (vllm-project#3686)

Signed-off-by: Galleons2029 <Galleons777@gmail.com>
Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
Co-authored-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
zengchuang-hw pushed a commit to zengchuang-hw/vllm-omni that referenced this pull request Jun 1, 2026
…tion under concurrent batching (followup to vllm-project#3686) (vllm-project#3817)

Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
Signed-off-by: Galleons2029 <Galleons777@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI tts-test label to trigger buildkite tts models test in nightly CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants