Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
1cc4059
Added Cosmos3 model
MaciejBalaNV May 14, 2026
ee77c61
Small qol improvements
MaciejBalaNV May 14, 2026
0d0542f
Updated docs for Cosmos3
MaciejBalaNV May 14, 2026
bd4ecb3
Cleared up docs
MaciejBalaNV May 14, 2026
921cc4b
Fixed sound quality issues
MaciejBalaNV May 14, 2026
75d9888
Linter fixes
MaciejBalaNV May 14, 2026
9e9e453
extra cleanup
MaciejBalaNV May 14, 2026
3c7cd31
Updated examples to refer to HF repo
MaciejBalaNV May 15, 2026
c36b23c
Improved guardrails
MaciejBalaNV May 15, 2026
095f6b5
Linter fixes
MaciejBalaNV May 15, 2026
0585095
Reworked guardrail error
MaciejBalaNV May 18, 2026
144ffdf
Simplify sound tokenizer and bring parity with diffusers
MaciejBalaNV May 18, 2026
6775a91
Rename _layerwise_offload_blocks_attr
MaciejBalaNV May 18, 2026
1c7f785
Merge branch 'main' into mbala/cosmos3_model
MaciejBalaNV May 18, 2026
ab3c262
Added yaml example without guardrails
MaciejBalaNV May 18, 2026
8948494
Updated examples; improved action generation
MaciejBalaNV May 19, 2026
1a93c89
Merge branch 'main' into mbala/cosmos3_model
MaciejBalaNV May 19, 2026
4fa8f65
Doc cleanup
MaciejBalaNV May 19, 2026
3ac4c71
Doc cleanup v2
MaciejBalaNV May 19, 2026
b16e4c7
Updated deploy config
MaciejBalaNV May 19, 2026
149700e
Removed examples for now
MaciejBalaNV May 20, 2026
a063830
Simplify tests
MaciejBalaNV May 20, 2026
386d40f
Scope Cosmos3 to core generation
MaciejBalaNV May 20, 2026
84dc378
Merge branch 'vllm-project:main' into mbala/cosmos3_model
MaciejBalaNV May 20, 2026
c3500c0
Add back some of the unnecessarily deleted code
MaciejBalaNV May 20, 2026
8536f5b
Linter fixes
MaciejBalaNV May 20, 2026
bcf609e
Simplified the guardrails with cosmos-guardrail package
MaciejBalaNV May 22, 2026
3bf5f73
Remove the Cosmos3 model from pipeline due to loading issues
MaciejBalaNV May 22, 2026
c4b8886
Removed introduced guardrail error
MaciejBalaNV May 22, 2026
0d325bf
Improved error message
MaciejBalaNV May 22, 2026
397d2fe
Remove custom RMSNorm
MaciejBalaNV May 22, 2026
82b8a2c
Added Cosmos3 as a pipeline
MaciejBalaNV May 22, 2026
3c7305a
Merge branch 'main' into mbala/cosmos3_model
MaciejBalaNV May 22, 2026
2e88a23
Cleanup
MaciejBalaNV May 22, 2026
bd12fd4
Adapted to new checkpoint format
MaciejBalaNV May 25, 2026
49fb11a
Merge branch 'main' into mbala/cosmos3_model
MaciejBalaNV May 25, 2026
2c8ec84
Linter fixes
MaciejBalaNV May 25, 2026
6d16f65
Cleaned up the code
MaciejBalaNV May 25, 2026
85dfeb2
Fixed a test
MaciejBalaNV May 25, 2026
c6185a1
Fixed multi-GPU with deploy configs
MaciejBalaNV May 26, 2026
64bfa80
Merge branch 'main' into mbala/cosmos3_model
MaciejBalaNV May 28, 2026
ed5e391
Linter fix
MaciejBalaNV May 28, 2026
5a4eb97
Reverted c6185a1324397b3729b1c67f4b59b8438d317957
MaciejBalaNV May 28, 2026
d9a35ab
Fix CUDNN_ATTN for GQA
MaciejBalaNV May 28, 2026
7c91b3c
Answered review comments
MaciejBalaNV May 28, 2026
b78f881
Add Cosmos3 sound generation
MaciejBalaNV May 28, 2026
2d0725e
Added action generation
MaciejBalaNV May 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ th {
| `ZImagePipeline` | Z-Image | `Tongyi-MAI/Z-Image-Turbo` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `WanPipeline` | Wan2.1-T2V, Wan2.2-T2V, Wan2.2-TI2V | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.1-T2V-14B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `WanImageToVideoPipeline` | Wan2.2-I2V | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `Cosmos3OmniDiffusersPipeline` | Cosmos3 T2I, T2V, I2V, T2V with sound, action policy | `nvidia/Cosmos3-Nano` | ✅︎ | | | |
| `WanSpeechToVideoPipeline` | Wan2.2-S2V | `Wan-AI/Wan2.2-S2V-14B` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `Wan22VACEPipeline` | Wan2.1-VACE | `Wan-AI/Wan2.1-VACE-1.3B-diffusers`, `Wan-AI/Wan2.1-VACE-14B-diffusers` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `LTX2Pipeline` | LTX-2-T2V | `Lightricks/LTX-2` | ✅︎ | ✅︎ | | |
Expand Down
4 changes: 4 additions & 0 deletions docs/user_guide/diffusion_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,12 @@ The following tables show which models support each feature:
| **Stable-Diffusion3.5** | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ (decode) | ❌ | ❌ |
| **Z-Image** | ✅ | ✅ | ✅ | ❓ | ✅ (TP=2 only) | ❌ | ✅ | ❌ | ✅ (decode) | ✅ | ❌ |
| **ERNIE-Image** | ❌ | ✅ | ✅ | ❓ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **Cosmos3** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ |

> Notes:
> 1. Nextstep_1(T2I) does not support cache acceleration methods such as TeaCache or Cache-DiT.
> 2. `Tongyi-MAI/Z-Image-Turbo` and `SII-GAIR/daVinci-MagiHuman-Base-1080p` are distilled models with minimal NFEs; CFG-Parallel is not necessary.
> 3. Cosmos3 T2I uses `Cosmos3OmniDiffusersPipeline` with `modalities=["image"]`. Model-level CPU offload is not supported; use layerwise offload.

### VideoGen

Expand All @@ -149,6 +151,8 @@ The following tables show which models support each feature:
| **Helios** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ |
| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **Cosmos3** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ (encode/decode) | ✅ | ❌ |


**Frame Interpolation Support**

Expand Down
19 changes: 19 additions & 0 deletions tests/diffusion/cache/test_cache_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
cd_backend.enable_cache_for_helios,
cd_backend.enable_cache_for_wan22,
cd_backend.enable_cache_for_longcat_image,
cd_backend.enable_cache_for_cosmos3,
]

SAMPLE_CACHE_CONFIG = DiffusionCacheConfig()
Expand All @@ -47,6 +48,24 @@ def test_separate_cfg(mock_cache_dit, mock_block_adapter, enabler):
assert adapter_kwargs["has_separate_cfg"] is True


@patch("vllm_omni.diffusion.cache.cache_dit_backend.BlockAdapter")
@patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit")
def test_cosmos3_cache_dit_wraps_gen_layers(mock_cache_dit, mock_block_adapter):
"""Cosmos3 should cache only the repeated GEN pathway blocks."""
mock_pipeline = Mock()
gen_layers = object()
mock_pipeline.transformer.gen_layers = gen_layers

cd_backend.enable_cache_for_cosmos3(mock_pipeline, SAMPLE_CACHE_CONFIG)

mock_cache_dit.enable_cache.assert_called_once()
adapter_kwargs = mock_block_adapter.call_args.kwargs
assert adapter_kwargs["transformer"] is mock_pipeline.transformer
assert adapter_kwargs["blocks"] == [gen_layers]
assert adapter_kwargs["has_separate_cfg"] is True
assert adapter_kwargs["check_forward_pattern"] is False


# This test is skipped on ROCm since rocm_unquantized_gemm doesn't support CPU backend
@pytest.mark.skipif(
current_omni_platform.is_rocm(),
Expand Down
2 changes: 2 additions & 0 deletions tests/diffusion/models/cosmos3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
192 changes: 192 additions & 0 deletions tests/diffusion/models/cosmos3/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

import sys
import types
from types import SimpleNamespace
from typing import Any

import pytest
import torch
from torch import nn


class StubScheduler:
def __init__(self, timesteps: list[int] | None = None, *, flow_shift: float = 1.0) -> None:
self.timesteps = torch.tensor(timesteps or [9, 3], dtype=torch.int64)
self.config = SimpleNamespace(num_train_timesteps=1000, flow_shift=flow_shift)
self.set_timesteps_calls: list[tuple[int, torch.device]] = []
self.step_calls: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = []

def set_timesteps(self, num_steps: int, device: torch.device) -> None:
self.set_timesteps_calls.append((num_steps, device))
self.timesteps = torch.arange(num_steps, 0, -1, dtype=torch.int64, device=device)

def step(self, noise_pred: torch.Tensor, timestep: torch.Tensor, latents: torch.Tensor, **kwargs):
del kwargs
self.step_calls.append((noise_pred.clone(), timestep.clone(), latents.clone()))
return (latents + noise_pred,)


class _ModeLatentDist:
def __init__(self, latents: torch.Tensor) -> None:
self._latents = latents

def mode(self) -> torch.Tensor:
return self._latents


class StubCosmos3VAE:
dtype = torch.float32

def __init__(self, z_dim: int = 2, *, temporal: int = 4, spatial: int = 8) -> None:
self.config = SimpleNamespace(
z_dim=z_dim,
scale_factor_temporal=temporal,
scale_factor_spatial=spatial,
latents_mean=[0.0] * z_dim,
latents_std=[1.0] * z_dim,
)

def encode(self, video: torch.Tensor):
latent_frames = (video.shape[2] - 1) // self.config.scale_factor_temporal + 1
latent_height = video.shape[-2] // self.config.scale_factor_spatial
latent_width = video.shape[-1] // self.config.scale_factor_spatial
latents = torch.ones(
video.shape[0],
self.config.z_dim,
latent_frames,
latent_height,
latent_width,
dtype=video.dtype,
device=video.device,
)
return SimpleNamespace(latent_dist=_ModeLatentDist(latents))

def decode(self, latents: torch.Tensor, return_dict: bool = False):
del return_dict
return (latents,)


class StubCosmos3Transformer(nn.Module):
def __init__(
self,
*,
latent_channel_size: int = 2,
sound_gen: bool = False,
sound_dim: int = 3,
action_gen: bool = False,
action_dim: int = 4,
) -> None:
super().__init__()
self.latent_channel_size = latent_channel_size
self.sound_gen = sound_gen
self.sound_dim = sound_dim
self.action_gen = action_gen
self.action_dim = action_dim
self.cached_kv: Any | None = None
self.cached_freqs_gen: Any | None = None
self.calls: list[dict[str, Any]] = []
self.reset_calls = 0

def reset_cache(self) -> None:
self.reset_calls += 1
self.cached_kv = None
self.cached_freqs_gen = None

def forward(
self,
*,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
text_ids: torch.Tensor,
text_mask: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
token = int(text_ids.reshape(-1)[0].item()) if text_ids.numel() else 0
sound_latents = kwargs.get("sound_latents")
self.calls.append(
{
"token": token,
"timestep": timestep.clone(),
"text_mask": text_mask.clone(),
"cache_before": self.cached_kv,
"kwargs": dict(kwargs),
}
)
if self.cached_kv is None:
marker = torch.tensor([token], dtype=torch.float32)
self.cached_kv = [(marker, marker + 100)]
self.cached_freqs_gen = (marker + 200, marker + 300)
action_latents = kwargs.get("action_latents")
outputs: list[torch.Tensor] = [torch.full_like(hidden_states, float(token))]
if action_latents is not None:
outputs.append(torch.full_like(action_latents, float(token + 20)))
if sound_latents is not None:
outputs.append(torch.full_like(sound_latents, float(token + 10)))
return outputs[0] if len(outputs) == 1 else tuple(outputs)


def passthrough_progress_bar(iterable):
return iterable


@pytest.fixture(autouse=True)
def fake_cosmos3_guardrails(monkeypatch: pytest.MonkeyPatch):
module = types.ModuleType("vllm_omni.diffusion.models.cosmos3.guardrails")
module.is_guardrails_enabled = lambda od_config, sampling_params=None: False
module.ensure_initialized = lambda od_config: None
module.check_text_safety = lambda text: None
module.check_video_safety = lambda video: video
monkeypatch.setitem(sys.modules, module.__name__, module)
return module


@pytest.fixture
def make_cosmos3_pipeline():
def _make():
from vllm_omni.diffusion.models.cosmos3.pipeline_cosmos3 import (
Cosmos3OmniDiffusersPipeline,
)

pipeline = object.__new__(Cosmos3OmniDiffusersPipeline)
nn.Module.__init__(pipeline)
pipeline.od_config = SimpleNamespace()
pipeline.device = torch.device("cpu")
pipeline.dtype = torch.float32
pipeline.transformer = StubCosmos3Transformer(latent_channel_size=2)
pipeline.vae = StubCosmos3VAE(z_dim=2)
pipeline.vae_scale_factor_temporal = 4
pipeline.vae_scale_factor_spatial = 8
pipeline.scheduler = StubScheduler([9, 3], flow_shift=1.0)
pipeline._base_scheduler_config = pipeline.scheduler.config
pipeline._engine_init_flow_shift = 1.0
pipeline._current_flow_shift = 1.0
pipeline._guidance_scale = None
pipeline._num_timesteps = None
pipeline.progress_bar = passthrough_progress_bar
pipeline._sound_tokenizer = None
return pipeline

return _make


def make_sampling_params(**overrides: Any) -> SimpleNamespace:
values = {
"height": None,
"width": None,
"num_frames": None,
"num_inference_steps": None,
"guidance_scale": None,
"generator": None,
"seed": 123,
"num_outputs_per_prompt": 1,
"frame_rate": None,
"resolved_frame_rate": None,
"max_sequence_length": None,
"extra_args": {},
}
values.update(overrides)
return SimpleNamespace(**values)
Loading