Skip to content

feat: add MOSS-TTS-Nano single-stage TTS support#2753

Merged
hsliuustc0106 merged 18 commits into
vllm-project:mainfrom
linyueqian:feat/moss-tts-nano
Apr 27, 2026
Merged

feat: add MOSS-TTS-Nano single-stage TTS support#2753
hsliuustc0106 merged 18 commits into
vllm-project:mainfrom
linyueqian:feat/moss-tts-nano

Conversation

@linyueqian

@linyueqian linyueqian commented Apr 13, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • Add single-stage MOSS-TTS-Nano (0.1B) TTS model support with online serving via `/v1/audio/speech`
  • Integrates OpenMOSS-Team/MOSS-TTS-Nano AR LM + MOSS-Audio-Tokenizer-Nano codec
  • 15 built-in voice presets (6 ZH, 4 EN, 5 JA), 48 kHz stereo output
  • Gradio demo with AudioWorklet gapless streaming player and TTFP/RTF metrics
  • Progressive streaming via AR runner — TTFP reduced from ~3.1s to ~0.11s (30x improvement)

Doc/skill updates split into #2806.

Changes

Model integration:

  • `vllm_omni/model_executor/models/moss_tts_nano/` — model wrapper using `trust_remote_code` for upstream HF classes; VoxCPM-style per-request `inference_stream()` generator stored in `_stream_gens`; each `forward()` call pops one chunk via `next(generator)`
  • `vllm_omni/model_executor/stage_configs/moss_tts_nano.yaml` — single-stage config (`worker_type: ar`, `is_comprehension: true`, `async_chunk: false`, `max_num_seqs: 4`)
  • `vllm_omni/model_executor/models/registry.py` — register `MossTTSNanoForCausalLM`

Serving layer:

  • `vllm_omni/entrypoints/openai/serving_speech.py` — add `_MOSS_TTS_MODEL_STAGES` constant, extend `_TTS_MODEL_STAGES` union, add `moss_tts_nano` branch in `_get_tts_model_type()`, add `_validate_moss_tts_request()` and `_build_moss_tts_params()` methods

Streaming (AR runner):

  • Switched from `GPUGenerationWorker` to AR worker with `OmniARScheduler`, using the VoxCPM-style generator pattern
  • `inference_stream()` stored per-request; each `forward()` call yields one audio chunk
  • `compute_logits()` emits EOS only when the last chunk is yielded; AR scheduler loops until EOS for progressive audio output

Examples:

  • `examples/offline_inference/moss_tts_nano/` — CLI end-to-end script + README
  • `examples/online_serving/moss_tts_nano/` — Gradio demo, server/demo launch scripts, README

Tests:

  • `tests/e2e/offline_inference/test_moss_tts_nano.py` — English, Chinese, deterministic, batch, voice presets
  • `tests/e2e/online_serving/test_moss_tts_nano.py` — non-streaming WAV, streaming PCM, Chinese
  • `.buildkite/test-merge.yml` — pre-merge CI entry

Test plan

  • Offline inference verified on H20 (English/Chinese/Japanese voices)
  • Online serving verified via curl (`/v1/audio/speech` returns 48 kHz audio)
  • Gradio demo verified via SSH tunnel (AudioWorklet streaming player)
  • Progressive streaming verified: TTFP ~0.11s (vs ~3.1s before)
  • Pre-commit hooks pass (ruff F841 fix, ruff format applied)
  • DCO sign-off email matches git author (`linyueqian@outlook.com`)
  • Buildkite pre-merge CI (L4 GPU)

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@linyueqian linyueqian force-pushed the feat/moss-tts-nano branch 2 times, most recently from 1d4beaa to 5de0923 Compare April 13, 2026 22:36
@hsliuustc0106

Copy link
Copy Markdown
Collaborator

BLOCKER scan:

This PR has a merge conflict. Please resolve the conflict before proceeding with review.

OVERALL: MERGE CONFLICT

VERDICT: REQUEST_CHANGES

@hsliuustc0106

Copy link
Copy Markdown
Collaborator

I think you can update the skills to allow community users to contribute more effectively

@linyueqian

Copy link
Copy Markdown
Collaborator Author

I think you can update the skills to allow community users to contribute more effectively

will do

@lishunyang12 lishunyang12 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review: feat: add MOSS-TTS-Nano single-stage TTS support

Overall this is a well-structured PR that follows existing patterns in the codebase (VoxCPM-style generator, stage config YAML, registry entry, serving layer extension). The examples, tests, and documentation are thorough. A few issues to address before merging:


Issues

1. _build_moss_tts_params does not handle ref_audio (serving layer bug)
vllm_omni/entrypoints/openai/serving_speech.py_build_moss_tts_params() maps ref_text to prompt_text but completely ignores request.ref_audio. The online serving README documents custom voice cloning via ref_audio as a supported feature, and the model's _create_stream_gen() reads prompt_audio_path from additional_information. Without wiring ref_audio through the serving layer, voice cloning via the /v1/audio/speech endpoint will silently fail to use the reference audio.

You need to resolve ref_audio (download/decode the data URL) and pass the path through as prompt_audio_path in the params dict, similar to how CosyVoice3 or VoxCPM2 handle it.

2. _ar_emit_stop_token is shared mutable state across a batch — race condition
modeling_moss_tts_nano.pyself._ar_emit_stop_token is a single boolean on the model instance. In forward(), it is set to all(last_chunk_flags) — meaning if any request in the batch is still generating, all requests get non-EOS logits from compute_logits(). This means finished requests are kept alive (emitting empty audio) until the slowest request in the batch completes. For max_num_seqs: 4 this could be significant. Consider making this per-request (e.g., a dict keyed by request ID) so the AR scheduler can finish individual requests independently.

3. _create_stream_gen buffers all chunks then yields — "streaming" is misleading
The docstring says "yields one audio chunk per forward() call" for progressive streaming, and the PR description claims "TTFP reduced from ~3.1s to ~0.11s." However, looking at the generator, it calls self._lm.inference_stream() which does yield events progressively, and those are yielded out. This part looks correct on re-read. But the comment at line ~1949 says "We buffer first because inference_stream mixes audio events with a final result event" — this comment is outdated/misleading since you actually yield chunk, False inside the loop. Please clean up the comment to match the actual streaming behavior.

4. torch.manual_seed() in _create_stream_gen sets global RNG state
modeling_moss_tts_nano.py line ~1925 — calling torch.manual_seed(seed) and torch.cuda.manual_seed_all(seed) sets the global RNG state. In a concurrent batch with max_num_seqs: 4, one request's seed will overwrite another's. Consider using a torch.Generator for per-request determinism, or at least document that seeding is best-effort and not safe under concurrency.

5. _stream_gens dict is not thread-safe
self._stream_gens is a plain dict mutated without the existing self._lock. While the AR worker is likely single-threaded, the _lock is already used for load_weights, so if there's any possibility of concurrent forward calls (e.g., from async scheduling), this could corrupt state. Either document the single-thread assumption or protect mutations with the lock.

6. Offline test _collect_audio has a typo: AssertionError
tests/e2e/offline_inference/test_moss_tts_nano.py line 1331 — raise AssertionError is misspelled as AssertionError. This would actually raise a NameError at runtime instead of the intended assertion.


Minor / Nits

  • ## MOSS-TTS-Nano comment in registry.py — other entries don't use ## headers. Use # for consistency.
  • for _ in weights: pass in load_weights() — add a comment explaining this drains the iterator to satisfy the vLLM weight-loading protocol, since it's non-obvious.
  • Gradio demo is 690 lines — the inline AudioWorklet JS and HTML templates are large. Consider extracting them to separate files under the example directory for maintainability (not blocking).
  • _DEFAULT_MODE = "continuation" vs online serving README says default voice is "Junhao" which implies voice_clone mode — the defaults are inconsistent between offline and online paths. The offline example uses voice_clone as default mode while the model code uses continuation. Clarify which is intended.
  • CI step uses gpu_1_queue — confirm this is the right queue for L4 GPUs, as the test decorator specifies res={"cuda": "L4"}.
  • _REPO_ROOT calculation in end2end.py uses Path(__file__).resolve().parents[4] — this is fragile and breaks if the file is moved. Consider using a more robust path resolution or accepting it as a required CLI arg.

What looks good

  • Clean single-stage architecture with well-documented YAML config
  • Proper use of the VoxCPM generator pattern for streaming
  • Good test coverage (offline: English, Chinese, deterministic, batch, voice presets; online: WAV, streaming PCM, Chinese)
  • Serving layer integration follows existing patterns cleanly
  • Gradio demo with AudioWorklet streaming is a nice addition

