Skip to content

[New Model] Add support for tencent/Covo-Audio-Chat#2293

Open
Dnoob wants to merge 25 commits intovllm-project:mainfrom
Dnoob:feat/covo-audio-chat
Open

[New Model] Add support for tencent/Covo-Audio-Chat#2293
Dnoob wants to merge 25 commits intovllm-project:mainfrom
Dnoob:feat/covo-audio-chat

Conversation

@Dnoob
Copy link
Copy Markdown
Contributor

@Dnoob Dnoob commented Mar 28, 2026

The model to consider

Model Weights: https://huggingface.co/tencent/Covo-Audio-Chat
Model Code: https://github.com/Tencent/Covo-Audio

Model description

This PR adds support for Covo-Audio-Chat (Tencent, 7B end-to-end audio language model) with a 2-stage pipeline:

Stage 0 (fused_thinker_talker)

  • Whisper encoder + AudioAdapter + Qwen2.5-7B LLM → interleaved text + audio tokens

Stage 1 (code2wav)

  • BigVGAN vocoder → 24kHz audio waveform

Note: This model requires audio input to produce correctly interleaved text + audio tokens. Text-only input causes abnormal interleaving patterns, resulting in incomplete audio output. This is a model design constraint, not a code limitation.

Changes

  • Core model: covo_audio.py (dual-stage router), covo_audio_llm.py (Stage 0), covo_audio_code2wav.py (Stage 1)
  • Config: config_covo_audio.py, stage YAML configs
  • Prompt utilities: prompt_utils.py (centralized prompt templates and construction helpers)
  • Stage input processor: extracts audio codes from Stage 0 output for Stage 1
  • Registry + speech endpoint integration
  • token2wav.py (consolidated from upstream model repo into a single module, import paths adapted)
  • Speaker prompt .npy files
  • Offline inference example + OpenAI-compatible client example
  • Documentation update to supported_models.md

Fixes #2004

Test plan

E2E test (requires GPU with ≥20 GiB VRAM):

CUDA_VISIBLE_DEVICES=0 pytest -s -v tests/e2e/offline_inference/test_covo_audio.py::test_audio_to_audio

Online serving:

CUDA_VISIBLE_DEVICES=0 vllm serve tencent/Covo-Audio-Chat --omni \
    --stage-configs-path vllm_omni/model_executor/stage_configs/covo_audio.yaml \
    --trust-remote-code

python examples/online_serving/covo_audio/openai_chat_completion_client.py

Test result

tests/e2e/offline_inference/test_covo_audio.py::test_audio_to_audio[omni_runner0] PASSED
======================= 1 passed in 62.86s ========================

Environment

  • GPU: 1x A100 (80 GiB)
  • Stage 0 (7B LLM): ~16 GiB VRAM
  • Stage 1 (BigVGAN): ~2 GiB VRAM

@Dnoob Dnoob requested a review from hsliuustc0106 as a code owner March 28, 2026 05:27
@Dnoob Dnoob force-pushed the feat/covo-audio-chat branch 2 times, most recently from a142631 to e44e30e Compare March 28, 2026 05:35
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5b064383be

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/model_executor/models/covo_audio/covo_audio_llm.py
Comment thread vllm_omni/model_executor/models/covo_audio/covo_audio.py Outdated
@Dnoob Dnoob force-pushed the feat/covo-audio-chat branch from e44e30e to cf389cb Compare March 28, 2026 05:43
@Dnoob
Copy link
Copy Markdown
Contributor Author

Dnoob commented Mar 28, 2026

@codex review

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: cf389cbb49

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

engine_args:
model_stage: fused_thinker_talker
max_num_seqs: 1
model_arch: CovoAudioForConditionalGeneration
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use a registered Covo model_arch in stage config

The new stage config sets model_arch: CovoAudioForConditionalGeneration, but this architecture key is not registered in vllm_omni/model_executor/models/registry.py (the commit only registers CovoAudioForCausalLM and CovoAudioModel as keys that map to the Covo class). Because OmniEngineArgs.create_model_config() injects model_arch directly into architectures, serving with vllm_omni/model_executor/stage_configs/covo_audio.yaml can fail model resolution at startup instead of loading Covo; the config should use a registered key or the registry should add CovoAudioForConditionalGeneration explicitly.

Useful? React with 👍 / 👎.

_QWEN3_TTS_MODEL_STAGES = {"qwen3_tts"}
_FISH_TTS_MODEL_STAGES = {"fish_speech_slow_ar"}
_TTS_MODEL_STAGES: set[str] = _VOXTRAL_TTS_MODEL_STAGES | _QWEN3_TTS_MODEL_STAGES | _FISH_TTS_MODEL_STAGES
_COVO_AUDIO_MODEL_STAGES = {"fused_thinker_talker"}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Restrict Covo TTS stage matching to Covo models

Marking fused_thinker_talker as a Covo TTS stage is too broad: the same model_stage is used by non-Covo pipelines (for example, MiMoAudio stage configs), so _find_tts_stage/_detect_tts_model_type will classify those models as covo_audio and route /v1/audio/speech through Covo-specific prompt construction. This can produce incorrect prompts or runtime failures for other models that share the stage name; model detection should include an architecture/model check instead of relying on stage name alone.

Useful? React with 👍 / 👎.

@Dnoob Dnoob force-pushed the feat/covo-audio-chat branch 2 times, most recently from 1232884 to b372618 Compare March 28, 2026 06:12
Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested locally on A100-80G. A few findings:

Bug: model path resolution breaks with HF repo names

covo_audio_code2wav.py:26 uses vllm_config.model_config.model directly as a filesystem path. When the model is specified as a HF repo name (e.g. tencent/Covo-Audio-Chat), os.path.join(model_path, "token2wav", ...) resolves to a relative path that doesn't exist:

FileNotFoundError: [Errno 2] No such file or directory: 'tencent/Covo-Audio-Chat/token2wav/global_mean_var.npy'

Need to resolve to the local cache path first, e.g.:

if os.path.isdir(model_name):
    model_path = model_name
else:
    from huggingface_hub import snapshot_download
    model_path = snapshot_download(model_name)

Missing dependency: torchdiffeq

The vendored token2wav flow matching module requires torchdiffeq, but it's not declared in pyproject.toml. Stage 1 fails with ModuleNotFoundError: No module named 'torchdiffeq'.

Audio output cuts off abruptly

With the CI config (max_tokens: 512, ignore_eos: true), the text-to-audio test produces ~9.25s of audio that ends mid-sentence. The interleaving ratio (5 text + 15 audio tokens) means only ~384 of the 512 tokens are audio codes. Consider either increasing max_tokens or removing ignore_eos: true so the model can stop naturally.

Minor

  • PR is 42 commits behind main, needs rebase
  • test_audio_to_audio requires espeak-ng system package (via pyttsx3 for synthetic audio generation). Worth noting in the test plan or adding a skip condition.

test_text_to_audio passes after fixing the path issue. Model produces reasonable text + audio output.

@Dnoob
Copy link
Copy Markdown
Contributor Author

Dnoob commented Mar 30, 2026

Tested locally on A100-80G. A few findings:

Bug: model path resolution breaks with HF repo names

covo_audio_code2wav.py:26 uses vllm_config.model_config.model directly as a filesystem path. When the model is specified as a HF repo name (e.g. tencent/Covo-Audio-Chat), os.path.join(model_path, "token2wav", ...) resolves to a relative path that doesn't exist:

FileNotFoundError: [Errno 2] No such file or directory: 'tencent/Covo-Audio-Chat/token2wav/global_mean_var.npy'

Need to resolve to the local cache path first, e.g.:

if os.path.isdir(model_name):
    model_path = model_name
else:
    from huggingface_hub import snapshot_download
    model_path = snapshot_download(model_name)

Missing dependency: torchdiffeq

The vendored token2wav flow matching module requires torchdiffeq, but it's not declared in pyproject.toml. Stage 1 fails with ModuleNotFoundError: No module named 'torchdiffeq'.

Audio output cuts off abruptly

With the CI config (max_tokens: 512, ignore_eos: true), the text-to-audio test produces ~9.25s of audio that ends mid-sentence. The interleaving ratio (5 text + 15 audio tokens) means only ~384 of the 512 tokens are audio codes. Consider either increasing max_tokens or removing ignore_eos: true so the model can stop naturally.

Minor

  • PR is 42 commits behind main, needs rebase
  • test_audio_to_audio requires espeak-ng system package (via pyttsx3 for synthetic audio generation). Worth noting in the test plan or adding a skip condition.

test_text_to_audio passes after fixing the path issue. Model produces reasonable text + audio output.

For the torchdiffeq dependency, should it go into requirements/common.txt (which affects all users), or would you prefer a lazy import with a clear error message + doc note, since it's only used by the vendored token2wav module?

@linyueqian
Copy link
Copy Markdown
Collaborator

@amy-why-3459 PTAL

@linyueqian
Copy link
Copy Markdown
Collaborator

Thanks for the fixes on path resolution and max_tokens.

For torchdiffeq — go with a lazy import since it's only used by the vendored token2wav:

try:
    from torchdiffeq import odeint
except ImportError:
    raise ImportError(
        "Covo-Audio code2wav requires `torchdiffeq`. "
        "Install it with: pip install torchdiffeq"
    )

Still outstanding:

  1. Rebase onto main (currently blocked)
  2. espeak-ng — add a skip condition for test_audio_to_audio

@amy-why-3459
Copy link
Copy Markdown
Contributor

Thank you very much for your contribution. Could you please add a readme file for the model?

from .token2latent import Token2latentFlowMatchingWithEmbed


class Token2WavDecoder(BaseModel):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put these function definitions in the model file to avoid creating too many folders?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't consider this earlier and directly copied the structure from the model repo, which made it bloated. I've analyzed it and plan to consolidate the entire token2wav/ folder into a single token2wav.py file and remove unnecessary code. Will update in the next push.

@Dnoob Dnoob force-pushed the feat/covo-audio-chat branch from 8840f97 to 7d6696c Compare April 2, 2026 05:48
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems there are a lot of dead code inside this PR

@Dnoob
Copy link
Copy Markdown
Contributor Author

Dnoob commented Apr 2, 2026

Thank you very much for your contribution. Could you please add a readme file for the model?

Already added

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a few comments on the core model files. the vendored token2wav code is fine to skip lint-wise but the non-vendored parts have a couple issues.

inputs_embeds: torch.Tensor | None = None,
generate_audio: bool = True,
codec: torch.Tensor | None = None,
sampling_metadata: SamplingMetadata | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This materializes the entire safetensors weight iterator into a Python list just to filter by prefix. For a 7B model that's ~14 GB of extra peak memory. Use a generator instead:

Suggested change
sampling_metadata: SamplingMetadata | None = None,
llm_weights = ((k, v) for k, v in weights if k.startswith(("llm", "encoder", "audio_adapter")))

**kwargs,
)

return OmniOutput(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several of these forward params (generate_audio, codec, logits_index, sampler, additional_information) are never used. Remove them — dead parameters in the forward signature are confusing for anyone reading the dispatch logic.

logger.info(
"Request %s: total_tokens=%d, text_tokens=%d, audio_tokens=%d",
request_id,
len(token_ids),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of per-request logging at INFO level (token counts, full interleaving pattern). In production with concurrent requests this will flood the logs. Drop the pattern log to DEBUG, or remove it.

Dnoob added 6 commits April 4, 2026 09:31
Adapt Covo-Audio-Chat (Tencent, 7B end-to-end audio language model)
to vllm-omni with a 2-stage pipeline:

- Stage 0 (fused_thinker_talker): Whisper encoder + AudioAdapter +
  Qwen2.5-7B LLM → interleaved text + audio tokens
- Stage 1 (code2wav): BigVGAN vocoder → 24kHz audio waveform

Closes vllm-project#2004

Signed-off-by: Dnoob <dxpouo@gmail.com>
… for Covo-Audio

Signed-off-by: Dnoob <dxpouo@gmail.com>
…Audio-Chat

- Add prompt_utils.py with shared system prompt and prompt builders, removing 3 duplicates
- Add offline inference example (end2end.py + README)
- Fix online client: add system prompt, stop_token_ids, ignore_eos, detokenize=false
- Fix stage config detokenize setting for code2wav
- Use local sample audio for online example instead of S3 download
- Add --port 18091 to online README to match client config

Signed-off-by: Dnoob <dxpouo@gmail.com>
Signed-off-by: Dnoob <dxpouo@gmail.com>
- Use generator instead of list comprehension in load_weights to reduce peak memory
- Remove unused forward parameters (generate_audio, codec, logits_index, sampler, additional_information)
- Remove debug logging from stage input processor

Signed-off-by: Dnoob <dxpouo@gmail.com>
@Dnoob Dnoob force-pushed the feat/covo-audio-chat branch from b0f956b to 9600536 Compare April 4, 2026 09:32
Dnoob added 2 commits April 4, 2026 10:33
- Fix MultiModalDataDict import path for vllm 0.19.0 compatibility
- Remove unsupported text-only test (model requires audio input)
- Use sample_audio.wav instead of pyttsx3 synthetic audio in test
- Remove espeak-ng dependency from test

Signed-off-by: Dnoob <dxpouo@gmail.com>
Signed-off-by: Dnoob <dxpouo@gmail.com>
@Dnoob
Copy link
Copy Markdown
Contributor Author

Dnoob commented Apr 4, 2026

@lishunyang12 @amy-why-3459 Addressed all review feedback, rebased on main, and updated PR description. PTAL, thanks!

Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

return "请回答这段音频里的问题。"


@pytest.mark.core_model
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this want to be added to nightly, please:
1.change label to @pytest.mark.advanced_model
2.rename this script to test_covo_audio_expansion.py

}
]

outputs = omni_runner.generate(omni_inputs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you can use omni_runner_handler.send_request(request_config) in conftest.py, like tests/e2e/offline_inference/test_qwen3_omni.py

Comment thread .buildkite/test-nightly.yml Outdated
if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/offline_inference/test_covo_audio.py -m "core_model" --run-level "core_model"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you can modify commands in 🌕 Omni Model Test with H100, instead of adding a separate new job.

Signed-off-by: Dnoob <dxpouo@gmail.com>
@Dnoob
Copy link
Copy Markdown
Contributor Author

Dnoob commented Apr 8, 2026

@yenuo26 Sorry, I forgot to add the test file under online_serving/.

I checked how other omni models organize their tests in this repo. The convention seems to be: offline advanced_model tests go in test-merge.yml, and online _expansion.py tests go in test-nightly.yml.

Adjusted accordingly:

  • Added tests/e2e/online_serving/test_covo_audio_expansion.py
  • Removed the separate Covo-Audio-Chat job and reused the existing "Omni Model Test with H100"

@hsliuustc0106 hsliuustc0106 enabled auto-merge (squash) April 9, 2026 07:02
Signed-off-by: Dnoob <dxpouo@gmail.com>
auto-merge was automatically disabled April 13, 2026 07:12

Head branch was pushed to by a user without write access

Signed-off-by: Dnoob <dxpouo@gmail.com>
@linyueqian
Copy link
Copy Markdown
Collaborator

fix ci please

Dnoob added 2 commits April 14, 2026 02:33
Signed-off-by: Dnoob <dxpouo@gmail.com>
Signed-off-by: Dnoob <dxpouo@gmail.com>
@Dnoob
Copy link
Copy Markdown
Contributor Author

Dnoob commented Apr 15, 2026

@linyueqian @hsliuustc0106 CI failures are fixed and all checks pass. PTAL, thanks!

| `HunyuanVideo15Pipeline` | HunyuanVideo-1.5-T2V | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v`, `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v` | ✅︎ | ✅︎ | | |
| `HunyuanVideo15ImageToVideoPipeline` | HunyuanVideo-1.5-I2V | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_i2v`, `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_i2v` | ✅︎ | ✅︎ | | |
| `VoxtralTTSForConditionalGeneration` | Voxtral TTS | `mistralai/Voxtral-4B-TTS-2603` | ✅︎ | ✅︎ | | |
| `CovoAudioForConditionalGeneration` | Covo-Audio-Chat | `tencent/Covo-Audio-Chat` | ✅︎ | | | |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This PR includes online serving support (OpenAI-compatible client example + test_covo_audio_expansion.py), so the Online column should be ✅︎ instead of empty, to match the other models in this table.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see an "Online" column in this table, just NVIDIA GPU、AMD GPU、Ascend NPU、Intel GPU. Which one did you mean?

self.o_dropout = Dropout(dropout)

self.cpu_config = AttentionConfig(True, True, True)
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (vendored code): torch.cuda.get_device_properties(torch.device("cuda")) hardcodes device index 0. In a multi-GPU setup where this vocoder runs on a non-default CUDA device, this may select the wrong flash-attention config. I understand this is vendored from upstream, so just flagging for awareness — no action required for this PR.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker Scan

Category Result
Correctness PASS
Reliability/Safety PASS
Breaking Changes PASS
Test Coverage PASS (offline test in PR desc, CI green, nightly test added)
Documentation PASS (supported_models.md, offline + online examples + README)
Security PASS

Merge Gate

  • DCO: PASS
  • pre-commit: PASS
  • Build: PASS
  • Buildkite CI (amd/intel): PASS
  • Mergeable: FAIL — CONFLICTING, needs rebase onto latest main

Summary

The code is well-structured after multiple review rounds. All previously flagged issues have been addressed:

  • HF repo path resolution fixed (snapshot_download fallback)
  • torchdiffeq lazy import with clear error message
  • Dead parameters removed from forward signatures
  • INFO-level log flooding fixed
  • llm_weights uses generator instead of list materialization
  • Vendored token2wav/ consolidated into single token2wav.py
  • CovoAudioForConditionalGeneration registered in model registry
  • CI test config updated with max_tokens: 2048

Left 2 nits as inline comments (supported_models table + vendored code awareness).

The only remaining blocker is the merge conflict. Please rebase onto latest main and I'll approve.

# Conflicts:
#	vllm_omni/entrypoints/openai/serving_speech.py
@Dnoob
Copy link
Copy Markdown
Contributor Author

Dnoob commented Apr 19, 2026

@hsliuustc0106 PTAL, thanks!

Comment thread pyproject.toml Outdated
[tool.ruff.lint.per-file-ignores]
"examples/**" = ["E501"] # Allow long lines in examples
"tests/**" = ["E501"] # Allow long lines in tests
"**/token2wav/**" = ["E501", "E721", "E741", "F401", "F403", "F405", "F841", "UP028"] # Vendored third-party code
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add this line?

@@ -0,0 +1,78 @@
#
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please follow the new pipeline with #2383 merged

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

also, please add a new recipe

Dnoob added 6 commits April 20, 2026 06:10
Signed-off-by: Dnoob <dxpouo@gmail.com>
Signed-off-by: Dnoob <dxpouo@gmail.com>
Signed-off-by: Dnoob <dxpouo@gmail.com>
Signed-off-by: Dnoob <dxpouo@gmail.com>
Signed-off-by: Dnoob <dxpouo@gmail.com>
@Dnoob
Copy link
Copy Markdown
Contributor Author

Dnoob commented Apr 21, 2026

@hsliuustc0106 All addressed and CI is green. PTAL, thanks!

@lishunyang12
Copy link
Copy Markdown
Collaborator

Resolve conflict.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[New Model] Covo-Audio-Chat (End-to-End Audio LLM)

6 participants