From 1cc40592cd24f8b65b5cb569a1cc3f7beb4622c3 Mon Sep 17 00:00:00 2001 From: Maciej Bala Date: Thu, 14 May 2026 13:20:41 +0200 Subject: [PATCH 01/41] Added Cosmos3 model Signed-off-by: Maciej Bala --- docs/models/supported_models.md | 1 + .../diffusion/cache_acceleration/cache_dit.md | 21 + .../diffusion/cpu_offload_diffusion.md | 2 + .../diffusion/parallelism/cfg_parallel.md | 9 + docs/user_guide/diffusion_features.md | 6 + .../offline_inference/image_to_video.md | 42 +- .../offline_inference/text_to_image.md | 27 + .../offline_inference/text_to_video.md | 25 + .../examples/online_serving/image_to_video.md | 76 +- .../examples/online_serving/text_to_image.md | 36 + .../examples/online_serving/text_to_video.md | 109 +- .../image_to_video/README.md | 42 +- .../image_to_video/image_to_video.py | 70 +- .../offline_inference/text_to_image/README.md | 27 +- .../text_to_image/text_to_image.py | 3 +- .../text_to_video/text_to_video.md | 24 + .../text_to_video/text_to_video.py | 29 +- .../online_serving/image_to_video/README.md | 73 +- .../online_serving/text_to_image/README.md | 36 + .../online_serving/text_to_video/README.md | 65 +- tests/diffusion/cache/test_cache_dit.py | 19 + tests/diffusion/models/cosmos3/__init__.py | 2 + tests/diffusion/models/cosmos3/conftest.py | 191 ++ .../models/cosmos3/test_cosmos3_pipeline.py | 1108 ++++++++++ .../cosmos3/test_cosmos3_transformer.py | 577 +++++ tests/diffusion/test_diffusion_ipc.py | 25 + tests/e2e/accuracy/test_cosmos3_similarity.py | 155 ++ .../openai_api/test_image_server.py | 1 + .../openai_api/test_video_server.py | 116 ++ .../diffusion/attention/backends/sdpa.py | 3 + .../diffusion/cache/cache_dit_backend.py | 71 + vllm_omni/diffusion/diffusion_engine.py | 25 +- vllm_omni/diffusion/ipc.py | 14 +- .../diffusion/models/cosmos3/__init__.py | 16 + vllm_omni/diffusion/models/cosmos3/action.py | 217 ++ .../cosmos3/audio_tokenizer/__init__.py | 6 + .../cosmos3/audio_tokenizer/activations.py | 147 ++ .../alias_free_torch/__init__.py | 16 + .../audio_tokenizer/alias_free_torch/act.py | 32 + .../alias_free_torch/filter.py | 95 + .../alias_free_torch/resample.py | 48 + .../models/cosmos3/audio_tokenizer/avae.py | 271 +++ .../cosmos3/audio_tokenizer/bottlenecks.py | 133 ++ .../models/cosmos3/audio_tokenizer/config.py | 20 + .../models/cosmos3/audio_tokenizer/models.py | 614 ++++++ .../models/cosmos3/audio_tokenizer/modules.py | 418 ++++ .../audio_tokenizer/modules_encodec.py | 297 +++ .../diffusion/models/cosmos3/guardrails.py | 430 ++++ .../models/cosmos3/pipeline_cosmos3.py | 1848 +++++++++++++++++ .../models/cosmos3/sound_tokenizer.py | 232 +++ .../models/cosmos3/transformer_cosmos3.py | 1586 ++++++++++++++ vllm_omni/diffusion/registry.py | 7 + vllm_omni/engine/async_omni_engine.py | 1 + vllm_omni/entrypoints/openai/api_server.py | 13 +- .../entrypoints/openai/protocol/__init__.py | 2 + .../entrypoints/openai/protocol/videos.py | 22 + vllm_omni/entrypoints/openai/serving_chat.py | 2 + vllm_omni/entrypoints/openai/serving_video.py | 139 +- vllm_omni/inputs/data.py | 1 + 59 files changed, 9561 insertions(+), 82 deletions(-) create mode 100644 tests/diffusion/models/cosmos3/__init__.py create mode 100644 tests/diffusion/models/cosmos3/conftest.py create mode 100644 tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py create mode 100644 tests/diffusion/models/cosmos3/test_cosmos3_transformer.py create mode 100644 tests/diffusion/test_diffusion_ipc.py create mode 100644 tests/e2e/accuracy/test_cosmos3_similarity.py create mode 100644 vllm_omni/diffusion/models/cosmos3/__init__.py create mode 100644 vllm_omni/diffusion/models/cosmos3/action.py create mode 100644 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/__init__.py create mode 100755 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/activations.py create mode 100755 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/__init__.py create mode 100755 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/act.py create mode 100755 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/filter.py create mode 100755 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/resample.py create mode 100644 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/avae.py create mode 100755 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/bottlenecks.py create mode 100644 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/config.py create mode 100755 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/models.py create mode 100755 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/modules.py create mode 100755 vllm_omni/diffusion/models/cosmos3/audio_tokenizer/modules_encodec.py create mode 100644 vllm_omni/diffusion/models/cosmos3/guardrails.py create mode 100644 vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py create mode 100644 vllm_omni/diffusion/models/cosmos3/sound_tokenizer.py create mode 100644 vllm_omni/diffusion/models/cosmos3/transformer_cosmos3.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index f30f3475888..880d7f3939d 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -32,6 +32,7 @@ th { | `ZImagePipeline` | Z-Image | `Tongyi-MAI/Z-Image-Turbo` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `WanPipeline` | Wan2.1-T2V, Wan2.2-T2V, Wan2.2-TI2V | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.1-T2V-14B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `WanImageToVideoPipeline` | Wan2.2-I2V | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | +| `Cosmos3OmniDiffusersPipeline` | Cosmos3 T2I, T2V, I2V | local Diffusers-format Cosmos3 checkpoint (`$COSMOS3_MODEL`) | ✅︎ | | | | | `WanSpeechToVideoPipeline` | Wan2.2-S2V | `Wan-AI/Wan2.2-S2V-14B` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `Wan22VACEPipeline` | Wan2.1-VACE | `Wan-AI/Wan2.1-VACE-1.3B-diffusers`, `Wan-AI/Wan2.1-VACE-14B-diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `LTX2Pipeline` | LTX-2-T2V | `Lightricks/LTX-2` | ✅︎ | ✅︎ | | | diff --git a/docs/user_guide/diffusion/cache_acceleration/cache_dit.md b/docs/user_guide/diffusion/cache_acceleration/cache_dit.md index eaaca84ad6d..8e55e36bd57 100644 --- a/docs/user_guide/diffusion/cache_acceleration/cache_dit.md +++ b/docs/user_guide/diffusion/cache_acceleration/cache_dit.md @@ -128,6 +128,22 @@ python image_edit.py \ See the [image_edit.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py) for detailed configuration options. +For Cosmos3 text-to-video or image-to-video, use the video examples with the Cosmos3 pipeline class: + +```bash +cd examples/offline_inference/text_to_video +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +python text_to_video.py \ + --model "$COSMOS3_MODEL" \ + --model-class-name Cosmos3OmniDiffusersPipeline \ + --prompt "A small warehouse robot moves a blue box across a clean floor." \ + --cache-backend cache_dit \ + --num-inference-steps 35 +``` + +Cosmos3 Cache-DiT wraps the GEN denoising path. TeaCache is not implemented for Cosmos3. + ### Online Serving ```bash @@ -138,6 +154,11 @@ vllm serve Qwen/Qwen-Image --omni --port 8091 --cache-backend cache_dit vllm serve Qwen/Qwen-Image --omni --port 8091 \ --cache-backend cache_dit \ --cache-config '{"Fn_compute_blocks": 1, "residual_diff_threshold": 0.12}' + +# Cosmos3 +vllm serve "$COSMOS3_MODEL" --omni --port 8091 \ + --model-class-name Cosmos3OmniDiffusersPipeline \ + --cache-backend cache_dit ``` --- diff --git a/docs/user_guide/diffusion/cpu_offload_diffusion.md b/docs/user_guide/diffusion/cpu_offload_diffusion.md index 39dc366485e..d725502da1d 100644 --- a/docs/user_guide/diffusion/cpu_offload_diffusion.md +++ b/docs/user_guide/diffusion/cpu_offload_diffusion.md @@ -194,6 +194,7 @@ Factory function `get_offload_backend()` selects the appropriate backend based o | OvisImagePipeline | `AIDC-AI/Ovis-Image-7B` | `OvisImageTransformer2DModel` | - | ✓ | `"transformer"` | | QwenImagePipeline | `Qwen/Qwen-Image` | `QwenImageTransformer2DModel` | ✓ | ✓ | `"transformer_blocks"` | | StableDiffusion3Pipeline | `stabilityai/stable-diffusion-3.5-medium` | `SD3Transformer2DModel` | - | ✓ | `"transformer_blocks"` | +| Cosmos3OmniDiffusersPipeline | `$COSMOS3_MODEL` | `Cosmos3VFMTransformer` | - | ✓ | `"gen_layers"` | | Wan22I2VPipeline | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | `WanTransformer3DModel` | ✓ | ✓ | `"blocks"` | | Wan22Pipeline | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | `WanTransformer3DModel` | ✓ | ✓ | `"blocks"` | | BagelPipeline | `ByteDance-Seed/BAGEL-7B-MoT` | `Qwen2MoTModel` | - | ✓ | `"layers"`, `"customized modules"` | @@ -201,3 +202,4 @@ Factory function `get_offload_backend()` selects the appropriate backend based o **Notes:** - Model-Level Offloading is expected to be supported by all common diffusion models (DiT and encoders) naturally - Layerwise Offloading requires DiT class to define `_layerwise_offload_blocks_attrs` pointing to transformer blocks +- Cosmos3 uses the singular `_layerwise_offload_blocks_attr` compatibility path and offloads GEN decoder layers. diff --git a/docs/user_guide/diffusion/parallelism/cfg_parallel.md b/docs/user_guide/diffusion/parallelism/cfg_parallel.md index 5541106680a..ce468d817cd 100644 --- a/docs/user_guide/diffusion/parallelism/cfg_parallel.md +++ b/docs/user_guide/diffusion/parallelism/cfg_parallel.md @@ -144,6 +144,15 @@ sampling_params = OmniDiffusionSamplingParams( ) ``` +For Cosmos3, use `guidance_scale` rather than `true_cfg_scale`: + +```python +sampling_params = OmniDiffusionSamplingParams( + num_inference_steps=35, + guidance_scale=4.0, +) +``` + 2. **Add negative prompt:** ```python outputs = omni.generate( diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index 18792a9d665..606c8b9aeca 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -108,6 +108,7 @@ The following tables show which models support each feature: | Model | ⚡TeaCache | ⚡Cache-DiT | 🔀SP (Ulysses & Ring) | 🔀CFG-Parallel | 🔀Tensor-Parallel | 🔀HSDP | 💾CPU Offload (Layerwise) | 💾VAE-Patch-Parallel | 💾Quantization | 🔄Step Execution | |-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| | **Bagel** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| **Cosmos3 (T2I)** | ❌ | ✅ | ✅ (Ulysses) | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | | **FLUX.1-dev** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | **FLUX.1-schnell** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | **FLUX.2-klein** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | @@ -135,12 +136,14 @@ The following tables show which models support each feature: > Notes: > 1. Nextstep_1(T2I) does not support cache acceleration methods such as TeaCache or Cache-DiT. > 2. `Tongyi-MAI/Z-Image-Turbo` and `SII-GAIR/daVinci-MagiHuman-Base-1080p` are distilled models with minimal NFEs; CFG-Parallel is not necessary. +> 3. Cosmos3 T2I uses `Cosmos3OmniDiffusersPipeline` with `modalities=["image"]`. Model-level CPU offload is not supported; use layerwise offload. ### VideoGen | Model | ⚡TeaCache | ⚡Cache-DiT | 🔀SP (Ulysses & Ring) | 🔀CFG-Parallel | 🔀Tensor-Parallel | 🔀HSDP | 💾CPU Offload (Layerwise) | 💾VAE-Patch-Parallel | 💾Quantization | 🔄Step Execution | |-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| | **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ | +| **Cosmos3 (T2V/I2V)** | ❌ | ✅ | ✅ (Ulysses) | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ✅ | ❌ | | **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ | | **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | **LTX-2.3** | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | @@ -148,6 +151,9 @@ The following tables show which models support each feature: | **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | | **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | +> Notes: +> 1. Cosmos3 T2V and I2V use `Cosmos3OmniDiffusersPipeline` with video output. I2V is selected when the request includes an input image. Model-level CPU offload is not supported; use layerwise offload. + **Frame Interpolation Support** - **Supported**: Wan2.2 text-to-video, image-to-video, and TI2V pipelines diff --git a/docs/user_guide/examples/offline_inference/image_to_video.md b/docs/user_guide/examples/offline_inference/image_to_video.md index 6e105741a7e..5011ccf1978 100644 --- a/docs/user_guide/examples/offline_inference/image_to_video.md +++ b/docs/user_guide/examples/offline_inference/image_to_video.md @@ -3,7 +3,15 @@ Source . -This example demonstrates how to generate videos from images using Wan2.2 Image-to-Video models with vLLM-Omni's offline inference API. +This example demonstrates how to generate videos from images using Wan2.2 Image-to-Video models and Cosmos3 with vLLM-Omni's offline inference API. + +## Supported Models + +| Model | Default Resolution | Default Frames | Default Steps | Guidance | +|-------|--------------------|----------------|---------------|----------| +| `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | auto, 480p area | 81 | 50 | 5.0 | +| `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | auto, 480p area | 81 | 50 | 5.0 | +| `$COSMOS3_MODEL` with `Cosmos3OmniDiffusersPipeline` | auto, 720p area | 81 | 35 | 4.0 | ## Local CLI Usage @@ -51,20 +59,46 @@ python image_to_video.py \ --output i2v_output.mp4 ``` +### Cosmos3 + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +python image_to_video.py \ + --model "$COSMOS3_MODEL" \ + --model-class-name Cosmos3OmniDiffusersPipeline \ + --image cherry_blossom.jpg \ + --prompt "Cherry blossoms swaying gently in the breeze, petals falling, smooth motion" \ + --negative-prompt "blurry, distorted, low quality" \ + --height 720 \ + --width 1280 \ + --num-frames 81 \ + --guidance-scale 4.0 \ + --num-inference-steps 35 \ + --fps 24 \ + --output cosmos3_i2v_output.mp4 +``` + +For Cosmos3 I2V, the input image is resized and center-cropped by the pipeline. If `--height` and `--width` are omitted, this example chooses a 720p-area resolution from the input aspect ratio. Cosmos3 currently supports one prompt and one video per request, and model-level CPU offload is not supported; use `--enable-layerwise-offload` instead. + Key arguments: - `--model`: Model ID (I2V-A14B for MoE, TI2V-5B for unified T2V+I2V). +- `--model-class-name`: explicit pipeline class. Use `Cosmos3OmniDiffusersPipeline` for Cosmos3 checkpoints. - `--image`: Path to input image (required). - `--prompt`: Text description of desired motion/animation. - `--height/--width`: Output resolution (auto-calculated from image if not set). Dimensions should be multiples of 16. -- `--num-frames`: Number of frames (default 81). +- `--num-frames`: Number of frames (default is model-specific). - `--guidance-scale` and `--guidance-scale-high`: CFG scale (applied to low/high-noise stages for MoE). - `--negative-prompt`: Optional list of artifacts to suppress. - `--boundary-ratio`: Boundary split ratio for two-stage MoE models. -- `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p). +- `--flow-shift`: Scheduler flow shift. Defaults are model-specific. - `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints. -- `--num-inference-steps`: Number of denoising steps (default 50). +- `--num-inference-steps`: Number of denoising steps (default is model-specific). - `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video). +- `--frame-rate`: Generation frame rate for models that use it. Defaults to `--fps`. - `--output`: Path to save the generated video. - `--vae-use-slicing`: Enable VAE slicing for memory optimization. - `--vae-use-tiling`: Enable VAE tiling for memory optimization. diff --git a/docs/user_guide/examples/offline_inference/text_to_image.md b/docs/user_guide/examples/offline_inference/text_to_image.md index 3a97ffbf74b..e9bf48d7aa1 100644 --- a/docs/user_guide/examples/offline_inference/text_to_image.md +++ b/docs/user_guide/examples/offline_inference/text_to_image.md @@ -36,6 +36,7 @@ This folder provides several entrypoints for experimenting with text-to-image di | `black-forest-labs/FLUX.2-klein-4B` | 1024 x 1024 | 72.7 | 14.9 | | `black-forest-labs/FLUX.2-klein-9B` | 1024 x 1024 | 37.1 | 32.3 | | `black-forest-labs/FLUX.2-dev` | 1024 x 1024 | 65.7 | >80 (CPU offload required) | +| `$COSMOS3_MODEL` with `Cosmos3OmniDiffusersPipeline` | 1024 x 1024 | model/checkpoint dependent | local checkpoint | !!! info *Peak VRAM: based on basic single-card usage, batch size =1, without any acceleration/optimization features. FLUX.2-dev requires `--enable-cpu-offload` on a single 80 GiB GPU. @@ -74,6 +75,7 @@ python text_to_image.py \ | Argument | Type | Default | Description | | -------- | ---- | ------- | ----------- | +| `--model` | str | `"Qwen/Qwen-Image"` | Diffusion model name or local path | | `--prompt` | str | `"a cup of coffee on the table"` | Text description for image generation | | `--seed` | int | `142` | Integer seed for deterministic sampling | | `--negative-prompt` | str | `None` | Negative prompt for classifier-free conditional guidance | @@ -87,6 +89,9 @@ python text_to_image.py \ | `--vae-use-slicing` | flag | off | Enable VAE slicing for memory optimization | | `--vae-use-tiling` | flag | off | Enable VAE tiling for memory optimization | | `--cfg-parallel-size` | int | `1` | Set to `2` to enable CFG Parallel | +| `--ulysses-degree` | int | `1` | Ulysses sequence parallel degree for multi-GPU inference | +| `--ring-degree` | int | `1` | Ring sequence parallel degree for hybrid Ulysses + Ring inference | +| `--ulysses-mode` | str | `"strict"` | Ulysses SP mode: `"strict"` or `"advanced_uaa"` | | `--enable-cpu-offload` | flag | off | Enable CPU offloading for diffusion models | | `--lora-path` | str | — | Path to PEFT LoRA adapter folder | | `--lora-scale` | float | `1.0` | Scale factor for LoRA weights | @@ -160,6 +165,28 @@ python examples/offline_inference/text_to_image/text_to_image.py \ --output flux2-dev.png ``` +### Cosmos3 + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +python text_to_image.py \ + --model "$COSMOS3_MODEL" \ + --model-class-name Cosmos3OmniDiffusersPipeline \ + --prompt "A small warehouse robot carrying a blue box, clean product photography" \ + --negative-prompt "blurry, distorted, low quality" \ + --guidance-scale 7.0 \ + --num-inference-steps 50 \ + --height 1024 \ + --width 1024 \ + --num-images-per-prompt 1 \ + --output cosmos3_t2i.png +``` + +This script marks text-to-image requests with `modalities=["image"]`, which selects Cosmos3 T2I. Cosmos3 currently supports one prompt per request; use `--num-images-per-prompt` to request multiple images for that prompt. Model-level CPU offload is not supported for Cosmos3, so use `--enable-layerwise-offload` for offload instead. + ### Batch Requests (Multiple Prompts) You can pass multiple prompts in a single `generate` call. diff --git a/docs/user_guide/examples/offline_inference/text_to_video.md b/docs/user_guide/examples/offline_inference/text_to_video.md index a09dbfc979f..861af8ca1d4 100644 --- a/docs/user_guide/examples/offline_inference/text_to_video.md +++ b/docs/user_guide/examples/offline_inference/text_to_video.md @@ -14,6 +14,7 @@ For backend selection and SageAttention usage, see the [Diffusion Attention Back | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | 720x1280 | 81 | 40 | 4.0 | ~60 GiB | | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v` | 480x832 | 121 | 50 | 6.0 | 1×A100 80GB | | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v` | 720x1280 | 121 | 50 | 6.0 | FP8 + VAE tiling required | +| `$COSMOS3_MODEL` with `Cosmos3OmniDiffusersPipeline` | 720x1280 | 81 | 35 | 4.0 | model/checkpoint dependent | ## Local CLI Usage @@ -50,6 +51,29 @@ python text_to_video.py \ --output ltx2_out.mp4 ``` +### Cosmos3 + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +python text_to_video.py \ + --model "$COSMOS3_MODEL" \ + --model-class-name Cosmos3OmniDiffusersPipeline \ + --prompt "A small warehouse robot moves a blue box across a clean floor." \ + --negative-prompt "blurry, distorted, low quality" \ + --height 720 \ + --width 1280 \ + --num-frames 81 \ + --guidance-scale 4.0 \ + --num-inference-steps 35 \ + --fps 24 \ + --output cosmos3_t2v_output.mp4 +``` + +Cosmos3 video generation currently supports one prompt and one video per request. The implementation supports `--cache-backend cache_dit`, `--cfg-parallel-size 2`, `--ulysses-degree`, `--tensor-parallel-size`, `--use-hsdp`, and `--enable-layerwise-offload`. Do not use `--enable-cpu-offload`; Cosmos3 does not support model-level CPU offload. + ### HunyuanVideo-1.5 (480p) ```bash @@ -127,6 +151,7 @@ python text_to_video.py \ - `--audio-sample-rate`: audio sample rate for embedded audio (when the pipeline returns audio). - `--quantization`: quantization method (`fp8` for FP8, `gguf` for GGUF). - `--flow-shift`: scheduler flow_shift parameter. +- `--cache-backend`: `cache_dit` for supported models. ### Wan2.2-specific diff --git a/docs/user_guide/examples/online_serving/image_to_video.md b/docs/user_guide/examples/online_serving/image_to_video.md index 781f0c2a5ed..1ef5c9be318 100644 --- a/docs/user_guide/examples/online_serving/image_to_video.md +++ b/docs/user_guide/examples/online_serving/image_to_video.md @@ -3,7 +3,15 @@ Source . -This example demonstrates how to deploy the Wan2.2 image-to-video model for online video generation using vLLM-Omni. +This example demonstrates how to deploy image-to-video models, including Wan2.2 and Cosmos3, for online video generation using vLLM-Omni. + +## Supported Models + +| Model | Model ID | +|-------|----------| +| Wan2.2 I2V | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | +| Wan2.2 TI2V | `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | +| Cosmos3 I2V | `$COSMOS3_MODEL` with `Cosmos3OmniDiffusersPipeline` | ## Start Server @@ -29,6 +37,22 @@ The script allows overriding: - `CACHE_BACKEND` (default: `none`) - `ENABLE_CACHE_DIT_SUMMARY` (default: `0`) +### Cosmos3 + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +vllm serve "$COSMOS3_MODEL" \ + --omni \ + --port 8091 \ + --model-class-name Cosmos3OmniDiffusersPipeline \ + --allowed-local-media-path / +``` + +Use `--enable-layerwise-offload`, `--cache-backend cache_dit`, `--cfg-parallel-size 2`, `--usp`, `--tensor-parallel-size`, or `--use-hsdp` as needed. Do not use `--enable-cpu-offload`; Cosmos3 does not support model-level CPU offload. + ## Async Job Behavior `POST /v1/videos` is asynchronous. It creates a video job and immediately @@ -59,6 +83,7 @@ file. Metadata is returned via response headers: - `X-Model`: model name used for generation - `X-Inference-Time-S`: wall-clock inference time in seconds +### Wan2.2 Sync Request ```bash curl -X POST http://localhost:8091/v1/videos/sync \ -F "prompt=A bear playing with yarn, smooth motion" \ @@ -79,6 +104,53 @@ curl -X POST http://localhost:8091/v1/videos/sync \ -o sync_i2v_output.mp4 ``` +### Cosmos3 Sync Request + +```bash +curl -X POST http://localhost:8091/v1/videos/sync \ + -F "prompt=Cherry blossoms swaying gently in the breeze, petals falling, smooth motion" \ + -F "negative_prompt=blurry, distorted, low quality" \ + -F "input_reference=@/path/to/cherry_blossom.jpg" \ + -F "size=1280x720" \ + -F "num_frames=81" \ + -F "fps=24" \ + -F "num_inference_steps=35" \ + -F "guidance_scale=4.0" \ + -F "seed=42" \ + -o cosmos3_i2v_output.mp4 +``` + +For async generation, send the same form fields to `POST /v1/videos`, poll `GET /v1/videos/{video_id}`, and download from `GET /v1/videos/{video_id}/content`. Cosmos3 currently supports one prompt and one video per request. + +```bash +create_response=$(curl -s http://localhost:8091/v1/videos \ + -H "Accept: application/json" \ + -F "prompt=Cherry blossoms swaying gently in the breeze, petals falling, smooth motion" \ + -F "negative_prompt=blurry, distorted, low quality" \ + -F "input_reference=@/path/to/cherry_blossom.jpg" \ + -F "size=1280x720" \ + -F "num_frames=81" \ + -F "fps=24" \ + -F "num_inference_steps=35" \ + -F "guidance_scale=4.0" \ + -F "seed=42") + +video_id=$(echo "$create_response" | jq -r '.id') +while true; do + status=$(curl -s "http://localhost:8091/v1/videos/${video_id}" | jq -r '.status') + if [ "$status" = "completed" ]; then + break + fi + if [ "$status" = "failed" ]; then + echo "Video generation failed" + exit 1 + fi + sleep 2 +done + +curl -L "http://localhost:8091/v1/videos/${video_id}/content" -o cosmos3_i2v_output.mp4 +``` + ## Storage Generated video files are stored on local disk by the async video API. @@ -103,6 +175,7 @@ export VLLM_OMNI_STORAGE_MAX_CONCURRENCY=8 bash run_curl_image_to_video.sh # Or execute directly (OpenAI-style multipart) +# Note: frame interpolation specific arguments are relevant only for Wan2.2 models create_response=$(curl -s http://localhost:8091/v1/videos \ -H "Accept: application/json" \ -F "prompt=A bear playing with yarn, smooth motion" \ @@ -165,6 +238,7 @@ curl -X POST http://localhost:8091/v1/videos \ ### Generation with Parameters ```bash +# Note: frame interpolation specific arguments are relevant only for Wan2.2 models curl -X POST http://localhost:8091/v1/videos \ -F "prompt=A bear playing with yarn, smooth motion" \ -F "negative_prompt=low quality, blurry, static" \ diff --git a/docs/user_guide/examples/online_serving/text_to_image.md b/docs/user_guide/examples/online_serving/text_to_image.md index 69c1480e39f..894a1b4be6b 100644 --- a/docs/user_guide/examples/online_serving/text_to_image.md +++ b/docs/user_guide/examples/online_serving/text_to_image.md @@ -23,6 +23,21 @@ Or use the startup script: bash run_server.sh ``` +### Cosmos3 + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +vllm serve "$COSMOS3_MODEL" \ + --omni \ + --port 8091 \ + --model-class-name Cosmos3OmniDiffusersPipeline +``` + +Use `--enable-layerwise-offload`, `--cache-backend cache_dit`, `--cfg-parallel-size 2`, `--usp`, `--tensor-parallel-size`, or `--use-hsdp` as needed. Do not use `--enable-cpu-offload`; Cosmos3 does not support model-level CPU offload. + ### Start with Parallelism Acceleration Enable Tensor Parallelism and VAE Patch Parallelism for faster inference: @@ -71,6 +86,26 @@ curl -s http://localhost:8091/v1/chat/completions \ }' | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > output.png ``` +#### Cosmos3 Images API + +The dedicated image endpoint sets `modalities=["image"]` internally, which selects Cosmos3 text-to-image. + +```bash +curl -X POST http://localhost:8091/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A small warehouse robot carrying a blue box, clean product photography", + "size": "1024x1024", + "n": 1, + "num_inference_steps": 50, + "guidance_scale": 7.0, + "negative_prompt": "blurry, distorted, low quality", + "seed": 42 + }' | jq -r '.data[0].b64_json' | base64 -d > cosmos3_t2i.png +``` + +Cosmos3 currently supports one prompt per request. Use `n` to request multiple images for that prompt. + ### Method 2: Using OpenAI Python SDK ```python @@ -248,6 +283,7 @@ directly. For image dimensions and count, use `size` and `n` rather than | `height` | int | None | Image height in pixels | | `width` | int | None | Image width in pixels | | `size` | str | None | Image size (e.g., "1024x1024") | +| `n` | int | 1 | Number of images for `/v1/images/generations` | | `num_inference_steps` | int | 50 | Number of denoising steps | | `true_cfg_scale` | float | 4.0 | Qwen-Image CFG scale | | `seed` | int | None | Random seed (reproducible) | diff --git a/docs/user_guide/examples/online_serving/text_to_video.md b/docs/user_guide/examples/online_serving/text_to_video.md index b918aac19d0..d5aa8a154ff 100644 --- a/docs/user_guide/examples/online_serving/text_to_video.md +++ b/docs/user_guide/examples/online_serving/text_to_video.md @@ -13,6 +13,7 @@ This example demonstrates how to deploy text-to-video models for online video ge | Wan2.1 T2V (14B) | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | | Wan2.2 T2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | | LTX-2 | `Lightricks/LTX-2` | +| Cosmos3 T2V | `$COSMOS3_MODEL` with `Cosmos3OmniDiffusersPipeline` | ## Wan2.2 T2V @@ -40,6 +41,23 @@ The script allows overriding: - `CACHE_BACKEND` (default: `none`) - `ENABLE_CACHE_DIT_SUMMARY` (default: `0`) +## Cosmos3 T2V + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +### Start Server + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +vllm serve "$COSMOS3_MODEL" \ + --omni \ + --port 8091 \ + --model-class-name Cosmos3OmniDiffusersPipeline +``` + +Use `--enable-layerwise-offload`, `--cache-backend cache_dit`, `--cfg-parallel-size 2`, `--usp`, `--tensor-parallel-size`, or `--use-hsdp` as needed. Do not use `--enable-cpu-offload`; Cosmos3 does not support model-level CPU offload. + ## Async Job Behavior `POST /v1/videos` is asynchronous. It creates a video job and immediately @@ -85,6 +103,51 @@ curl -X POST http://localhost:8091/v1/videos/sync \ -o sync_t2v_output.mp4 ``` +### Cosmos3 Sync Request + +```bash +curl -X POST http://localhost:8091/v1/videos/sync \ + -F "prompt=A small warehouse robot moves a blue box across a clean floor." \ + -F "negative_prompt=blurry, distorted, low quality" \ + -F "size=1280x720" \ + -F "num_frames=81" \ + -F "fps=24" \ + -F "num_inference_steps=35" \ + -F "guidance_scale=4.0" \ + -F "seed=42" \ + -o cosmos3_t2v_output.mp4 +``` + +For async generation, send the same form fields to `POST /v1/videos`, poll `GET /v1/videos/{video_id}`, and download from `GET /v1/videos/{video_id}/content`. Cosmos3 currently supports one prompt and one video per request. + +```bash +create_response=$(curl -s http://localhost:8091/v1/videos \ + -H "Accept: application/json" \ + -F "prompt=A small warehouse robot moves a blue box across a clean floor." \ + -F "negative_prompt=blurry, distorted, low quality" \ + -F "size=1280x720" \ + -F "num_frames=81" \ + -F "fps=24" \ + -F "num_inference_steps=35" \ + -F "guidance_scale=4.0" \ + -F "seed=42") + +video_id=$(echo "$create_response" | jq -r '.id') +while true; do + status=$(curl -s "http://localhost:8091/v1/videos/${video_id}" | jq -r '.status') + if [ "$status" = "completed" ]; then + break + fi + if [ "$status" = "failed" ]; then + echo "Video generation failed" + exit 1 + fi + sleep 2 +done + +curl -L "http://localhost:8091/v1/videos/${video_id}/content" -o cosmos3_t2v_output.mp4 +``` + ## Storage Generated video files are stored on local disk by the async video API. @@ -153,6 +216,7 @@ curl -X POST http://localhost:8091/v1/videos \ ### Generation with Parameters ```bash +# Note: frame interpolation specific arguments are relevant only for Wan2.2 models curl -X POST http://localhost:8091/v1/videos \ -F "prompt=A cinematic view of a futuristic city at sunset" \ -F "width=832" \ @@ -173,32 +237,32 @@ curl -X POST http://localhost:8091/v1/videos \ ## Generation Parameters -| Parameter | Type | Default | Description | -| --------------------- | ------ | ------- | ------------------------------------------------ | -| `prompt` | str | - | Text description of the desired video | -| `seconds` | str | None | Clip duration in seconds | -| `size` | str | None | Output size in `WIDTHxHEIGHT` format | -| `negative_prompt` | str | None | Negative prompt | -| `width` | int | None | Video width in pixels | -| `height` | int | None | Video height in pixels | -| `num_frames` | int | None | Number of frames to generate | -| `fps` | int | None | Frames per second for output video | -| `num_inference_steps` | int | None | Number of denoising steps | -| `guidance_scale` | float | None | CFG guidance scale (low-noise stage) | -| `guidance_scale_2` | float | None | CFG guidance scale (high-noise stage, Wan2.2) | -| `boundary_ratio` | float | None | Boundary split ratio for low/high DiT (Wan2.2) | -| `flow_shift` | float | None | Scheduler flow shift (Wan2.2) | -| `seed` | int | None | Random seed (reproducible) | -| `lora` | object | None | LoRA configuration | -| `enable_frame_interpolation` | bool | false | Enable RIFE frame interpolation before MP4 encoding | -| `frame_interpolation_exp` | int | 1 | Interpolation exponent; 1=2x temporal resolution, 2=4x | -| `frame_interpolation_scale` | float | 1.0 | RIFE inference scale; use 0.5 for high-resolution inputs | -| `frame_interpolation_model_path` | str | None | Local directory or Hugging Face repo ID with `flownet.pkl`; defaults to `elfgum/RIFE-4.22.lite` | +| Parameter | Type | Default | Description | +| --------------------- | ------ | ------- |----------------------------------------------------------------------------------------------------------| +| `prompt` | str | - | Text description of the desired video | +| `seconds` | str | None | Clip duration in seconds | +| `size` | str | None | Output size in `WIDTHxHEIGHT` format | +| `negative_prompt` | str | None | Negative prompt | +| `width` | int | None | Video width in pixels | +| `height` | int | None | Video height in pixels | +| `num_frames` | int | None | Number of frames to generate | +| `fps` | int | None | Frames per second for output video | +| `num_inference_steps` | int | None | Number of denoising steps | +| `guidance_scale` | float | None | CFG guidance scale (low-noise stage) | +| `guidance_scale_2` | float | None | CFG guidance scale (high-noise stage, Wan2.2) | +| `boundary_ratio` | float | None | Boundary split ratio for low/high DiT (Wan2.2) | +| `flow_shift` | float | None | Scheduler flow shift | +| `seed` | int | None | Random seed (reproducible) | +| `lora` | object | None | LoRA configuration | +| `enable_frame_interpolation` | bool | false | Enable RIFE frame interpolation before MP4 encoding (Wan2.2) | +| `frame_interpolation_exp` | int | 1 | Interpolation exponent; 1=2x temporal resolution, 2=4x (Wan2.2) | +| `frame_interpolation_scale` | float | 1.0 | RIFE inference scale; use 0.5 for high-resolution inputs (Wan2.2) | +| `frame_interpolation_model_path` | str | None | Local directory or Hugging Face repo ID with `flownet.pkl`; defaults to `elfgum/RIFE-4.22.lite` (Wan2.2) | ## Frame Interpolation Frame interpolation is an optional post-processing step for `/v1/videos` and -`/v1/videos/sync`. It synthesizes intermediate frames between generated frames +`/v1/videos/sync`, supported by Wan2.2 models. It synthesizes intermediate frames between generated frames without rerunning the diffusion model. If the generated video has `N` frames, the interpolated output frame count is `(N - 1) * 2**exp + 1`. The encoder FPS is multiplied by `2**exp` so the output duration remains close to the original. @@ -210,6 +274,7 @@ device without blocking the FastAPI event loop. Example: generate 5 frames and interpolate to 9 frames: ```bash +# Note: frame interpolation specific arguments are relevant only for Wan2.2 models curl -X POST http://localhost:8091/v1/videos/sync \ -F "prompt=A dog running through a park" \ -F "num_frames=5" \ diff --git a/examples/offline_inference/image_to_video/README.md b/examples/offline_inference/image_to_video/README.md index a458850a02b..8de4cafce78 100644 --- a/examples/offline_inference/image_to_video/README.md +++ b/examples/offline_inference/image_to_video/README.md @@ -1,6 +1,14 @@ # Image-To-Video -This example demonstrates how to generate videos from images using Wan2.2 Image-to-Video models with vLLM-Omni's offline inference API. +This example demonstrates how to generate videos from images using Wan2.2 Image-to-Video models and Cosmos3 with vLLM-Omni's offline inference API. + +## Supported Models + +| Model | Default Resolution | Default Frames | Default Steps | Guidance | +|-------|--------------------|----------------|---------------|----------| +| `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | auto, 480p area | 81 | 50 | 5.0 | +| `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | auto, 480p area | 81 | 50 | 5.0 | +| `$COSMOS3_MODEL` with `Cosmos3OmniDiffusersPipeline` | auto, 720p area | 81 | 35 | 4.0 | ## Local CLI Usage @@ -48,20 +56,46 @@ python image_to_video.py \ --output i2v_output.mp4 ``` +### Cosmos3 + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +python image_to_video.py \ + --model "$COSMOS3_MODEL" \ + --model-class-name Cosmos3OmniDiffusersPipeline \ + --image cherry_blossom.jpg \ + --prompt "Cherry blossoms swaying gently in the breeze, petals falling, smooth motion" \ + --negative-prompt "blurry, distorted, low quality" \ + --height 720 \ + --width 1280 \ + --num-frames 81 \ + --guidance-scale 4.0 \ + --num-inference-steps 35 \ + --fps 24 \ + --output cosmos3_i2v_output.mp4 +``` + +For Cosmos3 I2V, the input image is resized and center-cropped by the pipeline. If `--height` and `--width` are omitted, this example chooses a 720p-area resolution from the input aspect ratio. Cosmos3 currently supports one prompt and one video per request, and model-level CPU offload is not supported; use `--enable-layerwise-offload` instead. + Key arguments: - `--model`: Model ID (I2V-A14B for MoE, TI2V-5B for unified T2V+I2V). +- `--model-class-name`: explicit pipeline class. Use `Cosmos3OmniDiffusersPipeline` for Cosmos3 checkpoints. - `--image`: Path to input image (required). - `--prompt`: Text description of desired motion/animation. - `--height/--width`: Output resolution (auto-calculated from image if not set). Dimensions should be multiples of 16. -- `--num-frames`: Number of frames (default 81). +- `--num-frames`: Number of frames (default is model-specific). - `--guidance-scale` and `--guidance-scale-high`: CFG scale (applied to low/high-noise stages for MoE). - `--negative-prompt`: Optional list of artifacts to suppress. - `--boundary-ratio`: Boundary split ratio for two-stage MoE models. -- `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p). +- `--flow-shift`: Scheduler flow shift. Defaults are model-specific. - `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints. -- `--num-inference-steps`: Number of denoising steps (default 50). +- `--num-inference-steps`: Number of denoising steps (default is model-specific). - `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video). +- `--frame-rate`: Generation frame rate for models that use it. Defaults to `--fps`. - `--output`: Path to save the generated video. - `--vae-use-slicing`: Enable VAE slicing for memory optimization. - `--vae-use-tiling`: Enable VAE tiling for memory optimization. diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 84fbf2a94ca..b89409e50e4 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -2,13 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Image-to-Video generation example using Wan2.2 I2V/TI2V models, LTX2, or HunyuanVideo-1.5. +Image-to-Video generation example using Wan2.2 I2V/TI2V models, LTX2, HunyuanVideo-1.5, or Cosmos3. Supports: - Wan2.2-I2V-A14B-Diffusers: MoE model with CLIP image encoder - Wan2.2-TI2V-5B-Diffusers: Unified T2V+I2V model (dense 5B) - LTX2 image-to-video pipeline - HunyuanVideo-1.5 I2V: SigLIP + VAE dual image conditioning +- Cosmos3: unified text-to-image, text-to-video, and image-to-video pipeline Usage: # Wan I2V-A14B (MoE) @@ -30,6 +31,13 @@ python image_to_video.py --model hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_i2v \ --image input.jpg --prompt "A cat playing with yarn" \ --flow-shift 5.0 --guidance-scale 6.0 + + # Cosmos3 image-to-video + python image_to_video.py --model "$COSMOS3_MODEL" \ + --model-class-name Cosmos3OmniDiffusersPipeline \ + --image input.jpg --prompt "A cinematic dolly shot of a boat" \ + --height 720 --width 1280 --num-frames 81 \ + --num-inference-steps 35 --guidance-scale 4.0 --fps 24 """ import argparse @@ -60,7 +68,9 @@ def parse_profiler_config(value: str) -> dict[str, Any]: def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Generate a video from an image (Wan2.2, LTX2, HunyuanVideo-1.5).") + parser = argparse.ArgumentParser( + description="Generate a video from an image (Wan2.2, LTX2, HunyuanVideo-1.5, Cosmos3)." + ) parser.add_argument( "--model", default="Wan-AI/Wan2.2-I2V-A14B-Diffusers", @@ -69,13 +79,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--model-class-name", default=None, - help="Override model class name (e.g., LTX2ImageToVideoPipeline).", + help="Override model class name (e.g., Cosmos3OmniDiffusersPipeline or LTX2ImageToVideoPipeline).", ) parser.add_argument("--image", required=True, help="Path to input image.") parser.add_argument("--prompt", default="", help="Text prompt describing the desired motion.") parser.add_argument("--negative-prompt", default="", help="Negative prompt.") parser.add_argument("--seed", type=int, default=42, help="Random seed.") - parser.add_argument("--guidance-scale", type=float, default=5.0, help="CFG scale.") + parser.add_argument("--guidance-scale", type=float, default=None, help="CFG scale. Default: model-specific.") parser.add_argument( "--guidance-scale-high", type=float, default=None, help="Optional separate CFG for high-noise (MoE only)." ) @@ -83,8 +93,10 @@ def parse_args() -> argparse.Namespace: "--height", type=int, default=None, help="Video height (auto-calculated from image if not set)." ) parser.add_argument("--width", type=int, default=None, help="Video width (auto-calculated from image if not set).") - parser.add_argument("--num-frames", type=int, default=81, help="Number of frames.") - parser.add_argument("--num-inference-steps", type=int, default=50, help="Sampling steps.") + parser.add_argument("--num-frames", type=int, default=None, help="Number of frames. Default: model-specific.") + parser.add_argument( + "--num-inference-steps", type=int, default=None, help="Sampling steps. Default: model-specific." + ) parser.add_argument("--boundary-ratio", type=float, default=0.875, help="Boundary split ratio for MoE models.") parser.add_argument( "--frame-rate", @@ -93,7 +105,10 @@ def parse_args() -> argparse.Namespace: help="Optional generation frame rate (used by models like LTX2). Defaults to --fps.", ) parser.add_argument( - "--flow-shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)." + "--flow-shift", + type=float, + default=None, + help="Scheduler flow_shift. Default: model-specific.", ) parser.add_argument( "--sample-solver", @@ -253,31 +268,51 @@ def calculate_dimensions( return height, width +def _is_cosmos3_model(model_name: str, model_class_name: str | None = None) -> bool: + combined = f"{model_name} {model_class_name or ''}".lower() + return "cosmos3" in combined or "cosmos-3" in combined + + def main(): args = parse_args() generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) model_name = str(args.model).lower() if args.model is not None else "" model_class_name = args.model_class_name is_ltx2 = "ltx2" in model_name or (model_class_name and "ltx2" in model_class_name.lower()) + is_cosmos3 = _is_cosmos3_model(model_name, model_class_name) if model_class_name is None and is_ltx2: model_class_name = "LTX2ImageToVideoPipeline" + elif model_class_name is None and is_cosmos3: + model_class_name = "Cosmos3OmniDiffusersPipeline" # Load input image image = PIL.Image.open(args.image).convert("RGB") - fps = args.fps if args.fps is not None else (24 if is_ltx2 else 16) + fps = args.fps if args.fps is not None else (24 if (is_ltx2 or is_cosmos3) else 16) frame_rate = args.frame_rate if args.frame_rate is not None else float(fps) - guidance_scale = args.guidance_scale if args.guidance_scale is not None else (4.0 if is_ltx2 else 5.0) + guidance_scale = ( + args.guidance_scale if args.guidance_scale is not None else (4.0 if (is_ltx2 or is_cosmos3) else 5.0) + ) num_frames = args.num_frames if args.num_frames is not None else (121 if is_ltx2 else 81) - num_inference_steps = args.num_inference_steps if args.num_inference_steps is not None else (40 if is_ltx2 else 50) + num_inference_steps = ( + args.num_inference_steps + if args.num_inference_steps is not None + else (40 if is_ltx2 else (35 if is_cosmos3 else 50)) + ) # Calculate dimensions if not provided height = args.height width = args.width if height is None or width is None: - # Default to 480P area for Wan2.2 I2V, 512x768 area for LTX2 - max_area = 512 * 768 if is_ltx2 else 480 * 832 - mod_value = 32 if is_ltx2 else 16 + if is_ltx2: + max_area = 512 * 768 + mod_value = 32 + elif is_cosmos3: + max_area = 720 * 1280 + mod_value = 16 + else: + max_area = 480 * 832 + mod_value = 16 calc_height, calc_width = calculate_dimensions(image, max_area=max_area, mod_value=mod_value) height = height or calc_height width = width or calc_width @@ -358,8 +393,10 @@ def main(): print(f"\n{'=' * 60}") print("Generation Configuration:") print(f" Model: {args.model}") - print(f" Inference steps: {args.num_inference_steps}") - print(f" Frames: {args.num_frames}") + if model_class_name: + print(f" Model class: {model_class_name}") + print(f" Inference steps: {num_inference_steps}") + print(f" Frames: {num_frames}") print(f" Solver: {args.sample_solver}") print(f" kv_cache_dtype(config): {args.kv_cache_dtype}") print(f" kv_cache_skip_steps(config): {args.kv_cache_skip_steps}") @@ -368,7 +405,7 @@ def main(): f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size}," f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}" ) - print(f" Video size: {args.width}x{args.height}") + print(f" Video size: {width}x{height}") print(f"{'=' * 60}\n") generation_start = time.perf_counter() @@ -377,6 +414,7 @@ def main(): { "prompt": args.prompt, "negative_prompt": args.negative_prompt, + "modalities": ["video"], "multi_modal_data": {"image": image}, }, OmniDiffusionSamplingParams( diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md index c71773972b3..149e4260904 100644 --- a/examples/offline_inference/text_to_image/README.md +++ b/examples/offline_inference/text_to_image/README.md @@ -34,6 +34,7 @@ This folder provides several entrypoints for experimenting with text-to-image di | `black-forest-labs/FLUX.2-klein-4B` | 1024 x 1024 | 72.7 | 14.9 | | `black-forest-labs/FLUX.2-klein-9B` | 1024 x 1024 | 37.1 | 32.3 | | `black-forest-labs/FLUX.2-dev` | 1024 x 1024 | 65.7 | >80 (CPU offload required) | +| `$COSMOS3_MODEL` with `Cosmos3OmniDiffusersPipeline` | 1024 x 1024 | model/checkpoint dependent | local checkpoint | | `HunyuanImage-3.0` | 1024 x 1024 | 80.0 (TP≥3) | 160 | !!! info @@ -73,11 +74,13 @@ python text_to_image.py \ | Argument | Type | Default | Description | | -------- | ---- | ------- | ----------- | +| `--model` | str | `"Qwen/Qwen-Image"` | Diffusion model name or local path | +| `--model-class-name` | str | `None` | Override pipeline class, for example `Cosmos3OmniDiffusersPipeline` | | `--prompt` | str | `"a cup of coffee on the table"` | Text description for image generation | | `--seed` | int | `142` | Integer seed for deterministic sampling | | `--negative-prompt` | str | `None` | Negative prompt for classifier-free conditional guidance | | `--cfg-scale` | float | `4.0` | True CFG scale (model-specific guidance strength) | -| `--guidance-scale` | float | `1.0` | Classifier-free guidance scale | +| `--guidance-scale` | float | `4.0` | Classifier-free guidance scale | | `--num-images-per-prompt` | int | `1` | Number of images per prompt (saved as `output`, `output_1`, ...) | | `--num-inference-steps` | int | `50` | Diffusion sampling steps (more steps = higher quality, slower) | | `--height` | int | `1024` | Output image height in pixels | @@ -177,6 +180,28 @@ python examples/offline_inference/text_to_image/text_to_image.py \ --output flux2-dev.png ``` +### Cosmos3 + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +python text_to_image.py \ + --model "$COSMOS3_MODEL" \ + --model-class-name Cosmos3OmniDiffusersPipeline \ + --prompt "A small warehouse robot carrying a blue box, clean product photography" \ + --negative-prompt "blurry, distorted, low quality" \ + --guidance-scale 7.0 \ + --num-inference-steps 50 \ + --height 1024 \ + --width 1024 \ + --num-images-per-prompt 1 \ + --output cosmos3_t2i.png +``` + +This script marks text-to-image requests with `modalities=["image"]`, which selects Cosmos3 T2I. Cosmos3 currently supports one prompt per request; use `--num-images-per-prompt` to request multiple images for that prompt. Model-level CPU offload is not supported for Cosmos3, so use `--enable-layerwise-offload` for offload instead. + ### Batch Requests (Multiple Prompts) You can pass multiple prompts in a single `generate` call. diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index c0fd337bd93..6978b8bc1a9 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -49,7 +49,7 @@ def parse_args() -> argparse.Namespace: "Qwen/Qwen-Image, Tongyi-MAI/Z-Image-Turbo, Qwen/Qwen-Image-2512, stepfun-ai/NextStep-1.1, " "black-forest-labs/FLUX.1-dev, black-forest-labs/FLUX.2-klein-9B, " "black-forest-labs/FLUX.2-dev, tencent/HunyuanImage-3.0-Instruct, " - "meituan-longcat/LongCat-Image, OvisAI/Ovis-Image, " + "meituan-longcat/LongCat-Image, OvisAI/Ovis-Image, Cosmos3, " "stabilityai/stable-diffusion-3.5-medium, Tongyi-MAI/Z-Image-Turbo and etc.", ) parser.add_argument( @@ -456,6 +456,7 @@ def main(): { "prompt": args.prompt, "negative_prompt": args.negative_prompt, + "modalities": ["image"], }, OmniDiffusionSamplingParams( height=args.height, diff --git a/examples/offline_inference/text_to_video/text_to_video.md b/examples/offline_inference/text_to_video/text_to_video.md index f852e980a78..69ef1dadfe7 100644 --- a/examples/offline_inference/text_to_video/text_to_video.md +++ b/examples/offline_inference/text_to_video/text_to_video.md @@ -9,6 +9,7 @@ A unified script for text-to-video generation. Supports multiple models with mod | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | 720x1280 | 81 | 40 | 4.0 | ~60 GiB | | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v` | 480x832 | 121 | 50 | 6.0 | 1×A100 80GB | | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v` | 720x1280 | 121 | 50 | 6.0 | FP8 + VAE tiling required | +| `$COSMOS3_MODEL` with `Cosmos3OmniDiffusersPipeline` | 720x1280 | 81 | 35 | 4.0 | model/checkpoint dependent | ## Local CLI Usage @@ -45,6 +46,28 @@ python text_to_video.py \ --output ltx2_out.mp4 ``` +### Cosmos3 + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +python text_to_video.py \ + --model "$COSMOS3_MODEL" \ + --prompt "A small warehouse robot moves a blue box across a clean floor." \ + --negative-prompt "blurry, distorted, low quality" \ + --height 720 \ + --width 1280 \ + --num-frames 81 \ + --guidance-scale 4.0 \ + --num-inference-steps 35 \ + --fps 24 \ + --output cosmos3_t2v_output.mp4 +``` + +Cosmos3 video generation currently supports one prompt and one video per request. The implementation supports `--cache-backend cache_dit`, `--cfg-parallel-size 2`, `--ulysses-degree`, `--tensor-parallel-size`, `--use-hsdp`, and `--enable-layerwise-offload`. Do not use `--enable-cpu-offload`; Cosmos3 does not support model-level CPU offload. + ### HunyuanVideo-1.5 (480p) ```bash @@ -122,6 +145,7 @@ python text_to_video.py \ - `--audio-sample-rate`: audio sample rate for embedded audio (when the pipeline returns audio). - `--quantization`: quantization method (`fp8` for FP8, `gguf` for GGUF). - `--flow-shift`: scheduler flow_shift parameter. +- `--cache-backend`: `cache_dit` for supported models. ### Wan2.2-specific diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index b19f0095e64..b704f1b87eb 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -35,10 +35,26 @@ "fps": 24, "output": "hunyuan_video_15_output.mp4", }, + "cosmos3": { + "height": 720, + "width": 1280, + "num_frames": 81, + "num_inference_steps": 35, + "guidance_scale": 4.0, + "fps": 24, + "output": "cosmos3_t2v_output.mp4", + }, } -def _detect_preset(model: str) -> dict: +def _is_cosmos3_model(model: str, model_class_name: str | None = None) -> bool: + combined = f"{model} {model_class_name or ''}".lower() + return "cosmos3" in combined or "cosmos-3" in combined + + +def _detect_preset(model: str, model_class_name: str | None = None) -> dict: + if _is_cosmos3_model(model, model_class_name): + return _MODEL_PRESETS["cosmos3"] model_lower = model.lower() if "hunyuan" in model_lower: return _MODEL_PRESETS["hunyuan"] @@ -58,19 +74,19 @@ def parse_profiler_config(value: str) -> dict[str, Any]: def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Generate a video from a text prompt. " - "Supports Wan2.2, HunyuanVideo-1.5, and other text-to-video models." + "Supports Wan2.2, HunyuanVideo-1.5, Cosmos3, and other text-to-video models." ) parser.add_argument( "--model", default="Wan-AI/Wan2.2-T2V-A14B-Diffusers", help="Diffusers model ID or local path. " "Examples: Wan-AI/Wan2.2-T2V-A14B-Diffusers, " - "hunyuanvideo-community/HunyuanVideo-1.5-480p_t2v", + "hunyuanvideo-community/HunyuanVideo-1.5-480p_t2v, $COSMOS3_MODEL", ) parser.add_argument( "--model-class-name", default=None, - help="Override model class name (e.g., LTX2TwoStagesVideoPipeline).", + help="Override model class name (e.g., Cosmos3OmniDiffusersPipeline or LTX2TwoStagesVideoPipeline).", ) parser.add_argument("--prompt", default="A serene lakeside sunrise with mist over the water.", help="Text prompt.") parser.add_argument("--negative-prompt", default="", help="Negative prompt.") @@ -108,7 +124,7 @@ def parse_args() -> argparse.Namespace: type=str, default=None, choices=["cache_dit"], - help="Cache backend for acceleration (Wan2.2). Default: None.", + help="Cache backend for acceleration on supported models. Default: None.", ) parser.add_argument( "--enable-cache-dit-summary", @@ -312,7 +328,7 @@ def main(): print(f" Video size: {args.width}x{args.height}") print(f"{'=' * 60}\n") - prompt_dict = {"prompt": args.prompt} + prompt_dict = {"prompt": args.prompt, "modalities": ["video"]} if args.negative_prompt: prompt_dict["negative_prompt"] = args.negative_prompt @@ -323,6 +339,7 @@ def main(): guidance_scale=args.guidance_scale, num_inference_steps=args.num_inference_steps, num_frames=args.num_frames, + frame_rate=args.frame_rate if args.frame_rate is not None else float(args.fps), ) if args.guidance_scale_high is not None: sampling_kwargs["guidance_scale_2"] = args.guidance_scale_high diff --git a/examples/online_serving/image_to_video/README.md b/examples/online_serving/image_to_video/README.md index 285eeb27983..6f82d3a2019 100644 --- a/examples/online_serving/image_to_video/README.md +++ b/examples/online_serving/image_to_video/README.md @@ -1,6 +1,14 @@ # Image-To-Video -This example demonstrates how to deploy the Wan2.2 image-to-video model for online video generation using vLLM-Omni. +This example demonstrates how to deploy image-to-video models, including Wan2.2 and Cosmos3, for online video generation using vLLM-Omni. + +## Supported Models + +| Model | Model ID | +|-------|----------| +| Wan2.2 I2V | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | +| Wan2.2 TI2V | `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | +| Cosmos3 I2V | `$COSMOS3_MODEL` with `Cosmos3OmniDiffusersPipeline` | ## Start Server @@ -26,6 +34,22 @@ The script allows overriding: - `CACHE_BACKEND` (default: `none`) - `ENABLE_CACHE_DIT_SUMMARY` (default: `0`) +### Cosmos3 + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +vllm serve "$COSMOS3_MODEL" \ + --omni \ + --port 8091 \ + --model-class-name Cosmos3OmniDiffusersPipeline \ + --allowed-local-media-path / +``` + +Use `--enable-layerwise-offload`, `--cache-backend cache_dit`, `--cfg-parallel-size 2`, `--usp`, `--tensor-parallel-size`, or `--use-hsdp` as needed. Do not use `--enable-cpu-offload`; Cosmos3 does not support model-level CPU offload. + ### Ascend / Local LightX2V Example For a local Wan2.2-LightX2V Diffusers directory on Ascend/NPU, you can start the server like this: @@ -91,6 +115,53 @@ curl -X POST http://localhost:8091/v1/videos/sync \ -o sync_i2v_output.mp4 ``` +### Cosmos3 Sync Request + +```bash +curl -X POST http://localhost:8091/v1/videos/sync \ + -F "prompt=Cherry blossoms swaying gently in the breeze, petals falling, smooth motion" \ + -F "negative_prompt=blurry, distorted, low quality" \ + -F "input_reference=@/path/to/cherry_blossom.jpg" \ + -F "size=1280x720" \ + -F "num_frames=81" \ + -F "fps=24" \ + -F "num_inference_steps=35" \ + -F "guidance_scale=4.0" \ + -F "seed=42" \ + -o cosmos3_i2v_output.mp4 +``` + +For async generation, send the same form fields to `POST /v1/videos`, poll `GET /v1/videos/{video_id}`, and download from `GET /v1/videos/{video_id}/content`. Cosmos3 currently supports one prompt and one video per request. + +```bash +create_response=$(curl -s http://localhost:8091/v1/videos \ + -H "Accept: application/json" \ + -F "prompt=Cherry blossoms swaying gently in the breeze, petals falling, smooth motion" \ + -F "negative_prompt=blurry, distorted, low quality" \ + -F "input_reference=@/path/to/cherry_blossom.jpg" \ + -F "size=1280x720" \ + -F "num_frames=81" \ + -F "fps=24" \ + -F "num_inference_steps=35" \ + -F "guidance_scale=4.0" \ + -F "seed=42") + +video_id=$(echo "$create_response" | jq -r '.id') +while true; do + status=$(curl -s "http://localhost:8091/v1/videos/${video_id}" | jq -r '.status') + if [ "$status" = "completed" ]; then + break + fi + if [ "$status" = "failed" ]; then + echo "Video generation failed" + exit 1 + fi + sleep 2 +done + +curl -L "http://localhost:8091/v1/videos/${video_id}/content" -o cosmos3_i2v_output.mp4 +``` + For Wan Lightning/Distill checkpoints, pass `{"sample_solver":"euler"}` via `extra_params`. The default solver is `unipc`. Example matching the local LightX2V deployment above: diff --git a/examples/online_serving/text_to_image/README.md b/examples/online_serving/text_to_image/README.md index 17d377ea3e2..41062f718b2 100644 --- a/examples/online_serving/text_to_image/README.md +++ b/examples/online_serving/text_to_image/README.md @@ -20,6 +20,21 @@ Or use the startup script: bash run_server.sh ``` +### Cosmos3 + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +vllm serve "$COSMOS3_MODEL" \ + --omni \ + --port 8091 \ + --model-class-name Cosmos3OmniDiffusersPipeline +``` + +Use `--enable-layerwise-offload`, `--cache-backend cache_dit`, `--cfg-parallel-size 2`, `--usp`, `--tensor-parallel-size`, or `--use-hsdp` as needed. Do not use `--enable-cpu-offload`; Cosmos3 does not support model-level CPU offload. + ### Start with Parallelism Acceleration Enable Tensor Parallelism and VAE Patch Parallelism for faster inference: @@ -68,6 +83,26 @@ curl -s http://localhost:8091/v1/chat/completions \ }' | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > output.png ``` +#### Cosmos3 Images API + +The dedicated image endpoint sets `modalities=["image"]` internally, which selects Cosmos3 text-to-image. + +```bash +curl -X POST http://localhost:8091/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A small warehouse robot carrying a blue box, clean product photography", + "size": "1024x1024", + "n": 1, + "num_inference_steps": 50, + "guidance_scale": 7.0, + "negative_prompt": "blurry, distorted, low quality", + "seed": 42 + }' | jq -r '.data[0].b64_json' | base64 -d > cosmos3_t2i.png +``` + +Cosmos3 currently supports one prompt per request. Use `n` to request multiple images for that prompt. + ### Method 2: Using OpenAI Python SDK ```python @@ -226,6 +261,7 @@ count, use `size` and `n` rather than `height`, `width`, or | `height` | int | None | Image height in pixels | | `width` | int | None | Image width in pixels | | `size` | str | None | Image size (e.g., "1024x1024") | +| `n` | int | 1 | Number of images for `/v1/images/generations` | | `num_inference_steps` | int | 50 | Number of denoising steps | | `true_cfg_scale` | float | 4.0 | Qwen-Image CFG scale | | `seed` | int | None | Random seed (reproducible) | diff --git a/examples/online_serving/text_to_video/README.md b/examples/online_serving/text_to_video/README.md index c01e0602ff9..57922abd38a 100644 --- a/examples/online_serving/text_to_video/README.md +++ b/examples/online_serving/text_to_video/README.md @@ -10,6 +10,7 @@ This example demonstrates how to deploy text-to-video models for online video ge | Wan2.1 T2V (14B) | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | | Wan2.2 T2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | | LTX-2 | `Lightricks/LTX-2` | +| Cosmos3 T2V | `$COSMOS3_MODEL` with `Cosmos3OmniDiffusersPipeline` | ## Wan2.2 T2V @@ -37,6 +38,23 @@ The script allows overriding: - `CACHE_BACKEND` (default: `none`) - `ENABLE_CACHE_DIT_SUMMARY` (default: `0`) +## Cosmos3 T2V + +Cosmos3 uses one pipeline for text-to-image, text-to-video, and image-to-video. Set `COSMOS3_MODEL` to a local Diffusers-format Cosmos3 checkpoint or model reference, and select the pipeline explicitly. + +### Start Server + +```bash +export COSMOS3_MODEL=/path/to/cosmos3-diffusers + +vllm serve "$COSMOS3_MODEL" \ + --omni \ + --port 8091 \ + --model-class-name Cosmos3OmniDiffusersPipeline +``` + +Use `--enable-layerwise-offload`, `--cache-backend cache_dit`, `--cfg-parallel-size 2`, `--usp`, `--tensor-parallel-size`, or `--use-hsdp` as needed. Do not use `--enable-cpu-offload`; Cosmos3 does not support model-level CPU offload. + ## Async Job Behavior `POST /v1/videos` is asynchronous. It creates a video job and immediately @@ -82,6 +100,51 @@ curl -X POST http://localhost:8091/v1/videos/sync \ -o sync_t2v_output.mp4 ``` +### Cosmos3 Sync Request + +```bash +curl -X POST http://localhost:8091/v1/videos/sync \ + -F "prompt=A small warehouse robot moves a blue box across a clean floor." \ + -F "negative_prompt=blurry, distorted, low quality" \ + -F "size=1280x720" \ + -F "num_frames=81" \ + -F "fps=24" \ + -F "num_inference_steps=35" \ + -F "guidance_scale=4.0" \ + -F "seed=42" \ + -o cosmos3_t2v_output.mp4 +``` + +For async generation, send the same form fields to `POST /v1/videos`, poll `GET /v1/videos/{video_id}`, and download from `GET /v1/videos/{video_id}/content`. Cosmos3 currently supports one prompt and one video per request. + +```bash +create_response=$(curl -s http://localhost:8091/v1/videos \ + -H "Accept: application/json" \ + -F "prompt=A small warehouse robot moves a blue box across a clean floor." \ + -F "negative_prompt=blurry, distorted, low quality" \ + -F "size=1280x720" \ + -F "num_frames=81" \ + -F "fps=24" \ + -F "num_inference_steps=35" \ + -F "guidance_scale=4.0" \ + -F "seed=42") + +video_id=$(echo "$create_response" | jq -r '.id') +while true; do + status=$(curl -s "http://localhost:8091/v1/videos/${video_id}" | jq -r '.status') + if [ "$status" = "completed" ]; then + break + fi + if [ "$status" = "failed" ]; then + echo "Video generation failed" + exit 1 + fi + sleep 2 +done + +curl -L "http://localhost:8091/v1/videos/${video_id}/content" -o cosmos3_t2v_output.mp4 +``` + ## Storage Generated video files are stored on local disk by the async video API. @@ -181,7 +244,7 @@ curl -X POST http://localhost:8091/v1/videos \ | `guidance_scale` | float | None | CFG guidance scale (low-noise stage) | | `guidance_scale_2` | float | None | CFG guidance scale (high-noise stage, Wan2.2) | | `boundary_ratio` | float | None | Boundary split ratio for low/high DiT (Wan2.2) | -| `flow_shift` | float | None | Scheduler flow shift (Wan2.2) | +| `flow_shift` | float | None | Scheduler flow shift | | `seed` | int | None | Random seed (reproducible) | | `lora` | object | None | LoRA configuration | diff --git a/tests/diffusion/cache/test_cache_dit.py b/tests/diffusion/cache/test_cache_dit.py index 0b7ef723585..8499aa39e8c 100644 --- a/tests/diffusion/cache/test_cache_dit.py +++ b/tests/diffusion/cache/test_cache_dit.py @@ -18,6 +18,7 @@ cd_backend.enable_cache_for_ltx2, cd_backend.enable_cache_for_wan22, cd_backend.enable_cache_for_longcat_image, + cd_backend.enable_cache_for_cosmos3, ] SAMPLE_CACHE_CONFIG = DiffusionCacheConfig() @@ -38,3 +39,21 @@ def test_separate_cfg(mock_cache_dit, mock_block_adapter, enabler): mock_cache_dit.enable_cache.assert_called_once() adapter_kwargs = mock_block_adapter.call_args.kwargs assert adapter_kwargs["has_separate_cfg"] is True + + +@patch("vllm_omni.diffusion.cache.cache_dit_backend.BlockAdapter") +@patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit") +def test_cosmos3_cache_dit_wraps_gen_layers(mock_cache_dit, mock_block_adapter): + """Cosmos3 should cache only the repeated GEN pathway blocks.""" + mock_pipeline = Mock() + gen_layers = object() + mock_pipeline.transformer.gen_layers = gen_layers + + cd_backend.enable_cache_for_cosmos3(mock_pipeline, SAMPLE_CACHE_CONFIG) + + mock_cache_dit.enable_cache.assert_called_once() + adapter_kwargs = mock_block_adapter.call_args.kwargs + assert adapter_kwargs["transformer"] is mock_pipeline.transformer + assert adapter_kwargs["blocks"] == [gen_layers] + assert adapter_kwargs["has_separate_cfg"] is True + assert adapter_kwargs["check_forward_pattern"] is False diff --git a/tests/diffusion/models/cosmos3/__init__.py b/tests/diffusion/models/cosmos3/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/tests/diffusion/models/cosmos3/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/diffusion/models/cosmos3/conftest.py b/tests/diffusion/models/cosmos3/conftest.py new file mode 100644 index 00000000000..58d4af9bf85 --- /dev/null +++ b/tests/diffusion/models/cosmos3/conftest.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from typing import Any + +import pytest +import torch +from torch import nn + + +class StubScheduler: + def __init__(self, timesteps: list[int] | None = None, *, flow_shift: float = 1.0) -> None: + self.timesteps = torch.tensor(timesteps or [9, 3], dtype=torch.int64) + self.config = SimpleNamespace(num_train_timesteps=1000, flow_shift=flow_shift) + self.set_timesteps_calls: list[tuple[int, torch.device]] = [] + self.step_calls: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] + + def set_timesteps(self, num_steps: int, device: torch.device) -> None: + self.set_timesteps_calls.append((num_steps, device)) + self.timesteps = torch.arange(num_steps, 0, -1, dtype=torch.int64, device=device) + + def step(self, noise_pred: torch.Tensor, timestep: torch.Tensor, latents: torch.Tensor, **kwargs): + del kwargs + self.step_calls.append((noise_pred.clone(), timestep.clone(), latents.clone())) + return (latents + noise_pred,) + + +class _ModeLatentDist: + def __init__(self, latents: torch.Tensor) -> None: + self._latents = latents + + def mode(self) -> torch.Tensor: + return self._latents + + +class StubCosmos3VAE: + dtype = torch.float32 + + def __init__(self, z_dim: int = 2, *, temporal: int = 4, spatial: int = 8) -> None: + self.config = SimpleNamespace( + z_dim=z_dim, + scale_factor_temporal=temporal, + scale_factor_spatial=spatial, + latents_mean=[0.0] * z_dim, + latents_std=[1.0] * z_dim, + ) + + def encode(self, video: torch.Tensor): + latent_frames = (video.shape[2] - 1) // self.config.scale_factor_temporal + 1 + latent_height = video.shape[-2] // self.config.scale_factor_spatial + latent_width = video.shape[-1] // self.config.scale_factor_spatial + latents = torch.ones( + video.shape[0], + self.config.z_dim, + latent_frames, + latent_height, + latent_width, + dtype=video.dtype, + device=video.device, + ) + return SimpleNamespace(latent_dist=_ModeLatentDist(latents)) + + def decode(self, latents: torch.Tensor, return_dict: bool = False): + del return_dict + return (latents,) + + +class StubCosmos3Transformer(nn.Module): + def __init__( + self, + *, + latent_channel_size: int = 2, + sound_gen: bool = False, + sound_dim: int = 3, + action_gen: bool = False, + action_dim: int = 4, + ) -> None: + super().__init__() + self.latent_channel_size = latent_channel_size + self.sound_gen = sound_gen + self.sound_dim = sound_dim + self.action_gen = action_gen + self.action_dim = action_dim + self.cached_kv: Any | None = None + self.cached_freqs_gen: Any | None = None + self.calls: list[dict[str, Any]] = [] + self.reset_calls = 0 + + def reset_cache(self) -> None: + self.reset_calls += 1 + self.cached_kv = None + self.cached_freqs_gen = None + + def forward( + self, + *, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + text_ids: torch.Tensor, + text_mask: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + token = int(text_ids.reshape(-1)[0].item()) if text_ids.numel() else 0 + sound_latents = kwargs.get("sound_latents") + self.calls.append( + { + "token": token, + "timestep": timestep.clone(), + "text_mask": text_mask.clone(), + "cache_before": self.cached_kv, + "kwargs": dict(kwargs), + } + ) + if self.cached_kv is None: + marker = torch.tensor([token], dtype=torch.float32) + self.cached_kv = [(marker, marker + 100)] + self.cached_freqs_gen = (marker + 200, marker + 300) + action_latents = kwargs.get("action_latents") + outputs: list[torch.Tensor] = [torch.full_like(hidden_states, float(token))] + if action_latents is not None: + outputs.append(torch.full_like(action_latents, float(token + 20))) + if sound_latents is not None: + outputs.append(torch.full_like(sound_latents, float(token + 10))) + return outputs[0] if len(outputs) == 1 else tuple(outputs) + + +def passthrough_progress_bar(iterable): + return iterable + + +@pytest.fixture(autouse=True) +def fake_cosmos3_guardrails(monkeypatch: pytest.MonkeyPatch): + module = types.ModuleType("vllm_omni.diffusion.models.cosmos3.guardrails") + module.is_guardrails_enabled = lambda od_config: False + module.ensure_initialized = lambda od_config: None + module.check_text_safety = lambda text: None + module.check_video_safety = lambda video: video + monkeypatch.setitem(sys.modules, module.__name__, module) + return module + + +@pytest.fixture +def make_cosmos3_pipeline(): + def _make(): + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + ) + + pipeline = object.__new__(Cosmos3OmniDiffusersPipeline) + nn.Module.__init__(pipeline) + pipeline.od_config = SimpleNamespace() + pipeline.device = torch.device("cpu") + pipeline.dtype = torch.float32 + pipeline.transformer = StubCosmos3Transformer(latent_channel_size=2) + pipeline.vae = StubCosmos3VAE(z_dim=2) + pipeline.vae_scale_factor_temporal = 4 + pipeline.vae_scale_factor_spatial = 8 + pipeline.scheduler = StubScheduler([9, 3], flow_shift=1.0) + pipeline._base_scheduler_config = pipeline.scheduler.config + pipeline._engine_init_flow_shift = 1.0 + pipeline._current_flow_shift = 1.0 + pipeline._guidance_scale = None + pipeline._num_timesteps = None + pipeline.progress_bar = passthrough_progress_bar + pipeline._sound_tokenizer = None + return pipeline + + return _make + + +def make_sampling_params(**overrides: Any) -> SimpleNamespace: + values = { + "height": None, + "width": None, + "num_frames": None, + "num_inference_steps": None, + "guidance_scale": None, + "seed": 123, + "num_outputs_per_prompt": 1, + "frame_rate": None, + "resolved_frame_rate": None, + "max_sequence_length": None, + "extra_args": {}, + } + values.update(overrides) + return SimpleNamespace(**values) diff --git a/tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py b/tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py new file mode 100644 index 00000000000..b068ea7e74a --- /dev/null +++ b/tests/diffusion/models/cosmos3/test_cosmos3_pipeline.py @@ -0,0 +1,1108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import logging +from types import SimpleNamespace + +import pytest +import torch +from PIL import Image +from torch import nn + +from tests.diffusion.models.cosmos3.conftest import ( + StubScheduler, + make_sampling_params, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +def _ids(value: int) -> torch.Tensor: + return torch.tensor([[value]], dtype=torch.long) + + +def _mask() -> torch.Tensor: + return torch.ones(1, 1, dtype=torch.long) + + +class TestRegistryIntegration: + def test_pipeline_registered_and_exported(self) -> None: + from vllm_omni.diffusion.cache.cache_dit_backend import CUSTOM_DIT_ENABLERS + from vllm_omni.diffusion.models import cosmos3 + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + ) + from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin + from vllm_omni.diffusion.registry import ( + _DIFFUSION_MODELS, + _DIFFUSION_POST_PROCESS_FUNCS, + _DIFFUSION_PRE_PROCESS_FUNCS, + ) + + assert issubclass(Cosmos3OmniDiffusersPipeline, nn.Module) + assert issubclass(Cosmos3OmniDiffusersPipeline, ProgressBarMixin) + assert Cosmos3OmniDiffusersPipeline.support_image_input is True + assert _DIFFUSION_MODELS["Cosmos3OmniDiffusersPipeline"] == ( + "cosmos3", + "pipeline_cosmos3", + "Cosmos3OmniDiffusersPipeline", + ) + assert _DIFFUSION_PRE_PROCESS_FUNCS["Cosmos3OmniDiffusersPipeline"] == "get_cosmos3_pre_process_func" + assert _DIFFUSION_POST_PROCESS_FUNCS["Cosmos3OmniDiffusersPipeline"] == "get_cosmos3_post_process_func" + assert "Cosmos3OmniDiffusersPipeline" in CUSTOM_DIT_ENABLERS + assert hasattr(cosmos3, "Cosmos3OmniDiffusersPipeline") + assert "Cosmos3OmniDiffusersPipeline" in cosmos3.__all__ + + +class TestPreAndPostProcess: + def test_preprocess_leaves_t2v_string_prompt_unchanged(self) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import get_cosmos3_pre_process_func + + request = SimpleNamespace( + prompts=["A robot walks through a warehouse."], + sampling_params=SimpleNamespace(height=None, width=None), + ) + + result = get_cosmos3_pre_process_func(SimpleNamespace())(request) + + assert result is request + assert result.prompts == ["A robot walks through a warehouse."] + assert result.sampling_params.height is None + assert result.sampling_params.width is None + + def test_preprocess_resizes_i2v_image_to_720p_aspect_and_stores_tensor(self) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import get_cosmos3_pre_process_func + + request = SimpleNamespace( + prompts=[ + { + "prompt": "A slow camera push.", + "multi_modal_data": {"image": Image.new("RGB", (320, 160), "red")}, + } + ], + sampling_params=SimpleNamespace(height=None, width=None), + ) + + result = get_cosmos3_pre_process_func(SimpleNamespace())(request) + prompt = result.prompts[0] + + assert result.sampling_params.height == 672 + assert result.sampling_params.width == 1344 + preprocessed = prompt["additional_information"]["preprocessed_image"] + assert isinstance(preprocessed, torch.Tensor) + assert tuple(preprocessed.shape[-2:]) == (672, 1344) + + def test_preprocess_preserves_explicit_size_for_i2v(self) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import get_cosmos3_pre_process_func + + request = SimpleNamespace( + prompts=[ + { + "prompt": "A slow camera push.", + "multi_modal_data": {"image": Image.new("RGB", (320, 160), "red")}, + } + ], + sampling_params=SimpleNamespace(height=64, width=96), + ) + + result = get_cosmos3_pre_process_func(SimpleNamespace())(request) + + assert tuple(result.prompts[0]["additional_information"]["preprocessed_image"].shape[-2:]) == (64, 96) + + def test_postprocess_latent_passthrough_and_t2i_shape_validation(self) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import get_cosmos3_post_process_func + + func = get_cosmos3_post_process_func(SimpleNamespace()) + video = torch.zeros(1, 3, 1, 4, 4) + + assert func(video, output_type="latent") is video + + images = func({"image": video}) + assert len(images) == 1 + assert images[0].size == (4, 4) + + video_result = func({"video": video}) + assert "video" in video_result + + sound_result = func( + { + "video": video, + "audio": torch.ones(1, 2, 16), + "audio_sample_rate": 48000, + }, + sampling_params=SimpleNamespace(extra_args={"resolved_frame_rate": 12}), + ) + assert "video" in sound_result + assert sound_result["audio"].shape == (1, 2, 16) + assert sound_result["audio_sample_rate"] == 48000 + assert sound_result["fps"] == 12 + + with pytest.raises(ValueError, match="text-to-image postprocess expects"): + func({"image": torch.zeros(1, 3, 2, 4, 4)}) + + with pytest.raises(ValueError, match="both image and video"): + func({"image": video, "video": video}) + + with pytest.raises(ValueError, match="does not support audio output"): + func({"image": video, "audio": torch.ones(1, 2, 16)}) + + +class TestPipelineHelpers: + def test_get_sp_param_prefers_extra_args_then_direct_attribute(self) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + ) + + sp = SimpleNamespace(extra_args={"flow_shift": 3.0}, flow_shift=2.0) + assert Cosmos3OmniDiffusersPipeline._get_sp_param(sp, "flow_shift", 1.0) == 3.0 + + sp = SimpleNamespace(extra_args={}, flow_shift=2.0) + assert Cosmos3OmniDiffusersPipeline._get_sp_param(sp, "flow_shift", 1.0) == 2.0 + + sp = SimpleNamespace(extra_args={}) + assert Cosmos3OmniDiffusersPipeline._get_sp_param(sp, "flow_shift", 1.0) == 1.0 + + def test_apply_metadata_templates_adds_duration_and_resolution(self) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + ) + + prompt = Cosmos3OmniDiffusersPipeline._apply_metadata_templates( + "A city street.", + num_frames=48, + frame_rate=24, + height=720, + width=1280, + ) + + assert prompt == ( + "A city street. The video is 2.0 seconds long and is of 24 FPS. This video is of 720x1280 resolution." + ) + + @pytest.mark.parametrize( + "tokenized", + [ + [1, 2], + (1, 2), + {"input_ids": [[1, 2]]}, + torch.tensor([1, 2]), + ], + ) + def test_normalize_token_ids_accepts_common_tokenizer_outputs(self, tokenized) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + ) + + assert Cosmos3OmniDiffusersPipeline._normalize_token_ids(tokenized) == [1, 2] + + def test_normalize_token_ids_rejects_unknown_or_non_integer_values(self) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + ) + + with pytest.raises(TypeError, match="must return token IDs"): + Cosmos3OmniDiffusersPipeline._normalize_token_ids(object()) + + with pytest.raises(TypeError, match="non-integer token"): + Cosmos3OmniDiffusersPipeline._normalize_token_ids([object()]) + + def test_tokenize_prompt_adds_generation_tokens_and_padding(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + + class FakeTokenizer: + eos_token_id = 99 + pad_token_id = 0 + + def __init__(self) -> None: + self.conversations = None + + def apply_chat_template(self, conversations, tokenize: bool, add_generation_prompt: bool): + self.conversations = conversations + assert tokenize is True + assert add_generation_prompt is True + return [10, 11] + + def convert_tokens_to_ids(self, token: str) -> int: + assert token == "<|vision_start|>" + return 88 + + tokenizer = FakeTokenizer() + pipeline.tokenizer = tokenizer + + input_ids, attention_mask = pipeline._tokenize_prompt( + "hello", + max_sequence_length=6, + use_system_prompt=True, + system_prompt="system", + ) + + assert input_ids.tolist() == [[10, 11, 99, 88, 0, 0]] + assert attention_mask.tolist() == [[1, 1, 1, 1, 0, 0]] + assert tokenizer.conversations == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "hello"}, + ] + + def test_format_and_tokenize_uses_video_and_image_metadata_modes(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + captured: list[tuple[str, bool, str | None]] = [] + + def fake_tokenize(text, max_sequence_length, use_system_prompt=False, system_prompt=None): + del max_sequence_length + captured.append((text, use_system_prompt, system_prompt)) + return _ids(len(captured)), _mask() + + pipeline._tokenize_prompt = fake_tokenize # type: ignore[method-assign] + + pipeline._format_and_tokenize_prompts( + "A robot", + "bad", + num_frames=48, + frame_rate=24, + height=720, + width=1280, + max_sequence_length=32, + sp=SimpleNamespace(extra_args={"negative_metadata_mode": "inverse"}), + use_system_prompt=True, + is_t2i=False, + ) + assert "The video is 2.0 seconds long" in captured[0][0] + assert "This video is of 720x1280 resolution" in captured[0][0] + assert "The video is not 2.0 seconds long" in captured[1][0] + assert captured[0][1] is True + + captured.clear() + pipeline._format_and_tokenize_prompts( + "A robot", + "bad", + num_frames=1, + frame_rate=24, + height=1024, + width=1024, + max_sequence_length=32, + sp=SimpleNamespace(extra_args={}), + use_system_prompt=False, + is_t2i=True, + ) + assert "This image is of 1024x1024 resolution" in captured[0][0] + assert "seconds long" not in captured[0][0] + assert captured[1][0] == "bad" + + @pytest.mark.parametrize( + ("key", "expected"), + [ + ("transformer.vae2llm.weight", "transformer.vae2llm.weight"), + ("model.embed_tokens.weight", "transformer.language_model.embed_tokens.weight"), + ("model.norm.weight", "transformer.language_model.norm.weight"), + ("model.norm_moe_gen.weight", "transformer.norm_moe_gen.weight"), + ( + "model.layers.3.self_attn.q_proj.weight", + "transformer.language_model.layers.3.self_attn.q_proj.weight", + ), + ( + "model.layers.3.self_attn.q_proj_moe_gen.weight", + "transformer.gen_layers.3.cross_attention.q_proj.weight", + ), + ( + "model.layers.3.mlp_moe_gen.down_proj.weight", + "transformer.gen_layers.3.mlp.down_proj.weight", + ), + ("sound2llm.weight", "transformer.sound2llm.weight"), + ("llm2sound.bias", "transformer.llm2sound.bias"), + ("sound_modality_embed", "transformer.sound_modality_embed"), + ("sound_modality_embed.weight", "transformer.sound_modality_embed"), + ("action2llm.fc.weight", "transformer.action2llm.fc.weight"), + ("llm2action.bias.weight", "transformer.llm2action.bias.weight"), + ("action_modality_embed", "transformer.action_modality_embed"), + ("action_modality_embed.weight", "transformer.action_modality_embed"), + ("action_pos_embed.weight", None), + ("lm_head.weight", None), + ("other.weight", None), + ], + ) + def test_remap_ckpt_key(self, key: str, expected: str | None) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + ) + + assert Cosmos3OmniDiffusersPipeline._remap_ckpt_key(key) == expected + + def test_prepare_latents_shape_uses_cosmos_temporal_and_spatial_factors(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + + latents = pipeline._prepare_latents( + height=16, + width=24, + num_frames=5, + generator=torch.Generator(device="cpu").manual_seed(0), + ) + + assert latents.shape == (1, 2, 2, 2, 3) + assert latents.dtype == torch.float32 + + def test_sound_request_detection_uses_prompt_and_extra_args(self) -> None: + from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + ) + + assert Cosmos3OmniDiffusersPipeline._is_sound_request( + {"prompt": "x", "generate_sound": True}, + SimpleNamespace(extra_args={}), + ) + assert Cosmos3OmniDiffusersPipeline._is_sound_request( + {"prompt": "x"}, + SimpleNamespace(extra_args={"enable_sound_generation": "true"}), + ) + assert not Cosmos3OmniDiffusersPipeline._is_sound_request( + {"prompt": "x"}, + SimpleNamespace(extra_args={"generate_sound": False}), + ) + + def test_prepare_sound_latents_uses_lazy_tokenizer_and_duration(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, sound_gen=True, sound_dim=3) + + class FakeSoundTokenizer: + sample_rate = 10 + latent_ch = 3 + + def get_latent_num_samples(self, samples: int) -> int: + assert samples == 20 + return 5 + + def decode(self, latents: torch.Tensor) -> torch.Tensor: + return torch.ones(latents.shape[0], 2, 7) + + pipeline._sound_tokenizer = FakeSoundTokenizer() + + target_samples, duration, sample_rate = pipeline._resolve_sound_target_samples( + SimpleNamespace(extra_args={"sound_duration": 2.0}), + num_frames=9, + frame_rate=3.0, + ) + latents, latent_frames = pipeline._prepare_sound_latents( + target_samples, + torch.Generator(device="cpu").manual_seed(0), + ) + audio = pipeline._decode_sound_latents(torch.zeros(1, 3, 5), target_audio_samples=5) + + assert (target_samples, duration, sample_rate) == (20, 2.0, 10) + assert latents.shape == (1, 3, 5) + assert latent_frames == 5 + assert audio.shape == (1, 2, 5) + + def test_init_eagerly_loads_sound_tokenizer_when_transformer_supports_sound( + self, + tmp_path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + import vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 as cosmos3_module + from vllm_omni.diffusion.models.cosmos3 import sound_tokenizer + + class FakeTokenizer: + @classmethod + def from_pretrained(cls, *args, **kwargs): + return cls() + + class FakeVAE: + config = SimpleNamespace(scale_factor_temporal=4, scale_factor_spatial=8) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + return cls() + + def to(self, device): + self.device = device + return self + + class FakeScheduler: + config = SimpleNamespace(flow_shift=1.0) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + return cls() + + class FakeTransformer: + sound_gen = True + + fake_sound_tokenizer = object() + calls = [] + + def fake_from_config(od_config): + calls.append(od_config) + return fake_sound_tokenizer + + monkeypatch.setattr(cosmos3_module, "AutoTokenizer", FakeTokenizer) + monkeypatch.setattr(cosmos3_module, "DistributedAutoencoderKLWan", FakeVAE) + monkeypatch.setattr(cosmos3_module, "UniPCMultistepScheduler", FakeScheduler) + monkeypatch.setattr(cosmos3_module, "Cosmos3VFMTransformer", lambda *args, **kwargs: FakeTransformer()) + monkeypatch.setattr(sound_tokenizer.Cosmos3SoundTokenizer, "from_config", staticmethod(fake_from_config)) + monkeypatch.setattr( + cosmos3_module.Cosmos3OmniDiffusersPipeline, + "setup_diffusion_pipeline_profiler", + lambda self, **kwargs: None, + ) + + od_config = SimpleNamespace( + model=str(tmp_path), + dtype=torch.float32, + enable_cpu_offload=False, + flow_shift=None, + enable_diffusion_pipeline_profiler=False, + ) + pipeline = cosmos3_module.Cosmos3OmniDiffusersPipeline(od_config=od_config) + + assert calls == [od_config] + assert pipeline._sound_tokenizer is fake_sound_tokenizer + source = pipeline.weights_sources[0] + assert source.subfolder is None + assert source.prefix == "transformer." + assert source.allow_patterns_overrides == ["transformer/*.safetensors"] + + def test_prepare_latents_i2v_conditions_first_latent_frame(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + + def fake_encode(image_tensor, num_frames, height, width): + del image_tensor, num_frames, height, width + return torch.full((1, 2, 2, 2, 3), 5.0) + + pipeline._encode_conditioning_video = fake_encode # type: ignore[method-assign] + + latents, velocity_mask, image_latent = pipeline._prepare_latents_i2v( + image_tensor=torch.zeros(1, 3, 16, 24), + height=16, + width=24, + num_frames=5, + generator=torch.Generator(device="cpu").manual_seed(0), + ) + + assert latents.shape == (1, 2, 2, 2, 3) + torch.testing.assert_close(latents[:, :, 0], torch.full((1, 2, 2, 3), 5.0)) + assert velocity_mask.tolist() == [[[[[0.0]], [[1.0]]]]] + torch.testing.assert_close(image_latent, torch.full((1, 2, 1, 2, 3), 5.0)) + + def test_prepare_action_latents_policy_uses_noise_and_raw_dim_mask(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(action_gen=True, action_dim=4) + + action, velocity_mask, clean, raw_dim = pipeline._prepare_action_latents( + mode="policy", + action_chunk_size=3, + raw_action_dim=2, + generator=torch.Generator(device="cpu").manual_seed(0), + sp=SimpleNamespace(extra_args={}), + ) + + assert action.shape == (1, 3, 4) + assert raw_dim == 2 + assert velocity_mask.tolist() == [[[1.0], [1.0], [1.0]]] + torch.testing.assert_close(action[:, :, 2:], torch.zeros(1, 3, 2)) + torch.testing.assert_close(clean, torch.zeros(1, 3, 4)) + + def test_prepare_action_latents_forward_dynamics_conditions_supplied_actions(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(action_gen=True, action_dim=4) + + action, velocity_mask, clean, raw_dim = pipeline._prepare_action_latents( + mode="forward_dynamics", + action_chunk_size=2, + raw_action_dim=None, + generator=torch.Generator(device="cpu").manual_seed(0), + sp=SimpleNamespace(extra_args={"action": [[1.0, 2.0], [3.0, 4.0]]}), + ) + + assert raw_dim == 2 + assert velocity_mask.tolist() == [[[0.0], [0.0]]] + torch.testing.assert_close(action, clean) + torch.testing.assert_close(action[0, :, :2], torch.tensor([[1.0, 2.0], [3.0, 4.0]])) + + def test_set_flow_shift_rebuilds_only_when_target_changes(self, make_cosmos3_pipeline, monkeypatch) -> None: + import vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 as cosmos3_module + + pipeline = make_cosmos3_pipeline() + + class FakeUniPCMultistepScheduler: + calls: list[tuple[object, float]] = [] + + @classmethod + def from_config(cls, config, flow_shift: float): + cls.calls.append((config, flow_shift)) + return StubScheduler([1], flow_shift=flow_shift) + + monkeypatch.setattr(cosmos3_module, "UniPCMultistepScheduler", FakeUniPCMultistepScheduler) + original_scheduler = pipeline.scheduler + + pipeline._set_flow_shift(1.0) + assert pipeline.scheduler is original_scheduler + assert FakeUniPCMultistepScheduler.calls == [] + + pipeline._set_flow_shift(3.0) + assert pipeline.scheduler is not original_scheduler + assert pipeline._current_flow_shift == 3.0 + assert FakeUniPCMultistepScheduler.calls == [(pipeline._base_scheduler_config, 3.0)] + + +class TestDiffuse: + def test_diffuse_without_cfg_runs_one_cond_forward_per_step(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + latents = torch.zeros(1, 2, 2, 1, 1) + + result = pipeline.diffuse( + latents=latents, + timesteps=torch.tensor([7, 3]), + cond_ids=_ids(2), + cond_mask=_mask(), + uncond_ids=_ids(1), + uncond_mask=_mask(), + guidance_scale=1.0, + shared_kwargs={"video_shape": (2, 1, 1), "fps": 24.0}, + ) + + assert pipeline.transformer.reset_calls == 1 + assert [call["token"] for call in pipeline.transformer.calls] == [2, 2] + torch.testing.assert_close(result, torch.full_like(latents, 4.0)) + + def test_diffuse_sequential_cfg_uses_separate_caches_and_interval_skip(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + latents = torch.zeros(1, 2, 1, 1, 1) + + result = pipeline.diffuse( + latents=latents, + timesteps=torch.tensor([900, 100]), + cond_ids=_ids(2), + cond_mask=_mask(), + uncond_ids=_ids(1), + uncond_mask=_mask(), + guidance_scale=3.0, + shared_kwargs={"video_shape": (1, 1, 1), "fps": 24.0}, + guidance_interval=(500.0, 1000.0), + ) + + assert [call["token"] for call in pipeline.transformer.calls] == [2, 1, 2] + assert pipeline.transformer.calls[0]["cache_before"] is None + assert pipeline.transformer.calls[1]["cache_before"] is None + assert pipeline.transformer.calls[2]["cache_before"] is not None + torch.testing.assert_close(result, torch.full_like(latents, 6.0)) + + def test_diffuse_cfg_parallel_uses_scale_one_outside_guidance_interval( + self, + make_cosmos3_pipeline, + ) -> None: + pipeline = make_cosmos3_pipeline() + pipeline._cfg_parallel_active = lambda: True # type: ignore[method-assign] + latents = torch.zeros(1, 2, 1, 1, 1) + calls = [] + + def fake_predict_noise_maybe_with_cfg(**kwargs): + calls.append(kwargs) + return torch.ones_like(latents) + + pipeline.predict_noise_maybe_with_cfg = fake_predict_noise_maybe_with_cfg # type: ignore[method-assign] + + result = pipeline.diffuse( + latents=latents, + timesteps=torch.tensor([900, 100]), + cond_ids=_ids(2), + cond_mask=_mask(), + uncond_ids=_ids(1), + uncond_mask=_mask(), + guidance_scale=4.0, + shared_kwargs={"video_shape": (1, 1, 1), "fps": 24.0}, + guidance_interval=(500.0, 1000.0), + ) + + assert [call["true_cfg_scale"] for call in calls] == [4.0, 1.0] + assert calls[0]["positive_kwargs"]["text_ids"].item() == 2 + assert calls[0]["negative_kwargs"]["text_ids"].item() == 1 + torch.testing.assert_close(result, torch.full_like(latents, 2.0)) + + def test_diffuse_i2v_masks_conditioned_frame_and_reinjects_image_latent(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + latents = torch.zeros(1, 2, 2, 1, 1) + velocity_mask = torch.tensor([[[[[0.0]], [[1.0]]]]]) + image_latent = torch.full((1, 2, 1, 1, 1), 7.0) + + result = pipeline.diffuse( + latents=latents, + timesteps=torch.tensor([7]), + cond_ids=_ids(2), + cond_mask=_mask(), + uncond_ids=_ids(1), + uncond_mask=_mask(), + guidance_scale=1.0, + shared_kwargs={"video_shape": (2, 1, 1), "fps": 24.0}, + velocity_mask=velocity_mask, + image_latent=image_latent, + ) + + torch.testing.assert_close(result[:, :, 0:1], image_latent) + torch.testing.assert_close(result[:, :, 1:2], torch.full((1, 2, 1, 1, 1), 2.0)) + + def test_diffuse_with_sound_steps_video_and_sound_jointly(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, sound_gen=True, sound_dim=3) + latents = torch.zeros(1, 2, 1, 1, 1) + sound_latents = torch.zeros(1, 3, 2) + + video_result, sound_result = pipeline.diffuse( + latents=latents, + sound_latents=sound_latents, + timesteps=torch.tensor([7, 3]), + cond_ids=_ids(2), + cond_mask=_mask(), + uncond_ids=_ids(1), + uncond_mask=_mask(), + guidance_scale=1.0, + shared_kwargs={"video_shape": (1, 1, 1), "fps": 24.0}, + ) + + torch.testing.assert_close(video_result, torch.full_like(latents, 4.0)) + torch.testing.assert_close(sound_result, torch.full_like(sound_latents, 24.0)) + assert pipeline.scheduler.step_calls[0][0].shape == (1, latents.numel() + sound_latents.numel()) + + def test_diffuse_with_action_steps_video_and_action_jointly(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, action_gen=True, action_dim=4) + latents = torch.zeros(1, 2, 1, 1, 1) + action_latents = torch.zeros(1, 3, 4) + + video_result, action_result = pipeline.diffuse( + latents=latents, + action_latents=action_latents, + action_velocity_mask=torch.ones(1, 3, 1), + action_condition_latents=torch.zeros(1, 3, 4), + timesteps=torch.tensor([7, 3]), + cond_ids=_ids(2), + cond_mask=_mask(), + uncond_ids=_ids(1), + uncond_mask=_mask(), + guidance_scale=1.0, + shared_kwargs={ + "video_shape": (1, 1, 1), + "fps": 24.0, + "action_domain_ids": torch.tensor([0]), + "action_noisy_mask": torch.ones(1, 3, 1), + }, + ) + + torch.testing.assert_close(video_result, torch.full_like(latents, 4.0)) + torch.testing.assert_close(action_result, torch.full_like(action_latents, 44.0)) + assert pipeline.scheduler.step_calls[0][0].shape == (1, latents.numel() + action_latents.numel()) + + +class TestForwardRouting: + def _install_forward_stubs(self, pipeline): + captured: dict[str, object] = {"diffuse_calls": [], "prepare_calls": []} + + def fake_format( + prompt, + negative_prompt, + num_frames, + frame_rate, + height, + width, + max_sequence_length, + sp, + use_system_prompt=False, + is_t2i=False, + ): + captured["format"] = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "num_frames": num_frames, + "frame_rate": frame_rate, + "height": height, + "width": width, + "max_sequence_length": max_sequence_length, + "use_system_prompt": use_system_prompt, + "is_t2i": is_t2i, + "sp": sp, + } + return _ids(2), _mask(), _ids(1), _mask() + + def fake_prepare(height, width, num_frames, generator): + captured["prepare_calls"].append((height, width, num_frames, generator.initial_seed())) + return torch.zeros(1, 2, 1, 1, 1) + + def fake_set_flow_shift(target): + captured.setdefault("flow_shifts", []).append(target) + pipeline._current_flow_shift = target + + def fake_set_scheduler_timesteps(num_inference_steps): + captured.setdefault("scheduler_steps", []).append(num_inference_steps) + pipeline.scheduler.timesteps = torch.tensor([7]) + + def fake_diffuse(**kwargs): + captured["diffuse_calls"].append(kwargs) + outputs = [kwargs["latents"] + len(captured["diffuse_calls"])] + if kwargs.get("action_latents") is not None: + outputs.append(kwargs["action_latents"] + 3.0) + if kwargs.get("sound_latents") is not None: + outputs.append(kwargs["sound_latents"] + 2.0) + return outputs[0] if len(outputs) == 1 else tuple(outputs) + + pipeline._format_and_tokenize_prompts = fake_format # type: ignore[method-assign] + pipeline._prepare_latents = fake_prepare # type: ignore[method-assign] + pipeline._set_flow_shift = fake_set_flow_shift # type: ignore[method-assign] + pipeline._set_scheduler_timesteps = fake_set_scheduler_timesteps # type: ignore[method-assign] + pipeline.diffuse = fake_diffuse # type: ignore[method-assign] + pipeline._decode_latents = lambda latents: latents # type: ignore[method-assign] + return captured + + def _install_sound_stubs(self, pipeline): + sound_latents = torch.zeros(1, 3, 4) + decoded_audio = torch.ones(1, 2, 20) + + def fake_resolve_sound_target_samples(sp, num_frames, frame_rate): + del sp, num_frames, frame_rate + return 20, 2.0, 10 + + def fake_prepare_sound_latents(target_samples, generator): + del target_samples, generator + return sound_latents, 4 + + pipeline._resolve_sound_target_samples = fake_resolve_sound_target_samples # type: ignore[method-assign] + pipeline._prepare_sound_latents = fake_prepare_sound_latents # type: ignore[method-assign] + pipeline._decode_sound_latents = lambda latents, target_samples: decoded_audio # type: ignore[method-assign] + return sound_latents, decoded_audio + + def test_forward_uses_t2i_defaults_and_generates_multiple_outputs(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + captured = self._install_forward_stubs(pipeline) + req = SimpleNamespace( + prompts=[{"prompt": "A painted robot", "modalities": ["image"]}], + sampling_params=make_sampling_params(num_outputs_per_prompt=2), + ) + + output = pipeline.forward(req) + + assert captured["flow_shifts"] == [3.0] + assert captured["scheduler_steps"] == [50, 50] + assert captured["format"]["is_t2i"] is True + assert captured["format"]["height"] == 1024 + assert captured["format"]["width"] == 1024 + assert captured["format"]["num_frames"] == 1 + assert len(captured["diffuse_calls"]) == 2 + assert captured["diffuse_calls"][0]["guidance_interval"] == (400.0, 1000.0) + assert output.output["image"].shape[0] == 2 + + def test_forward_uses_t2v_defaults_and_engine_flow_shift(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + captured = self._install_forward_stubs(pipeline) + req = SimpleNamespace( + prompts=[{"prompt": "A warehouse robot", "modalities": ["video"]}], + sampling_params=make_sampling_params(), + ) + + pipeline.forward(req) + + assert captured["flow_shifts"] == [1.0] + assert captured["scheduler_steps"] == [35] + assert captured["format"]["is_t2i"] is False + assert captured["format"]["height"] == 720 + assert captured["format"]["width"] == 1280 + assert captured["format"]["num_frames"] == 81 + assert captured["diffuse_calls"][0]["guidance_interval"] is None + + def test_forward_defaults_to_video_without_modalities(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + captured = self._install_forward_stubs(pipeline) + req = SimpleNamespace( + prompts=["A warehouse robot"], + sampling_params=make_sampling_params(), + ) + + output = pipeline.forward(req) + + assert captured["format"]["is_t2i"] is False + assert "video" in output.output + + def test_forward_selects_i2v_latents_for_image_conditioning(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + captured = self._install_forward_stubs(pipeline) + image_tensor = torch.zeros(1, 3, 16, 16) + velocity_mask = torch.tensor([[[[[0.0]], [[1.0]]]]]) + image_latent = torch.full((1, 2, 1, 1, 1), 5.0) + + def fake_prepare_i2v(image, height, width, num_frames, generator): + captured["i2v_prepare"] = (image, height, width, num_frames, generator.initial_seed()) + return torch.zeros(1, 2, 2, 1, 1), velocity_mask, image_latent + + def fail_prepare(*args, **kwargs): + del args, kwargs + raise AssertionError("T2V latent preparation should not run for an I2V request") + + pipeline._prepare_latents = fail_prepare # type: ignore[method-assign] + pipeline._prepare_latents_i2v = fake_prepare_i2v # type: ignore[method-assign] + req = SimpleNamespace( + prompts=[ + { + "prompt": "A robot starts moving.", + "modalities": ["video"], + "negative_prompt": "bad", + "additional_information": {"preprocessed_image": image_tensor}, + } + ], + sampling_params=make_sampling_params(height=16, width=16, num_frames=5), + ) + + pipeline.forward(req) + + prepared_image, prepared_height, prepared_width, prepared_frames, _ = captured["i2v_prepare"] + assert prepared_image is image_tensor + assert prepared_height == 16 + assert prepared_width == 16 + assert prepared_frames == 5 + diffuse_call = captured["diffuse_calls"][0] + assert diffuse_call["velocity_mask"] is velocity_mask + assert diffuse_call["image_latent"] is image_latent + assert diffuse_call["shared_kwargs"]["noisy_frame_mask"] is velocity_mask + + def test_forward_policy_action_returns_custom_output(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, action_gen=True, action_dim=4) + captured = self._install_forward_stubs(pipeline) + image_tensor = torch.zeros(1, 3, 16, 16) + req = SimpleNamespace( + prompts=[ + { + "prompt": "Pick the block.", + "modalities": ["video"], + "additional_information": {"preprocessed_image": image_tensor}, + } + ], + sampling_params=make_sampling_params( + height=16, + width=16, + extra_args={ + "action_mode": "policy", + "action_chunk_size": 2, + "raw_action_dim": 2, + "domain_name": "bridge_orig_lerobot", + }, + ), + ) + + output = pipeline.forward(req) + + diffuse_call = captured["diffuse_calls"][0] + assert diffuse_call["action_latents"].shape == (1, 2, 4) + assert diffuse_call["action_velocity_mask"].tolist() == [[[1.0], [1.0]]] + assert diffuse_call["shared_kwargs"]["action_domain_ids"].tolist() == [7] + assert diffuse_call["shared_kwargs"]["action_start_frame_offset"] == 1 + assert output.custom_output["action"].shape == (1, 2, 2) + assert output.custom_output["raw_action_dim"] == 2 + assert output.custom_output["action_mode"] == "policy" + assert output.custom_output["domain_id"] == 7 + + def test_forward_action_defaults_to_reference_chunk_size(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, action_gen=True, action_dim=4) + captured = self._install_forward_stubs(pipeline) + req = SimpleNamespace( + prompts=[ + { + "prompt": "Pick the block.", + "modalities": ["video"], + "additional_information": {"preprocessed_image": torch.zeros(1, 3, 16, 16)}, + } + ], + sampling_params=make_sampling_params( + height=16, + width=16, + extra_args={ + "action_mode": "policy", + "raw_action_dim": 2, + "domain_id": 0, + }, + ), + ) + + pipeline.forward(req) + + assert captured["format"]["num_frames"] == 17 + assert captured["diffuse_calls"][0]["action_latents"].shape == (1, 16, 4) + + def test_forward_video_sound_decodes_and_returns_audio_payload(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, sound_gen=True, sound_dim=3) + captured = self._install_forward_stubs(pipeline) + sound_latents = torch.zeros(1, 3, 4) + decoded_audio = torch.ones(1, 2, 20) + + def fake_resolve_sound_target_samples(sp, num_frames, frame_rate): + del sp, num_frames, frame_rate + return 20, 2.0, 10 + + def fake_prepare_sound_latents(target_samples, generator): + del target_samples, generator + return sound_latents, 4 + + pipeline._resolve_sound_target_samples = fake_resolve_sound_target_samples # type: ignore[method-assign] + pipeline._prepare_sound_latents = fake_prepare_sound_latents # type: ignore[method-assign] + pipeline._decode_sound_latents = lambda latents, target_samples: decoded_audio # type: ignore[method-assign] + + req = SimpleNamespace( + prompts=[{"prompt": "A robot", "modalities": ["video"], "generate_sound": True}], + sampling_params=make_sampling_params(num_frames=9, frame_rate=3.0), + ) + + output = pipeline.forward(req) + + assert captured["diffuse_calls"][0]["sound_latents"] is sound_latents + assert output.output["audio"] is decoded_audio + assert output.output["audio_sample_rate"] == 10 + assert "video" in output.output + + def test_forward_decode_info_logs_only_on_rank_zero( + self, + make_cosmos3_pipeline, + monkeypatch: pytest.MonkeyPatch, + caplog, + ) -> None: + from vllm_omni.diffusion.models.cosmos3 import pipeline_cosmos3 as cosmos3_pipeline + + monkeypatch.setattr(cosmos3_pipeline, "_is_rank_zero", lambda: True) + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, sound_gen=True, sound_dim=3) + self._install_forward_stubs(pipeline) + self._install_sound_stubs(pipeline) + req = SimpleNamespace( + prompts=[{"prompt": "A robot", "modalities": ["video"], "generate_sound": True}], + sampling_params=make_sampling_params(num_frames=9, frame_rate=3.0), + ) + + target_logger = logging.getLogger(cosmos3_pipeline.logger.name) + target_logger.addHandler(caplog.handler) + prev_level = target_logger.level + target_logger.setLevel(logging.INFO) + try: + pipeline.forward(req) + finally: + target_logger.removeHandler(caplog.handler) + target_logger.setLevel(prev_level) + + messages = [record.getMessage() for record in caplog.records if record.name == cosmos3_pipeline.logger.name] + assert "Decoding video..." in messages + assert any(message.startswith("Video decoded in ") for message in messages) + assert any(message.startswith("Total pipeline time: ") for message in messages) + assert "Decoding sound..." in messages + + def test_forward_decode_info_logs_suppressed_on_nonzero_rank( + self, + make_cosmos3_pipeline, + monkeypatch: pytest.MonkeyPatch, + caplog, + ) -> None: + from vllm_omni.diffusion.models.cosmos3 import pipeline_cosmos3 as cosmos3_pipeline + + monkeypatch.setattr(cosmos3_pipeline, "_is_rank_zero", lambda: False) + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, sound_gen=True, sound_dim=3) + self._install_forward_stubs(pipeline) + _, decoded_audio = self._install_sound_stubs(pipeline) + req = SimpleNamespace( + prompts=[{"prompt": "A robot", "modalities": ["video"], "generate_sound": True}], + sampling_params=make_sampling_params(num_frames=9, frame_rate=3.0), + ) + + target_logger = logging.getLogger(cosmos3_pipeline.logger.name) + target_logger.addHandler(caplog.handler) + prev_level = target_logger.level + target_logger.setLevel(logging.INFO) + try: + output = pipeline.forward(req) + finally: + target_logger.removeHandler(caplog.handler) + target_logger.setLevel(prev_level) + + messages = [record.getMessage() for record in caplog.records if record.name == cosmos3_pipeline.logger.name] + assert output.output["audio"] is decoded_audio + assert not any( + message == "Decoding video..." + or message.startswith("Video decoded in ") + or message.startswith("Total pipeline time: ") + or message == "Decoding sound..." + for message in messages + ) + + def test_forward_rejects_multiple_prompts(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + req = SimpleNamespace( + prompts=["one", "two"], + sampling_params=make_sampling_params(), + ) + + with pytest.raises(ValueError, match="currently supports a single prompt"): + pipeline.forward(req) + + def test_forward_rejects_conflicting_modalities(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + req = SimpleNamespace( + prompts=[{"prompt": "one", "modalities": ["image", "video"]}], + sampling_params=make_sampling_params(), + ) + + with pytest.raises(ValueError, match="cannot request both image and video"): + pipeline.forward(req) + + def test_forward_rejects_sound_for_text_to_image(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, sound_gen=True, sound_dim=3) + req = SimpleNamespace( + prompts=[{"prompt": "A robot", "modalities": ["image"], "generate_sound": True}], + sampling_params=make_sampling_params(), + ) + + with pytest.raises(ValueError, match="only for video outputs"): + pipeline.forward(req) + + def test_forward_rejects_action_without_action_modules(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + req = SimpleNamespace( + prompts=[{"prompt": "A robot", "modalities": ["video"]}], + sampling_params=make_sampling_params(extra_args={"action_mode": "policy", "raw_action_dim": 2}), + ) + + with pytest.raises(ValueError, match="without action modules"): + pipeline.forward(req) + + def test_forward_rejects_action_without_explicit_domain(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__(latent_channel_size=2, action_gen=True, action_dim=4) + req = SimpleNamespace( + prompts=[ + { + "prompt": "A robot", + "modalities": ["video"], + "additional_information": {"preprocessed_image": torch.zeros(1, 3, 16, 16)}, + } + ], + sampling_params=make_sampling_params( + height=16, + width=16, + extra_args={"action_mode": "policy", "raw_action_dim": 2}, + ), + ) + + with pytest.raises(ValueError, match=r"domain_id.*domain_name"): + pipeline.forward(req) + + def test_forward_rejects_action_with_sound(self, make_cosmos3_pipeline) -> None: + pipeline = make_cosmos3_pipeline() + pipeline.transformer = pipeline.transformer.__class__( + latent_channel_size=2, + action_gen=True, + action_dim=4, + sound_gen=True, + sound_dim=3, + ) + req = SimpleNamespace( + prompts=[{"prompt": "A robot", "modalities": ["video"], "generate_sound": True}], + sampling_params=make_sampling_params(extra_args={"action_mode": "policy", "raw_action_dim": 2}), + ) + + with pytest.raises(ValueError, match=r"action\+sound"): + pipeline.forward(req) diff --git a/tests/diffusion/models/cosmos3/test_cosmos3_transformer.py b/tests/diffusion/models/cosmos3/test_cosmos3_transformer.py new file mode 100644 index 00000000000..49b5821347c --- /dev/null +++ b/tests/diffusion/models/cosmos3/test_cosmos3_transformer.py @@ -0,0 +1,577 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch +from torch import nn + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +def test_compute_mrope_position_ids_text_offsets_all_axes() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import ( + compute_mrope_position_ids_text, + ) + + ids, next_offset = compute_mrope_position_ids_text(num_tokens=3, temporal_offset=5) + + assert ids.tolist() == [[5, 6, 7], [5, 6, 7], [5, 6, 7]] + assert next_offset == 8 + + +def test_compute_mrope_position_ids_vision_without_fps_modulation() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import ( + compute_mrope_position_ids_vision, + ) + + ids, next_offset = compute_mrope_position_ids_vision( + grid_t=2, + grid_h=2, + grid_w=3, + temporal_offset=10, + fps=None, + ) + + assert ids.shape == (3, 12) + assert ids[0].tolist() == [10] * 6 + [11] * 6 + assert ids[1].tolist() == [0, 0, 0, 1, 1, 1] * 2 + assert ids[2].tolist() == [0, 1, 2, 0, 1, 2] * 2 + assert next_offset == 12 + + +def test_compute_mrope_position_ids_vision_with_fps_modulation() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import ( + compute_mrope_position_ids_vision, + ) + + ids, next_offset = compute_mrope_position_ids_vision( + grid_t=2, + grid_h=1, + grid_w=1, + temporal_offset=10, + fps=12.0, + base_fps=24.0, + temporal_compression_factor=4, + ) + + torch.testing.assert_close(ids[0], torch.tensor([10.0, 12.0])) + assert ids.dtype == torch.float32 + assert next_offset == 13 + + +def test_compute_mrope_position_ids_sound_uses_sound_latent_fps() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import ( + compute_mrope_position_ids_sound, + ) + + ids, next_offset = compute_mrope_position_ids_sound( + grid_t=3, + temporal_offset=10, + sound_latent_fps=24.0, + base_fps=24.0, + base_temporal_compression_factor=4, + ) + + torch.testing.assert_close(ids[0], torch.tensor([10.0, 10.25, 10.5])) + assert ids[1].tolist() == [0.0, 0.0, 0.0] + assert ids[2].tolist() == [0.0, 0.0, 0.0] + assert next_offset == 11 + + +def test_compute_mrope_position_ids_action_uses_start_frame_offset() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import ( + compute_mrope_position_ids_action, + ) + + ids, next_offset = compute_mrope_position_ids_action( + grid_t=3, + temporal_offset=10, + action_fps=None, + start_frame_offset=1, + ) + + assert ids.tolist() == [[11, 12, 13], [0, 0, 0], [0, 0, 0]] + assert next_offset == 14 + + +@pytest.mark.parametrize( + ("key", "value"), + [ + ("qk_norm_for_diffusion", False), + ("qk_norm_for_text", False), + ("position_embedding_type", "rotary"), + ("unified_3d_mrope_reset_spatial_ids", False), + ("joint_attn_implementation", "one_way"), + ], +) +def test_validate_supported_config_rejects_unsupported_flags(key: str, value) -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + with pytest.raises(ValueError, match=f"{key}="): + Cosmos3VFMTransformer._validate_supported_config({key: value}) + + +def test_validate_supported_config_accepts_defaults() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + Cosmos3VFMTransformer._validate_supported_config({}) + Cosmos3VFMTransformer._validate_supported_config(None) + + +def test_cosmos3_hsdp_conditions_match_und_and_gen_blocks() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.language_model = nn.Module() + model.language_model.layers = nn.ModuleList([nn.Linear(2, 2) for _ in range(2)]) + model.gen_layers = nn.ModuleList([nn.Linear(2, 2)]) + model.norm_moe_gen = nn.LayerNorm(2) + + conditions = model._hsdp_shard_conditions + matched = [ + name for name, module in model.named_modules() if any(condition(name, module) for condition in conditions) + ] + + assert matched == [ + "language_model.layers.0", + "language_model.layers.1", + "gen_layers.0", + ] + + +def test_cosmos3_transformer_exposes_layerwise_offload_and_repeated_blocks() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + assert Cosmos3VFMTransformer._layerwise_offload_blocks_attr == "gen_layers" + assert Cosmos3VFMTransformer._repeated_blocks == ["Cosmos3GenDecoderLayer"] + + +def test_patchify_unpatchify_round_trip_crops_padding() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.latent_patch_size = 2 + model.latent_channel_size = 3 + + latents = torch.arange(1 * 3 * 1 * 3 * 5, dtype=torch.float32).reshape(1, 3, 1, 3, 5) + + tokens = model.patchify(latents, t=1, h=3, w=5) + restored = model.unpatchify(tokens, t=1, h=3, w=5) + + assert tokens.shape == (1, 6, 12) + torch.testing.assert_close(restored, latents) + + +def _tiny_cosmos3_config(**overrides): + config = { + "hidden_size": 8, + "num_hidden_layers": 0, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 4, + "intermediate_size": 16, + "vocab_size": 32, + "latent_patch_size": 1, + "latent_channel": 2, + "rope_scaling": {"mrope_section": [1, 1, 0]}, + } + config.update(overrides) + return config + + +def test_sound_modules_created_only_when_sound_config_present() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + tiny = _tiny_cosmos3_config() + + no_sound = Cosmos3VFMTransformer(SimpleNamespace(tf_model_config=tiny, dtype=torch.float32)) + explicit_disabled = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config={**tiny, "sound_gen": False, "sound_dim": 3}, + dtype=torch.float32, + ) + ) + with_sound = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config={**tiny, "sound_gen": True, "sound_dim": 3}, + dtype=torch.float32, + ) + ) + + assert no_sound.sound_gen is False + assert not hasattr(no_sound, "sound2llm") + assert explicit_disabled.sound_gen is False + assert not hasattr(explicit_disabled, "sound2llm") + assert with_sound.sound_gen is True + assert with_sound.sound2llm.in_features == 3 + assert with_sound.llm2sound.out_features == 3 + assert tuple(with_sound.sound_modality_embed.shape) == (8,) + + +def test_action_modules_created_only_when_action_config_present() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + tiny = _tiny_cosmos3_config() + + no_action = Cosmos3VFMTransformer(SimpleNamespace(tf_model_config=tiny, dtype=torch.float32)) + explicit_disabled = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config={**tiny, "action_gen": False, "max_action_dim": 6}, + dtype=torch.float32, + ) + ) + with_action = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config={**tiny, "action_gen": True, "max_action_dim": 6, "num_embodiment_domains": 9}, + dtype=torch.float32, + ) + ) + + assert no_action.action_gen is False + assert not hasattr(no_action, "action2llm") + assert explicit_disabled.action_gen is False + assert not hasattr(explicit_disabled, "action2llm") + assert with_action.action_gen is True + assert with_action.action_dim == 6 + assert with_action.action2llm.num_domains == 9 + assert tuple(with_action.action_modality_embed.shape) == (8,) + + +def test_sound_latent_fps_derives_from_sound_tokenizer_config() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + tiny = _tiny_cosmos3_config(sound_gen=True, sound_dim=3) + + derived = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config=tiny, + custom_pipeline_args={"sound_sample_rate": 32000, "sound_hop_size": 800}, + dtype=torch.float32, + ) + ) + explicit = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config=tiny, + custom_pipeline_args={ + "sound_sample_rate": 32000, + "sound_hop_size": 800, + "sound_latent_fps": 12.5, + }, + dtype=torch.float32, + ) + ) + + assert derived.sound_latent_fps == 40.0 + assert explicit.sound_latent_fps == 12.5 + + +def test_pack_unpack_sound_round_trip_and_shape_validation() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.sound_dim = 3 + + latents = torch.arange(2 * 3 * 4, dtype=torch.float32).reshape(2, 3, 4) + tokens = model.pack_sound(latents) + restored = model.unpack_sound(tokens) + + assert tokens.shape == (2, 4, 3) + torch.testing.assert_close(restored, latents) + with pytest.raises(ValueError, match="channel mismatch"): + model.pack_sound(torch.zeros(1, 4, 2)) + + +def test_pack_unpack_action_round_trip_and_shape_validation() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.action_dim = 3 + + latents = torch.arange(2 * 4 * 3, dtype=torch.float32).reshape(2, 4, 3) + tokens = model.pack_action(latents) + restored = model.unpack_action(tokens) + + assert tokens.shape == (2, 4, 3) + torch.testing.assert_close(restored, latents) + with pytest.raises(ValueError, match="dimension mismatch"): + model.pack_action(torch.zeros(1, 2, 4)) + + +def test_forward_with_sound_returns_video_and_sound_predictions() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + model = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config=_tiny_cosmos3_config(sound_gen=True, sound_dim=3, sound_latent_fps=24.0), + dtype=torch.float32, + ) + ) + + video = torch.zeros(1, 2, 1, 2, 2) + sound = torch.zeros(1, 3, 4) + output = model( + hidden_states=video, + timestep=torch.tensor([1.0]), + text_ids=torch.tensor([[1, 2]], dtype=torch.long), + text_mask=torch.ones(1, 2, dtype=torch.long), + video_shape=(1, 2, 2), + fps=24.0, + sound_latents=sound, + ) + + assert isinstance(output, tuple) + video_pred, sound_pred = output + assert video_pred.shape == video.shape + assert sound_pred.shape == sound.shape + + +def test_forward_with_action_returns_video_and_action_predictions() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + model = Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config=_tiny_cosmos3_config( + action_gen=True, + max_action_dim=3, + num_embodiment_domains=4, + ), + dtype=torch.float32, + ) + ) + + video = torch.zeros(1, 2, 1, 2, 2) + action = torch.zeros(1, 5, 3) + output = model( + hidden_states=video, + timestep=torch.tensor([1.0]), + text_ids=torch.tensor([[1, 2]], dtype=torch.long), + text_mask=torch.ones(1, 2, dtype=torch.long), + video_shape=(1, 2, 2), + fps=24.0, + action_latents=action, + action_domain_ids=torch.tensor([2]), + action_noisy_mask=torch.ones(1, 5, 1), + ) + + assert isinstance(output, tuple) + video_pred, action_pred = output + assert video_pred.shape == video.shape + assert action_pred.shape == action.shape + + +def test_forward_with_sound_ulysses_error_mentions_combined_sequence(monkeypatch: pytest.MonkeyPatch) -> None: + import vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 as cosmos3_module + + model = cosmos3_module.Cosmos3VFMTransformer( + SimpleNamespace( + tf_model_config=_tiny_cosmos3_config(sound_gen=True, sound_dim=3), + dtype=torch.float32, + ) + ) + monkeypatch.setattr(cosmos3_module, "_get_ulysses_state", lambda: (2, 0, None)) + + with pytest.raises( + ValueError, + match=r"GEN sequence length \(3 = video tokens 2 \+ sound tokens 1\).*combined media sequence", + ): + model( + hidden_states=torch.zeros(1, 2, 1, 1, 2), + timestep=torch.tensor([1.0]), + text_ids=torch.tensor([[1, 2]], dtype=torch.long), + text_mask=torch.ones(1, 2, dtype=torch.long), + video_shape=(1, 1, 2), + fps=24.0, + sound_latents=torch.zeros(1, 3, 1), + ) + + +def test_reset_cache_clears_und_and_gen_cache() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.cached_kv = object() + model.cached_freqs_gen = object() + + model.reset_cache() + + assert model.cached_kv is None + assert model.cached_freqs_gen is None + + +def test_compute_rope_freqs_pads_text_and_offsets_vision_positions() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + class FakeRotary: + def __init__(self) -> None: + self.position_ids: list[torch.Tensor] = [] + + def __call__(self, x, position_ids): + del x + self.position_ids.append(position_ids.detach().cpu()) + batch = position_ids.shape[1] + seq = position_ids.shape[2] + return torch.zeros(batch, seq, 4), torch.ones(batch, seq, 4) + + rotary = FakeRotary() + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.language_model = SimpleNamespace(rotary_emb=rotary) + model.temporal_modality_margin = 100 + model.base_fps = 24.0 + model.temporal_compression_factor = 4 + model.enable_fps_modulation = False + + freqs_und, freqs_gen = model._compute_rope_freqs( + text_mask=torch.tensor([[1, 1, 0], [1, 0, 0]], dtype=torch.long), + t=2, + hp=1, + wp=1, + fps=24.0, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + text_pos, vision_pos = rotary.position_ids + assert text_pos[:, 0, :].tolist() == [[0, 1, 0], [0, 1, 0], [0, 1, 0]] + assert text_pos[:, 1, :].tolist() == [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + assert vision_pos[0, 0].tolist() == [102, 103] + assert vision_pos[0, 1].tolist() == [101, 102] + assert freqs_und[0].shape == (2, 3, 1, 4) + assert freqs_gen[0].shape == (2, 2, 1, 4) + + +def test_compute_rope_freqs_appends_sound_positions_after_vision() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + class FakeRotary: + def __init__(self) -> None: + self.position_ids: list[torch.Tensor] = [] + + def __call__(self, x, position_ids): + del x + self.position_ids.append(position_ids.detach().cpu()) + batch = position_ids.shape[1] + seq = position_ids.shape[2] + return torch.zeros(batch, seq, 4), torch.ones(batch, seq, 4) + + rotary = FakeRotary() + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.language_model = SimpleNamespace(rotary_emb=rotary) + model.temporal_modality_margin = 100 + model.base_fps = 24.0 + model.temporal_compression_factor = 4 + model.enable_fps_modulation = True + model.sound_latent_fps = 24.0 + + model._compute_rope_freqs( + text_mask=torch.tensor([[1, 1]], dtype=torch.long), + t=2, + hp=1, + wp=1, + fps=24.0, + device=torch.device("cpu"), + dtype=torch.float32, + t_sound=3, + ) + + _, gen_pos = rotary.position_ids + assert gen_pos.shape == (3, 1, 5) + torch.testing.assert_close( + gen_pos[0, 0], + torch.tensor([102.0, 103.0, 102.0, 102.25, 102.5]), + ) + + +def test_compute_rope_freqs_appends_action_positions_between_vision_and_sound() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + class FakeRotary: + def __init__(self) -> None: + self.position_ids: list[torch.Tensor] = [] + + def __call__(self, x, position_ids): + del x + self.position_ids.append(position_ids.detach().cpu()) + batch = position_ids.shape[1] + seq = position_ids.shape[2] + return torch.zeros(batch, seq, 4), torch.ones(batch, seq, 4) + + rotary = FakeRotary() + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.language_model = SimpleNamespace(rotary_emb=rotary) + model.temporal_modality_margin = 100 + model.base_fps = 24.0 + model.temporal_compression_factor = 4 + model.enable_fps_modulation = False + model.sound_latent_fps = 24.0 + + model._compute_rope_freqs( + text_mask=torch.tensor([[1, 1]], dtype=torch.long), + t=2, + hp=1, + wp=1, + fps=24.0, + device=torch.device("cpu"), + dtype=torch.float32, + t_action=2, + action_start_frame_offset=1, + t_sound=1, + ) + + _, gen_pos = rotary.position_ids + assert gen_pos.shape == (3, 1, 5) + assert gen_pos[0, 0].tolist() == [102, 103, 103, 104, 102] + + +def test_compute_rope_freqs_promotes_mixed_video_sound_position_dtypes() -> None: + from vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 import Cosmos3VFMTransformer + + class FakeRotary: + def __init__(self) -> None: + self.position_ids: list[torch.Tensor] = [] + + def __call__(self, x, position_ids): + del x + self.position_ids.append(position_ids.detach().cpu()) + batch = position_ids.shape[1] + seq = position_ids.shape[2] + return torch.zeros(batch, seq, 4), torch.ones(batch, seq, 4) + + rotary = FakeRotary() + model = object.__new__(Cosmos3VFMTransformer) + nn.Module.__init__(model) + model.language_model = SimpleNamespace(rotary_emb=rotary) + model.temporal_modality_margin = 100 + model.base_fps = 24.0 + model.temporal_compression_factor = 4 + model.enable_fps_modulation = True + model.sound_latent_fps = 24.0 + + model._compute_rope_freqs( + text_mask=torch.tensor([[1, 1]], dtype=torch.long), + t=1, + hp=1, + wp=1, + fps=None, + device=torch.device("cpu"), + dtype=torch.float32, + t_sound=3, + ) + + _, gen_pos = rotary.position_ids + assert gen_pos.dtype == torch.float32 + torch.testing.assert_close( + gen_pos[0, 0], + torch.tensor([102.0, 102.0, 102.25, 102.5]), + ) diff --git a/tests/diffusion/test_diffusion_ipc.py b/tests/diffusion/test_diffusion_ipc.py new file mode 100644 index 00000000000..43e96b834f6 --- /dev/null +++ b/tests/diffusion/test_diffusion_ipc.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.ipc import pack_diffusion_output_shm, unpack_diffusion_output_shm + + +def test_diffusion_output_dict_tensors_round_trip_through_shm() -> None: + image = torch.arange(300_000, dtype=torch.float32) + video = torch.arange(300_000, dtype=torch.float32) * 2 + output = DiffusionOutput(output={"image": image, "video": video, "metadata": {"keep": "inline"}}) + + pack_diffusion_output_shm(output) + + assert output.output["image"]["__tensor_shm__"] is True + assert output.output["video"]["__tensor_shm__"] is True + assert output.output["metadata"] == {"keep": "inline"} + + unpack_diffusion_output_shm(output) + + torch.testing.assert_close(output.output["image"], image) + torch.testing.assert_close(output.output["video"], video) + assert output.output["metadata"] == {"keep": "inline"} diff --git a/tests/e2e/accuracy/test_cosmos3_similarity.py b/tests/e2e/accuracy/test_cosmos3_similarity.py new file mode 100644 index 00000000000..166c56a9318 --- /dev/null +++ b/tests/e2e/accuracy/test_cosmos3_similarity.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import base64 +import io +import json +import os +from pathlib import Path + +import pytest +import requests +import torch +from PIL import Image + +from tests.e2e.accuracy.helpers import model_output_dir +from tests.helpers.mark import hardware_test +from tests.helpers.runtime import OmniServer + +pytestmark = [pytest.mark.full_model, pytest.mark.diffusion] + + +MODEL_ENV_VAR = "VLLM_TEST_COSMOS3_MODEL" +MODEL_ID = "cosmos3" +PROMPT = "A small warehouse robot moves a blue box across a clean floor." +NEGATIVE_PROMPT = "blurry, distorted, low quality" +SEED = 42 +WIDTH = 256 +HEIGHT = 256 +NUM_INFERENCE_STEPS = 2 + + +def _model_name() -> str: + model = os.environ.get(MODEL_ENV_VAR) + if not model: + pytest.skip(f"Set {MODEL_ENV_VAR} to run Cosmos3 full-model smoke tests.") + return model + + +def _server_args() -> list[str]: + return [ + "--num-gpus", + "1", + "--model-class-name", + "Cosmos3OmniDiffusersPipeline", + "--stage-init-timeout", + "900", + "--init-timeout", + "1200", + ] + + +def _image_data_url(image: Image.Image) -> str: + buf = io.BytesIO() + image.save(buf, format="PNG") + encoded = base64.b64encode(buf.getvalue()).decode("ascii") + return f"data:image/png;base64,{encoded}" + + +@pytest.mark.benchmark +@hardware_test(res={"cuda": "H100"}, num_cards=1) +def test_cosmos3_t2i_serving_smoke(accuracy_artifact_root: Path) -> None: + if not torch.cuda.is_available(): + pytest.skip("Cosmos3 full-model smoke tests require CUDA.") + + model = _model_name() + output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID) + with OmniServer(model, _server_args(), use_omni=True) as server: + response = requests.post( + f"http://{server.host}:{server.port}/v1/images/generations", + json={ + "model": server.model, + "prompt": PROMPT, + "negative_prompt": NEGATIVE_PROMPT, + "size": f"{WIDTH}x{HEIGHT}", + "n": 1, + "response_format": "b64_json", + "num_inference_steps": NUM_INFERENCE_STEPS, + "guidance_scale": 1.0, + "seed": SEED, + }, + timeout=1800, + ) + + response.raise_for_status() + payload = response.json() + assert len(payload["data"]) == 1 + image = Image.open(io.BytesIO(base64.b64decode(payload["data"][0]["b64_json"]))).convert("RGB") + image.save(output_dir / "cosmos3_t2i.png") + assert image.size == (WIDTH, HEIGHT) + + +@pytest.mark.benchmark +@hardware_test(res={"cuda": "H100"}, num_cards=1) +def test_cosmos3_t2v_sync_serving_smoke(accuracy_artifact_root: Path) -> None: + if not torch.cuda.is_available(): + pytest.skip("Cosmos3 full-model smoke tests require CUDA.") + + model = _model_name() + output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID) + with OmniServer(model, _server_args(), use_omni=True) as server: + response = requests.post( + f"http://{server.host}:{server.port}/v1/videos/sync", + data={ + "model": server.model, + "prompt": PROMPT, + "negative_prompt": NEGATIVE_PROMPT, + "size": f"{WIDTH}x{HEIGHT}", + "num_frames": "1", + "fps": "1", + "num_inference_steps": str(NUM_INFERENCE_STEPS), + "guidance_scale": "1.0", + "seed": str(SEED), + }, + timeout=1800, + ) + + response.raise_for_status() + assert response.headers["content-type"].startswith("video/mp4") + assert response.content + (output_dir / "cosmos3_t2v.mp4").write_bytes(response.content) + + +@pytest.mark.benchmark +@hardware_test(res={"cuda": "H100"}, num_cards=1) +def test_cosmos3_i2v_sync_serving_smoke(accuracy_artifact_root: Path) -> None: + if not torch.cuda.is_available(): + pytest.skip("Cosmos3 full-model smoke tests require CUDA.") + + model = _model_name() + output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID) + reference = Image.new("RGB", (96, 64), color=(40, 80, 160)) + with OmniServer(model, _server_args(), use_omni=True) as server: + response = requests.post( + f"http://{server.host}:{server.port}/v1/videos/sync", + data={ + "model": server.model, + "prompt": "The blue rectangle moves slowly forward.", + "negative_prompt": NEGATIVE_PROMPT, + "image_reference": json.dumps({"image_url": _image_data_url(reference)}), + "size": f"{WIDTH}x{HEIGHT}", + "num_frames": "5", + "fps": "1", + "num_inference_steps": str(NUM_INFERENCE_STEPS), + "guidance_scale": "1.0", + "seed": str(SEED), + }, + timeout=1800, + ) + + response.raise_for_status() + assert response.headers["content-type"].startswith("video/mp4") + assert response.content + (output_dir / "cosmos3_i2v.mp4").write_bytes(response.content) diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 40adb7a9151..cdf6ca3f8a7 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -495,6 +495,7 @@ def test_generate_single_image(test_client): img_bytes = base64.b64decode(data["data"][0]["b64_json"]) img = Image.open(io.BytesIO(img_bytes)) assert img.size == (64, 64) # Our mock returns 64x64 images + assert test_client.app.state.engine_client.captured_prompt["modalities"] == ["image"] def test_generate_images_async_omni_sampling_params(async_omni_test_client): diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py index a29f4493c28..57a09872397 100644 --- a/tests/entrypoints/openai_api/test_video_server.py +++ b/tests/entrypoints/openai_api/test_video_server.py @@ -13,6 +13,7 @@ import time from types import SimpleNamespace +import numpy as np import pytest from fastapi import FastAPI from fastapi.testclient import TestClient @@ -243,6 +244,7 @@ def _fake_encode(video, fps, audio=None, audio_sample_rate=None, **kwargs): _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value) engine = test_client.app.state.openai_serving_video._engine_client + assert engine.captured_prompt["modalities"] == ["video"] captured = engine.captured_sampling_params_list[0] assert captured.num_outputs_per_prompt == 1 assert captured.width == 640 @@ -398,6 +400,8 @@ def test_sampling_params_pass_through(test_client, mocker: MockerFixture): "true_cfg_scale": "4.0", "boundary_ratio": "0.7", "flow_shift": "0.25", + "generate_sound": "true", + "sound_duration": "2.5", }, ) @@ -412,6 +416,8 @@ def test_sampling_params_pass_through(test_client, mocker: MockerFixture): assert captured.true_cfg_scale == 4.0 assert captured.boundary_ratio == 0.7 assert captured.extra_args["flow_shift"] == 0.25 + assert captured.extra_args["generate_sound"] is True + assert captured.extra_args["sound_duration"] == 2.5 def test_frame_interpolation_params_pass_to_diffusion_sampling_params(test_client, mocker: MockerFixture): @@ -622,6 +628,109 @@ async def _generate(prompt, request_id, sampling_params_list): assert completed["stage_durations"] == {"diffuse": 2.5, "vae.decode": 0.3} assert completed["peak_memory_mb"] == 4096.5 + assert completed["action"] is None + + +def test_video_generation_response_exposes_action_payload(mocker: MockerFixture): + engine = FakeAsyncOmni() + handler = OmniOpenAIServingVideo.for_diffusion( + diffusion_engine=engine, + model_name="Cosmos3-8B-UVA", + ) + + async def _generate(prompt, request_id, sampling_params_list): + del prompt, request_id, sampling_params_list + yield MockVideoResult( + [object()], + custom_output={ + "action": np.array([[[1.5, 2.5], [3.5, 4.5]]], dtype=np.float32), + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + }, + ) + + engine.generate = _generate + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video.encode_video_base64", + return_value="encoded-video", + ) + + response = asyncio.run( + handler.generate_videos( + VideoGenerationRequest(prompt="predict actions"), + "action-json", + ) + ) + + action = response.data[0].action + assert action is not None + assert action.data == [[1.5, 2.5], [3.5, 4.5]] + assert action.shape == [2, 2] + assert action.dtype == "float32" + assert action.raw_action_dim == 2 + assert action.action_mode == "policy" + assert action.domain_id == 7 + assert response.model_dump(mode="json")["data"][0]["action"]["data"] == [[1.5, 2.5], [3.5, 4.5]] + + +def test_video_job_persists_action_metadata(test_client, mocker: MockerFixture): + engine = test_client.app.state.openai_serving_video._engine_client + + async def _generate(prompt, request_id, sampling_params_list): + engine.captured_prompt = prompt + engine.captured_sampling_params_list = sampling_params_list + yield MockVideoResult( + [object()], + custom_output={ + "action": np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32), + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + }, + ) + + engine.generate = _generate + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes", + return_value=b"fake-video", + ) + + response = test_client.post("/v1/videos", data={"prompt": "profile me"}) + assert response.status_code == 200 + video_id = response.json()["id"] + completed = _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value) + + expected_action = { + "data": [[1.0, 2.0], [3.0, 4.0]], + "shape": [2, 2], + "dtype": "float32", + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + } + assert completed["action"] == expected_action + + listed = test_client.get("/v1/videos").json() + assert listed["data"][0]["action"] == expected_action + + +def test_action_extraction_accepts_unbatched_action(): + result = MockVideoResult( + [object()], + custom_output={ + "action": np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), + "raw_action_dim": 2, + "action_mode": "policy", + "domain_id": 7, + }, + ) + + actions = OmniOpenAIServingVideo._extract_action_outputs(result, expected_count=1) + + assert actions[0] is not None + assert actions[0].data == [[1.0, 2.0], [3.0, 4.0]] + assert actions[0].shape == [2, 2] def test_missing_handler_returns_503(): @@ -755,6 +864,9 @@ def test_invalid_uploaded_input_reference_returns_400(test_client): def test_video_request_validation(): req = VideoGenerationRequest(prompt="test") assert req.prompt == "test" + assert req.generate_sound is False + assert req.sound_duration is None + assert VideoGenerationRequest(prompt="test", generate_sound=True, sound_duration=1.5).generate_sound is True with pytest.raises(ValueError): VideoGenerationRequest(prompt="test", size="invalid") @@ -767,6 +879,8 @@ def test_video_request_validation(): VideoGenerationRequest(prompt="test", frame_interpolation_exp=0) with pytest.raises(ValueError): VideoGenerationRequest(prompt="test", frame_interpolation_scale=0) + with pytest.raises(ValueError): + VideoGenerationRequest(prompt="test", sound_duration=0) def test_list_videos_supports_order_after_and_limit(test_client, mocker: MockerFixture): @@ -1063,6 +1177,8 @@ def test_sync_t2v_returns_video_bytes(test_client, mocker: MockerFixture): assert float(response.headers["x-inference-time-s"]) >= 0 assert json.loads(response.headers["x-stage-durations"]) == {} assert float(response.headers["x-peak-memory-mb"]) == 0.0 + engine = test_client.app.state.openai_serving_video._engine_client + assert engine.captured_prompt["modalities"] == ["video"] def test_sync_t2v_returns_profiler_headers(test_client, mocker: MockerFixture): diff --git a/vllm_omni/diffusion/attention/backends/sdpa.py b/vllm_omni/diffusion/attention/backends/sdpa.py index ab71e753b25..c650313698d 100644 --- a/vllm_omni/diffusion/attention/backends/sdpa.py +++ b/vllm_omni/diffusion/attention/backends/sdpa.py @@ -91,6 +91,8 @@ def __init__( self.softmax_scale = softmax_scale if backend_kwargs: logger.warning("SDPAImpl ignoring backend_kwargs: %s", list(backend_kwargs.keys())) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads def _forward_impl( self, @@ -115,6 +117,7 @@ def _forward_impl( dropout_p=0.0, is_causal=self.causal, scale=self.softmax_scale, + enable_gqa=self.num_heads != self.num_kv_heads, ) out = output.permute(0, 2, 1, 3) return out diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index 436ea29664a..9e9276e467d 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -1438,6 +1438,76 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_cosmos3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Cosmos3 (T2V and I2V). + + Cosmos3 has a dual-pathway architecture (UND + GEN) but only the GEN + pathway (``gen_layers``) runs at every denoising step. The UND pathway + computes once and its K/V are cached by the pipeline itself; no cache-dit + needed there. We wrap only ``gen_layers`` via ``BlockAdapter``. + + Args: + pipeline: The Cosmos3 pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + db_cache_config = _build_db_cache_config(cache_config) + + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + logger.info( + f"Enabling cache-dit on Cosmos3 gen_layers: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + cache_dit.enable_cache( + BlockAdapter( + transformer=pipeline.transformer, + blocks=[pipeline.transformer.gen_layers], + # Cosmos3 GEN blocks return only hidden_states. Per-layer UND K/V + # conditioning uses the transformer's cache-dit fallback path. + forward_pattern=[ForwardPattern.Pattern_3], + params_modifiers=[ + ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ), + ], + check_forward_pattern=False, + has_separate_cfg=True, + ), + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + # Register custom cache-dit enablers after function definitions CUSTOM_DIT_ENABLERS.update( { @@ -1463,6 +1533,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "ErnieImagePipeline": enable_cache_for_ernie_image, "HunyuanVideo15Pipeline": enable_cache_for_hunyuan_video_15, "HunyuanVideo15I2VPipeline": enable_cache_for_hunyuan_video_15, + "Cosmos3OmniDiffusersPipeline": enable_cache_for_cosmos3, } ) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index fe4a4c77e5e..17259467a64 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -83,6 +83,18 @@ def supports_audio_output(model_class_name: str) -> bool: return bool(getattr(model_cls, "support_audio_output", False)) +def _move_tensor_tree_to_cpu(value: object) -> object: + if isinstance(value, torch.Tensor): + return value.cpu() if value.device.type != "cpu" else value + if isinstance(value, dict): + return {key: _move_tensor_tree_to_cpu(item) for key, item in value.items()} + if isinstance(value, list): + return [_move_tensor_tree_to_cpu(item) for item in value] + if isinstance(value, tuple): + return tuple(_move_tensor_tree_to_cpu(item) for item in value) + return value + + def get_extra_body_params(model_class_name: str) -> frozenset[str]: """Return the set of extra_body keys accepted by a pipeline. @@ -223,12 +235,8 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: # post-processing to avoid device OOM — model weights may still # reside on the device and leave no headroom for intermediates. output_data = output.output - if ( - self.od_config.enable_cpu_offload - and isinstance(output_data, torch.Tensor) - and output_data.device.type != "cpu" - ): - output_data = output_data.cpu() + if self.od_config.enable_cpu_offload: + output_data = _move_tensor_tree_to_cpu(output_data) postprocess_start_time = time.perf_counter() if self.post_process_func is not None: @@ -249,7 +257,10 @@ async def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: custom_output.update(outputs.get("custom_output") or {}) model_audio_sample_rate = outputs.get("audio_sample_rate") model_fps = outputs.get("fps") - outputs = outputs.get("video", outputs) + if "image" in outputs: + outputs = outputs["image"] + elif "video" in outputs: + outputs = outputs["video"] postprocess_time = time.perf_counter() - postprocess_start_time logger.debug("Post-processing completed in %.4f seconds", postprocess_time) diff --git a/vllm_omni/diffusion/ipc.py b/vllm_omni/diffusion/ipc.py index 6a96533fd40..d4989da3d9e 100644 --- a/vllm_omni/diffusion/ipc.py +++ b/vllm_omni/diffusion/ipc.py @@ -85,16 +85,26 @@ def _pack_tensor_if_large(val: torch.Tensor) -> torch.Tensor | dict: return val +def _pack_value_if_large(val: object) -> object: + if isinstance(val, torch.Tensor): + return _pack_tensor_if_large(val) + if isinstance(val, dict): + return {key: _pack_value_if_large(value) for key, value in val.items()} + return val + + def _unpack_if_shm_handle(val: object) -> object: """Reconstruct a tensor from an SHM handle dict, or return as-is.""" if isinstance(val, dict) and val.get("__tensor_shm__"): return _tensor_from_shm(val) + if isinstance(val, dict): + return {key: _unpack_if_shm_handle(value) for key, value in val.items()} return val def _pack_diffusion_fields(output: DiffusionOutput) -> DiffusionOutput: - if output.output is not None and isinstance(output.output, torch.Tensor): - output.output = _pack_tensor_if_large(output.output) + if output.output is not None: + output.output = _pack_value_if_large(output.output) if output.trajectory_latents is not None and isinstance(output.trajectory_latents, torch.Tensor): output.trajectory_latents = _pack_tensor_if_large(output.trajectory_latents) if output.trajectory_timesteps is not None and isinstance(output.trajectory_timesteps, torch.Tensor): diff --git a/vllm_omni/diffusion/models/cosmos3/__init__.py b/vllm_omni/diffusion/models/cosmos3/__init__.py new file mode 100644 index 00000000000..6df062b5c0d --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .pipeline_cosmos3 import ( + Cosmos3OmniDiffusersPipeline, + get_cosmos3_post_process_func, + get_cosmos3_pre_process_func, +) +from .transformer_cosmos3 import Cosmos3VFMTransformer + +__all__ = [ + "Cosmos3OmniDiffusersPipeline", + "get_cosmos3_post_process_func", + "get_cosmos3_pre_process_func", + "Cosmos3VFMTransformer", +] diff --git a/vllm_omni/diffusion/models/cosmos3/action.py b/vllm_omni/diffusion/models/cosmos3/action.py new file mode 100644 index 00000000000..e2572bbb733 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/action.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Action-token helpers for Cosmos3 UVA/action generation.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import numpy as np +import torch + +ACTION_MODE_POLICY = "policy" +ACTION_MODE_FORWARD_DYNAMICS = "forward_dynamics" +ACTION_MODE_INVERSE_DYNAMICS = "inverse_dynamics" +ACTION_MODES = { + ACTION_MODE_POLICY, + ACTION_MODE_FORWARD_DYNAMICS, + ACTION_MODE_INVERSE_DYNAMICS, +} + + +EMBODIMENT_TO_DOMAIN_ID: dict[str, int] = { + "no_action": 0, + "av": 1, + "camera_pose": 2, + "hand_pose": 3, + "pusht": 4, + "libero": 5, + "umi": 6, + "bridge_orig_lerobot": 7, + "droid_lerobot": 8, + "robomind-franka": 8, + "galbot": 9, + "robomind-franka-dual": 12, + "robomind-ur": 13, + "agibotworld": 15, + "agibot_gear_gripper": 15, + "agibot_gear_gripper_ext": 15, + "fractal": 20, +} + + +VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = { + "256": { + "1,1": (256, 256), + "4,3": (320, 256), + "3,4": (256, 320), + "16,9": (320, 192), + "9,16": (192, 320), + }, + "480": { + "1,1": (640, 640), + "4,3": (736, 544), + "3,4": (544, 736), + "16,9": (832, 480), + "9,16": (480, 832), + }, + "704": { + "1,1": (960, 960), + "4,3": (1088, 832), + "3,4": (832, 1088), + "16,9": (1280, 704), + "9,16": (704, 1280), + }, + "720": { + "1,1": (960, 960), + "4,3": (1104, 832), + "3,4": (832, 1104), + "16,9": (1280, 720), + "9,16": (720, 1280), + }, +} + + +def normalize_action_mode(mode: Any) -> str | None: + if mode is None: + return None + normalized = str(mode).strip().lower() + if not normalized: + return None + if normalized not in ACTION_MODES: + raise ValueError(f"Unsupported Cosmos3 action_mode={mode!r}; expected one of {sorted(ACTION_MODES)}.") + return normalized + + +def resolve_domain_id( + *, + domain_id: Any = None, + domain_name: Any = None, + require_explicit: bool = False, +) -> int: + if domain_id is not None: + resolved = int(domain_id) + if resolved < 0: + raise ValueError(f"Cosmos3 domain_id must be non-negative, got {resolved}.") + return resolved + + if domain_name is None or str(domain_name).strip() == "": + if require_explicit: + raise ValueError( + "Cosmos3 action generation requires extra_args['domain_id'] or non-empty extra_args['domain_name']." + ) + return 0 + + key = str(domain_name).strip().lower() + if key not in EMBODIMENT_TO_DOMAIN_ID: + raise ValueError( + f"Unknown Cosmos3 action domain_name={domain_name!r}; " + f"expected one of {sorted(EMBODIMENT_TO_DOMAIN_ID)} or pass domain_id directly." + ) + return EMBODIMENT_TO_DOMAIN_ID[key] + + +def action_condition_indexes(mode: str, action_length: int) -> list[int]: + mode = normalize_action_mode(mode) + if mode == ACTION_MODE_FORWARD_DYNAMICS: + return list(range(action_length)) + if mode in {ACTION_MODE_POLICY, ACTION_MODE_INVERSE_DYNAMICS}: + return [] + raise AssertionError(f"Unexpected action mode: {mode!r}") + + +def vision_condition_indexes(mode: str, video_length: int, temporal_compression_factor: int) -> list[int]: + mode = normalize_action_mode(mode) + latent_frames = (video_length - 1) // temporal_compression_factor + 1 + if mode in {ACTION_MODE_POLICY, ACTION_MODE_FORWARD_DYNAMICS}: + return [0] + if mode == ACTION_MODE_INVERSE_DYNAMICS: + return list(range(latent_frames)) + raise AssertionError(f"Unexpected action mode: {mode!r}") + + +def action_start_frame_offset(mode: str, action_length: int, video_length: int) -> int: + del mode + if action_length == video_length - 1: + return 1 + if action_length == video_length: + return 0 + raise ValueError( + "Cosmos3 action_chunk_size must equal num_frames - 1 or num_frames; " + f"got action_chunk_size={action_length}, num_frames={video_length}." + ) + + +def build_action_condition_mask( + mode: str, + action_length: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + mask = torch.zeros(1, action_length, 1, device=device, dtype=dtype) + for idx in action_condition_indexes(mode, action_length): + mask[:, idx, :] = 1.0 + return mask + + +def build_vision_condition_mask( + mode: str, + video_length: int, + temporal_compression_factor: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + latent_frames = (video_length - 1) // temporal_compression_factor + 1 + mask = torch.zeros(1, 1, latent_frames, 1, 1, device=device, dtype=dtype) + for idx in vision_condition_indexes(mode, video_length, temporal_compression_factor): + mask[:, :, idx, :, :] = 1.0 + return mask + + +def pad_action_to_dim(action: torch.Tensor, action_dim: int) -> torch.Tensor: + if action.shape[-1] > action_dim: + raise ValueError(f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}.") + if action.shape[-1] == action_dim: + return action + padding = torch.zeros(*action.shape[:-1], action_dim - action.shape[-1], dtype=action.dtype, device=action.device) + return torch.cat([action, padding], dim=-1) + + +def load_action_tensor(action: Any = None, action_path: str | Path | None = None) -> torch.Tensor: + if action is None and action_path is None: + raise ValueError( + "Cosmos3 forward_dynamics action mode requires extra_args['action'] or extra_args['action_path']." + ) + if action is None: + action = json.loads(Path(str(action_path)).read_text()) + if isinstance(action, torch.Tensor): + tensor = action.detach().to(dtype=torch.float32) + else: + tensor = torch.as_tensor(np.asarray(action), dtype=torch.float32) + if tensor.ndim == 3 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + if tensor.ndim != 2: + raise ValueError(f"Cosmos3 action must have shape [T, D], got {tuple(tensor.shape)}.") + return tensor + + +def find_closest_target_size(h: int, w: int, resolution: str | int) -> tuple[int, int]: + key = str(resolution) + if key not in VIDEO_RES_SIZE_INFO: + raise ValueError( + f"Unknown Cosmos3 action resolution={resolution!r}; expected one of {sorted(VIDEO_RES_SIZE_INFO)}." + ) + input_ratio = h / w + best_size = None + best_diff = float("inf") + for cand_w, cand_h in VIDEO_RES_SIZE_INFO[key].values(): + diff = abs(input_ratio - cand_h / cand_w) + if diff < best_diff: + best_diff = diff + best_size = (cand_w, cand_h) + assert best_size is not None + return best_size diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/__init__.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/__init__.py new file mode 100644 index 00000000000..cfb794705ba --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .avae import Cosmos3AVAEAudioTokenizer + +__all__ = ["Cosmos3AVAEAudioTokenizer"] diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/activations.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/activations.py new file mode 100755 index 00000000000..02678a4ef09 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/activations.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. + + +import torch +from torch import nn, pow, sin +from torch.nn import Parameter + + +# https://github.com/jaywalnut310/vits/blob/main/commons.py +@torch.jit.script +def fused_add_tanh_sigmoid_multiply( + input_a: torch.Tensor, input_b: torch.Tensor, n_channels: list[int] +) -> torch.Tensor: + n_channels_int = n_channels[0] + in_act = input_a + input_b # [B,2*C,T] + t_act = torch.tanh(in_act[:, :n_channels_int, :]) # [B,C,T] + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) # [B,C,T] + acts = t_act * s_act # [B,C,T] + return acts # [B,C,T] + + +# about 10% faster training. no_div_by_zero (1e-9) baked in +@torch.jit.script +def fused_snake(x: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: + return x + (1.0 / (beta + 1e-9)) * pow(sin(x * alpha), 2) + + +class Snake(nn.Module): + """ + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__( + self, in_features: int, alpha: float = 1.0, alpha_trainable: bool = True, alpha_logscale: bool = True + ) -> None: + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self: "Snake", x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the function. + Applies the function to the input elementwise. + Snake := x + 1/a * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # [1,C,1] + if self.alpha_logscale: + alpha = torch.exp(alpha) # [1,C,1] + + return fused_snake(x, alpha, alpha) # [B,C,T] + # x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + # return x + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - Modified from the paper by Liu Ziyin, Tilman Hartwig, and Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__( + self, in_features: int, alpha: float = 1.0, alpha_trainable: bool = True, alpha_logscale: bool = True + ) -> None: + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self: "SnakeBeta", x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta := x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # [1,C,1] + beta = self.beta.unsqueeze(0).unsqueeze(-1) # [1,C,1] + if self.alpha_logscale: + alpha = torch.exp(alpha) # [1,C,1] + beta = torch.exp(beta) # [1,C,1] + + return fused_snake(x, alpha, beta) # [B,C,T] diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/__init__.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/__init__.py new file mode 100755 index 00000000000..28f76f7d706 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +from .act import Activation1d +from .filter import LowPassFilter1d, kaiser_sinc_filter1d, sinc +from .resample import DownSample1d, UpSample1d + +__all__ = [ + "Activation1d", + "LowPassFilter1d", + "kaiser_sinc_filter1d", + "sinc", + "DownSample1d", + "UpSample1d", +] diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/act.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/act.py new file mode 100755 index 00000000000..0825c181fa5 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/act.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch.nn as nn + +from .resample import DownSample1d, UpSample1d + + +class Activation1d(nn.Module): + def __init__( + self, + activation: nn.Module, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/filter.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/filter.py new file mode 100755 index 00000000000..56a45011ed9 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/filter.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, torch.tensor(1.0, device=x.device, dtype=x.dtype), torch.sin(math.pi * x) / math.pi / x + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) # [kernel_size] + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 # [kernel_size] + else: + time = torch.arange(kernel_size) - half_size # [kernel_size] + if cutoff == 0: + filter_ = torch.zeros_like(time) # [kernel_size] + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) # [kernel_size] + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) # [1,1,kernel_size] + + return filter # [1,1,kernel_size] + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + # input [B,C,T] + def forward(self, x): # x: [B,C,T] + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) # [B,C,T+pad_left+pad_right] + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) # [B,C,T//stride] + + return out # [B,C,T//stride] diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/resample.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/resample.py new file mode 100755 index 00000000000..30e9663fe18 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/alias_free_torch/resample.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch.nn as nn +from torch.nn import functional as F + +from .filter import LowPassFilter1d, kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B,C,T] + def forward(self, x): # x: [B,C,T] + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode="replicate") # [B,C,T+2*pad] + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) # [B,C,T*ratio+pad_left+pad_right] + x = x[..., self.pad_left : -self.pad_right] # [B,C,T*ratio] + + return x # [B,C,T*ratio] + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, stride=ratio, kernel_size=self.kernel_size + ) + + def forward(self, x): # x: [B,C,T] + xx = self.lowpass(x) # [B,C,T//ratio] + + return xx # [B,C,T//ratio] diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/avae.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/avae.py new file mode 100644 index 00000000000..03367071f4f --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/avae.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Local AVAE audio tokenizer used by Cosmos3 sound generation.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import torch +import torch.nn.functional as F +from torch import nn +from vllm.logger import init_logger + +from vllm_omni.diffusion.models.progress_bar import _is_rank_zero + +from .config import AttrDict +from .models import load_generator + +logger = init_logger(__name__) + + +def _default_avae_config( + *, + sample_rate: int, + audio_channels: int, + io_channels: int, + hop_size: int, +) -> AttrDict: + return AttrDict( + { + "model_type": "autoencoder_v2", + "sampling_rate": sample_rate, + "stereo": audio_channels == 2, + "use_wav_as_input": True, + "normalize_volume": True, + "hop_size": hop_size, + "input_channels": 1, + "enc_type": "spec_convnext", + "enc_dim": 192, + "enc_intermediate_dim": 768, + "enc_num_layers": 12, + "enc_num_blocks": 2, + "enc_n_fft": 64, + "enc_hop_length": 16, + "enc_latent_dim": 128, + "enc_c_mults": [1, 2, 4], + "enc_strides": [4, 4, 8], + "enc_identity_init": False, + "enc_use_snake": True, + "dec_type": "oobleck", + "dec_dim": 320, + "dec_c_mults": [1, 2, 4, 8, 16], + "dec_strides": [2, 4, 4, 8, 8], + "dec_use_snake": True, + "dec_final_tanh": False, + "dec_out_channels": audio_channels, + "dec_anti_aliasing": False, + "dec_use_nearest_upsample": False, + "dec_use_tanh_at_final": False, + "bottleneck_type": "vae", + "bottleneck": {"type": "vae"}, + "activation": "snakebeta", + "snake_logscale": True, + "anti_aliasing": False, + "use_cuda_kernel": False, + "causal": False, + "padding_mode": "zeros", + "vocoder_input_dim": io_channels, + } + ) + + +def _load_config( + config_path: str | Path | None, + *, + sample_rate: int, + audio_channels: int, + io_channels: int, + hop_size: int, +) -> AttrDict: + if config_path: + with open(config_path, encoding="utf-8") as f: + return AttrDict(json.load(f)) + return _default_avae_config( + sample_rate=sample_rate, + audio_channels=audio_channels, + io_channels=io_channels, + hop_size=hop_size, + ) + + +def _load_checkpoint(path: str | Path, map_location: torch.device | str) -> dict[str, torch.Tensor]: + path = Path(path) + if path.suffix == ".safetensors": + try: + from safetensors.torch import load_file + except ImportError as exc: + raise ImportError("Loading AVAE .safetensors checkpoints requires safetensors.") from exc + checkpoint = load_file(str(path), device=str(map_location)) + else: + checkpoint = torch.load(path, map_location=map_location) + + if not isinstance(checkpoint, dict): + raise TypeError(f"AVAE checkpoint must be a dict, got {type(checkpoint)!r}.") + + for key in ("generator", "state_dict", "model"): + value = checkpoint.get(key) + if isinstance(value, dict): + checkpoint = value + break + + if not all(isinstance(value, torch.Tensor) for value in checkpoint.values()): + tensor_items = {key: value for key, value in checkpoint.items() if isinstance(value, torch.Tensor)} + if not tensor_items: + raise RuntimeError(f"No tensor state dict found in AVAE checkpoint keys: {list(checkpoint.keys())[:16]}") + checkpoint = tensor_items + + return checkpoint + + +def _strip_prefixes( + state_dict: dict[str, torch.Tensor], + model_state: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + prefixes = ("module.", "generator.", "model.") + normalized: dict[str, torch.Tensor] = {} + for key, value in state_dict.items(): + candidates = [key] + current = key + changed = True + while changed: + changed = False + for prefix in prefixes: + if current.startswith(prefix): + current = current[len(prefix) :] + candidates.append(current) + changed = True + break + selected = next((candidate for candidate in candidates if candidate in model_state), candidates[-1]) + normalized[selected] = value + return normalized + + +class Cosmos3AVAEAudioTokenizer(nn.Module): + """AVAE tokenizer/decoder for Cosmos3 audio latents.""" + + def __init__( + self, + *, + checkpoint_path: str | Path, + config_path: str | Path | None = None, + sample_rate: int = 48000, + audio_channels: int = 2, + io_channels: int = 64, + hop_size: int = 1920, + normalize_latents: bool = True, + normalization_type: str = "none", + tanh_input_scale: float = 1.5, + tanh_output_scale: float = 3.5, + tanh_clamp: float = 0.995, + dtype: torch.dtype = torch.bfloat16, + device: torch.device | str = "cuda", + ) -> None: + super().__init__() + self.sample_rate = int(sample_rate) + self.audio_channels = int(audio_channels) + self.latent_ch = int(io_channels) + self.hop_size = int(hop_size) + self.dtype = dtype + self.device = torch.device(device) + self.normalize_volume = True + + if normalization_type == "none" and normalize_latents: + normalization_type = "tanh" + self.normalization_type = normalization_type + self.tanh_input_scale = float(tanh_input_scale) + self.tanh_output_scale = float(tanh_output_scale) + self.tanh_clamp = float(tanh_clamp) + + config = _load_config( + config_path, + sample_rate=self.sample_rate, + audio_channels=self.audio_channels, + io_channels=self.latent_ch, + hop_size=self.hop_size, + ) + self.model = load_generator(config.model_type, config, self.device) + state_dict = _strip_prefixes( + _load_checkpoint(checkpoint_path, self.device), + self.model.state_dict(), + ) + matching_keys = set(state_dict).intersection(self.model.state_dict()) + if not matching_keys: + raise RuntimeError("AVAE checkpoint did not contain any keys matching the local AVAE model.") + missing, unexpected = self.model.load_state_dict(state_dict, strict=False) + if _is_rank_zero(): + logger.info( + "Loaded Cosmos3 AVAE checkpoint from %s; missing=%d unexpected=%d", + checkpoint_path, + len(missing), + len(unexpected), + ) + + self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + if hasattr(self.model, "remove_weight_norm"): + self.model.remove_weight_norm() + self.model.to(dtype=self.dtype) + + @property + def temporal_compression_factor(self) -> int: + return self.hop_size + + def get_latent_num_samples(self, num_audio_samples: int) -> int: + return int(num_audio_samples) // self.temporal_compression_factor + + def get_audio_num_samples(self, num_latent_samples: int) -> int: + return int(num_latent_samples) * self.temporal_compression_factor + + def _normalize_latent(self, latent: torch.Tensor) -> torch.Tensor: + if self.normalization_type == "tanh": + in_dtype = latent.dtype + return (torch.tanh(latent.float() / self.tanh_input_scale) * self.tanh_output_scale).to(in_dtype) + if self.normalization_type != "none": + raise ValueError(f"Unsupported AVAE normalization_type={self.normalization_type!r}.") + return latent + + def _denormalize_latent(self, latent: torch.Tensor) -> torch.Tensor: + if self.normalization_type == "tanh": + in_dtype = latent.dtype + latent = torch.clamp( + latent.float() / self.tanh_output_scale, + -self.tanh_clamp, + self.tanh_clamp, + ) + return (torch.atanh(latent) * self.tanh_input_scale).to(in_dtype) + if self.normalization_type != "none": + raise ValueError(f"Unsupported AVAE normalization_type={self.normalization_type!r}.") + return latent + + @torch.no_grad() + def encode(self, audio: torch.Tensor, force_pad: bool = False) -> torch.Tensor: + in_dtype = audio.dtype + x = audio.to(self.device) + if x.ndim != 3: + raise ValueError(f"AVAE audio input must be [B, C, T], got {tuple(x.shape)}.") + if x.shape[1] == 1 and self.audio_channels == 2: + x = x.repeat(1, 2, 1) + elif x.shape[1] > self.audio_channels: + x = x[:, : self.audio_channels] + if self.normalize_volume: + x = x / (x.abs().amax(dim=(-2, -1), keepdim=True) + 1e-5) * 0.95 + if force_pad or not self.model.training: + pad_amount = (self.hop_size - (x.shape[-1] % self.hop_size)) % self.hop_size + if pad_amount: + x = F.pad(x, (0, pad_amount), mode="constant", value=0) + encoded = self.model.encode(x.to(self.dtype)) + latent = encoded["latent"] if isinstance(encoded, dict) else encoded + return self._normalize_latent(latent).to(in_dtype) + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + in_dtype = latent.dtype + z = self._denormalize_latent(latent.to(self.device)).to(self.dtype) + decoded = self.model.decode(z) + if not isinstance(decoded, dict) or "decoder_out" not in decoded: + raise RuntimeError("AVAE decoder did not return decoder_out.") + audio = decoded["decoder_out"].clamp(-1.0, 1.0) + return audio.to(in_dtype) diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/bottlenecks.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/bottlenecks.py new file mode 100755 index 00000000000..191f653c470 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/bottlenecks.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Bottleneck modules for AVAE tokenizer. + +This cleaned-up version only includes VAEBottleneck which is used +by the spec_convnext encoder + oobleck decoder + vae configuration. +""" + +from typing import Any + +import torch +from torch import Tensor, nn + + +# Base class +class Bottleneck(nn.Module): + """Base class for bottleneck modules.""" + + def __init__(self: "Bottleneck", is_discrete: bool = False) -> None: + super().__init__() + self.is_discrete = is_discrete + + def encode( + self: "Bottleneck", x: Tensor, return_info: bool = False, **kwargs: Any + ) -> Tensor | tuple[Tensor, dict[str, Any]]: + raise NotImplementedError + + def decode(self: "Bottleneck", x: Tensor, return_info: bool = False) -> Tensor | tuple[Tensor, dict[str, Any]]: + raise NotImplementedError + + +def vae_sample(mean: Tensor, scale: Tensor) -> tuple[Tensor, Tensor]: + """ + Sample from VAE latent distribution. + + Args: + mean: Mean of the latent distribution + scale: Scale parameter (will be passed through softplus) + + Returns: + latents: Sampled latents + kl: KL divergence loss + """ + stdev = nn.functional.softplus(scale) + 1e-4 # [B,C,T] + var = stdev * stdev # [B,C,T] + logvar = torch.log(var) # [B,C,T] + latents = torch.randn_like(mean) * stdev + mean # [B,C,T] + + kl = (mean * mean + var - logvar - 1).sum(1).mean() # scalar + + return latents, kl + + +class VAEBottleneck(Bottleneck): + """ + Variational Autoencoder (VAE) bottleneck. + + Applies VAE reparameterization trick during encoding. + """ + + def __init__(self: "VAEBottleneck") -> None: + super().__init__(is_discrete=False) + + def encode( + self: "VAEBottleneck", x: Tensor, return_info: bool = False, **kwargs: Any + ) -> Tensor | tuple[Tensor, dict[str, Any]]: + """ + Encode input through VAE bottleneck. + + Args: + x: Input tensor with shape [B, C*2, T] where C*2 contains + concatenated mean and scale parameters + return_info: Whether to return additional info dict + + Returns: + Sampled latents (and optionally info dict with KL divergence) + """ + info: dict[str, Any] = {} + + mean, scale = x.chunk(2, dim=1) # mean,scale: [B,C,T] + x, kl = vae_sample(mean, scale) # x: [B,C,T] + + info["kl"] = kl + + if return_info: + return x, info + else: + return x + + def decode(self: "VAEBottleneck", x: Tensor, return_info: bool = False) -> Tensor | tuple[Tensor, dict[str, Any]]: + """ + Decode from latents (identity operation for VAE). + + Args: + x: Latent tensor + return_info: Whether to return additional info dict + + Returns: + Latents (and optionally empty info dict) + """ + info: dict[str, Any] = {} + if return_info: + return x, info + else: + return x + + +def create_bottleneck_from_config(bottleneck_config: dict[str, Any]) -> Bottleneck: + """ + Create a bottleneck module from configuration. + + Args: + bottleneck_config: Dictionary with 'type' key specifying bottleneck type + + Returns: + Bottleneck module instance + + Note: + This cleaned version only supports 'vae' bottleneck type. + """ + bottleneck_type = bottleneck_config.get("type", None) + + assert bottleneck_type is not None, "type must be specified in bottleneck config" + + if bottleneck_type == "vae": + bottleneck = VAEBottleneck() + else: + raise NotImplementedError( + f"Bottleneck type '{bottleneck_type}' not supported in cleaned AVAE. " + f"Only 'vae' is supported for the spec_convnext + oobleck + vae configuration." + ) + + return bottleneck diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/config.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/config.py new file mode 100644 index 00000000000..c52a956ce4b --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/config.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. + +from typing import Any + + +class AttrDict(dict): + def __init__(self: "AttrDict", *args: Any, **kwargs: Any) -> None: + values = dict(*args, **kwargs) + super().__init__({key: self._convert(value) for key, value in values.items()}) + self.__dict__ = self + + @classmethod + def _convert(cls, value: Any) -> Any: + if isinstance(value, dict) and not isinstance(value, AttrDict): + return cls(value) + if isinstance(value, list): + return [cls._convert(item) for item in value] + return value diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/models.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/models.py new file mode 100755 index 00000000000..41ebe5b7b65 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/models.py @@ -0,0 +1,614 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. + +"""AVAE Models. + +This file contains only the models needed for the spec_convnext encoder + +oobleck decoder + vae configuration. +""" + +import math +from collections.abc import Callable +from functools import partial +from typing import Any + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.utils import remove_weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + +from .config import AttrDict +from .modules import ConvNeXtBlock, OobleckDecoderBlock, WNConv1d, get_activation + +# for causal models we use encodec modules +from .modules_encodec import SConv1d + + +def load_generator(model_type: str, h: AttrDict, device: torch.device | str) -> nn.Module: + """ + Load generator model based on model_type. + + Cleaned version only supports 'autoencoder_v2' type. + """ + if model_type in ["autoencoder_v2"]: + generator = LatentAutoEncoderV2(h).to(device) + else: + raise NotImplementedError( + f"Model type '{model_type}' not supported in cleaned AVAE. Only 'autoencoder_v2' is supported." + ) + + return generator + + +class TrimPadding(nn.Module): + """ + Used for causal convolution support of a conv layer wrapped with nn.Sequential + """ + + def __init__(self: "TrimPadding", padding: int) -> None: + super().__init__() + self.padding = padding + + def forward(self: "TrimPadding", x: torch.Tensor) -> torch.Tensor: + return x[:, :, : -self.padding] # [B,C,T-padding] + + +class SpectrogramConvNeXtEncoder(nn.Module): + """ + Spectrogram Encoder with ConvNeXtBlocks + + This encoder processes input waveforms by converting them into spectrograms + (magnitude and phase concatenated along the channel dimension) and encodes them + using a sequence of ConvNeXtBlocks and downsampling layers. + + Args (mapped from h): + in_channels (int): Number of input audio channels (1 for mono, 2 for stereo). + channels (int): Base number of channels for the encoder. + latent_dim (int): Dimensionality of the final latent representation. + c_mults (List[int]): Channel multipliers at each depth of the encoder. + strides (List[int]): Downsampling strides for each depth. + num_blocks (int): Number of ConvNeXtBlocks to stack per depth. + identity_init (bool): Whether to initialize the 1x1 convs in residual paths as zeros. + n_fft (int): Number of FFT points for spectrogram computation. + hop_length (int): Hop length for the STFT. + use_snake (bool): Whether to use Snake activation in ConvNeXtBlocks. + causal (bool): If True, uses causal convolutions. + padding_mode (str): Padding mode for convolutions (default: 'zeros'). + + Inputs: + x (torch.Tensor): Input waveform tensor of shape `[batch, in_channels, time]`. + + Outputs: + torch.Tensor: Encoded representation of shape `[batch, time_out, latent_dim]`. + + Forward Pass: + - Converts waveform input into spectrograms (concatenates magnitude and phase). + - Processes the spectrogram through stacked ConvNeXtBlocks and downsampling layers. + - Outputs the final latent representation of specified dimensionality. + + Example: + encoder = SpectrogramConvNeXtEncoder( + in_channels=2, channels=256, latent_dim=128, c_mults=[1, 2, 4], strides=[4, 4, 8] + ) + waveform = torch.randn(8, 2, 65536) # [batch, channels, time] + encoded = encoder(waveform) # Output: [8, time_out, 128] + + NOTE: output is in [B, T, C] to be consistent with other encoders + """ + + def __init__(self: "SpectrogramConvNeXtEncoder", h: AttrDict, **kwargs: Any) -> None: + super().__init__() + + self.in_channels = h.input_channels + if getattr(h, "stereo", False): + self.in_channels *= 2 + + # if "enc_latent_dim" is found in v2 config, set it as latent_dim + if hasattr(h, "enc_latent_dim"): + self.latent_dim = h.enc_latent_dim + else: + # if not found, fallback to v1 logic + self.latent_dim = h.vocoder_input_dim + if h.model_type == "vae": + self.latent_dim *= 2 + + self.channels = h.enc_dim + + self.c_mults = h.enc_c_mults + self.strides = h.enc_strides + self.num_blocks = h.enc_num_blocks + self.identity_init = h.enc_identity_init + self.causal = h.causal + self.padding_mode = h.padding_mode + + self.use_snake = h.enc_use_snake + + # Basic checks + assert len(self.c_mults) == len(self.strides), ( + f"The length of c_mults and strides must match. Got {len(self.c_mults)} vs {len(self.strides)}." + ) + + # Spectrogram function + self.n_fft = h.enc_n_fft + self.hop_length = h.enc_hop_length + self.spectrogram_fn = partial( + self.spectrogram, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.n_fft, + window_fn=torch.hann_window, + ) + + # --------------------------------------------------------------------- + # 1) Initial projection (similar to the first_conv in OobleckEncoder), + # but here we typically use a 1x1 conv for a "spectrogram style" input. + # --------------------------------------------------------------------- + layers = [] + layers.append( + WNConv1d((self.n_fft + 2) * self.in_channels, self.c_mults[0] * self.channels, kernel_size=1, bias=False) + ) + + # --------------------------------------------------------------------- + # 2) Stages: For each i in range(len(c_mults)): + # - Stack num_blocks of ConvNeXtBlock + # - Downsample via stride convolution + # --------------------------------------------------------------------- + for i in range(len(self.c_mults)): + dim_in = self.c_mults[i] * self.channels + # Determine output dimension for the block + if i < len(self.c_mults) - 1: # If not the last block + dim_out = self.c_mults[i + 1] * self.channels + else: # For the last block, dim_out is c_mults[-1] * channels + dim_out = self.c_mults[-1] * self.channels + ds_rate = self.strides[i] + + # (a) Repeated ConvNeXtBlocks + for _ in range(self.num_blocks): + layers.append( + ConvNeXtBlock( + dim=dim_in, + intermediate_dim=dim_in * 4, + identity_init=self.identity_init, + use_snake=self.use_snake, + causal=self.causal, + ) + ) + + # (b) Downsampling convolution + layers.append(self._create_downsample_layer(dim_in, dim_out, ds_rate, self.causal, self.padding_mode)) + + # --------------------------------------------------------------------- + # 3) Final projection from the last channel dimension to latent_dim. + # --------------------------------------------------------------------- + layers.append(WNConv1d(self.c_mults[-1] * self.channels, self.latent_dim, kernel_size=1, bias=False)) + + self.layers = nn.Sequential(*layers) + + def spectrogram( + self: "SpectrogramConvNeXtEncoder", + wav: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + window_fn: Callable[[int], torch.Tensor] = torch.hann_window, + ) -> Tensor: + """ + wav: [B_ch,T_audio] where B_ch = batch * channels (channel folded into batch) + returns: [B_ch,n_fft//2+1,T_frames] complex + """ + pad_size_l = (n_fft - hop_length) // 2 + pad_size_r = (n_fft - hop_length) - pad_size_l + with torch.autocast(device_type=wav.device.type, enabled=False): + wav = F.pad(wav, (pad_size_l, pad_size_r)).float() # [B_ch,T_audio+pad] + spec = torch.stft( + wav, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window_fn(win_length).to(wav), + center=False, + normalized=False, + onesided=True, + return_complex=True, + ) # [B_ch,n_fft//2+1,T_frames] + return spec # [B_ch,n_fft//2+1,T_frames] + + def _create_downsample_layer( + self: "SpectrogramConvNeXtEncoder", + in_channels: int, + out_channels: int, + stride: int, + causal: bool, + padding_mode: str, + ) -> nn.Module: + if ( + causal + ): # use EnCodec's SConv1d for convenience without reinventing the wheels. padding_mode is reflect by default + downsample_layer = SConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=stride, + causal=True, + norm="weight_norm", + ) + else: # original non-causal implementation + downsample_layer = WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + padding_mode=padding_mode, + ) + return downsample_layer + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: [B,C,T_audio] waveform (mono: C=1, stereo: C=2) + + Returns: + [B,T_latent,latent_dim] + """ + + # Handle stereo input by merging channel dim into batch dim + batch, channels, length = x.shape + if channels > 1: # Stereo case + x = x.reshape(batch * channels, 1, length) # [B*C,1,T_audio] (channel folded into batch) + + # Compute the spectrogram + with torch.autocast(device_type=x.device.type, enabled=False): + spec = self.spectrogram_fn(x.float().squeeze(1)) # [B*C,n_fft//2+1,T_frames] complex + mag, ph = torch.view_as_real(spec).chunk(2, dim=-1) # each [B*C,n_fft//2+1,T_frames,1] + spectrogram = torch.cat([mag, ph], dim=1).squeeze(-1) # [B*C,n_fft+2,T_frames] + + # Cast spectrogram back to original dtype + spectrogram = spectrogram.to(x.dtype) # [B*C,n_fft+2,T_frames] + + # Restore stereo structure if needed + if channels > 1: # Stereo case + freq = spectrogram.shape[1] # Get the frequency dimension + spectrogram = spectrogram.reshape( + batch, channels * freq, *spectrogram.shape[2:] + ) # [B,(n_fft+2)*C,T_frames] + + # forward pass the encoder + output = self.layers(spectrogram) # [B,latent_dim,T_latent] + + return output.transpose(1, 2) # [B,T_latent,latent_dim] + + def remove_weight_norm(self: "SpectrogramConvNeXtEncoder") -> None: + for module in self.modules(): + if hasattr(module, "parametrizations"): # for new WN implementation using parameterizations + try: + remove_parametrizations(module, "weight") + except ValueError: + pass + elif hasattr(module, "weight"): + try: + remove_weight_norm(module) + except ValueError: + pass + + +class OobleckDecoder(nn.Module): + """ + Oobleck Decoder for audio synthesis. + + Decodes latent representations into audio waveforms using + upsampling blocks with optional Snake activation and anti-aliasing. + """ + + def __init__( + self: "OobleckDecoder", + h: AttrDict, + ) -> None: + super().__init__() + + self.h = h + + latent_dim = self.h.vocoder_input_dim + + out_channels = self.h.input_channels + if getattr(h, "stereo", False): + out_channels *= 2 + + channels = self.h.dec_dim + c_mults = self.h.dec_c_mults + strides = self.h.dec_strides + use_snake = self.h.dec_use_snake + use_nearest_upsample = self.h.dec_use_nearest_upsample + antialias_activation = self.h.dec_anti_aliasing + causal = self.h.causal + final_tanh = self.h.dec_use_tanh_at_final + padding_mode = self.h.padding_mode + + c_mults = [1, *c_mults] + + self.depth = len(c_mults) + + # Padding for the first convolution layer + self.first_padding = 6 if causal else 3 + first_conv = WNConv1d( + in_channels=latent_dim, + out_channels=c_mults[-1] * channels, + kernel_size=7, + padding=self.first_padding, + padding_mode=padding_mode, + ) + + if causal: + first_conv = nn.Sequential(first_conv, TrimPadding(self.first_padding)) + + layers = [first_conv] + + for i in range(self.depth - 1, 0, -1): + layers += [ + OobleckDecoderBlock( + in_channels=c_mults[i] * channels, + out_channels=c_mults[i - 1] * channels, + stride=strides[i - 1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample, + causal=causal, + padding_mode=padding_mode, + ) + ] + + # Padding for the final convolution layer + self.final_padding = 6 if causal else 3 + final_conv = WNConv1d( + in_channels=c_mults[0] * channels, + out_channels=out_channels, + kernel_size=7, + padding=self.final_padding, + padding_mode=padding_mode, + bias=False, + ) + + if causal: + final_conv = nn.Sequential(final_conv, TrimPadding(self.final_padding)) + + layers += [ + get_activation( + "snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels + ), + final_conv, + nn.Tanh() if final_tanh else nn.Identity(), + ] + + self.layers = nn.Sequential(*layers) + + def forward(self: "OobleckDecoder", x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: [B,latent_dim,T_latent] + + Returns: + [B,C,T_audio] + """ + x = self.layers(x) # [B,C,T_audio] + return x # [B,C,T_audio] + + def remove_weight_norm(self: "OobleckDecoder") -> None: + for module in self.modules(): + if hasattr(module, "parametrizations"): # for new WN implementation using parameterizations + try: + remove_parametrizations(module, "weight") + except ValueError: + pass + elif hasattr(module, "weight"): + try: + remove_weight_norm(module) + except ValueError: + pass + + +class LatentAutoEncoderV2(nn.Module): + """ + A Latent AutoEncoder class with cleaner implementation to generalize using bottleneck.py + + Attributes: + h: Configuration object containing model hyperparameters. + encoder (nn.Module): The encoder module based on configuration. + bottleneck (Bottleneck): Bottleneck module from bottleneck.py. + decoder (nn.Module): The decoder module based on configuration. + """ + + def __init__(self: "LatentAutoEncoderV2", h: AttrDict) -> None: + super().__init__() + self.h = h + + # Set up basic model properties + self.stereo = getattr(self.h, "stereo", False) + + # Determine input type + self.input_type = None + if getattr(self.h, "use_wav_as_input", False): + self.input_type = "waveform" + self.h.input_channels = 1 + elif getattr(self.h, "use_linear_spec_as_input", False): + self.input_type = "linear" + self.h.input_channels = self.h.num_linears + elif getattr(self.h, "use_discrete_code_as_input", False): + self.input_type = "discrete_code" + self.h.input_channels = 1 + else: + self.input_type = "mel" + self.h.input_channels = self.h.num_mels + + # hop_size defines the down/up sampling factor of the autoencoder + self.hop_size = self.h.hop_size + + # Initialize encoder + self.enc_type = getattr(self.h, "enc_type", "convnext") + + # Define encoder (only spec_convnext supported in cleaned version) + if self.enc_type == "spec_convnext": + self.encoder = SpectrogramConvNeXtEncoder(self.h) + else: + raise NotImplementedError( + f"Encoder type '{self.enc_type}' not supported in cleaned AVAE. Only 'spec_convnext' is supported." + ) + + # Initialize encoder projector (Identity for spec_convnext) + self.encoder_proj = nn.Identity() + + # Initialize bottleneck from config + from .bottlenecks import create_bottleneck_from_config + + if hasattr(self.h, "bottleneck"): + self.bottleneck = create_bottleneck_from_config(self.h.bottleneck) + else: + raise ValueError("Bottleneck configuration must be specified") + + # Check for encoder-only mode + self.encoder_only = getattr(self.h, "encoder_only", False) + + if not self.encoder_only: + # Initialize decoder + self.dec_type = getattr(self.h, "dec_type", "oobleck") + if self.dec_type == "oobleck": + self.decoder = OobleckDecoder(self.h) + else: + raise NotImplementedError( + f"Decoder type '{self.dec_type}' not supported in cleaned AVAE. Only 'oobleck' is supported." + ) + else: + # Skip decoder initialization + self.decoder = None + + # Whether to freeze encoder + self.freeze_encoder = getattr(self.h, "freeze_encoder", False) + if self.freeze_encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + + def calculate_latent_lengths(self: "LatentAutoEncoderV2", audio_lengths: torch.Tensor) -> torch.Tensor: + """ + Calculates the latent lengths given the original audio lengths. + + Args: + audio_lengths (torch.Tensor): A tensor of shape [B] containing the lengths of the original audio samples. + + Returns: + torch.Tensor: A tensor of shape [B] containing the corresponding latent lengths. + """ + if self.input_type == "waveform": + # The latent length is the audio length divided by the hop_size + latent_lengths = torch.ceil(audio_lengths.float() / self.hop_size).long() # [B] + else: + # The latent length is same as audio_lengths + latent_lengths = audio_lengths # [B] + + return latent_lengths + + def forward(self: "LatentAutoEncoderV2", x: torch.Tensor) -> dict[str, torch.Tensor]: + """ + Forward pass through the model. + + Args: + x (torch.Tensor): Input tensor to the model with shape [B,C,T_audio]. + + Returns: + dict[str, torch.Tensor]: Dictionary of output tensors including: + - encoder_out: Raw encoder output + - latent: Bottleneck latent representation + - decoder_out: Decoded output (if decoder exists) + - Additional outputs specific to the bottleneck type + """ + return_dict = {} + + # Encoder + encoder_out = self.encoder(x) # [B,T_latent,enc_latent_dim] + encoder_out_proj = self.encoder_proj(encoder_out) # [B,T_latent,enc_latent_dim] + + # Apply bottleneck after reshaping to [B, C, T] again + latent, bottleneck_enc_info = self.bottleneck.encode( + encoder_out_proj.transpose(1, 2), + return_info=True, # transpose: [B,enc_latent_dim,T_latent] + ) # [B,C,T_latent] + + # Update return dictionary + return_dict.update( + {"encoder_out": encoder_out.transpose(1, 2), "latent": latent} # encoder_out: [B,enc_latent_dim,T_latent] + ) + # Add bottleneck-specific info to return dict + for k, v in bottleneck_enc_info.items(): + return_dict[k] = v + + # Decode (if decoder exists) + if self.decoder is not None: + # Apply bottleneck decode + decoded_latent, bottleneck_dec_info = self.bottleneck.decode(latent, return_info=True) # [B,C,T_latent] + # Apply decoder + decoder_out = self.decoder(decoded_latent) # [B,C,T_audio] + + # Update return dictionary + return_dict["decoder_out"] = decoder_out # [B,C,T_audio] + # Add bottleneck-specific info to return dict + for k, v in bottleneck_dec_info.items(): + return_dict[k] = v + + return return_dict + + def encode(self: "LatentAutoEncoderV2", x: torch.Tensor) -> dict[str, torch.Tensor]: + """ + Encodes input x into latent representation using encoder and bottleneck. + + Args: + x (torch.Tensor): Input tensor with shape [B, C, T]. + + Returns: + dict[str, torch.Tensor]: Dictionary containing: + - latent: Bottleneck latent representation + - Additional outputs specific to the bottleneck type + """ + encoder_out = self.encoder(x) # [B,T_latent,enc_latent_dim] + encoder_out_proj = self.encoder_proj(encoder_out) # [B,T_latent,enc_latent_dim] + latent, bottleneck_info = self.bottleneck.encode( + encoder_out_proj.transpose(1, 2), + return_info=True, # transpose: [B,enc_latent_dim,T_latent] + ) # [B,C,T_latent] + + return_dict = {"latent": latent} # latent: [B,C,T_latent] + # Add bottleneck-specific info to return dict + for k, v in bottleneck_info.items(): + return_dict[k] = v + + return return_dict + + def decode(self: "LatentAutoEncoderV2", latent: torch.Tensor) -> dict[str, torch.Tensor]: + """ + Decodes continuous latent representation into output using bottleneck and decoder. + + Args: + latent (torch.Tensor): continuous latent representation with shape [B, C, T]. + + Returns: + dict[str, torch.Tensor]: Dictionary containing: + - decoder_out: The output from the decoder + - Additional outputs from the bottleneck decode process + """ + # Apply bottleneck decode + decoded_latent, bottleneck_info = self.bottleneck.decode(latent, return_info=True) # [B,C,T_latent] + + # Apply decoder + decoder_out = self.decoder(decoded_latent) # [B,C,T_audio] + + return_dict = {"decoder_out": decoder_out} # decoder_out: [B,C,T_audio] + # Add bottleneck-specific info to return dict + for k, v in bottleneck_info.items(): + return_dict[k] = v + + return return_dict + + def remove_weight_norm(self: "LatentAutoEncoderV2") -> None: + """Remove weight normalization from all components.""" + self.encoder.remove_weight_norm() + if self.decoder is not None: + self.decoder.remove_weight_norm() diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/modules.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/modules.py new file mode 100755 index 00000000000..03c08938dcf --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/modules.py @@ -0,0 +1,418 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. + +"""AVAE Modules. + +This file contains only the modules needed for the spec_convnext encoder + +oobleck decoder + vae configuration. +""" + +import math +from typing import Any, Literal + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.cuda import amp +from torch.nn.utils import weight_norm + +from . import activations +from .alias_free_torch.act import Activation1d as TorchActivation1d + +# for causal models we use encodec modules +from .modules_encodec import SConvTranspose1d + + +def WNConv1d(*args: Any, **kwargs: Any) -> nn.Conv1d: + """Weight-normalized 1D convolution.""" + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args: Any, **kwargs: Any) -> nn.ConvTranspose1d: + """Weight-normalized 1D transpose convolution.""" + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + Used for identity initialization in ConvNeXt blocks. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def may_mask( + x: Tensor, + mask: Tensor | None = None, +) -> Tensor: + """ + Apply mask to tensor if provided. + + Args: + x: Input tensor + mask: Optional mask tensor + + Returns: + Masked tensor if mask is provided, otherwise original tensor + """ + if mask is not None: + x = x * mask + return x + + +class LayerNorm(nn.Module): + """ + LayerNorm with optional bias. + PyTorch doesn't support bias=False natively. + """ + + def __init__(self, size: int, gamma0: float = 1, eps: float = 1e-5, use_bias: bool = False) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(size)) + self.bias = nn.Parameter(torch.zeros(size)) if use_bias else None + self.eps = eps + self.size = size + + def forward(self, tensor: Tensor) -> Tensor: + """ + Forward pass. + + Args: + tensor: Input tensor of shape (B, T, C) + + Returns: + Normalized tensor + """ + dtype = tensor.dtype + # fp32 to avoid numerical issues + with amp.autocast(enabled=True, dtype=torch.float32): + tensor = F.layer_norm(tensor, self.weight.shape, self.weight, self.bias, self.eps) + return tensor.to(dtype) + + +class ConvNeXtBlock(nn.Module): + """ + ConvNeXt 1D Block adapted from https://github.com/charactr-platform/vocos + which is adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + Supports causal and non-causal mode. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + identity_init (bool): If True, initializes the 1x1 conv in residual paths to zero (identity-friendly). + use_snake (bool): If True, uses SnakeBeta activation; otherwise, GELU. + causal (bool): If True, applies causal padding; otherwise, applies symmetric padding for non-causal. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + identity_init: bool = False, + use_snake: bool = False, + causal: bool = False, + ): + super().__init__() + self.causal = causal + + if causal: + # Causal padding: Only pad on the left + self.dwconv = nn.Sequential( + nn.ConstantPad1d((6, 0), 0), # causal padding + nn.Conv1d(dim, dim, kernel_size=7, groups=dim), + ) + else: + # Non-causal padding: Symmetric padding + self.dwconv = nn.Sequential( + nn.ConstantPad1d((3, 3), 0), # symmetric padding (kernel_size // 2 on both sides) + nn.Conv1d(dim, dim, kernel_size=7, groups=dim), + ) + + self.norm = LayerNorm(dim) + self.pwconv1 = nn.Conv1d(dim, intermediate_dim, 1) # pointwise/1x1 convs + self.act = activations.SnakeBeta(intermediate_dim) if use_snake else nn.GELU() + + if identity_init: + self.pwconv2 = zero_module(nn.Conv1d(intermediate_dim, dim, 1)) + else: + self.pwconv2 = nn.Conv1d(intermediate_dim, dim, 1) + + def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor: + """ + Forward pass. + + Args: + x: Input tensor of shape (B, C, T) + mask: Optional mask tensor + + Returns: + Output tensor of shape (B, C, T) + """ + residual = x # [B,C,T] + x = self.dwconv(may_mask(x, mask)) # [B,C,T] + x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) # [B,C,T] -> [B,T,C] -> [B,C,T] + x = self.pwconv1(x) # [B,intermediate_dim,T] + x = self.act(x) # [B,intermediate_dim,T] + x = self.pwconv2(x) # [B,C,T] + x = residual + x # [B,C,T] + return may_mask(x, mask) # [B,C,T] + + def remove_weight_norm(self) -> None: + """No weight norm is applied in ConvNeXtBlock.""" + pass + + +def get_activation( + activation: Literal["elu", "snake", "none"], + antialias: bool = False, + channels: int | None = None, + use_cuda_kernel: bool = False, +) -> nn.Module: + """ + Get activation module by name. + + Args: + activation: Activation type ('elu', 'snake', or 'none') + antialias: Whether to wrap with anti-aliasing + channels: Number of channels (required for snake activation) + use_cuda_kernel: Whether to use CUDA kernel (not supported) + + Returns: + Activation module + """ + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = activations.SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + # select which Activation1d, lazy-load cuda version to ensure backward compatibility + if use_cuda_kernel: + raise NotImplementedError("CUDA kernels not supported in this port") + else: + Activation1d = TorchActivation1d + + act = Activation1d(act) + + return act + + +class ResidualUnit(nn.Module): + """ + Residual unit with dilated convolutions. + Used in OobleckDecoderBlock. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + dilation: Dilation rate + kernel_size: Convolution kernel size (default: 7) + use_snake: Whether to use Snake activation (default: False) + antialias_activation: Whether to use anti-aliasing (default: False) + causal: Whether to use causal convolutions (default: False) + padding_mode: Padding mode for convolutions (default: 'zeros') + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + dilation: int, + kernel_size: int = 7, + use_snake: bool = False, + antialias_activation: bool = False, + causal: bool = False, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.dilation = dilation + self.causal = causal + self.kernel_size = kernel_size + + if causal: + self.padding = dilation * (kernel_size - 1) + else: + self.padding = (dilation * (kernel_size - 1)) // 2 + + # original non-causal impl used zero padding (DAC, SAVAE) + # Reflect padding may reduce edge artifacts (EnCodec's default), but + # it increases VRAM usage during training. + self.padding_mode = padding_mode + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding=self.padding, + padding_mode=self.padding_mode, + ), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, padding=0), + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass. + + Args: + x: Input tensor of shape (B, C, T) + + Returns: + Output tensor of shape (B, C, T) + """ + res = x # [B,C,T] + + # apply conv layers + x = self.layers(x) # [B,C,T] (padded if causal) + + if self.causal: + # Trim right padding to get the causal output + x = x[:, :, : -self.padding] # [B,C,T] + + return x + res # [B,C,T] + + +class OobleckDecoderBlock(nn.Module): + """ + Oobleck decoder block with upsampling and residual units. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + stride: Upsampling stride + use_snake: Whether to use Snake activation (default: False) + antialias_activation: Whether to use anti-aliasing (default: False) + use_nearest_upsample: Whether to use nearest neighbor upsampling (default: False) + causal: Whether to use causal convolutions (default: False) + padding_mode: Padding mode for convolutions (default: 'zeros') + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + use_snake: bool = False, + antialias_activation: bool = False, + use_nearest_upsample: bool = False, + causal: bool = False, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.causal = causal + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + self._create_upsample_layer(in_channels, out_channels, stride, use_nearest_upsample, causal, padding_mode), + ResidualUnit( + in_channels=out_channels, + out_channels=out_channels, + dilation=1, + use_snake=use_snake, + causal=causal, + padding_mode=padding_mode, + ), + ResidualUnit( + in_channels=out_channels, + out_channels=out_channels, + dilation=3, + use_snake=use_snake, + causal=causal, + padding_mode=padding_mode, + ), + ResidualUnit( + in_channels=out_channels, + out_channels=out_channels, + dilation=9, + use_snake=use_snake, + causal=causal, + padding_mode=padding_mode, + ), + ) + + def _create_upsample_layer( + self, + in_channels: int, + out_channels: int, + stride: int, + use_nearest_upsample: bool, + causal: bool, + padding_mode: str, + ) -> nn.Module: + """ + Create upsampling layer based on configuration. + + Note: padding_mode parameter is not used in this function. + """ + + if causal: # use EnCodec's SConvTransposed1d for convenience. padding_mode is reflect by default + assert not use_nearest_upsample, "use_nearest_upsample is not implemented for causal mode!" + upsample_layer = SConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=stride, + causal=True, + norm="weight_norm", + ) + else: + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=1, + bias=False, + padding="same", + ), + ) + else: + # WNConvTranspose1d only supports zeros padding mode so it's hardcoded + upsample_layer = WNConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, + padding_mode="zeros", + ) + + return upsample_layer + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass. + + Args: + x: Input tensor of shape (B, C, T) + + Returns: + Output tensor of shape (B, C, T_upsampled) + """ + return self.layers(x) + + def remove_weight_norm(self) -> None: + """Remove weight normalization from all layers.""" + from torch.nn.utils import remove_weight_norm + + for layer in self.layers: + try: + remove_weight_norm(layer) + except (ValueError, AttributeError): + # Layer doesn't have weight norm or is not a module with weight norm + pass diff --git a/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/modules_encodec.py b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/modules_encodec.py new file mode 100755 index 00000000000..007e13f24df --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/audio_tokenizer/modules_encodec.py @@ -0,0 +1,297 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/facebookresearch/encodec under the MIT license. + +"""Convolutional layers wrappers and utilities.""" + +import math +import warnings +from typing import Any + +import einops +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +CONV_NORMALIZATIONS = frozenset( + ["none", "weight_norm", "spectral_norm", "time_layer_norm", "layer_norm", "time_group_norm"] +) + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + + def __init__(self: "ConvLayerNorm", normalized_shape: int | list[int] | torch.Size, **kwargs: Any) -> None: + super().__init__(normalized_shape, **kwargs) + + def forward(self: "ConvLayerNorm", x: torch.Tensor) -> torch.Tensor: + x = einops.rearrange(x, "b ... t -> b t ...") # [B,T,C] + x = super().forward(x) # [B,T,C] + x = einops.rearrange(x, "b t ... -> b ... t") # [B,C,T] + return x # [B,C,T] + + +def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == "weight_norm": + return weight_norm(module) + elif norm == "spectral_norm": + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparameterization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == "layer_norm": + assert isinstance(module, nn.modules.conv._ConvAnd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == "time_group_norm": + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvAnd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) # [B,C,T+extra_padding] + + +def pad1d(x: torch.Tensor, paddings: tuple[int, int], mode: str = "zero", value: float = 0.0): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) # [B,C,T+extra_pad] + padded = F.pad(x, paddings, mode, value) # [B,C,T+padding_left+padding_right] + end = padded.shape[-1] - extra_pad + return padded[..., :end] # [B,C,T+padding_left+padding_right] + else: + return F.pad(x, paddings, mode, value) # [B,C,T+padding_left+padding_right] + + +def unpad1d(x: torch.Tensor, paddings: tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__(self, *args, causal: bool = False, norm: str = "none", norm_kwargs: dict[str, Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) # [B,C_out,T_out] + x = self.norm(x) # [B,C_out,T_out] + return x # [B,C_out,T_out] + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__(self, *args, norm: str = "none", norm_kwargs: dict[str, Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) # [B,C_out,H_out,W_out] + x = self.norm(x) # [B,C_out,H_out,W_out] + return x # [B,C_out,H_out,W_out] + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__(self, *args, causal: bool = False, norm: str = "none", norm_kwargs: dict[str, Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.convtr(x) # [B,C_out,T_out] + x = self.norm(x) # [B,C_out,T_out] + return x # [B,C_out,T_out] + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__(self, *args, norm: str = "none", norm_kwargs: dict[str, Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.convtr(x) # [B,C_out,H_out,W_out] + x = self.norm(x) # [B,C_out,H_out,W_out] + return x # [B,C_out,H_out,W_out] + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: dict[str, Any] = {}, + pad_mode: str = "reflect", + ): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn( + "SConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + self.conv = NormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B,C,T] + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) # [B,C,T+padding_total+extra_padding] + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d( + x, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) # [B,C,T+padding_total+extra_padding] + return self.conv(x) # [B,C_out,T_out] + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: dict[str, Any] = {}, + ): + super().__init__() + self.convtr = NormConvTranspose1d( + in_channels, out_channels, kernel_size, stride, causal=causal, norm=norm, norm_kwargs=norm_kwargs + ) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1.0, ( + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + ) + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 + + def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B,C,T] + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) # [B,C_out,T*stride+padding_total] + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) # [B,C_out,T_out] + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) # [B,C_out,T_out] + return y # [B,C_out,T_out] diff --git a/vllm_omni/diffusion/models/cosmos3/guardrails.py b/vllm_omni/diffusion/models/cosmos3/guardrails.py new file mode 100644 index 00000000000..a085c3f3a59 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/guardrails.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Cosmos3 guardrail hooks for vllm-omni. + +Text: Blocklist (keyword matching) + Qwen3Guard (0.6B LLM classifier) +Video: SigLIP-based content safety filter + RetinaFace face blur + +Enable via custom_pipeline_args or the test script: + python test_cosmos3.py --model ... --guardrails +""" + +from __future__ import annotations + +import os +import warnings +from collections.abc import Callable +from typing import Any + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from vllm.logger import init_logger + +from vllm_omni.diffusion.models.progress_bar import _is_rank_zero + +logger = init_logger(__name__) + +TextGuardrailFn = Callable[[str], None] +VideoGuardrailFn = Callable[[np.ndarray], np.ndarray] + +_text_guardrail: TextGuardrailFn | None = None +_video_guardrail: VideoGuardrailFn | None = None +_initialized = False + +GUARDRAIL_HF_REPO = "nvidia/Cosmos-Guardrail1" +GUARDRAIL_HF_REVISION = "d6d4bfa899a71454a700907664f3e88f503950cf" +CUTOFF_UNSAFE_FRAMES_PERCENT = 10 + + +def set_text_guardrail(fn: TextGuardrailFn) -> None: + global _text_guardrail + _text_guardrail = fn + + +def set_video_guardrail(fn: VideoGuardrailFn) -> None: + global _video_guardrail + _video_guardrail = fn + + +# --------------------------------------------------------------------------- +# Video safety classifier (matches reference: SigLIP so400m + 3-layer head) +# --------------------------------------------------------------------------- +class SafetyClassifier(nn.Module): + """3-layer classifier with BatchNorm (1152 → 512 → 256 → 7).""" + + def __init__(self, input_size: int = 1152, num_classes: int = 7): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(input_size, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Linear(256, num_classes), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + +CLASS_IDX_TO_NAME = { + 0: "Safe", + 1: "Sexual_Content", + 3: "Drugs", + 4: "Child_Abuse", + 5: "Hate_and_Harassment", + 6: "Self-Harm", +} + + +# --------------------------------------------------------------------------- +# Face pixelation utility +# --------------------------------------------------------------------------- +def _pixelate_face(face_img: np.ndarray, blocks: int = 5) -> np.ndarray: + h, w = face_img.shape[:2] + if h == 0 or w == 0: + return face_img + temp = cv2.resize(face_img, (blocks, blocks), interpolation=cv2.INTER_LINEAR) + return cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST) + + +# --------------------------------------------------------------------------- +# Default guardrail builders +# --------------------------------------------------------------------------- +def _download_checkpoint() -> str: + from huggingface_hub import snapshot_download + + return snapshot_download(GUARDRAIL_HF_REPO, revision=GUARDRAIL_HF_REVISION) + + +def _build_text_guardrail(offload_to_cpu: bool) -> TextGuardrailFn: + checkers: list[Callable[[str], tuple[bool, str]]] = [] + + # 1. Blocklist + try: + import nltk + from better_profanity import profanity as profanity_filter + + ckpt_dir = _download_checkpoint() + blocklist_dir = os.path.join(ckpt_dir, "blocklist") + nltk.data.path.append(os.path.join(blocklist_dir, "nltk_data")) + + def _read_keywords(dirpath: str) -> list[str]: + words: list[str] = [] + if not os.path.isdir(dirpath): + return words + for fname in sorted(os.listdir(dirpath)): + fpath = os.path.join(dirpath, fname) + if os.path.isfile(fpath): + with open(fpath) as f: + words.extend(line.strip() for line in f if line.strip()) + return words + + blocklist_words = _read_keywords(os.path.join(blocklist_dir, "custom")) + whitelist_words = _read_keywords(os.path.join(blocklist_dir, "whitelist")) + profanity_filter.load_censor_words(custom_words=blocklist_words, whitelist_words=whitelist_words) + + def _blocklist_check(prompt: str) -> tuple[bool, str]: + if profanity_filter.contains_profanity(prompt): + return False, "Blocked by keyword filter" + return True, "" + + checkers.append(_blocklist_check) + if _is_rank_zero(): + logger.info("Blocklist guardrail loaded (%d keywords)", len(blocklist_words)) + except ImportError: + logger.warning("better-profanity or nltk not installed; skipping blocklist guardrail") + + # 2. Qwen3Guard + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_id = "Qwen/Qwen3Guard-Gen-0.6B" + qwen_tokenizer = AutoTokenizer.from_pretrained(model_id) + device = "cpu" if offload_to_cpu else "cuda" + qwen_model = ( + AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + ) + .to(device) + .eval() + ) + + def _qwen_check(prompt: str) -> tuple[bool, str]: + conversations = [{"role": "user", "content": prompt}] + input_ids = qwen_tokenizer.apply_chat_template( + conversations, + tokenize=True, + return_tensors="pt", + add_generation_prompt=True, + ).to(device) + with torch.no_grad(): + output_ids = qwen_model.generate(input_ids, max_new_tokens=128) + response = qwen_tokenizer.decode( + output_ids[0][input_ids.shape[1] :], + skip_special_tokens=True, + ) + if "unsafe" in response.lower(): + return False, f"Qwen3Guard: {response.strip()}" + return True, "" + + checkers.append(_qwen_check) + if _is_rank_zero(): + logger.info("Qwen3Guard guardrail loaded") + except ImportError: + logger.warning("transformers not installed; skipping Qwen3Guard") + + def text_guardrail(prompt: str) -> None: + for checker in checkers: + is_safe, msg = checker(prompt) + if not is_safe: + raise ValueError(f"Guardrail blocked prompt: {msg}") + + return text_guardrail + + +def _build_video_guardrail(offload_to_cpu: bool) -> VideoGuardrailFn: + ckpt_dir = _download_checkpoint() + safety_checker: Callable[[np.ndarray], tuple[bool, str]] | None = None + face_blurrer: Callable[[np.ndarray], np.ndarray] | None = None + + # 1. Video content safety filter: SigLIP so400m + SafetyClassifier + try: + from PIL import Image + from transformers import SiglipModel, SiglipProcessor + + device = "cpu" if offload_to_cpu else "cuda" + siglip_id = "google/siglip-so400m-patch14-384" + siglip_model = SiglipModel.from_pretrained(siglip_id).to(device, dtype=torch.float32).eval() + siglip_processor = SiglipProcessor.from_pretrained(siglip_id) + + classifier = SafetyClassifier(input_size=1152, num_classes=7) + ckpt_path = os.path.join(ckpt_dir, "video_content_safety_filter", "safety_filter.pt") + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) + # Checkpoint keys have "network." prefix from the VideoSafetyModel wrapper. + state = {k.removeprefix("network."): v for k, v in checkpoint["model"].items()} + classifier.load_state_dict(state) + classifier = classifier.to(device, dtype=torch.float32).eval() + + def _safety_check(frames: np.ndarray) -> tuple[bool, str]: + nonlocal siglip_model, classifier + if offload_to_cpu: + siglip_model = siglip_model.to("cuda") + classifier = classifier.to("cuda") + + unsafe_count = 0 + total = len(frames) + for frame in frames: + img = Image.fromarray(frame) + inputs = siglip_processor(images=img, return_tensors="pt").to("cuda", dtype=torch.float32) + with torch.no_grad(): + features = siglip_model.get_image_features(**inputs) + features = features / features.norm(dim=-1, keepdim=True) + logits = classifier(features) + pred = logits.argmax(dim=-1).item() + class_name = CLASS_IDX_TO_NAME.get(pred, "Unknown") + if class_name != "Safe": + unsafe_count += 1 + + if offload_to_cpu: + siglip_model = siglip_model.to("cpu") + classifier = classifier.to("cpu") + + if unsafe_count / total > CUTOFF_UNSAFE_FRAMES_PERCENT / 100: + return False, f"Video content safety: {unsafe_count}/{total} frames unsafe" + return True, "" + + safety_checker = _safety_check + if _is_rank_zero(): + logger.info("Video content safety filter loaded (SigLIP so400m + classifier)") + except (ImportError, FileNotFoundError) as e: + logger.warning("Could not load video safety filter: %s", e) + + # 2. Face blur: RetinaFace + pixelation + try: + from retinaface.data import cfg_re50 + from retinaface.layers.functions.prior_box import PriorBox + from retinaface.models.retinaface import RetinaFace + from retinaface.utils.nms.py_cpu_nms import py_cpu_nms + + face_ckpt = os.path.join(ckpt_dir, "face_blur_filter", "Resnet50_Final.pth") + if not os.path.exists(face_ckpt): + raise FileNotFoundError(face_ckpt) + + cfg = dict(cfg_re50) + cfg["pretrain"] = False + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + retinaface_net = RetinaFace(cfg=cfg, phase="test") + + # Load weights (strip 'module.' prefix if present) + pretrained_dict = torch.load(face_ckpt, map_location="cpu", weights_only=True) + if "state_dict" in pretrained_dict: + pretrained_dict = pretrained_dict["state_dict"] + pretrained_dict = { + k.replace("module.", "", 1) if k.startswith("module.") else k: v for k, v in pretrained_dict.items() + } + retinaface_net.load_state_dict(pretrained_dict, strict=False) + retinaface_device = "cpu" if offload_to_cpu else "cuda" + retinaface_net = retinaface_net.to(retinaface_device, dtype=torch.float32).eval() + + CONF_THRESH = 0.7 + NMS_THRESH = 0.4 + TOP_K = 5000 + KEEP_TOP_K = 750 + + def _decode_batch(loc, priors, variances): + batch_size = loc.size(0) + p = priors.unsqueeze(0).expand(batch_size, -1, -1) + boxes = torch.cat( + ( + p[:, :, :2] + loc[:, :, :2] * variances[0] * p[:, :, 2:], + p[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1]), + ), + dim=2, + ) + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + def _face_blur(frames: np.ndarray) -> np.ndarray: + nonlocal retinaface_net + if offload_to_cpu: + retinaface_net = retinaface_net.to("cuda") + + prior_data = None + scale = None + result_frames = [] + + for frame in frames: + frame_t = torch.from_numpy(frame).to("cuda", dtype=torch.float32) + frame_t = frame_t.permute(2, 0, 1).unsqueeze(0) # [1, C, H, W] + frame_t = frame_t[:, [2, 1, 0], :, :] # RGB → BGR + means = torch.tensor([104.0, 117.0, 123.0], device="cuda", dtype=torch.float32).view(1, 3, 1, 1) + frame_t = frame_t - means + + h, w = frame_t.shape[2], frame_t.shape[3] + if prior_data is None: + priorbox = PriorBox(cfg, image_size=(h, w)) + prior_data = priorbox.forward().to("cuda", dtype=torch.float32) + if scale is None: + scale = torch.tensor([w, h, w, h], device="cuda", dtype=torch.float32) + + with torch.no_grad(): + loc, conf, _ = retinaface_net(frame_t) + + boxes = _decode_batch(loc, prior_data, cfg["variance"]) + boxes = (boxes * scale).squeeze(0).cpu().numpy() + scores = conf.squeeze(0)[:, 1].cpu().numpy() + + # Filter by confidence + inds = np.where(scores > CONF_THRESH)[0] + boxes_f = boxes[inds] + scores_f = scores[inds] + order = scores_f.argsort()[::-1][:TOP_K] + boxes_f = boxes_f[order] + scores_f = scores_f[order] + + # NMS + dets = np.hstack((boxes_f, scores_f[:, np.newaxis])).astype(np.float32) + keep = py_cpu_nms(dets, NMS_THRESH) + dets = dets[keep][:KEEP_TOP_K] + + out_frame = frame.copy() + for det in dets: + x1, y1, x2, y2 = map(int, det[:4]) + if x2 - x1 < 20 or y2 - y1 < 20: + continue + max_h, max_w = out_frame.shape[:2] + y1c, y2c = max(y1, 0), min(y2, max_h) + x1c, x2c = max(x1, 0), min(x2, max_w) + out_frame[y1c:y2c, x1c:x2c] = _pixelate_face(out_frame[y1c:y2c, x1c:x2c]) + + result_frames.append(out_frame) + + if offload_to_cpu: + retinaface_net = retinaface_net.to("cpu") + + return np.array(result_frames) + + face_blurrer = _face_blur + if _is_rank_zero(): + logger.info("Face blur filter loaded (RetinaFace Resnet50)") + except (ImportError, FileNotFoundError) as e: + logger.warning("Could not load face blur filter: %s", e) + + def video_guardrail(frames: np.ndarray) -> np.ndarray: + if safety_checker is not None: + is_safe, msg = safety_checker(frames) + if not is_safe: + raise ValueError(f"Guardrail blocked video: {msg}") + if face_blurrer is not None: + frames = face_blurrer(frames) + return frames + + return video_guardrail + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- +def _init_default_guardrails(offload_to_cpu: bool = False) -> None: + global _text_guardrail, _video_guardrail, _initialized + if _initialized: + return + if _is_rank_zero(): + logger.info("Initializing Cosmos3 guardrails (offload_to_cpu=%s)...", offload_to_cpu) + _text_guardrail = _build_text_guardrail(offload_to_cpu) + _video_guardrail = _build_video_guardrail(offload_to_cpu) + _initialized = True + if _is_rank_zero(): + logger.info("Cosmos3 guardrails initialized.") + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- +def ensure_initialized(od_config: Any) -> None: + if not is_guardrails_enabled(od_config): + return + _init_default_guardrails(offload_to_cpu=get_offload_flag(od_config)) + + +def check_text_safety(prompt: str) -> None: + if _text_guardrail is not None: + _text_guardrail(prompt) + + +def check_video_safety(video_tensor: torch.Tensor) -> torch.Tensor: + if _video_guardrail is None: + return video_tensor + + v = video_tensor.detach().cpu().float() + if v.dim() == 5: + v = v[0] + v = v.clamp(-1, 1) * 0.5 + 0.5 + frames_np = (v.permute(1, 2, 3, 0).numpy() * 255).round().astype(np.uint8) + + frames_np = _video_guardrail(frames_np) + + # Convert back to [-1, 1] to match the VAE output range. + result = torch.from_numpy(frames_np.copy()).float() / 127.5 - 1.0 + result = result.permute(3, 0, 1, 2) + if video_tensor.dim() == 5: + result = result.unsqueeze(0) + return result.to(video_tensor.device) + + +def is_guardrails_enabled(od_config: Any) -> bool: + return False + cfg = getattr(od_config, "model_config", None) or {} + return bool(cfg.get("guardrails", True)) + + +def get_offload_flag(od_config: Any) -> bool: + cfg = getattr(od_config, "model_config", None) or {} + return bool(cfg.get("offload_guardrail_models", False)) diff --git a/vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py b/vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py new file mode 100644 index 00000000000..634be5f6ca7 --- /dev/null +++ b/vllm_omni/diffusion/models/cosmos3/pipeline_cosmos3.py @@ -0,0 +1,1848 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Cosmos3 text/image-to-video and text-to-image pipeline for vllm-omni. + +Single pipeline class supports T2V, I2V, and T2I; the mode is selected at +runtime by: + +* ``prompt["modalities"]`` contains ``"image"``: **T2I** (text-to-image). +* ``prompt["modalities"]`` contains ``"video"`` or is omitted: **T2V** + (text-to-video). +* ``multi_modal_data['image']`` present on the prompt: **I2V** + (handled by :func:`get_cosmos3_pre_process_func`) + +""" + +from __future__ import annotations + +import os +import time +from collections.abc import Iterable +from typing import Any, ClassVar + +import numpy as np +import PIL.Image +import torch +from diffusers import UniPCMultistepScheduler +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from torch import nn +from transformers import AutoTokenizer +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import DistributedAutoencoderKLWan +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import ( + get_classifier_free_guidance_world_size, +) +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero +from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin +from vllm_omni.diffusion.request import OmniDiffusionRequest + +from .action import ( + ACTION_MODE_FORWARD_DYNAMICS, + ACTION_MODE_INVERSE_DYNAMICS, + ACTION_MODE_POLICY, + action_start_frame_offset, + build_action_condition_mask, + build_vision_condition_mask, + find_closest_target_size, + load_action_tensor, + normalize_action_mode, + pad_action_to_dim, + resolve_domain_id, +) +from .transformer_cosmos3 import Cosmos3VFMTransformer + +logger = init_logger(__name__) + +COSMOS3_DEFAULT_NEGATIVE_PROMPT = "" +COSMOS3_DURATION_TEMPLATE = "The video is {duration:.1f} seconds long and is of {fps:.0f} FPS." +COSMOS3_RESOLUTION_TEMPLATE = "This video is of {height}x{width} resolution." +COSMOS3_IMAGE_RESOLUTION_TEMPLATE = "This image is of {height}x{width} resolution." +COSMOS3_INVERSE_DURATION_TEMPLATE = "The video is not {duration:.1f} seconds long and is not of {fps:.0f} FPS." +COSMOS3_INVERSE_RESOLUTION_TEMPLATE = "This video is not of {height}x{width} resolution." +COSMOS3_INVERSE_IMAGE_RESOLUTION_TEMPLATE = "This image is not of {height}x{width} resolution." +COSMOS3_SYSTEM_PROMPT = "You are a helpful assistant who will generate videos from a given prompt." +COSMOS3_T2I_SYSTEM_PROMPT = "You are a helpful assistant who will generate images from a given prompt." + + +# --------------------------------------------------------------------------- +# Post-process function (registered in registry.py) +# --------------------------------------------------------------------------- +def get_cosmos3_pre_process_func(od_config: OmniDiffusionConfig): + """Pre-process function for both T2V and I2V. + + For T2V (no image in ``multi_modal_data``), the request is returned + unchanged after the optional guardrails check. For I2V (image present), + the conditioning image is loaded, aspect-resized + center-cropped, and + stored back on the prompt as ``additional_information.preprocessed_image``. + """ + from .guardrails import check_text_safety, ensure_initialized, is_guardrails_enabled + + video_processor = VideoProcessor(vae_scale_factor=16) + guardrails_on = is_guardrails_enabled(od_config) + if guardrails_on: + ensure_initialized(od_config) + + def _extra_args(request: OmniDiffusionRequest) -> dict[str, Any]: + extra = getattr(getattr(request, "sampling_params", None), "extra_args", None) + return extra if isinstance(extra, dict) else {} + + def _request_action_mode(request: OmniDiffusionRequest) -> str | None: + return normalize_action_mode(_extra_args(request).get("action_mode")) + + def _set_action_size_from_image(request: OmniDiffusionRequest, image: PIL.Image.Image) -> tuple[int, int]: + sp = request.sampling_params + if sp.height is not None and sp.width is not None: + return int(sp.height), int(sp.width) + + extra = _extra_args(request) + resolution = extra.get("resolution", extra.get("image_size", 480)) + target_w, target_h = find_closest_target_size(image.height, image.width, resolution) + if sp.height is None: + sp.height = target_h + if sp.width is None: + sp.width = target_w + return int(sp.height), int(sp.width) + + def _pil_to_rgb(value: Any) -> PIL.Image.Image: + if isinstance(value, str): + return PIL.Image.open(value).convert("RGB") + if isinstance(value, PIL.Image.Image): + return value.convert("RGB") + raise TypeError(f"Cosmos3 action preprocessing expected PIL image or image path, got {type(value)!r}.") + + def _resize_and_pad_action_image(image: PIL.Image.Image, target_h: int, target_w: int) -> PIL.Image.Image: + scale = min(target_w / image.width, target_h / image.height, 1.0) + resize_w = max(1, int(scale * image.width + 0.5)) + resize_h = max(1, int(scale * image.height + 0.5)) + if (resize_w, resize_h) != image.size: + image = image.resize((resize_w, resize_h), PIL.Image.Resampling.BICUBIC) + + array = np.asarray(image) + pad_h = target_h - resize_h + pad_w = target_w - resize_w + if pad_h < 0 or pad_w < 0: + raise ValueError( + f"Cosmos3 action image resize exceeded target size: resized={(resize_h, resize_w)}, " + f"target={(target_h, target_w)}." + ) + if pad_h == 0 and pad_w == 0: + return image + pad_mode = "reflect" if pad_h < resize_h and pad_w < resize_w else "edge" + padded = np.pad(array, ((0, pad_h), (0, pad_w), (0, 0)), mode=pad_mode) + return PIL.Image.fromarray(padded) + + def _preprocess_action_image(image: PIL.Image.Image, target_h: int, target_w: int) -> torch.Tensor: + image = _resize_and_pad_action_image(image, target_h, target_w) + return video_processor.preprocess(image, height=target_h, width=target_w) + + def _preprocess_action_video(frames: list[Any], target_h: int, target_w: int) -> torch.Tensor: + if not frames: + raise ValueError("Cosmos3 action video input must contain at least one frame.") + processed = [_preprocess_action_image(_pil_to_rgb(frame), target_h, target_w).squeeze(0) for frame in frames] + return torch.stack(processed, dim=1).unsqueeze(0).contiguous() + + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + action_mode = _request_action_mode(request) + if guardrails_on: + for prompt in request.prompts: + text = prompt if isinstance(prompt, str) else prompt.get("prompt", "") + check_text_safety(text) + + for i, prompt in enumerate(request.prompts): + if isinstance(prompt, str): + continue + multi_modal_data = prompt.get("multi_modal_data", {}) or {} + raw_image = multi_modal_data.get("image") + raw_video = multi_modal_data.get("video") + if raw_image is None and not (action_mode is not None and raw_video is not None): + continue + + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + if not isinstance(raw_video, list) or not raw_video: + raise TypeError("Cosmos3 action video input must be a non-empty list of PIL images or image paths.") + image = _pil_to_rgb(raw_video[0]) + else: + image = _pil_to_rgb(raw_image) + + # Auto-calculate H/W from aspect ratio (720p max area) + if request.sampling_params.height is None or request.sampling_params.width is None: + if action_mode is not None: + _set_action_size_from_image(request, image) + else: + max_area = 720 * 1280 + aspect_ratio = image.height / image.width + mod_value = 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + target_w = request.sampling_params.width + target_h = request.sampling_params.height + if action_mode is not None: + prompt["additional_information"]["preprocessed_image"] = _preprocess_action_image( + image, + int(target_h), + int(target_w), + ) + else: + scale = max(target_w / image.width, target_h / image.height) + resize_w = int(np.ceil(scale * image.width)) + resize_h = int(np.ceil(scale * image.height)) + image = image.resize((resize_w, resize_h), PIL.Image.Resampling.LANCZOS) + left = (resize_w - target_w) // 2 + top = (resize_h - target_h) // 2 + image = image.crop((left, top, left + target_w, top + target_h)) + + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=target_h, width=target_w + ) + if action_mode is not None and raw_video is not None: + if not isinstance(raw_video, list): + raise TypeError("Cosmos3 action video input must be a list of PIL images or image paths.") + prompt["additional_information"]["preprocessed_video"] = _preprocess_action_video( + raw_video, + int(target_h), + int(target_w), + ) + request.prompts[i] = prompt + + return request + + return pre_process_func + + +def get_cosmos3_post_process_func(od_config: OmniDiffusionConfig): + from .guardrails import check_video_safety, is_guardrails_enabled + + video_processor = VideoProcessor(vae_scale_factor=16) + guardrails_on = is_guardrails_enabled(od_config) + + def _sampling_param(sampling_params, key: str, default=None): + extra = getattr(sampling_params, "extra_args", None) + if isinstance(extra, dict) and extra.get(key) is not None: + return extra[key] + value = getattr(sampling_params, key, None) + return default if value is None else value + + def _resolve_output_fps(sampling_params): + fps = ( + _sampling_param(sampling_params, "resolved_frame_rate") + or _sampling_param(sampling_params, "frame_rate") + or _sampling_param(sampling_params, "fps") + or 24.0 + ) + try: + fps_value = float(fps) + except (TypeError, ValueError): + fps_value = 24.0 + if fps_value <= 0: + fps_value = 24.0 + return int(fps_value) if fps_value.is_integer() else fps_value + + def post_process_func( + output: torch.Tensor | dict[str, torch.Tensor] | tuple, + output_type: str = "np", + sampling_params=None, + ): + if output_type == "latent": + return output + + audio = None + audio_sample_rate = None + if isinstance(output, dict): + if "image" in output and "video" in output: + raise ValueError("Cosmos3 output cannot contain both image and video payloads.") + if "image" in output: + video = output["image"] + elif "video" in output: + video = output["video"] + else: + raise ValueError("Cosmos3 postprocess expected an 'image' or 'video' output payload.") + audio = output.get("audio") + audio_sample_rate = output.get("audio_sample_rate") + elif isinstance(output, tuple): + if len(output) == 3: + video, audio, audio_sample_rate = output + elif len(output) == 2: + video, audio = output + else: + raise ValueError( + "Cosmos3 postprocess expects output tensor, output dict, or (video, audio[, sample_rate]) tuple." + ) + else: + video = output + + if isinstance(output, dict) and "image" in output: + if audio is not None: + raise ValueError("Cosmos3 text-to-image postprocess does not support audio output.") + if video.ndim != 5 or video.shape[2] != 1: + raise ValueError( + "Cosmos3 text-to-image postprocess expects decoded output " + f"with shape [B, C, 1, H, W], got {tuple(video.shape)}." + ) + image = video.squeeze(2) # [B, 3, H, W] + if guardrails_on: + # check_video_safety expects a 5D tensor; re-add T axis. + checked = check_video_safety(image.unsqueeze(2)) + image = checked.squeeze(2) + return video_processor.postprocess(image, output_type="pil") + if guardrails_on: + video = check_video_safety(video) + result = {"video": video_processor.postprocess_video(video, output_type=output_type)} + if audio is None: + return result + if isinstance(audio, torch.Tensor): + audio = audio.detach().cpu() + result["audio"] = audio + result["fps"] = _resolve_output_fps(sampling_params) + if audio_sample_rate is not None: + result["audio_sample_rate"] = int(audio_sample_rate) + return result + + return post_process_func + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- +class Cosmos3OmniDiffusersPipeline( + nn.Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin +): + """Cosmos3 text/image-to-video / text-to-image pipeline. + + Architecture: Mixture-of-Transformers with Qwen3-VL backbone. + - Understanding pathway: causal self-attention on text (runs once, K/V cached) + - Generation pathway: cross-attention on noisy visual latents (runs each step) + + Supports T2V, I2V, and T2I from the same class. Mode is selected at + runtime: + + * **T2I** when ``prompt["modalities"]`` contains ``"image"``. Latent + T-dim is forced to 1, T2I-specific scheduler defaults are applied (50 steps, + flow_shift=3.0, guidance_interval=[400, 1000]), the duration + template is suppressed, and post-process emits PIL images. + * **I2V** when the request supplies a preprocessed image via + ``multi_modal_data['image']`` (handled by + :func:`get_cosmos3_pre_process_func`) and the requested output modality + is not image. + Frame 0 of the initial latent is set to the VAE-encoded conditioning + image, frame-0 noise predictions are masked to zero, and the clean + image latent is re-injected at frame 0 after each scheduler step. + * **T2V** otherwise (default video generation). + """ + + support_image_input: ClassVar[bool] = True + color_format: ClassVar[str] = "RGB" + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ) -> None: + super().__init__() + if od_config.enable_cpu_offload: + raise ValueError( + "Cosmos3 has no separate text encoder, so CPU offloading " + "(transformer↔encoder swapping) is not supported. " + "Use --enable-layerwise-offload instead." + ) + self.od_config = od_config + self.device = get_local_device() + self.dtype = getattr(od_config, "dtype", torch.bfloat16) + + model_path = od_config.model + local_files_only = os.path.exists(model_path) + + # --- Tokenizer --- + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + subfolder="text_tokenizer", + local_files_only=local_files_only, + ) + + # --- VAE --- + self.vae = DistributedAutoencoderKLWan.from_pretrained( + model_path, + subfolder="vae", + torch_dtype=torch.bfloat16, + local_files_only=local_files_only, + ).to(self.device) + + if not hasattr(self.vae.config, "scale_factor_temporal"): + raise ValueError( + "Cosmos3 Diffusers VAE config must define scale_factor_temporal " + "so transformer mRoPE temporal positions can be computed correctly." + ) + self.vae_scale_factor_temporal = int(self.vae.config.scale_factor_temporal) + self.vae_scale_factor_spatial = getattr(self.vae.config, "scale_factor_spatial", 16) + + # --- Transformer (weights loaded later via weights_sources) --- + self.transformer = Cosmos3VFMTransformer( + od_config=od_config, + temporal_compression_factor=self.vae_scale_factor_temporal, + ) + + # --- Scheduler --- + # Load from checkpoint to preserve solver_order, timestep_spacing, + # beta_schedule, sigma bounds, flow_shift, etc. Only override + # flow_shift when explicitly requested by the user. + self.scheduler = UniPCMultistepScheduler.from_pretrained( + model_path, + subfolder="scheduler", + local_files_only=local_files_only, + ) + if od_config.flow_shift is not None: + self.scheduler = UniPCMultistepScheduler.from_config(self.scheduler.config, flow_shift=od_config.flow_shift) + + # --- Video processor for post-decode --- + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # --- Weight sources for DiffusersPipelineLoader --- + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=model_path, + subfolder=None, + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + allow_patterns_overrides=["transformer/*.safetensors"], + ), + ] + + # Snapshot the loaded scheduler config so we can rebuild the + # scheduler at request time when a per-request flow_shift override + # is supplied (T2I uses shift=3.0; T2V/I2V use the engine default). + self._base_scheduler_config = self.scheduler.config + # ``_engine_init_flow_shift`` is the shift the engine was configured + # with at init time (after the optional ``od_config.flow_shift`` + # override). This is the value T2V/I2V requests fall back to. + # ``_current_flow_shift`` tracks the shift the scheduler *currently* + # uses, since per-request rebuilds in ``_set_flow_shift`` must be + # detectable on the next request to restore the prior shift. + self._engine_init_flow_shift = float(getattr(self.scheduler.config, "flow_shift", 1.0) or 1.0) + self._current_flow_shift = self._engine_init_flow_shift + + self._guidance_scale = None + self._num_timesteps = None + self._loaded_weight_names: set[str] = set() + self._sound_tokenizer = None + if getattr(self.transformer, "sound_gen", False): + self._get_sound_tokenizer() + + self.setup_diffusion_pipeline_profiler( + enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler + ) + + # -- Weight loading -------------------------------------------------------- + + @staticmethod + def _remap_ckpt_key(key: str) -> str | None: + """Remap a Diffusers transformer key to the model parameter namespace. + + Checkpoint keys arrive with a synthetic ``transformer.`` prefix from + ``weights_sources``. The source checkpoint itself uses the Diffusers + transformer namespace: top-level projections plus ``model.*`` for the + Qwen3-VL backbone. UND and GEN components share each layer in the + source and are split into separate module lists here. + + Returns the remapped name under ``transformer.``, or None to skip. + """ + k = key + # Strip the weights_sources prefix + if k.startswith("transformer."): + k = k[len("transformer.") :] + + # Top-level generation components. + if k.startswith( + ( + "vae2llm.", + "llm2vae.", + "time_embedder.", + "sound2llm.", + "llm2sound.", + "action2llm.", + "llm2action.", + ) + ): + return f"transformer.{k}" + if k in ("sound_modality_embed", "sound_modality_embed.weight"): + return "transformer.sound_modality_embed" + if k in ("action_modality_embed", "action_modality_embed.weight"): + return "transformer.action_modality_embed" + if k.startswith("action_pos_embed."): + return None + + # Skip lm_head + if k.startswith("lm_head."): + return None + + # embed_tokens / norm → language_model.* + if k.startswith("model.embed_tokens."): + return f"transformer.language_model.{k[len('model.') :]}" + if k.startswith("model.norm."): + return f"transformer.language_model.{k[len('model.') :]}" + + # norm_moe_gen → top level + if k.startswith("model.norm_moe_gen."): + return f"transformer.{k[len('model.') :]}" + + if not k.startswith("model.layers."): + return None + k = k[len("model.") :] + + if not k.startswith("layers."): + return None + + parts = k.split(".", 2) # ['layers', '{i}', '{rest}'] + if len(parts) != 3: + return None + layer_idx = parts[1] + rest = parts[2] + + und_lp = f"transformer.language_model.layers.{layer_idx}" + gen_lp = f"transformer.gen_layers.{layer_idx}" + + _LAYER_MAP = { + # UND attention + "self_attn.q_proj.": f"{und_lp}.self_attn.q_proj.", + "self_attn.k_proj.": f"{und_lp}.self_attn.k_proj.", + "self_attn.v_proj.": f"{und_lp}.self_attn.v_proj.", + "self_attn.o_proj.": f"{und_lp}.self_attn.o_proj.", + "self_attn.q_norm.": f"{und_lp}.self_attn.q_norm.", + "self_attn.k_norm.": f"{und_lp}.self_attn.k_norm.", + # GEN attention + "self_attn.q_proj_moe_gen.": f"{gen_lp}.cross_attention.q_proj.", + "self_attn.k_proj_moe_gen.": f"{gen_lp}.cross_attention.k_proj.", + "self_attn.v_proj_moe_gen.": f"{gen_lp}.cross_attention.v_proj.", + "self_attn.o_proj_moe_gen.": f"{gen_lp}.cross_attention.o_proj.", + "self_attn.q_norm_moe_gen.": f"{gen_lp}.cross_attention.q_norm.", + "self_attn.k_norm_moe_gen.": f"{gen_lp}.cross_attention.k_norm.", + # Norms + "input_layernorm.": f"{und_lp}.input_layernorm.", + "post_attention_layernorm.": f"{und_lp}.post_attention_layernorm.", + "input_layernorm_moe_gen.": f"{gen_lp}.input_layernorm.", + "post_attention_layernorm_moe_gen.": f"{gen_lp}.post_attention_layernorm.", + # UND MLP + "mlp.gate_proj.": f"{und_lp}.mlp.gate_proj.", + "mlp.up_proj.": f"{und_lp}.mlp.up_proj.", + "mlp.down_proj.": f"{und_lp}.mlp.down_proj.", + # GEN MLP + "mlp_moe_gen.gate_proj.": f"{gen_lp}.mlp.gate_proj.", + "mlp_moe_gen.up_proj.": f"{gen_lp}.mlp.up_proj.", + "mlp_moe_gen.down_proj.": f"{gen_lp}.mlp.down_proj.", + } + + for pattern, replacement in _LAYER_MAP.items(): + if rest.startswith(pattern): + suffix = rest[len(pattern) :] + return replacement + suffix + + return None + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Stream-remap checkpoint weights and load via AutoWeightsLoader. + + Handles quantization, TP-aware weight_loader, and buffer loading. + Returns the set of loaded parameter names for strict validation. + """ + state = self.state_dict() + allowed = set(state.keys()) + tp_aware = {n for n, p in self.named_parameters() if hasattr(p, "weight_loader")} + + def _remapped_weights() -> Iterable[tuple[str, torch.Tensor]]: + total = kept = 0 + for name, tensor in weights: + total += 1 + remapped = self._remap_ckpt_key(name) + if remapped is not None and (remapped in allowed or remapped in tp_aware): + kept += 1 + yield remapped, tensor + if _is_rank_zero(): + logger.info( + "Cosmos3 weight remap: kept %d/%d tensors", + kept, + total, + ) + + loader = AutoWeightsLoader(self) + loaded = loader.load_weights(_remapped_weights()) + self.transformer.post_load_weights() + self.transformer.eval() + self._loaded_weight_names = set(loaded) + if getattr(self.transformer, "sound_gen", False): + sound_markers = ("sound2llm.", "llm2sound.", "sound_modality_embed") + missing = [marker.rstrip(".") for marker in sound_markers if not any(marker in name for name in loaded)] + if missing: + raise ValueError( + "Cosmos3 transformer config enables sound generation, but " + f"the checkpoint is missing sound weights for {missing}. " + "Use a sound-capable transformer checkpoint." + ) + if getattr(self.transformer, "action_gen", False): + action_markers = ("action2llm.", "llm2action.", "action_modality_embed") + missing = [marker.rstrip(".") for marker in action_markers if not any(marker in name for name in loaded)] + if missing: + raise ValueError( + "Cosmos3 transformer config enables action generation, but " + f"the checkpoint is missing action weights for {missing}. " + "Use an action-capable transformer checkpoint." + ) + return loaded + + def predict_noise(self, **kwargs) -> torch.Tensor | tuple[torch.Tensor, ...]: + """Override CFGParallelMixin.predict_noise for Cosmos3. + + The transformer returns the raw prediction: video-only as a tensor, + or a tuple in video, action, sound order for multimodal generation. + """ + return self.transformer(**kwargs) + + @staticmethod + def _cfg_parallel_active() -> bool: + try: + return get_classifier_free_guidance_world_size() > 1 + except Exception: + return False + + @staticmethod + def _get_sp_param(sp, key: str, default=None): + """Read a runtime control from sampling params. + + Order of precedence: + 1. ``sp.extra_args[key]`` - preferred path; the OpenAI image/video + endpoints surface custom controls here (see e.g. + ``serving_video.py`` writing ``extra_args['flow_shift']``). + 2. direct attribute on ``sp`` - backward compat for callers that + set attributes directly. + 3. ``default``. + + Skipping this helper would cause API-driven overrides like + ``request.flow_shift`` (forwarded as ``extra_args['flow_shift']``) to + be silently ignored. + """ + extra = getattr(sp, "extra_args", None) + if isinstance(extra, dict) and extra.get(key) is not None: + return extra[key] + val = getattr(sp, key, None) + if val is not None: + return val + return default + + @staticmethod + def _truthy(value) -> bool: + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return bool(value) + + @classmethod + def _get_prompt_param(cls, prompt_data, key: str, default=None): + if not isinstance(prompt_data, dict): + return default + if prompt_data.get(key) is not None: + return prompt_data[key] + additional = prompt_data.get("additional_information") + if isinstance(additional, dict) and additional.get(key) is not None: + return additional[key] + return default + + @classmethod + def _is_sound_request(cls, prompt_data, sp) -> bool: + keys = ( + "sound_gen", + "generate_sound", + "enable_sound_generation", + "return_audio", + "output_audio", + "generate_audio", + ) + for key in keys: + if cls._truthy(cls._get_prompt_param(prompt_data, key, None)): + return True + if cls._truthy(cls._get_sp_param(sp, key, None)): + return True + return False + + @classmethod + def _get_action_mode(cls, prompt_data, sp) -> str | None: + return normalize_action_mode( + cls._get_sp_param(sp, "action_mode", cls._get_prompt_param(prompt_data, "action_mode", None)) + ) + + def _get_sound_tokenizer(self): + if not hasattr(self, "_sound_tokenizer"): + self._sound_tokenizer = None + if self._sound_tokenizer is None: + from .sound_tokenizer import Cosmos3SoundTokenizer + + self._sound_tokenizer = Cosmos3SoundTokenizer.from_config(self.od_config) + return self._sound_tokenizer + + @staticmethod + def _is_t2i_request(req: OmniDiffusionRequest) -> bool: + """Detect text-to-image mode from request-level prompt modalities.""" + if not req.prompts: + return False + first_prompt = req.prompts[0] + modalities = first_prompt.get("modalities", []) if isinstance(first_prompt, dict) else [] + if modalities is None: + modalities = [] + if isinstance(modalities, str): + modalities = [modalities] + if "image" in modalities and "video" in modalities: + raise ValueError("Cosmos3 prompt modalities cannot request both image and video output.") + return "image" in modalities + + def _set_flow_shift(self, target_shift: float) -> None: + """Set the UniPC ``flow_shift`` to a concrete target value. + + The scheduler is rebuilt from the saved base config if + the target differs from the current shift. Tracking + ``self._current_flow_shift`` explicitly is required because the + previous mode may have rebuilt the scheduler - we cannot rely on + ``self.scheduler.config.flow_shift`` reflecting the last requested + target if a rebuild was skipped via the equality check. + """ + target = float(target_shift) + if target == float(self._current_flow_shift): + return + self.scheduler = UniPCMultistepScheduler.from_config(self._base_scheduler_config, flow_shift=target) + self._current_flow_shift = target + + def _set_scheduler_timesteps(self, num_inference_steps: int) -> None: + for name, value in vars(self.scheduler).items(): + if isinstance(value, torch.Tensor) and value.device.type != "cpu": + setattr(self.scheduler, name, value.cpu()) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale is not None and self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + # -- Prompt formatting ----------------------------------------------------- + + @staticmethod + def _apply_metadata_templates( + prompt: str, + num_frames: int, + frame_rate: float, + height: int, + width: int, + duration_template: str | None = COSMOS3_DURATION_TEMPLATE, + resolution_template: str | None = COSMOS3_RESOLUTION_TEMPLATE, + force_duration_template: bool = False, + ) -> str: + """Append duration and resolution metadata to a prompt. + + Strips trailing dot and appends ``".