Please address items 1, 2, 4, and 6 before merging. Items 3 and 5 are lower priority but worth fixing.

@linyueqian

Copy link
Copy Markdown
Collaborator Author

@Sy0307 PTAL.

@linyueqian linyueqian added this to the v0.20.0 milestone Apr 22, 2026
@Sy0307

Sy0307 commented Apr 22, 2026

Copy link
Copy Markdown
Collaborator

Overall LGTM, one minor issue:

_stream_gens leaks on abnormal request termination

When a request is cancelled, timed out, or preempted, the AR scheduler notifies the model via on_requests_finished() (gpu_ar_model_runner.py:324, guarded by hasattr). This model doesn't implement that method, so terminated requests leave their generator permanently in _stream_gens — forward() never visits it again, GC can't collect it (dict holds a reference), and the finally block never runs (temp files leak too).

Suggest adding cleanup following the voxcpm2_talker.py pattern (on_requests_finished + _deferred_cleanup_ids):

def on_requests_finished(self, finished_req_ids: set[str] | list[str]) -> None:
    for req_id in finished_req_ids:
        gen = self._stream_gens.pop(str(req_id), None)
        if gen is not None:
            gen.close()

@linyueqian linyueqian added the ready label to trigger buildkite CI label Apr 24, 2026
Integrates OpenMOSS-Team/MOSS-TTS-Nano (0.1B) as a single
GPUGenerationWorker stage.  Both the AR LM and MOSS-Audio-Tokenizer-Nano
codec run inside MossTTSNanoForGeneration.forward(), removing the need for
an inter-stage connector.

Key design choices:
- Weights loaded in load_weights() not __init__ (avoids pre-CUDA alloc)
- trust_remote_code delegates to upstream HF model classes
- codec path read from config.audio_tokenizer_pretrained_name_or_path
- inference_stream() collects progressive audio chunks for low latency
- 48 kHz stereo output; voice clone + continuation modes supported


Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
…codec unavailable

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
…io/speech integration

- Register moss_tts_nano as TTS model stage in serving_speech.py
  (model detection, validation, prompt building)
- Fix registry mod_relname: moss_tts_nano -> modeling_moss_tts_nano
- Fix stage config: is_comprehension=true (required for generate task)
- Fix default mode: voice_clone -> continuation (built-in presets)
- Add compute_logits stub for VllmModelForTextGeneration protocol
- Remove unused _sentinel nn.Parameter
- Add Gradio demo with AudioWorklet streaming player (48kHz stereo)
- Add online/offline serving docs and launch scripts

TODO: Single-stage generation models don't support true streaming
(progressive audio chunks). Current TTFP = full generation time.
Needs multi-step scheduling support in GPUGenerationModelRunner.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
- Add online serving e2e test (non-streaming, streaming, Chinese)
- Add online serving user guide doc (API, voices, Gradio, curl/Python)
- Add offline inference user guide doc

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
- Add MOSS-TTS-Nano E2E test to .buildkite/test-merge.yml
- Remove docs from docs/ (will be synced from examples/ READMEs)

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Switch from GPUGenerationWorker to AR worker with OmniARScheduler,
using the VoxCPM-style generator pattern for streaming:

- inference_stream() stored per-request in _stream_gens dict
- Each forward() call yields one audio chunk via next(generator)
- compute_logits() emits EOS only when last chunk is yielded
- AR scheduler loops until EOS, enabling progressive audio output

TTFP reduced from ~3.1s to ~0.11s (30x improvement).

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
The upstream inference_stream() API has no voice/spk parameter;
voice preset selection is not yet wired into the call. Remove the
dead assignment to silence ruff F841.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
- Wire ref_audio through /v1/audio/speech: _build_moss_tts_params is now
  async, resolves ref_audio via MediaConnector and passes the waveform as
  prompt_audio_array so the model can materialise a temp WAV for upstream
  inference_stream (previously voice cloning via the REST endpoint was a
  no-op). Serving now also requires ref_text when ref_audio is provided.
- Fix per-request EOS in batched decode: replace the shared
  _ar_emit_stop_token bool with a _ar_last_chunk_flags list so
  compute_logits emits EOS per row; finished requests no longer wait for
  the slowest peer in a max_num_seqs=4 batch.
- Snapshot and restore CPU+CUDA RNG state around torch.manual_seed to
  limit global-state bleed; add comment noting deterministic output
  under concurrent batching is best-effort (upstream inference_stream
  uses the global RNG).
- Align _DEFAULT_MODE with the offline example and tests ("voice_clone").
- Clean up outdated "buffer first" comment in _create_stream_gen; document
  single-threaded AR-worker assumption for _stream_gens; add one-liner
  explaining the load_weights iterator drain.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Follow-up to vllm-project#2958. MOSS-TTS-Nano was authored before the schema refactor
and still shipped as a legacy stage_configs yaml; this aligns it with the
same layout the 5 migrated TTS models now use.

- vllm_omni/model_executor/models/moss_tts_nano/pipeline.py declares
  MOSS_TTS_NANO_PIPELINE (single LLM_AR stage, owns_tokenizer, audio
  output, stop_token_ids=[2] as a hard EOS backstop).
- vllm_omni/deploy/moss_tts_nano.yaml holds runtime knobs (max_num_seqs,
  gpu_memory_utilization, enforce_eager, default_sampling_params,
  skip_mm_profiling); trust_remote_code stays at deploy top-level.
- vllm_omni/config/pipeline_registry.py registers the entry so the lazy
  registry can resolve it.
- moss_tts_nano/__init__.py exports MossTTSNanoForGeneration (VoxCPM2 pattern).
- Removed vllm_omni/model_executor/stage_configs/moss_tts_nano.yaml.

Examples, shell scripts, READMEs, and buildkite-invoked tests are
updated to use `vllm serve <model> --omni` / `--deploy-config`
(auto-load kicks in; no --stage-configs-path or --trust-remote-code).

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
forward() only popped _stream_gens on normal completion, so cancelled,
timed-out, or preempted requests leaked their generator and skipped the
finally block that unlinks the temp WAV files. Implement
on_requests_finished to close each finished generator, which raises
GeneratorExit inside it and triggers the existing cleanup.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
linyueqian and others added 5 commits April 26, 2026 00:49
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
…fs clash

Fish Speech's example README also has a `## Model` heading. mkdocs-autorefs
treats both as primary URLs for the symbol `model`, producing a warning per
cross-ref site (>100 warnings). With `--strict` + `fail_on_warning: true` in
.readthedocs.yml, this fails the docs build (RTD #32425202).

Renaming this PR's new heading to `## Model checkpoint` removes the slug
conflict and gets the docs build green.

Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
@hsliuustc0106 hsliuustc0106 merged commit afa2b09 into vllm-project:main Apr 27, 2026
7 of 8 checks passed
linyueqian added a commit to linyueqian/vllm-omni that referenced this pull request Apr 27, 2026
MOSS-TTS-Nano upstream is voice-cloning-only — there are no built-in
speaker presets. The integration shipped 15 invented voice names
(Junhao/Ava/Adam/...) with no resolution layer mapping name → audio,
so every request defaulted to mode='voice_clone' with prompt_audio_path
unset and the model raised ValueError on serve. The post-merge L4 build
caught this (RTD #8107, MOSS-TTS-Nano E2E Test).

Changes:
- serving_speech: require ref_audio + ref_text in /v1/audio/speech;
  ignore the OpenAI-schema voice field with a clear error message.
- modeling: drop _DEFAULT_VOICE; dummy run no longer carries voice.
- examples/online: rewrite README + gradio_demo around required ref
  audio upload. Drop the 15-row preset table + dropdown + examples.
- examples/offline: --prompt-audio and --prompt-text now required.
  Drop --voice and --batch (no per-voice batch makes sense without
  presets). README points users to upstream assets/audio/zh_1.wav.
- tests: session-scoped fixture downloads upstream zh_1.wav (~50 KB)
  and reuses it across cases. Drop test_moss_tts_nano_voice_presets
  (no presets to test). All paths use XDG_CACHE_HOME or pytest
  tmp_path_factory — no /tmp shared-dir writes.

Refs: vllm-project#2753
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
linyueqian added a commit to linyueqian/vllm-omni that referenced this pull request Apr 27, 2026
MOSS-TTS-Nano upstream is voice-cloning-only — there are no built-in
speaker presets. The integration shipped 15 invented voice names
(Junhao/Ava/Adam/...) with no resolution layer mapping name → audio,
so every request defaulted to mode='voice_clone' with prompt_audio_path
unset and the model raised ValueError on serve. The post-merge L4 build
caught this (RTD #8107, MOSS-TTS-Nano E2E Test).

Changes:
- serving_speech: require ref_audio + ref_text in /v1/audio/speech;
  ignore the OpenAI-schema voice field with a clear error message.
- modeling: drop _DEFAULT_VOICE; dummy run no longer carries voice.
- examples/online: rewrite README + gradio_demo around required ref
  audio upload. Drop the 15-row preset table + dropdown + examples.
- examples/offline: --prompt-audio and --prompt-text now required.
  Drop --voice and --batch (no per-voice batch makes sense without
  presets). README points users to upstream assets/audio/zh_1.wav.
- tests: session-scoped fixture downloads upstream zh_1.wav (~50 KB)
  and reuses it across cases. Drop test_moss_tts_nano_voice_presets
  (no presets to test). All paths use XDG_CACHE_HOME or pytest
  tmp_path_factory — no /tmp shared-dir writes.

Refs: vllm-project#2753
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
linyueqian added a commit to linyueqian/vllm-omni that referenced this pull request Apr 27, 2026
MOSS-TTS-Nano upstream is voice-cloning-only — there are no built-in
speaker presets. The integration shipped 15 invented voice names
(Junhao/Ava/Adam/...) with no resolution layer mapping name → audio,
so every request defaulted to mode='voice_clone' with prompt_audio_path
unset and the model raised ValueError on serve. The post-merge L4 build
caught this (RTD #8107, MOSS-TTS-Nano E2E Test).

Changes:
- serving_speech: require ref_audio + ref_text in /v1/audio/speech;
  ignore the OpenAI-schema voice field with a clear error message.
- modeling: drop _DEFAULT_VOICE; dummy run no longer carries voice.
- examples/online: rewrite README + gradio_demo around required ref
  audio upload. Drop the 15-row preset table + dropdown + examples.
- examples/offline: --prompt-audio and --prompt-text now required.
  Drop --voice and --batch (no per-voice batch makes sense without
  presets). README points users to upstream assets/audio/zh_1.wav.
- tests: session-scoped fixture downloads upstream zh_1.wav (~50 KB)
  and reuses it across cases. Drop test_moss_tts_nano_voice_presets
  (no presets to test). All paths use XDG_CACHE_HOME or pytest
  tmp_path_factory — no /tmp shared-dir writes.

Refs: vllm-project#2753
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
linyueqian added a commit to linyueqian/vllm-omni that referenced this pull request Apr 27, 2026
MOSS-TTS-Nano upstream is voice-cloning-only — there are no built-in
speaker presets. The integration shipped 15 invented voice names
(Junhao/Ava/Adam/...) with no resolution layer mapping name → audio,
so every request defaulted to mode='voice_clone' with prompt_audio_path
unset and the model raised ValueError on serve. The post-merge L4 build
caught this (RTD #8107, MOSS-TTS-Nano E2E Test).

Changes:
- serving_speech: require ref_audio + ref_text in /v1/audio/speech;
  ignore the OpenAI-schema voice field with a clear error message.
- modeling: drop _DEFAULT_VOICE; dummy run no longer carries voice.
- examples/online: rewrite README + gradio_demo around required ref
  audio upload. Drop the 15-row preset table + dropdown + examples.
- examples/offline: --prompt-audio and --prompt-text now required.
  Drop --voice and --batch (no per-voice batch makes sense without
  presets). README points users to upstream assets/audio/zh_1.wav.
- tests: session-scoped fixture downloads upstream zh_1.wav (~50 KB)
  and reuses it across cases. Drop test_moss_tts_nano_voice_presets
  (no presets to test). All paths use XDG_CACHE_HOME or pytest
  tmp_path_factory — no /tmp shared-dir writes.

Refs: vllm-project#2753
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
xiaohajiayou pushed a commit to xiaohajiayou/vllm-omni that referenced this pull request Apr 30, 2026
Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
lengrongfu pushed a commit to lengrongfu/vllm-omni that referenced this pull request May 1, 2026
Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
BeatSeat pushed a commit to BeatSeat/vllm-omni that referenced this pull request May 2, 2026
Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
sphinxkkkbc pushed a commit to sphinxkkkbc/vllm-omni that referenced this pull request May 4, 2026
Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
quyifei23 pushed a commit to quyifei23/vllm-omni that referenced this pull request Jun 6, 2026
Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants