Skip to content

Add Cosmos3 sound generation#4073

Merged
lishunyang12 merged 11 commits into
vllm-project:mainfrom
MaciejBalaNV:mbala/cosmos3_sound
Jun 3, 2026
Merged

Add Cosmos3 sound generation#4073
lishunyang12 merged 11 commits into
vllm-project:mainfrom
MaciejBalaNV:mbala/cosmos3_sound

Conversation

@MaciejBalaNV
Copy link
Copy Markdown
Contributor

@MaciejBalaNV MaciejBalaNV commented Jun 2, 2026

Purpose

This PR is a follow-up to #3454 and adds sound generation to Cosmos3. Already partially reviewed in MaciejBalaNV#1

Cosmos3 model is available under https://huggingface.co/nvidia/Cosmos3-Nano (and more variants)

Test Plan

Unit tests

cd tests; python -m pytest -v -m "core_model and cpu"

Added 23 new test cases for the new model integration and pipeline unit tests.

Serving tests

Host server with

vllm serve nvidia/Cosmos3-Nano --omni

Run a request with
curl -sS -X POST http://localhost:8000/v1/videos/sync -H "Accept: video/mp4" -F "prompt=A low-angle tracking shot follows a man riding a vintage black motorcycle across a lush green grassy yard. Sunlight filters through overhead trees, casting dappled shadows across the vibrating chrome exhaust and the rider's leather jacket. He kicks up small blades of grass as he maneuvers the bike. He gradually decelerates, the front fork compressing slightly as he brakes to a smooth halt beside another individual standing in the shade. The camera settles into a medium two-shot, capturing the rider lifting his visor to speak, his face framed by a matte helmet. The video is 8 seconds long and is of 24 FPS. This video is of 1280x720 resolution. Audio description: The rhythmic, mechanical chugging of a four-stroke motorcycle engine dominates the foreground, characterized by a throaty, guttural timbre. Periodic high-pitched revs punctuate the steady idle as the throttle is twisted. The sound of tires crunching softly over dry grass and twigs provides a textured background layer. As the vehicle slows, the engine note drops to a low-frequency rumble before clicking into neutral. A muffled, mid-range male voice begins speaking, accompanied by the metallic clink of a helmet visor snapping upward and the faint chirping of distant birds in an open-air environment." -F "negative_prompt=blurry, distorted, low quality" -F "size=1280x720" -F "num_frames=193" -F "fps=24" -F "num_inference_steps=35" -F "guidance_scale=4.0" -F "seed=42" -F "generate_sound=true" -F "sound_duration=8" -F 'extra_params={"use_resolution_template":false,"use_duration_template":false}' -o cosmos3_t2v_with_sound.mp4

Test Result

The unit tests pass, including all of the new Cosmos unit tests.

===== 3173 passed, 14 skipped, 1511 deselected, 54 warnings in 327.61s (0:05:27) ===

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5f84b67c0c

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

self.sound_gen = _as_bool(sound_gen_value) if sound_gen_value is not None else sound_dim_value is not None
from .sound_tokenizer import get_sound_dim, get_sound_latent_fps

self.sound_dim = int(sound_dim_value if sound_dim_value is not None else get_sound_dim(od_config))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep transformer sound_dim aligned with tokenizer config

When the bundled sound_tokenizer/config.json is the authoritative source for vocoder_input_dim/io_channels, Cosmos3SoundTokenizer.from_config() uses that component config, but this transformer initialization only looks at od_config/defaults. If the component config overrides or is the only place defining the AVAE latent width, the transformer sound_dim can differ from sound_tokenizer.latent_ch, and every sound request then fails in _prepare_sound_latents with the channel-mismatch check before generation starts. Please resolve sound_dim from the same component config path or synchronize it after loading the tokenizer.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed so that sound_dim and sound_latent_fps are arguments passed to transformer from initialized sound tokenizer

@MaciejBalaNV MaciejBalaNV force-pushed the mbala/cosmos3_sound branch from 5f84b67 to 55c1917 Compare June 2, 2026 08:53
}
if audio_sample_rate is not None:
result["audio_sample_rate"] = int(audio_sample_rate)
return result
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The post-process function now returns a dict {video, audio, fps, audio_sample_rate} when sound is enabled, but still returns a bare tensor/ndarray for video-only paths. This asymmetry may surprise downstream consumers. Consider always returning a dict (even for video-only) or documenting the contract clearly so callers know to check isinstance(result, dict).

Copy link
Copy Markdown
Contributor

@bastefaniak bastefaniak Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be consistent with LTX-2.3 and Wan2.2.

enable_fps_modulation: bool = True,
base_temporal_compression_factor: int | None = None,
) -> tuple[torch.Tensor, int | float]:
"""Generate mRoPE IDs for sound tokens as a (T, 1, 1) grid."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: del base_temporal_compression_factor immediately discards this parameter. Consider removing it from the signature entirely (or renaming to **_kwargs if part of a shared interface).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted

DEFAULT_SOUND_TANH_OUTPUT_SCALE = 3.5
DEFAULT_SOUND_TANH_CLAMP = 0.995
SOUND_TOKENIZER_COMPONENT_NAME = "sound_tokenizer"
SOUND_TOKENIZER_CHECKPOINT_NAME = "diffusion_pytorch_model.safetensors"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper _config_get is identical to the one in audio_tokenizer/avae.py:52. Worth consolidating into a shared utility or importing from one place to avoid drift.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they serve different purposes. avae._config_get finds the value of the first present key in a flat dict. sound_tokenizer._config_get is used by _config_path_get to recursively descend into the config tree looking for a value. They differ in signatures and semantics, so I'd keep them separate.

return
if not isinstance(step_out, tuple):
raise ValueError("Cosmos3 multimodal diffusion step returned a non-tuple result.")
latents = step_out[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When sound_latents is provided, both video and sound noise predictions are packed into a single tensor and stepped together through the scheduler. This works for Euler-style schedulers where the step is stateless per-timestep, but worth confirming the approach is correct for any scheduler that might maintain internal state across steps (e.g. multistep solvers). If the scheduler is always stateless per-step here, a brief comment noting that assumption would help future readers.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment was added around Cosmos3OmniDiffusersPipeline.diffuse._pack_joint

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Thanks for adding sound generation to Cosmos3 — the code is well-structured with thorough test coverage (23 new tests).

One thing missing from the PR description: the model-addition checklist requires profiling data and a baseline comparison (latency, VRAM, output quality vs the upstream/nvidia implementation). For a new capability like this, even a single-row table with bs=1 wall time and peak VRAM measured against the official Cosmos3 repo would be helpful to confirm no regressions.

This isn't blocking given the solid unit test results, but it would strengthen confidence for merge.

@lishunyang12
Copy link
Copy Markdown
Collaborator

lishunyang12 commented Jun 2, 2026

Let's update the recipe.

@lishunyang12 lishunyang12 added the ready label to trigger buildkite CI label Jun 2, 2026
david6666666 pushed a commit to MaciejBalaNV/vllm-omni that referenced this pull request Jun 2, 2026
- _is_sound_request: keep only the two documented flags (generate_sound,
  sound_gen) and drop the four extra aliases, so a non-canonical key can't
  silently yield video-only output.
- _pack_joint: add a comment noting the joint (video, sound) scheduler step is
  valid for flow-matching schedulers (linear per-element update).

Addresses review comments on PR vllm-project#4073.

Signed-off-by: lishunyang12 <lishunyang12@163.com>
@david6666666 david6666666 force-pushed the mbala/cosmos3_sound branch 2 times, most recently from a4f1e69 to b4e0379 Compare June 2, 2026 19:09
@david6666666
Copy link
Copy Markdown
Collaborator

david6666666 commented Jun 2, 2026

Tested end-to-end against the official nvidia/Cosmos3-Nano examples.

  • Unit tests: tests/diffusion/models/cosmos3/ all pass (32 incl. the sound tokenizer + transformer tests).
  • T2V+sound / I2V+sound via /v1/videos/sync with generate_sound=true, sound_duration=7.875 (official params, 1280x720 / 189f / 35 steps): both return video + AAC 48 kHz stereo audio muxed in, matching the structure of the official example_t2vs_output.mp4 / example_i2vs_output.mp4. The AVAE sound tokenizer loads cleanly (sr=48000, 2ch).
case output official ref match
T2VS 1280x720/189f + aac 48000Hz 2ch 1280x720/189f + aac 48000Hz 2ch Yes
I2VS 1280x720/189f + aac 48000Hz 2ch 1280x720/189f + aac 48000Hz 2ch Yes

@lishunyang12
Copy link
Copy Markdown
Collaborator

DCO failed. Let's solve it.

@bastefaniak
Copy link
Copy Markdown
Contributor

Thanks for adding sound generation to Cosmos3 — the code is well-structured with thorough test coverage (23 new tests).

One thing missing from the PR description: the model-addition checklist requires profiling data and a baseline comparison (latency, VRAM, output quality vs the upstream/nvidia implementation). For a new capability like this, even a single-row table with bs=1 wall time and peak VRAM measured against the official Cosmos3 repo would be helpful to confirm no regressions.

This isn't blocking given the solid unit test results, but it would strengthen confidence for merge.

I tested prompt from MR description against cosmos-framework (official repository with pytorch inference code), on RTX 6000 Pro Blackwell, both used around 48GB of memory, inference took around 420s, and quality was comparable

Signed-off-by: Maciej Bala <mbala@nvidia.com>
Signed-off-by: lishunyang12 <lishunyang12@163.com>
Signed-off-by: Maciej Bala <mbala@nvidia.com>
Signed-off-by: lishunyang12 <lishunyang12@163.com>
bastefaniak and others added 7 commits June 2, 2026 19:36
Signed-off-by: lishunyang12 <lishunyang12@163.com>
Signed-off-by: lishunyang12 <lishunyang12@163.com>
…lags

Signed-off-by: lishunyang12 <lishunyang12@163.com>
…nd tokenizer

Signed-off-by: lishunyang12 <lishunyang12@163.com>
Signed-off-by: lishunyang12 <lishunyang12@163.com>
Signed-off-by: Bartosz Stefaniak <bstefaniak@nvidia.com>
Signed-off-by: lishunyang12 <lishunyang12@163.com>
Signed-off-by: lishunyang12 <lishunyang12@163.com>
@lishunyang12
Copy link
Copy Markdown
Collaborator

Let's also add recipe instructions for cosmos-super.

@lishunyang12
Copy link
Copy Markdown
Collaborator

lishunyang12 commented Jun 2, 2026

Thanks for adding sound generation to Cosmos3 — the code is well-structured with thorough test coverage (23 new tests).
One thing missing from the PR description: the model-addition checklist requires profiling data and a baseline comparison (latency, VRAM, output quality vs the upstream/nvidia implementation). For a new capability like this, even a single-row table with bs=1 wall time and peak VRAM measured against the official Cosmos3 repo would be helpful to confirm no regressions.
This isn't blocking given the solid unit test results, but it would strengthen confidence for merge.

I tested prompt from MR description against cosmos-framework (official repository with pytorch inference code), on RTX 6000 Pro Blackwell, both used around 48GB of memory, inference took around 420s, and quality was comparable

Will be running against official repo on B300 for dc usage. I assume we have one more MR for action generation right? We will launch new release towards end of this week and 0.22.0 version can have full support.

Signed-off-by: lishunyang12 <lishunyang12@163.com>
@lishunyang12 lishunyang12 added merge-test label to trigger buildkite merge test CI and removed ready label to trigger buildkite CI labels Jun 2, 2026
Signed-off-by: lishunyang12 <lishunyang12@163.com>
@lishunyang12 lishunyang12 enabled auto-merge (squash) June 2, 2026 20:32
@lishunyang12
Copy link
Copy Markdown
Collaborator

Tested e2e on B300 — all HTTP 200

Nano (sound): T2VS 120.7s · I2VS 123.4s

Super (64B, 2× B300): T2I 6s · T2V 197s · I2V 200s · T2VS 198s · I2VS 201s

Outputs match the official examples (1280×720/189f; AAC 48 kHz stereo for sound cases).

@lishunyang12 lishunyang12 merged commit 40b2959 into vllm-project:main Jun 3, 2026
6 checks passed
@bastefaniak bastefaniak mentioned this pull request Jun 3, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge-test label to trigger buildkite merge test CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants